From 4ba63b193c1ac292493e06343d9d618c12c5ef3f Mon Sep 17 00:00:00 2001 From: jiangxingbo Date: Tue, 13 Sep 2016 17:04:51 +0200 Subject: [PATCH 001/306] [SPARK-17142][SQL] Complex query triggers binding error in HashAggregateExec ## What changes were proposed in this pull request? In `ReorderAssociativeOperator` rule, we extract foldable expressions with Add/Multiply arithmetics, and replace with eval literal. For example, `(a + 1) + (b + 2)` is optimized to `(a + b + 3)` by this rule. For aggregate operator, output expressions should be derived from groupingExpressions, current implemenation of `ReorderAssociativeOperator` rule may break this promise. A instance could be: ``` SELECT ((t1.a + 1) + (t2.a + 2)) AS out_col FROM testdata2 AS t1 INNER JOIN testdata2 AS t2 ON (t1.a = t2.a) GROUP BY (t1.a + 1), (t2.a + 2) ``` `((t1.a + 1) + (t2.a + 2))` is optimized to `(t1.a + t2.a + 3)`, which could not be derived from `ExpressionSet((t1.a +1), (t2.a + 2))`. Maybe we should improve the rule of `ReorderAssociativeOperator` by adding a GroupingExpressionSet to keep Aggregate.groupingExpressions, and respect these expressions during the optimize stage. ## How was this patch tested? Add new test case in `ReorderAssociativeOperatorSuite`. Author: jiangxingbo Closes #14917 from jiangxb1987/rao. --- .../sql/catalyst/optimizer/expressions.scala | 31 ++++++++++++++----- .../ReorderAssociativeOperatorSuite.scala | 16 +++++++++- 2 files changed, 39 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 82ab111aa2259..b7458910da13e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -57,20 +57,37 @@ object ConstantFolding extends Rule[LogicalPlan] { * Reorder associative integral-type operators and fold all constants into one. */ object ReorderAssociativeOperator extends Rule[LogicalPlan] { - private def flattenAdd(e: Expression): Seq[Expression] = e match { - case Add(l, r) => flattenAdd(l) ++ flattenAdd(r) + private def flattenAdd( + expression: Expression, + groupSet: ExpressionSet): Seq[Expression] = expression match { + case expr @ Add(l, r) if !groupSet.contains(expr) => + flattenAdd(l, groupSet) ++ flattenAdd(r, groupSet) case other => other :: Nil } - private def flattenMultiply(e: Expression): Seq[Expression] = e match { - case Multiply(l, r) => flattenMultiply(l) ++ flattenMultiply(r) + private def flattenMultiply( + expression: Expression, + groupSet: ExpressionSet): Seq[Expression] = expression match { + case expr @ Multiply(l, r) if !groupSet.contains(expr) => + flattenMultiply(l, groupSet) ++ flattenMultiply(r, groupSet) case other => other :: Nil } + private def collectGroupingExpressions(plan: LogicalPlan): ExpressionSet = plan match { + case Aggregate(groupingExpressions, aggregateExpressions, child) => + ExpressionSet.apply(groupingExpressions) + case _ => ExpressionSet(Seq()) + } + def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case q: LogicalPlan => q transformExpressionsDown { + case q: LogicalPlan => + // We have to respect aggregate expressions which exists in grouping expressions when plan + // is an Aggregate operator, otherwise the optimized expression could not be derived from + // grouping expressions. + val groupingExpressionSet = collectGroupingExpressions(q) + q transformExpressionsDown { case a: Add if a.deterministic && a.dataType.isInstanceOf[IntegralType] => - val (foldables, others) = flattenAdd(a).partition(_.foldable) + val (foldables, others) = flattenAdd(a, groupingExpressionSet).partition(_.foldable) if (foldables.size > 1) { val foldableExpr = foldables.reduce((x, y) => Add(x, y)) val c = Literal.create(foldableExpr.eval(EmptyRow), a.dataType) @@ -79,7 +96,7 @@ object ReorderAssociativeOperator extends Rule[LogicalPlan] { a } case m: Multiply if m.deterministic && m.dataType.isInstanceOf[IntegralType] => - val (foldables, others) = flattenMultiply(m).partition(_.foldable) + val (foldables, others) = flattenMultiply(m, groupingExpressionSet).partition(_.foldable) if (foldables.size > 1) { val foldableExpr = foldables.reduce((x, y) => Multiply(x, y)) val c = Literal.create(foldableExpr.eval(EmptyRow), m.dataType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReorderAssociativeOperatorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReorderAssociativeOperatorSuite.scala index 05e15e9ec4728..a1ab0a834474f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReorderAssociativeOperatorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReorderAssociativeOperatorSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor @@ -60,4 +60,18 @@ class ReorderAssociativeOperatorSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + + test("nested expression with aggregate operator") { + val originalQuery = + testRelation.as("t1") + .join(testRelation.as("t2"), Inner, Some("t1.a".attr === "t2.a".attr)) + .groupBy("t1.a".attr + 1, "t2.a".attr + 1)( + (("t1.a".attr + 1) + ("t2.a".attr + 1)).as("col")) + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = originalQuery.analyze + + comparePlans(optimized, correctAnswer) + } } From 72edc7e958271cedb01932880550cfc2c0631204 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Tue, 13 Sep 2016 15:11:55 -0700 Subject: [PATCH 002/306] [SPARK-17531] Don't initialize Hive Listeners for the Execution Client ## What changes were proposed in this pull request? If a user provides listeners inside the Hive Conf, the configuration for these listeners are passed to the Hive Execution Client as well. This may cause issues for two reasons: 1. The Execution Client will actually generate garbage 2. The listener class needs to be both in the Spark Classpath and Hive Classpath This PR empties the listener configurations in `HiveUtils.newTemporaryConfiguration` so that the execution client will not contain the listener confs, but the metadata client will. ## How was this patch tested? Unit tests Author: Burak Yavuz Closes #15086 from brkyvz/null-listeners. --- .../org/apache/spark/sql/hive/HiveUtils.scala | 7 ++++ .../spark/sql/hive/HiveUtilsSuite.scala | 36 +++++++++++++++++++ 2 files changed, 43 insertions(+) create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUtilsSuite.scala diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala index bdec611453b2d..39d71e164bf51 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala @@ -394,6 +394,13 @@ private[spark] object HiveUtils extends Logging { // hive.metastore.uris is not set. propMap.put(ConfVars.METASTOREURIS.varname, "") + // The execution client will generate garbage events, therefore the listeners that are generated + // for the execution clients are useless. In order to not output garbage, we don't generate + // these listeners. + propMap.put(ConfVars.METASTORE_PRE_EVENT_LISTENERS.varname, "") + propMap.put(ConfVars.METASTORE_EVENT_LISTENERS.varname, "") + propMap.put(ConfVars.METASTORE_END_FUNCTION_LISTENERS.varname, "") + propMap.toMap } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUtilsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUtilsSuite.scala new file mode 100644 index 0000000000000..667a7ddd8bb61 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUtilsSuite.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.hadoop.hive.conf.HiveConf.ConfVars + +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.QueryTest + +class HiveUtilsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { + + test("newTemporaryConfiguration overwrites listener configurations") { + Seq(true, false).foreach { useInMemoryDerby => + val conf = HiveUtils.newTemporaryConfiguration(useInMemoryDerby) + assert(conf(ConfVars.METASTORE_PRE_EVENT_LISTENERS.varname) === "") + assert(conf(ConfVars.METASTORE_EVENT_LISTENERS.varname) === "") + assert(conf(ConfVars.METASTORE_END_FUNCTION_LISTENERS.varname) === "") + } + } +} From 37b93f54e89332b6b77bb02c1c2299614338fd7c Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 14 Sep 2016 00:37:42 +0200 Subject: [PATCH 003/306] [SPARK-17530][SQL] Add Statistics into DESCRIBE FORMATTED ### What changes were proposed in this pull request? Statistics is missing in the output of `DESCRIBE FORMATTED`. This PR is to add it. After the PR, the output will be like: ``` +----------------------------+----------------------------------------------------------------------------------------------------------------------+-------+ |col_name |data_type |comment| +----------------------------+----------------------------------------------------------------------------------------------------------------------+-------+ |key |string |null | |value |string |null | | | | | |# Detailed Table Information| | | |Database: |default | | |Owner: |xiaoli | | |Create Time: |Tue Sep 13 14:36:57 PDT 2016 | | |Last Access Time: |Wed Dec 31 16:00:00 PST 1969 | | |Location: |file:/private/var/folders/4b/sgmfldk15js406vk7lw5llzw0000gn/T/warehouse-9982e1db-df17-4376-a140-dbbee0203d83/texttable| | |Table Type: |MANAGED | | |Statistics: |sizeInBytes=5812, rowCount=500, isBroadcastable=false | | |Table Parameters: | | | | rawDataSize |-1 | | | numFiles |1 | | | transient_lastDdlTime |1473802620 | | | totalSize |5812 | | | COLUMN_STATS_ACCURATE |false | | | numRows |-1 | | | | | | |# Storage Information | | | |SerDe Library: |org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe | | |InputFormat: |org.apache.hadoop.mapred.TextInputFormat | | |OutputFormat: |org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat | | |Compressed: |No | | |Storage Desc Parameters: | | | | serialization.format |1 | | +----------------------------+----------------------------------------------------------------------------------------------------------------------+-------+ ``` Also improve the output of statistics in `DESCRIBE EXTENDED` by removing duplicate `Statistics`. Below is the example after the PR: ``` +----------------------------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+-------+ |col_name |data_type |comment| +----------------------------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+-------+ |key |string |null | |value |string |null | | | | | |# Detailed Table Information|CatalogTable( Table: `default`.`texttable` Owner: xiaoli Created: Tue Sep 13 14:38:43 PDT 2016 Last Access: Wed Dec 31 16:00:00 PST 1969 Type: MANAGED Schema: [StructField(key,StringType,true), StructField(value,StringType,true)] Provider: hive Properties: [rawDataSize=-1, numFiles=1, transient_lastDdlTime=1473802726, totalSize=5812, COLUMN_STATS_ACCURATE=false, numRows=-1] Statistics: sizeInBytes=5812, rowCount=500, isBroadcastable=false Storage(Location: file:/private/var/folders/4b/sgmfldk15js406vk7lw5llzw0000gn/T/warehouse-8ea5c5a0-5680-4778-91cb-c6334cf8a708/texttable, InputFormat: org.apache.hadoop.mapred.TextInputFormat, OutputFormat: org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat, Serde: org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe, Properties: [serialization.format=1]))| | +----------------------------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+-------+ ``` ### How was this patch tested? Manually tested. Author: gatorsmile Closes #15083 from gatorsmile/descFormattedStats. --- .../spark/sql/catalyst/catalog/interface.scala | 2 +- .../sql/catalyst/plans/logical/Statistics.scala | 15 ++++++++------- .../spark/sql/execution/command/tables.scala | 1 + 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index e74fa6e638a0b..e52251f960ff4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -191,7 +191,7 @@ case class CatalogTable( viewText.map("View: " + _).getOrElse(""), comment.map("Comment: " + _).getOrElse(""), if (properties.nonEmpty) s"Properties: $tableProperties" else "", - if (stats.isDefined) s"Statistics: ${stats.get}" else "", + if (stats.isDefined) s"Statistics: ${stats.get.simpleString}" else "", s"$storage") output.filter(_.nonEmpty).mkString("CatalogTable(\n\t", "\n\t", ")") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala index 58fa537a18e3e..3cf20385dd712 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala @@ -38,12 +38,13 @@ case class Statistics( sizeInBytes: BigInt, rowCount: Option[BigInt] = None, isBroadcastable: Boolean = false) { - override def toString: String = { - val output = - Seq(s"sizeInBytes=$sizeInBytes", - if (rowCount.isDefined) s"rowCount=${rowCount.get}" else "", - s"isBroadcastable=$isBroadcastable" - ) - output.filter(_.nonEmpty).mkString("Statistics(", ", ", ")") + override def toString: String = "Statistics(" + simpleString + ")" + + /** Readable string representation for the Statistics. */ + def simpleString: String = { + Seq(s"sizeInBytes=$sizeInBytes", + if (rowCount.isDefined) s"rowCount=${rowCount.get}" else "", + s"isBroadcastable=$isBroadcastable" + ).filter(_.nonEmpty).mkString("", ", ", "") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index 027f3588e2922..9fbcd48b4a911 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -468,6 +468,7 @@ case class DescribeTableCommand(table: TableIdentifier, isExtended: Boolean, isF append(buffer, "Last Access Time:", new Date(table.lastAccessTime).toString, "") append(buffer, "Location:", table.storage.locationUri.getOrElse(""), "") append(buffer, "Table Type:", table.tableType.name, "") + table.stats.foreach(s => append(buffer, "Statistics:", s.simpleString, "")) append(buffer, "Table Parameters:", "", "") table.properties.foreach { case (key, value) => From a454a4d86bbed1b6988da0a0e23b3e87a1a16340 Mon Sep 17 00:00:00 2001 From: junyangq Date: Tue, 13 Sep 2016 21:01:03 -0700 Subject: [PATCH 004/306] [SPARK-17317][SPARKR] Add SparkR vignette ## What changes were proposed in this pull request? This PR tries to add a SparkR vignette, which works as a friendly guidance going through the functionality provided by SparkR. ## How was this patch tested? Manual test. Author: junyangq Author: Shivaram Venkataraman Author: Junyang Qian Closes #14980 from junyangq/SPARKR-vignette. --- R/create-docs.sh | 11 +- R/pkg/vignettes/sparkr-vignettes.Rmd | 861 +++++++++++++++++++++++++++ 2 files changed, 870 insertions(+), 2 deletions(-) create mode 100644 R/pkg/vignettes/sparkr-vignettes.Rmd diff --git a/R/create-docs.sh b/R/create-docs.sh index d2ae160b50021..0dfba22463396 100755 --- a/R/create-docs.sh +++ b/R/create-docs.sh @@ -17,11 +17,13 @@ # limitations under the License. # -# Script to create API docs for SparkR -# This requires `devtools` and `knitr` to be installed on the machine. +# Script to create API docs and vignettes for SparkR +# This requires `devtools`, `knitr` and `rmarkdown` to be installed on the machine. # After running this script the html docs can be found in # $SPARK_HOME/R/pkg/html +# The vignettes can be found in +# $SPARK_HOME/R/pkg/vignettes/sparkr_vignettes.html set -o pipefail set -e @@ -43,4 +45,9 @@ Rscript -e 'libDir <- "../../lib"; library(SparkR, lib.loc=libDir); library(knit popd +# render creates SparkR vignettes +Rscript -e 'library(rmarkdown); paths <- .libPaths(); .libPaths(c("lib", paths)); Sys.setenv(SPARK_HOME=tools::file_path_as_absolute("..")); render("pkg/vignettes/sparkr-vignettes.Rmd"); .libPaths(paths)' + +find pkg/vignettes/. -not -name '.' -not -name '*.Rmd' -not -name '*.md' -not -name '*.pdf' -not -name '*.html' -delete + popd diff --git a/R/pkg/vignettes/sparkr-vignettes.Rmd b/R/pkg/vignettes/sparkr-vignettes.Rmd new file mode 100644 index 0000000000000..aea52db8b8556 --- /dev/null +++ b/R/pkg/vignettes/sparkr-vignettes.Rmd @@ -0,0 +1,861 @@ +--- +title: "SparkR - Practical Guide" +output: + html_document: + theme: united + toc: true + toc_depth: 4 + toc_float: true + highlight: textmate +--- + +## Overview + +SparkR is an R package that provides a light-weight frontend to use Apache Spark from R. With Spark `r packageVersion("SparkR")`, SparkR provides a distributed data frame implementation that supports data processing operations like selection, filtering, aggregation etc. and distributed machine learning using [MLlib](http://spark.apache.org/mllib/). + +## Getting Started + +We begin with an example running on the local machine and provide an overview of the use of SparkR: data ingestion, data processing and machine learning. + +First, let's load and attach the package. +```{r, message=FALSE} +library(SparkR) +``` + +`SparkSession` is the entry point into SparkR which connects your R program to a Spark cluster. You can create a `SparkSession` using `sparkR.session` and pass in options such as the application name, any Spark packages depended on, etc. + +We use default settings in which it runs in local mode. It auto downloads Spark package in the background if no previous installation is found. For more details about setup, see [Spark Session](#SetupSparkSession). + +```{r, message=FALSE} +sparkR.session() +``` + +The operations in SparkR are centered around an R class called `SparkDataFrame`. It is a distributed collection of data organized into named columns, which is conceptually equivalent to a table in a relational database or a data frame in R, but with richer optimizations under the hood. + +`SparkDataFrame` can be constructed from a wide array of sources such as: structured data files, tables in Hive, external databases, or existing local R data frames. For example, we create a `SparkDataFrame` from a local R data frame, + +```{r} +cars <- cbind(model = rownames(mtcars), mtcars) +carsDF <- createDataFrame(cars) +``` + +We can view the first few rows of the `SparkDataFrame` by `head` or `showDF` function. +```{r} +head(carsDF) +``` + +Common data processing operations such as `filter`, `select` are supported on the `SparkDataFrame`. +```{r} +carsSubDF <- select(carsDF, "model", "mpg", "hp") +carsSubDF <- filter(carsSubDF, carsSubDF$hp >= 200) +head(carsSubDF) +``` + +SparkR can use many common aggregation functions after grouping. + +```{r} +carsGPDF <- summarize(groupBy(carsDF, carsDF$gear), count = n(carsDF$gear)) +head(carsGPDF) +``` + +The results `carsDF` and `carsSubDF` are `SparkDataFrame` objects. To convert back to R `data.frame`, we can use `collect`. **Caution**: This can cause your interactive environment to run out of memory, though, because `collect()` fetches the entire distributed `DataFrame` to your client, which is acting as a Spark driver. +```{r} +carsGP <- collect(carsGPDF) +class(carsGP) +``` + +SparkR supports a number of commonly used machine learning algorithms. Under the hood, SparkR uses MLlib to train the model. Users can call `summary` to print a summary of the fitted model, `predict` to make predictions on new data, and `write.ml`/`read.ml` to save/load fitted models. + +SparkR supports a subset of R formula operators for model fitting, including ‘~’, ‘.’, ‘:’, ‘+’, and ‘-‘. We use linear regression as an example. +```{r} +model <- spark.glm(carsDF, mpg ~ wt + cyl) +``` + +The result matches that returned by R `glm` function applied to the corresponding `data.frame` `mtcars` of `carsDF`. In fact, for Generalized Linear Model, we specifically expose `glm` for `SparkDataFrame` as well so that the above is equivalent to `model <- glm(mpg ~ wt + cyl, data = carsDF)`. + +```{r} +summary(model) +``` + +The model can be saved by `write.ml` and loaded back using `read.ml`. +```{r, eval=FALSE} +write.ml(model, path = "/HOME/tmp/mlModel/glmModel") +``` + +In the end, we can stop Spark Session by running +```{r, eval=FALSE} +sparkR.session.stop() +``` + +## Setup + +### Installation + +Different from many other R packages, to use SparkR, you need an additional installation of Apache Spark. The Spark installation will be used to run a backend process that will compile and execute SparkR programs. + +If you don't have Spark installed on the computer, you may download it from [Apache Spark Website](http://spark.apache.org/downloads.html). Alternatively, we provide an easy-to-use function `install.spark` to complete this process. You don't have to call it explicitly. We will check the installation when `sparkR.session` is called and `install.spark` function will be triggered automatically if no installation is found. + +```{r, eval=FALSE} +install.spark() +``` + +If you already have Spark installed, you don't have to install again and can pass the `sparkHome` argument to `sparkR.session` to let SparkR know where the Spark installation is. + +```{r, eval=FALSE} +sparkR.session(sparkHome = "/HOME/spark") +``` + +### Spark Session {#SetupSparkSession} + + +In addition to `sparkHome`, many other options can be specified in `sparkR.session`. For a complete list, see [Starting up: SparkSession](http://spark.apache.org/docs/latest/sparkr.html#starting-up-sparksession) and [SparkR API doc](http://spark.apache.org/docs/latest/api/R/sparkR.session.html). + +In particular, the following Spark driver properties can be set in `sparkConfig`. + +Property Name | Property group | spark-submit equivalent +---------------- | ------------------ | ---------------------- +spark.driver.memory | Application Properties | --driver-memory +spark.driver.extraClassPath | Runtime Environment | --driver-class-path +spark.driver.extraJavaOptions | Runtime Environment | --driver-java-options +spark.driver.extraLibraryPath | Runtime Environment | --driver-library-path + +**For Windows users**: Due to different file prefixes across operating systems, to avoid the issue of potential wrong prefix, a current workaround is to specify `spark.sql.warehouse.dir` when starting the `SparkSession`. + +```{r, eval=FALSE} +spark_warehouse_path <- file.path(path.expand('~'), "spark-warehouse") +sparkR.session(spark.sql.warehouse.dir = spark_warehouse_path) +``` + + +#### Cluster Mode +SparkR can connect to remote Spark clusters. [Cluster Mode Overview](http://spark.apache.org/docs/latest/cluster-overview.html) is a good introduction to different Spark cluster modes. + +When connecting SparkR to a remote Spark cluster, make sure that the Spark version and Hadoop version on the machine match the corresponding versions on the cluster. Current SparkR package is compatible with +```{r, echo=FALSE, tidy = TRUE} +paste("Spark", packageVersion("SparkR")) +``` +It should be used both on the local computer and on the remote cluster. + +To connect, pass the URL of the master node to `sparkR.session`. A complete list can be seen in [Spark Master URLs](http://spark.apache.org/docs/latest/submitting-applications.html#master-urls). +For example, to connect to a local standalone Spark master, we can call + +```{r, eval=FALSE} +sparkR.session(master = "spark://local:7077") +``` + +For YARN cluster, SparkR supports the client mode with the master set as "yarn". +```{r, eval=FALSE} +sparkR.session(master = "yarn") +``` +Yarn cluster mode is not supported in the current version. + +## Data Import + +### Local Data Frame +The simplest way is to convert a local R data frame into a `SparkDataFrame`. Specifically we can use `as.DataFrame` or `createDataFrame` and pass in the local R data frame to create a `SparkDataFrame`. As an example, the following creates a `SparkDataFrame` based using the `faithful` dataset from R. +```{r} +df <- as.DataFrame(faithful) +head(df) +``` + +### Data Sources +SparkR supports operating on a variety of data sources through the `SparkDataFrame` interface. You can check the Spark SQL programming guide for more [specific options](https://spark.apache.org/docs/latest/sql-programming-guide.html#manually-specifying-options) that are available for the built-in data sources. + +The general method for creating `SparkDataFrame` from data sources is `read.df`. This method takes in the path for the file to load and the type of data source, and the currently active Spark Session will be used automatically. SparkR supports reading CSV, JSON and Parquet files natively and through Spark Packages you can find data source connectors for popular file formats like Avro. These packages can be added with `sparkPackages` parameter when initializing SparkSession using `sparkR.session'.` + +```{r, eval=FALSE} +sparkR.session(sparkPackages = "com.databricks:spark-avro_2.11:3.0.0") +``` + +We can see how to use data sources using an example CSV input file. For more information please refer to SparkR [read.df](https://spark.apache.org/docs/latest/api/R/read.df.html) API documentation. +```{r, eval=FALSE} +df <- read.df(csvPath, "csv", header = "true", inferSchema = "true", na.strings = "NA") +``` + +The data sources API natively supports JSON formatted input files. Note that the file that is used here is not a typical JSON file. Each line in the file must contain a separate, self-contained valid JSON object. As a consequence, a regular multi-line JSON file will most often fail. + +Let's take a look at the first two lines of the raw JSON file used here. + +```{r} +filePath <- paste0(sparkR.conf("spark.home"), + "/examples/src/main/resources/people.json") +readLines(filePath, n = 2L) +``` + +We use `read.df` to read that into a `SparkDataFrame`. + +```{r} +people <- read.df(filePath, "json") +count(people) +head(people) +``` + +SparkR automatically infers the schema from the JSON file. +```{r} +printSchema(people) +``` + +If we want to read multiple JSON files, `read.json` can be used. +```{r} +people <- read.json(paste0(Sys.getenv("SPARK_HOME"), + c("/examples/src/main/resources/people.json", + "/examples/src/main/resources/people.json"))) +count(people) +``` + +The data sources API can also be used to save out `SparkDataFrames` into multiple file formats. For example we can save the `SparkDataFrame` from the previous example to a Parquet file using `write.df`. +```{r, eval=FALSE} +write.df(people, path = "people.parquet", source = "parquet", mode = "overwrite") +``` + +### Hive Tables +You can also create SparkDataFrames from Hive tables. To do this we will need to create a SparkSession with Hive support which can access tables in the Hive MetaStore. Note that Spark should have been built with Hive support and more details can be found in the [SQL programming guide](https://spark.apache.org/docs/latest/sql-programming-guide.html). In SparkR, by default it will attempt to create a SparkSession with Hive support enabled (`enableHiveSupport = TRUE`). + +```{r, eval=FALSE} +sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") + +txtPath <- paste0(sparkR.conf("spark.home"), "/examples/src/main/resources/kv1.txt") +sqlCMD <- sprintf("LOAD DATA LOCAL INPATH '%s' INTO TABLE src", txtPath) +sql(sqlCMD) + +results <- sql("FROM src SELECT key, value") + +# results is now a SparkDataFrame +head(results) +``` + + +## Data Processing + +**To dplyr users**: SparkR has similar interface as dplyr in data processing. However, some noticeable differences are worth mentioning in the first place. We use `df` to represent a `SparkDataFrame` and `col` to represent the name of column here. + +1. indicate columns. SparkR uses either a character string of the column name or a Column object constructed with `$` to indicate a column. For example, to select `col` in `df`, we can write `select(df, "col")` or `select(df, df$col)`. + +2. describe conditions. In SparkR, the Column object representation can be inserted into the condition directly, or we can use a character string to describe the condition, without referring to the `SparkDataFrame` used. For example, to select rows with value > 1, we can write `filter(df, df$col > 1)` or `filter(df, "col > 1")`. + +Here are more concrete examples. + +dplyr | SparkR +-------- | --------- +`select(mtcars, mpg, hp)` | `select(carsDF, "mpg", "hp")` +`filter(mtcars, mpg > 20, hp > 100)` | `filter(carsDF, carsDF$mpg > 20, carsDF$hp > 100)` + +Other differences will be mentioned in the specific methods. + +We use the `SparkDataFrame` `carsDF` created above. We can get basic information about the `SparkDataFrame`. +```{r} +carsDF +``` + +Print out the schema in tree format. +```{r} +printSchema(carsDF) +``` + +### SparkDataFrame Operations + +#### Selecting rows, columns + +SparkDataFrames support a number of functions to do structured data processing. Here we include some basic examples and a complete list can be found in the [API](https://spark.apache.org/docs/latest/api/R/index.html) docs: + +You can also pass in column name as strings. +```{r} +head(select(carsDF, "mpg")) +``` + +Filter the SparkDataFrame to only retain rows with mpg less than 20 miles/gallon. +```{r} +head(filter(carsDF, carsDF$mpg < 20)) +``` + +#### Grouping, Aggregation + +A common flow of grouping and aggregation is + +1. Use `groupBy` or `group_by` with respect to some grouping variables to create a `GroupedData` object + +2. Feed the `GroupedData` object to `agg` or `summarize` functions, with some provided aggregation functions to compute a number within each group. + +A number of widely used functions are supported to aggregate data after grouping, including `avg`, `countDistinct`, `count`, `first`, `kurtosis`, `last`, `max`, `mean`, `min`, `sd`, `skewness`, `stddev_pop`, `stddev_samp`, `sumDistinct`, `sum`, `var_pop`, `var_samp`, `var`. See the [API doc for `mean`](http://spark.apache.org/docs/latest/api/R/mean.html) and other `agg_funcs` linked there. + +For example we can compute a histogram of the number of cylinders in the `mtcars` dataset as shown below. + +```{r} +numCyl <- summarize(groupBy(carsDF, carsDF$cyl), count = n(carsDF$cyl)) +head(numCyl) +``` + +#### Operating on Columns + +SparkR also provides a number of functions that can directly applied to columns for data processing and during aggregation. The example below shows the use of basic arithmetic functions. + +```{r} +carsDF_km <- carsDF +carsDF_km$kmpg <- carsDF_km$mpg * 1.61 +head(select(carsDF_km, "model", "mpg", "kmpg")) +``` + + +### Window Functions +A window function is a variation of aggregation function. In simple words, + +* aggregation function: `n` to `1` mapping - returns a single value for a group of entries. Examples include `sum`, `count`, `max`. + +* window function: `n` to `n` mapping - returns one value for each entry in the group, but the value may depend on all the entries of the *group*. Examples include `rank`, `lead`, `lag`. + +Formally, the *group* mentioned above is called the *frame*. Every input row can have a unique frame associated with it and the output of the window function on that row is based on the rows confined in that frame. + +Window functions are often used in conjunction with the following functions: `windowPartitionBy`, `windowOrderBy`, `partitionBy`, `orderBy`, `over`. To illustrate this we next look at an example. + +We still use the `mtcars` dataset. The corresponding `SparkDataFrame` is `carsDF`. Suppose for each number of cylinders, we want to calculate the rank of each car in `mpg` within the group. +```{r} +carsSubDF <- select(carsDF, "model", "mpg", "cyl") +ws <- orderBy(windowPartitionBy("cyl"), "mpg") +carsRank <- withColumn(carsSubDF, "rank", over(rank(), ws)) +head(carsRank, n = 20L) +``` + +We explain in detail the above steps. + +* `windowPartitionBy` creates a window specification object `WindowSpec` that defines the partition. It controls which rows will be in the same partition as the given row. In this case, rows with the same value in `cyl` will be put in the same partition. `orderBy` further defines the ordering - the position a given row is in the partition. The resulting `WindowSpec` is returned as `ws`. + +More window specification methods include `rangeBetween`, which can define boundaries of the frame by value, and `rowsBetween`, which can define the boundaries by row indices. + +* `withColumn` appends a Column called `rank` to the `SparkDataFrame`. `over` returns a windowing column. The first argument is usually a Column returned by window function(s) such as `rank()`, `lead(carsDF$wt)`. That calculates the corresponding values according to the partitioned-and-ordered table. + +### User-Defined Function + +In SparkR, we support several kinds of user-defined functions (UDFs). + +#### Apply by Partition + +`dapply` can apply a function to each partition of a `SparkDataFrame`. The function to be applied to each partition of the `SparkDataFrame` should have only one parameter, a `data.frame` corresponding to a partition, and the output should be a `data.frame` as well. Schema specifies the row format of the resulting a `SparkDataFrame`. It must match to data types of returned value. See [here](#DataTypes) for mapping between R and Spark. + +We convert `mpg` to `kmpg` (kilometers per gallon). `carsSubDF` is a `SparkDataFrame` with a subset of `carsDF` columns. + +```{r} +carsSubDF <- select(carsDF, "model", "mpg") +schema <- structType(structField("model", "string"), structField("mpg", "double"), + structField("kmpg", "double")) +out <- dapply(carsSubDF, function(x) { x <- cbind(x, x$mpg * 1.61) }, schema) +head(collect(out)) +``` + +Like `dapply`, apply a function to each partition of a `SparkDataFrame` and collect the result back. The output of function should be a `data.frame`, but no schema is required in this case. Note that `dapplyCollect` can fail if the output of UDF run on all the partition cannot be pulled to the driver and fit in driver memory. + +```{r} +out <- dapplyCollect( + carsSubDF, + function(x) { + x <- cbind(x, "kmpg" = x$mpg * 1.61) + }) +head(out, 3) +``` + +#### Apply by Group +`gapply` can apply a function to each group of a `SparkDataFrame`. The function is to be applied to each group of the `SparkDataFrame` and should have only two parameters: grouping key and R `data.frame` corresponding to that key. The groups are chosen from `SparkDataFrames` column(s). The output of function should be a `data.frame`. Schema specifies the row format of the resulting `SparkDataFrame`. It must represent R function’s output schema on the basis of Spark data types. The column names of the returned `data.frame` are set by user. See [here](#DataTypes) for mapping between R and Spark. + +```{r} +schema <- structType(structField("cyl", "double"), structField("max_mpg", "double")) +result <- gapply( + carsDF, + "cyl", + function(key, x) { + y <- data.frame(key, max(x$mpg)) + }, + schema) +head(arrange(result, "max_mpg", decreasing = TRUE)) +``` + +Like gapply, `gapplyCollect` applies a function to each partition of a `SparkDataFrame` and collect the result back to R `data.frame`. The output of the function should be a `data.frame` but no schema is required in this case. Note that `gapplyCollect` can fail if the output of UDF run on all the partition cannot be pulled to the driver and fit in driver memory. + +```{r} +result <- gapplyCollect( + carsDF, + "cyl", + function(key, x) { + y <- data.frame(key, max(x$mpg)) + colnames(y) <- c("cyl", "max_mpg") + y + }) +head(result[order(result$max_mpg, decreasing = TRUE), ]) +``` + +#### Distribute Local Functions + +Similar to `lapply` in native R, `spark.lapply` runs a function over a list of elements and distributes the computations with Spark. `spark.lapply` works in a manner that is similar to `doParallel` or `lapply` to elements of a list. The results of all the computations should fit in a single machine. If that is not the case you can do something like `df <- createDataFrame(list)` and then use `dapply`. + +We use `svm` in package `e1071` as an example. We use all default settings except for varying costs of constraints violation. `spark.lapply` can train those different models in parallel. + +```{r} +costs <- exp(seq(from = log(1), to = log(1000), length.out = 5)) +train <- function(cost) { + stopifnot(requireNamespace("e1071", quietly = TRUE)) + model <- e1071::svm(Species ~ ., data = iris, cost = cost) + summary(model) +} +``` + +Return a list of model's summaries. +```{r} +model.summaries <- spark.lapply(costs, train) +``` + +```{r} +class(model.summaries) +``` + + +To avoid lengthy display, we only present the result of the second fitted model. You are free to inspect other models as well. +```{r} +print(model.summaries[[2]]) +``` + + +### SQL Queries +A `SparkDataFrame` can also be registered as a temporary view in Spark SQL and that allows you to run SQL queries over its data. The sql function enables applications to run SQL queries programmatically and returns the result as a `SparkDataFrame`. + +```{r} +people <- read.df(paste0(sparkR.conf("spark.home"), + "/examples/src/main/resources/people.json"), "json") +``` + +Register this SparkDataFrame as a temporary view. + +```{r} +createOrReplaceTempView(people, "people") +``` + +SQL statements can be run by using the sql method. +```{r} +teenagers <- sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") +head(teenagers) +``` + + +## Machine Learning + +SparkR supports the following machine learning models and algorithms. + +* Generalized Linear Model (GLM) + +* Naive Bayes Model + +* $k$-means Clustering + +* Accelerated Failure Time (AFT) Survival Model + +* Gaussian Mixture Model (GMM) + +* Latent Dirichlet Allocation (LDA) + +* Multilayer Perceptron Model + +* Collaborative Filtering with Alternating Least Squares (ALS) + +* Isotonic Regression Model + +More will be added in the future. + +### R Formula + +For most above, SparkR supports **R formula operators**, including `~`, `.`, `:`, `+` and `-` for model fitting. This makes it a similar experience as using R functions. + +### Training and Test Sets + +We can easily split `SparkDataFrame` into random training and test sets by the `randomSplit` function. It returns a list of split `SparkDataFrames` with provided `weights`. We use `carsDF` as an example and want to have about $70%$ training data and $30%$ test data. +```{r} +splitDF_list <- randomSplit(carsDF, c(0.7, 0.3), seed = 0) +carsDF_train <- splitDF_list[[1]] +carsDF_test <- splitDF_list[[2]] +``` + +```{r} +count(carsDF_train) +head(carsDF_train) +``` + +```{r} +count(carsDF_test) +head(carsDF_test) +``` + + +### Models and Algorithms + +#### Generalized Linear Model + +The main function is `spark.glm`. The following families and link functions are supported. The default is gaussian. + +Family | Link Function +------ | --------- +gaussian | identity, log, inverse +binomial | logit, probit, cloglog (complementary log-log) +poisson | log, identity, sqrt +gamma | inverse, identity, log + +There are three ways to specify the `family` argument. + +* Family name as a character string, e.g. `family = "gaussian"`. + +* Family function, e.g. `family = binomial`. + +* Result returned by a family function, e.g. `family = poisson(link = log)` + +For more information regarding the families and their link functions, see the Wikipedia page [Generalized Linear Model](https://en.wikipedia.org/wiki/Generalized_linear_model). + +We use the `mtcars` dataset as an illustration. The corresponding `SparkDataFrame` is `carsDF`. After fitting the model, we print out a summary and see the fitted values by making predictions on the original dataset. We can also pass into a new `SparkDataFrame` of same schema to predict on new data. + +```{r} +gaussianGLM <- spark.glm(carsDF, mpg ~ wt + hp) +summary(gaussianGLM) +``` +When doing prediction, a new column called `prediction` will be appended. Let's look at only a subset of columns here. +```{r} +gaussianFitted <- predict(gaussianGLM, carsDF) +head(select(gaussianFitted, "model", "prediction", "mpg", "wt", "hp")) +``` + +#### Naive Bayes Model + +Naive Bayes model assumes independence among the features. `spark.naiveBayes` fits a [Bernoulli naive Bayes model](https://en.wikipedia.org/wiki/Naive_Bayes_classifier#Bernoulli_naive_Bayes) against a SparkDataFrame. The data should be all categorical. These models are often used for document classification. + +```{r} +titanic <- as.data.frame(Titanic) +titanicDF <- createDataFrame(titanic[titanic$Freq > 0, -5]) +naiveBayesModel <- spark.naiveBayes(titanicDF, Survived ~ Class + Sex + Age) +summary(naiveBayesModel) +naiveBayesPrediction <- predict(naiveBayesModel, titanicDF) +head(select(naiveBayesPrediction, "Class", "Sex", "Age", "Survived", "prediction")) +``` + +#### k-Means Clustering + +`spark.kmeans` fits a $k$-means clustering model against a `SparkDataFrame`. As an unsupervised learning method, we don't need a response variable. Hence, the left hand side of the R formula should be left blank. The clustering is based only on the variables on the right hand side. + +```{r} +kmeansModel <- spark.kmeans(carsDF, ~ mpg + hp + wt, k = 3) +summary(kmeansModel) +kmeansPredictions <- predict(kmeansModel, carsDF) +head(select(kmeansPredictions, "model", "mpg", "hp", "wt", "prediction"), n = 20L) +``` + +#### AFT Survival Model +Survival analysis studies the expected duration of time until an event happens, and often the relationship with risk factors or treatment taken on the subject. In contrast to standard regression analysis, survival modeling has to deal with special characteristics in the data including non-negative survival time and censoring. + +Accelerated Failure Time (AFT) model is a parametric survival model for censored data that assumes the effect of a covariate is to accelerate or decelerate the life course of an event by some constant. For more information, refer to the Wikipedia page [AFT Model](https://en.wikipedia.org/wiki/Accelerated_failure_time_model) and the references there. Different from a [Proportional Hazards Model](https://en.wikipedia.org/wiki/Proportional_hazards_model) designed for the same purpose, the AFT model is easier to parallelize because each instance contributes to the objective function independently. +```{r} +library(survival) +ovarianDF <- createDataFrame(ovarian) +aftModel <- spark.survreg(ovarianDF, Surv(futime, fustat) ~ ecog_ps + rx) +summary(aftModel) +aftPredictions <- predict(aftModel, ovarianDF) +head(aftPredictions) +``` + +#### Gaussian Mixture Model + +(Coming in 2.1.0) + +`spark.gaussianMixture` fits multivariate [Gaussian Mixture Model](https://en.wikipedia.org/wiki/Mixture_model#Multivariate_Gaussian_mixture_model) (GMM) against a `SparkDataFrame`. [Expectation-Maximization](https://en.wikipedia.org/wiki/Expectation%E2%80%93maximization_algorithm) (EM) is used to approximate the maximum likelihood estimator (MLE) of the model. + +We use a simulated example to demostrate the usage. +```{r} +X1 <- data.frame(V1 = rnorm(4), V2 = rnorm(4)) +X2 <- data.frame(V1 = rnorm(6, 3), V2 = rnorm(6, 4)) +data <- rbind(X1, X2) +df <- createDataFrame(data) +gmmModel <- spark.gaussianMixture(df, ~ V1 + V2, k = 2) +summary(gmmModel) +gmmFitted <- predict(gmmModel, df) +head(select(gmmFitted, "V1", "V2", "prediction")) +``` + + +#### Latent Dirichlet Allocation + +(Coming in 2.1.0) + +`spark.lda` fits a [Latent Dirichlet Allocation](https://en.wikipedia.org/wiki/Latent_Dirichlet_allocation) model on a `SparkDataFrame`. It is often used in topic modeling in which topics are inferred from a collection of text documents. LDA can be thought of as a clustering algorithm as follows: + +* Topics correspond to cluster centers, and documents correspond to examples (rows) in a dataset. + +* Topics and documents both exist in a feature space, where feature vectors are vectors of word counts (bag of words). + +* Rather than estimating a clustering using a traditional distance, LDA uses a function based on a statistical model of how text documents are generated. + +To use LDA, we need to specify a `features` column in `data` where each entry represents a document. There are two type options for the column: + +* character string: This can be a string of the whole document. It will be parsed automatically. Additional stop words can be added in `customizedStopWords`. + +* libSVM: Each entry is a collection of words and will be processed directly. + +There are several parameters LDA takes for fitting the model. + +* `k`: number of topics (default 10). + +* `maxIter`: maximum iterations (default 20). + +* `optimizer`: optimizer to train an LDA model, "online" (default) uses [online variational inference](https://www.cs.princeton.edu/~blei/papers/HoffmanBleiBach2010b.pdf). "em" uses [expectation-maximization](https://en.wikipedia.org/wiki/Expectation%E2%80%93maximization_algorithm). + +* `subsamplingRate`: For `optimizer = "online"`. Fraction of the corpus to be sampled and used in each iteration of mini-batch gradient descent, in range (0, 1] (default 0.05). + +* `topicConcentration`: concentration parameter (commonly named beta or eta) for the prior placed on topic distributions over terms, default -1 to set automatically on the Spark side. Use `summary` to retrieve the effective topicConcentration. Only 1-size numeric is accepted. + +* `docConcentration`: concentration parameter (commonly named alpha) for the prior placed on documents distributions over topics (theta), default -1 to set automatically on the Spark side. Use `summary` to retrieve the effective docConcentration. Only 1-size or k-size numeric is accepted. + +* `maxVocabSize`: maximum vocabulary size, default 1 << 18. + +Two more functions are provided for the fitted model. + +* `spark.posterior` returns a `SparkDataFrame` containing a column of posterior probabilities vectors named "topicDistribution". + +* `spark.perplexity` returns the log perplexity of given `SparkDataFrame`, or the log perplexity of the training data if missing argument `data`. + +For more information, see the help document `?spark.lda`. + +Let's look an artificial example. +```{r} +corpus <- data.frame(features = c( + "1 2 6 0 2 3 1 1 0 0 3", + "1 3 0 1 3 0 0 2 0 0 1", + "1 4 1 0 0 4 9 0 1 2 0", + "2 1 0 3 0 0 5 0 2 3 9", + "3 1 1 9 3 0 2 0 0 1 3", + "4 2 0 3 4 5 1 1 1 4 0", + "2 1 0 3 0 0 5 0 2 2 9", + "1 1 1 9 2 1 2 0 0 1 3", + "4 4 0 3 4 2 1 3 0 0 0", + "2 8 2 0 3 0 2 0 2 7 2", + "1 1 1 9 0 2 2 0 0 3 3", + "4 1 0 0 4 5 1 3 0 1 0")) +corpusDF <- createDataFrame(corpus) +model <- spark.lda(data = corpusDF, k = 5, optimizer = "em") +summary(model) +``` + +```{r} +posterior <- spark.posterior(model, corpusDF) +head(posterior) +``` + +```{r} +perplexity <- spark.perplexity(model, corpusDF) +perplexity +``` + + +#### Multilayer Perceptron + +(Coming in 2.1.0) + +Multilayer perceptron classifier (MLPC) is a classifier based on the [feedforward artificial neural network](https://en.wikipedia.org/wiki/Feedforward_neural_network). MLPC consists of multiple layers of nodes. Each layer is fully connected to the next layer in the network. Nodes in the input layer represent the input data. All other nodes map inputs to outputs by a linear combination of the inputs with the node’s weights $w$ and bias $b$ and applying an activation function. This can be written in matrix form for MLPC with $K+1$ layers as follows: +$$ +y(x)=f_K(\ldots f_2(w_2^T f_1(w_1^T x + b_1) + b_2) \ldots + b_K). +$$ + +Nodes in intermediate layers use sigmoid (logistic) function: +$$ +f(z_i) = \frac{1}{1+e^{-z_i}}. +$$ + +Nodes in the output layer use softmax function: +$$ +f(z_i) = \frac{e^{z_i}}{\sum_{k=1}^N e^{z_k}}. +$$ + +The number of nodes $N$ in the output layer corresponds to the number of classes. + +MLPC employs backpropagation for learning the model. We use the logistic loss function for optimization and L-BFGS as an optimization routine. + +`spark.mlp` requires at least two columns in `data`: one named `"label"` and the other one `"features"`. The `"features"` column should be in libSVM-format. According to the description above, there are several additional parameters that can be set: + +* `layers`: integer vector containing the number of nodes for each layer. + +* `solver`: solver parameter, supported options: `"gd"` (minibatch gradient descent) or `"l-bfgs"`. + +* `maxIter`: maximum iteration number. + +* `tol`: convergence tolerance of iterations. + +* `stepSize`: step size for `"gd"`. + +* `seed`: seed parameter for weights initialization. + +#### Collaborative Filtering + +(Coming in 2.1.0) + +`spark.als` learns latent factors in [collaborative filtering](https://en.wikipedia.org/wiki/Recommender_system#Collaborative_filtering) via [alternating least squares](http://dl.acm.org/citation.cfm?id=1608614). + +There are multiple options that can be configured in `spark.als`, including `rank`, `reg`, `nonnegative`. For a complete list, refer to the help file. + +```{r} +ratings <- list(list(0, 0, 4.0), list(0, 1, 2.0), list(1, 1, 3.0), list(1, 2, 4.0), + list(2, 1, 1.0), list(2, 2, 5.0)) +df <- createDataFrame(ratings, c("user", "item", "rating")) +model <- spark.als(df, "rating", "user", "item", rank = 10, reg = 0.1, nonnegative = TRUE) +``` + +Extract latent factors. +```{r} +stats <- summary(model) +userFactors <- stats$userFactors +itemFactors <- stats$itemFactors +head(userFactors) +head(itemFactors) +``` + +Make predictions. + +```{r} +predicted <- predict(model, df) +head(predicted) +``` + +#### Isotonic Regression Model + +(Coming in 2.1.0) + +`spark.isoreg` fits an [Isotonic Regression](https://en.wikipedia.org/wiki/Isotonic_regression) model against a `SparkDataFrame`. It solves a weighted univariate a regression problem under a complete order constraint. Specifically, given a set of real observed responses $y_1, \ldots, y_n$, corresponding real features $x_1, \ldots, x_n$, and optionally positive weights $w_1, \ldots, w_n$, we want to find a monotone (piecewise linear) function $f$ to minimize +$$ +\ell(f) = \sum_{i=1}^n w_i (y_i - f(x_i))^2. +$$ + +There are a few more arguments that may be useful. + +* `weightCol`: a character string specifying the weight column. + +* `isotonic`: logical value indicating whether the output sequence should be isotonic/increasing (`TRUE`) or antitonic/decreasing (`FALSE`). + +* `featureIndex`: the index of the feature on the right hand side of the formula if it is a vector column (default: 0), no effect otherwise. + +We use an artificial example to show the use. + +```{r} +y <- c(3.0, 6.0, 8.0, 5.0, 7.0) +x <- c(1.0, 2.0, 3.5, 3.0, 4.0) +w <- rep(1.0, 5) +data <- data.frame(y = y, x = x, w = w) +df <- createDataFrame(data) +isoregModel <- spark.isoreg(df, y ~ x, weightCol = "w") +isoregFitted <- predict(isoregModel, df) +head(select(isoregFitted, "x", "y", "prediction")) +``` + +In the prediction stage, based on the fitted monotone piecewise function, the rules are: + +* If the prediction input exactly matches a training feature then associated prediction is returned. In case there are multiple predictions with the same feature then one of them is returned. Which one is undefined. + +* If the prediction input is lower or higher than all training features then prediction with lowest or highest feature is returned respectively. In case there are multiple predictions with the same feature then the lowest or highest is returned respectively. + +* If the prediction input falls between two training features then prediction is treated as piecewise linear function and interpolated value is calculated from the predictions of the two closest features. In case there are multiple values with the same feature then the same rules as in previous point are used. + +For example, when the input is $3.2$, the two closest feature values are $3.0$ and $3.5$, then predicted value would be a linear interpolation between the predicted values at $3.0$ and $3.5$. + +```{r} +newDF <- createDataFrame(data.frame(x = c(1.5, 3.2))) +head(predict(isoregModel, newDF)) +``` + +#### What's More? +We also expect Decision Tree, Random Forest, Kolmogorov-Smirnov Test coming in the next version 2.1.0. + +### Model Persistence +The following example shows how to save/load an ML model by SparkR. +```{r} +irisDF <- suppressWarnings(createDataFrame(iris)) +gaussianGLM <- spark.glm(irisDF, Sepal_Length ~ Sepal_Width + Species, family = "gaussian") + +# Save and then load a fitted MLlib model +modelPath <- tempfile(pattern = "ml", fileext = ".tmp") +write.ml(gaussianGLM, modelPath) +gaussianGLM2 <- read.ml(modelPath) + +# Check model summary +summary(gaussianGLM2) + +# Check model prediction +gaussianPredictions <- predict(gaussianGLM2, irisDF) +head(gaussianPredictions) + +unlink(modelPath) +``` + + +## Advanced Topics + +### SparkR Object Classes + +There are three main object classes in SparkR you may be working with. + +* `SparkDataFrame`: the central component of SparkR. It is an S4 class representing distributed collection of data organized into named columns, which is conceptually equivalent to a table in a relational database or a data frame in R. It has two slots `sdf` and `env`. + + `sdf` stores a reference to the corresponding Spark Dataset in the Spark JVM backend. + + `env` saves the meta-information of the object such as `isCached`. + +It can be created by data import methods or by transforming an existing `SparkDataFrame`. We can manipulate `SparkDataFrame` by numerous data processing functions and feed that into machine learning algorithms. + +* `Column`: an S4 class representing column of `SparkDataFrame`. The slot `jc` saves a reference to the corresponding Column object in the Spark JVM backend. + +It can be obtained from a `SparkDataFrame` by `$` operator, `df$col`. More often, it is used together with other functions, for example, with `select` to select particular columns, with `filter` and constructed conditions to select rows, with aggregation functions to compute aggregate statistics for each group. + +* `GroupedData`: an S4 class representing grouped data created by `groupBy` or by transforming other `GroupedData`. Its `sgd` slot saves a reference to a RelationalGroupedDataset object in the backend. + +This is often an intermediate object with group information and followed up by aggregation operations. + +### Architecture + +A complete description of architecture can be seen in reference, in particular the paper *SparkR: Scaling R Programs with Spark*. + +Under the hood of SparkR is Spark SQL engine. This avoids the overheads of running interpreted R code, and the optimized SQL execution engine in Spark uses structural information about data and computation flow to perform a bunch of optimizations to speed up the computation. + +The main method calls of actual computation happen in the Spark JVM of the driver. We have a socket-based SparkR API that allows us to invoke functions on the JVM from R. We use a SparkR JVM backend that listens on a Netty-based socket server. + +Two kinds of RPCs are supported in the SparkR JVM backend: method invocation and creating new objects. Method invocation can be done in two ways. + +* `sparkR.invokeJMethod` takes a reference to an existing Java object and a list of arguments to be passed on to the method. + +* `sparkR.invokeJStatic` takes a class name for static method and a list of arguments to be passed on to the method. + +The arguments are serialized using our custom wire format which is then deserialized on the JVM side. We then use Java reflection to invoke the appropriate method. + +To create objects, `sparkR.newJObject` is used and then similarly the appropriate constructor is invoked with provided arguments. + +Finally, we use a new R class `jobj` that refers to a Java object existing in the backend. These references are tracked on the Java side and are automatically garbage collected when they go out of scope on the R side. + +## Appendix + +### R and Spark Data Types {#DataTypes} + +R | Spark +----------- | ------------- +byte | byte +integer | integer +float | float +double | double +numeric | double +character | string +string | string +binary | binary +raw | binary +logical | boolean +POSIXct | timestamp +POSIXlt | timestamp +Date | date +array | array +list | array +env | map + +## References + +* [Spark Cluster Mode Overview](http://spark.apache.org/docs/latest/cluster-overview.html) + +* [Submitting Spark Applications](http://spark.apache.org/docs/latest/submitting-applications.html) + +* [Machine Learning Library Guide (MLlib)](http://spark.apache.org/docs/latest/ml-guide.html) + +* [SparkR: Scaling R Programs with Spark](https://people.csail.mit.edu/matei/papers/2016/sigmod_sparkr.pdf), Shivaram Venkataraman, Zongheng Yang, Davies Liu, Eric Liang, Hossein Falaki, Xiangrui Meng, Reynold Xin, Ali Ghodsi, Michael Franklin, Ion Stoica, and Matei Zaharia. SIGMOD 2016. June 2016. + +```{r, echo=FALSE} +sparkR.session.stop() +``` From def7c265f539f3e119f068b6e9050300d05b14a4 Mon Sep 17 00:00:00 2001 From: Jagadeesan Date: Wed, 14 Sep 2016 09:03:16 +0100 Subject: [PATCH 005/306] =?UTF-8?q?[SPARK-17449][DOCUMENTATION]=20Relation?= =?UTF-8?q?=20between=20heartbeatInterval=20and=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? The relation between spark.network.timeout and spark.executor.heartbeatInterval should be mentioned in the document. … network timeout] Author: Jagadeesan Closes #15042 from jagadeesanas2/SPARK-17449. --- core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala | 1 + docs/configuration.md | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala index c3764ac671afb..5242ab6f55235 100644 --- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala +++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala @@ -32,6 +32,7 @@ import org.apache.spark.util._ * A heartbeat from executors to the driver. This is a shared message used by several internal * components to convey liveness or execution information for in-progress tasks. It will also * expire the hosts that have not heartbeated for more than spark.network.timeout. + * spark.executor.heartbeatInterval should be significantly less than spark.network.timeout. */ private[spark] case class Heartbeat( executorId: String, diff --git a/docs/configuration.md b/docs/configuration.md index ebd0aa796db08..8aea74505e28b 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -987,7 +987,8 @@ Apart from these, the following properties are also available, and may be useful 10s Interval between each executor's heartbeats to the driver. Heartbeats let the driver know that the executor is still alive and update it with metrics for in-progress - tasks. + tasks. spark.executor.heartbeatInterval should be significantly less than + spark.network.timeout spark.files.fetchTimeout From b5bfcddbfbc2e79d3d0fbd43942716946e6c4ba3 Mon Sep 17 00:00:00 2001 From: Sami Jaktholm Date: Wed, 14 Sep 2016 09:38:30 +0100 Subject: [PATCH 006/306] [SPARK-17525][PYTHON] Remove SparkContext.clearFiles() from the PySpark API as it was removed from the Scala API prior to Spark 2.0.0 ## What changes were proposed in this pull request? This pull request removes the SparkContext.clearFiles() method from the PySpark API as the method was removed from the Scala API in 8ce645d4eeda203cf5e100c4bdba2d71edd44e6a. Using that method in PySpark leads to an exception as PySpark tries to call the non-existent method on the JVM side. ## How was this patch tested? Existing tests (though none of them tested this particular method). Author: Sami Jaktholm Closes #15081 from sjakthol/pyspark-sc-clearfiles. --- python/pyspark/context.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 6e9f24ef1026b..2744bb9ec04e5 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -787,14 +787,6 @@ def addFile(self, path): """ self._jsc.sc().addFile(path) - def clearFiles(self): - """ - Clear the job's list of files added by L{addFile} or L{addPyFile} so - that they do not get downloaded to any new nodes. - """ - # TODO: remove added .py or .zip files from the PYTHONPATH? - self._jsc.sc().clearFiles() - def addPyFile(self, path): """ Add a .py or .zip dependency for all tasks to be executed on this From 18b4f035f40359b3164456d0dab52dbc762ea3b4 Mon Sep 17 00:00:00 2001 From: "wm624@hotmail.com" Date: Wed, 14 Sep 2016 09:49:15 +0100 Subject: [PATCH 007/306] [CORE][DOC] remove redundant comment ## What changes were proposed in this pull request? In the comment, there is redundant `the estimated`. This PR simply remove the redundant comment and adjusts format. Author: wm624@hotmail.com Closes #15091 from wangmiao1981/comment. --- .../spark/storage/memory/MemoryStore.scala | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala index 1a3bf2bb672c6..baa3fde2d05f1 100644 --- a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala @@ -169,12 +169,12 @@ private[spark] class MemoryStore( * temporary unroll memory used during the materialization is "transferred" to storage memory, * so we won't acquire more memory than is actually needed to store the block. * - * @return in case of success, the estimated the estimated size of the stored data. In case of - * failure, return an iterator containing the values of the block. The returned iterator - * will be backed by the combination of the partially-unrolled block and the remaining - * elements of the original input iterator. The caller must either fully consume this - * iterator or call `close()` on it in order to free the storage memory consumed by the - * partially-unrolled block. + * @return in case of success, the estimated size of the stored data. In case of failure, return + * an iterator containing the values of the block. The returned iterator will be backed + * by the combination of the partially-unrolled block and the remaining elements of the + * original input iterator. The caller must either fully consume this iterator or call + * `close()` on it in order to free the storage memory consumed by the partially-unrolled + * block. */ private[storage] def putIteratorAsValues[T]( blockId: BlockId, @@ -298,9 +298,9 @@ private[spark] class MemoryStore( * temporary unroll memory used during the materialization is "transferred" to storage memory, * so we won't acquire more memory than is actually needed to store the block. * - * @return in case of success, the estimated the estimated size of the stored data. In case of - * failure, return a handle which allows the caller to either finish the serialization - * by spilling to disk or to deserialize the partially-serialized block and reconstruct + * @return in case of success, the estimated size of the stored data. In case of failure, + * return a handle which allows the caller to either finish the serialization by + * spilling to disk or to deserialize the partially-serialized block and reconstruct * the original input iterator. The caller must either fully consume this result * iterator or call `discard()` on it in order to free the storage memory consumed by the * partially-unrolled block. From 4cea9da2ae88b40a5503111f8f37051e2372163e Mon Sep 17 00:00:00 2001 From: Ergin Seyfe Date: Wed, 14 Sep 2016 09:51:14 +0100 Subject: [PATCH 008/306] [SPARK-17480][SQL] Improve performance by removing or caching List.length which is O(n) ## What changes were proposed in this pull request? Scala's List.length method is O(N) and it makes the gatherCompressibilityStats function O(N^2). Eliminate the List.length calls by writing it in Scala way. https://github.com/scala/scala/blob/2.10.x/src/library/scala/collection/LinearSeqOptimized.scala#L36 As suggested. Extended the fix to HiveInspectors and AggregationIterator classes as well. ## How was this patch tested? Profiled a Spark job and found that CompressibleColumnBuilder is using 39% of the CPU. Out of this 39% CompressibleColumnBuilder->gatherCompressibilityStats is using 23% of it. 6.24% of the CPU is spend on List.length which is called inside gatherCompressibilityStats. After this change we started to save 6.24% of the CPU. Author: Ergin Seyfe Closes #15032 from seyfe/gatherCompressibilityStats. --- .../sql/execution/aggregate/AggregationIterator.scala | 7 ++++--- .../columnar/compression/CompressibleColumnBuilder.scala | 6 +----- .../scala/org/apache/spark/sql/hive/HiveInspectors.scala | 6 ++++-- 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala index dfed084fe64a2..f335912ba2c32 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala @@ -73,9 +73,10 @@ abstract class AggregationIterator( startingInputBufferOffset: Int): Array[AggregateFunction] = { var mutableBufferOffset = 0 var inputBufferOffset: Int = startingInputBufferOffset - val functions = new Array[AggregateFunction](expressions.length) + val expressionsLength = expressions.length + val functions = new Array[AggregateFunction](expressionsLength) var i = 0 - while (i < expressions.length) { + while (i < expressionsLength) { val func = expressions(i).aggregateFunction val funcWithBoundReferences: AggregateFunction = expressions(i).mode match { case Partial | Complete if func.isInstanceOf[ImperativeAggregate] => @@ -171,7 +172,7 @@ abstract class AggregationIterator( case PartialMerge | Final => (buffer: MutableRow, row: InternalRow) => ae.merge(buffer, row) } - } + }.toArray // This projection is used to merge buffer values for all expression-based aggregates. val aggregationBufferSchema = functions.flatMap(_.aggBufferAttributes) val updateProjection = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressibleColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressibleColumnBuilder.scala index 63eae1b8685ac..0f4680e502781 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressibleColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressibleColumnBuilder.scala @@ -66,11 +66,7 @@ private[columnar] trait CompressibleColumnBuilder[T <: AtomicType] } private def gatherCompressibilityStats(row: InternalRow, ordinal: Int): Unit = { - var i = 0 - while (i < compressionEncoders.length) { - compressionEncoders(i).gatherCompressibilityStats(row, ordinal) - i += 1 - } + compressionEncoders.foreach(_.gatherCompressibilityStats(row, ordinal)) } abstract override def appendFrom(row: InternalRow, ordinal: Int): Unit = { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index bf5cc17a68f57..4e74452f6cd12 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -756,7 +756,8 @@ private[hive] trait HiveInspectors { cache: Array[AnyRef], dataTypes: Array[DataType]): Array[AnyRef] = { var i = 0 - while (i < inspectors.length) { + val length = inspectors.length + while (i < length) { cache(i) = wrap(row.get(i, dataTypes(i)), inspectors(i), dataTypes(i)) i += 1 } @@ -769,7 +770,8 @@ private[hive] trait HiveInspectors { cache: Array[AnyRef], dataTypes: Array[DataType]): Array[AnyRef] = { var i = 0 - while (i < inspectors.length) { + val length = inspectors.length + while (i < length) { cache(i) = wrap(row(i), inspectors(i), dataTypes(i)) i += 1 } From dc0a4c916151c795dc41b5714e9d23b4937f4636 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Wed, 14 Sep 2016 10:10:16 +0100 Subject: [PATCH 009/306] [SPARK-17445][DOCS] Reference an ASF page as the main place to find third-party packages ## What changes were proposed in this pull request? Point references to spark-packages.org to https://cwiki.apache.org/confluence/display/SPARK/Third+Party+Projects This will be accompanied by a parallel change to the spark-website repo, and additional changes to this wiki. ## How was this patch tested? Jenkins tests. Author: Sean Owen Closes #15075 from srowen/SPARK-17445. --- CONTRIBUTING.md | 2 +- R/pkg/R/sparkR.R | 4 ++-- docs/_layouts/global.html | 2 +- docs/index.md | 2 +- docs/sparkr.md | 3 ++- docs/streaming-programming-guide.md | 2 +- .../spark/sql/execution/datasources/DataSource.scala | 7 ++++--- .../test/scala/org/apache/spark/sql/SQLQuerySuite.scala | 9 +++------ .../spark/sql/sources/ResolvedDataSourceSuite.scala | 6 +++--- 9 files changed, 18 insertions(+), 19 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index f10d7e277eea3..1a8206abe3838 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -6,7 +6,7 @@ It lists steps that are required before creating a PR. In particular, consider: - Is the change important and ready enough to ask the community to spend time reviewing? - Have you searched for existing, related JIRAs and pull requests? -- Is this a new feature that can stand alone as a package on http://spark-packages.org ? +- Is this a new feature that can stand alone as a [third party project](https://cwiki.apache.org/confluence/display/SPARK/Third+Party+Projects) ? - Is the change being proposed clearly explained and motivated? When you contribute code, you affirm that the contribution is your original work and that you diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index 15afe01c24ed2..06015362e6bc1 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -100,7 +100,7 @@ sparkR.stop <- function() { #' @param sparkEnvir Named list of environment variables to set on worker nodes #' @param sparkExecutorEnv Named list of environment variables to be used when launching executors #' @param sparkJars Character vector of jar files to pass to the worker nodes -#' @param sparkPackages Character vector of packages from spark-packages.org +#' @param sparkPackages Character vector of package coordinates #' @seealso \link{sparkR.session} #' @rdname sparkR.init-deprecated #' @export @@ -327,7 +327,7 @@ sparkRHive.init <- function(jsc = NULL) { #' @param sparkHome Spark Home directory. #' @param sparkConfig named list of Spark configuration to set on worker nodes. #' @param sparkJars character vector of jar files to pass to the worker nodes. -#' @param sparkPackages character vector of packages from spark-packages.org +#' @param sparkPackages character vector of package coordinates #' @param enableHiveSupport enable support for Hive, fallback if not built with Hive support; once #' set, this cannot be turned off on an existing session #' @param ... named Spark properties passed to the method. diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html index d3bf082aa751a..ad5b5c9adfac8 100755 --- a/docs/_layouts/global.html +++ b/docs/_layouts/global.html @@ -114,7 +114,7 @@
  • Building Spark
  • Contributing to Spark
  • -
  • Supplemental Projects
  • +
  • Third Party Projects
  • diff --git a/docs/index.md b/docs/index.md index 0cb8803783a0f..a7a92f6c4f6d7 100644 --- a/docs/index.md +++ b/docs/index.md @@ -120,7 +120,7 @@ options for deployment: * [OpenStack Swift](storage-openstack-swift.html) * [Building Spark](building-spark.html): build Spark using the Maven system * [Contributing to Spark](https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark) -* [Supplemental Projects](https://cwiki.apache.org/confluence/display/SPARK/Supplemental+Spark+Projects): related third party Spark projects +* [Third Party Projects](https://cwiki.apache.org/confluence/display/SPARK/Third+Party+Projects): related third party Spark projects **External Resources:** diff --git a/docs/sparkr.md b/docs/sparkr.md index 4bbc362c52086..b881119731045 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -110,7 +110,8 @@ head(df) SparkR supports operating on a variety of data sources through the `SparkDataFrame` interface. This section describes the general methods for loading and saving data using Data Sources. You can check the Spark SQL programming guide for more [specific options](sql-programming-guide.html#manually-specifying-options) that are available for the built-in data sources. -The general method for creating SparkDataFrames from data sources is `read.df`. This method takes in the path for the file to load and the type of data source, and the currently active SparkSession will be used automatically. SparkR supports reading JSON, CSV and Parquet files natively and through [Spark Packages](http://spark-packages.org/) you can find data source connectors for popular file formats like [Avro](http://spark-packages.org/package/databricks/spark-avro). These packages can either be added by +The general method for creating SparkDataFrames from data sources is `read.df`. This method takes in the path for the file to load and the type of data source, and the currently active SparkSession will be used automatically. +SparkR supports reading JSON, CSV and Parquet files natively, and through packages available from sources like [Third Party Projects](https://cwiki.apache.org/confluence/display/SPARK/Third+Party+Projects), you can find data source connectors for popular file formats like Avro. These packages can either be added by specifying `--packages` with `spark-submit` or `sparkR` commands, or if initializing SparkSession with `sparkPackages` parameter when in an interactive R shell or from RStudio.
    diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 5392b4a9bcf4b..43f1cf3e31871 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -2382,7 +2382,7 @@ additional effort may be necessary to achieve exactly-once semantics. There are - [Kafka Integration Guide](streaming-kafka-integration.html) - [Kinesis Integration Guide](streaming-kinesis-integration.html) - [Custom Receiver Guide](streaming-custom-receivers.html) -* Third-party DStream data sources can be found in [Spark Packages](https://spark-packages.org/) +* Third-party DStream data sources can be found in [Third Party Projects](https://cwiki.apache.org/confluence/display/SPARK/Third+Party+Projects) * API documentation - Scala docs * [StreamingContext](api/scala/index.html#org.apache.spark.streaming.StreamingContext) and diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 71807b771a95f..825c01365dd1e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -142,12 +142,13 @@ case class DataSource( } else if (provider.toLowerCase == "avro" || provider == "com.databricks.spark.avro") { throw new AnalysisException( - s"Failed to find data source: ${provider.toLowerCase}. Please use Spark " + - "package http://spark-packages.org/package/databricks/spark-avro") + s"Failed to find data source: ${provider.toLowerCase}. Please find an Avro " + + "package at " + + "https://cwiki.apache.org/confluence/display/SPARK/Third+Party+Projects") } else { throw new ClassNotFoundException( s"Failed to find data source: $provider. Please find packages at " + - "http://spark-packages.org", + "https://cwiki.apache.org/confluence/display/SPARK/Third+Party+Projects", error) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index a2164f9ae3d3e..3cc3b319f5a57 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1645,21 +1645,18 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { e = intercept[AnalysisException] { sql(s"select id from `com.databricks.spark.avro`.`file_path`") } - assert(e.message.contains("Failed to find data source: com.databricks.spark.avro. " + - "Please use Spark package http://spark-packages.org/package/databricks/spark-avro")) + assert(e.message.contains("Failed to find data source: com.databricks.spark.avro.")) // data source type is case insensitive e = intercept[AnalysisException] { sql(s"select id from Avro.`file_path`") } - assert(e.message.contains("Failed to find data source: avro. Please use Spark package " + - "http://spark-packages.org/package/databricks/spark-avro")) + assert(e.message.contains("Failed to find data source: avro.")) e = intercept[AnalysisException] { sql(s"select id from avro.`file_path`") } - assert(e.message.contains("Failed to find data source: avro. Please use Spark package " + - "http://spark-packages.org/package/databricks/spark-avro")) + assert(e.message.contains("Failed to find data source: avro.")) e = intercept[AnalysisException] { sql(s"select id from `org.apache.spark.sql.sources.HadoopFsRelationProvider`.`file_path`") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala index 5ea1f32433699..76ffb949f1293 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala @@ -74,16 +74,16 @@ class ResolvedDataSourceSuite extends SparkFunSuite { val error1 = intercept[AnalysisException] { getProvidingClass("avro") } - assert(error1.getMessage.contains("spark-packages")) + assert(error1.getMessage.contains("Failed to find data source: avro.")) val error2 = intercept[AnalysisException] { getProvidingClass("com.databricks.spark.avro") } - assert(error2.getMessage.contains("spark-packages")) + assert(error2.getMessage.contains("Failed to find data source: com.databricks.spark.avro.")) val error3 = intercept[ClassNotFoundException] { getProvidingClass("asfdwefasdfasdf") } - assert(error3.getMessage.contains("spark-packages")) + assert(error3.getMessage.contains("Failed to find data source: asfdwefasdfasdf.")) } } From 52738d4e099a19466ef909b77c24cab109548706 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 14 Sep 2016 23:10:20 +0800 Subject: [PATCH 010/306] [SPARK-17409][SQL] Do Not Optimize Query in CTAS More Than Once ### What changes were proposed in this pull request? As explained in https://github.com/apache/spark/pull/14797: >Some analyzer rules have assumptions on logical plans, optimizer may break these assumption, we should not pass an optimized query plan into QueryExecution (will be analyzed again), otherwise we may some weird bugs. For example, we have a rule for decimal calculation to promote the precision before binary operations, use PromotePrecision as placeholder to indicate that this rule should not apply twice. But a Optimizer rule will remove this placeholder, that break the assumption, then the rule applied twice, cause wrong result. We should not optimize the query in CTAS more than once. For example, ```Scala spark.range(99, 101).createOrReplaceTempView("tab1") val sqlStmt = "SELECT id, cast(id as long) * cast('1.0' as decimal(38, 18)) as num FROM tab1" sql(s"CREATE TABLE tab2 USING PARQUET AS $sqlStmt") checkAnswer(spark.table("tab2"), sql(sqlStmt)) ``` Before this PR, the results do not match ``` == Results == !== Correct Answer - 2 == == Spark Answer - 2 == ![100,100.000000000000000000] [100,null] [99,99.000000000000000000] [99,99.000000000000000000] ``` After this PR, the results match. ``` +---+----------------------+ |id |num | +---+----------------------+ |99 |99.000000000000000000 | |100|100.000000000000000000| +---+----------------------+ ``` In this PR, we do not treat the `query` in CTAS as a child. Thus, the `query` will not be optimized when optimizing CTAS statement. However, we still need to analyze it for normalizing and verifying the CTAS in the Analyzer. Thus, we do it in the analyzer rule `PreprocessDDL`, because so far only this rule needs the analyzed plan of the `query`. ### How was this patch tested? Added a test Author: gatorsmile Closes #15048 from gatorsmile/ctasOptimized. --- .../sql/catalyst/plans/logical/Command.scala | 7 +++++- .../analysis/UnsupportedOperationsSuite.scala | 5 +--- .../sql/execution/command/SetCommand.scala | 2 -- .../spark/sql/execution/command/cache.scala | 7 ------ .../sql/execution/command/commands.scala | 4 +--- .../sql/execution/command/databases.scala | 2 -- .../spark/sql/execution/command/ddl.scala | 6 ----- .../spark/sql/execution/datasources/ddl.scala | 12 +++++----- .../sql/execution/datasources/rules.scala | 24 ++++++++++++++----- .../spark/sql/internal/SessionState.scala | 2 +- .../sources/CreateTableAsSelectSuite.scala | 12 ++++++++++ .../spark/sql/hive/HiveSessionState.scala | 2 +- .../sql/hive/execution/HiveExplainSuite.scala | 6 ++--- 13 files changed, 49 insertions(+), 42 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Command.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Command.scala index 75a5b10d9ed04..64f57835c8898 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Command.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Command.scala @@ -17,9 +17,14 @@ package org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.catalyst.expressions.Attribute + /** * A logical node that represents a non-query command to be executed by the system. For example, * commands can be used by parsers to represent DDL operations. Commands, unlike queries, are * eagerly executed. */ -trait Command +trait Command extends LeafNode { + final override def children: Seq[LogicalPlan] = Seq.empty + override def output: Seq[Attribute] = Seq.empty +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala index 6df47acaba85b..ff1bb126f463d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala @@ -31,10 +31,7 @@ import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.IntegerType /** A dummy command for testing unsupported operations. */ -case class DummyCommand() extends LogicalPlan with Command { - override def output: Seq[Attribute] = Nil - override def children: Seq[LogicalPlan] = Nil -} +case class DummyCommand() extends Command class UnsupportedOperationsSuite extends SparkFunSuite { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala index b0e2d03af070d..af6def52d07d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala @@ -129,6 +129,4 @@ case object ResetCommand extends RunnableCommand with Logging { sparkSession.sessionState.conf.clear() Seq.empty[Row] } - - override val output: Seq[Attribute] = Seq.empty } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala index 697e2ff21159b..c31f4dc9aba4b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.execution.command import org.apache.spark.sql.{Dataset, Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan @@ -47,8 +46,6 @@ case class CacheTableCommand( Seq.empty[Row] } - - override def output: Seq[Attribute] = Seq.empty } @@ -58,8 +55,6 @@ case class UncacheTableCommand(tableIdent: TableIdentifier) extends RunnableComm sparkSession.catalog.uncacheTable(tableIdent.quotedString) Seq.empty[Row] } - - override def output: Seq[Attribute] = Seq.empty } /** @@ -71,6 +66,4 @@ case object ClearCacheCommand extends RunnableCommand { sparkSession.catalog.clearCache() Seq.empty[Row] } - - override def output: Seq[Attribute] = Seq.empty } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala index 424a962b5eb1c..698c625d617fc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala @@ -35,9 +35,7 @@ import org.apache.spark.sql.types._ * A logical command that is executed for its side-effects. `RunnableCommand`s are * wrapped in `ExecutedCommand` during execution. */ -trait RunnableCommand extends LogicalPlan with logical.Command { - override def output: Seq[Attribute] = Seq.empty - final override def children: Seq[LogicalPlan] = Seq.empty +trait RunnableCommand extends logical.Command { def run(sparkSession: SparkSession): Seq[Row] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/databases.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/databases.scala index 597ec27ce6698..e5a6a5f60b8a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/databases.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/databases.scala @@ -59,6 +59,4 @@ case class SetDatabaseCommand(databaseName: String) extends RunnableCommand { sparkSession.sessionState.catalog.setCurrentDatabase(databaseName) Seq.empty[Row] } - - override val output: Seq[Attribute] = Seq.empty } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index bc1c4f85e3315..dcda2f8d1c52a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -70,8 +70,6 @@ case class CreateDatabaseCommand( ifNotExists) Seq.empty[Row] } - - override val output: Seq[Attribute] = Seq.empty } @@ -101,8 +99,6 @@ case class DropDatabaseCommand( sparkSession.sessionState.catalog.dropDatabase(databaseName, ifExists, cascade) Seq.empty[Row] } - - override val output: Seq[Attribute] = Seq.empty } /** @@ -126,8 +122,6 @@ case class AlterDatabasePropertiesCommand( Seq.empty[Row] } - - override val output: Seq[Attribute] = Seq.empty } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index 1b1e2123b7c47..fa95af2648cf9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -20,13 +20,15 @@ package org.apache.spark.sql.execution.datasources import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.CatalogTable -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan} import org.apache.spark.sql.execution.command.RunnableCommand import org.apache.spark.sql.types._ -case class CreateTable(tableDesc: CatalogTable, mode: SaveMode, query: Option[LogicalPlan]) - extends LogicalPlan with Command { +case class CreateTable( + tableDesc: CatalogTable, + mode: SaveMode, + query: Option[LogicalPlan]) extends Command { assert(tableDesc.provider.isDefined, "The table to be created must have a provider.") if (query.isEmpty) { @@ -35,9 +37,7 @@ case class CreateTable(tableDesc: CatalogTable, mode: SaveMode, query: Option[Lo "create table without data insertion can only use ErrorIfExists or Ignore as SaveMode.") } - override def output: Seq[Attribute] = Seq.empty[Attribute] - - override def children: Seq[LogicalPlan] = query.toSeq + override def innerChildren: Seq[QueryPlan[_]] = query.toSeq } case class CreateTempViewUsing( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index fbf4063ff63b8..bd6eb6e0535ab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -66,9 +66,10 @@ class ResolveDataSource(sparkSession: SparkSession) extends Rule[LogicalPlan] { } /** - * Preprocess some DDL plans, e.g. [[CreateTable]], to do some normalization and checking. + * Analyze [[CreateTable]] and do some normalization and checking. + * For CREATE TABLE AS SELECT, the SELECT query is also analyzed. */ -case class PreprocessDDL(conf: SQLConf) extends Rule[LogicalPlan] { +case class AnalyzeCreateTable(sparkSession: SparkSession) extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { // When we CREATE TABLE without specifying the table schema, we should fail the query if @@ -95,9 +96,19 @@ case class PreprocessDDL(conf: SQLConf) extends Rule[LogicalPlan] { // * can't use all table columns as partition columns. // * partition columns' type must be AtomicType. // * sort columns' type must be orderable. - case c @ CreateTable(tableDesc, mode, query) if c.childrenResolved => - val schema = if (query.isDefined) query.get.schema else tableDesc.schema - val columnNames = if (conf.caseSensitiveAnalysis) { + case c @ CreateTable(tableDesc, mode, query) => + val analyzedQuery = query.map { q => + // Analyze the query in CTAS and then we can do the normalization and checking. + val qe = sparkSession.sessionState.executePlan(q) + qe.assertAnalyzed() + qe.analyzed + } + val schema = if (analyzedQuery.isDefined) { + analyzedQuery.get.schema + } else { + tableDesc.schema + } + val columnNames = if (sparkSession.sessionState.conf.caseSensitiveAnalysis) { schema.map(_.name) } else { schema.map(_.name.toLowerCase) @@ -106,7 +117,7 @@ case class PreprocessDDL(conf: SQLConf) extends Rule[LogicalPlan] { val partitionColsChecked = checkPartitionColumns(schema, tableDesc) val bucketColsChecked = checkBucketColumns(schema, partitionColsChecked) - c.copy(tableDesc = bucketColsChecked) + c.copy(tableDesc = bucketColsChecked, query = analyzedQuery) } private def checkPartitionColumns(schema: StructType, tableDesc: CatalogTable): CatalogTable = { @@ -176,6 +187,7 @@ case class PreprocessDDL(conf: SQLConf) extends Rule[LogicalPlan] { colName: String, colType: String): String = { val tableCols = schema.map(_.name) + val conf = sparkSession.sessionState.conf tableCols.find(conf.resolver(_, colName)).getOrElse { failAnalysis(s"$colType column $colName is not defined in table $tableIdent, " + s"defined table columns are: ${tableCols.mkString(", ")}") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index 8fdbd0f2c6dab..c899773b6b36f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -111,7 +111,7 @@ private[sql] class SessionState(sparkSession: SparkSession) { lazy val analyzer: Analyzer = { new Analyzer(catalog, conf) { override val extendedResolutionRules = - PreprocessDDL(conf) :: + AnalyzeCreateTable(sparkSession) :: PreprocessTableInsertion(conf) :: new FindDataSourceTable(sparkSession) :: DataSourceAnalysis(conf) :: diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala index 729c9fdda543e..344d4aa6cfea4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala @@ -236,4 +236,16 @@ class CreateTableAsSelectSuite assert(e.contains("Expected positive number of buckets, but got `0`")) } } + + test("CTAS of decimal calculation") { + withTable("tab2") { + withTempView("tab1") { + spark.range(99, 101).createOrReplaceTempView("tab1") + val sqlStmt = + "SELECT id, cast(id as long) * cast('1.0' as decimal(38, 18)) as num FROM tab1" + sql(s"CREATE TABLE tab2 USING PARQUET AS $sqlStmt") + checkAnswer(spark.table("tab2"), sql(sqlStmt)) + } + } + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala index 15e1255653f88..eb10c11382e83 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala @@ -60,7 +60,7 @@ private[hive] class HiveSessionState(sparkSession: SparkSession) override val extendedResolutionRules = catalog.ParquetConversions :: catalog.OrcConversions :: - PreprocessDDL(conf) :: + AnalyzeCreateTable(sparkSession) :: PreprocessTableInsertion(conf) :: DataSourceAnalysis(conf) :: (if (conf.runSQLonFile) new ResolveDataSource(sparkSession) :: Nil else Nil) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala index 98afd99a203ac..f9751e3d5f2eb 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala @@ -77,7 +77,7 @@ class HiveExplainSuite extends QueryTest with SQLTestUtils with TestHiveSingleto "src") } - test("SPARK-6212: The EXPLAIN output of CTAS only shows the analyzed plan") { + test("SPARK-17409: The EXPLAIN output of CTAS only shows the analyzed plan") { withTempView("jt") { val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str$i"}""")) spark.read.json(rdd).createOrReplaceTempView("jt") @@ -98,8 +98,8 @@ class HiveExplainSuite extends QueryTest with SQLTestUtils with TestHiveSingleto } val physicalIndex = outputs.indexOf("== Physical Plan ==") - assert(!outputs.substring(physicalIndex).contains("Subquery"), - "Physical Plan should not contain Subquery since it's eliminated by optimizer") + assert(outputs.substring(physicalIndex).contains("Subquery"), + "Physical Plan should contain SubqueryAlias since the query should not be optimized") } } From 6d06ff6f7e2dd72ba8fe96cd875e83eda6ebb2a9 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 14 Sep 2016 10:10:01 -0700 Subject: [PATCH 011/306] [SPARK-17514] df.take(1) and df.limit(1).collect() should perform the same in Python ## What changes were proposed in this pull request? In PySpark, `df.take(1)` runs a single-stage job which computes only one partition of the DataFrame, while `df.limit(1).collect()` computes all partitions and runs a two-stage job. This difference in performance is confusing. The reason why `limit(1).collect()` is so much slower is that `collect()` internally maps to `df.rdd..toLocalIterator`, which causes Spark SQL to build a query where a global limit appears in the middle of the plan; this, in turn, ends up being executed inefficiently because limits in the middle of plans are now implemented by repartitioning to a single task rather than by running a `take()` job on the driver (this was done in #7334, a patch which was a prerequisite to allowing partition-local limits to be pushed beneath unions, etc.). In order to fix this performance problem I think that we should generalize the fix from SPARK-10731 / #8876 so that `DataFrame.collect()` also delegates to the Scala implementation and shares the same performance properties. This patch modifies `DataFrame.collect()` to first collect all results to the driver and then pass them to Python, allowing this query to be planned using Spark's `CollectLimit` optimizations. ## How was this patch tested? Added a regression test in `sql/tests.py` which asserts that the expected number of jobs, stages, and tasks are run for both queries. Author: Josh Rosen Closes #15068 from JoshRosen/pyspark-collect-limit. --- python/pyspark/sql/dataframe.py | 5 +---- python/pyspark/sql/tests.py | 18 ++++++++++++++++++ .../scala/org/apache/spark/sql/Dataset.scala | 8 ++++++-- .../sql/execution/python/EvaluatePython.scala | 13 +------------ 4 files changed, 26 insertions(+), 18 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index e5eac918a93a0..0f7d8fba3bd54 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -357,10 +357,7 @@ def take(self, num): >>> df.take(2) [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] """ - with SCCallSiteSync(self._sc) as css: - port = self._sc._jvm.org.apache.spark.sql.execution.python.EvaluatePython.takeAndServe( - self._jdf, num) - return list(_load_from_socket(port, BatchedSerializer(PickleSerializer()))) + return self.limit(num).collect() @since(1.3) def foreach(self, f): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 769e4540720e7..1be0b72304ae8 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1862,6 +1862,24 @@ def test_collect_functions(self): sorted(df.select(functions.collect_list(df.value).alias('r')).collect()[0].r), ["1", "2", "2", "2"]) + def test_limit_and_take(self): + df = self.spark.range(1, 1000, numPartitions=10) + + def assert_runs_only_one_job_stage_and_task(job_group_name, f): + tracker = self.sc.statusTracker() + self.sc.setJobGroup(job_group_name, description="") + f() + jobs = tracker.getJobIdsForGroup(job_group_name) + self.assertEqual(1, len(jobs)) + stages = tracker.getJobInfo(jobs[0]).stageIds + self.assertEqual(1, len(stages)) + self.assertEqual(1, tracker.getStageInfo(stages[0]).numTasks) + + # Regression test for SPARK-10731: take should delegate to Scala implementation + assert_runs_only_one_job_stage_and_task("take", lambda: df.take(1)) + # Regression test for SPARK-17514: limit(n).collect() should the perform same as take(n) + assert_runs_only_one_job_stage_and_task("collect_limit", lambda: df.limit(1).collect()) + if __name__ == "__main__": from pyspark.sql.tests import * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 3b3cb820788a2..9cfbdffd02582 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -29,7 +29,7 @@ import org.apache.commons.lang3.StringUtils import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.java.function._ -import org.apache.spark.api.python.PythonRDD +import org.apache.spark.api.python.{PythonRDD, SerDeUtil} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst._ @@ -2567,8 +2567,12 @@ class Dataset[T] private[sql]( } private[sql] def collectToPython(): Int = { + EvaluatePython.registerPicklers() withNewExecutionId { - PythonRDD.collectAndServe(javaToPython.rdd) + val toJava: (Any) => Any = EvaluatePython.toJava(_, schema) + val iter = new SerDeUtil.AutoBatchedPickler( + queryExecution.executedPlan.executeCollect().iterator.map(toJava)) + PythonRDD.serveIterator(iter, "serve-DataFrame") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala index cf68ed4ec36a8..724025b4647f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala @@ -24,9 +24,8 @@ import scala.collection.JavaConverters._ import net.razorvine.pickle.{IObjectPickler, Opcodes, Pickler} -import org.apache.spark.api.python.{PythonRDD, SerDeUtil} +import org.apache.spark.api.python.SerDeUtil import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} @@ -34,16 +33,6 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String object EvaluatePython { - def takeAndServe(df: DataFrame, n: Int): Int = { - registerPicklers() - df.withNewExecutionId { - val iter = new SerDeUtil.AutoBatchedPickler( - df.queryExecution.executedPlan.executeTake(n).iterator.map { row => - EvaluatePython.toJava(row, df.schema) - }) - PythonRDD.serveIterator(iter, s"serve-DataFrame") - } - } def needConversionInPython(dt: DataType): Boolean = dt match { case DateType | TimestampType => true From a79838bdeeb12cec4d50da3948bd8a33777e53a6 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 15 Sep 2016 01:33:56 +0800 Subject: [PATCH 012/306] [MINOR][SQL] Add missing functions for some options in SQLConf and use them where applicable ## What changes were proposed in this pull request? I first thought they are missing because they are kind of hidden options but it seems they are just missing. For example, `spark.sql.parquet.mergeSchema` is documented in [sql-programming-guide.md](https://github.com/apache/spark/blob/master/docs/sql-programming-guide.md) but this function is missing whereas many options such as `spark.sql.join.preferSortMergeJoin` are not documented but have its own function individually. So, this PR suggests making them consistent by adding the missing functions for some options in `SQLConf` and use them where applicable, in order to make them more readable. ## How was this patch tested? Existing tests should cover this. Author: hyukjinkwon Closes #14678 from HyukjinKwon/sqlconf-cleanup. --- .../spark/sql/RelationalGroupedDataset.scala | 2 +- .../spark/sql/execution/QueryExecution.scala | 2 +- .../execution/datasources/DataSource.scala | 2 +- .../InsertIntoHadoopFsRelationCommand.scala | 2 +- .../PartitioningAwareFileCatalog.scala | 2 +- .../parquet/ParquetFileFormat.scala | 8 ++-- .../datasources/parquet/ParquetOptions.scala | 2 +- .../streaming/FileStreamSinkLog.scala | 6 +-- .../execution/streaming/StreamExecution.scala | 2 +- .../streaming/state/StateStoreConf.scala | 6 +-- .../apache/spark/sql/internal/SQLConf.scala | 42 ++++++++++++++----- .../sql/streaming/StreamingQueryManager.scala | 4 +- 12 files changed, 49 insertions(+), 31 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 53d732403f979..6c3fe07709fa3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -313,7 +313,7 @@ class RelationalGroupedDataset protected[sql]( */ def pivot(pivotColumn: String): RelationalGroupedDataset = { // This is to prevent unintended OOM errors when the number of distinct values is large - val maxValues = df.sparkSession.conf.get(SQLConf.DATAFRAME_PIVOT_MAX_VALUES) + val maxValues = df.sparkSession.sessionState.conf.dataFramePivotMaxValues // Get the distinct values of the column and sort them so its consistent val values = df.select(pivotColumn) .distinct() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index d4845637be049..383b3a233fc27 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -55,7 +55,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { } def assertSupported(): Unit = { - if (sparkSession.sessionState.conf.getConf(SQLConf.UNSUPPORTED_OPERATION_CHECK_ENABLED)) { + if (sparkSession.sessionState.conf.isUnsupportedOperationCheckEnabled) { UnsupportedOperationChecker.checkForBatch(analyzed) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 825c01365dd1e..93154bd2ca69c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -231,7 +231,7 @@ case class DataSource( } } - val isSchemaInferenceEnabled = sparkSession.conf.get(SQLConf.STREAMING_SCHEMA_INFERENCE) + val isSchemaInferenceEnabled = sparkSession.sessionState.conf.streamingSchemaInference val isTextSource = providingClass == classOf[text.TextFileFormat] // If the schema inference is disabled, only text sources require schema to be specified if (!isSchemaInferenceEnabled && !isTextSource && userSpecifiedSchema.isEmpty) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala index 02ce7fab64729..99ca3df673568 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala @@ -131,7 +131,7 @@ case class InsertIntoHadoopFsRelationCommand( dataColumns = dataColumns, inputSchema = query.output, PartitioningUtils.DEFAULT_PARTITION_NAME, - sparkSession.conf.get(SQLConf.PARTITION_MAX_FILES), + sparkSession.sessionState.conf.partitionMaxFiles, isAppend) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileCatalog.scala index cef9d4d9c7f1b..d2d5b56c82946 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileCatalog.scala @@ -126,7 +126,7 @@ abstract class PartitioningAwareFileCatalog( PartitioningUtils.parsePartitions( leafDirs, PartitioningUtils.DEFAULT_PARTITION_NAME, - typeInference = sparkSession.sessionState.conf.partitionColumnTypeInferenceEnabled(), + typeInference = sparkSession.sessionState.conf.partitionColumnTypeInferenceEnabled, basePaths = basePaths) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index 9208c82179d8d..e7c3545630fea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -151,7 +151,7 @@ class ParquetFileFormat // Should we merge schemas from all Parquet part-files? val shouldMergeSchemas = parquetOptions.mergeSchema - val mergeRespectSummaries = sparkSession.conf.get(SQLConf.PARQUET_SCHEMA_RESPECT_SUMMARIES) + val mergeRespectSummaries = sparkSession.sessionState.conf.isParquetSchemaRespectSummaries val filesByType = splitFiles(files) @@ -308,14 +308,14 @@ class ParquetFileFormat // Sets flags for `CatalystSchemaConverter` hadoopConf.setBoolean( SQLConf.PARQUET_BINARY_AS_STRING.key, - sparkSession.conf.get(SQLConf.PARQUET_BINARY_AS_STRING)) + sparkSession.sessionState.conf.isParquetBinaryAsString) hadoopConf.setBoolean( SQLConf.PARQUET_INT96_AS_TIMESTAMP.key, - sparkSession.conf.get(SQLConf.PARQUET_INT96_AS_TIMESTAMP)) + sparkSession.sessionState.conf.isParquetINT96AsTimestamp) // Try to push down filters when filter push-down is enabled. val pushed = - if (sparkSession.conf.get(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key).toBoolean) { + if (sparkSession.sessionState.conf.parquetFilterPushDown) { filters // Collects all converted Parquet filter predicates. Notice that not all predicates can be // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap` diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala index 3eec582714e15..615731889dfad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala @@ -52,7 +52,7 @@ private[parquet] class ParquetOptions( val mergeSchema: Boolean = parameters .get(MERGE_SCHEMA) .map(_.toBoolean) - .getOrElse(sqlConf.getConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED)) + .getOrElse(sqlConf.isParquetSchemaMergingEnabled) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLog.scala index 7520163522027..6f9f7c18c4dc7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLog.scala @@ -93,11 +93,11 @@ class FileStreamSinkLog(sparkSession: SparkSession, path: String) * a live lock may happen if the compaction happens too frequently: one processing keeps deleting * old files while another one keeps retrying. Setting a reasonable cleanup delay could avoid it. */ - private val fileCleanupDelayMs = sparkSession.conf.get(SQLConf.FILE_SINK_LOG_CLEANUP_DELAY) + private val fileCleanupDelayMs = sparkSession.sessionState.conf.fileSinkLogCleanupDelay - private val isDeletingExpiredLog = sparkSession.conf.get(SQLConf.FILE_SINK_LOG_DELETION) + private val isDeletingExpiredLog = sparkSession.sessionState.conf.fileSinkLogDeletion - private val compactInterval = sparkSession.conf.get(SQLConf.FILE_SINK_LOG_COMPACT_INTERVAL) + private val compactInterval = sparkSession.sessionState.conf.fileSinkLogCompatInterval require(compactInterval > 0, s"Please set ${SQLConf.FILE_SINK_LOG_COMPACT_INTERVAL.key} (was $compactInterval) " + "to a positive value.") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 5e1e5eeb50936..a1aae61107baf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -58,7 +58,7 @@ class StreamExecution( import org.apache.spark.sql.streaming.StreamingQueryListener._ - private val pollingDelayMs = sparkSession.conf.get(SQLConf.STREAMING_POLLING_DELAY) + private val pollingDelayMs = sparkSession.sessionState.conf.streamingPollingDelay /** * A lock used to wait/notify when batches complete. Use a fair lock to avoid thread starvation. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala index e55f63a6c8db8..de72f1cf2723d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala @@ -24,11 +24,9 @@ private[streaming] class StateStoreConf(@transient private val conf: SQLConf) ex def this() = this(new SQLConf) - import SQLConf._ + val minDeltasForSnapshot = conf.stateStoreMinDeltasForSnapshot - val minDeltasForSnapshot = conf.getConf(STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT) - - val minVersionsToRetain = conf.getConf(STATE_STORE_MIN_VERSIONS_TO_RETAIN) + val minVersionsToRetain = conf.stateStoreMinVersionsToRetain } private[streaming] object StateStoreConf { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 1d6ca5a965cbf..428032b1fba83 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -338,11 +338,6 @@ object SQLConf { .intConf .createWithDefault(4000) - val PARTITION_DISCOVERY_ENABLED = SQLConfigBuilder("spark.sql.sources.partitionDiscovery.enabled") - .doc("When true, automatically discover data partitions.") - .booleanConf - .createWithDefault(true) - val PARTITION_COLUMN_TYPE_INFERENCE = SQLConfigBuilder("spark.sql.sources.partitionColumnTypeInference.enabled") .doc("When true, automatically infer the data types for partitioned columns.") @@ -391,8 +386,10 @@ object SQLConf { val PARALLEL_PARTITION_DISCOVERY_THRESHOLD = SQLConfigBuilder("spark.sql.sources.parallelPartitionDiscovery.threshold") - .doc("The degree of parallelism for schema merging and partition discovery of " + - "Parquet data sources.") + .doc("The maximum number of files allowed for listing files at driver side. If the number " + + "of detected files exceeds this value during partition discovery, it tries to list the " + + "files with another Spark distributed job. This applies to Parquet, ORC, CSV, JSON and " + + "LibSVM data sources.") .intConf .createWithDefault(32) @@ -592,8 +589,24 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def optimizerInSetConversionThreshold: Int = getConf(OPTIMIZER_INSET_CONVERSION_THRESHOLD) + def stateStoreMinDeltasForSnapshot: Int = getConf(STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT) + + def stateStoreMinVersionsToRetain: Int = getConf(STATE_STORE_MIN_VERSIONS_TO_RETAIN) + def checkpointLocation: Option[String] = getConf(CHECKPOINT_LOCATION) + def isUnsupportedOperationCheckEnabled: Boolean = getConf(UNSUPPORTED_OPERATION_CHECK_ENABLED) + + def fileSinkLogDeletion: Boolean = getConf(FILE_SINK_LOG_DELETION) + + def fileSinkLogCompatInterval: Int = getConf(FILE_SINK_LOG_COMPACT_INTERVAL) + + def fileSinkLogCleanupDelay: Long = getConf(FILE_SINK_LOG_CLEANUP_DELAY) + + def streamingSchemaInference: Boolean = getConf(STREAMING_SCHEMA_INFERENCE) + + def streamingPollingDelay: Long = getConf(STREAMING_POLLING_DELAY) + def filesMaxPartitionBytes: Long = getConf(FILES_MAX_PARTITION_BYTES) def filesOpenCostInBytes: Long = getConf(FILES_OPEN_COST_IN_BYTES) @@ -657,6 +670,12 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def defaultSizeInBytes: Long = getConf(DEFAULT_SIZE_IN_BYTES, Long.MaxValue) + def isParquetSchemaMergingEnabled: Boolean = getConf(PARQUET_SCHEMA_MERGING_ENABLED) + + def isParquetSchemaRespectSummaries: Boolean = getConf(PARQUET_SCHEMA_RESPECT_SUMMARIES) + + def parquetOutputCommitterClass: String = getConf(PARQUET_OUTPUT_COMMITTER_CLASS) + def isParquetBinaryAsString: Boolean = getConf(PARQUET_BINARY_AS_STRING) def isParquetINT96AsTimestamp: Boolean = getConf(PARQUET_INT96_AS_TIMESTAMP) @@ -673,12 +692,11 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def convertCTAS: Boolean = getConf(CONVERT_CTAS) - def partitionDiscoveryEnabled(): Boolean = - getConf(SQLConf.PARTITION_DISCOVERY_ENABLED) - - def partitionColumnTypeInferenceEnabled(): Boolean = + def partitionColumnTypeInferenceEnabled: Boolean = getConf(SQLConf.PARTITION_COLUMN_TYPE_INFERENCE) + def partitionMaxFiles: Int = getConf(PARTITION_MAX_FILES) + def parallelPartitionDiscoveryThreshold: Int = getConf(SQLConf.PARALLEL_PARTITION_DISCOVERY_THRESHOLD) @@ -695,6 +713,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def dataFrameRetainGroupColumns: Boolean = getConf(DATAFRAME_RETAIN_GROUP_COLUMNS) + def dataFramePivotMaxValues: Int = getConf(DATAFRAME_PIVOT_MAX_VALUES) + override def runSQLonFile: Boolean = getConf(RUN_SQL_ON_FILES) def enableTwoLevelAggMap: Boolean = getConf(ENABLE_TWOLEVEL_AGG_MAP) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index bae7f56a23f81..bba7bc753eea9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -204,7 +204,7 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) { val checkpointLocation = userSpecifiedCheckpointLocation.map { userSpecified => new Path(userSpecified).toUri.toString }.orElse { - df.sparkSession.conf.get(SQLConf.CHECKPOINT_LOCATION).map { location => + df.sparkSession.sessionState.conf.checkpointLocation.map { location => new Path(location, name).toUri.toString } }.getOrElse { @@ -232,7 +232,7 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) { val analyzedPlan = df.queryExecution.analyzed df.queryExecution.assertAnalyzed() - if (sparkSession.conf.get(SQLConf.UNSUPPORTED_OPERATION_CHECK_ENABLED)) { + if (sparkSession.sessionState.conf.isUnsupportedOperationCheckEnabled) { UnsupportedOperationChecker.checkForStreaming(analyzedPlan, outputMode) } From 040e46979d5f90edc7f9be3cbedd87e8986e8053 Mon Sep 17 00:00:00 2001 From: Xin Wu Date: Wed, 14 Sep 2016 21:14:29 +0200 Subject: [PATCH 013/306] [SPARK-10747][SQL] Support NULLS FIRST|LAST clause in ORDER BY ## What changes were proposed in this pull request? Currently, ORDER BY clause returns nulls value according to sorting order (ASC|DESC), considering null value is always smaller than non-null values. However, SQL2003 standard support NULLS FIRST or NULLS LAST to allow users to specify whether null values should be returned first or last, regardless of sorting order (ASC|DESC). This PR is to support this new feature. ## How was this patch tested? New test cases are added to test NULLS FIRST|LAST for regular select queries and windowing queries. (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Author: Xin Wu Closes #14842 from xwu0226/SPARK-10747. --- .../unsafe/sort/PrefixComparators.java | 58 +++- .../unsafe/sort/UnsafeInMemorySorter.java | 11 +- .../unsafe/sort/RadixSortSuite.scala | 27 +- .../spark/sql/catalyst/parser/SqlBase.g4 | 7 +- .../sql/catalyst/analysis/Analyzer.scala | 4 +- .../SubstituteUnresolvedOrdinals.scala | 2 +- .../spark/sql/catalyst/dsl/package.scala | 3 +- .../sql/catalyst/expressions/SortOrder.scala | 65 ++++- .../codegen/GenerateOrdering.scala | 16 +- .../sql/catalyst/expressions/ordering.scala | 6 +- .../sql/catalyst/parser/AstBuilder.scala | 14 +- .../spark/sql/execution/SortPrefixUtils.scala | 68 ++++- .../spark/sql/execution/SparkPlan.scala | 2 +- .../inputs/orderby-nulls-ordering.sql | 83 ++++++ .../results/orderby-nulls-ordering.sql.out | 254 ++++++++++++++++++ .../spark/sql/execution/SortSuite.scala | 3 +- sql/hive/src/test/resources/sqlgen/agg2.sql | 2 +- sql/hive/src/test/resources/sqlgen/agg3.sql | 2 +- .../sqlgen/broadcast_join_subquery.sql | 2 +- .../sqlgen/generate_with_other_1.sql | 2 +- .../sqlgen/generate_with_other_2.sql | 2 +- .../resources/sqlgen/grouping_sets_2_1.sql | 2 +- .../resources/sqlgen/grouping_sets_2_2.sql | 2 +- .../resources/sqlgen/grouping_sets_2_3.sql | 2 +- .../resources/sqlgen/grouping_sets_2_4.sql | 2 +- .../resources/sqlgen/grouping_sets_2_5.sql | 2 +- .../test/resources/sqlgen/rollup_cube_6_1.sql | 2 +- .../test/resources/sqlgen/rollup_cube_6_2.sql | 2 +- .../test/resources/sqlgen/rollup_cube_6_3.sql | 2 +- .../test/resources/sqlgen/rollup_cube_6_4.sql | 2 +- .../resources/sqlgen/sort_asc_nulls_last.sql | 4 + .../resources/sqlgen/sort_by_after_having.sql | 2 +- .../sqlgen/sort_desc_nulls_first.sql | 4 + .../resources/sqlgen/subquery_in_having_1.sql | 2 +- .../resources/sqlgen/subquery_in_having_2.sql | 2 +- .../test/resources/sqlgen/window_basic_2.sql | 2 +- .../test/resources/sqlgen/window_basic_3.sql | 2 +- .../sqlgen/window_basic_asc_nulls_last.sql | 5 + .../sqlgen/window_basic_desc_nulls_first.sql | 5 + .../resources/sqlgen/window_with_join.sql | 2 +- .../window_with_the_same_window_with_agg.sql | 2 +- ...w_with_the_same_window_with_agg_filter.sql | 2 +- ...ith_the_same_window_with_agg_functions.sql | 2 +- ...w_with_the_same_window_with_agg_having.sql | 2 +- .../catalyst/ExpressionSQLBuilderSuite.scala | 6 +- .../sql/catalyst/LogicalPlanToSQLSuite.scala | 24 ++ 46 files changed, 639 insertions(+), 80 deletions(-) create mode 100644 sql/core/src/test/resources/sql-tests/inputs/orderby-nulls-ordering.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/orderby-nulls-ordering.sql.out create mode 100644 sql/hive/src/test/resources/sqlgen/sort_asc_nulls_last.sql create mode 100644 sql/hive/src/test/resources/sqlgen/sort_desc_nulls_first.sql create mode 100644 sql/hive/src/test/resources/sqlgen/window_basic_asc_nulls_last.sql create mode 100644 sql/hive/src/test/resources/sqlgen/window_basic_desc_nulls_first.sql diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java index c44630fbbc2f0..116c84943e855 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java @@ -29,12 +29,23 @@ private PrefixComparators() {} public static final PrefixComparator STRING = new UnsignedPrefixComparator(); public static final PrefixComparator STRING_DESC = new UnsignedPrefixComparatorDesc(); + public static final PrefixComparator STRING_NULLS_LAST = new UnsignedPrefixComparatorNullsLast(); + public static final PrefixComparator STRING_DESC_NULLS_FIRST = new UnsignedPrefixComparatorDescNullsFirst(); + public static final PrefixComparator BINARY = new UnsignedPrefixComparator(); public static final PrefixComparator BINARY_DESC = new UnsignedPrefixComparatorDesc(); + public static final PrefixComparator BINARY_NULLS_LAST = new UnsignedPrefixComparatorNullsLast(); + public static final PrefixComparator BINARY_DESC_NULLS_FIRST = new UnsignedPrefixComparatorDescNullsFirst(); + public static final PrefixComparator LONG = new SignedPrefixComparator(); public static final PrefixComparator LONG_DESC = new SignedPrefixComparatorDesc(); + public static final PrefixComparator LONG_NULLS_LAST = new SignedPrefixComparatorNullsLast(); + public static final PrefixComparator LONG_DESC_NULLS_FIRST = new SignedPrefixComparatorDescNullsFirst(); + public static final PrefixComparator DOUBLE = new UnsignedPrefixComparator(); public static final PrefixComparator DOUBLE_DESC = new UnsignedPrefixComparatorDesc(); + public static final PrefixComparator DOUBLE_NULLS_LAST = new UnsignedPrefixComparatorNullsLast(); + public static final PrefixComparator DOUBLE_DESC_NULLS_FIRST = new UnsignedPrefixComparatorDescNullsFirst(); public static final class StringPrefixComparator { public static long computePrefix(UTF8String value) { @@ -74,6 +85,9 @@ public abstract static class RadixSortSupport extends PrefixComparator { /** @return Whether the sort should take into account the sign bit. */ public abstract boolean sortSigned(); + + /** @return Whether the sort should put nulls first or last. */ + public abstract boolean nullsFirst(); } // @@ -83,16 +97,34 @@ public abstract static class RadixSortSupport extends PrefixComparator { public static final class UnsignedPrefixComparator extends RadixSortSupport { @Override public boolean sortDescending() { return false; } @Override public boolean sortSigned() { return false; } - @Override + @Override public boolean nullsFirst() { return true; } + public int compare(long aPrefix, long bPrefix) { + return UnsignedLongs.compare(aPrefix, bPrefix); + } + } + + public static final class UnsignedPrefixComparatorNullsLast extends RadixSortSupport { + @Override public boolean sortDescending() { return false; } + @Override public boolean sortSigned() { return false; } + @Override public boolean nullsFirst() { return false; } public int compare(long aPrefix, long bPrefix) { return UnsignedLongs.compare(aPrefix, bPrefix); } } + public static final class UnsignedPrefixComparatorDescNullsFirst extends RadixSortSupport { + @Override public boolean sortDescending() { return true; } + @Override public boolean sortSigned() { return false; } + @Override public boolean nullsFirst() { return true; } + public int compare(long bPrefix, long aPrefix) { + return UnsignedLongs.compare(aPrefix, bPrefix); + } + } + public static final class UnsignedPrefixComparatorDesc extends RadixSortSupport { @Override public boolean sortDescending() { return true; } @Override public boolean sortSigned() { return false; } - @Override + @Override public boolean nullsFirst() { return false; } public int compare(long bPrefix, long aPrefix) { return UnsignedLongs.compare(aPrefix, bPrefix); } @@ -101,16 +133,34 @@ public int compare(long bPrefix, long aPrefix) { public static final class SignedPrefixComparator extends RadixSortSupport { @Override public boolean sortDescending() { return false; } @Override public boolean sortSigned() { return true; } - @Override + @Override public boolean nullsFirst() { return true; } + public int compare(long a, long b) { + return (a < b) ? -1 : (a > b) ? 1 : 0; + } + } + + public static final class SignedPrefixComparatorNullsLast extends RadixSortSupport { + @Override public boolean sortDescending() { return false; } + @Override public boolean sortSigned() { return true; } + @Override public boolean nullsFirst() { return false; } public int compare(long a, long b) { return (a < b) ? -1 : (a > b) ? 1 : 0; } } + public static final class SignedPrefixComparatorDescNullsFirst extends RadixSortSupport { + @Override public boolean sortDescending() { return true; } + @Override public boolean sortSigned() { return true; } + @Override public boolean nullsFirst() { return true; } + public int compare(long b, long a) { + return (a < b) ? -1 : (a > b) ? 1 : 0; + } + } + public static final class SignedPrefixComparatorDesc extends RadixSortSupport { @Override public boolean sortDescending() { return true; } @Override public boolean sortSigned() { return true; } - @Override + @Override public boolean nullsFirst() { return false; } public int compare(long b, long a) { return (a < b) ? -1 : (a > b) ? 1 : 0; } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index 30d0f3006a04e..be382955c0d42 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -333,17 +333,18 @@ public UnsafeSorterIterator getSortedIterator() { if (nullBoundaryPos > 0) { assert radixSortSupport != null : "Nulls are only stored separately with radix sort"; LinkedList queue = new LinkedList<>(); - if (radixSortSupport.sortDescending()) { - // Nulls are smaller than non-nulls - queue.add(new SortedIterator((pos - nullBoundaryPos) / 2, offset)); + + // The null order is either LAST or FIRST, regardless of sorting direction (ASC|DESC) + if (radixSortSupport.nullsFirst()) { queue.add(new SortedIterator(nullBoundaryPos / 2, 0)); + queue.add(new SortedIterator((pos - nullBoundaryPos) / 2, offset)); } else { - queue.add(new SortedIterator(nullBoundaryPos / 2, 0)); queue.add(new SortedIterator((pos - nullBoundaryPos) / 2, offset)); + queue.add(new SortedIterator(nullBoundaryPos / 2, 0)); } return new UnsafeExternalSorter.ChainedIterator(queue); } else { return new SortedIterator(pos / 2, offset); } } -} +} \ No newline at end of file diff --git a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala index 2c13806410192..366ffda7788d3 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala @@ -40,23 +40,38 @@ class RadixSortSuite extends SparkFunSuite with Logging { case class RadixSortType( name: String, referenceComparator: PrefixComparator, - startByteIdx: Int, endByteIdx: Int, descending: Boolean, signed: Boolean) + startByteIdx: Int, endByteIdx: Int, descending: Boolean, signed: Boolean, nullsFirst: Boolean) val SORT_TYPES_TO_TEST = Seq( - RadixSortType("unsigned binary data asc", PrefixComparators.BINARY, 0, 7, false, false), - RadixSortType("unsigned binary data desc", PrefixComparators.BINARY_DESC, 0, 7, true, false), - RadixSortType("twos complement asc", PrefixComparators.LONG, 0, 7, false, true), - RadixSortType("twos complement desc", PrefixComparators.LONG_DESC, 0, 7, true, true), + RadixSortType("unsigned binary data asc nulls first", + PrefixComparators.BINARY, 0, 7, false, false, true), + RadixSortType("unsigned binary data asc nulls last", + PrefixComparators.BINARY_NULLS_LAST, 0, 7, false, false, false), + RadixSortType("unsigned binary data desc nulls last", + PrefixComparators.BINARY_DESC_NULLS_FIRST, 0, 7, true, false, false), + RadixSortType("unsigned binary data desc nulls first", + PrefixComparators.BINARY_DESC, 0, 7, true, false, true), + + RadixSortType("twos complement asc nulls first", + PrefixComparators.LONG, 0, 7, false, true, true), + RadixSortType("twos complement asc nulls last", + PrefixComparators.LONG_NULLS_LAST, 0, 7, false, true, false), + RadixSortType("twos complement desc nulls last", + PrefixComparators.LONG_DESC, 0, 7, true, true, false), + RadixSortType("twos complement desc nulls first", + PrefixComparators.LONG_DESC_NULLS_FIRST, 0, 7, true, true, true), + RadixSortType( "binary data partial", new PrefixComparators.RadixSortSupport { override def sortDescending = false override def sortSigned = false + override def nullsFirst = true override def compare(a: Long, b: Long): Int = { return PrefixComparators.BINARY.compare(a & 0xffffff0000L, b & 0xffffff0000L) } }, - 2, 4, false, false)) + 2, 4, false, false, true)) private def generateTestData(size: Int, rand: => Long): (Array[JLong], LongArray) = { val ref = Array.tabulate[Long](size) { i => rand } diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 9a643465a9994..b475abdce2da9 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -324,7 +324,7 @@ queryPrimary ; sortItem - : expression ordering=(ASC | DESC)? + : expression ordering=(ASC | DESC)? (NULLS nullOrder=(LAST | FIRST))? ; querySpecification @@ -641,7 +641,8 @@ number nonReserved : SHOW | TABLES | COLUMNS | COLUMN | PARTITIONS | FUNCTIONS | DATABASES | ADD - | OVER | PARTITION | RANGE | ROWS | PRECEDING | FOLLOWING | CURRENT | ROW | MAP | ARRAY | STRUCT + | OVER | PARTITION | RANGE | ROWS | PRECEDING | FOLLOWING | CURRENT | ROW | LAST | FIRST + | MAP | ARRAY | STRUCT | LATERAL | WINDOW | REDUCE | TRANSFORM | USING | SERDE | SERDEPROPERTIES | RECORDREADER | DELIMITED | FIELDS | TERMINATED | COLLECTION | ITEMS | KEYS | ESCAPED | LINES | SEPARATED | EXTENDED | REFRESH | CLEAR | CACHE | UNCACHE | LAZY | TEMPORARY | OPTIONS @@ -729,6 +730,8 @@ UNBOUNDED: 'UNBOUNDED'; PRECEDING: 'PRECEDING'; FOLLOWING: 'FOLLOWING'; CURRENT: 'CURRENT'; +FIRST: 'FIRST'; +LAST: 'LAST'; ROW: 'ROW'; WITH: 'WITH'; VALUES: 'VALUES'; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 18f814d6cdfd4..92bf8e0536fc4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -714,9 +714,9 @@ class Analyzer( case s @ Sort(orders, global, child) if orders.exists(_.child.isInstanceOf[UnresolvedOrdinal]) => val newOrders = orders map { - case s @ SortOrder(UnresolvedOrdinal(index), direction) => + case s @ SortOrder(UnresolvedOrdinal(index), direction, nullOrdering) => if (index > 0 && index <= child.output.size) { - SortOrder(child.output(index - 1), direction) + SortOrder(child.output(index - 1), direction, nullOrdering) } else { s.failAnalysis( s"ORDER BY position $index is not in select list " + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala index 6d8dc8628229a..af0a565f73ae9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala @@ -36,7 +36,7 @@ class SubstituteUnresolvedOrdinals(conf: CatalystConf) extends Rule[LogicalPlan] def apply(plan: LogicalPlan): LogicalPlan = plan transform { case s: Sort if conf.orderByOrdinal && s.order.exists(o => isIntLiteral(o.child)) => val newOrders = s.order.map { - case order @ SortOrder(ordinal @ Literal(index: Int, IntegerType), _) => + case order @ SortOrder(ordinal @ Literal(index: Int, IntegerType), _, _) => val newOrdinal = withOrigin(ordinal.origin)(UnresolvedOrdinal(index)) withOrigin(order.origin)(order.copy(child = newOrdinal)) case other => other diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 8549187a66369..66e52ca68af19 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -109,8 +109,9 @@ package object dsl { def cast(to: DataType): Expression = Cast(expr, to) def asc: SortOrder = SortOrder(expr, Ascending) + def asc_nullsLast: SortOrder = SortOrder(expr, Ascending, NullsLast) def desc: SortOrder = SortOrder(expr, Descending) - + def desc_nullsFirst: SortOrder = SortOrder(expr, Descending, NullsFirst) def as(alias: String): NamedExpression = Alias(expr, alias)() def as(alias: Symbol): NamedExpression = Alias(expr, alias.name)() } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index de779ed3702d3..d015125baccaf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -21,26 +21,43 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.types._ -import org.apache.spark.util.collection.unsafe.sort.PrefixComparators.BinaryPrefixComparator -import org.apache.spark.util.collection.unsafe.sort.PrefixComparators.DoublePrefixComparator +import org.apache.spark.util.collection.unsafe.sort.PrefixComparators._ abstract sealed class SortDirection { def sql: String + def defaultNullOrdering: NullOrdering +} + +abstract sealed class NullOrdering { + def sql: String } case object Ascending extends SortDirection { override def sql: String = "ASC" + override def defaultNullOrdering: NullOrdering = NullsFirst } case object Descending extends SortDirection { override def sql: String = "DESC" + override def defaultNullOrdering: NullOrdering = NullsLast +} + +case object NullsFirst extends NullOrdering{ + override def sql: String = "NULLS FIRST" +} + +case object NullsLast extends NullOrdering{ + override def sql: String = "NULLS LAST" } /** * An expression that can be used to sort a tuple. This class extends expression primarily so that * transformations over expression will descend into its child. */ -case class SortOrder(child: Expression, direction: SortDirection) +case class SortOrder( + child: Expression, + direction: SortDirection, + nullOrdering: NullOrdering) extends UnaryExpression with Unevaluable { /** Sort order is not foldable because we don't have an eval for it. */ @@ -57,12 +74,18 @@ case class SortOrder(child: Expression, direction: SortDirection) override def dataType: DataType = child.dataType override def nullable: Boolean = child.nullable - override def toString: String = s"$child ${direction.sql}" - override def sql: String = child.sql + " " + direction.sql + override def toString: String = s"$child ${direction.sql} ${nullOrdering.sql}" + override def sql: String = child.sql + " " + direction.sql + " " + nullOrdering.sql def isAscending: Boolean = direction == Ascending } +object SortOrder { + def apply(child: Expression, direction: SortDirection): SortOrder = { + new SortOrder(child, direction, direction.defaultNullOrdering) + } +} + /** * An expression to generate a 64-bit long prefix used in sorting. If the sort must operate over * null keys as well, this.nullValue can be used in place of emitted null prefixes in the sort. @@ -71,14 +94,35 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression { val nullValue = child.child.dataType match { case BooleanType | DateType | TimestampType | _: IntegralType => - Long.MinValue + if (nullAsSmallest) { + Long.MinValue + } else { + Long.MaxValue + } case dt: DecimalType if dt.precision - dt.scale <= Decimal.MAX_LONG_DIGITS => - Long.MinValue + if (nullAsSmallest) { + Long.MinValue + } else { + Long.MaxValue + } case _: DecimalType => - DoublePrefixComparator.computePrefix(Double.NegativeInfinity) - case _ => 0L + if (nullAsSmallest) { + DoublePrefixComparator.computePrefix(Double.NegativeInfinity) + } else { + DoublePrefixComparator.computePrefix(Double.NaN) + } + case _ => + if (nullAsSmallest) { + 0L + } else { + -1L + } } + private def nullAsSmallest: Boolean = (child.isAscending && child.nullOrdering == NullsFirst) || + (!child.isAscending && child.nullOrdering == NullsLast) + + override def eval(input: InternalRow): Any = throw new UnsupportedOperationException override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { @@ -86,6 +130,7 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression { val input = childCode.value val BinaryPrefixCmp = classOf[BinaryPrefixComparator].getName val DoublePrefixCmp = classOf[DoublePrefixComparator].getName + val StringPrefixCmp = classOf[StringPrefixComparator].getName val prefixCode = child.child.dataType match { case BooleanType => s"$input ? 1L : 0L" @@ -95,7 +140,7 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression { s"(long) $input" case FloatType | DoubleType => s"$DoublePrefixCmp.computePrefix((double)$input)" - case StringType => s"$input.getPrefix()" + case StringType => s"$StringPrefixCmp.computePrefix($input)" case BinaryType => s"$BinaryPrefixCmp.computePrefix($input)" case dt: DecimalType if dt.precision - dt.scale <= Decimal.MAX_LONG_DIGITS => if (dt.precision <= Decimal.MAX_LONG_DIGITS) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index f4d35d232e691..e7df95e1142ca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -63,7 +63,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR */ def genComparisons(ctx: CodegenContext, schema: StructType): String = { val ordering = schema.fields.map(_.dataType).zipWithIndex.map { - case(dt, index) => new SortOrder(BoundReference(index, dt, nullable = true), Ascending) + case(dt, index) => SortOrder(BoundReference(index, dt, nullable = true), Ascending) } genComparisons(ctx, ordering) } @@ -74,7 +74,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR def genComparisons(ctx: CodegenContext, ordering: Seq[SortOrder]): String = { val comparisons = ordering.map { order => val eval = order.child.genCode(ctx) - val asc = order.direction == Ascending + val asc = order.isAscending val isNullA = ctx.freshName("isNullA") val primitiveA = ctx.freshName("primitiveA") val isNullB = ctx.freshName("isNullB") @@ -99,9 +99,17 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR if ($isNullA && $isNullB) { // Nothing } else if ($isNullA) { - return ${if (order.direction == Ascending) "-1" else "1"}; + return ${ + order.nullOrdering match { + case NullsFirst => "-1" + case NullsLast => "1" + }}; } else if ($isNullB) { - return ${if (order.direction == Ascending) "1" else "-1"}; + return ${ + order.nullOrdering match { + case NullsFirst => "1" + case NullsLast => "-1" + }}; } else { int comp = ${ctx.genComp(order.child.dataType, primitiveA, primitiveB)}; if (comp != 0) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala index 6112259fed619..79d2052c38a27 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala @@ -39,9 +39,9 @@ class InterpretedOrdering(ordering: Seq[SortOrder]) extends Ordering[InternalRow if (left == null && right == null) { // Both null, continue looking. } else if (left == null) { - return if (order.direction == Ascending) -1 else 1 + return if (order.nullOrdering == NullsFirst) -1 else 1 } else if (right == null) { - return if (order.direction == Ascending) 1 else -1 + return if (order.nullOrdering == NullsFirst) 1 else -1 } else { val comparison = order.dataType match { case dt: AtomicType if order.direction == Ascending => @@ -76,7 +76,7 @@ object InterpretedOrdering { */ def forSchema(dataTypes: Seq[DataType]): InterpretedOrdering = { new InterpretedOrdering(dataTypes.zipWithIndex.map { - case (dt, index) => new SortOrder(BoundReference(index, dt, nullable = true), Ascending) + case (dt, index) => SortOrder(BoundReference(index, dt, nullable = true), Ascending) }) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index bbbb14df88f8c..69d68fa6f92ef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1206,11 +1206,19 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { * Create a [[SortOrder]] expression. */ override def visitSortItem(ctx: SortItemContext): SortOrder = withOrigin(ctx) { - if (ctx.DESC != null) { - SortOrder(expression(ctx.expression), Descending) + val direction = if (ctx.DESC != null) { + Descending } else { - SortOrder(expression(ctx.expression), Ascending) + Ascending } + val nullOrdering = if (ctx.FIRST != null) { + NullsFirst + } else if (ctx.LAST != null) { + NullsLast + } else { + direction.defaultNullOrdering + } + SortOrder(expression(ctx.expression), direction, nullOrdering) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala index 940467e74d597..c6665d273fd27 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala @@ -40,22 +40,70 @@ object SortPrefixUtils { def getPrefixComparator(sortOrder: SortOrder): PrefixComparator = { sortOrder.dataType match { - case StringType => - if (sortOrder.isAscending) PrefixComparators.STRING else PrefixComparators.STRING_DESC - case BinaryType => - if (sortOrder.isAscending) PrefixComparators.BINARY else PrefixComparators.BINARY_DESC + case StringType => stringPrefixComparator(sortOrder) + case BinaryType => binaryPrefixComparator(sortOrder) case BooleanType | ByteType | ShortType | IntegerType | LongType | DateType | TimestampType => - if (sortOrder.isAscending) PrefixComparators.LONG else PrefixComparators.LONG_DESC + longPrefixComparator(sortOrder) case dt: DecimalType if dt.precision - dt.scale <= Decimal.MAX_LONG_DIGITS => - if (sortOrder.isAscending) PrefixComparators.LONG else PrefixComparators.LONG_DESC - case FloatType | DoubleType => - if (sortOrder.isAscending) PrefixComparators.DOUBLE else PrefixComparators.DOUBLE_DESC - case dt: DecimalType => - if (sortOrder.isAscending) PrefixComparators.DOUBLE else PrefixComparators.DOUBLE_DESC + longPrefixComparator(sortOrder) + case FloatType | DoubleType => doublePrefixComparator(sortOrder) + case dt: DecimalType => doublePrefixComparator(sortOrder) case _ => NoOpPrefixComparator } } + private def stringPrefixComparator(sortOrder: SortOrder): PrefixComparator = { + sortOrder.direction match { + case Ascending if (sortOrder.nullOrdering == NullsLast) => + PrefixComparators.STRING_NULLS_LAST + case Ascending => + PrefixComparators.STRING + case Descending if (sortOrder.nullOrdering == NullsFirst) => + PrefixComparators.STRING_DESC_NULLS_FIRST + case Descending => + PrefixComparators.STRING_DESC + } + } + + private def binaryPrefixComparator(sortOrder: SortOrder): PrefixComparator = { + sortOrder.direction match { + case Ascending if (sortOrder.nullOrdering == NullsLast) => + PrefixComparators.BINARY_NULLS_LAST + case Ascending => + PrefixComparators.BINARY + case Descending if (sortOrder.nullOrdering == NullsFirst) => + PrefixComparators.BINARY_DESC_NULLS_FIRST + case Descending => + PrefixComparators.BINARY_DESC + } + } + + private def longPrefixComparator(sortOrder: SortOrder): PrefixComparator = { + sortOrder.direction match { + case Ascending if (sortOrder.nullOrdering == NullsLast) => + PrefixComparators.LONG_NULLS_LAST + case Ascending => + PrefixComparators.LONG + case Descending if (sortOrder.nullOrdering == NullsFirst) => + PrefixComparators.LONG_DESC_NULLS_FIRST + case Descending => + PrefixComparators.LONG_DESC + } + } + + private def doublePrefixComparator(sortOrder: SortOrder): PrefixComparator = { + sortOrder.direction match { + case Ascending if (sortOrder.nullOrdering == NullsLast) => + PrefixComparators.DOUBLE_NULLS_LAST + case Ascending => + PrefixComparators.DOUBLE + case Descending if (sortOrder.nullOrdering == NullsFirst) => + PrefixComparators.DOUBLE_DESC_NULLS_FIRST + case Descending => + PrefixComparators.DOUBLE_DESC + } + } + /** * Creates the prefix comparator for the first field in the given schema, in ascending order. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 6a2d97c9b1797..6aeefa6eddafe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -368,7 +368,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ */ protected def newNaturalAscendingOrdering(dataTypes: Seq[DataType]): Ordering[InternalRow] = { val order: Seq[SortOrder] = dataTypes.zipWithIndex.map { - case (dt, index) => new SortOrder(BoundReference(index, dt, nullable = true), Ascending) + case (dt, index) => SortOrder(BoundReference(index, dt, nullable = true), Ascending) } newOrdering(order, Seq.empty) } diff --git a/sql/core/src/test/resources/sql-tests/inputs/orderby-nulls-ordering.sql b/sql/core/src/test/resources/sql-tests/inputs/orderby-nulls-ordering.sql new file mode 100644 index 0000000000000..f7637b444b9fe --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/orderby-nulls-ordering.sql @@ -0,0 +1,83 @@ +-- Q1. testing window functions with order by +create table spark_10747(col1 int, col2 int, col3 int) using parquet; + +-- Q2. insert to tables +INSERT INTO spark_10747 VALUES (6, 12, 10), (6, 11, 4), (6, 9, 10), (6, 15, 8), +(6, 15, 8), (6, 7, 4), (6, 7, 8), (6, 13, null), (6, 10, null); + +-- Q3. windowing with order by DESC NULLS LAST +select col1, col2, col3, sum(col2) + over (partition by col1 + order by col3 desc nulls last, col2 + rows between 2 preceding and 2 following ) as sum_col2 +from spark_10747 where col1 = 6 order by sum_col2; + +-- Q4. windowing with order by DESC NULLS FIRST +select col1, col2, col3, sum(col2) + over (partition by col1 + order by col3 desc nulls first, col2 + rows between 2 preceding and 2 following ) as sum_col2 +from spark_10747 where col1 = 6 order by sum_col2; + +-- Q5. windowing with order by ASC NULLS LAST +select col1, col2, col3, sum(col2) + over (partition by col1 + order by col3 asc nulls last, col2 + rows between 2 preceding and 2 following ) as sum_col2 +from spark_10747 where col1 = 6 order by sum_col2; + +-- Q6. windowing with order by ASC NULLS FIRST +select col1, col2, col3, sum(col2) + over (partition by col1 + order by col3 asc nulls first, col2 + rows between 2 preceding and 2 following ) as sum_col2 +from spark_10747 where col1 = 6 order by sum_col2; + +-- Q7. Regular query with ORDER BY ASC NULLS FIRST +SELECT COL1, COL2, COL3 FROM spark_10747 ORDER BY COL3 ASC NULLS FIRST, COL2; + +-- Q8. Regular query with ORDER BY ASC NULLS LAST +SELECT COL1, COL2, COL3 FROM spark_10747 ORDER BY COL3 NULLS LAST, COL2; + +-- Q9. Regular query with ORDER BY DESC NULLS FIRST +SELECT COL1, COL2, COL3 FROM spark_10747 ORDER BY COL3 DESC NULLS FIRST, COL2; + +-- Q10. Regular query with ORDER BY DESC NULLS LAST +SELECT COL1, COL2, COL3 FROM spark_10747 ORDER BY COL3 DESC NULLS LAST, COL2; + +-- drop the test table +drop table spark_10747; + +-- Q11. mix datatype for ORDER BY NULLS FIRST|LAST +create table spark_10747_mix( +col1 string, +col2 int, +col3 double, +col4 decimal(10,2), +col5 decimal(20,1)) +using parquet; + +-- Q12. Insert to the table +INSERT INTO spark_10747_mix VALUES +('b', 2, 1.0, 1.00, 10.0), +('d', 3, 2.0, 3.00, 0.0), +('c', 3, 2.0, 2.00, 15.1), +('d', 3, 0.0, 3.00, 1.0), +(null, 3, 0.0, 3.00, 1.0), +('d', 3, null, 4.00, 1.0), +('a', 1, 1.0, 1.00, null), +('c', 3, 2.0, 2.00, null); + +-- Q13. Regular query with 2 NULLS LAST columns +select * from spark_10747_mix order by col1 nulls last, col5 nulls last; + +-- Q14. Regular query with 2 NULLS FIRST columns +select * from spark_10747_mix order by col1 desc nulls first, col5 desc nulls first; + +-- Q15. Regular query with mixed NULLS FIRST|LAST +select * from spark_10747_mix order by col5 desc nulls first, col3 desc nulls last; + +-- drop the test table +drop table spark_10747_mix; + + diff --git a/sql/core/src/test/resources/sql-tests/results/orderby-nulls-ordering.sql.out b/sql/core/src/test/resources/sql-tests/results/orderby-nulls-ordering.sql.out new file mode 100644 index 0000000000000..c1b63dfb8caef --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/orderby-nulls-ordering.sql.out @@ -0,0 +1,254 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 17 + + +-- !query 0 +create table spark_10747(col1 int, col2 int, col3 int) using parquet +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +INSERT INTO spark_10747 VALUES (6, 12, 10), (6, 11, 4), (6, 9, 10), (6, 15, 8), +(6, 15, 8), (6, 7, 4), (6, 7, 8), (6, 13, null), (6, 10, null) +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +select col1, col2, col3, sum(col2) + over (partition by col1 + order by col3 desc nulls last, col2 + rows between 2 preceding and 2 following ) as sum_col2 +from spark_10747 where col1 = 6 order by sum_col2 +-- !query 2 schema +struct +-- !query 2 output +6 9 10 28 +6 13 NULL 34 +6 10 NULL 41 +6 12 10 43 +6 15 8 55 +6 15 8 56 +6 11 4 56 +6 7 8 58 +6 7 4 58 + + +-- !query 3 +select col1, col2, col3, sum(col2) + over (partition by col1 + order by col3 desc nulls first, col2 + rows between 2 preceding and 2 following ) as sum_col2 +from spark_10747 where col1 = 6 order by sum_col2 +-- !query 3 schema +struct +-- !query 3 output +6 10 NULL 32 +6 11 4 33 +6 13 NULL 44 +6 7 4 48 +6 9 10 51 +6 15 8 55 +6 12 10 56 +6 15 8 56 +6 7 8 58 + + +-- !query 4 +select col1, col2, col3, sum(col2) + over (partition by col1 + order by col3 asc nulls last, col2 + rows between 2 preceding and 2 following ) as sum_col2 +from spark_10747 where col1 = 6 order by sum_col2 +-- !query 4 schema +struct +-- !query 4 output +6 7 4 25 +6 13 NULL 35 +6 11 4 40 +6 10 NULL 44 +6 7 8 55 +6 15 8 57 +6 15 8 58 +6 12 10 59 +6 9 10 61 + + +-- !query 5 +select col1, col2, col3, sum(col2) + over (partition by col1 + order by col3 asc nulls first, col2 + rows between 2 preceding and 2 following ) as sum_col2 +from spark_10747 where col1 = 6 order by sum_col2 +-- !query 5 schema +struct +-- !query 5 output +6 10 NULL 30 +6 12 10 36 +6 13 NULL 41 +6 7 4 48 +6 9 10 51 +6 11 4 53 +6 7 8 55 +6 15 8 57 +6 15 8 58 + + +-- !query 6 +SELECT COL1, COL2, COL3 FROM spark_10747 ORDER BY COL3 ASC NULLS FIRST, COL2 +-- !query 6 schema +struct +-- !query 6 output +6 10 NULL +6 13 NULL +6 7 4 +6 11 4 +6 7 8 +6 15 8 +6 15 8 +6 9 10 +6 12 10 + + +-- !query 7 +SELECT COL1, COL2, COL3 FROM spark_10747 ORDER BY COL3 NULLS LAST, COL2 +-- !query 7 schema +struct +-- !query 7 output +6 7 4 +6 11 4 +6 7 8 +6 15 8 +6 15 8 +6 9 10 +6 12 10 +6 10 NULL +6 13 NULL + + +-- !query 8 +SELECT COL1, COL2, COL3 FROM spark_10747 ORDER BY COL3 DESC NULLS FIRST, COL2 +-- !query 8 schema +struct +-- !query 8 output +6 10 NULL +6 13 NULL +6 9 10 +6 12 10 +6 7 8 +6 15 8 +6 15 8 +6 7 4 +6 11 4 + + +-- !query 9 +SELECT COL1, COL2, COL3 FROM spark_10747 ORDER BY COL3 DESC NULLS LAST, COL2 +-- !query 9 schema +struct +-- !query 9 output +6 9 10 +6 12 10 +6 7 8 +6 15 8 +6 15 8 +6 7 4 +6 11 4 +6 10 NULL +6 13 NULL + + +-- !query 10 +drop table spark_10747 +-- !query 10 schema +struct<> +-- !query 10 output + + + +-- !query 11 +create table spark_10747_mix( +col1 string, +col2 int, +col3 double, +col4 decimal(10,2), +col5 decimal(20,1)) +using parquet +-- !query 11 schema +struct<> +-- !query 11 output + + + +-- !query 12 +INSERT INTO spark_10747_mix VALUES +('b', 2, 1.0, 1.00, 10.0), +('d', 3, 2.0, 3.00, 0.0), +('c', 3, 2.0, 2.00, 15.1), +('d', 3, 0.0, 3.00, 1.0), +(null, 3, 0.0, 3.00, 1.0), +('d', 3, null, 4.00, 1.0), +('a', 1, 1.0, 1.00, null), +('c', 3, 2.0, 2.00, null) +-- !query 12 schema +struct<> +-- !query 12 output + + + +-- !query 13 +select * from spark_10747_mix order by col1 nulls last, col5 nulls last +-- !query 13 schema +struct +-- !query 13 output +a 1 1.0 1 NULL +b 2 1.0 1 10 +c 3 2.0 2 15.1 +c 3 2.0 2 NULL +d 3 2.0 3 0 +d 3 0.0 3 1 +d 3 NULL 4 1 +NULL 3 0.0 3 1 + + +-- !query 14 +select * from spark_10747_mix order by col1 desc nulls first, col5 desc nulls first +-- !query 14 schema +struct +-- !query 14 output +NULL 3 0.0 3 1 +d 3 0.0 3 1 +d 3 NULL 4 1 +d 3 2.0 3 0 +c 3 2.0 2 NULL +c 3 2.0 2 15.1 +b 2 1.0 1 10 +a 1 1.0 1 NULL + + +-- !query 15 +select * from spark_10747_mix order by col5 desc nulls first, col3 desc nulls last +-- !query 15 schema +struct +-- !query 15 output +c 3 2.0 2 NULL +a 1 1.0 1 NULL +c 3 2.0 2 15.1 +b 2 1.0 1 10 +d 3 0.0 3 1 +NULL 3 0.0 3 1 +d 3 NULL 4 1 +d 3 2.0 3 0 + + +-- !query 16 +drop table spark_10747_mix +-- !query 16 schema +struct<> +-- !query 16 output + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala index ba3fa3732d0df..a7bbe34f4eedb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala @@ -101,7 +101,8 @@ class SortSuite extends SparkPlanTest with SharedSQLContext { for ( dataType <- DataTypeTestUtils.atomicTypes ++ Set(NullType); nullable <- Seq(true, false); - sortOrder <- Seq('a.asc :: Nil, 'a.desc :: Nil); + sortOrder <- + Seq('a.asc :: Nil, 'a.asc_nullsLast :: Nil, 'a.desc :: Nil, 'a.desc_nullsFirst :: Nil); randomDataGenerator <- RandomDataGenerator.forType(dataType, nullable) ) { test(s"sorting on $dataType with nullable=$nullable, sortOrder=$sortOrder") { diff --git a/sql/hive/src/test/resources/sqlgen/agg2.sql b/sql/hive/src/test/resources/sqlgen/agg2.sql index 65d71714fe850..adbfdb7e79d64 100644 --- a/sql/hive/src/test/resources/sqlgen/agg2.sql +++ b/sql/hive/src/test/resources/sqlgen/agg2.sql @@ -1,4 +1,4 @@ -- This file is automatically generated by LogicalPlanToSQLSuite. SELECT COUNT(value) FROM parquet_t1 GROUP BY key ORDER BY MAX(key) -------------------------------------------------------------------------------- -SELECT `gen_attr_0` AS `count(value)` FROM (SELECT `gen_attr_0` FROM (SELECT count(`gen_attr_3`) AS `gen_attr_0`, max(`gen_attr_2`) AS `gen_attr_1` FROM (SELECT `key` AS `gen_attr_2`, `value` AS `gen_attr_3` FROM `default`.`parquet_t1`) AS gen_subquery_0 GROUP BY `gen_attr_2` ORDER BY `gen_attr_1` ASC) AS gen_subquery_1) AS gen_subquery_2 +SELECT `gen_attr_0` AS `count(value)` FROM (SELECT `gen_attr_0` FROM (SELECT count(`gen_attr_3`) AS `gen_attr_0`, max(`gen_attr_2`) AS `gen_attr_1` FROM (SELECT `key` AS `gen_attr_2`, `value` AS `gen_attr_3` FROM `default`.`parquet_t1`) AS gen_subquery_0 GROUP BY `gen_attr_2` ORDER BY `gen_attr_1` ASC NULLS FIRST) AS gen_subquery_1) AS gen_subquery_2 diff --git a/sql/hive/src/test/resources/sqlgen/agg3.sql b/sql/hive/src/test/resources/sqlgen/agg3.sql index 14b19392cdce3..207542d226e23 100644 --- a/sql/hive/src/test/resources/sqlgen/agg3.sql +++ b/sql/hive/src/test/resources/sqlgen/agg3.sql @@ -1,4 +1,4 @@ -- This file is automatically generated by LogicalPlanToSQLSuite. SELECT COUNT(value) FROM parquet_t1 GROUP BY key ORDER BY key, MAX(key) -------------------------------------------------------------------------------- -SELECT `gen_attr_0` AS `count(value)` FROM (SELECT `gen_attr_0` FROM (SELECT count(`gen_attr_4`) AS `gen_attr_0`, `gen_attr_3` AS `gen_attr_1`, max(`gen_attr_3`) AS `gen_attr_2` FROM (SELECT `key` AS `gen_attr_3`, `value` AS `gen_attr_4` FROM `default`.`parquet_t1`) AS gen_subquery_0 GROUP BY `gen_attr_3` ORDER BY `gen_attr_1` ASC, `gen_attr_2` ASC) AS gen_subquery_1) AS gen_subquery_2 +SELECT `gen_attr_0` AS `count(value)` FROM (SELECT `gen_attr_0` FROM (SELECT count(`gen_attr_4`) AS `gen_attr_0`, `gen_attr_3` AS `gen_attr_1`, max(`gen_attr_3`) AS `gen_attr_2` FROM (SELECT `key` AS `gen_attr_3`, `value` AS `gen_attr_4` FROM `default`.`parquet_t1`) AS gen_subquery_0 GROUP BY `gen_attr_3` ORDER BY `gen_attr_1` ASC NULLS FIRST, `gen_attr_2` ASC NULLS FIRST) AS gen_subquery_1) AS gen_subquery_2 diff --git a/sql/hive/src/test/resources/sqlgen/broadcast_join_subquery.sql b/sql/hive/src/test/resources/sqlgen/broadcast_join_subquery.sql index ec881a216e0b0..3de4f8a059965 100644 --- a/sql/hive/src/test/resources/sqlgen/broadcast_join_subquery.sql +++ b/sql/hive/src/test/resources/sqlgen/broadcast_join_subquery.sql @@ -5,4 +5,4 @@ FROM (SELECT x.key as key1, x.value as value1, y.key as key2, y.value as value2 JOIN srcpart z ON (subq.key1 = z.key and z.ds='2008-04-08' and z.hr=11) ORDER BY subq.key1, z.value -------------------------------------------------------------------------------- -SELECT `gen_attr_0` AS `key1`, `gen_attr_1` AS `value` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `gen_attr_5` AS `gen_attr_0`, `gen_attr_7` AS `gen_attr_6`, `gen_attr_9` AS `gen_attr_8`, `gen_attr_11` AS `gen_attr_10` FROM (SELECT `key` AS `gen_attr_5`, `value` AS `gen_attr_7` FROM `default`.`src1`) AS gen_subquery_0 INNER JOIN (SELECT `key` AS `gen_attr_9`, `value` AS `gen_attr_11` FROM `default`.`src`) AS gen_subquery_1 ON (`gen_attr_5` = `gen_attr_9`)) AS subq INNER JOIN (SELECT `key` AS `gen_attr_2`, `value` AS `gen_attr_1`, `ds` AS `gen_attr_3`, `hr` AS `gen_attr_4` FROM `default`.`srcpart`) AS gen_subquery_2 ON (((`gen_attr_0` = `gen_attr_2`) AND (`gen_attr_3` = '2008-04-08')) AND (CAST(`gen_attr_4` AS DOUBLE) = CAST(11 AS DOUBLE))) ORDER BY `gen_attr_0` ASC, `gen_attr_1` ASC) AS gen_subquery_3 +SELECT `gen_attr_0` AS `key1`, `gen_attr_1` AS `value` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `gen_attr_5` AS `gen_attr_0`, `gen_attr_7` AS `gen_attr_6`, `gen_attr_9` AS `gen_attr_8`, `gen_attr_11` AS `gen_attr_10` FROM (SELECT `key` AS `gen_attr_5`, `value` AS `gen_attr_7` FROM `default`.`src1`) AS gen_subquery_0 INNER JOIN (SELECT `key` AS `gen_attr_9`, `value` AS `gen_attr_11` FROM `default`.`src`) AS gen_subquery_1 ON (`gen_attr_5` = `gen_attr_9`)) AS subq INNER JOIN (SELECT `key` AS `gen_attr_2`, `value` AS `gen_attr_1`, `ds` AS `gen_attr_3`, `hr` AS `gen_attr_4` FROM `default`.`srcpart`) AS gen_subquery_2 ON (((`gen_attr_0` = `gen_attr_2`) AND (`gen_attr_3` = '2008-04-08')) AND (CAST(`gen_attr_4` AS DOUBLE) = CAST(11 AS DOUBLE))) ORDER BY `gen_attr_0` ASC NULLS FIRST, `gen_attr_1` ASC NULLS FIRST) AS gen_subquery_3 diff --git a/sql/hive/src/test/resources/sqlgen/generate_with_other_1.sql b/sql/hive/src/test/resources/sqlgen/generate_with_other_1.sql index 805197a4ea11b..ab444d0c70936 100644 --- a/sql/hive/src/test/resources/sqlgen/generate_with_other_1.sql +++ b/sql/hive/src/test/resources/sqlgen/generate_with_other_1.sql @@ -5,4 +5,4 @@ WHERE id > 2 ORDER BY val, id LIMIT 5 -------------------------------------------------------------------------------- -SELECT `gen_attr_0` AS `val`, `gen_attr_1` AS `id` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT gen_subquery_0.`gen_attr_2`, gen_subquery_0.`gen_attr_3`, gen_subquery_0.`gen_attr_4`, gen_subquery_0.`gen_attr_1` FROM (SELECT `arr` AS `gen_attr_2`, `arr2` AS `gen_attr_3`, `json` AS `gen_attr_4`, `id` AS `gen_attr_1` FROM `default`.`parquet_t3`) AS gen_subquery_0 WHERE (`gen_attr_1` > CAST(2 AS BIGINT))) AS gen_subquery_1 LATERAL VIEW explode(`gen_attr_2`) gen_subquery_2 AS `gen_attr_0` ORDER BY `gen_attr_0` ASC, `gen_attr_1` ASC LIMIT 5) AS parquet_t3 +SELECT `gen_attr_0` AS `val`, `gen_attr_1` AS `id` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT gen_subquery_0.`gen_attr_2`, gen_subquery_0.`gen_attr_3`, gen_subquery_0.`gen_attr_4`, gen_subquery_0.`gen_attr_1` FROM (SELECT `arr` AS `gen_attr_2`, `arr2` AS `gen_attr_3`, `json` AS `gen_attr_4`, `id` AS `gen_attr_1` FROM `default`.`parquet_t3`) AS gen_subquery_0 WHERE (`gen_attr_1` > CAST(2 AS BIGINT))) AS gen_subquery_1 LATERAL VIEW explode(`gen_attr_2`) gen_subquery_2 AS `gen_attr_0` ORDER BY `gen_attr_0` ASC NULLS FIRST, `gen_attr_1` ASC NULLS FIRST LIMIT 5) AS parquet_t3 diff --git a/sql/hive/src/test/resources/sqlgen/generate_with_other_2.sql b/sql/hive/src/test/resources/sqlgen/generate_with_other_2.sql index ef9a596197b8b..42a2369f34d1c 100644 --- a/sql/hive/src/test/resources/sqlgen/generate_with_other_2.sql +++ b/sql/hive/src/test/resources/sqlgen/generate_with_other_2.sql @@ -7,4 +7,4 @@ WHERE val > 2 ORDER BY val, id LIMIT 5 -------------------------------------------------------------------------------- -SELECT `gen_attr_0` AS `val`, `gen_attr_1` AS `id` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `arr` AS `gen_attr_4`, `arr2` AS `gen_attr_3`, `json` AS `gen_attr_5`, `id` AS `gen_attr_1` FROM `default`.`parquet_t3`) AS gen_subquery_0 LATERAL VIEW explode(`gen_attr_3`) gen_subquery_2 AS `gen_attr_2` LATERAL VIEW explode(`gen_attr_2`) gen_subquery_3 AS `gen_attr_0` WHERE (`gen_attr_0` > CAST(2 AS BIGINT)) ORDER BY `gen_attr_0` ASC, `gen_attr_1` ASC LIMIT 5) AS gen_subquery_1 +SELECT `gen_attr_0` AS `val`, `gen_attr_1` AS `id` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `arr` AS `gen_attr_4`, `arr2` AS `gen_attr_3`, `json` AS `gen_attr_5`, `id` AS `gen_attr_1` FROM `default`.`parquet_t3`) AS gen_subquery_0 LATERAL VIEW explode(`gen_attr_3`) gen_subquery_2 AS `gen_attr_2` LATERAL VIEW explode(`gen_attr_2`) gen_subquery_3 AS `gen_attr_0` WHERE (`gen_attr_0` > CAST(2 AS BIGINT)) ORDER BY `gen_attr_0` ASC NULLS FIRST, `gen_attr_1` ASC NULLS FIRST LIMIT 5) AS gen_subquery_1 diff --git a/sql/hive/src/test/resources/sqlgen/grouping_sets_2_1.sql b/sql/hive/src/test/resources/sqlgen/grouping_sets_2_1.sql index b2c426c660d80..245b52341658f 100644 --- a/sql/hive/src/test/resources/sqlgen/grouping_sets_2_1.sql +++ b/sql/hive/src/test/resources/sqlgen/grouping_sets_2_1.sql @@ -1,4 +1,4 @@ -- This file is automatically generated by LogicalPlanToSQLSuite. SELECT a, b, sum(c) FROM parquet_t2 GROUP BY a, b GROUPING SETS (a, b) ORDER BY a, b -------------------------------------------------------------------------------- -SELECT `gen_attr_0` AS `a`, `gen_attr_1` AS `b`, `gen_attr_3` AS `sum(c)` FROM (SELECT `gen_attr_5` AS `gen_attr_0`, `gen_attr_6` AS `gen_attr_1`, sum(`gen_attr_4`) AS `gen_attr_3` FROM (SELECT `a` AS `gen_attr_5`, `b` AS `gen_attr_6`, `c` AS `gen_attr_4`, `d` AS `gen_attr_7` FROM `default`.`parquet_t2`) AS gen_subquery_0 GROUP BY `gen_attr_5`, `gen_attr_6` GROUPING SETS((`gen_attr_5`), (`gen_attr_6`)) ORDER BY `gen_attr_0` ASC, `gen_attr_1` ASC) AS gen_subquery_1 +SELECT `gen_attr_0` AS `a`, `gen_attr_1` AS `b`, `gen_attr_3` AS `sum(c)` FROM (SELECT `gen_attr_5` AS `gen_attr_0`, `gen_attr_6` AS `gen_attr_1`, sum(`gen_attr_4`) AS `gen_attr_3` FROM (SELECT `a` AS `gen_attr_5`, `b` AS `gen_attr_6`, `c` AS `gen_attr_4`, `d` AS `gen_attr_7` FROM `default`.`parquet_t2`) AS gen_subquery_0 GROUP BY `gen_attr_5`, `gen_attr_6` GROUPING SETS((`gen_attr_5`), (`gen_attr_6`)) ORDER BY `gen_attr_0` ASC NULLS FIRST, `gen_attr_1` ASC NULLS FIRST) AS gen_subquery_1 diff --git a/sql/hive/src/test/resources/sqlgen/grouping_sets_2_2.sql b/sql/hive/src/test/resources/sqlgen/grouping_sets_2_2.sql index 96ee8e85951e8..1505dea11ec68 100644 --- a/sql/hive/src/test/resources/sqlgen/grouping_sets_2_2.sql +++ b/sql/hive/src/test/resources/sqlgen/grouping_sets_2_2.sql @@ -1,4 +1,4 @@ -- This file is automatically generated by LogicalPlanToSQLSuite. SELECT a, b, sum(c) FROM parquet_t2 GROUP BY a, b GROUPING SETS (a) ORDER BY a, b -------------------------------------------------------------------------------- -SELECT `gen_attr_0` AS `a`, `gen_attr_1` AS `b`, `gen_attr_3` AS `sum(c)` FROM (SELECT `gen_attr_5` AS `gen_attr_0`, `gen_attr_6` AS `gen_attr_1`, sum(`gen_attr_4`) AS `gen_attr_3` FROM (SELECT `a` AS `gen_attr_5`, `b` AS `gen_attr_6`, `c` AS `gen_attr_4`, `d` AS `gen_attr_7` FROM `default`.`parquet_t2`) AS gen_subquery_0 GROUP BY `gen_attr_5`, `gen_attr_6` GROUPING SETS((`gen_attr_5`)) ORDER BY `gen_attr_0` ASC, `gen_attr_1` ASC) AS gen_subquery_1 +SELECT `gen_attr_0` AS `a`, `gen_attr_1` AS `b`, `gen_attr_3` AS `sum(c)` FROM (SELECT `gen_attr_5` AS `gen_attr_0`, `gen_attr_6` AS `gen_attr_1`, sum(`gen_attr_4`) AS `gen_attr_3` FROM (SELECT `a` AS `gen_attr_5`, `b` AS `gen_attr_6`, `c` AS `gen_attr_4`, `d` AS `gen_attr_7` FROM `default`.`parquet_t2`) AS gen_subquery_0 GROUP BY `gen_attr_5`, `gen_attr_6` GROUPING SETS((`gen_attr_5`)) ORDER BY `gen_attr_0` ASC NULLS FIRST, `gen_attr_1` ASC NULLS FIRST) AS gen_subquery_1 diff --git a/sql/hive/src/test/resources/sqlgen/grouping_sets_2_3.sql b/sql/hive/src/test/resources/sqlgen/grouping_sets_2_3.sql index 9b8b230c879c2..281add6aabb64 100644 --- a/sql/hive/src/test/resources/sqlgen/grouping_sets_2_3.sql +++ b/sql/hive/src/test/resources/sqlgen/grouping_sets_2_3.sql @@ -1,4 +1,4 @@ -- This file is automatically generated by LogicalPlanToSQLSuite. SELECT a, b, sum(c) FROM parquet_t2 GROUP BY a, b GROUPING SETS (b) ORDER BY a, b -------------------------------------------------------------------------------- -SELECT `gen_attr_0` AS `a`, `gen_attr_1` AS `b`, `gen_attr_3` AS `sum(c)` FROM (SELECT `gen_attr_5` AS `gen_attr_0`, `gen_attr_6` AS `gen_attr_1`, sum(`gen_attr_4`) AS `gen_attr_3` FROM (SELECT `a` AS `gen_attr_5`, `b` AS `gen_attr_6`, `c` AS `gen_attr_4`, `d` AS `gen_attr_7` FROM `default`.`parquet_t2`) AS gen_subquery_0 GROUP BY `gen_attr_5`, `gen_attr_6` GROUPING SETS((`gen_attr_6`)) ORDER BY `gen_attr_0` ASC, `gen_attr_1` ASC) AS gen_subquery_1 +SELECT `gen_attr_0` AS `a`, `gen_attr_1` AS `b`, `gen_attr_3` AS `sum(c)` FROM (SELECT `gen_attr_5` AS `gen_attr_0`, `gen_attr_6` AS `gen_attr_1`, sum(`gen_attr_4`) AS `gen_attr_3` FROM (SELECT `a` AS `gen_attr_5`, `b` AS `gen_attr_6`, `c` AS `gen_attr_4`, `d` AS `gen_attr_7` FROM `default`.`parquet_t2`) AS gen_subquery_0 GROUP BY `gen_attr_5`, `gen_attr_6` GROUPING SETS((`gen_attr_6`)) ORDER BY `gen_attr_0` ASC NULLS FIRST, `gen_attr_1` ASC NULLS FIRST) AS gen_subquery_1 diff --git a/sql/hive/src/test/resources/sqlgen/grouping_sets_2_4.sql b/sql/hive/src/test/resources/sqlgen/grouping_sets_2_4.sql index c35db74a5c5b5..f8d64742b11e3 100644 --- a/sql/hive/src/test/resources/sqlgen/grouping_sets_2_4.sql +++ b/sql/hive/src/test/resources/sqlgen/grouping_sets_2_4.sql @@ -1,4 +1,4 @@ -- This file is automatically generated by LogicalPlanToSQLSuite. SELECT a, b, sum(c) FROM parquet_t2 GROUP BY a, b GROUPING SETS (()) ORDER BY a, b -------------------------------------------------------------------------------- -SELECT `gen_attr_0` AS `a`, `gen_attr_1` AS `b`, `gen_attr_3` AS `sum(c)` FROM (SELECT `gen_attr_5` AS `gen_attr_0`, `gen_attr_6` AS `gen_attr_1`, sum(`gen_attr_4`) AS `gen_attr_3` FROM (SELECT `a` AS `gen_attr_5`, `b` AS `gen_attr_6`, `c` AS `gen_attr_4`, `d` AS `gen_attr_7` FROM `default`.`parquet_t2`) AS gen_subquery_0 GROUP BY `gen_attr_5`, `gen_attr_6` GROUPING SETS(()) ORDER BY `gen_attr_0` ASC, `gen_attr_1` ASC) AS gen_subquery_1 +SELECT `gen_attr_0` AS `a`, `gen_attr_1` AS `b`, `gen_attr_3` AS `sum(c)` FROM (SELECT `gen_attr_5` AS `gen_attr_0`, `gen_attr_6` AS `gen_attr_1`, sum(`gen_attr_4`) AS `gen_attr_3` FROM (SELECT `a` AS `gen_attr_5`, `b` AS `gen_attr_6`, `c` AS `gen_attr_4`, `d` AS `gen_attr_7` FROM `default`.`parquet_t2`) AS gen_subquery_0 GROUP BY `gen_attr_5`, `gen_attr_6` GROUPING SETS(()) ORDER BY `gen_attr_0` ASC NULLS FIRST, `gen_attr_1` ASC NULLS FIRST) AS gen_subquery_1 diff --git a/sql/hive/src/test/resources/sqlgen/grouping_sets_2_5.sql b/sql/hive/src/test/resources/sqlgen/grouping_sets_2_5.sql index e47f6d5dcf465..09e6ec2a5f8c9 100644 --- a/sql/hive/src/test/resources/sqlgen/grouping_sets_2_5.sql +++ b/sql/hive/src/test/resources/sqlgen/grouping_sets_2_5.sql @@ -2,4 +2,4 @@ SELECT a, b, sum(c) FROM parquet_t2 GROUP BY a, b GROUPING SETS ((), (a), (a, b)) ORDER BY a, b -------------------------------------------------------------------------------- -SELECT `gen_attr_0` AS `a`, `gen_attr_1` AS `b`, `gen_attr_3` AS `sum(c)` FROM (SELECT `gen_attr_5` AS `gen_attr_0`, `gen_attr_6` AS `gen_attr_1`, sum(`gen_attr_4`) AS `gen_attr_3` FROM (SELECT `a` AS `gen_attr_5`, `b` AS `gen_attr_6`, `c` AS `gen_attr_4`, `d` AS `gen_attr_7` FROM `default`.`parquet_t2`) AS gen_subquery_0 GROUP BY `gen_attr_5`, `gen_attr_6` GROUPING SETS((), (`gen_attr_5`), (`gen_attr_5`, `gen_attr_6`)) ORDER BY `gen_attr_0` ASC, `gen_attr_1` ASC) AS gen_subquery_1 +SELECT `gen_attr_0` AS `a`, `gen_attr_1` AS `b`, `gen_attr_3` AS `sum(c)` FROM (SELECT `gen_attr_5` AS `gen_attr_0`, `gen_attr_6` AS `gen_attr_1`, sum(`gen_attr_4`) AS `gen_attr_3` FROM (SELECT `a` AS `gen_attr_5`, `b` AS `gen_attr_6`, `c` AS `gen_attr_4`, `d` AS `gen_attr_7` FROM `default`.`parquet_t2`) AS gen_subquery_0 GROUP BY `gen_attr_5`, `gen_attr_6` GROUPING SETS((), (`gen_attr_5`), (`gen_attr_5`, `gen_attr_6`)) ORDER BY `gen_attr_0` ASC NULLS FIRST, `gen_attr_1` ASC NULLS FIRST) AS gen_subquery_1 diff --git a/sql/hive/src/test/resources/sqlgen/rollup_cube_6_1.sql b/sql/hive/src/test/resources/sqlgen/rollup_cube_6_1.sql index 22df578518ef3..c364c32dd5c55 100644 --- a/sql/hive/src/test/resources/sqlgen/rollup_cube_6_1.sql +++ b/sql/hive/src/test/resources/sqlgen/rollup_cube_6_1.sql @@ -1,4 +1,4 @@ -- This file is automatically generated by LogicalPlanToSQLSuite. SELECT a, b, sum(c) FROM parquet_t2 GROUP BY ROLLUP(a, b) ORDER BY a, b -------------------------------------------------------------------------------- -SELECT `gen_attr_0` AS `a`, `gen_attr_1` AS `b`, `gen_attr_3` AS `sum(c)` FROM (SELECT `gen_attr_5` AS `gen_attr_0`, `gen_attr_6` AS `gen_attr_1`, sum(`gen_attr_4`) AS `gen_attr_3` FROM (SELECT `a` AS `gen_attr_5`, `b` AS `gen_attr_6`, `c` AS `gen_attr_4`, `d` AS `gen_attr_7` FROM `default`.`parquet_t2`) AS gen_subquery_0 GROUP BY `gen_attr_5`, `gen_attr_6` GROUPING SETS((`gen_attr_5`, `gen_attr_6`), (`gen_attr_5`), ()) ORDER BY `gen_attr_0` ASC, `gen_attr_1` ASC) AS gen_subquery_1 +SELECT `gen_attr_0` AS `a`, `gen_attr_1` AS `b`, `gen_attr_3` AS `sum(c)` FROM (SELECT `gen_attr_5` AS `gen_attr_0`, `gen_attr_6` AS `gen_attr_1`, sum(`gen_attr_4`) AS `gen_attr_3` FROM (SELECT `a` AS `gen_attr_5`, `b` AS `gen_attr_6`, `c` AS `gen_attr_4`, `d` AS `gen_attr_7` FROM `default`.`parquet_t2`) AS gen_subquery_0 GROUP BY `gen_attr_5`, `gen_attr_6` GROUPING SETS((`gen_attr_5`, `gen_attr_6`), (`gen_attr_5`), ()) ORDER BY `gen_attr_0` ASC NULLS FIRST, `gen_attr_1` ASC NULLS FIRST) AS gen_subquery_1 diff --git a/sql/hive/src/test/resources/sqlgen/rollup_cube_6_2.sql b/sql/hive/src/test/resources/sqlgen/rollup_cube_6_2.sql index f44b652343acb..36c0223fceced 100644 --- a/sql/hive/src/test/resources/sqlgen/rollup_cube_6_2.sql +++ b/sql/hive/src/test/resources/sqlgen/rollup_cube_6_2.sql @@ -1,4 +1,4 @@ -- This file is automatically generated by LogicalPlanToSQLSuite. SELECT a, b, sum(c) FROM parquet_t2 GROUP BY CUBE(a, b) ORDER BY a, b -------------------------------------------------------------------------------- -SELECT `gen_attr_0` AS `a`, `gen_attr_1` AS `b`, `gen_attr_3` AS `sum(c)` FROM (SELECT `gen_attr_5` AS `gen_attr_0`, `gen_attr_6` AS `gen_attr_1`, sum(`gen_attr_4`) AS `gen_attr_3` FROM (SELECT `a` AS `gen_attr_5`, `b` AS `gen_attr_6`, `c` AS `gen_attr_4`, `d` AS `gen_attr_7` FROM `default`.`parquet_t2`) AS gen_subquery_0 GROUP BY `gen_attr_5`, `gen_attr_6` GROUPING SETS((`gen_attr_5`, `gen_attr_6`), (`gen_attr_5`), (`gen_attr_6`), ()) ORDER BY `gen_attr_0` ASC, `gen_attr_1` ASC) AS gen_subquery_1 +SELECT `gen_attr_0` AS `a`, `gen_attr_1` AS `b`, `gen_attr_3` AS `sum(c)` FROM (SELECT `gen_attr_5` AS `gen_attr_0`, `gen_attr_6` AS `gen_attr_1`, sum(`gen_attr_4`) AS `gen_attr_3` FROM (SELECT `a` AS `gen_attr_5`, `b` AS `gen_attr_6`, `c` AS `gen_attr_4`, `d` AS `gen_attr_7` FROM `default`.`parquet_t2`) AS gen_subquery_0 GROUP BY `gen_attr_5`, `gen_attr_6` GROUPING SETS((`gen_attr_5`, `gen_attr_6`), (`gen_attr_5`), (`gen_attr_6`), ()) ORDER BY `gen_attr_0` ASC NULLS FIRST, `gen_attr_1` ASC NULLS FIRST) AS gen_subquery_1 diff --git a/sql/hive/src/test/resources/sqlgen/rollup_cube_6_3.sql b/sql/hive/src/test/resources/sqlgen/rollup_cube_6_3.sql index 40f6924913765..ed33f2a1de3cf 100644 --- a/sql/hive/src/test/resources/sqlgen/rollup_cube_6_3.sql +++ b/sql/hive/src/test/resources/sqlgen/rollup_cube_6_3.sql @@ -1,4 +1,4 @@ -- This file is automatically generated by LogicalPlanToSQLSuite. SELECT a, b, sum(a) FROM parquet_t2 GROUP BY ROLLUP(a, b) ORDER BY a, b -------------------------------------------------------------------------------- -SELECT `gen_attr_0` AS `a`, `gen_attr_1` AS `b`, `gen_attr_3` AS `sum(a)` FROM (SELECT `gen_attr_4` AS `gen_attr_0`, `gen_attr_5` AS `gen_attr_1`, sum(`gen_attr_4`) AS `gen_attr_3` FROM (SELECT `a` AS `gen_attr_4`, `b` AS `gen_attr_5`, `c` AS `gen_attr_6`, `d` AS `gen_attr_7` FROM `default`.`parquet_t2`) AS gen_subquery_0 GROUP BY `gen_attr_4`, `gen_attr_5` GROUPING SETS((`gen_attr_4`, `gen_attr_5`), (`gen_attr_4`), ()) ORDER BY `gen_attr_0` ASC, `gen_attr_1` ASC) AS gen_subquery_1 +SELECT `gen_attr_0` AS `a`, `gen_attr_1` AS `b`, `gen_attr_3` AS `sum(a)` FROM (SELECT `gen_attr_4` AS `gen_attr_0`, `gen_attr_5` AS `gen_attr_1`, sum(`gen_attr_4`) AS `gen_attr_3` FROM (SELECT `a` AS `gen_attr_4`, `b` AS `gen_attr_5`, `c` AS `gen_attr_6`, `d` AS `gen_attr_7` FROM `default`.`parquet_t2`) AS gen_subquery_0 GROUP BY `gen_attr_4`, `gen_attr_5` GROUPING SETS((`gen_attr_4`, `gen_attr_5`), (`gen_attr_4`), ()) ORDER BY `gen_attr_0` ASC NULLS FIRST, `gen_attr_1` ASC NULLS FIRST) AS gen_subquery_1 diff --git a/sql/hive/src/test/resources/sqlgen/rollup_cube_6_4.sql b/sql/hive/src/test/resources/sqlgen/rollup_cube_6_4.sql index 608e644dee6d0..e0e40241480da 100644 --- a/sql/hive/src/test/resources/sqlgen/rollup_cube_6_4.sql +++ b/sql/hive/src/test/resources/sqlgen/rollup_cube_6_4.sql @@ -1,4 +1,4 @@ -- This file is automatically generated by LogicalPlanToSQLSuite. SELECT a, b, sum(a) FROM parquet_t2 GROUP BY CUBE(a, b) ORDER BY a, b -------------------------------------------------------------------------------- -SELECT `gen_attr_0` AS `a`, `gen_attr_1` AS `b`, `gen_attr_3` AS `sum(a)` FROM (SELECT `gen_attr_4` AS `gen_attr_0`, `gen_attr_5` AS `gen_attr_1`, sum(`gen_attr_4`) AS `gen_attr_3` FROM (SELECT `a` AS `gen_attr_4`, `b` AS `gen_attr_5`, `c` AS `gen_attr_6`, `d` AS `gen_attr_7` FROM `default`.`parquet_t2`) AS gen_subquery_0 GROUP BY `gen_attr_4`, `gen_attr_5` GROUPING SETS((`gen_attr_4`, `gen_attr_5`), (`gen_attr_4`), (`gen_attr_5`), ()) ORDER BY `gen_attr_0` ASC, `gen_attr_1` ASC) AS gen_subquery_1 +SELECT `gen_attr_0` AS `a`, `gen_attr_1` AS `b`, `gen_attr_3` AS `sum(a)` FROM (SELECT `gen_attr_4` AS `gen_attr_0`, `gen_attr_5` AS `gen_attr_1`, sum(`gen_attr_4`) AS `gen_attr_3` FROM (SELECT `a` AS `gen_attr_4`, `b` AS `gen_attr_5`, `c` AS `gen_attr_6`, `d` AS `gen_attr_7` FROM `default`.`parquet_t2`) AS gen_subquery_0 GROUP BY `gen_attr_4`, `gen_attr_5` GROUPING SETS((`gen_attr_4`, `gen_attr_5`), (`gen_attr_4`), (`gen_attr_5`), ()) ORDER BY `gen_attr_0` ASC NULLS FIRST, `gen_attr_1` ASC NULLS FIRST) AS gen_subquery_1 diff --git a/sql/hive/src/test/resources/sqlgen/sort_asc_nulls_last.sql b/sql/hive/src/test/resources/sqlgen/sort_asc_nulls_last.sql new file mode 100644 index 0000000000000..da4e3678a33b9 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/sort_asc_nulls_last.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT COUNT(value) FROM parquet_t1 GROUP BY key ORDER BY key nulls last, MAX(key) +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `count(value)` FROM (SELECT `gen_attr_0` FROM (SELECT count(`gen_attr_4`) AS `gen_attr_0`, `gen_attr_3` AS `gen_attr_1`, max(`gen_attr_3`) AS `gen_attr_2` FROM (SELECT `key` AS `gen_attr_3`, `value` AS `gen_attr_4` FROM `default`.`parquet_t1`) AS gen_subquery_0 GROUP BY `gen_attr_3` ORDER BY `gen_attr_1` ASC NULLS LAST, `gen_attr_2` ASC NULLS FIRST) AS gen_subquery_1) AS gen_subquery_2 diff --git a/sql/hive/src/test/resources/sqlgen/sort_by_after_having.sql b/sql/hive/src/test/resources/sqlgen/sort_by_after_having.sql index da60204297a21..a4f3ddc761f30 100644 --- a/sql/hive/src/test/resources/sqlgen/sort_by_after_having.sql +++ b/sql/hive/src/test/resources/sqlgen/sort_by_after_having.sql @@ -1,4 +1,4 @@ -- This file is automatically generated by LogicalPlanToSQLSuite. SELECT COUNT(value) FROM parquet_t1 GROUP BY key HAVING MAX(key) > 0 SORT BY key -------------------------------------------------------------------------------- -SELECT `gen_attr_0` AS `count(value)` FROM (SELECT `gen_attr_0` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT count(`gen_attr_3`) AS `gen_attr_0`, max(`gen_attr_1`) AS `gen_attr_2`, `gen_attr_1` FROM (SELECT `key` AS `gen_attr_1`, `value` AS `gen_attr_3` FROM `default`.`parquet_t1`) AS gen_subquery_0 GROUP BY `gen_attr_1` HAVING (`gen_attr_2` > CAST(0 AS BIGINT))) AS gen_subquery_1 SORT BY `gen_attr_1` ASC) AS gen_subquery_2) AS gen_subquery_3 +SELECT `gen_attr_0` AS `count(value)` FROM (SELECT `gen_attr_0` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT count(`gen_attr_3`) AS `gen_attr_0`, max(`gen_attr_1`) AS `gen_attr_2`, `gen_attr_1` FROM (SELECT `key` AS `gen_attr_1`, `value` AS `gen_attr_3` FROM `default`.`parquet_t1`) AS gen_subquery_0 GROUP BY `gen_attr_1` HAVING (`gen_attr_2` > CAST(0 AS BIGINT))) AS gen_subquery_1 SORT BY `gen_attr_1` ASC NULLS FIRST) AS gen_subquery_2) AS gen_subquery_3 diff --git a/sql/hive/src/test/resources/sqlgen/sort_desc_nulls_first.sql b/sql/hive/src/test/resources/sqlgen/sort_desc_nulls_first.sql new file mode 100644 index 0000000000000..d995e3bdfad5c --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/sort_desc_nulls_first.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT COUNT(value) FROM parquet_t1 GROUP BY key ORDER BY key desc nulls first,MAX(key) +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `count(value)` FROM (SELECT `gen_attr_0` FROM (SELECT count(`gen_attr_4`) AS `gen_attr_0`, `gen_attr_3` AS `gen_attr_1`, max(`gen_attr_3`) AS `gen_attr_2` FROM (SELECT `key` AS `gen_attr_3`, `value` AS `gen_attr_4` FROM `default`.`parquet_t1`) AS gen_subquery_0 GROUP BY `gen_attr_3` ORDER BY `gen_attr_1` DESC NULLS FIRST, `gen_attr_2` ASC NULLS FIRST) AS gen_subquery_1) AS gen_subquery_2 diff --git a/sql/hive/src/test/resources/sqlgen/subquery_in_having_1.sql b/sql/hive/src/test/resources/sqlgen/subquery_in_having_1.sql index 9894f5ab39c76..25882147463b9 100644 --- a/sql/hive/src/test/resources/sqlgen/subquery_in_having_1.sql +++ b/sql/hive/src/test/resources/sqlgen/subquery_in_having_1.sql @@ -5,4 +5,4 @@ group by key having count(*) in (select count(*) from src s1 where s1.key = '90' group by s1.key) order by key -------------------------------------------------------------------------------- -SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `count(1)` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `gen_attr_0`, count(1) AS `gen_attr_1`, count(1) AS `gen_attr_2` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_4` FROM `default`.`src`) AS gen_subquery_0 GROUP BY `gen_attr_0` HAVING (`gen_attr_2` IN (SELECT `gen_attr_5` AS `_c0` FROM (SELECT `gen_attr_3` AS `gen_attr_5` FROM (SELECT count(1) AS `gen_attr_3` FROM (SELECT `key` AS `gen_attr_6`, `value` AS `gen_attr_7` FROM `default`.`src`) AS gen_subquery_3 WHERE (CAST(`gen_attr_6` AS DOUBLE) = CAST('90' AS DOUBLE)) GROUP BY `gen_attr_6`) AS gen_subquery_2) AS gen_subquery_4))) AS gen_subquery_1 ORDER BY `gen_attr_0` ASC) AS src +SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `count(1)` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `gen_attr_0`, count(1) AS `gen_attr_1`, count(1) AS `gen_attr_2` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_4` FROM `default`.`src`) AS gen_subquery_0 GROUP BY `gen_attr_0` HAVING (`gen_attr_2` IN (SELECT `gen_attr_5` AS `_c0` FROM (SELECT `gen_attr_3` AS `gen_attr_5` FROM (SELECT count(1) AS `gen_attr_3` FROM (SELECT `key` AS `gen_attr_6`, `value` AS `gen_attr_7` FROM `default`.`src`) AS gen_subquery_3 WHERE (CAST(`gen_attr_6` AS DOUBLE) = CAST('90' AS DOUBLE)) GROUP BY `gen_attr_6`) AS gen_subquery_2) AS gen_subquery_4))) AS gen_subquery_1 ORDER BY `gen_attr_0` ASC NULLS FIRST) AS src diff --git a/sql/hive/src/test/resources/sqlgen/subquery_in_having_2.sql b/sql/hive/src/test/resources/sqlgen/subquery_in_having_2.sql index c3a122aa889b9..de0116a4dcbaf 100644 --- a/sql/hive/src/test/resources/sqlgen/subquery_in_having_2.sql +++ b/sql/hive/src/test/resources/sqlgen/subquery_in_having_2.sql @@ -7,4 +7,4 @@ having b.key in (select a.key where a.value > 'val_9' and a.value = min(b.value)) order by b.key -------------------------------------------------------------------------------- -SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `min(value)` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `gen_attr_0`, min(`gen_attr_5`) AS `gen_attr_1`, min(`gen_attr_5`) AS `gen_attr_4` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_5` FROM `default`.`src`) AS gen_subquery_0 GROUP BY `gen_attr_0` HAVING (struct(`gen_attr_0`, `gen_attr_4`) IN (SELECT `gen_attr_6` AS `_c0`, `gen_attr_7` AS `_c1` FROM (SELECT `gen_attr_2` AS `gen_attr_6`, `gen_attr_3` AS `gen_attr_7` FROM (SELECT `gen_attr_2`, `gen_attr_3` FROM (SELECT `key` AS `gen_attr_2`, `value` AS `gen_attr_3` FROM `default`.`src`) AS gen_subquery_3 WHERE (`gen_attr_3` > 'val_9')) AS gen_subquery_2) AS gen_subquery_4))) AS gen_subquery_1 ORDER BY `gen_attr_0` ASC) AS b +SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `min(value)` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `gen_attr_0`, min(`gen_attr_5`) AS `gen_attr_1`, min(`gen_attr_5`) AS `gen_attr_4` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_5` FROM `default`.`src`) AS gen_subquery_0 GROUP BY `gen_attr_0` HAVING (struct(`gen_attr_0`, `gen_attr_4`) IN (SELECT `gen_attr_6` AS `_c0`, `gen_attr_7` AS `_c1` FROM (SELECT `gen_attr_2` AS `gen_attr_6`, `gen_attr_3` AS `gen_attr_7` FROM (SELECT `gen_attr_2`, `gen_attr_3` FROM (SELECT `key` AS `gen_attr_2`, `value` AS `gen_attr_3` FROM `default`.`src`) AS gen_subquery_3 WHERE (`gen_attr_3` > 'val_9')) AS gen_subquery_2) AS gen_subquery_4))) AS gen_subquery_1 ORDER BY `gen_attr_0` ASC NULLS FIRST) AS b diff --git a/sql/hive/src/test/resources/sqlgen/window_basic_2.sql b/sql/hive/src/test/resources/sqlgen/window_basic_2.sql index ec55d4b7146f2..0e2a9a54731fc 100644 --- a/sql/hive/src/test/resources/sqlgen/window_basic_2.sql +++ b/sql/hive/src/test/resources/sqlgen/window_basic_2.sql @@ -2,4 +2,4 @@ SELECT key, value, ROUND(AVG(key) OVER (), 2) FROM parquet_t1 ORDER BY key -------------------------------------------------------------------------------- -SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `value`, `gen_attr_2` AS `round(avg(key) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), 2)` FROM (SELECT `gen_attr_0`, `gen_attr_1`, round(`gen_attr_3`, 2) AS `gen_attr_2` FROM (SELECT gen_subquery_1.`gen_attr_0`, gen_subquery_1.`gen_attr_1`, avg(`gen_attr_0`) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS `gen_attr_3` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_1` FROM `default`.`parquet_t1`) AS gen_subquery_0) AS gen_subquery_1) AS gen_subquery_2 ORDER BY `gen_attr_0` ASC) AS parquet_t1 +SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `value`, `gen_attr_2` AS `round(avg(key) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), 2)` FROM (SELECT `gen_attr_0`, `gen_attr_1`, round(`gen_attr_3`, 2) AS `gen_attr_2` FROM (SELECT gen_subquery_1.`gen_attr_0`, gen_subquery_1.`gen_attr_1`, avg(`gen_attr_0`) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS `gen_attr_3` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_1` FROM `default`.`parquet_t1`) AS gen_subquery_0) AS gen_subquery_1) AS gen_subquery_2 ORDER BY `gen_attr_0` ASC NULLS FIRST) AS parquet_t1 diff --git a/sql/hive/src/test/resources/sqlgen/window_basic_3.sql b/sql/hive/src/test/resources/sqlgen/window_basic_3.sql index c0ac9541e67ee..d727caa583e61 100644 --- a/sql/hive/src/test/resources/sqlgen/window_basic_3.sql +++ b/sql/hive/src/test/resources/sqlgen/window_basic_3.sql @@ -2,4 +2,4 @@ SELECT value, MAX(key + 1) OVER (PARTITION BY key % 5 ORDER BY key % 7) AS max FROM parquet_t1 -------------------------------------------------------------------------------- -SELECT `gen_attr_0` AS `value`, `gen_attr_1` AS `max` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT gen_subquery_1.`gen_attr_0`, gen_subquery_1.`gen_attr_2`, gen_subquery_1.`gen_attr_3`, gen_subquery_1.`gen_attr_4`, max(`gen_attr_2`) OVER (PARTITION BY `gen_attr_3` ORDER BY `gen_attr_4` ASC RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS `gen_attr_1` FROM (SELECT `gen_attr_0`, (`gen_attr_5` + CAST(1 AS BIGINT)) AS `gen_attr_2`, (`gen_attr_5` % CAST(5 AS BIGINT)) AS `gen_attr_3`, (`gen_attr_5` % CAST(7 AS BIGINT)) AS `gen_attr_4` FROM (SELECT `key` AS `gen_attr_5`, `value` AS `gen_attr_0` FROM `default`.`parquet_t1`) AS gen_subquery_0) AS gen_subquery_1) AS gen_subquery_2) AS parquet_t1 +SELECT `gen_attr_0` AS `value`, `gen_attr_1` AS `max` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT gen_subquery_1.`gen_attr_0`, gen_subquery_1.`gen_attr_2`, gen_subquery_1.`gen_attr_3`, gen_subquery_1.`gen_attr_4`, max(`gen_attr_2`) OVER (PARTITION BY `gen_attr_3` ORDER BY `gen_attr_4` ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS `gen_attr_1` FROM (SELECT `gen_attr_0`, (`gen_attr_5` + CAST(1 AS BIGINT)) AS `gen_attr_2`, (`gen_attr_5` % CAST(5 AS BIGINT)) AS `gen_attr_3`, (`gen_attr_5` % CAST(7 AS BIGINT)) AS `gen_attr_4` FROM (SELECT `key` AS `gen_attr_5`, `value` AS `gen_attr_0` FROM `default`.`parquet_t1`) AS gen_subquery_0) AS gen_subquery_1) AS gen_subquery_2) AS parquet_t1 diff --git a/sql/hive/src/test/resources/sqlgen/window_basic_asc_nulls_last.sql b/sql/hive/src/test/resources/sqlgen/window_basic_asc_nulls_last.sql new file mode 100644 index 0000000000000..4739f05808daf --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/window_basic_asc_nulls_last.sql @@ -0,0 +1,5 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT key, value, ROUND(AVG(key) OVER (), 2) +FROM parquet_t1 ORDER BY key nulls last +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `value`, `gen_attr_2` AS `round(avg(key) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), 2)` FROM (SELECT `gen_attr_0`, `gen_attr_1`, round(`gen_attr_3`, 2) AS `gen_attr_2` FROM (SELECT gen_subquery_1.`gen_attr_0`, gen_subquery_1.`gen_attr_1`, avg(`gen_attr_0`) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS `gen_attr_3` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_1` FROM `default`.`parquet_t1`) AS gen_subquery_0) AS gen_subquery_1) AS gen_subquery_2 ORDER BY `gen_attr_0` ASC NULLS LAST) AS parquet_t1 diff --git a/sql/hive/src/test/resources/sqlgen/window_basic_desc_nulls_first.sql b/sql/hive/src/test/resources/sqlgen/window_basic_desc_nulls_first.sql new file mode 100644 index 0000000000000..1b9db2993b09d --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/window_basic_desc_nulls_first.sql @@ -0,0 +1,5 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT key, value, ROUND(AVG(key) OVER (), 2) +FROM parquet_t1 ORDER BY key desc nulls first +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `value`, `gen_attr_2` AS `round(avg(key) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), 2)` FROM (SELECT `gen_attr_0`, `gen_attr_1`, round(`gen_attr_3`, 2) AS `gen_attr_2` FROM (SELECT gen_subquery_1.`gen_attr_0`, gen_subquery_1.`gen_attr_1`, avg(`gen_attr_0`) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS `gen_attr_3` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_1` FROM `default`.`parquet_t1`) AS gen_subquery_0) AS gen_subquery_1) AS gen_subquery_2 ORDER BY `gen_attr_0` DESC NULLS FIRST) AS parquet_t1 diff --git a/sql/hive/src/test/resources/sqlgen/window_with_join.sql b/sql/hive/src/test/resources/sqlgen/window_with_join.sql index 030a4c0907a1c..43d5b47be8fba 100644 --- a/sql/hive/src/test/resources/sqlgen/window_with_join.sql +++ b/sql/hive/src/test/resources/sqlgen/window_with_join.sql @@ -2,4 +2,4 @@ SELECT x.key, MAX(y.key) OVER (PARTITION BY x.key % 5 ORDER BY x.key) FROM parquet_t1 x JOIN parquet_t1 y ON x.key = y.key -------------------------------------------------------------------------------- -SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `max(key) OVER (PARTITION BY (key % CAST(5 AS BIGINT)) ORDER BY key ASC RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT gen_subquery_2.`gen_attr_0`, gen_subquery_2.`gen_attr_2`, gen_subquery_2.`gen_attr_3`, max(`gen_attr_2`) OVER (PARTITION BY `gen_attr_3` ORDER BY `gen_attr_0` ASC RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS `gen_attr_1` FROM (SELECT `gen_attr_0`, `gen_attr_2`, (`gen_attr_0` % CAST(5 AS BIGINT)) AS `gen_attr_3` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_4` FROM `default`.`parquet_t1`) AS gen_subquery_0 INNER JOIN (SELECT `key` AS `gen_attr_2`, `value` AS `gen_attr_5` FROM `default`.`parquet_t1`) AS gen_subquery_1 ON (`gen_attr_0` = `gen_attr_2`)) AS gen_subquery_2) AS gen_subquery_3) AS x +SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `max(key) OVER (PARTITION BY (key % CAST(5 AS BIGINT)) ORDER BY key ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT gen_subquery_2.`gen_attr_0`, gen_subquery_2.`gen_attr_2`, gen_subquery_2.`gen_attr_3`, max(`gen_attr_2`) OVER (PARTITION BY `gen_attr_3` ORDER BY `gen_attr_0` ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS `gen_attr_1` FROM (SELECT `gen_attr_0`, `gen_attr_2`, (`gen_attr_0` % CAST(5 AS BIGINT)) AS `gen_attr_3` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_4` FROM `default`.`parquet_t1`) AS gen_subquery_0 INNER JOIN (SELECT `key` AS `gen_attr_2`, `value` AS `gen_attr_5` FROM `default`.`parquet_t1`) AS gen_subquery_1 ON (`gen_attr_0` = `gen_attr_2`)) AS gen_subquery_2) AS gen_subquery_3) AS x diff --git a/sql/hive/src/test/resources/sqlgen/window_with_the_same_window_with_agg.sql b/sql/hive/src/test/resources/sqlgen/window_with_the_same_window_with_agg.sql index 7b99539a05480..33a8e83750be0 100644 --- a/sql/hive/src/test/resources/sqlgen/window_with_the_same_window_with_agg.sql +++ b/sql/hive/src/test/resources/sqlgen/window_with_the_same_window_with_agg.sql @@ -4,4 +4,4 @@ DENSE_RANK() OVER (DISTRIBUTE BY key SORT BY key, value) AS dr, COUNT(key) FROM parquet_t1 GROUP BY key, value -------------------------------------------------------------------------------- -SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `value`, `gen_attr_2` AS `dr`, `gen_attr_3` AS `count(key)` FROM (SELECT `gen_attr_0`, `gen_attr_1`, `gen_attr_2`, `gen_attr_3` FROM (SELECT gen_subquery_1.`gen_attr_0`, gen_subquery_1.`gen_attr_1`, gen_subquery_1.`gen_attr_3`, DENSE_RANK() OVER (PARTITION BY `gen_attr_0` ORDER BY `gen_attr_0` ASC, `gen_attr_1` ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS `gen_attr_2` FROM (SELECT `gen_attr_0`, `gen_attr_1`, count(`gen_attr_0`) AS `gen_attr_3` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_1` FROM `default`.`parquet_t1`) AS gen_subquery_0 GROUP BY `gen_attr_0`, `gen_attr_1`) AS gen_subquery_1) AS gen_subquery_2) AS parquet_t1 +SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `value`, `gen_attr_2` AS `dr`, `gen_attr_3` AS `count(key)` FROM (SELECT `gen_attr_0`, `gen_attr_1`, `gen_attr_2`, `gen_attr_3` FROM (SELECT gen_subquery_1.`gen_attr_0`, gen_subquery_1.`gen_attr_1`, gen_subquery_1.`gen_attr_3`, DENSE_RANK() OVER (PARTITION BY `gen_attr_0` ORDER BY `gen_attr_0` ASC NULLS FIRST, `gen_attr_1` ASC NULLS FIRST ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS `gen_attr_2` FROM (SELECT `gen_attr_0`, `gen_attr_1`, count(`gen_attr_0`) AS `gen_attr_3` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_1` FROM `default`.`parquet_t1`) AS gen_subquery_0 GROUP BY `gen_attr_0`, `gen_attr_1`) AS gen_subquery_1) AS gen_subquery_2) AS parquet_t1 diff --git a/sql/hive/src/test/resources/sqlgen/window_with_the_same_window_with_agg_filter.sql b/sql/hive/src/test/resources/sqlgen/window_with_the_same_window_with_agg_filter.sql index 591a654a3888e..e01bc034d3d12 100644 --- a/sql/hive/src/test/resources/sqlgen/window_with_the_same_window_with_agg_filter.sql +++ b/sql/hive/src/test/resources/sqlgen/window_with_the_same_window_with_agg_filter.sql @@ -4,4 +4,4 @@ DENSE_RANK() OVER (DISTRIBUTE BY key SORT BY key, value) AS dr, COUNT(key) OVER(DISTRIBUTE BY key SORT BY key, value) AS ca FROM parquet_t1 -------------------------------------------------------------------------------- -SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `value`, `gen_attr_2` AS `dr`, `gen_attr_3` AS `ca` FROM (SELECT `gen_attr_0`, `gen_attr_1`, `gen_attr_2`, `gen_attr_3` FROM (SELECT gen_subquery_1.`gen_attr_0`, gen_subquery_1.`gen_attr_1`, DENSE_RANK() OVER (PARTITION BY `gen_attr_0` ORDER BY `gen_attr_0` ASC, `gen_attr_1` ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS `gen_attr_2`, count(`gen_attr_0`) OVER (PARTITION BY `gen_attr_0` ORDER BY `gen_attr_0` ASC, `gen_attr_1` ASC RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS `gen_attr_3` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_1` FROM `default`.`parquet_t1`) AS gen_subquery_0) AS gen_subquery_1) AS gen_subquery_2) AS parquet_t1 +SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `value`, `gen_attr_2` AS `dr`, `gen_attr_3` AS `ca` FROM (SELECT `gen_attr_0`, `gen_attr_1`, `gen_attr_2`, `gen_attr_3` FROM (SELECT gen_subquery_1.`gen_attr_0`, gen_subquery_1.`gen_attr_1`, DENSE_RANK() OVER (PARTITION BY `gen_attr_0` ORDER BY `gen_attr_0` ASC NULLS FIRST, `gen_attr_1` ASC NULLS FIRST ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS `gen_attr_2`, count(`gen_attr_0`) OVER (PARTITION BY `gen_attr_0` ORDER BY `gen_attr_0` ASC NULLS FIRST, `gen_attr_1` ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS `gen_attr_3` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_1` FROM `default`.`parquet_t1`) AS gen_subquery_0) AS gen_subquery_1) AS gen_subquery_2) AS parquet_t1 diff --git a/sql/hive/src/test/resources/sqlgen/window_with_the_same_window_with_agg_functions.sql b/sql/hive/src/test/resources/sqlgen/window_with_the_same_window_with_agg_functions.sql index d9169eab6e46a..dbfa408fa517e 100644 --- a/sql/hive/src/test/resources/sqlgen/window_with_the_same_window_with_agg_functions.sql +++ b/sql/hive/src/test/resources/sqlgen/window_with_the_same_window_with_agg_functions.sql @@ -3,4 +3,4 @@ SELECT key, value, MAX(value) OVER (PARTITION BY key % 5 ORDER BY key) AS max FROM parquet_t1 GROUP BY key, value -------------------------------------------------------------------------------- -SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `value`, `gen_attr_2` AS `max` FROM (SELECT `gen_attr_0`, `gen_attr_1`, `gen_attr_2` FROM (SELECT gen_subquery_1.`gen_attr_0`, gen_subquery_1.`gen_attr_1`, gen_subquery_1.`gen_attr_3`, max(`gen_attr_1`) OVER (PARTITION BY `gen_attr_3` ORDER BY `gen_attr_0` ASC RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS `gen_attr_2` FROM (SELECT `gen_attr_0`, `gen_attr_1`, (`gen_attr_0` % CAST(5 AS BIGINT)) AS `gen_attr_3` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_1` FROM `default`.`parquet_t1`) AS gen_subquery_0 GROUP BY `gen_attr_0`, `gen_attr_1`) AS gen_subquery_1) AS gen_subquery_2) AS parquet_t1 +SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `value`, `gen_attr_2` AS `max` FROM (SELECT `gen_attr_0`, `gen_attr_1`, `gen_attr_2` FROM (SELECT gen_subquery_1.`gen_attr_0`, gen_subquery_1.`gen_attr_1`, gen_subquery_1.`gen_attr_3`, max(`gen_attr_1`) OVER (PARTITION BY `gen_attr_3` ORDER BY `gen_attr_0` ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS `gen_attr_2` FROM (SELECT `gen_attr_0`, `gen_attr_1`, (`gen_attr_0` % CAST(5 AS BIGINT)) AS `gen_attr_3` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_1` FROM `default`.`parquet_t1`) AS gen_subquery_0 GROUP BY `gen_attr_0`, `gen_attr_1`) AS gen_subquery_1) AS gen_subquery_2) AS parquet_t1 diff --git a/sql/hive/src/test/resources/sqlgen/window_with_the_same_window_with_agg_having.sql b/sql/hive/src/test/resources/sqlgen/window_with_the_same_window_with_agg_having.sql index f0a820811ee0a..6f5741b946262 100644 --- a/sql/hive/src/test/resources/sqlgen/window_with_the_same_window_with_agg_having.sql +++ b/sql/hive/src/test/resources/sqlgen/window_with_the_same_window_with_agg_having.sql @@ -3,4 +3,4 @@ SELECT key, value, MAX(value) OVER (PARTITION BY key % 5 ORDER BY key DESC) AS max FROM parquet_t1 GROUP BY key, value HAVING key > 5 -------------------------------------------------------------------------------- -SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `value`, `gen_attr_2` AS `max` FROM (SELECT `gen_attr_0`, `gen_attr_1`, `gen_attr_2` FROM (SELECT gen_subquery_1.`gen_attr_0`, gen_subquery_1.`gen_attr_1`, gen_subquery_1.`gen_attr_3`, max(`gen_attr_1`) OVER (PARTITION BY `gen_attr_3` ORDER BY `gen_attr_0` DESC RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS `gen_attr_2` FROM (SELECT `gen_attr_0`, `gen_attr_1`, (`gen_attr_0` % CAST(5 AS BIGINT)) AS `gen_attr_3` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_1` FROM `default`.`parquet_t1`) AS gen_subquery_0 GROUP BY `gen_attr_0`, `gen_attr_1` HAVING (`gen_attr_0` > CAST(5 AS BIGINT))) AS gen_subquery_1) AS gen_subquery_2) AS parquet_t1 +SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `value`, `gen_attr_2` AS `max` FROM (SELECT `gen_attr_0`, `gen_attr_1`, `gen_attr_2` FROM (SELECT gen_subquery_1.`gen_attr_0`, gen_subquery_1.`gen_attr_1`, gen_subquery_1.`gen_attr_3`, max(`gen_attr_1`) OVER (PARTITION BY `gen_attr_3` ORDER BY `gen_attr_0` DESC NULLS LAST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS `gen_attr_2` FROM (SELECT `gen_attr_0`, `gen_attr_1`, (`gen_attr_0` % CAST(5 AS BIGINT)) AS `gen_attr_3` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_1` FROM `default`.`parquet_t1`) AS gen_subquery_0 GROUP BY `gen_attr_0`, `gen_attr_1` HAVING (`gen_attr_0` > CAST(5 AS BIGINT))) AS gen_subquery_1) AS gen_subquery_2) AS parquet_t1 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionSQLBuilderSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionSQLBuilderSuite.scala index d2b2f38fa1f71..ce5efe853ca4f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionSQLBuilderSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionSQLBuilderSuite.scala @@ -106,17 +106,17 @@ class ExpressionSQLBuilderSuite extends SQLBuilderTest { checkSQL( WindowSpecDefinition(Nil, 'a.int.asc :: Nil, frame), - s"(ORDER BY `a` ASC $frame)" + s"(ORDER BY `a` ASC NULLS FIRST $frame)" ) checkSQL( WindowSpecDefinition(Nil, 'a.int.asc :: 'b.string.desc :: Nil, frame), - s"(ORDER BY `a` ASC, `b` DESC $frame)" + s"(ORDER BY `a` ASC NULLS FIRST, `b` DESC NULLS LAST $frame)" ) checkSQL( WindowSpecDefinition('a.int :: 'b.string :: Nil, 'c.int.asc :: 'd.string.desc :: Nil, frame), - s"(PARTITION BY `a`, `b` ORDER BY `c` ASC, `d` DESC $frame)" + s"(PARTITION BY `a`, `b` ORDER BY `c` ASC NULLS FIRST, `d` DESC NULLS LAST $frame)" ) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala index d80f894c22dd8..7fa5c29dc5b8f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala @@ -235,6 +235,16 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { checkSQL("SELECT COUNT(value) FROM parquet_t1 GROUP BY key ORDER BY key, MAX(key)", "agg3") } + test("order by asc nulls last") { + checkSQL("SELECT COUNT(value) FROM parquet_t1 GROUP BY key ORDER BY key nulls last, MAX(key)", + "sort_asc_nulls_last") + } + + test("order by desc nulls first") { + checkSQL("SELECT COUNT(value) FROM parquet_t1 GROUP BY key ORDER BY key desc nulls first," + + "MAX(key)", "sort_desc_nulls_first") + } + test("type widening in union") { checkSQL("SELECT id FROM parquet_t0 UNION ALL SELECT CAST(id AS INT) AS id FROM parquet_t0", "type_widening") @@ -697,6 +707,20 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { |FROM parquet_t1 """.stripMargin, "window_basic_3") + + checkSQL( + """ + |SELECT key, value, ROUND(AVG(key) OVER (), 2) + |FROM parquet_t1 ORDER BY key nulls last + """.stripMargin, + "window_basic_asc_nulls_last") + + checkSQL( + """ + |SELECT key, value, ROUND(AVG(key) OVER (), 2) + |FROM parquet_t1 ORDER BY key desc nulls first + """.stripMargin, + "window_basic_desc_nulls_first") } test("multiple window functions in one expression") { From ff6e4cbdc80e2ad84c5d70ee07f323fad9374e3e Mon Sep 17 00:00:00 2001 From: Kishor Patil Date: Wed, 14 Sep 2016 14:19:35 -0500 Subject: [PATCH 014/306] [SPARK-17511] Yarn Dynamic Allocation: Avoid marking released container as Failed ## What changes were proposed in this pull request? Due to race conditions, the ` assert(numExecutorsRunning <= targetNumExecutors)` can fail causing `AssertionError`. So removed the assertion, instead moved the conditional check before launching new container: ``` java.lang.AssertionError: assertion failed at scala.Predef$.assert(Predef.scala:156) at org.apache.spark.deploy.yarn.YarnAllocator$$anonfun$runAllocatedContainers$1.org$apache$spark$deploy$yarn$YarnAllocator$$anonfun$$updateInternalState$1(YarnAllocator.scala:489) at org.apache.spark.deploy.yarn.YarnAllocator$$anonfun$runAllocatedContainers$1$$anon$1.run(YarnAllocator.scala:519) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617) at java.lang.Thread.run(Thread.java:745) ``` ## How was this patch tested? This was manually tested using a large ForkAndJoin job with Dynamic Allocation enabled to validate the failing job succeeds, without any such exception. Author: Kishor Patil Closes #15069 from kishorvpatil/SPARK-17511. --- .../spark/deploy/yarn/YarnAllocator.scala | 62 ++++++++++--------- .../deploy/yarn/YarnAllocatorSuite.scala | 19 ++++++ 2 files changed, 52 insertions(+), 29 deletions(-) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index 2f4b498b3ca74..0b66d1cf08eac 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -496,7 +496,6 @@ private[yarn] class YarnAllocator( def updateInternalState(): Unit = synchronized { numExecutorsRunning += 1 - assert(numExecutorsRunning <= targetNumExecutors) executorIdToContainer(executorId) = container containerIdToExecutorId(container.getId) = executorId @@ -506,36 +505,41 @@ private[yarn] class YarnAllocator( allocatedContainerToHostMap.put(containerId, executorHostname) } - if (launchContainers) { - launcherPool.execute(new Runnable { - override def run(): Unit = { - try { - new ExecutorRunnable( - Some(container), - conf, - sparkConf, - driverUrl, - executorId, - executorHostname, - executorMemory, - executorCores, - appAttemptId.getApplicationId.toString, - securityMgr, - localResources - ).run() - updateInternalState() - } catch { - case NonFatal(e) => - logError(s"Failed to launch executor $executorId on container $containerId", e) - // Assigned container should be released immediately to avoid unnecessary resource - // occupation. - amClient.releaseAssignedContainer(containerId) + if (numExecutorsRunning < targetNumExecutors) { + if (launchContainers) { + launcherPool.execute(new Runnable { + override def run(): Unit = { + try { + new ExecutorRunnable( + Some(container), + conf, + sparkConf, + driverUrl, + executorId, + executorHostname, + executorMemory, + executorCores, + appAttemptId.getApplicationId.toString, + securityMgr, + localResources + ).run() + updateInternalState() + } catch { + case NonFatal(e) => + logError(s"Failed to launch executor $executorId on container $containerId", e) + // Assigned container should be released immediately to avoid unnecessary resource + // occupation. + amClient.releaseAssignedContainer(containerId) + } } - } - }) + }) + } else { + // For test only + updateInternalState() + } } else { - // For test only - updateInternalState() + logInfo(("Skip launching executorRunnable as runnning Excecutors count: %d " + + "reached target Executors count: %d.").format(numExecutorsRunning, targetNumExecutors)) } } } diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala index 696e552c35d12..994dc75d34c30 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala @@ -136,6 +136,25 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter size should be (0) } + test("container should not be created if requested number if met") { + // request a single container and receive it + val handler = createAllocator(1) + handler.updateResourceRequests() + handler.getNumExecutorsRunning should be (0) + handler.getPendingAllocate.size should be (1) + + val container = createContainer("host1") + handler.handleAllocatedContainers(Array(container)) + + handler.getNumExecutorsRunning should be (1) + handler.allocatedContainerToHostMap.get(container.getId).get should be ("host1") + handler.allocatedHostToContainersMap.get("host1").get should contain (container.getId) + + val container2 = createContainer("host2") + handler.handleAllocatedContainers(Array(container2)) + handler.getNumExecutorsRunning should be (1) + } + test("some containers allocated") { // request a few containers and receive some of them val handler = createAllocator(4) From e33bfaed3b160fbc617c878067af17477a0044f5 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Wed, 14 Sep 2016 13:33:51 -0700 Subject: [PATCH 015/306] [SPARK-17463][CORE] Make CollectionAccumulator and SetAccumulator's value can be read thread-safely ## What changes were proposed in this pull request? Make CollectionAccumulator and SetAccumulator's value can be read thread-safely to fix the ConcurrentModificationException reported in [JIRA](https://issues.apache.org/jira/browse/SPARK-17463). ## How was this patch tested? Existing tests. Author: Shixiong Zhu Closes #15063 from zsxwing/SPARK-17463. --- .../apache/spark/executor/TaskMetrics.scala | 41 ++++++++++++------- .../org/apache/spark/util/AccumulatorV2.scala | 7 +++- .../org/apache/spark/util/JsonProtocol.scala | 11 ++--- .../apache/spark/util/JsonProtocolSuite.scala | 3 +- .../spark/sql/execution/debug/package.scala | 24 +++++++---- 5 files changed, 54 insertions(+), 32 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index dd149a919fe55..52a349919e336 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -17,6 +17,9 @@ package org.apache.spark.executor +import java.util.{ArrayList, Collections} + +import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, LinkedHashMap} import org.apache.spark._ @@ -99,7 +102,11 @@ class TaskMetrics private[spark] () extends Serializable { /** * Storage statuses of any blocks that have been updated as a result of this task. */ - def updatedBlockStatuses: Seq[(BlockId, BlockStatus)] = _updatedBlockStatuses.value + def updatedBlockStatuses: Seq[(BlockId, BlockStatus)] = { + // This is called on driver. All accumulator updates have a fixed value. So it's safe to use + // `asScala` which accesses the internal values using `java.util.Iterator`. + _updatedBlockStatuses.value.asScala + } // Setters and increment-ers private[spark] def setExecutorDeserializeTime(v: Long): Unit = @@ -114,8 +121,10 @@ class TaskMetrics private[spark] () extends Serializable { private[spark] def incPeakExecutionMemory(v: Long): Unit = _peakExecutionMemory.add(v) private[spark] def incUpdatedBlockStatuses(v: (BlockId, BlockStatus)): Unit = _updatedBlockStatuses.add(v) - private[spark] def setUpdatedBlockStatuses(v: Seq[(BlockId, BlockStatus)]): Unit = + private[spark] def setUpdatedBlockStatuses(v: java.util.List[(BlockId, BlockStatus)]): Unit = _updatedBlockStatuses.setValue(v) + private[spark] def setUpdatedBlockStatuses(v: Seq[(BlockId, BlockStatus)]): Unit = + _updatedBlockStatuses.setValue(v.asJava) /** * Metrics related to reading data from a [[org.apache.spark.rdd.HadoopRDD]] or from persisted @@ -268,7 +277,7 @@ private[spark] object TaskMetrics extends Logging { val name = info.name.get val value = info.update.get if (name == UPDATED_BLOCK_STATUSES) { - tm.setUpdatedBlockStatuses(value.asInstanceOf[Seq[(BlockId, BlockStatus)]]) + tm.setUpdatedBlockStatuses(value.asInstanceOf[java.util.List[(BlockId, BlockStatus)]]) } else { tm.nameToAccums.get(name).foreach( _.asInstanceOf[LongAccumulator].setValue(value.asInstanceOf[Long]) @@ -299,8 +308,8 @@ private[spark] object TaskMetrics extends Logging { private[spark] class BlockStatusesAccumulator - extends AccumulatorV2[(BlockId, BlockStatus), Seq[(BlockId, BlockStatus)]] { - private var _seq = ArrayBuffer.empty[(BlockId, BlockStatus)] + extends AccumulatorV2[(BlockId, BlockStatus), java.util.List[(BlockId, BlockStatus)]] { + private val _seq = Collections.synchronizedList(new ArrayList[(BlockId, BlockStatus)]()) override def isZero(): Boolean = _seq.isEmpty @@ -308,25 +317,27 @@ private[spark] class BlockStatusesAccumulator override def copy(): BlockStatusesAccumulator = { val newAcc = new BlockStatusesAccumulator - newAcc._seq = _seq.clone() + newAcc._seq.addAll(_seq) newAcc } override def reset(): Unit = _seq.clear() - override def add(v: (BlockId, BlockStatus)): Unit = _seq += v + override def add(v: (BlockId, BlockStatus)): Unit = _seq.add(v) - override def merge(other: AccumulatorV2[(BlockId, BlockStatus), Seq[(BlockId, BlockStatus)]]) - : Unit = other match { - case o: BlockStatusesAccumulator => _seq ++= o.value - case _ => throw new UnsupportedOperationException( - s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}") + override def merge( + other: AccumulatorV2[(BlockId, BlockStatus), java.util.List[(BlockId, BlockStatus)]]): Unit = { + other match { + case o: BlockStatusesAccumulator => _seq.addAll(o.value) + case _ => throw new UnsupportedOperationException( + s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}") + } } - override def value: Seq[(BlockId, BlockStatus)] = _seq + override def value: java.util.List[(BlockId, BlockStatus)] = _seq - def setValue(newValue: Seq[(BlockId, BlockStatus)]): Unit = { + def setValue(newValue: java.util.List[(BlockId, BlockStatus)]): Unit = { _seq.clear() - _seq ++= newValue + _seq.addAll(newValue) } } diff --git a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala index d130a37db5b5d..470d912ecff13 100644 --- a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala +++ b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala @@ -19,7 +19,7 @@ package org.apache.spark.util import java.{lang => jl} import java.io.ObjectInputStream -import java.util.ArrayList +import java.util.{ArrayList, Collections} import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.atomic.AtomicLong @@ -38,6 +38,9 @@ private[spark] case class AccumulatorMetadata( /** * The base class for accumulators, that can accumulate inputs of type `IN`, and produce output of * type `OUT`. + * + * `OUT` should be a type that can be read atomically (e.g., Int, Long), or thread-safely + * (e.g., synchronized collections) because it will be read from other threads. */ abstract class AccumulatorV2[IN, OUT] extends Serializable { private[spark] var metadata: AccumulatorMetadata = _ @@ -433,7 +436,7 @@ class DoubleAccumulator extends AccumulatorV2[jl.Double, jl.Double] { * @since 2.0.0 */ class CollectionAccumulator[T] extends AccumulatorV2[T, java.util.List[T]] { - private val _list: java.util.List[T] = new ArrayList[T] + private val _list: java.util.List[T] = Collections.synchronizedList(new ArrayList[T]()) override def isZero: Boolean = _list.isEmpty diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 022b226894105..41d947c4428ad 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -310,11 +310,12 @@ private[spark] object JsonProtocol { case v: Int => JInt(v) case v: Long => JInt(v) // We only have 3 kind of internal accumulator types, so if it's not int or long, it must be - // the blocks accumulator, whose type is `Seq[(BlockId, BlockStatus)]` + // the blocks accumulator, whose type is `java.util.List[(BlockId, BlockStatus)]` case v => - JArray(v.asInstanceOf[Seq[(BlockId, BlockStatus)]].toList.map { case (id, status) => - ("Block ID" -> id.toString) ~ - ("Status" -> blockStatusToJson(status)) + JArray(v.asInstanceOf[java.util.List[(BlockId, BlockStatus)]].asScala.toList.map { + case (id, status) => + ("Block ID" -> id.toString) ~ + ("Status" -> blockStatusToJson(status)) }) } } else { @@ -743,7 +744,7 @@ private[spark] object JsonProtocol { val id = BlockId((blockJson \ "Block ID").extract[String]) val status = blockStatusFromJson(blockJson \ "Status") (id, status) - } + }.asJava case _ => throw new IllegalArgumentException(s"unexpected json value $value for " + "accumulator " + name.get) } diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 85ca9d39d4a3f..c89be22a34c9d 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.util import java.util.Properties +import scala.collection.JavaConverters._ import scala.collection.Map import org.json4s.jackson.JsonMethods._ @@ -415,7 +416,7 @@ class JsonProtocolSuite extends SparkFunSuite { }) testAccumValue(Some(RESULT_SIZE), 3L, JInt(3)) testAccumValue(Some(shuffleRead.REMOTE_BLOCKS_FETCHED), 2, JInt(2)) - testAccumValue(Some(UPDATED_BLOCK_STATUSES), blocks, blocksJson) + testAccumValue(Some(UPDATED_BLOCK_STATUSES), blocks.asJava, blocksJson) // For anything else, we just cast the value to a string testAccumValue(Some("anything"), blocks, JString(blocks.toString)) testAccumValue(Some("anything"), 123, JString("123")) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index 082f97a8808fa..d321f4cd76877 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql.execution -import scala.collection.mutable.HashSet +import java.util.Collections + +import scala.collection.JavaConverters._ import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD @@ -107,18 +109,20 @@ package object debug { case class DebugExec(child: SparkPlan) extends UnaryExecNode with CodegenSupport { def output: Seq[Attribute] = child.output - class SetAccumulator[T] extends AccumulatorV2[T, HashSet[T]] { - private val _set = new HashSet[T]() + class SetAccumulator[T] extends AccumulatorV2[T, java.util.Set[T]] { + private val _set = Collections.synchronizedSet(new java.util.HashSet[T]()) override def isZero: Boolean = _set.isEmpty - override def copy(): AccumulatorV2[T, HashSet[T]] = { + override def copy(): AccumulatorV2[T, java.util.Set[T]] = { val newAcc = new SetAccumulator[T]() - newAcc._set ++= _set + newAcc._set.addAll(_set) newAcc } override def reset(): Unit = _set.clear() - override def add(v: T): Unit = _set += v - override def merge(other: AccumulatorV2[T, HashSet[T]]): Unit = _set ++= other.value - override def value: HashSet[T] = _set + override def add(v: T): Unit = _set.add(v) + override def merge(other: AccumulatorV2[T, java.util.Set[T]]): Unit = { + _set.addAll(other.value) + } + override def value: java.util.Set[T] = _set } /** @@ -138,7 +142,9 @@ package object debug { debugPrint(s"== ${child.simpleString} ==") debugPrint(s"Tuples output: ${tupleCount.value}") child.output.zip(columnStats).foreach { case (attr, metric) => - val actualDataTypes = metric.elementTypes.value.mkString("{", ",", "}") + // This is called on driver. All accumulator updates have a fixed value. So it's safe to use + // `asScala` which accesses the internal values using `java.util.Iterator`. + val actualDataTypes = metric.elementTypes.value.asScala.mkString("{", ",", "}") debugPrint(s" ${attr.name} ${attr.dataType}: $actualDataTypes") } } From dbfc7aa4d0d5457bc92e1e66d065c6088d476843 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Wed, 14 Sep 2016 13:37:35 -0700 Subject: [PATCH 016/306] [SPARK-17472] [PYSPARK] Better error message for serialization failures of large objects in Python ## What changes were proposed in this pull request? For large objects, pickle does not raise useful error messages. However, we can wrap them to be slightly more user friendly: Example 1: ``` def run(): import numpy.random as nr b = nr.bytes(8 * 1000000000) sc.parallelize(range(1000), 1000).map(lambda x: len(b)).count() run() ``` Before: ``` error: 'i' format requires -2147483648 <= number <= 2147483647 ``` After: ``` pickle.PicklingError: Object too large to serialize: 'i' format requires -2147483648 <= number <= 2147483647 ``` Example 2: ``` def run(): import numpy.random as nr b = sc.broadcast(nr.bytes(8 * 1000000000)) sc.parallelize(range(1000), 1000).map(lambda x: len(b.value)).count() run() ``` Before: ``` SystemError: error return without exception set ``` After: ``` cPickle.PicklingError: Could not serialize broadcast: SystemError: error return without exception set ``` ## How was this patch tested? Manually tried out these cases cc davies Author: Eric Liang Closes #15026 from ericl/spark-17472. --- python/pyspark/broadcast.py | 11 ++++++++++- python/pyspark/cloudpickle.py | 10 ++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py index a0b819220e6d3..74dee1420754a 100644 --- a/python/pyspark/broadcast.py +++ b/python/pyspark/broadcast.py @@ -20,6 +20,8 @@ import gc from tempfile import NamedTemporaryFile +from pyspark.cloudpickle import print_exec + if sys.version < '3': import cPickle as pickle else: @@ -75,7 +77,14 @@ def __init__(self, sc=None, value=None, pickle_registry=None, path=None): self._path = path def dump(self, value, f): - pickle.dump(value, f, 2) + try: + pickle.dump(value, f, 2) + except pickle.PickleError: + raise + except Exception as e: + msg = "Could not serialize broadcast: " + e.__class__.__name__ + ": " + e.message + print_exec(sys.stderr) + raise pickle.PicklingError(msg) f.close() return f.name diff --git a/python/pyspark/cloudpickle.py b/python/pyspark/cloudpickle.py index 822ae46e45111..da2b2f3757967 100644 --- a/python/pyspark/cloudpickle.py +++ b/python/pyspark/cloudpickle.py @@ -109,6 +109,16 @@ def dump(self, obj): if 'recursion' in e.args[0]: msg = """Could not pickle object as excessively deep recursion required.""" raise pickle.PicklingError(msg) + except pickle.PickleError: + raise + except Exception as e: + if "'i' format requires" in e.message: + msg = "Object too large to serialize: " + e.message + else: + msg = "Could not serialize object: " + e.__class__.__name__ + ": " + e.message + print_exec(sys.stderr) + raise pickle.PicklingError(msg) + def save_memoryview(self, obj): """Fallback to save_string""" From bb322943623d14b85283705e74d913e31230387f Mon Sep 17 00:00:00 2001 From: Xing SHI Date: Wed, 14 Sep 2016 13:46:46 -0700 Subject: [PATCH 017/306] [SPARK-17465][SPARK CORE] Inappropriate memory management in `org.apache.spark.storage.MemoryStore` may lead to memory leak The expression like `if (memoryMap(taskAttemptId) == 0) memoryMap.remove(taskAttemptId)` in method `releaseUnrollMemoryForThisTask` and `releasePendingUnrollMemoryForThisTask` should be called after release memory operation, whatever `memoryToRelease` is > 0 or not. If the memory of a task has been set to 0 when calling a `releaseUnrollMemoryForThisTask` or a `releasePendingUnrollMemoryForThisTask` method, the key in the memory map corresponding to that task will never be removed from the hash map. See the details in [SPARK-17465](https://issues.apache.org/jira/browse/SPARK-17465). Author: Xing SHI Closes #15022 from saturday-shi/SPARK-17465. --- .../scala/org/apache/spark/storage/memory/MemoryStore.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala index baa3fde2d05f1..ec1b0f7149271 100644 --- a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala @@ -593,11 +593,11 @@ private[spark] class MemoryStore( val memoryToRelease = math.min(memory, unrollMemoryMap(taskAttemptId)) if (memoryToRelease > 0) { unrollMemoryMap(taskAttemptId) -= memoryToRelease - if (unrollMemoryMap(taskAttemptId) == 0) { - unrollMemoryMap.remove(taskAttemptId) - } memoryManager.releaseUnrollMemory(memoryToRelease, memoryMode) } + if (unrollMemoryMap(taskAttemptId) == 0) { + unrollMemoryMap.remove(taskAttemptId) + } } } } From 6a6adb1673775df63a62270879eac70f5f8d7d75 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 15 Sep 2016 14:43:10 +0800 Subject: [PATCH 018/306] [SPARK-17440][SPARK-17441] Fixed Multiple Bugs in ALTER TABLE ### What changes were proposed in this pull request? For the following `ALTER TABLE` DDL, we should issue an exception when the target table is a `VIEW`: ```SQL ALTER TABLE viewName SET LOCATION '/path/to/your/lovely/heart' ALTER TABLE viewName SET SERDE 'whatever' ALTER TABLE viewName SET SERDEPROPERTIES ('x' = 'y') ALTER TABLE viewName PARTITION (a=1, b=2) SET SERDEPROPERTIES ('x' = 'y') ALTER TABLE viewName ADD IF NOT EXISTS PARTITION (a='4', b='8') ALTER TABLE viewName DROP IF EXISTS PARTITION (a='2') ALTER TABLE viewName RECOVER PARTITIONS ALTER TABLE viewName PARTITION (a='1', b='q') RENAME TO PARTITION (a='100', b='p') ``` In addition, `ALTER TABLE RENAME PARTITION` is unable to handle data source tables, just like the other `ALTER PARTITION` commands. We should issue an exception instead. ### How was this patch tested? Added a few test cases. Author: gatorsmile Closes #15004 from gatorsmile/altertable. --- .../spark/sql/execution/command/ddl.scala | 45 +++++++++---- .../spark/sql/execution/command/tables.scala | 4 +- .../sql/execution/command/DDLSuite.scala | 63 +++++++++++++---- .../sql/hive/execution/HiveDDLSuite.scala | 67 ++++++++++--------- 4 files changed, 120 insertions(+), 59 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index dcda2f8d1c52a..c0ccdca98e05b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -230,8 +230,8 @@ case class AlterTableSetPropertiesCommand( override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog - DDLUtils.verifyAlterTableType(catalog, tableName, isView) val table = catalog.getTableMetadata(tableName) + DDLUtils.verifyAlterTableType(catalog, table, isView) // This overrides old properties val newTable = table.copy(properties = table.properties ++ properties) catalog.alterTable(newTable) @@ -258,8 +258,8 @@ case class AlterTableUnsetPropertiesCommand( override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog - DDLUtils.verifyAlterTableType(catalog, tableName, isView) val table = catalog.getTableMetadata(tableName) + DDLUtils.verifyAlterTableType(catalog, table, isView) if (!ifExists) { propKeys.foreach { k => if (!table.properties.contains(k)) { @@ -299,6 +299,7 @@ case class AlterTableSerDePropertiesCommand( override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog val table = catalog.getTableMetadata(tableName) + DDLUtils.verifyAlterTableType(catalog, table, isView = false) // For datasource tables, disallow setting serde or specifying partition if (partSpec.isDefined && DDLUtils.isDatasourceTable(table)) { throw new AnalysisException("Operation not allowed: ALTER TABLE SET " + @@ -348,6 +349,7 @@ case class AlterTableAddPartitionCommand( override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog val table = catalog.getTableMetadata(tableName) + DDLUtils.verifyAlterTableType(catalog, table, isView = false) if (DDLUtils.isDatasourceTable(table)) { throw new AnalysisException( "ALTER TABLE ADD PARTITION is not allowed for tables defined using the datasource API") @@ -377,7 +379,14 @@ case class AlterTableRenamePartitionCommand( extends RunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { - sparkSession.sessionState.catalog.renamePartitions( + val catalog = sparkSession.sessionState.catalog + val table = catalog.getTableMetadata(tableName) + if (DDLUtils.isDatasourceTable(table)) { + throw new AnalysisException( + "ALTER TABLE RENAME PARTITION is not allowed for tables defined using the datasource API") + } + DDLUtils.verifyAlterTableType(catalog, table, isView = false) + catalog.renamePartitions( tableName, Seq(oldPartition), Seq(newPartition)) Seq.empty[Row] } @@ -408,6 +417,7 @@ case class AlterTableDropPartitionCommand( override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog val table = catalog.getTableMetadata(tableName) + DDLUtils.verifyAlterTableType(catalog, table, isView = false) if (DDLUtils.isDatasourceTable(table)) { throw new AnalysisException( "ALTER TABLE DROP PARTITIONS is not allowed for tables defined using the datasource API") @@ -469,6 +479,7 @@ case class AlterTableRecoverPartitionsCommand( s"Operation not allowed: $cmd on temporary tables: $tableName") } val table = catalog.getTableMetadata(tableName) + DDLUtils.verifyAlterTableType(catalog, table, isView = false) if (DDLUtils.isDatasourceTable(table)) { throw new AnalysisException( s"Operation not allowed: $cmd on datasource tables: $tableName") @@ -644,6 +655,7 @@ case class AlterTableSetLocationCommand( override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog val table = catalog.getTableMetadata(tableName) + DDLUtils.verifyAlterTableType(catalog, table, isView = false) partitionSpec match { case Some(spec) => // Partition spec is specified, so we set the location only for this partition @@ -682,19 +694,26 @@ object DDLUtils { /** * If the command ALTER VIEW is to alter a table or ALTER TABLE is to alter a view, * issue an exception [[AnalysisException]]. + * + * Note: temporary views can be altered by both ALTER VIEW and ALTER TABLE commands, + * since temporary views can be also created by CREATE TEMPORARY TABLE. In the future, + * when we decided to drop the support, we should disallow users to alter temporary views + * by ALTER TABLE. */ def verifyAlterTableType( catalog: SessionCatalog, - tableIdentifier: TableIdentifier, + tableMetadata: CatalogTable, isView: Boolean): Unit = { - catalog.getTableMetadataOption(tableIdentifier).map(_.tableType match { - case CatalogTableType.VIEW if !isView => - throw new AnalysisException( - "Cannot alter a view with ALTER TABLE. Please use ALTER VIEW instead") - case o if o != CatalogTableType.VIEW && isView => - throw new AnalysisException( - s"Cannot alter a table with ALTER VIEW. Please use ALTER TABLE instead") - case _ => - }) + if (!catalog.isTemporaryTable(tableMetadata.identifier)) { + tableMetadata.tableType match { + case CatalogTableType.VIEW if !isView => + throw new AnalysisException( + "Cannot alter a view with ALTER TABLE. Please use ALTER VIEW instead") + case o if o != CatalogTableType.VIEW && isView => + throw new AnalysisException( + s"Cannot alter a table with ALTER VIEW. Please use ALTER TABLE instead") + case _ => + } + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index 9fbcd48b4a911..60e6b5db62a31 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -158,7 +158,8 @@ case class AlterTableRenameCommand( override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog - DDLUtils.verifyAlterTableType(catalog, oldName, isView) + val table = catalog.getTableMetadata(oldName) + DDLUtils.verifyAlterTableType(catalog, table, isView) // If this is a temp view, just rename the view. // Otherwise, if this is a real table, we also need to uncache and invalidate the table. val isTemporary = catalog.isTemporaryTable(oldName) @@ -177,7 +178,6 @@ case class AlterTableRenameCommand( } } // For datasource tables, we also need to update the "path" serde property - val table = catalog.getTableMetadata(oldName) if (DDLUtils.isDatasourceTable(table) && table.tableType == CatalogTableType.MANAGED) { val newPath = catalog.defaultTablePath(newTblName) val newTable = table.withNewStorage( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 95672e01f5546..4a171808c05ba 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -696,6 +696,18 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { assert(spark.table("teachers").collect().toSeq == df.collect().toSeq) } + test("rename temporary table") { + withTempView("tab1", "tab2") { + spark.range(10).createOrReplaceTempView("tab1") + sql("ALTER TABLE tab1 RENAME TO tab2") + checkAnswer(spark.table("tab2"), spark.range(10).toDF()) + intercept[NoSuchTableException] { spark.table("tab1") } + sql("ALTER VIEW tab2 RENAME TO tab1") + checkAnswer(spark.table("tab1"), spark.range(10).toDF()) + intercept[NoSuchTableException] { spark.table("tab2") } + } + } + test("rename temporary table - destination table already exists") { withTempView("tab1", "tab2") { sql( @@ -880,25 +892,16 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { test("alter table: rename partition") { val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) - val part1 = Map("a" -> "1", "b" -> "q") - val part2 = Map("a" -> "2", "b" -> "c") - val part3 = Map("a" -> "3", "b" -> "p") - createDatabase(catalog, "dbx") - createTable(catalog, tableIdent) - createTablePartition(catalog, part1, tableIdent) - createTablePartition(catalog, part2, tableIdent) - createTablePartition(catalog, part3, tableIdent) - assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == - Set(part1, part2, part3)) + createPartitionedTable(tableIdent, isDatasourceTable = false) sql("ALTER TABLE dbx.tab1 PARTITION (a='1', b='q') RENAME TO PARTITION (a='100', b='p')") - sql("ALTER TABLE dbx.tab1 PARTITION (a='2', b='c') RENAME TO PARTITION (a='200', b='c')") + sql("ALTER TABLE dbx.tab1 PARTITION (a='2', b='c') RENAME TO PARTITION (a='20', b='c')") assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == - Set(Map("a" -> "100", "b" -> "p"), Map("a" -> "200", "b" -> "c"), part3)) + Set(Map("a" -> "100", "b" -> "p"), Map("a" -> "20", "b" -> "c"), Map("a" -> "3", "b" -> "p"))) // rename without explicitly specifying database catalog.setCurrentDatabase("dbx") sql("ALTER TABLE tab1 PARTITION (a='100', b='p') RENAME TO PARTITION (a='10', b='p')") assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == - Set(Map("a" -> "10", "b" -> "p"), Map("a" -> "200", "b" -> "c"), part3)) + Set(Map("a" -> "10", "b" -> "p"), Map("a" -> "20", "b" -> "c"), Map("a" -> "3", "b" -> "p"))) // table to alter does not exist intercept[NoSuchTableException] { sql("ALTER TABLE does_not_exist PARTITION (c='3') RENAME TO PARTITION (c='333')") @@ -909,6 +912,38 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } } + test("alter table: rename partition (datasource table)") { + createPartitionedTable(TableIdentifier("tab1", Some("dbx")), isDatasourceTable = true) + val e = intercept[AnalysisException] { + sql("ALTER TABLE dbx.tab1 PARTITION (a='1', b='q') RENAME TO PARTITION (a='100', b='p')") + }.getMessage + assert(e.contains( + "ALTER TABLE RENAME PARTITION is not allowed for tables defined using the datasource API")) + // table to alter does not exist + intercept[NoSuchTableException] { + sql("ALTER TABLE does_not_exist PARTITION (c='3') RENAME TO PARTITION (c='333')") + } + } + + private def createPartitionedTable( + tableIdent: TableIdentifier, + isDatasourceTable: Boolean): Unit = { + val catalog = spark.sessionState.catalog + val part1 = Map("a" -> "1", "b" -> "q") + val part2 = Map("a" -> "2", "b" -> "c") + val part3 = Map("a" -> "3", "b" -> "p") + createDatabase(catalog, "dbx") + createTable(catalog, tableIdent) + createTablePartition(catalog, part1, tableIdent) + createTablePartition(catalog, part2, tableIdent) + createTablePartition(catalog, part3, tableIdent) + assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == + Set(part1, part2, part3)) + if (isDatasourceTable) { + convertToDatasourceTable(catalog, tableIdent) + } + } + test("show tables") { withTempView("show1a", "show2b") { sql( @@ -1255,7 +1290,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } // table to alter does not exist intercept[AnalysisException] { - sql("ALTER TABLE does_not_exist SET SERDEPROPERTIES ('x' = 'y')") + sql("ALTER TABLE does_not_exist PARTITION (a=1, b=2) SET SERDEPROPERTIES ('x' = 'y')") } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 3cba5b2a097f1..aa35a335facbf 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -305,6 +305,16 @@ class HiveDDLSuite } } + private def assertErrorForAlterTableOnView(sqlText: String): Unit = { + val message = intercept[AnalysisException](sql(sqlText)).getMessage + assert(message.contains("Cannot alter a view with ALTER TABLE. Please use ALTER VIEW instead")) + } + + private def assertErrorForAlterViewOnTable(sqlText: String): Unit = { + val message = intercept[AnalysisException](sql(sqlText)).getMessage + assert(message.contains("Cannot alter a table with ALTER VIEW. Please use ALTER TABLE instead")) + } + test("alter views and alter table - misuse") { val tabName = "tab1" withTable(tabName) { @@ -317,45 +327,42 @@ class HiveDDLSuite assert(catalog.tableExists(TableIdentifier(tabName))) assert(catalog.tableExists(TableIdentifier(oldViewName))) + assert(!catalog.tableExists(TableIdentifier(newViewName))) - var message = intercept[AnalysisException] { - sql(s"ALTER VIEW $tabName RENAME TO $newViewName") - }.getMessage - assert(message.contains( - "Cannot alter a table with ALTER VIEW. Please use ALTER TABLE instead")) + assertErrorForAlterViewOnTable(s"ALTER VIEW $tabName RENAME TO $newViewName") - message = intercept[AnalysisException] { - sql(s"ALTER VIEW $tabName SET TBLPROPERTIES ('p' = 'an')") - }.getMessage - assert(message.contains( - "Cannot alter a table with ALTER VIEW. Please use ALTER TABLE instead")) + assertErrorForAlterTableOnView(s"ALTER TABLE $oldViewName RENAME TO $newViewName") - message = intercept[AnalysisException] { - sql(s"ALTER VIEW $tabName UNSET TBLPROPERTIES ('p')") - }.getMessage - assert(message.contains( - "Cannot alter a table with ALTER VIEW. Please use ALTER TABLE instead")) + assertErrorForAlterViewOnTable(s"ALTER VIEW $tabName SET TBLPROPERTIES ('p' = 'an')") - message = intercept[AnalysisException] { - sql(s"ALTER TABLE $oldViewName RENAME TO $newViewName") - }.getMessage - assert(message.contains( - "Cannot alter a view with ALTER TABLE. Please use ALTER VIEW instead")) + assertErrorForAlterTableOnView(s"ALTER TABLE $oldViewName SET TBLPROPERTIES ('p' = 'an')") - message = intercept[AnalysisException] { - sql(s"ALTER TABLE $oldViewName SET TBLPROPERTIES ('p' = 'an')") - }.getMessage - assert(message.contains( - "Cannot alter a view with ALTER TABLE. Please use ALTER VIEW instead")) + assertErrorForAlterViewOnTable(s"ALTER VIEW $tabName UNSET TBLPROPERTIES ('p')") - message = intercept[AnalysisException] { - sql(s"ALTER TABLE $oldViewName UNSET TBLPROPERTIES ('p')") - }.getMessage - assert(message.contains( - "Cannot alter a view with ALTER TABLE. Please use ALTER VIEW instead")) + assertErrorForAlterTableOnView(s"ALTER TABLE $oldViewName UNSET TBLPROPERTIES ('p')") + + assertErrorForAlterTableOnView(s"ALTER TABLE $oldViewName SET LOCATION '/path/to/home'") + + assertErrorForAlterTableOnView(s"ALTER TABLE $oldViewName SET SERDE 'whatever'") + + assertErrorForAlterTableOnView(s"ALTER TABLE $oldViewName SET SERDEPROPERTIES ('x' = 'y')") + + assertErrorForAlterTableOnView( + s"ALTER TABLE $oldViewName PARTITION (a=1, b=2) SET SERDEPROPERTIES ('x' = 'y')") + + assertErrorForAlterTableOnView( + s"ALTER TABLE $oldViewName ADD IF NOT EXISTS PARTITION (a='4', b='8')") + + assertErrorForAlterTableOnView(s"ALTER TABLE $oldViewName DROP IF EXISTS PARTITION (a='2')") + + assertErrorForAlterTableOnView(s"ALTER TABLE $oldViewName RECOVER PARTITIONS") + + assertErrorForAlterTableOnView( + s"ALTER TABLE $oldViewName PARTITION (a='1') RENAME TO PARTITION (a='100')") assert(catalog.tableExists(TableIdentifier(tabName))) assert(catalog.tableExists(TableIdentifier(oldViewName))) + assert(!catalog.tableExists(TableIdentifier(newViewName))) } } } From d15b4f90e64f7ec5cf14c7c57d2cb4234c3ce677 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Thu, 15 Sep 2016 09:30:15 +0100 Subject: [PATCH 019/306] [SPARK-17507][ML][MLLIB] check weight vector size in ANN ## What changes were proposed in this pull request? as the TODO described, check weight vector size and if wrong throw exception. ## How was this patch tested? existing tests. Author: WeichenXu Closes #15060 from WeichenXu123/check_input_weight_size_of_ann. --- .../src/main/scala/org/apache/spark/ml/ann/Layer.scala | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala b/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala index 88909a9fb953f..e7e0dae0b5a01 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala @@ -545,7 +545,9 @@ private[ann] object FeedForwardModel { * @return model */ def apply(topology: FeedForwardTopology, weights: Vector): FeedForwardModel = { - // TODO: check that weights size is equal to sum of layers sizes + val expectedWeightSize = topology.layers.map(_.weightSize).sum + require(weights.size == expectedWeightSize, + s"Expected weight vector of size ${expectedWeightSize} but got size ${weights.size}.") new FeedForwardModel(weights, topology) } @@ -559,11 +561,7 @@ private[ann] object FeedForwardModel { def apply(topology: FeedForwardTopology, seed: Long = 11L): FeedForwardModel = { val layers = topology.layers val layerModels = new Array[LayerModel](layers.length) - var totalSize = 0 - for (i <- 0 until topology.layers.length) { - totalSize += topology.layers(i).weightSize - } - val weights = BDV.zeros[Double](totalSize) + val weights = BDV.zeros[Double](topology.layers.map(_.weightSize).sum) var offset = 0 val random = new XORShiftRandom(seed) for (i <- 0 until layers.length) { From f893e262500e2f183de88e984300dd5b085e1f71 Mon Sep 17 00:00:00 2001 From: Adam Roberts Date: Thu, 15 Sep 2016 09:37:12 +0100 Subject: [PATCH 020/306] [SPARK-17524][TESTS] Use specified spark.buffer.pageSize ## What changes were proposed in this pull request? This PR has the appendRowUntilExceedingPageSize test in RowBasedKeyValueBatchSuite use whatever spark.buffer.pageSize value a user has specified to prevent a test failure for anyone testing Apache Spark on a box with a reduced page size. The test is currently hardcoded to use the default page size which is 64 MB so this minor PR is a test improvement ## How was this patch tested? Existing unit tests with 1 MB page size and with 64 MB (the default) page size Author: Adam Roberts Closes #15079 from a-roberts/patch-5. --- .../catalyst/expressions/RowBasedKeyValueBatchSuite.java | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java index 0dd129cea7b3f..fb3dbe8ed1996 100644 --- a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java +++ b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java @@ -338,15 +338,17 @@ public void appendRowUntilExceedingCapacity() throws Exception { @Test public void appendRowUntilExceedingPageSize() throws Exception { + // Use default size or spark.buffer.pageSize if specified + int pageSizeToUse = (int) memoryManager.pageSizeBytes(); RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema, - valueSchema, taskMemoryManager, 64 * 1024 * 1024); //enough capacity + valueSchema, taskMemoryManager, pageSizeToUse); //enough capacity try { UnsafeRow key = makeKeyRow(1, "A"); UnsafeRow value = makeValueRow(1, 1); int recordLength = 8 + key.getSizeInBytes() + value.getSizeInBytes() + 8; int totalSize = 4; int numRows = 0; - while (totalSize + recordLength < 64 * 1024 * 1024) { // default page size + while (totalSize + recordLength < pageSizeToUse) { appendRow(batch, key, value); totalSize += recordLength; numRows++; From 647ee05e5815bde361662a9286ac602c44b4d4e6 Mon Sep 17 00:00:00 2001 From: codlife <1004910847@qq.com> Date: Thu, 15 Sep 2016 09:38:13 +0100 Subject: [PATCH 021/306] [SPARK-17521] Error when I use sparkContext.makeRDD(Seq()) ## What changes were proposed in this pull request? when i use sc.makeRDD below ``` val data3 = sc.makeRDD(Seq()) println(data3.partitions.length) ``` I got an error: Exception in thread "main" java.lang.IllegalArgumentException: Positive number of slices required We can fix this bug just modify the last line ,do a check of seq.size ``` def makeRDD[T: ClassTag](seq: Seq[(T, Seq[String])]): RDD[T] = withScope { assertNotStopped() val indexToPrefs = seq.zipWithIndex.map(t => (t._2, t._1._2)).toMap new ParallelCollectionRDD[T](this, seq.map(_._1), math.max(seq.size, defaultParallelism), indexToPrefs) } ``` ## How was this patch tested? manual tests (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Author: codlife <1004910847@qq.com> Author: codlife Closes #15077 from codlife/master. --- core/src/main/scala/org/apache/spark/SparkContext.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index e32e4aa5b8312..35b6334832393 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -795,7 +795,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli def makeRDD[T: ClassTag](seq: Seq[(T, Seq[String])]): RDD[T] = withScope { assertNotStopped() val indexToPrefs = seq.zipWithIndex.map(t => (t._2, t._1._2)).toMap - new ParallelCollectionRDD[T](this, seq.map(_._1), seq.size, indexToPrefs) + new ParallelCollectionRDD[T](this, seq.map(_._1), math.max(seq.size, 1), indexToPrefs) } /** From ad79fc0a8407a950a03869f2f8cdc3ed0bf13875 Mon Sep 17 00:00:00 2001 From: cenyuhai Date: Thu, 15 Sep 2016 09:58:53 +0100 Subject: [PATCH 022/306] [SPARK-17406][WEB UI] limit timeline executor events ## What changes were proposed in this pull request? The job page will be too slow to open when there are thousands of executor events(added or removed). I found that in ExecutorsTab file, executorIdToData will not remove elements, it will increase all the time.Before this pr, it looks like [timeline1.png](https://issues.apache.org/jira/secure/attachment/12827112/timeline1.png). After this pr, it looks like [timeline2.png](https://issues.apache.org/jira/secure/attachment/12827113/timeline2.png)(we can set how many executor events will be displayed) Author: cenyuhai Closes #14969 from cenyuhai/SPARK-17406. --- .../apache/spark/ui/exec/ExecutorsPage.scala | 41 +++---- .../apache/spark/ui/exec/ExecutorsTab.scala | 112 +++++++++++------- .../apache/spark/ui/jobs/AllJobsPage.scala | 66 +++++------ .../apache/spark/ui/jobs/ExecutorTable.scala | 3 +- .../org/apache/spark/ui/jobs/JobPage.scala | 67 ++++++----- .../org/apache/spark/ui/jobs/StagePage.scala | 4 +- .../org/apache/spark/ui/jobs/UIData.scala | 5 - project/MimaExcludes.scala | 12 ++ 8 files changed, 162 insertions(+), 148 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala index 982e8915a8ded..7953d77fd7ece 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala @@ -17,14 +17,12 @@ package org.apache.spark.ui.exec -import java.net.URLEncoder import javax.servlet.http.HttpServletRequest import scala.xml.Node import org.apache.spark.status.api.v1.ExecutorSummary -import org.apache.spark.ui.{ToolTips, UIUtils, WebUIPage} -import org.apache.spark.util.Utils +import org.apache.spark.ui.{UIUtils, WebUIPage} // This isn't even used anymore -- but we need to keep it b/c of a MiMa false positive private[ui] case class ExecutorSummaryInfo( @@ -83,18 +81,7 @@ private[spark] object ExecutorsPage { val memUsed = status.memUsed val maxMem = status.maxMem val diskUsed = status.diskUsed - val totalCores = listener.executorToTotalCores.getOrElse(execId, 0) - val maxTasks = listener.executorToTasksMax.getOrElse(execId, 0) - val activeTasks = listener.executorToTasksActive.getOrElse(execId, 0) - val failedTasks = listener.executorToTasksFailed.getOrElse(execId, 0) - val completedTasks = listener.executorToTasksComplete.getOrElse(execId, 0) - val totalTasks = activeTasks + failedTasks + completedTasks - val totalDuration = listener.executorToDuration.getOrElse(execId, 0L) - val totalGCTime = listener.executorToJvmGCTime.getOrElse(execId, 0L) - val totalInputBytes = listener.executorToInputBytes.getOrElse(execId, 0L) - val totalShuffleRead = listener.executorToShuffleRead.getOrElse(execId, 0L) - val totalShuffleWrite = listener.executorToShuffleWrite.getOrElse(execId, 0L) - val executorLogs = listener.executorToLogUrls.getOrElse(execId, Map.empty) + val taskSummary = listener.executorToTaskSummary.getOrElse(execId, ExecutorTaskSummary(execId)) new ExecutorSummary( execId, @@ -103,19 +90,19 @@ private[spark] object ExecutorsPage { rddBlocks, memUsed, diskUsed, - totalCores, - maxTasks, - activeTasks, - failedTasks, - completedTasks, - totalTasks, - totalDuration, - totalGCTime, - totalInputBytes, - totalShuffleRead, - totalShuffleWrite, + taskSummary.totalCores, + taskSummary.tasksMax, + taskSummary.tasksActive, + taskSummary.tasksFailed, + taskSummary.tasksComplete, + taskSummary.tasksActive + taskSummary.tasksFailed + taskSummary.tasksComplete, + taskSummary.duration, + taskSummary.jvmGCTime, + taskSummary.inputBytes, + taskSummary.shuffleRead, + taskSummary.shuffleWrite, maxMem, - executorLogs + taskSummary.executorLogs ) } } diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala index 676f4457510c2..678571fd4f5ac 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala @@ -17,14 +17,13 @@ package org.apache.spark.ui.exec -import scala.collection.mutable.HashMap +import scala.collection.mutable.{LinkedHashMap, ListBuffer} import org.apache.spark.{ExceptionFailure, Resubmitted, SparkConf, SparkContext} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.scheduler._ import org.apache.spark.storage.{StorageStatus, StorageStatusListener} import org.apache.spark.ui.{SparkUI, SparkUITab} -import org.apache.spark.ui.jobs.UIData.ExecutorUIData private[ui] class ExecutorsTab(parent: SparkUI) extends SparkUITab(parent, "executors") { val listener = parent.executorsListener @@ -38,6 +37,25 @@ private[ui] class ExecutorsTab(parent: SparkUI) extends SparkUITab(parent, "exec } } +private[ui] case class ExecutorTaskSummary( + var executorId: String, + var totalCores: Int = 0, + var tasksMax: Int = 0, + var tasksActive: Int = 0, + var tasksFailed: Int = 0, + var tasksComplete: Int = 0, + var duration: Long = 0L, + var jvmGCTime: Long = 0L, + var inputBytes: Long = 0L, + var inputRecords: Long = 0L, + var outputBytes: Long = 0L, + var outputRecords: Long = 0L, + var shuffleRead: Long = 0L, + var shuffleWrite: Long = 0L, + var executorLogs: Map[String, String] = Map.empty, + var isAlive: Boolean = true +) + /** * :: DeveloperApi :: * A SparkListener that prepares information to be displayed on the ExecutorsTab @@ -45,21 +63,11 @@ private[ui] class ExecutorsTab(parent: SparkUI) extends SparkUITab(parent, "exec @DeveloperApi class ExecutorsListener(storageStatusListener: StorageStatusListener, conf: SparkConf) extends SparkListener { - val executorToTotalCores = HashMap[String, Int]() - val executorToTasksMax = HashMap[String, Int]() - val executorToTasksActive = HashMap[String, Int]() - val executorToTasksComplete = HashMap[String, Int]() - val executorToTasksFailed = HashMap[String, Int]() - val executorToDuration = HashMap[String, Long]() - val executorToJvmGCTime = HashMap[String, Long]() - val executorToInputBytes = HashMap[String, Long]() - val executorToInputRecords = HashMap[String, Long]() - val executorToOutputBytes = HashMap[String, Long]() - val executorToOutputRecords = HashMap[String, Long]() - val executorToShuffleRead = HashMap[String, Long]() - val executorToShuffleWrite = HashMap[String, Long]() - val executorToLogUrls = HashMap[String, Map[String, String]]() - val executorIdToData = HashMap[String, ExecutorUIData]() + var executorToTaskSummary = LinkedHashMap[String, ExecutorTaskSummary]() + var executorEvents = new ListBuffer[SparkListenerEvent]() + + private val maxTimelineExecutors = conf.getInt("spark.ui.timeline.executors.maximum", 1000) + private val retainedDeadExecutors = conf.getInt("spark.ui.retainedDeadExecutors", 100) def activeStorageStatusList: Seq[StorageStatus] = storageStatusListener.storageStatusList @@ -67,18 +75,29 @@ class ExecutorsListener(storageStatusListener: StorageStatusListener, conf: Spar override def onExecutorAdded(executorAdded: SparkListenerExecutorAdded): Unit = synchronized { val eid = executorAdded.executorId - executorToLogUrls(eid) = executorAdded.executorInfo.logUrlMap - executorToTotalCores(eid) = executorAdded.executorInfo.totalCores - executorToTasksMax(eid) = executorToTotalCores(eid) / conf.getInt("spark.task.cpus", 1) - executorIdToData(eid) = new ExecutorUIData(executorAdded.time) + val taskSummary = executorToTaskSummary.getOrElseUpdate(eid, ExecutorTaskSummary(eid)) + taskSummary.executorLogs = executorAdded.executorInfo.logUrlMap + taskSummary.totalCores = executorAdded.executorInfo.totalCores + taskSummary.tasksMax = taskSummary.totalCores / conf.getInt("spark.task.cpus", 1) + executorEvents += executorAdded + if (executorEvents.size > maxTimelineExecutors) { + executorEvents.remove(0) + } + + val deadExecutors = executorToTaskSummary.filter(e => !e._2.isAlive) + if (deadExecutors.size > retainedDeadExecutors) { + val head = deadExecutors.head + executorToTaskSummary.remove(head._1) + } } override def onExecutorRemoved( executorRemoved: SparkListenerExecutorRemoved): Unit = synchronized { - val eid = executorRemoved.executorId - val uiData = executorIdToData(eid) - uiData.finishTime = Some(executorRemoved.time) - uiData.finishReason = Some(executorRemoved.reason) + executorEvents += executorRemoved + if (executorEvents.size > maxTimelineExecutors) { + executorEvents.remove(0) + } + executorToTaskSummary.get(executorRemoved.executorId).foreach(e => e.isAlive = false) } override def onApplicationStart(applicationStart: SparkListenerApplicationStart): Unit = { @@ -87,19 +106,25 @@ class ExecutorsListener(storageStatusListener: StorageStatusListener, conf: Spar s.blockManagerId.executorId == SparkContext.LEGACY_DRIVER_IDENTIFIER || s.blockManagerId.executorId == SparkContext.DRIVER_IDENTIFIER } - storageStatus.foreach { s => executorToLogUrls(s.blockManagerId.executorId) = logs.toMap } + storageStatus.foreach { s => + val eid = s.blockManagerId.executorId + val taskSummary = executorToTaskSummary.getOrElseUpdate(eid, ExecutorTaskSummary(eid)) + taskSummary.executorLogs = logs.toMap + } } } override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = synchronized { val eid = taskStart.taskInfo.executorId - executorToTasksActive(eid) = executorToTasksActive.getOrElse(eid, 0) + 1 + val taskSummary = executorToTaskSummary.getOrElseUpdate(eid, ExecutorTaskSummary(eid)) + taskSummary.tasksActive += 1 } override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized { val info = taskEnd.taskInfo if (info != null) { val eid = info.executorId + val taskSummary = executorToTaskSummary.getOrElseUpdate(eid, ExecutorTaskSummary(eid)) taskEnd.reason match { case Resubmitted => // Note: For resubmitted tasks, we continue to use the metrics that belong to the @@ -108,31 +133,26 @@ class ExecutorsListener(storageStatusListener: StorageStatusListener, conf: Spar // metrics added by each attempt, but this is much more complicated. return case e: ExceptionFailure => - executorToTasksFailed(eid) = executorToTasksFailed.getOrElse(eid, 0) + 1 + taskSummary.tasksFailed += 1 case _ => - executorToTasksComplete(eid) = executorToTasksComplete.getOrElse(eid, 0) + 1 + taskSummary.tasksComplete += 1 } - - executorToTasksActive(eid) = executorToTasksActive.getOrElse(eid, 1) - 1 - executorToDuration(eid) = executorToDuration.getOrElse(eid, 0L) + info.duration + if (taskSummary.tasksActive >= 1) { + taskSummary.tasksActive -= 1 + } + taskSummary.duration += info.duration // Update shuffle read/write val metrics = taskEnd.taskMetrics if (metrics != null) { - executorToInputBytes(eid) = - executorToInputBytes.getOrElse(eid, 0L) + metrics.inputMetrics.bytesRead - executorToInputRecords(eid) = - executorToInputRecords.getOrElse(eid, 0L) + metrics.inputMetrics.recordsRead - executorToOutputBytes(eid) = - executorToOutputBytes.getOrElse(eid, 0L) + metrics.outputMetrics.bytesWritten - executorToOutputRecords(eid) = - executorToOutputRecords.getOrElse(eid, 0L) + metrics.outputMetrics.recordsWritten - - executorToShuffleRead(eid) = - executorToShuffleRead.getOrElse(eid, 0L) + metrics.shuffleReadMetrics.remoteBytesRead - executorToShuffleWrite(eid) = - executorToShuffleWrite.getOrElse(eid, 0L) + metrics.shuffleWriteMetrics.bytesWritten - executorToJvmGCTime(eid) = executorToJvmGCTime.getOrElse(eid, 0L) + metrics.jvmGCTime + taskSummary.inputBytes += metrics.inputMetrics.bytesRead + taskSummary.inputRecords += metrics.inputMetrics.recordsRead + taskSummary.outputBytes += metrics.outputMetrics.bytesWritten + taskSummary.outputRecords += metrics.outputMetrics.recordsWritten + + taskSummary.shuffleRead += metrics.shuffleReadMetrics.remoteBytesRead + taskSummary.shuffleWrite += metrics.shuffleWriteMetrics.bytesWritten + taskSummary.jvmGCTime += metrics.jvmGCTime } } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala index e5363ce8ca9dc..c04964ec66479 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala @@ -28,9 +28,9 @@ import scala.xml._ import org.apache.commons.lang3.StringEscapeUtils import org.apache.spark.JobExecutionStatus -import org.apache.spark.scheduler.StageInfo +import org.apache.spark.scheduler._ import org.apache.spark.ui._ -import org.apache.spark.ui.jobs.UIData.{ExecutorUIData, JobUIData, StageUIData} +import org.apache.spark.ui.jobs.UIData.{JobUIData, StageUIData} import org.apache.spark.util.Utils /** Page showing list of all ongoing and recently finished jobs */ @@ -123,55 +123,55 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { } } - private def makeExecutorEvent(executorUIDatas: HashMap[String, ExecutorUIData]): Seq[String] = { + private def makeExecutorEvent(executorUIDatas: Seq[SparkListenerEvent]): + Seq[String] = { val events = ListBuffer[String]() executorUIDatas.foreach { - case (executorId, event) => + case a: SparkListenerExecutorAdded => val addedEvent = s""" |{ | 'className': 'executor added', | 'group': 'executors', - | 'start': new Date(${event.startTime}), + | 'start': new Date(${a.time}), | 'content': '
    Executor ${executorId} added
    ' + | 'data-title="Executor ${a.executorId}
    ' + + | 'Added at ${UIUtils.formatDate(new Date(a.time))}"' + + | 'data-html="true">Executor ${a.executorId} added
    ' |} """.stripMargin events += addedEvent + case e: SparkListenerExecutorRemoved => + val removedEvent = + s""" + |{ + | 'className': 'executor removed', + | 'group': 'executors', + | 'start': new Date(${e.time}), + | 'content': '
    Reason: ${e.reason.replace("\n", " ")}""" + } else { + "" + } + }"' + + | 'data-html="true">Executor ${e.executorId} removed
    ' + |} + """.stripMargin + events += removedEvent - if (event.finishTime.isDefined) { - val removedEvent = - s""" - |{ - | 'className': 'executor removed', - | 'group': 'executors', - | 'start': new Date(${event.finishTime.get}), - | 'content': '
    Reason: ${event.finishReason.get.replace("\n", " ")}""" - } else { - "" - } - }"' + - | 'data-html="true">Executor ${executorId} removed
    ' - |} - """.stripMargin - events += removedEvent - } } events.toSeq } private def makeTimeline( jobs: Seq[JobUIData], - executors: HashMap[String, ExecutorUIData], + executors: Seq[SparkListenerEvent], startTime: Long): Seq[Node] = { val jobEventJsonAsStrSeq = makeJobEvent(jobs) @@ -353,7 +353,7 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { var content = summary val executorListener = parent.executorListener content ++= makeTimeline(activeJobs ++ completedJobs ++ failedJobs, - executorListener.executorIdToData, startTime) + executorListener.executorEvents, startTime) if (shouldShowActiveJobs) { content ++=

    Active Jobs ({activeJobs.size})

    ++ diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala index 133c3b1b9aca8..9fb3f35fd9685 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala @@ -118,7 +118,8 @@ private[ui] class ExecutorTable(stageId: Int, stageAttemptId: Int, parent: Stage
    {k}
    { - val logs = parent.executorsListener.executorToLogUrls.getOrElse(k, Map.empty) + val logs = parent.executorsListener.executorToTaskSummary.get(k) + .map(_.executorLogs).getOrElse(Map.empty) logs.map { case (logName, logUrl) => } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala index 0ec42d68d3dcc..2f7f8976a8899 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala @@ -20,15 +20,14 @@ package org.apache.spark.ui.jobs import java.util.Date import javax.servlet.http.HttpServletRequest -import scala.collection.mutable.{Buffer, HashMap, ListBuffer} +import scala.collection.mutable.{Buffer, ListBuffer} import scala.xml.{Node, NodeSeq, Unparsed, Utility} import org.apache.commons.lang3.StringEscapeUtils import org.apache.spark.JobExecutionStatus -import org.apache.spark.scheduler.StageInfo +import org.apache.spark.scheduler._ import org.apache.spark.ui.{ToolTips, UIUtils, WebUIPage} -import org.apache.spark.ui.jobs.UIData.ExecutorUIData /** Page showing statistics and stage list for a given job */ private[ui] class JobPage(parent: JobsTab) extends WebUIPage("job") { @@ -93,55 +92,55 @@ private[ui] class JobPage(parent: JobsTab) extends WebUIPage("job") { } } - def makeExecutorEvent(executorUIDatas: HashMap[String, ExecutorUIData]): Seq[String] = { + def makeExecutorEvent(executorUIDatas: Seq[SparkListenerEvent]): Seq[String] = { val events = ListBuffer[String]() executorUIDatas.foreach { - case (executorId, event) => + case a: SparkListenerExecutorAdded => val addedEvent = s""" |{ | 'className': 'executor added', | 'group': 'executors', - | 'start': new Date(${event.startTime}), + | 'start': new Date(${a.time}), | 'content': '
    Executor ${executorId} added
    ' + | 'data-title="Executor ${a.executorId}
    ' + + | 'Added at ${UIUtils.formatDate(new Date(a.time))}"' + + | 'data-html="true">Executor ${a.executorId} added
    ' |} """.stripMargin events += addedEvent - if (event.finishTime.isDefined) { - val removedEvent = - s""" - |{ - | 'className': 'executor removed', - | 'group': 'executors', - | 'start': new Date(${event.finishTime.get}), - | 'content': '
    Reason: ${event.finishReason.get.replace("\n", " ")}""" - } else { - "" - } - }"' + - | 'data-html="true">Executor ${executorId} removed
    ' - |} - """.stripMargin - events += removedEvent - } + case e: SparkListenerExecutorRemoved => + val removedEvent = + s""" + |{ + | 'className': 'executor removed', + | 'group': 'executors', + | 'start': new Date(${e.time}), + | 'content': '
    Reason: ${e.reason.replace("\n", " ")}""" + } else { + "" + } + }"' + + | 'data-html="true">Executor ${e.executorId} removed
    ' + |} + """.stripMargin + events += removedEvent + } events.toSeq } private def makeTimeline( stages: Seq[StageInfo], - executors: HashMap[String, ExecutorUIData], + executors: Seq[SparkListenerEvent], appStartTime: Long): Seq[Node] = { val stageEventJsonAsStrSeq = makeStageEvent(stages) @@ -319,7 +318,7 @@ private[ui] class JobPage(parent: JobsTab) extends WebUIPage("job") { val operationGraphListener = parent.operationGraphListener content ++= makeTimeline(activeStages ++ completedStages ++ failedStages, - executorListener.executorIdToData, appStartTime) + executorListener.executorEvents, appStartTime) content ++= UIUtils.showDagVizForJob( jobId, operationGraphListener.getOperationGraphForJob(jobId)) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index de787f257737d..c322ae0972ad7 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -1017,8 +1017,8 @@ private[ui] class TaskDataSource( None } - val logs = executorsListener.executorToLogUrls.getOrElse(info.executorId, Map.empty) - + val logs = executorsListener.executorToTaskSummary.get(info.executorId) + .map(_.executorLogs).getOrElse(Map.empty) new TaskTableRowData( info.index, info.taskId, diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala index 74bca9931acf7..c729f03b3c383 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala @@ -177,11 +177,6 @@ private[spark] object UIData { } } - class ExecutorUIData( - val startTime: Long, - var finishTime: Option[Long] = None, - var finishReason: Option[String] = None) - case class TaskMetricsUIData( executorDeserializeTime: Long, executorRunTime: Long, diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index fbd78aeb20dd6..37fff2efa4eae 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -426,6 +426,18 @@ object MimaExcludes { ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatusListener.this"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.streaming.scheduler.BatchInfo.streamIdToNumRecords"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.storageStatusList"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorIdToData"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToTasksActive"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToTasksComplete"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToInputRecords"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToShuffleRead"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToTasksFailed"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToShuffleWrite"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToDuration"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToInputBytes"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToLogUrls"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToOutputBytes"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToOutputRecords"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.this"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.storage.StorageListener.storageStatusList"), ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ExceptionFailure.apply"), From 71a65825c5d5d0886ac3e11f9945cfcb39573ac3 Mon Sep 17 00:00:00 2001 From: John Muller Date: Thu, 15 Sep 2016 10:00:28 +0100 Subject: [PATCH 023/306] [SPARK-17536][SQL] Minor performance improvement to JDBC batch inserts ## What changes were proposed in this pull request? Optimize a while loop during batch inserts ## How was this patch tested? Unit tests were done, specifically "mvn test" for sql Author: John Muller Closes #15098 from blue666man/SPARK-17536. --- .../apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 132472ad0ce87..b09fd511a9074 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -590,12 +590,12 @@ object JdbcUtils extends Logging { val stmt = insertStatement(conn, table, rddSchema, dialect) val setters: Array[JDBCValueSetter] = rddSchema.fields.map(_.dataType) .map(makeSetter(conn, dialect, _)).toArray + val numFields = rddSchema.fields.length try { var rowCount = 0 while (iterator.hasNext) { val row = iterator.next() - val numFields = rddSchema.fields.length var i = 0 while (i < numFields) { if (row.isNullAt(i)) { From 2ad276954858b0a7b3f442b9e440c72cbb1610e2 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Thu, 15 Sep 2016 13:54:41 +0100 Subject: [PATCH 024/306] [SPARK-17406][BUILD][HOTFIX] MiMa excludes fix ## What changes were proposed in this pull request? Following https://github.com/apache/spark/pull/14969 for some reason the MiMa excludes weren't complete, but still passed the PR builder. This adds 3 more excludes from https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Test%20(Dashboard)/job/spark-master-test-sbt-hadoop-2.2/1749/consoleFull It also moves the excludes to their own Seq in the build, as they probably should have been. Even though this is merged to 2.1.x only / master, I left the exclude in for 2.0.x in case we back port. It's a private API so is always a false positive. ## How was this patch tested? Jenkins build Author: Sean Owen Closes #15110 from srowen/SPARK-17406.2. --- project/MimaExcludes.scala | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 37fff2efa4eae..1bdcf9a623dc9 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -426,18 +426,6 @@ object MimaExcludes { ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatusListener.this"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.streaming.scheduler.BatchInfo.streamIdToNumRecords"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.storageStatusList"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorIdToData"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToTasksActive"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToTasksComplete"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToInputRecords"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToShuffleRead"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToTasksFailed"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToShuffleWrite"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToDuration"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToInputBytes"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToLogUrls"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToOutputBytes"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToOutputRecords"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.this"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.storage.StorageListener.storageStatusList"), ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ExceptionFailure.apply"), @@ -807,6 +795,23 @@ object MimaExcludes { // SPARK-17096: Improve exception string reported through the StreamingQueryListener ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryListener#QueryTerminated.stackTrace"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryListener#QueryTerminated.this") + ) ++ Seq( + // SPARK-17406 limit timeline executor events + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorIdToData"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToTasksActive"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToTasksComplete"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToInputRecords"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToShuffleRead"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToTasksFailed"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToShuffleWrite"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToDuration"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToInputBytes"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToLogUrls"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToOutputBytes"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToOutputRecords"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToTotalCores"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToTasksMax"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToJvmGCTime") ) } From b479278142728eb003b9ee466fab0e8d6ec4b13d Mon Sep 17 00:00:00 2001 From: Tejas Patil Date: Thu, 15 Sep 2016 10:23:41 -0700 Subject: [PATCH 025/306] [SPARK-17451][CORE] CoarseGrainedExecutorBackend should inform driver before self-kill ## What changes were proposed in this pull request? Jira : https://issues.apache.org/jira/browse/SPARK-17451 `CoarseGrainedExecutorBackend` in some failure cases exits the JVM. While this does not have any issue, from the driver UI there is no specific reason captured for this. In this PR, I am adding functionality to `exitExecutor` to notify driver that the executor is exiting. ## How was this patch tested? Ran the change over a test env and took down shuffle service before the executor could register to it. In the driver logs, where the job failure reason is mentioned (ie. `Job aborted due to stage ...` it gives the correct reason: Before: `ExecutorLostFailure (executor ZZZZZZZZZ exited caused by one of the running tasks) Reason: Remote RPC client disassociated. Likely due to containers exceeding thresholds, or network issues. Check driver logs for WARN messages.` After: `ExecutorLostFailure (executor ZZZZZZZZZ exited caused by one of the running tasks) Reason: Unable to create executor due to java.util.concurrent.TimeoutException: Timeout waiting for task.` Author: Tejas Patil Closes #15013 from tejasapatil/SPARK-17451_inform_driver. --- .../CoarseGrainedExecutorBackend.scala | 26 ++++++++++++++----- .../apache/spark/storage/BlockManager.scala | 3 +++ 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 391b97d73e026..7eec4ae64f296 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -31,7 +31,7 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.worker.WorkerWatcher import org.apache.spark.internal.Logging import org.apache.spark.rpc._ -import org.apache.spark.scheduler.TaskDescription +import org.apache.spark.scheduler.{ExecutorLossReason, TaskDescription} import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.serializer.SerializerInstance import org.apache.spark.util.{ThreadUtils, Utils} @@ -65,7 +65,7 @@ private[spark] class CoarseGrainedExecutorBackend( case Success(msg) => // Always receive `true`. Just ignore it case Failure(e) => - exitExecutor(1, s"Cannot register with driver: $driverUrl", e) + exitExecutor(1, s"Cannot register with driver: $driverUrl", e, notifyDriver = false) }(ThreadUtils.sameThread) } @@ -129,7 +129,8 @@ private[spark] class CoarseGrainedExecutorBackend( if (stopping.get()) { logInfo(s"Driver from $remoteAddress disconnected during shutdown") } else if (driver.exists(_.address == remoteAddress)) { - exitExecutor(1, s"Driver $remoteAddress disassociated! Shutting down.") + exitExecutor(1, s"Driver $remoteAddress disassociated! Shutting down.", null, + notifyDriver = false) } else { logWarning(s"An unknown ($remoteAddress) driver disconnected.") } @@ -148,12 +149,25 @@ private[spark] class CoarseGrainedExecutorBackend( * executor exits differently. For e.g. when an executor goes down, * back-end may not want to take the parent process down. */ - protected def exitExecutor(code: Int, reason: String, throwable: Throwable = null) = { + protected def exitExecutor(code: Int, + reason: String, + throwable: Throwable = null, + notifyDriver: Boolean = true) = { + val message = "Executor self-exiting due to : " + reason if (throwable != null) { - logError(reason, throwable) + logError(message, throwable) } else { - logError(reason) + logError(message) } + + if (notifyDriver && driver.nonEmpty) { + driver.get.ask[Boolean]( + RemoveExecutor(executorId, new ExecutorLossReason(reason)) + ).onFailure { case e => + logWarning(s"Unable to notify the driver due to " + e.getMessage, e) + }(ThreadUtils.sameThread) + } + System.exit(code) } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index a724fdf009789..c172ac2cdc0e3 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -199,6 +199,9 @@ private[spark] class BlockManager( logError(s"Failed to connect to external shuffle server, will retry ${MAX_ATTEMPTS - i}" + s" more times after waiting $SLEEP_TIME_SECS seconds...", e) Thread.sleep(SLEEP_TIME_SECS * 1000) + case NonFatal(e) => + throw new SparkException("Unable to register with external shuffle server due to : " + + e.getMessage, e) } } } From 0ad8eeb4d365c2fff5715ec22fbcf4c69c3340fd Mon Sep 17 00:00:00 2001 From: Adam Roberts Date: Thu, 15 Sep 2016 10:40:10 -0700 Subject: [PATCH 026/306] [SPARK-17379][BUILD] Upgrade netty-all to 4.0.41 final for bug fixes ## What changes were proposed in this pull request? Upgrade netty-all to latest in the 4.0.x line which is 4.0.41, mentions several bug fixes and performance improvements we may find useful, see netty.io/news/2016/08/29/4-0-41-Final-4-1-5-Final.html. Initially tried to use 4.1.5 but noticed it's not backwards compatible. ## How was this patch tested? Existing unit tests against branch-1.6 and branch-2.0 using IBM Java 8 on Intel, Power and Z architectures Author: Adam Roberts Closes #14961 from a-roberts/netty. --- .../java/org/apache/spark/network/util/TransportConf.java | 5 +++++ dev/deps/spark-deps-hadoop-2.2 | 2 +- dev/deps/spark-deps-hadoop-2.3 | 2 +- dev/deps/spark-deps-hadoop-2.4 | 2 +- dev/deps/spark-deps-hadoop-2.6 | 2 +- dev/deps/spark-deps-hadoop-2.7 | 2 +- pom.xml | 2 +- 7 files changed, 11 insertions(+), 6 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java index 0efc400aa388c..7d5baa9a9c8f8 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -23,6 +23,11 @@ * A central location that tracks all the settings we expose to users. */ public class TransportConf { + + static { + // Set this due to Netty PR #5661 for Netty 4.0.37+ to work + System.setProperty("io.netty.maxDirectMemory", "0"); + } private final String SPARK_NETWORK_IO_MODE_KEY; private final String SPARK_NETWORK_IO_PREFERDIRECTBUFS_KEY; diff --git a/dev/deps/spark-deps-hadoop-2.2 b/dev/deps/spark-deps-hadoop-2.2 index 81adde6a13a14..a7259e25bfec6 100644 --- a/dev/deps/spark-deps-hadoop-2.2 +++ b/dev/deps/spark-deps-hadoop-2.2 @@ -124,7 +124,7 @@ metrics-json-3.1.2.jar metrics-jvm-3.1.2.jar minlog-1.3.0.jar netty-3.8.0.Final.jar -netty-all-4.0.29.Final.jar +netty-all-4.0.41.Final.jar objenesis-2.1.jar opencsv-2.3.jar oro-2.0.8.jar diff --git a/dev/deps/spark-deps-hadoop-2.3 b/dev/deps/spark-deps-hadoop-2.3 index 75ab6286dec3c..6986ab572b947 100644 --- a/dev/deps/spark-deps-hadoop-2.3 +++ b/dev/deps/spark-deps-hadoop-2.3 @@ -131,7 +131,7 @@ metrics-jvm-3.1.2.jar minlog-1.3.0.jar mx4j-3.0.2.jar netty-3.8.0.Final.jar -netty-all-4.0.29.Final.jar +netty-all-4.0.41.Final.jar objenesis-2.1.jar opencsv-2.3.jar oro-2.0.8.jar diff --git a/dev/deps/spark-deps-hadoop-2.4 b/dev/deps/spark-deps-hadoop-2.4 index 897d802a9d6a1..75cccb352b9cf 100644 --- a/dev/deps/spark-deps-hadoop-2.4 +++ b/dev/deps/spark-deps-hadoop-2.4 @@ -131,7 +131,7 @@ metrics-jvm-3.1.2.jar minlog-1.3.0.jar mx4j-3.0.2.jar netty-3.8.0.Final.jar -netty-all-4.0.29.Final.jar +netty-all-4.0.41.Final.jar objenesis-2.1.jar opencsv-2.3.jar oro-2.0.8.jar diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index f95ddb1c3065d..ef7b8a7d8da26 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -139,7 +139,7 @@ metrics-jvm-3.1.2.jar minlog-1.3.0.jar mx4j-3.0.2.jar netty-3.8.0.Final.jar -netty-all-4.0.29.Final.jar +netty-all-4.0.41.Final.jar objenesis-2.1.jar opencsv-2.3.jar oro-2.0.8.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 8df02c032bf21..d464c97ed1d67 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -140,7 +140,7 @@ metrics-jvm-3.1.2.jar minlog-1.3.0.jar mx4j-3.0.2.jar netty-3.8.0.Final.jar -netty-all-4.0.29.Final.jar +netty-all-4.0.41.Final.jar objenesis-2.1.jar opencsv-2.3.jar oro-2.0.8.jar diff --git a/pom.xml b/pom.xml index 3b3ad39b47571..ef83c184d0237 100644 --- a/pom.xml +++ b/pom.xml @@ -551,7 +551,7 @@ io.netty netty-all - 4.0.29.Final + 4.0.41.Final io.netty From 5b8f7377d54f83b93ef2bfc2a01ca65fae6d3032 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 15 Sep 2016 11:22:58 -0700 Subject: [PATCH 027/306] [SPARK-17547] Ensure temp shuffle data file is cleaned up after error SPARK-8029 (#9610) modified shuffle writers to first stage their data to a temporary file in the same directory as the final destination file and then to atomically rename this temporary file at the end of the write job. However, this change introduced the potential for the temporary output file to be leaked if an exception occurs during the write because the shuffle writers' existing error cleanup code doesn't handle deletion of the temp file. This patch avoids this potential cause of disk-space leaks by adding `finally` blocks to ensure that temp files are always deleted if they haven't been renamed. Author: Josh Rosen Closes #15104 from JoshRosen/cleanup-tmp-data-file-in-shuffle-writer. --- .../sort/BypassMergeSortShuffleWriter.java | 10 ++- .../shuffle/sort/UnsafeShuffleWriter.java | 18 +++-- .../shuffle/IndexShuffleBlockResolver.scala | 80 ++++++++++--------- .../shuffle/sort/SortShuffleWriter.scala | 14 +++- 4 files changed, 73 insertions(+), 49 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index 0fcc56d50ae6a..4a15559e55cbd 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -160,8 +160,14 @@ public void write(Iterator> records) throws IOException { File output = shuffleBlockResolver.getDataFile(shuffleId, mapId); File tmp = Utils.tempFileWith(output); - partitionLengths = writePartitionedFile(tmp); - shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp); + try { + partitionLengths = writePartitionedFile(tmp); + shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp); + } finally { + if (tmp.exists() && !tmp.delete()) { + logger.error("Error while deleting temp file {}", tmp.getAbsolutePath()); + } + } mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index 63d376b44fb11..f235c434be7b1 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -210,15 +210,21 @@ void closeAndWriteOutput() throws IOException { final File output = shuffleBlockResolver.getDataFile(shuffleId, mapId); final File tmp = Utils.tempFileWith(output); try { - partitionLengths = mergeSpills(spills, tmp); - } finally { - for (SpillInfo spill : spills) { - if (spill.file.exists() && ! spill.file.delete()) { - logger.error("Error while deleting spill file {}", spill.file.getPath()); + try { + partitionLengths = mergeSpills(spills, tmp); + } finally { + for (SpillInfo spill : spills) { + if (spill.file.exists() && ! spill.file.delete()) { + logger.error("Error while deleting spill file {}", spill.file.getPath()); + } } } + shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp); + } finally { + if (tmp.exists() && !tmp.delete()) { + logger.error("Error while deleting temp file {}", tmp.getAbsolutePath()); + } } - shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp); mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); } diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index 94d8c0d0fd3e4..8d6396bededa9 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -139,48 +139,54 @@ private[spark] class IndexShuffleBlockResolver( dataTmp: File): Unit = { val indexFile = getIndexFile(shuffleId, mapId) val indexTmp = Utils.tempFileWith(indexFile) - val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexTmp))) - Utils.tryWithSafeFinally { - // We take in lengths of each block, need to convert it to offsets. - var offset = 0L - out.writeLong(offset) - for (length <- lengths) { - offset += length + try { + val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexTmp))) + Utils.tryWithSafeFinally { + // We take in lengths of each block, need to convert it to offsets. + var offset = 0L out.writeLong(offset) + for (length <- lengths) { + offset += length + out.writeLong(offset) + } + } { + out.close() } - } { - out.close() - } - val dataFile = getDataFile(shuffleId, mapId) - // There is only one IndexShuffleBlockResolver per executor, this synchronization make sure - // the following check and rename are atomic. - synchronized { - val existingLengths = checkIndexAndDataFile(indexFile, dataFile, lengths.length) - if (existingLengths != null) { - // Another attempt for the same task has already written our map outputs successfully, - // so just use the existing partition lengths and delete our temporary map outputs. - System.arraycopy(existingLengths, 0, lengths, 0, lengths.length) - if (dataTmp != null && dataTmp.exists()) { - dataTmp.delete() - } - indexTmp.delete() - } else { - // This is the first successful attempt in writing the map outputs for this task, - // so override any existing index and data files with the ones we wrote. - if (indexFile.exists()) { - indexFile.delete() - } - if (dataFile.exists()) { - dataFile.delete() - } - if (!indexTmp.renameTo(indexFile)) { - throw new IOException("fail to rename file " + indexTmp + " to " + indexFile) - } - if (dataTmp != null && dataTmp.exists() && !dataTmp.renameTo(dataFile)) { - throw new IOException("fail to rename file " + dataTmp + " to " + dataFile) + val dataFile = getDataFile(shuffleId, mapId) + // There is only one IndexShuffleBlockResolver per executor, this synchronization make sure + // the following check and rename are atomic. + synchronized { + val existingLengths = checkIndexAndDataFile(indexFile, dataFile, lengths.length) + if (existingLengths != null) { + // Another attempt for the same task has already written our map outputs successfully, + // so just use the existing partition lengths and delete our temporary map outputs. + System.arraycopy(existingLengths, 0, lengths, 0, lengths.length) + if (dataTmp != null && dataTmp.exists()) { + dataTmp.delete() + } + indexTmp.delete() + } else { + // This is the first successful attempt in writing the map outputs for this task, + // so override any existing index and data files with the ones we wrote. + if (indexFile.exists()) { + indexFile.delete() + } + if (dataFile.exists()) { + dataFile.delete() + } + if (!indexTmp.renameTo(indexFile)) { + throw new IOException("fail to rename file " + indexTmp + " to " + indexFile) + } + if (dataTmp != null && dataTmp.exists() && !dataTmp.renameTo(dataFile)) { + throw new IOException("fail to rename file " + dataTmp + " to " + dataFile) + } } } + } finally { + if (indexTmp.exists() && !indexTmp.delete()) { + logError(s"Failed to delete temporary index file at ${indexTmp.getAbsolutePath}") + } } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index cc01e6aa7ea91..636b88e792bf3 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -67,10 +67,16 @@ private[spark] class SortShuffleWriter[K, V, C]( // (see SPARK-3570). val output = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId) val tmp = Utils.tempFileWith(output) - val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID) - val partitionLengths = sorter.writePartitionedFile(blockId, tmp) - shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp) - mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths) + try { + val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID) + val partitionLengths = sorter.writePartitionedFile(blockId, tmp) + shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp) + mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths) + } finally { + if (tmp.exists() && !tmp.delete()) { + logError(s"Error while deleting temp file ${tmp.getAbsolutePath}") + } + } } /** Close this writer, passing along whether the map completed */ From d403562eb4b5b1d804909861d3e8b75d8f6323b9 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Thu, 15 Sep 2016 20:24:15 +0200 Subject: [PATCH 028/306] [SPARK-17114][SQL] Fix aggregates grouped by literals with empty input ## What changes were proposed in this pull request? This PR fixes an issue with aggregates that have an empty input, and use a literals as their grouping keys. These aggregates are currently interpreted as aggregates **without** grouping keys, this triggers the ungrouped code path (which aways returns a single row). This PR fixes the `RemoveLiteralFromGroupExpressions` optimizer rule, which changes the semantics of the Aggregate by eliminating all literal grouping keys. ## How was this patch tested? Added tests to `SQLQueryTestSuite`. Author: Herman van Hovell Closes #15101 from hvanhovell/SPARK-17114-3. --- .../sql/catalyst/optimizer/Optimizer.scala | 11 +++- .../optimizer/AggregateOptimizeSuite.scala | 10 +++- .../resources/sql-tests/inputs/group-by.sql | 17 +++++++ .../sql-tests/results/group-by.sql.out | 51 +++++++++++++++++++ 4 files changed, 86 insertions(+), 3 deletions(-) create mode 100644 sql/core/src/test/resources/sql-tests/inputs/group-by.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/group-by.sql.out diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index d2f0c97989213..0df16b7a56c56 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1098,9 +1098,16 @@ object ReplaceExceptWithAntiJoin extends Rule[LogicalPlan] { */ object RemoveLiteralFromGroupExpressions extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case a @ Aggregate(grouping, _, _) => + case a @ Aggregate(grouping, _, _) if grouping.nonEmpty => val newGrouping = grouping.filter(!_.foldable) - a.copy(groupingExpressions = newGrouping) + if (newGrouping.nonEmpty) { + a.copy(groupingExpressions = newGrouping) + } else { + // All grouping expressions are literals. We should not drop them all, because this can + // change the return semantics when the input of the Aggregate is empty (SPARK-17114). We + // instead replace this by single, easy to hash/sort, literal expression. + a.copy(groupingExpressions = Seq(Literal(0, IntegerType))) + } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala index 4c26c184b7b5b..aecf59aee6a9b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor class AggregateOptimizeSuite extends PlanTest { - val conf = new SimpleCatalystConf(caseSensitiveAnalysis = false) + val conf = SimpleCatalystConf(caseSensitiveAnalysis = false, groupByOrdinal = false) val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf) val analyzer = new Analyzer(catalog, conf) @@ -49,6 +49,14 @@ class AggregateOptimizeSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("do not remove all grouping expressions if they are all literals") { + val query = testRelation.groupBy(Literal("1"), Literal(1) + Literal(2))(sum('b)) + val optimized = Optimize.execute(analyzer.execute(query)) + val correctAnswer = analyzer.execute(testRelation.groupBy(Literal(0))(sum('b))) + + comparePlans(optimized, correctAnswer) + } + test("Remove aliased literals") { val query = testRelation.select('a, Literal(1).as('y)).groupBy('a, 'y)(sum('b)) val optimized = Optimize.execute(analyzer.execute(query)) diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql new file mode 100644 index 0000000000000..6741703d9d82c --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql @@ -0,0 +1,17 @@ +-- Temporary data. +create temporary view myview as values 128, 256 as v(int_col); + +-- group by should produce all input rows, +select int_col, count(*) from myview group by int_col; + +-- group by should produce a single row. +select 'foo', count(*) from myview group by 1; + +-- group-by should not produce any rows (whole stage code generation). +select 'foo' from myview where int_col == 0 group by 1; + +-- group-by should not produce any rows (hash aggregate). +select 'foo', approx_count_distinct(int_col) from myview where int_col == 0 group by 1; + +-- group-by should not produce any rows (sort aggregate). +select 'foo', max(struct(int_col)) from myview where int_col == 0 group by 1; diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out new file mode 100644 index 0000000000000..9127bd4dd4c6f --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -0,0 +1,51 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 6 + + +-- !query 0 +create temporary view myview as values 128, 256 as v(int_col) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +select int_col, count(*) from myview group by int_col +-- !query 1 schema +struct +-- !query 1 output +128 1 +256 1 + + +-- !query 2 +select 'foo', count(*) from myview group by 1 +-- !query 2 schema +struct +-- !query 2 output +foo 2 + + +-- !query 3 +select 'foo' from myview where int_col == 0 group by 1 +-- !query 3 schema +struct +-- !query 3 output + + + +-- !query 4 +select 'foo', approx_count_distinct(int_col) from myview where int_col == 0 group by 1 +-- !query 4 schema +struct +-- !query 4 output + + + +-- !query 5 +select 'foo', max(struct(int_col)) from myview where int_col == 0 group by 1 +-- !query 5 schema +struct> +-- !query 5 output + From fe767395ff46ee6236cf53aece85fcd61c0b49d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B2=91=E7=8E=89=E6=B5=B7?= <261810726@qq.com> Date: Thu, 15 Sep 2016 20:45:00 +0200 Subject: [PATCH 029/306] [SPARK-17429][SQL] use ImplicitCastInputTypes with function Length MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? select length(11); select length(2.0); these sql will return errors, but hive is ok. this PR will support casting input types implicitly for function length the correct result is: select length(11) return 2 select length(2.0) return 3 Author: 岑玉海 <261810726@qq.com> Author: cenyuhai Closes #15014 from cenyuhai/SPARK-17429. --- .../sql/catalyst/expressions/stringExpressions.scala | 2 +- .../org/apache/spark/sql/StringFunctionsSuite.scala | 10 ++++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index a8c23a8b0c536..1bcbb6cfc9246 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1057,7 +1057,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression) @ExpressionDescription( usage = "_FUNC_(str | binary) - Returns the length of str or number of bytes in binary data.", extended = "> SELECT _FUNC_('Spark SQL');\n 9") -case class Length(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class Length(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def dataType: DataType = IntegerType override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, BinaryType)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index 1cc77464b93fc..bcc2351049953 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -330,7 +330,8 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext { } test("string / binary length function") { - val df = Seq(("123", Array[Byte](1, 2, 3, 4), 123)).toDF("a", "b", "c") + val df = Seq(("123", Array[Byte](1, 2, 3, 4), 123, 2.0f, 3.015)) + .toDF("a", "b", "c", "d", "e") checkAnswer( df.select(length($"a"), length($"b")), Row(3, 4)) @@ -339,9 +340,10 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext { df.selectExpr("length(a)", "length(b)"), Row(3, 4)) - intercept[AnalysisException] { - df.selectExpr("length(c)") // int type of the argument is unacceptable - } + checkAnswer( + df.selectExpr("length(c)", "length(d)", "length(e)"), + Row(3, 3, 5) + ) } test("initcap function") { From a6b8182006d0c3dda67c06861067ca78383ecf1b Mon Sep 17 00:00:00 2001 From: Sean Zhong Date: Thu, 15 Sep 2016 20:53:48 +0200 Subject: [PATCH 030/306] [SPARK-17364][SQL] Antlr lexer wrongly treats full qualified identifier as a decimal number token when parsing SQL string ## What changes were proposed in this pull request? The Antlr lexer we use to tokenize a SQL string may wrongly tokenize a fully qualified identifier as a decimal number token. For example, table identifier `default.123_table` is wrongly tokenized as ``` default // Matches lexer rule IDENTIFIER .123 // Matches lexer rule DECIMAL_VALUE _TABLE // Matches lexer rule IDENTIFIER ``` The correct tokenization for `default.123_table` should be: ``` default // Matches lexer rule IDENTIFIER, . // Matches a single dot 123_TABLE // Matches lexer rule IDENTIFIER ``` This PR fix the Antlr grammar so that it can tokenize fully qualified identifier correctly: 1. Fully qualified table name can be parsed correctly. For example, `select * from database.123_suffix`. 2. Fully qualified column name can be parsed correctly, for example `select a.123_suffix from a`. ### Before change #### Case 1: Failed to parse fully qualified column name ``` scala> spark.sql("select a.123_column from a").show org.apache.spark.sql.catalyst.parser.ParseException: extraneous input '.123' expecting {, ... , IDENTIFIER, BACKQUOTED_IDENTIFIER}(line 1, pos 8) == SQL == select a.123_column from a --------^^^ ``` #### Case 2: Failed to parse fully qualified table name ``` scala> spark.sql("select * from default.123_table") org.apache.spark.sql.catalyst.parser.ParseException: extraneous input '.123' expecting {, ... IDENTIFIER, BACKQUOTED_IDENTIFIER}(line 1, pos 21) == SQL == select * from default.123_table ---------------------^^^ ``` ### After Change #### Case 1: fully qualified column name, no ParseException thrown ``` scala> spark.sql("select a.123_column from a").show ``` #### Case 2: fully qualified table name, no ParseException thrown ``` scala> spark.sql("select * from default.123_table") ``` ## How was this patch tested? Unit test. Author: Sean Zhong Closes #15006 from clockfly/SPARK-17364. --- .../spark/sql/catalyst/parser/SqlBase.g4 | 44 +++++++++++++++---- .../parser/ExpressionParserSuite.scala | 15 ++++++- .../parser/TableIdentifierParserSuite.scala | 13 ++++++ 3 files changed, 63 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index b475abdce2da9..7023c0c8c493f 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -16,6 +16,30 @@ grammar SqlBase; +@members { + /** + * Verify whether current token is a valid decimal token (which contains dot). + * Returns true if the character that follows the token is not a digit or letter or underscore. + * + * For example: + * For char stream "2.3", "2." is not a valid decimal token, because it is followed by digit '3'. + * For char stream "2.3_", "2.3" is not a valid decimal token, because it is followed by '_'. + * For char stream "2.3W", "2.3" is not a valid decimal token, because it is followed by 'W'. + * For char stream "12.0D 34.E2+0.12 " 12.0D is a valid decimal token because it is folllowed + * by a space. 34.E2 is a valid decimal token because it is followed by symbol '+' + * which is not a digit or letter or underscore. + */ + public boolean isValidDecimal() { + int nextChar = _input.LA(1); + if (nextChar >= 'A' && nextChar <= 'Z' || nextChar >= '0' && nextChar <= '9' || + nextChar == '_') { + return false; + } else { + return true; + } + } +} + tokens { DELIMITER } @@ -920,23 +944,22 @@ INTEGER_VALUE ; DECIMAL_VALUE - : DIGIT+ '.' DIGIT* - | '.' DIGIT+ + : DECIMAL_DIGITS {isValidDecimal()}? ; SCIENTIFIC_DECIMAL_VALUE - : DIGIT+ ('.' DIGIT*)? EXPONENT - | '.' DIGIT+ EXPONENT + : DIGIT+ EXPONENT + | DECIMAL_DIGITS EXPONENT {isValidDecimal()}? ; DOUBLE_LITERAL - : - (INTEGER_VALUE | DECIMAL_VALUE | SCIENTIFIC_DECIMAL_VALUE) 'D' + : DIGIT+ EXPONENT? 'D' + | DECIMAL_DIGITS EXPONENT? 'D' {isValidDecimal()}? ; BIGDECIMAL_LITERAL - : - (INTEGER_VALUE | DECIMAL_VALUE | SCIENTIFIC_DECIMAL_VALUE) 'BD' + : DIGIT+ EXPONENT? 'BD' + | DECIMAL_DIGITS EXPONENT? 'BD' {isValidDecimal()}? ; IDENTIFIER @@ -947,6 +970,11 @@ BACKQUOTED_IDENTIFIER : '`' ( ~'`' | '``' )* '`' ; +fragment DECIMAL_DIGITS + : DIGIT+ '.' DIGIT* + | '.' DIGIT+ + ; + fragment EXPONENT : 'E' [+-]? DIGIT+ ; diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index 4e399eef1fed8..f319215f05681 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.parser import java.sql.{Date, Timestamp} -import org.apache.spark.sql.catalyst.FunctionIdentifier +import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, _} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.PlanTest @@ -518,4 +518,17 @@ class ExpressionParserSuite extends PlanTest { assertEqual("current_date", CurrentDate()) assertEqual("current_timestamp", CurrentTimestamp()) } + + test("SPARK-17364, fully qualified column name which starts with number") { + assertEqual("123_", UnresolvedAttribute("123_")) + assertEqual("1a.123_", UnresolvedAttribute("1a.123_")) + // ".123" should not be treated as token of type DECIMAL_VALUE + assertEqual("a.123A", UnresolvedAttribute("a.123A")) + // ".123E3" should not be treated as token of type SCIENTIFIC_DECIMAL_VALUE + assertEqual("a.123E3_column", UnresolvedAttribute("a.123E3_column")) + // ".123D" should not be treated as token of type DOUBLE_LITERAL + assertEqual("a.123D_column", UnresolvedAttribute("a.123D_column")) + // ".123BD" should not be treated as token of type BIGDECIMAL_LITERAL + assertEqual("a.123BD_column", UnresolvedAttribute("a.123BD_column")) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala index dadb8a8def43b..793be8953d07a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala @@ -91,4 +91,17 @@ class TableIdentifierParserSuite extends SparkFunSuite { assert(TableIdentifier(nonReserved) === parseTableIdentifier(nonReserved)) } } + + test("SPARK-17364 table identifier - contains number") { + assert(parseTableIdentifier("123_") == TableIdentifier("123_")) + assert(parseTableIdentifier("1a.123_") == TableIdentifier("123_", Some("1a"))) + // ".123" should not be treated as token of type DECIMAL_VALUE + assert(parseTableIdentifier("a.123A") == TableIdentifier("123A", Some("a"))) + // ".123E3" should not be treated as token of type SCIENTIFIC_DECIMAL_VALUE + assert(parseTableIdentifier("a.123E3_LIST") == TableIdentifier("123E3_LIST", Some("a"))) + // ".123D" should not be treated as token of type DOUBLE_LITERAL + assert(parseTableIdentifier("a.123D_LIST") == TableIdentifier("123D_LIST", Some("a"))) + // ".123BD" should not be treated as token of type BIGDECIMAL_LITERAL + assert(parseTableIdentifier("a.123BD_LIST") == TableIdentifier("123BD_LIST", Some("a"))) + } } From 1202075c95eabba0ffebc170077df798f271a139 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 15 Sep 2016 11:54:17 -0700 Subject: [PATCH 031/306] [SPARK-17484] Prevent invalid block locations from being reported after put() exceptions ## What changes were proposed in this pull request? If a BlockManager `put()` call failed after the BlockManagerMaster was notified of a block's availability then incomplete cleanup logic in a `finally` block would never send a second block status method to inform the master of the block's unavailability. This, in turn, leads to fetch failures and used to be capable of causing complete job failures before #15037 was fixed. This patch addresses this issue via multiple small changes: - The `finally` block now calls `removeBlockInternal` when cleaning up from a failed `put()`; in addition to removing the `BlockInfo` entry (which was _all_ that the old cleanup logic did), this code (redundantly) tries to remove the block from the memory and disk stores (as an added layer of defense against bugs lower down in the stack) and optionally notifies the master of block removal (which now happens during exception-triggered cleanup). - When a BlockManager receives a request for a block that it does not have it will now notify the master to update its block locations. This ensures that bad metadata pointing to non-existent blocks will eventually be fixed. Note that I could have implemented this logic in the block manager client (rather than in the remote server), but that would introduce the problem of distinguishing between transient and permanent failures; on the server, however, we know definitively that the block isn't present. - Catch `NonFatal` instead of `Exception` to avoid swallowing `InterruptedException`s thrown from synchronous block replication calls. This patch depends upon the refactorings in #15036, so that other patch will also have to be backported when backporting this fix. For more background on this issue, including example logs from a real production failure, see [SPARK-17484](https://issues.apache.org/jira/browse/SPARK-17484). ## How was this patch tested? Two new regression tests in BlockManagerSuite. Author: Josh Rosen Closes #15085 from JoshRosen/SPARK-17484. --- .../apache/spark/storage/BlockManager.scala | 37 +++++++++++++++---- .../spark/storage/BlockManagerSuite.scala | 34 +++++++++++++++++ 2 files changed, 63 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index c172ac2cdc0e3..aa29acfd70461 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -283,7 +283,12 @@ private[spark] class BlockManager( } else { getLocalBytes(blockId) match { case Some(buffer) => new BlockManagerManagedBuffer(blockInfoManager, blockId, buffer) - case None => throw new BlockNotFoundException(blockId.toString) + case None => + // If this block manager receives a request for a block that it doesn't have then it's + // likely that the master has outdated block statuses for this block. Therefore, we send + // an RPC so that this block is marked as being unavailable from this block manager. + reportBlockStatus(blockId, BlockStatus.empty) + throw new BlockNotFoundException(blockId.toString) } } } @@ -859,22 +864,38 @@ private[spark] class BlockManager( } val startTimeMs = System.currentTimeMillis - var blockWasSuccessfullyStored: Boolean = false + var exceptionWasThrown: Boolean = true val result: Option[T] = try { val res = putBody(putBlockInfo) - blockWasSuccessfullyStored = res.isEmpty - res - } finally { - if (blockWasSuccessfullyStored) { + exceptionWasThrown = false + if (res.isEmpty) { + // the block was successfully stored if (keepReadLock) { blockInfoManager.downgradeLock(blockId) } else { blockInfoManager.unlock(blockId) } } else { - blockInfoManager.removeBlock(blockId) + removeBlockInternal(blockId, tellMaster = false) logWarning(s"Putting block $blockId failed") } + res + } finally { + // This cleanup is performed in a finally block rather than a `catch` to avoid having to + // catch and properly re-throw InterruptedException. + if (exceptionWasThrown) { + logWarning(s"Putting block $blockId failed due to an exception") + // If an exception was thrown then it's possible that the code in `putBody` has already + // notified the master about the availability of this block, so we need to send an update + // to remove this block location. + removeBlockInternal(blockId, tellMaster = tellMaster) + // The `putBody` code may have also added a new block status to TaskMetrics, so we need + // to cancel that out by overwriting it with an empty block status. We only do this if + // the finally block was entered via an exception because doing this unconditionally would + // cause us to send empty block statuses for every block that failed to be cached due to + // a memory shortage (which is an expected failure, unlike an uncaught exception). + addUpdatedBlockStatusToTaskMetrics(blockId, BlockStatus.empty) + } } if (level.replication > 1) { logDebug("Putting block %s with replication took %s" @@ -1173,7 +1194,7 @@ private[spark] class BlockManager( done = true // specified number of peers have been replicated to } } catch { - case e: Exception => + case NonFatal(e) => logWarning(s"Failed to replicate $blockId to $peer, failure #$failures", e) failures += 1 replicationFailed = true diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index fdf28b7dcbcf4..6d53d2e5f0ca6 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -861,6 +861,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE serializerManager, conf, memoryManager, mapOutputTracker, shuffleManager, transfer, securityMgr, 0) memoryManager.setMemoryStore(store.memoryStore) + store.initialize("app-id") // The put should fail since a1 is not serializable. class UnserializableClass @@ -1206,6 +1207,39 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE verify(mockBlockManagerMaster, times(2)).getLocations("item") } + test("SPARK-17484: block status is properly updated following an exception in put()") { + val mockBlockTransferService = new MockBlockTransferService(maxFailures = 10) { + override def uploadBlock( + hostname: String, + port: Int, execId: String, + blockId: BlockId, + blockData: ManagedBuffer, + level: StorageLevel, + classTag: ClassTag[_]): Future[Unit] = { + throw new InterruptedException("Intentional interrupt") + } + } + store = makeBlockManager(8000, "executor1", transferService = Option(mockBlockTransferService)) + store2 = makeBlockManager(8000, "executor2", transferService = Option(mockBlockTransferService)) + intercept[InterruptedException] { + store.putSingle("item", "value", StorageLevel.MEMORY_ONLY_2, tellMaster = true) + } + assert(store.getLocalBytes("item").isEmpty) + assert(master.getLocations("item").isEmpty) + assert(store2.getRemoteBytes("item").isEmpty) + } + + test("SPARK-17484: master block locations are updated following an invalid remote block fetch") { + store = makeBlockManager(8000, "executor1") + store2 = makeBlockManager(8000, "executor2") + store.putSingle("item", "value", StorageLevel.MEMORY_ONLY, tellMaster = true) + assert(master.getLocations("item").nonEmpty) + store.removeBlock("item", tellMaster = false) + assert(master.getLocations("item").nonEmpty) + assert(store2.getRemoteBytes("item").isEmpty) + assert(master.getLocations("item").isEmpty) + } + class MockBlockTransferService(val maxFailures: Int) extends BlockTransferService { var numCalls = 0 From b72486f82dd9920135442191be5d384028e7fb41 Mon Sep 17 00:00:00 2001 From: Andrew Ray Date: Thu, 15 Sep 2016 21:45:29 +0200 Subject: [PATCH 032/306] [SPARK-17458][SQL] Alias specified for aggregates in a pivot are not honored ## What changes were proposed in this pull request? This change preserves aliases that are given for pivot aggregations ## How was this patch tested? New unit test Author: Andrew Ray Closes #15111 from aray/SPARK-17458. --- .../apache/spark/sql/catalyst/analysis/Analyzer.scala | 10 +++++++++- .../org/apache/spark/sql/DataFramePivotSuite.scala | 11 +++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 92bf8e0536fc4..5210f42c557b6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -373,7 +373,15 @@ class Analyzer( case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates, child) => val singleAgg = aggregates.size == 1 def outputName(value: Literal, aggregate: Expression): String = { - if (singleAgg) value.toString else value + "_" + aggregate.sql + if (singleAgg) { + value.toString + } else { + val suffix = aggregate match { + case n: NamedExpression => n.name + case _ => aggregate.sql + } + value + "_" + suffix + } } if (aggregates.forall(a => PivotFirst.supportsDataType(a.dataType))) { // Since evaluating |pivotValues| if statements for each input row can get slow this is an diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala index d5cb5e15688e8..1bbe1354d55f4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala @@ -197,4 +197,15 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext{ Row(2013, Seq(48000.0, 7.0), Seq(30000.0, 7.0)) :: Nil ) } + + test("pivot preserves aliases if given") { + assertResult( + Array("year", "dotNET_foo", "dotNET_avg(`earnings`)", "Java_foo", "Java_avg(`earnings`)") + )( + courseSales.groupBy($"year") + .pivot("course", Seq("dotNET", "Java")) + .agg(sum($"earnings").as("foo"), avg($"earnings")).columns + ) + } + } From b2e27262440015f57bcfa888921c9cc017800910 Mon Sep 17 00:00:00 2001 From: Jagadeesan Date: Fri, 16 Sep 2016 10:18:45 +0100 Subject: [PATCH 033/306] =?UTF-8?q?[SPARK-17543]=20Missing=20log4j=20confi?= =?UTF-8?q?g=20file=20for=20tests=20in=20common/network-=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? The Maven module `common/network-shuffle` does not have a log4j configuration file for its test cases. So, added `log4j.properties` in the directory `src/test/resources`. …shuffle] Author: Jagadeesan Closes #15108 from jagadeesanas2/SPARK-17543. --- .../src/test/resources/log4j.properties | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) create mode 100644 common/network-shuffle/src/test/resources/log4j.properties diff --git a/common/network-shuffle/src/test/resources/log4j.properties b/common/network-shuffle/src/test/resources/log4j.properties new file mode 100644 index 0000000000000..e73978908b683 --- /dev/null +++ b/common/network-shuffle/src/test/resources/log4j.properties @@ -0,0 +1,24 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Set everything to be logged to the file target/unit-tests.log +log4j.rootCategory=DEBUG, file +log4j.appender.file=org.apache.log4j.FileAppender +log4j.appender.file.append=true +log4j.appender.file.file=target/unit-tests.log +log4j.appender.file.layout=org.apache.log4j.PatternLayout +log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n From fc1efb720c9c0033077c3c20ee144d0f757e6bcd Mon Sep 17 00:00:00 2001 From: Adam Roberts Date: Fri, 16 Sep 2016 10:20:50 +0100 Subject: [PATCH 034/306] [SPARK-17534][TESTS] Increase timeouts for DirectKafkaStreamSuite tests **## What changes were proposed in this pull request?** There are two tests in this suite that are particularly flaky on the following hardware: 2x Intel(R) Xeon(R) CPU E5-2697 v2 2.70GHz and 16 GB of RAM, 1 TB HDD This simple PR increases the timeout times and batch duration so they can reliably pass **## How was this patch tested?** Existing unit tests with the two core box where I was seeing the failures often Author: Adam Roberts Closes #15094 from a-roberts/patch-6. --- .../spark/streaming/kafka010/DirectKafkaStreamSuite.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala index b1d90b8a82d59..e04f35eceb1b4 100644 --- a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala +++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala @@ -108,7 +108,7 @@ class DirectKafkaStreamSuite val expectedTotal = (data.values.sum * topics.size) - 2 val kafkaParams = getKafkaParams("auto.offset.reset" -> "earliest") - ssc = new StreamingContext(sparkConf, Milliseconds(200)) + ssc = new StreamingContext(sparkConf, Milliseconds(1000)) val stream = withClue("Error creating direct stream") { KafkaUtils.createDirectStream[String, String]( ssc, @@ -150,7 +150,7 @@ class DirectKafkaStreamSuite allReceived.addAll(Arrays.asList(rdd.map(r => (r.key, r.value)).collect(): _*)) } ssc.start() - eventually(timeout(20000.milliseconds), interval(200.milliseconds)) { + eventually(timeout(100000.milliseconds), interval(1000.milliseconds)) { assert(allReceived.size === expectedTotal, "didn't get expected number of messages, messages:\n" + allReceived.asScala.mkString("\n")) @@ -172,7 +172,7 @@ class DirectKafkaStreamSuite val expectedTotal = (data.values.sum * 2) - 3 val kafkaParams = getKafkaParams("auto.offset.reset" -> "earliest") - ssc = new StreamingContext(sparkConf, Milliseconds(200)) + ssc = new StreamingContext(sparkConf, Milliseconds(1000)) val stream = withClue("Error creating direct stream") { KafkaUtils.createDirectStream[String, String]( ssc, @@ -214,7 +214,7 @@ class DirectKafkaStreamSuite allReceived.addAll(Arrays.asList(rdd.map(r => (r.key, r.value)).collect(): _*)) } ssc.start() - eventually(timeout(20000.milliseconds), interval(200.milliseconds)) { + eventually(timeout(100000.milliseconds), interval(1000.milliseconds)) { assert(allReceived.size === expectedTotal, "didn't get expected number of messages, messages:\n" + allReceived.asScala.mkString("\n")) From a425a37a5d894e0d7462c8faa81a913495189ece Mon Sep 17 00:00:00 2001 From: Sean Zhong Date: Fri, 16 Sep 2016 19:37:30 +0800 Subject: [PATCH 035/306] [SPARK-17426][SQL] Refactor `TreeNode.toJSON` to avoid OOM when converting unknown fields to JSON ## What changes were proposed in this pull request? This PR is a follow up of SPARK-17356. Current implementation of `TreeNode.toJSON` recursively converts all fields of TreeNode to JSON, even if the field is of type `Seq` or type Map. This may trigger out of memory exception in cases like: 1. the Seq or Map can be very big. Converting them to JSON may take huge memory, which may trigger out of memory error. 2. Some user space input may also be propagated to the Plan. The user space input can be of arbitrary type, and may also be self-referencing. Trying to print user space input to JSON may trigger out of memory error or stack overflow error. For a code example, please check the Jira description of SPARK-17426. In this PR, we refactor the `TreeNode.toJSON` so that we only convert a field to JSON string if the field is a safe type. ## How was this patch tested? Unit test. Author: Sean Zhong Closes #14990 from clockfly/json_oom2. --- .../spark/sql/catalyst/trees/TreeNode.scala | 218 +++---------- .../sql/catalyst/trees/TreeNodeSuite.scala | 294 +++++++++++++++++- .../org/apache/spark/sql/QueryTest.scala | 136 -------- 3 files changed, 333 insertions(+), 315 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 893af5146c5b3..83cb375525832 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -30,10 +30,15 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.SparkContext import org.apache.spark.rdd.{EmptyRDD, RDD} +import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType, FunctionResource} +import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.ScalaReflection._ import org.apache.spark.sql.catalyst.ScalaReflectionLock +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.JoinType +import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, Partitioning} import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils @@ -597,7 +602,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { // this child in all children. case (name, value: TreeNode[_]) if containsChild(value) => name -> JInt(children.indexOf(value)) - case (name, value: Seq[BaseType]) if value.toSet.subsetOf(containsChild) => + case (name, value: Seq[BaseType]) if value.forall(containsChild) => name -> JArray( value.map(v => JInt(children.indexOf(v.asInstanceOf[TreeNode[_]]))).toList ) @@ -621,194 +626,53 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { // SPARK-17356: In usage of mllib, Metadata may store a huge vector of data, transforming // it to JSON may trigger OutOfMemoryError. case m: Metadata => Metadata.empty.jsonValue + case clazz: Class[_] => JString(clazz.getName) case s: StorageLevel => ("useDisk" -> s.useDisk) ~ ("useMemory" -> s.useMemory) ~ ("useOffHeap" -> s.useOffHeap) ~ ("deserialized" -> s.deserialized) ~ ("replication" -> s.replication) case n: TreeNode[_] => n.jsonValue case o: Option[_] => o.map(parseToJson) - case t: Seq[_] => JArray(t.map(parseToJson).toList) - case m: Map[_, _] => - val fields = m.toList.map { case (k: String, v) => (k, parseToJson(v)) } - JObject(fields) - case r: RDD[_] => JNothing + // Recursive scan Seq[TreeNode], Seq[Partitioning], Seq[DataType] + case t: Seq[_] if t.forall(_.isInstanceOf[TreeNode[_]]) || + t.forall(_.isInstanceOf[Partitioning]) || t.forall(_.isInstanceOf[DataType]) => + JArray(t.map(parseToJson).toList) + case t: Seq[_] if t.length > 0 && t.head.isInstanceOf[String] => + JString(Utils.truncatedString(t, "[", ", ", "]")) + case t: Seq[_] => JNull + case m: Map[_, _] => JNull // if it's a scala object, we can simply keep the full class path. // TODO: currently if the class name ends with "$", we think it's a scala object, there is // probably a better way to check it. case obj if obj.getClass.getName.endsWith("$") => "object" -> obj.getClass.getName - // returns null if the product type doesn't have a primary constructor, e.g. HiveFunctionWrapper - case p: Product => try { - val fieldNames = getConstructorParameterNames(p.getClass) - val fieldValues = p.productIterator.toSeq - assert(fieldNames.length == fieldValues.length) - ("product-class" -> JString(p.getClass.getName)) :: fieldNames.zip(fieldValues).map { - case (name, value) => name -> parseToJson(value) - }.toList - } catch { - case _: RuntimeException => null - } - case _ => JNull - } -} - -object TreeNode { - def fromJSON[BaseType <: TreeNode[BaseType]](json: String, sc: SparkContext): BaseType = { - val jsonAST = parse(json) - assert(jsonAST.isInstanceOf[JArray]) - reconstruct(jsonAST.asInstanceOf[JArray], sc).asInstanceOf[BaseType] - } - - private def reconstruct(treeNodeJson: JArray, sc: SparkContext): TreeNode[_] = { - assert(treeNodeJson.arr.forall(_.isInstanceOf[JObject])) - val jsonNodes = Stack(treeNodeJson.arr.map(_.asInstanceOf[JObject]): _*) - - def parseNextNode(): TreeNode[_] = { - val nextNode = jsonNodes.pop() - - val cls = Utils.classForName((nextNode \ "class").asInstanceOf[JString].s) - if (cls == classOf[Literal]) { - Literal.fromJSON(nextNode) - } else if (cls.getName.endsWith("$")) { - cls.getField("MODULE$").get(cls).asInstanceOf[TreeNode[_]] - } else { - val numChildren = (nextNode \ "num-children").asInstanceOf[JInt].num.toInt - - val children: Seq[TreeNode[_]] = (1 to numChildren).map(_ => parseNextNode()) - val fields = getConstructorParameters(cls) - - val parameters: Array[AnyRef] = fields.map { - case (fieldName, fieldType) => - parseFromJson(nextNode \ fieldName, fieldType, children, sc) - }.toArray - - val maybeCtor = cls.getConstructors.find { p => - val expectedTypes = p.getParameterTypes - expectedTypes.length == fields.length && expectedTypes.zip(fields.map(_._2)).forall { - case (cls, tpe) => cls == getClassFromType(tpe) - } - } - if (maybeCtor.isEmpty) { - sys.error(s"No valid constructor for ${cls.getName}") - } else { - try { - maybeCtor.get.newInstance(parameters: _*).asInstanceOf[TreeNode[_]] - } catch { - case e: java.lang.IllegalArgumentException => - throw new RuntimeException( - s""" - |Failed to construct tree node: ${cls.getName} - |ctor: ${maybeCtor.get} - |types: ${parameters.map(_.getClass).mkString(", ")} - |args: ${parameters.mkString(", ")} - """.stripMargin, e) - } - } - } - } - - parseNextNode() - } - - import universe._ - - private def parseFromJson( - value: JValue, - expectedType: Type, - children: Seq[TreeNode[_]], - sc: SparkContext): AnyRef = ScalaReflectionLock.synchronized { - if (value == JNull) return null - - expectedType match { - case t if t <:< definitions.BooleanTpe => - value.asInstanceOf[JBool].value: java.lang.Boolean - case t if t <:< definitions.ByteTpe => - value.asInstanceOf[JInt].num.toByte: java.lang.Byte - case t if t <:< definitions.ShortTpe => - value.asInstanceOf[JInt].num.toShort: java.lang.Short - case t if t <:< definitions.IntTpe => - value.asInstanceOf[JInt].num.toInt: java.lang.Integer - case t if t <:< definitions.LongTpe => - value.asInstanceOf[JInt].num.toLong: java.lang.Long - case t if t <:< definitions.FloatTpe => - value.asInstanceOf[JDouble].num.toFloat: java.lang.Float - case t if t <:< definitions.DoubleTpe => - value.asInstanceOf[JDouble].num: java.lang.Double - - case t if t <:< localTypeOf[java.lang.Boolean] => - value.asInstanceOf[JBool].value: java.lang.Boolean - case t if t <:< localTypeOf[BigInt] => value.asInstanceOf[JInt].num - case t if t <:< localTypeOf[java.lang.String] => value.asInstanceOf[JString].s - case t if t <:< localTypeOf[UUID] => UUID.fromString(value.asInstanceOf[JString].s) - case t if t <:< localTypeOf[DataType] => DataType.parseDataType(value) - case t if t <:< localTypeOf[Metadata] => Metadata.fromJObject(value.asInstanceOf[JObject]) - case t if t <:< localTypeOf[StorageLevel] => - val JBool(useDisk) = value \ "useDisk" - val JBool(useMemory) = value \ "useMemory" - val JBool(useOffHeap) = value \ "useOffHeap" - val JBool(deserialized) = value \ "deserialized" - val JInt(replication) = value \ "replication" - StorageLevel(useDisk, useMemory, useOffHeap, deserialized, replication.toInt) - case t if t <:< localTypeOf[TreeNode[_]] => value match { - case JInt(i) => children(i.toInt) - case arr: JArray => reconstruct(arr, sc) - case _ => throw new RuntimeException(s"$value is not a valid json value for tree node.") + case p: Product if shouldConvertToJson(p) => + try { + val fieldNames = getConstructorParameterNames(p.getClass) + val fieldValues = p.productIterator.toSeq + assert(fieldNames.length == fieldValues.length) + ("product-class" -> JString(p.getClass.getName)) :: fieldNames.zip(fieldValues).map { + case (name, value) => name -> parseToJson(value) + }.toList + } catch { + case _: RuntimeException => null } - case t if t <:< localTypeOf[Option[_]] => - if (value == JNothing) { - None - } else { - val TypeRef(_, _, Seq(optType)) = t - Option(parseFromJson(value, optType, children, sc)) - } - case t if t <:< localTypeOf[Seq[_]] => - val TypeRef(_, _, Seq(elementType)) = t - val JArray(elements) = value - elements.map(parseFromJson(_, elementType, children, sc)).toSeq - case t if t <:< localTypeOf[Map[_, _]] => - val TypeRef(_, _, Seq(keyType, valueType)) = t - val JObject(fields) = value - fields.map { - case (name, value) => name -> parseFromJson(value, valueType, children, sc) - }.toMap - case t if t <:< localTypeOf[RDD[_]] => - new EmptyRDD[Any](sc) - case _ if isScalaObject(value) => - val JString(clsName) = value \ "object" - val cls = Utils.classForName(clsName) - cls.getField("MODULE$").get(cls) - case t if t <:< localTypeOf[Product] => - val fields = getConstructorParameters(t) - val clsName = getClassNameFromType(t) - parseToProduct(clsName, fields, value, children, sc) - // There maybe some cases that the parameter type signature is not Product but the value is, - // e.g. `SpecifiedWindowFrame` with type signature `WindowFrame`, handle it here. - case _ if isScalaProduct(value) => - val JString(clsName) = value \ "product-class" - val fields = getConstructorParameters(Utils.classForName(clsName)) - parseToProduct(clsName, fields, value, children, sc) - case _ => sys.error(s"Do not support type $expectedType with json $value.") - } - } - - private def parseToProduct( - clsName: String, - fields: Seq[(String, Type)], - value: JValue, - children: Seq[TreeNode[_]], - sc: SparkContext): AnyRef = { - val parameters: Array[AnyRef] = fields.map { - case (fieldName, fieldType) => parseFromJson(value \ fieldName, fieldType, children, sc) - }.toArray - val ctor = Utils.classForName(clsName).getConstructors.maxBy(_.getParameterTypes.size) - ctor.newInstance(parameters: _*).asInstanceOf[AnyRef] - } - - private def isScalaObject(jValue: JValue): Boolean = (jValue \ "object") match { - case JString(str) if str.endsWith("$") => true - case _ => false + case _ => JNull } - private def isScalaProduct(jValue: JValue): Boolean = (jValue \ "product-class") match { - case _: JString => true + private def shouldConvertToJson(product: Product): Boolean = product match { + case exprId: ExprId => true + case field: StructField => true + case id: TableIdentifier => true + case join: JoinType => true + case id: FunctionIdentifier => true + case spec: BucketSpec => true + case catalog: CatalogTable => true + case boundary: FrameBoundary => true + case frame: WindowFrame => true + case partition: Partitioning => true + case resource: FunctionResource => true + case broadcast: BroadcastMode => true + case table: CatalogTableType => true + case storage: CatalogStorageFormat => true case _ => false } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 6246380dbeb9b..cb0426c7a98a1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -17,13 +17,29 @@ package org.apache.spark.sql.catalyst.trees +import java.math.BigInteger +import java.util.UUID + import scala.collection.mutable.ArrayBuffer +import org.json4s.jackson.JsonMethods +import org.json4s.jackson.JsonMethods._ +import org.json4s.JsonAST._ +import org.json4s.JsonDSL._ + import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType, FunctionResource, JarResource} +import org.apache.spark.sql.catalyst.dsl.expressions.DslString import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback -import org.apache.spark.sql.types.{IntegerType, NullType, StringType} +import org.apache.spark.sql.catalyst.FunctionIdentifier +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.plans.{LeftOuter, NaturalJoin} +import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Union} +import org.apache.spark.sql.catalyst.plans.physical.{IdentityBroadcastMode, RoundRobinPartitioning, SinglePartition} +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.types.{BooleanType, DoubleType, FloatType, IntegerType, Metadata, NullType, StringType, StructField, StructType} +import org.apache.spark.storage.StorageLevel case class Dummy(optKey: Option[Expression]) extends Expression with CodegenFallback { override def children: Seq[Expression] = optKey.toSeq @@ -45,6 +61,20 @@ case class ExpressionInMap(map: Map[String, Expression]) extends Expression with override lazy val resolved = true } +case class JsonTestTreeNode(arg: Any) extends LeafNode { + override def output: Seq[Attribute] = Seq.empty[Attribute] +} + +case class NameValue(name: String, value: Any) + +case object DummyObject + +case class SelfReferenceUDF( + var config: Map[String, Any] = Map.empty[String, Any]) extends Function1[String, Boolean] { + config += "self" -> this + def apply(key: String): Boolean = config.contains(key) +} + class TreeNodeSuite extends SparkFunSuite { test("top node changed") { val after = Literal(1) transform { case Literal(1, _) => Literal(2) } @@ -261,4 +291,264 @@ class TreeNodeSuite extends SparkFunSuite { assert(actual === expected) } } + + test("toJSON") { + def assertJSON(input: Any, json: JValue): Unit = { + val expected = + s""" + |[{ + | "class": "${classOf[JsonTestTreeNode].getName}", + | "num-children": 0, + | "arg": ${compact(render(json))} + |}] + """.stripMargin + compareJSON(JsonTestTreeNode(input).toJSON, expected) + } + + // Converts simple types to JSON + assertJSON(true, true) + assertJSON(33.toByte, 33) + assertJSON(44, 44) + assertJSON(55L, 55L) + assertJSON(3.0, 3.0) + assertJSON(4.0D, 4.0D) + assertJSON(BigInt(BigInteger.valueOf(88L)), 88L) + assertJSON(null, JNull) + assertJSON("text", "text") + assertJSON(Some("text"), "text") + compareJSON(JsonTestTreeNode(None).toJSON, + s"""[ + | { + | "class": "${classOf[JsonTestTreeNode].getName}", + | "num-children": 0 + | } + |] + """.stripMargin) + + val uuid = UUID.randomUUID() + assertJSON(uuid, uuid.toString) + + // Converts Spark Sql DataType to JSON + assertJSON(IntegerType, "integer") + assertJSON(Metadata.empty, JObject(Nil)) + assertJSON( + StorageLevel.NONE, + JObject( + "useDisk" -> false, + "useMemory" -> false, + "useOffHeap" -> false, + "deserialized" -> false, + "replication" -> 1) + ) + + // Converts TreeNode argument to JSON + assertJSON( + Literal(333), + List( + JObject( + "class" -> classOf[Literal].getName, + "num-children" -> 0, + "value" -> "333", + "dataType" -> "integer"))) + + // Converts Seq[String] to JSON + assertJSON(Seq("1", "2", "3"), "[1, 2, 3]") + + // Converts Seq[DataType] to JSON + assertJSON(Seq(IntegerType, DoubleType, FloatType), List("integer", "double", "float")) + + // Converts Seq[Partitioning] to JSON + assertJSON( + Seq(SinglePartition, RoundRobinPartitioning(numPartitions = 3)), + List( + JObject("object" -> JString(SinglePartition.getClass.getName)), + JObject( + "product-class" -> classOf[RoundRobinPartitioning].getName, + "numPartitions" -> 3))) + + // Converts case object to JSON + assertJSON(DummyObject, JObject("object" -> JString(DummyObject.getClass.getName))) + + // Converts ExprId to JSON + assertJSON( + ExprId(0, uuid), + JObject( + "product-class" -> classOf[ExprId].getName, + "id" -> 0, + "jvmId" -> uuid.toString)) + + // Converts StructField to JSON + assertJSON( + StructField("field", IntegerType), + JObject( + "product-class" -> classOf[StructField].getName, + "name" -> "field", + "dataType" -> "integer", + "nullable" -> true, + "metadata" -> JObject(Nil))) + + // Converts TableIdentifier to JSON + assertJSON( + TableIdentifier("table"), + JObject( + "product-class" -> classOf[TableIdentifier].getName, + "table" -> "table")) + + // Converts JoinType to JSON + assertJSON( + NaturalJoin(LeftOuter), + JObject( + "product-class" -> classOf[NaturalJoin].getName, + "tpe" -> JObject("object" -> JString(LeftOuter.getClass.getName)))) + + // Converts FunctionIdentifier to JSON + assertJSON( + FunctionIdentifier("function", None), + JObject( + "product-class" -> JString(classOf[FunctionIdentifier].getName), + "funcName" -> "function")) + + // Converts BucketSpec to JSON + assertJSON( + BucketSpec(1, Seq("bucket"), Seq("sort")), + JObject( + "product-class" -> classOf[BucketSpec].getName, + "numBuckets" -> 1, + "bucketColumnNames" -> "[bucket]", + "sortColumnNames" -> "[sort]")) + + // Converts FrameBoundary to JSON + assertJSON( + ValueFollowing(3), + JObject( + "product-class" -> classOf[ValueFollowing].getName, + "value" -> 3)) + + // Converts WindowFrame to JSON + assertJSON( + SpecifiedWindowFrame(RowFrame, UnboundedFollowing, CurrentRow), + JObject( + "product-class" -> classOf[SpecifiedWindowFrame].getName, + "frameType" -> JObject("object" -> JString(RowFrame.getClass.getName)), + "frameStart" -> JObject("object" -> JString(UnboundedFollowing.getClass.getName)), + "frameEnd" -> JObject("object" -> JString(CurrentRow.getClass.getName)))) + + // Converts Partitioning to JSON + assertJSON( + RoundRobinPartitioning(numPartitions = 3), + JObject( + "product-class" -> classOf[RoundRobinPartitioning].getName, + "numPartitions" -> 3)) + + // Converts FunctionResource to JSON + assertJSON( + FunctionResource(JarResource, "file:///"), + JObject( + "product-class" -> JString(classOf[FunctionResource].getName), + "resourceType" -> JObject("object" -> JString(JarResource.getClass.getName)), + "uri" -> "file:///")) + + // Converts BroadcastMode to JSON + assertJSON( + IdentityBroadcastMode, + JObject("object" -> JString(IdentityBroadcastMode.getClass.getName))) + + // Converts CatalogTable to JSON + assertJSON( + CatalogTable( + TableIdentifier("table"), + CatalogTableType.MANAGED, + CatalogStorageFormat.empty, + StructType(StructField("a", IntegerType, true) :: Nil), + createTime = 0L), + + JObject( + "product-class" -> classOf[CatalogTable].getName, + "identifier" -> JObject( + "product-class" -> classOf[TableIdentifier].getName, + "table" -> "table" + ), + "tableType" -> JObject( + "product-class" -> classOf[CatalogTableType].getName, + "name" -> "MANAGED" + ), + "storage" -> JObject( + "product-class" -> classOf[CatalogStorageFormat].getName, + "compressed" -> false, + "properties" -> JNull + ), + "schema" -> JObject( + "type" -> "struct", + "fields" -> List( + JObject( + "name" -> "a", + "type" -> "integer", + "nullable" -> true, + "metadata" -> JObject(Nil)))), + "partitionColumnNames" -> List.empty[String], + "owner" -> "", + "createTime" -> 0, + "lastAccessTime" -> -1, + "properties" -> JNull, + "unsupportedFeatures" -> List.empty[String])) + + // For unknown case class, returns JNull. + val bigValue = new Array[Int](10000) + assertJSON(NameValue("name", bigValue), JNull) + + // Converts Seq[TreeNode] to JSON recursively + assertJSON( + Seq(Literal(1), Literal(2)), + List( + List( + JObject( + "class" -> JString(classOf[Literal].getName), + "num-children" -> 0, + "value" -> "1", + "dataType" -> "integer")), + List( + JObject( + "class" -> JString(classOf[Literal].getName), + "num-children" -> 0, + "value" -> "2", + "dataType" -> "integer")))) + + // Other Seq is converted to JNull, to reduce the risk of out of memory + assertJSON(Seq(1, 2, 3), JNull) + + // All Map type is converted to JNull, to reduce the risk of out of memory + assertJSON(Map("key" -> "value"), JNull) + + // Unknown type is converted to JNull, to reduce the risk of out of memory + assertJSON(new Object {}, JNull) + + // Convert all TreeNode children to JSON + assertJSON( + Union(Seq(JsonTestTreeNode("0"), JsonTestTreeNode("1"))), + List( + JObject( + "class" -> classOf[Union].getName, + "num-children" -> 2, + "children" -> List(0, 1)), + JObject( + "class" -> classOf[JsonTestTreeNode].getName, + "num-children" -> 0, + "arg" -> "0"), + JObject( + "class" -> classOf[JsonTestTreeNode].getName, + "num-children" -> 0, + "arg" -> "1"))) + } + + test("toJSON should not throws java.lang.StackOverflowError") { + val udf = ScalaUDF(SelfReferenceUDF(), BooleanType, Seq("col1".attr)) + // Should not throw java.lang.StackOverflowError + udf.toJSON + } + + private def compareJSON(leftJson: String, rightJson: String): Unit = { + val left = JsonMethods.parse(leftJson) + val right = JsonMethods.parse(rightJson) + assert(left == right) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index d361f61764d1f..34fa626e00e31 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -120,7 +120,6 @@ abstract class QueryTest extends PlanTest { throw ae } } - checkJsonFormat(analyzedDS) assertEmptyMissingInput(analyzedDS) try ds.collect() catch { @@ -168,8 +167,6 @@ abstract class QueryTest extends PlanTest { } } - checkJsonFormat(analyzedDF) - assertEmptyMissingInput(analyzedDF) QueryTest.checkAnswer(analyzedDF, expectedAnswer) match { @@ -228,139 +225,6 @@ abstract class QueryTest extends PlanTest { planWithCaching) } - private def checkJsonFormat(ds: Dataset[_]): Unit = { - // Get the analyzed plan and rewrite the PredicateSubqueries in order to make sure that - // RDD and Data resolution does not break. - val logicalPlan = ds.queryExecution.analyzed - - // bypass some cases that we can't handle currently. - logicalPlan.transform { - case _: ObjectConsumer => return - case _: ObjectProducer => return - case _: AppendColumns => return - case _: TypedFilter => return - case _: LogicalRelation => return - case p if p.getClass.getSimpleName == "MetastoreRelation" => return - case _: MemoryPlan => return - case p: InMemoryRelation => - p.child.transform { - case _: ObjectConsumerExec => return - case _: ObjectProducerExec => return - } - p - }.transformAllExpressions { - case _: ImperativeAggregate => return - case _: TypedAggregateExpression => return - case Literal(_, _: ObjectType) => return - case _: UserDefinedGenerator => return - } - - // bypass hive tests before we fix all corner cases in hive module. - if (this.getClass.getName.startsWith("org.apache.spark.sql.hive")) return - - val jsonString = try { - logicalPlan.toJSON - } catch { - case NonFatal(e) => - fail( - s""" - |Failed to parse logical plan to JSON: - |${logicalPlan.treeString} - """.stripMargin, e) - } - - // scala function is not serializable to JSON, use null to replace them so that we can compare - // the plans later. - val normalized1 = logicalPlan.transformAllExpressions { - case udf: ScalaUDF => udf.copy(function = null) - case gen: UserDefinedGenerator => gen.copy(function = null) - // After SPARK-17356: the JSON representation no longer has the Metadata. We need to remove - // the Metadata from the normalized plan so that we can compare this plan with the - // JSON-deserialzed plan. - case a @ Alias(child, name) if a.explicitMetadata.isDefined => - Alias(child, name)(a.exprId, a.qualifier, Some(Metadata.empty), a.isGenerated) - case a: AttributeReference if a.metadata != Metadata.empty => - AttributeReference(a.name, a.dataType, a.nullable, Metadata.empty)(a.exprId, a.qualifier, - a.isGenerated) - } - - // RDDs/data are not serializable to JSON, so we need to collect LogicalPlans that contains - // these non-serializable stuff, and use these original ones to replace the null-placeholders - // in the logical plans parsed from JSON. - val logicalRDDs = new ArrayDeque[LogicalRDD]() - val localRelations = new ArrayDeque[LocalRelation]() - val inMemoryRelations = new ArrayDeque[InMemoryRelation]() - def collectData: (LogicalPlan => Unit) = { - case l: LogicalRDD => - logicalRDDs.offer(l) - case l: LocalRelation => - localRelations.offer(l) - case i: InMemoryRelation => - inMemoryRelations.offer(i) - case p => - p.expressions.foreach { - _.foreach { - case s: SubqueryExpression => - s.plan.foreach(collectData) - case _ => - } - } - } - logicalPlan.foreach(collectData) - - - val jsonBackPlan = try { - TreeNode.fromJSON[LogicalPlan](jsonString, spark.sparkContext) - } catch { - case NonFatal(e) => - fail( - s""" - |Failed to rebuild the logical plan from JSON: - |${logicalPlan.treeString} - | - |${logicalPlan.prettyJson} - """.stripMargin, e) - } - - def renormalize: PartialFunction[LogicalPlan, LogicalPlan] = { - case l: LogicalRDD => - val origin = logicalRDDs.pop() - LogicalRDD(l.output, origin.rdd)(spark) - case l: LocalRelation => - val origin = localRelations.pop() - l.copy(data = origin.data) - case l: InMemoryRelation => - val origin = inMemoryRelations.pop() - InMemoryRelation( - l.output, - l.useCompression, - l.batchSize, - l.storageLevel, - origin.child, - l.tableName)( - origin.cachedColumnBuffers, - origin.batchStats) - case p => - p.transformExpressions { - case s: SubqueryExpression => - s.withNewPlan(s.plan.transformDown(renormalize)) - } - } - val normalized2 = jsonBackPlan.transformDown(renormalize) - - assert(logicalRDDs.isEmpty) - assert(localRelations.isEmpty) - assert(inMemoryRelations.isEmpty) - - if (normalized1 != normalized2) { - fail( - s""" - |== FAIL: the logical plan parsed from json does not match the original one === - |${sideBySide(logicalPlan.treeString, normalized2.treeString).mkString("\n")} - """.stripMargin) - } - } - /** * Asserts that a given [[Dataset]] does not have missing inputs in all the analyzed plans. */ From dca771bec6edb1cd8fc75861d364e0ba9dccf7c3 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 16 Sep 2016 11:24:26 -0700 Subject: [PATCH 036/306] [SPARK-17558] Bump Hadoop 2.7 version from 2.7.2 to 2.7.3 ## What changes were proposed in this pull request? This patch bumps the Hadoop version in hadoop-2.7 profile from 2.7.2 to 2.7.3, which was recently released and contained a number of bug fixes. ## How was this patch tested? The change should be covered by existing tests. Author: Reynold Xin Closes #15115 from rxin/SPARK-17558. --- dev/deps/spark-deps-hadoop-2.7 | 30 +++++++++++++++--------------- pom.xml | 2 +- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index d464c97ed1d67..63566125373dd 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -59,21 +59,21 @@ gson-2.2.4.jar guava-14.0.1.jar guice-3.0.jar guice-servlet-3.0.jar -hadoop-annotations-2.7.2.jar -hadoop-auth-2.7.2.jar -hadoop-client-2.7.2.jar -hadoop-common-2.7.2.jar -hadoop-hdfs-2.7.2.jar -hadoop-mapreduce-client-app-2.7.2.jar -hadoop-mapreduce-client-common-2.7.2.jar -hadoop-mapreduce-client-core-2.7.2.jar -hadoop-mapreduce-client-jobclient-2.7.2.jar -hadoop-mapreduce-client-shuffle-2.7.2.jar -hadoop-yarn-api-2.7.2.jar -hadoop-yarn-client-2.7.2.jar -hadoop-yarn-common-2.7.2.jar -hadoop-yarn-server-common-2.7.2.jar -hadoop-yarn-server-web-proxy-2.7.2.jar +hadoop-annotations-2.7.3.jar +hadoop-auth-2.7.3.jar +hadoop-client-2.7.3.jar +hadoop-common-2.7.3.jar +hadoop-hdfs-2.7.3.jar +hadoop-mapreduce-client-app-2.7.3.jar +hadoop-mapreduce-client-common-2.7.3.jar +hadoop-mapreduce-client-core-2.7.3.jar +hadoop-mapreduce-client-jobclient-2.7.3.jar +hadoop-mapreduce-client-shuffle-2.7.3.jar +hadoop-yarn-api-2.7.3.jar +hadoop-yarn-client-2.7.3.jar +hadoop-yarn-common-2.7.3.jar +hadoop-yarn-server-common-2.7.3.jar +hadoop-yarn-server-web-proxy-2.7.3.jar hk2-api-2.4.0-b34.jar hk2-locator-2.4.0-b34.jar hk2-utils-2.4.0-b34.jar diff --git a/pom.xml b/pom.xml index ef83c184d0237..b5141736011b9 100644 --- a/pom.xml +++ b/pom.xml @@ -2524,7 +2524,7 @@ hadoop-2.7 - 2.7.2 + 2.7.3 0.9.3 3.4.6 2.6.0 From b9323fc9381a09af510f542fd5c86473e029caf6 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Fri, 16 Sep 2016 13:43:05 -0700 Subject: [PATCH 037/306] [SPARK-17561][DOCS] DataFrameWriter documentation formatting problems ## What changes were proposed in this pull request? Fix `
      /
    • ` problems in SQL scaladoc. ## How was this patch tested? Scaladoc build and manual verification of generated HTML. Author: Sean Owen Closes #15117 from srowen/SPARK-17561. --- .../apache/spark/sql/DataFrameReader.scala | 32 ++++++++-------- .../apache/spark/sql/DataFrameWriter.scala | 12 ++++++ .../sql/streaming/DataStreamReader.scala | 38 ++++++++++++------- 3 files changed, 53 insertions(+), 29 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 93bf74d06b71d..d29d90ce40453 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -269,14 +269,15 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { *
    • `allowBackslashEscapingAnyCharacter` (default `false`): allows accepting quoting of all * character using backslash quoting mechanism
    • *
    • `mode` (default `PERMISSIVE`): allows a mode for dealing with corrupt records - * during parsing.
    • - *
        - *
      • - `PERMISSIVE` : sets other fields to `null` when it meets a corrupted record, and puts - * the malformed string into a new field configured by `columnNameOfCorruptRecord`. When - * a schema is set by user, it sets `null` for extra fields.
      • - *
      • - `DROPMALFORMED` : ignores the whole corrupted records.
      • - *
      • - `FAILFAST` : throws an exception when it meets corrupted records.
      • - *
      + * during parsing. + *
        + *
      • `PERMISSIVE` : sets other fields to `null` when it meets a corrupted record, and puts + * the malformed string into a new field configured by `columnNameOfCorruptRecord`. When + * a schema is set by user, it sets `null` for extra fields.
      • + *
      • `DROPMALFORMED` : ignores the whole corrupted records.
      • + *
      • `FAILFAST` : throws an exception when it meets corrupted records.
      • + *
      + * *
    • `columnNameOfCorruptRecord` (default is the value specified in * `spark.sql.columnNameOfCorruptRecord`): allows renaming the new field having malformed string * created by `PERMISSIVE` mode. This overrides `spark.sql.columnNameOfCorruptRecord`.
    • @@ -395,13 +396,14 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { *
    • `maxMalformedLogPerPartition` (default `10`): sets the maximum number of malformed rows * Spark will log for each partition. Malformed records beyond this number will be ignored.
    • *
    • `mode` (default `PERMISSIVE`): allows a mode for dealing with corrupt records - * during parsing.
    • - *
        - *
      • - `PERMISSIVE` : sets other fields to `null` when it meets a corrupted record. When - * a schema is set by user, it sets `null` for extra fields.
      • - *
      • - `DROPMALFORMED` : ignores the whole corrupted records.
      • - *
      • - `FAILFAST` : throws an exception when it meets corrupted records.
      • - *
      + * during parsing. + *
        + *
      • `PERMISSIVE` : sets other fields to `null` when it meets a corrupted record. When + * a schema is set by user, it sets `null` for extra fields.
      • + *
      • `DROPMALFORMED` : ignores the whole corrupted records.
      • + *
      • `FAILFAST` : throws an exception when it meets corrupted records.
      • + *
      + * *
    * @since 2.0.0 */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index c05c7a6551600..e137f076a0cab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -397,7 +397,9 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { * your external database systems. * * You can set the following JDBC-specific option(s) for storing JDBC: + *
      *
    • `truncate` (default `false`): use `TRUNCATE TABLE` instead of `DROP TABLE`.
    • + *
    * * In case of failures, users should turn off `truncate` option to use `DROP TABLE` again. Also, * due to the different behavior of `TRUNCATE TABLE` among DBMS, it's not always safe to use this. @@ -486,6 +488,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { * }}} * * You can set the following JSON-specific option(s) for writing JSON files: + *
      *
    • `compression` (default `null`): compression codec to use when saving to file. This can be * one of the known case-insensitive shorten names (`none`, `bzip2`, `gzip`, `lz4`, * `snappy` and `deflate`).
    • @@ -495,6 +498,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { *
    • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSZZ`): sets the string that * indicates a timestamp format. Custom date formats follow the formats at * `java.text.SimpleDateFormat`. This applies to timestamp type.
    • + *
    * * @since 1.4.0 */ @@ -510,10 +514,12 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { * }}} * * You can set the following Parquet-specific option(s) for writing Parquet files: + *
      *
    • `compression` (default is the value specified in `spark.sql.parquet.compression.codec`): * compression codec to use when saving to file. This can be one of the known case-insensitive * shorten names(none, `snappy`, `gzip`, and `lzo`). This will override * `spark.sql.parquet.compression.codec`.
    • + *
    * * @since 1.4.0 */ @@ -529,9 +535,11 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { * }}} * * You can set the following ORC-specific option(s) for writing ORC files: + *
      *
    • `compression` (default `snappy`): compression codec to use when saving to file. This can be * one of the known case-insensitive shorten names(`none`, `snappy`, `zlib`, and `lzo`). * This will override `orc.compress`.
    • + *
    * * @since 1.5.0 * @note Currently, this method can only be used after enabling Hive support @@ -553,9 +561,11 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { * }}} * * You can set the following option(s) for writing text files: + *
      *
    • `compression` (default `null`): compression codec to use when saving to file. This can be * one of the known case-insensitive shorten names (`none`, `bzip2`, `gzip`, `lz4`, * `snappy` and `deflate`).
    • + *
    * * @since 1.6.0 */ @@ -571,6 +581,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { * }}} * * You can set the following CSV-specific option(s) for writing CSV files: + *
      *
    • `sep` (default `,`): sets the single character as a separator for each * field and value.
    • *
    • `quote` (default `"`): sets the single character used for escaping quoted values where @@ -593,6 +604,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { *
    • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSZZ`): sets the string that * indicates a timestamp format. Custom date formats follow the formats at * `java.text.SimpleDateFormat`. This applies to timestamp type.
    • + *
    * * @since 2.0.0 */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 3ad1125229c97..c25f71af7362a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -161,6 +161,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * schema in advance, use the version that specifies the schema to avoid the extra scan. * * You can set the following JSON-specific options to deal with non-standard JSON files: + *
      *
    • `maxFilesPerTrigger` (default: no max limit): sets the maximum number of new files to be * considered in every trigger.
    • *
    • `primitivesAsString` (default `false`): infers all primitive values as a string type
    • @@ -175,14 +176,15 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo *
    • `allowBackslashEscapingAnyCharacter` (default `false`): allows accepting quoting of all * character using backslash quoting mechanism
    • *
    • `mode` (default `PERMISSIVE`): allows a mode for dealing with corrupt records - * during parsing.
    • - *
        - *
      • `PERMISSIVE` : sets other fields to `null` when it meets a corrupted record, and puts the - * malformed string into a new field configured by `columnNameOfCorruptRecord`. When - * a schema is set by user, it sets `null` for extra fields.
      • - *
      • `DROPMALFORMED` : ignores the whole corrupted records.
      • - *
      • `FAILFAST` : throws an exception when it meets corrupted records.
      • - *
      + * during parsing. + *
        + *
      • `PERMISSIVE` : sets other fields to `null` when it meets a corrupted record, and puts + * the malformed string into a new field configured by `columnNameOfCorruptRecord`. When + * a schema is set by user, it sets `null` for extra fields.
      • + *
      • `DROPMALFORMED` : ignores the whole corrupted records.
      • + *
      • `FAILFAST` : throws an exception when it meets corrupted records.
      • + *
      + * *
    • `columnNameOfCorruptRecord` (default is the value specified in * `spark.sql.columnNameOfCorruptRecord`): allows renaming the new field having malformed string * created by `PERMISSIVE` mode. This overrides `spark.sql.columnNameOfCorruptRecord`.
    • @@ -192,6 +194,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo *
    • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSZZ`): sets the string that * indicates a timestamp format. Custom date formats follow the formats at * `java.text.SimpleDateFormat`. This applies to timestamp type.
    • + *
    * * @since 2.0.0 */ @@ -207,6 +210,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * specify the schema explicitly using [[schema]]. * * You can set the following CSV-specific options to deal with CSV files: + *
      *
    • `maxFilesPerTrigger` (default: no max limit): sets the maximum number of new files to be * considered in every trigger.
    • *
    • `sep` (default `,`): sets the single character as a separator for each @@ -245,12 +249,14 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo *
    • `maxCharsPerColumn` (default `1000000`): defines the maximum number of characters allowed * for any given value being read.
    • *
    • `mode` (default `PERMISSIVE`): allows a mode for dealing with corrupt records - * during parsing.
    • - *
        - *
      • `PERMISSIVE` : sets other fields to `null` when it meets a corrupted record. When - * a schema is set by user, it sets `null` for extra fields.
      • - *
      • `DROPMALFORMED` : ignores the whole corrupted records.
      • - *
      • `FAILFAST` : throws an exception when it meets corrupted records.
      • + * during parsing. + *
          + *
        • `PERMISSIVE` : sets other fields to `null` when it meets a corrupted record. When + * a schema is set by user, it sets `null` for extra fields.
        • + *
        • `DROPMALFORMED` : ignores the whole corrupted records.
        • + *
        • `FAILFAST` : throws an exception when it meets corrupted records.
        • + *
        + * *
      * * @since 2.0.0 @@ -263,12 +269,14 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * Loads a Parquet file stream, returning the result as a [[DataFrame]]. * * You can set the following Parquet-specific option(s) for reading Parquet files: + *
        *
      • `maxFilesPerTrigger` (default: no max limit): sets the maximum number of new files to be * considered in every trigger.
      • *
      • `mergeSchema` (default is the value specified in `spark.sql.parquet.mergeSchema`): sets * whether we should merge schemas collected from all * Parquet part-files. This will override * `spark.sql.parquet.mergeSchema`.
      • + *
      * * @since 2.0.0 */ @@ -292,8 +300,10 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * }}} * * You can set the following text-specific options to deal with text files: + *
        *
      • `maxFilesPerTrigger` (default: no max limit): sets the maximum number of new files to be * considered in every trigger.
      • + *
      * * @since 2.0.0 */ From 39e2bad6a866d27c3ca594d15e574a1da3ee84cc Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 16 Sep 2016 14:02:56 -0700 Subject: [PATCH 038/306] [SPARK-17549][SQL] Only collect table size stat in driver for cached relation. The existing code caches all stats for all columns for each partition in the driver; for a large relation, this causes extreme memory usage, which leads to gc hell and application failures. It seems that only the size in bytes of the data is actually used in the driver, so instead just colllect that. In executors, the full stats are still kept, but that's not a big problem; we expect the data to be distributed and thus not really incur in too much memory pressure in each individual executor. There are also potential improvements on the executor side, since the data being stored currently is very wasteful (e.g. storing boxed types vs. primitive types for stats). But that's a separate issue. On a mildly related change, I'm also adding code to catch exceptions in the code generator since Janino was breaking with the test data I tried this patch on. Tested with unit tests and by doing a count a very wide table (20k columns) with many partitions. Author: Marcelo Vanzin Closes #15112 from vanzin/SPARK-17549. --- .../expressions/codegen/CodeGenerator.scala | 18 +++++++++----- .../execution/columnar/InMemoryRelation.scala | 24 +++++-------------- .../columnar/InMemoryColumnarQuerySuite.scala | 14 +++++++++++ 3 files changed, 32 insertions(+), 24 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index f982c222af5f0..33b9b804fc601 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -23,6 +23,7 @@ import java.util.{Map => JavaMap} import scala.collection.JavaConverters._ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import scala.util.control.NonFatal import com.google.common.cache.{CacheBuilder, CacheLoader} import org.codehaus.janino.{ByteArrayClassLoader, ClassBodyEvaluator, SimpleCompiler} @@ -910,14 +911,19 @@ object CodeGenerator extends Logging { codeAttrField.setAccessible(true) classes.foreach { case (_, classBytes) => CodegenMetrics.METRIC_GENERATED_CLASS_BYTECODE_SIZE.update(classBytes.length) - val cf = new ClassFile(new ByteArrayInputStream(classBytes)) - cf.methodInfos.asScala.foreach { method => - method.getAttributes().foreach { a => - if (a.getClass.getName == codeAttr.getName) { - CodegenMetrics.METRIC_GENERATED_METHOD_BYTECODE_SIZE.update( - codeAttrField.get(a).asInstanceOf[Array[Byte]].length) + try { + val cf = new ClassFile(new ByteArrayInputStream(classBytes)) + cf.methodInfos.asScala.foreach { method => + method.getAttributes().foreach { a => + if (a.getClass.getName == codeAttr.getName) { + CodegenMetrics.METRIC_GENERATED_METHOD_BYTECODE_SIZE.update( + codeAttrField.get(a).asInstanceOf[Array[Byte]].length) + } } } + } catch { + case NonFatal(e) => + logWarning("Error calculating stats of compiled class.", e) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 479934a7afc75..56bd5c1891e8d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.execution.columnar -import scala.collection.JavaConverters._ - import org.apache.commons.lang3.StringUtils import org.apache.spark.network.util.JavaUtils @@ -31,7 +29,7 @@ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.CollectionAccumulator +import org.apache.spark.util.LongAccumulator object InMemoryRelation { @@ -63,8 +61,7 @@ case class InMemoryRelation( @transient child: SparkPlan, tableName: Option[String])( @transient var _cachedColumnBuffers: RDD[CachedBatch] = null, - val batchStats: CollectionAccumulator[InternalRow] = - child.sqlContext.sparkContext.collectionAccumulator[InternalRow]) + val batchStats: LongAccumulator = child.sqlContext.sparkContext.longAccumulator) extends logical.LeafNode with MultiInstanceRelation { override protected def innerChildren: Seq[QueryPlan[_]] = Seq(child) @@ -74,21 +71,12 @@ case class InMemoryRelation( @transient val partitionStatistics = new PartitionStatistics(output) override lazy val statistics: Statistics = { - if (batchStats.value.isEmpty) { + if (batchStats.value == 0L) { // Underlying columnar RDD hasn't been materialized, no useful statistics information // available, return the default statistics. Statistics(sizeInBytes = child.sqlContext.conf.defaultSizeInBytes) } else { - // Underlying columnar RDD has been materialized, required information has also been - // collected via the `batchStats` accumulator. - val sizeOfRow: Expression = - BindReferences.bindReference( - output.map(a => partitionStatistics.forAttribute(a).sizeInBytes).reduce(Add), - partitionStatistics.schema) - - val sizeInBytes = - batchStats.value.asScala.map(row => sizeOfRow.eval(row).asInstanceOf[Long]).sum - Statistics(sizeInBytes = sizeInBytes) + Statistics(sizeInBytes = batchStats.value.longValue) } } @@ -139,10 +127,10 @@ case class InMemoryRelation( rowCount += 1 } + batchStats.add(totalSize) + val stats = InternalRow.fromSeq(columnBuilders.map(_.columnStats.collectedStatistics) .flatMap(_.values)) - - batchStats.add(stats) CachedBatch(rowCount, columnBuilders.map { builder => JavaUtils.bufferToArray(builder.build()) }, stats) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 937839644ad5f..0daa29b666f62 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -232,4 +232,18 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { val columnTypes2 = List.fill(length2)(IntegerType) val columnarIterator2 = GenerateColumnAccessor.generate(columnTypes2) } + + test("SPARK-17549: cached table size should be correctly calculated") { + val data = spark.sparkContext.parallelize(1 to 10, 5).toDF() + val plan = spark.sessionState.executePlan(data.logicalPlan).sparkPlan + val cached = InMemoryRelation(true, 5, MEMORY_ONLY, plan, None) + + // Materialize the data. + val expectedAnswer = data.collect() + checkAnswer(cached, expectedAnswer) + + // Check that the right size was calculated. + assert(cached.batchStats.value === expectedAnswer.size * INT.defaultSize) + } + } From 69cb0496974737347e2650cda436b39bbd51e581 Mon Sep 17 00:00:00 2001 From: Daniel Darabos Date: Sat, 17 Sep 2016 12:28:42 +0100 Subject: [PATCH 039/306] Correct fetchsize property name in docs ## What changes were proposed in this pull request? Replace `fetchSize` with `fetchsize` in the docs. ## How was this patch tested? I manually tested `fetchSize` and `fetchsize`. The latter has an effect. See also [`JdbcUtils.scala#L38`](https://github.com/apache/spark/blob/v2.0.0/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala#L38) for the definition of the property. Author: Daniel Darabos Closes #14975 from darabos/patch-3. --- docs/sql-programming-guide.md | 2 +- .../test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 28cc88c322b7e..4ac5fae566abe 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1053,7 +1053,7 @@ the Data Sources API. The following options are supported: - fetchSize + fetchsize The JDBC fetch size, which determines how many rows to fetch per round trip. This can help performance on JDBC drivers which default to low fetch size (eg. Oracle with 10 rows). diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 2d8ee338a9804..10f15ca280689 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -289,7 +289,7 @@ class JDBCSuite extends SparkFunSuite assert(names(2).equals("mary")) } - test("SELECT first field when fetchSize is two") { + test("SELECT first field when fetchsize is two") { val names = sql("SELECT NAME FROM fetchtwo").collect().map(x => x.getString(0)).sortWith(_ < _) assert(names.size === 3) assert(names(0).equals("fred")) @@ -305,7 +305,7 @@ class JDBCSuite extends SparkFunSuite assert(ids(2) === 3) } - test("SELECT second field when fetchSize is two") { + test("SELECT second field when fetchsize is two") { val ids = sql("SELECT THEID FROM fetchtwo").collect().map(x => x.getInt(0)).sortWith(_ < _) assert(ids.size === 3) assert(ids(0) === 1) @@ -352,7 +352,7 @@ class JDBCSuite extends SparkFunSuite urlWithUserAndPass, "TEST.PEOPLE", new Properties()).collect().length === 3) } - test("Basic API with illegal FetchSize") { + test("Basic API with illegal fetchsize") { val properties = new Properties() properties.setProperty(JdbcUtils.JDBC_BATCH_FETCH_SIZE, "-1") val e = intercept[SparkException] { From f15d41be3ce7569736ccbf2ffe1bec265865f55d Mon Sep 17 00:00:00 2001 From: Xin Ren Date: Sat, 17 Sep 2016 12:30:25 +0100 Subject: [PATCH 040/306] [SPARK-17567][DOCS] Use valid url to Spark RDD paper https://issues.apache.org/jira/browse/SPARK-17567 ## What changes were proposed in this pull request? Documentation (http://spark.apache.org/docs/latest/api/scala/#org.apache.spark.rdd.RDD) contains broken link to Spark paper (http://www.cs.berkeley.edu/~matei/papers/2012/nsdi_spark.pdf). I found it elsewhere (https://www.usenix.org/system/files/conference/nsdi12/nsdi12-final138.pdf) and I hope it is the same one. It should be uploaded to and linked from some Apache controlled storage, so it won't break again. ## How was this patch tested? Tested manually on local laptop. Author: Xin Ren Closes #15121 from keypointt/SPARK-17567. --- core/src/main/scala/org/apache/spark/rdd/RDD.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 10b5f8291a03a..6dc334ceb52ea 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -70,7 +70,7 @@ import org.apache.spark.util.random.{BernoulliCellSampler, BernoulliSampler, Poi * All of the scheduling and execution in Spark is done based on these methods, allowing each RDD * to implement its own way of computing itself. Indeed, users can implement custom RDDs (e.g. for * reading data from a new storage system) by overriding these functions. Please refer to the - * [[http://www.cs.berkeley.edu/~matei/papers/2012/nsdi_spark.pdf Spark paper]] for more details + * [[http://people.csail.mit.edu/matei/papers/2012/nsdi_spark.pdf Spark paper]] for more details * on RDD internals. */ abstract class RDD[T: ClassTag]( From 25cbbe6ca334140204e7035ab8b9d304da9b8a8a Mon Sep 17 00:00:00 2001 From: William Benton Date: Sat, 17 Sep 2016 12:49:58 +0100 Subject: [PATCH 041/306] [SPARK-17548][MLLIB] Word2VecModel.findSynonyms no longer spuriously rejects the best match when invoked with a vector ## What changes were proposed in this pull request? This pull request changes the behavior of `Word2VecModel.findSynonyms` so that it will not spuriously reject the best match when invoked with a vector that does not correspond to a word in the model's vocabulary. Instead of blindly discarding the best match, the changed implementation discards a match that corresponds to the query word (in cases where `findSynonyms` is invoked with a word) or that has an identical angle to the query vector. ## How was this patch tested? I added a test to `Word2VecSuite` to ensure that the word with the most similar vector from a supplied vector would not be spuriously rejected. Author: William Benton Closes #15105 from willb/fix/findSynonyms. --- .../apache/spark/ml/feature/Word2Vec.scala | 20 +++++----- .../api/python/Word2VecModelWrapper.scala | 22 +++++++++-- .../apache/spark/mllib/feature/Word2Vec.scala | 37 ++++++++++++++----- .../spark/mllib/feature/Word2VecSuite.scala | 16 ++++++++ python/pyspark/mllib/feature.py | 12 ++++-- 5 files changed, 83 insertions(+), 24 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index c2b434c3d5cb1..14c05123c62ed 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -221,24 +221,26 @@ class Word2VecModel private[ml] ( } /** - * Find "num" number of words closest in similarity to the given word. - * Returns a dataframe with the words and the cosine similarities between the - * synonyms and the given word. + * Find "num" number of words closest in similarity to the given word, not + * including the word itself. Returns a dataframe with the words and the + * cosine similarities between the synonyms and the given word. */ @Since("1.5.0") def findSynonyms(word: String, num: Int): DataFrame = { - findSynonyms(wordVectors.transform(word), num) + val spark = SparkSession.builder().getOrCreate() + spark.createDataFrame(wordVectors.findSynonyms(word, num)).toDF("word", "similarity") } /** - * Find "num" number of words closest to similarity to the given vector representation - * of the word. Returns a dataframe with the words and the cosine similarities between the - * synonyms and the given word vector. + * Find "num" number of words whose vector representation most similar to the supplied vector. + * If the supplied vector is the vector representation of a word in the model's vocabulary, + * that word will be in the results. Returns a dataframe with the words and the cosine + * similarities between the synonyms and the given word vector. */ @Since("2.0.0") - def findSynonyms(word: Vector, num: Int): DataFrame = { + def findSynonyms(vec: Vector, num: Int): DataFrame = { val spark = SparkSession.builder().getOrCreate() - spark.createDataFrame(wordVectors.findSynonyms(word, num)).toDF("word", "similarity") + spark.createDataFrame(wordVectors.findSynonyms(vec, num)).toDF("word", "similarity") } /** @group setParam */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala index 4b4ed2291d139..5cbfbff3e4a62 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala @@ -43,18 +43,34 @@ private[python] class Word2VecModelWrapper(model: Word2VecModel) { rdd.rdd.map(model.transform) } + /** + * Finds synonyms of a word; do not include the word itself in results. + * @param word a word + * @param num number of synonyms to find + * @return a list consisting of a list of words and a vector of cosine similarities + */ def findSynonyms(word: String, num: Int): JList[Object] = { - val vec = transform(word) - findSynonyms(vec, num) + prepareResult(model.findSynonyms(word, num)) } + /** + * Finds words similar to the the vector representation of a word without + * filtering results. + * @param vector a vector + * @param num number of synonyms to find + * @return a list consisting of a list of words and a vector of cosine similarities + */ def findSynonyms(vector: Vector, num: Int): JList[Object] = { - val result = model.findSynonyms(vector, num) + prepareResult(model.findSynonyms(vector, num)) + } + + private def prepareResult(result: Array[(String, Double)]) = { val similarity = Vectors.dense(result.map(_._2)) val words = result.map(_._1) List(words, similarity).map(_.asInstanceOf[Object]).asJava } + def getVectors: JMap[String, JList[Float]] = { model.getVectors.map { case (k, v) => (k, v.toList.asJava) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 908198740b501..42ca9665e5843 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -518,7 +518,7 @@ class Word2VecModel private[spark] ( } /** - * Find synonyms of a word + * Find synonyms of a word; do not include the word itself in results. * @param word a word * @param num number of synonyms to find * @return array of (word, cosineSimilarity) @@ -526,17 +526,34 @@ class Word2VecModel private[spark] ( @Since("1.1.0") def findSynonyms(word: String, num: Int): Array[(String, Double)] = { val vector = transform(word) - findSynonyms(vector, num) + findSynonyms(vector, num, Some(word)) } /** - * Find synonyms of the vector representation of a word + * Find synonyms of the vector representation of a word, possibly + * including any words in the model vocabulary whose vector respresentation + * is the supplied vector. * @param vector vector representation of a word * @param num number of synonyms to find * @return array of (word, cosineSimilarity) */ @Since("1.1.0") def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = { + findSynonyms(vector, num, None) + } + + /** + * Find synonyms of the vector representation of a word, rejecting + * words identical to the value of wordOpt, if one is supplied. + * @param vector vector representation of a word + * @param num number of synonyms to find + * @param wordOpt optionally, a word to reject from the results list + * @return array of (word, cosineSimilarity) + */ + private def findSynonyms( + vector: Vector, + num: Int, + wordOpt: Option[String]): Array[(String, Double)] = { require(num > 0, "Number of similar words should > 0") // TODO: optimize top-k val fVector = vector.toArray.map(_.toFloat) @@ -563,12 +580,14 @@ class Word2VecModel private[spark] ( ind += 1 } - wordList.zip(cosVec) - .toSeq - .sortBy(-_._2) - .take(num + 1) - .tail - .toArray + val scored = wordList.zip(cosVec).toSeq.sortBy(-_._2) + + val filtered = wordOpt match { + case Some(w) => scored.take(num + 1).filter(tup => w != tup._1) + case None => scored + } + + filtered.take(num).toArray } /** diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala index 22de4c4ac40e6..f4fa216b8eba0 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.mllib.feature import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.util.Utils @@ -68,6 +69,21 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext { assert(syms(1)._1 == "japan") } + test("findSynonyms doesn't reject similar word vectors when called with a vector") { + val num = 2 + val word2VecMap = Map( + ("china", Array(0.50f, 0.50f, 0.50f, 0.50f)), + ("japan", Array(0.40f, 0.50f, 0.50f, 0.50f)), + ("taiwan", Array(0.60f, 0.50f, 0.50f, 0.50f)), + ("korea", Array(0.45f, 0.60f, 0.60f, 0.60f)) + ) + val model = new Word2VecModel(word2VecMap) + val syms = model.findSynonyms(Vectors.dense(Array(0.52, 0.5, 0.5, 0.5)), num) + assert(syms.length == num) + assert(syms(0)._1 == "china") + assert(syms(1)._1 == "taiwan") + } + test("model load / save") { val word2VecMap = Map( diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index b32d0c70ec6a7..5d99644fca254 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -544,8 +544,7 @@ def load(cls, sc, path): @ignore_unicode_prefix class Word2Vec(object): - """ - Word2Vec creates vector representation of words in a text corpus. + """Word2Vec creates vector representation of words in a text corpus. The algorithm first constructs a vocabulary from the corpus and then learns vector representation of words in the vocabulary. The vector representation can be used as features in @@ -567,13 +566,19 @@ class Word2Vec(object): >>> doc = sc.parallelize(localDoc).map(lambda line: line.split(" ")) >>> model = Word2Vec().setVectorSize(10).setSeed(42).fit(doc) + Querying for synonyms of a word will not return that word: + >>> syms = model.findSynonyms("a", 2) >>> [s[0] for s in syms] [u'b', u'c'] + + But querying for synonyms of a vector may return the word whose + representation is that vector: + >>> vec = model.transform("a") >>> syms = model.findSynonyms(vec, 2) >>> [s[0] for s in syms] - [u'b', u'c'] + [u'a', u'b'] >>> import os, tempfile >>> path = tempfile.mkdtemp() @@ -591,6 +596,7 @@ class Word2Vec(object): ... pass .. versionadded:: 1.2.0 + """ def __init__(self): """ From 9dbd4b864efacd09a8353d00c998be87f9eeacb2 Mon Sep 17 00:00:00 2001 From: David Navas Date: Sat, 17 Sep 2016 16:22:23 +0100 Subject: [PATCH 042/306] [SPARK-17529][CORE] Implement BitSet.clearUntil and use it during merge joins ## What changes were proposed in this pull request? Add a clearUntil() method on BitSet (adapted from the pre-existing setUntil() method). Use this method to clear the subset of the BitSet which needs to be used during merge joins. ## How was this patch tested? dev/run-tests, as well as performance tests on skewed data as described in jira. I expect there to be a small local performance hit using BitSet.clearUntil rather than BitSet.clear for normally shaped (unskewed) joins (additional read on the last long). This is expected to be de-minimis and was not specifically tested. Author: David Navas Closes #15084 from davidnavas/bitSet. --- .../apache/spark/util/collection/BitSet.scala | 28 ++++++++++------ .../spark/util/collection/BitSetSuite.scala | 32 +++++++++++++++++++ .../execution/joins/SortMergeJoinExec.scala | 4 +-- 3 files changed, 52 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala index 7ab67fc3a2de9..e63e0e3e1f68f 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala @@ -17,6 +17,8 @@ package org.apache.spark.util.collection +import java.util.Arrays + /** * A simple, fixed-size bit set implementation. This implementation is fast because it avoids * safety/bound checking. @@ -35,21 +37,14 @@ class BitSet(numBits: Int) extends Serializable { /** * Clear all set bits. */ - def clear(): Unit = { - var i = 0 - while (i < numWords) { - words(i) = 0L - i += 1 - } - } + def clear(): Unit = Arrays.fill(words, 0) /** * Set all the bits up to a given index */ - def setUntil(bitIndex: Int) { + def setUntil(bitIndex: Int): Unit = { val wordIndex = bitIndex >> 6 // divide by 64 - var i = 0 - while(i < wordIndex) { words(i) = -1; i += 1 } + Arrays.fill(words, 0, wordIndex, -1) if(wordIndex < words.length) { // Set the remaining bits (note that the mask could still be zero) val mask = ~(-1L << (bitIndex & 0x3f)) @@ -57,6 +52,19 @@ class BitSet(numBits: Int) extends Serializable { } } + /** + * Clear all the bits up to a given index + */ + def clearUntil(bitIndex: Int): Unit = { + val wordIndex = bitIndex >> 6 // divide by 64 + Arrays.fill(words, 0, wordIndex, 0) + if(wordIndex < words.length) { + // Clear the remaining bits + val mask = -1L << (bitIndex & 0x3f) + words(wordIndex) &= mask + } + } + /** * Compute the bit-wise AND of the two sets returning the * result. diff --git a/core/src/test/scala/org/apache/spark/util/collection/BitSetSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/BitSetSuite.scala index 69dbfa9cd7141..0169c9926e68f 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/BitSetSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/BitSetSuite.scala @@ -152,4 +152,36 @@ class BitSetSuite extends SparkFunSuite { assert(bitsetDiff.nextSetBit(85) === 85) assert(bitsetDiff.nextSetBit(86) === -1) } + + test( "[gs]etUntil" ) { + val bitSet = new BitSet(100) + + bitSet.setUntil(bitSet.capacity) + + (0 until bitSet.capacity).foreach { i => + assert(bitSet.get(i)) + } + + bitSet.clearUntil(bitSet.capacity) + + (0 until bitSet.capacity).foreach { i => + assert(!bitSet.get(i)) + } + + val setUntil = bitSet.capacity / 2 + bitSet.setUntil(setUntil) + + val clearUntil = setUntil / 2 + bitSet.clearUntil(clearUntil) + + (0 until clearUntil).foreach { i => + assert(!bitSet.get(i)) + } + (clearUntil until setUntil).foreach { i => + assert(bitSet.get(i)) + } + (setUntil until bitSet.capacity).foreach { i => + assert(!bitSet.get(i)) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index b46af2a99a1e0..81b3e1d224ab6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -954,12 +954,12 @@ private class SortMergeFullOuterJoinScanner( } if (leftMatches.size <= leftMatched.capacity) { - leftMatched.clear() + leftMatched.clearUntil(leftMatches.size) } else { leftMatched = new BitSet(leftMatches.size) } if (rightMatches.size <= rightMatched.capacity) { - rightMatched.clear() + rightMatched.clearUntil(rightMatches.size) } else { rightMatched = new BitSet(rightMatches.size) } From bbe0b1d623741decce98827130cc67eb1fff1240 Mon Sep 17 00:00:00 2001 From: sandy Date: Sat, 17 Sep 2016 16:25:03 +0100 Subject: [PATCH 043/306] [SPARK-17575][DOCS] Remove extra table tags in configuration document ## What changes were proposed in this pull request? Remove extra table tags in configurations document. ## How was this patch tested? Run all test cases and generate document. Before with extra tag its look like below ![config-wrong1](https://cloud.githubusercontent.com/assets/8075390/18608239/c602bb60-7d01-11e6-875e-f38558997dd3.png) ![config-wrong2](https://cloud.githubusercontent.com/assets/8075390/18608241/cf3b672c-7d01-11e6-935e-1e73f9e6e578.png) After removing tags its looks like below ![config](https://cloud.githubusercontent.com/assets/8075390/18608245/e156eb8e-7d01-11e6-98aa-3be68d4d1961.png) ![config2](https://cloud.githubusercontent.com/assets/8075390/18608247/e84eecd4-7d01-11e6-9738-a3f7ff8fe834.png) Author: sandy Closes #15130 from phalodi/SPARK-17575. --- docs/configuration.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index 8aea74505e28b..b50565367a98b 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -123,6 +123,7 @@ of the most common options to set are: Number of cores to use for the driver process, only in cluster mode. + spark.driver.maxResultSize 1g @@ -217,7 +218,7 @@ Apart from these, the following properties are also available, and may be useful
      Note: In client mode, this config must not be set through the SparkConf directly in your application, because the driver JVM has already started at that point. Instead, please set this through the --driver-class-path command line option or in - your default properties file. + your default properties file. @@ -244,7 +245,7 @@ Apart from these, the following properties are also available, and may be useful
      Note: In client mode, this config must not be set through the SparkConf directly in your application, because the driver JVM has already started at that point. Instead, please set this through the --driver-library-path command line option or in - your default properties file. + your default properties file. From 86c2d393a56bf1e5114bc5a781253c0460efb8af Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sat, 17 Sep 2016 16:52:30 +0100 Subject: [PATCH 044/306] [SPARK-17480][SQL][FOLLOWUP] Fix more instances which calls List.length/size which is O(n) ## What changes were proposed in this pull request? This PR fixes all the instances which was fixed in the previous PR. To make sure, I manually debugged and also checked the Scala source. `length` in [LinearSeqOptimized.scala#L49-L57](https://github.com/scala/scala/blob/2.11.x/src/library/scala/collection/LinearSeqOptimized.scala#L49-L57) is O(n). Also, `size` calls `length` via [SeqLike.scala#L106](https://github.com/scala/scala/blob/2.11.x/src/library/scala/collection/SeqLike.scala#L106). For debugging, I have created these as below: ```scala ArrayBuffer(1, 2, 3) Array(1, 2, 3) List(1, 2, 3) Seq(1, 2, 3) ``` and then called `size` and `length` for each to debug. ## How was this patch tested? I ran the bash as below on Mac ```bash find . -name *.scala -type f -exec grep -il "while (.*\\.length)" {} \; | grep "src/main" find . -name *.scala -type f -exec grep -il "while (.*\\.size)" {} \; | grep "src/main" ``` and then checked each. Author: hyukjinkwon Closes #15093 from HyukjinKwon/SPARK-17480-followup. --- .../sql/catalyst/analysis/Analyzer.scala | 28 ++++++------------- .../expressions/conditionalExpressions.scala | 3 +- .../sql/catalyst/expressions/ordering.scala | 3 +- .../sql/catalyst/util/QuantileSummaries.scala | 10 +++---- .../datasources/jdbc/JdbcUtils.scala | 2 +- .../spark/sql/hive/HiveInspectors.scala | 6 ++-- .../apache/spark/sql/hive/TableReader.scala | 3 +- .../org/apache/spark/sql/hive/hiveUDFs.scala | 3 +- .../spark/sql/hive/orc/OrcFileFormat.scala | 6 ++-- 9 files changed, 31 insertions(+), 33 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 5210f42c557b6..cc62d5e7c8826 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1663,27 +1663,17 @@ class Analyzer( } }.toSeq - // Third, for every Window Spec, we add a Window operator and set currentChild as the - // child of it. - var currentChild = child - var i = 0 - while (i < groupedWindowExpressions.size) { - val ((partitionSpec, orderSpec), windowExpressions) = groupedWindowExpressions(i) - // Set currentChild to the newly created Window operator. - currentChild = - Window( - windowExpressions, - partitionSpec, - orderSpec, - currentChild) - - // Move to next Window Spec. - i += 1 - } + // Third, we aggregate them by adding each Window operator for each Window Spec and then + // setting this to the child of the next Window operator. + val windowOps = + groupedWindowExpressions.foldLeft(child) { + case (last, ((partitionSpec, orderSpec), windowExpressions)) => + Window(windowExpressions, partitionSpec, orderSpec, last) + } - // Finally, we create a Project to output currentChild's output + // Finally, we create a Project to output windowOps's output // newExpressionsWithWindowFunctions. - Project(currentChild.output ++ newExpressionsWithWindowFunctions, currentChild) + Project(windowOps.output ++ newExpressionsWithWindowFunctions, windowOps) } // end of addWindow // We have to use transformDown at here to make sure the rule of diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 1dd70bcfcfe87..71d4e9a3c9471 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -125,7 +125,8 @@ abstract class CaseWhenBase( override def eval(input: InternalRow): Any = { var i = 0 - while (i < branches.size) { + val size = branches.size + while (i < size) { if (java.lang.Boolean.TRUE.equals(branches(i)._1.eval(input))) { return branches(i)._2.eval(input) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala index 79d2052c38a27..e24a3de3cfdbe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala @@ -31,7 +31,8 @@ class InterpretedOrdering(ordering: Seq[SortOrder]) extends Ordering[InternalRow def compare(a: InternalRow, b: InternalRow): Int = { var i = 0 - while (i < ordering.size) { + val size = ordering.size + while (i < size) { val order = ordering(i) val left = order.child.eval(a) val right = order.child.eval(b) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala index fd62bd511fac0..27928c493d5ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala @@ -91,10 +91,10 @@ class QuantileSummaries( var sampleIdx = 0 // The index of the sample currently being inserted. var opsIdx: Int = 0 - while(opsIdx < sorted.length) { + while (opsIdx < sorted.length) { val currentSample = sorted(opsIdx) // Add all the samples before the next observation. - while(sampleIdx < sampled.size && sampled(sampleIdx).value <= currentSample) { + while (sampleIdx < sampled.length && sampled(sampleIdx).value <= currentSample) { newSamples += sampled(sampleIdx) sampleIdx += 1 } @@ -102,7 +102,7 @@ class QuantileSummaries( // If it is the first one to insert, of if it is the last one currentCount += 1 val delta = - if (newSamples.isEmpty || (sampleIdx == sampled.size && opsIdx == sorted.length - 1)) { + if (newSamples.isEmpty || (sampleIdx == sampled.length && opsIdx == sorted.length - 1)) { 0 } else { math.floor(2 * relativeError * currentCount).toInt @@ -114,7 +114,7 @@ class QuantileSummaries( } // Add all the remaining existing samples - while(sampleIdx < sampled.size) { + while (sampleIdx < sampled.length) { newSamples += sampled(sampleIdx) sampleIdx += 1 } @@ -195,7 +195,7 @@ class QuantileSummaries( // Minimum rank at current sample var minRank = 0 var i = 1 - while (i < sampled.size - 1) { + while (i < sampled.length - 1) { val curSample = sampled(i) minRank += curSample.g val maxRank = minRank + curSample.delta diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index b09fd511a9074..3db1d1f109fb7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -369,7 +369,7 @@ object JdbcUtils extends Logging { val bytes = rs.getBytes(pos + 1) var ans = 0L var j = 0 - while (j < bytes.size) { + while (j < bytes.length) { ans = 256 * ans + (255 & bytes(j)) j = j + 1 } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 4e74452f6cd12..e4b963efeaf18 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -703,7 +703,8 @@ private[hive] trait HiveInspectors { // 1. create the pojo (most likely) object val result = x.create() var i = 0 - while (i < fieldRefs.size) { + val size = fieldRefs.size + while (i < size) { // 2. set the property for the pojo val tpe = structType(i).dataType x.setStructFieldData( @@ -720,7 +721,8 @@ private[hive] trait HiveInspectors { val row = a.asInstanceOf[InternalRow] val result = new java.util.ArrayList[AnyRef](fieldRefs.size) var i = 0 - while (i < fieldRefs.size) { + val size = fieldRefs.size + while (i < size) { val tpe = structType(i).dataType result.add(wrap(row.get(i, tpe), fieldRefs.get(i).getFieldObjectInspector, tpe)) i += 1 diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index b4808fdbed9c9..ec7e53efc87f9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -427,7 +427,8 @@ private[hive] object HadoopTableReader extends HiveInspectors with Logging { iterator.map { value => val raw = converter.convert(rawDeser.deserialize(value)) var i = 0 - while (i < fieldRefs.length) { + val length = fieldRefs.length + while (i < length) { val fieldValue = soi.getStructFieldData(raw, fieldRefs(i)) if (fieldValue == null) { mutableRow.setNullAt(fieldOrdinals(i)) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index 9347aeb8e09a8..962dd5a52ebc0 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -153,7 +153,8 @@ private[hive] case class HiveGenericUDF( returnInspector // Make sure initialized. var i = 0 - while (i < children.length) { + val length = children.length + while (i < length) { val idx = i deferredObjects(i).asInstanceOf[DeferredObjectAdapter] .set(() => children(idx).eval(input)) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala index 286197b50e229..03b508e11aa76 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala @@ -190,7 +190,8 @@ private[orc] class OrcSerializer(dataSchema: StructType, conf: Configuration) row: InternalRow): Unit = { val fieldRefs = oi.getAllStructFieldRefs var i = 0 - while (i < fieldRefs.size) { + val size = fieldRefs.size + while (i < size) { oi.setStructFieldData( struct, @@ -289,7 +290,8 @@ private[orc] object OrcRelation extends HiveInspectors { iterator.map { value => val raw = deserializer.deserialize(value) var i = 0 - while (i < fieldRefs.length) { + val length = fieldRefs.length + while (i < length) { val fieldValue = oi.getStructFieldData(raw, fieldRefs(i)) if (fieldValue == null) { mutableRow.setNullAt(fieldOrdinals(i)) From 8faa5217b44e8d52eab7eb2d53d0652abaaf43cd Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 17 Sep 2016 11:46:15 -0700 Subject: [PATCH 045/306] [SPARK-17491] Close serialization stream to fix wrong answer bug in putIteratorAsBytes() ## What changes were proposed in this pull request? `MemoryStore.putIteratorAsBytes()` may silently lose values when used with `KryoSerializer` because it does not properly close the serialization stream before attempting to deserialize the already-serialized values, which may cause values buffered in Kryo's internal buffers to not be read. This is the root cause behind a user-reported "wrong answer" bug in PySpark caching reported by bennoleslie on the Spark user mailing list in a thread titled "pyspark persist MEMORY_ONLY vs MEMORY_AND_DISK". Due to Spark 2.0's automatic use of KryoSerializer for "safe" types (such as byte arrays, primitives, etc.) this misuse of serializers manifested itself as silent data corruption rather than a StreamCorrupted error (which you might get from JavaSerializer). The minimal fix, implemented here, is to close the serialization stream before attempting to deserialize written values. In addition, this patch adds several additional assertions / precondition checks to prevent misuse of `PartiallySerializedBlock` and `ChunkedByteBufferOutputStream`. ## How was this patch tested? The original bug was masked by an invalid assert in the memory store test cases: the old assert compared two results record-by-record with `zip` but didn't first check that the lengths of the two collections were equal, causing missing records to go unnoticed. The updated test case reproduced this bug. In addition, I added a new `PartiallySerializedBlockSuite` to unit test that component. Author: Josh Rosen Closes #15043 from JoshRosen/partially-serialized-block-values-iterator-bugfix. --- .../org/apache/spark/scheduler/Task.scala | 1 + .../spark/storage/memory/MemoryStore.scala | 89 ++++++-- .../spark/util/ByteBufferOutputStream.scala | 27 ++- .../io/ChunkedByteBufferOutputStream.scala | 12 +- .../spark/storage/MemoryStoreSuite.scala | 34 ++- .../PartiallySerializedBlockSuite.scala | 215 ++++++++++++++++++ .../PartiallyUnrolledIteratorSuite.scala | 2 +- .../ChunkedByteBufferOutputStreamSuite.scala | 8 + 8 files changed, 344 insertions(+), 44 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/storage/PartiallySerializedBlockSuite.scala diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 35c4dafe9c19c..1ed36bf0692f8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -230,6 +230,7 @@ private[spark] object Task { dataOut.flush() val taskBytes = serializer.serialize(task) Utils.writeByteBuffer(taskBytes, out) + out.close() out.toByteBuffer } diff --git a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala index ec1b0f7149271..205d469f48144 100644 --- a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala @@ -33,7 +33,7 @@ import org.apache.spark.memory.{MemoryManager, MemoryMode} import org.apache.spark.serializer.{SerializationStream, SerializerManager} import org.apache.spark.storage.{BlockId, BlockInfoManager, StorageLevel} import org.apache.spark.unsafe.Platform -import org.apache.spark.util.{CompletionIterator, SizeEstimator, Utils} +import org.apache.spark.util.{SizeEstimator, Utils} import org.apache.spark.util.collection.SizeTrackingVector import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream} @@ -277,6 +277,7 @@ private[spark] class MemoryStore( "released too much unroll memory") Left(new PartiallyUnrolledIterator( this, + MemoryMode.ON_HEAP, unrollMemoryUsedByThisBlock, unrolled = arrayValues.toIterator, rest = Iterator.empty)) @@ -285,7 +286,11 @@ private[spark] class MemoryStore( // We ran out of space while unrolling the values for this block logUnrollFailureMessage(blockId, vector.estimateSize()) Left(new PartiallyUnrolledIterator( - this, unrollMemoryUsedByThisBlock, unrolled = vector.iterator, rest = values)) + this, + MemoryMode.ON_HEAP, + unrollMemoryUsedByThisBlock, + unrolled = vector.iterator, + rest = values)) } } @@ -394,7 +399,7 @@ private[spark] class MemoryStore( redirectableStream, unrollMemoryUsedByThisBlock, memoryMode, - bbos.toChunkedByteBuffer, + bbos, values, classTag)) } @@ -655,6 +660,7 @@ private[spark] class MemoryStore( * The result of a failed [[MemoryStore.putIteratorAsValues()]] call. * * @param memoryStore the memoryStore, used for freeing memory. + * @param memoryMode the memory mode (on- or off-heap). * @param unrollMemory the amount of unroll memory used by the values in `unrolled`. * @param unrolled an iterator for the partially-unrolled values. * @param rest the rest of the original iterator passed to @@ -662,13 +668,14 @@ private[spark] class MemoryStore( */ private[storage] class PartiallyUnrolledIterator[T]( memoryStore: MemoryStore, + memoryMode: MemoryMode, unrollMemory: Long, private[this] var unrolled: Iterator[T], rest: Iterator[T]) extends Iterator[T] { private def releaseUnrollMemory(): Unit = { - memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP, unrollMemory) + memoryStore.releaseUnrollMemoryForThisTask(memoryMode, unrollMemory) // SPARK-17503: Garbage collects the unrolling memory before the life end of // PartiallyUnrolledIterator. unrolled = null @@ -706,7 +713,7 @@ private[storage] class PartiallyUnrolledIterator[T]( /** * A wrapper which allows an open [[OutputStream]] to be redirected to a different sink. */ -private class RedirectableOutputStream extends OutputStream { +private[storage] class RedirectableOutputStream extends OutputStream { private[this] var os: OutputStream = _ def setOutputStream(s: OutputStream): Unit = { os = s } override def write(b: Int): Unit = os.write(b) @@ -726,7 +733,8 @@ private class RedirectableOutputStream extends OutputStream { * @param redirectableOutputStream an OutputStream which can be redirected to a different sink. * @param unrollMemory the amount of unroll memory used by the values in `unrolled`. * @param memoryMode whether the unroll memory is on- or off-heap - * @param unrolled a byte buffer containing the partially-serialized values. + * @param bbos byte buffer output stream containing the partially-serialized values. + * [[redirectableOutputStream]] initially points to this output stream. * @param rest the rest of the original iterator passed to * [[MemoryStore.putIteratorAsValues()]]. * @param classTag the [[ClassTag]] for the block. @@ -735,14 +743,19 @@ private[storage] class PartiallySerializedBlock[T]( memoryStore: MemoryStore, serializerManager: SerializerManager, blockId: BlockId, - serializationStream: SerializationStream, - redirectableOutputStream: RedirectableOutputStream, - unrollMemory: Long, + private val serializationStream: SerializationStream, + private val redirectableOutputStream: RedirectableOutputStream, + val unrollMemory: Long, memoryMode: MemoryMode, - unrolled: ChunkedByteBuffer, + bbos: ChunkedByteBufferOutputStream, rest: Iterator[T], classTag: ClassTag[T]) { + private lazy val unrolledBuffer: ChunkedByteBuffer = { + bbos.close() + bbos.toChunkedByteBuffer + } + // If the task does not fully consume `valuesIterator` or otherwise fails to consume or dispose of // this PartiallySerializedBlock then we risk leaking of direct buffers, so we use a task // completion listener here in order to ensure that `unrolled.dispose()` is called at least once. @@ -751,7 +764,23 @@ private[storage] class PartiallySerializedBlock[T]( taskContext.addTaskCompletionListener { _ => // When a task completes, its unroll memory will automatically be freed. Thus we do not call // releaseUnrollMemoryForThisTask() here because we want to avoid double-freeing. - unrolled.dispose() + unrolledBuffer.dispose() + } + } + + // Exposed for testing + private[storage] def getUnrolledChunkedByteBuffer: ChunkedByteBuffer = unrolledBuffer + + private[this] var discarded = false + private[this] var consumed = false + + private def verifyNotConsumedAndNotDiscarded(): Unit = { + if (consumed) { + throw new IllegalStateException( + "Can only call one of finishWritingToStream() or valuesIterator() and can only call once.") + } + if (discarded) { + throw new IllegalStateException("Cannot call methods on a discarded PartiallySerializedBlock") } } @@ -759,15 +788,18 @@ private[storage] class PartiallySerializedBlock[T]( * Called to dispose of this block and free its memory. */ def discard(): Unit = { - try { - // We want to close the output stream in order to free any resources associated with the - // serializer itself (such as Kryo's internal buffers). close() might cause data to be - // written, so redirect the output stream to discard that data. - redirectableOutputStream.setOutputStream(ByteStreams.nullOutputStream()) - serializationStream.close() - } finally { - unrolled.dispose() - memoryStore.releaseUnrollMemoryForThisTask(memoryMode, unrollMemory) + if (!discarded) { + try { + // We want to close the output stream in order to free any resources associated with the + // serializer itself (such as Kryo's internal buffers). close() might cause data to be + // written, so redirect the output stream to discard that data. + redirectableOutputStream.setOutputStream(ByteStreams.nullOutputStream()) + serializationStream.close() + } finally { + discarded = true + unrolledBuffer.dispose() + memoryStore.releaseUnrollMemoryForThisTask(memoryMode, unrollMemory) + } } } @@ -776,8 +808,10 @@ private[storage] class PartiallySerializedBlock[T]( * and then serializing the values from the original input iterator. */ def finishWritingToStream(os: OutputStream): Unit = { + verifyNotConsumedAndNotDiscarded() + consumed = true // `unrolled`'s underlying buffers will be freed once this input stream is fully read: - ByteStreams.copy(unrolled.toInputStream(dispose = true), os) + ByteStreams.copy(unrolledBuffer.toInputStream(dispose = true), os) memoryStore.releaseUnrollMemoryForThisTask(memoryMode, unrollMemory) redirectableOutputStream.setOutputStream(os) while (rest.hasNext) { @@ -794,13 +828,22 @@ private[storage] class PartiallySerializedBlock[T]( * `close()` on it to free its resources. */ def valuesIterator: PartiallyUnrolledIterator[T] = { + verifyNotConsumedAndNotDiscarded() + consumed = true + // Close the serialization stream so that the serializer's internal buffers are freed and any + // "end-of-stream" markers can be written out so that `unrolled` is a valid serialized stream. + serializationStream.close() // `unrolled`'s underlying buffers will be freed once this input stream is fully read: val unrolledIter = serializerManager.dataDeserializeStream( - blockId, unrolled.toInputStream(dispose = true))(classTag) + blockId, unrolledBuffer.toInputStream(dispose = true))(classTag) + // The unroll memory will be freed once `unrolledIter` is fully consumed in + // PartiallyUnrolledIterator. If the iterator is not consumed by the end of the task then any + // extra unroll memory will automatically be freed by a `finally` block in `Task`. new PartiallyUnrolledIterator( memoryStore, + memoryMode, unrollMemory, - unrolled = CompletionIterator[T, Iterator[T]](unrolledIter, discard()), + unrolled = unrolledIter, rest = rest) } } diff --git a/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala b/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala index 09e7579ae9606..9077b86f9ba1d 100644 --- a/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala +++ b/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala @@ -29,7 +29,32 @@ private[spark] class ByteBufferOutputStream(capacity: Int) extends ByteArrayOutp def getCount(): Int = count + private[this] var closed: Boolean = false + + override def write(b: Int): Unit = { + require(!closed, "cannot write to a closed ByteBufferOutputStream") + super.write(b) + } + + override def write(b: Array[Byte], off: Int, len: Int): Unit = { + require(!closed, "cannot write to a closed ByteBufferOutputStream") + super.write(b, off, len) + } + + override def reset(): Unit = { + require(!closed, "cannot reset a closed ByteBufferOutputStream") + super.reset() + } + + override def close(): Unit = { + if (!closed) { + super.close() + closed = true + } + } + def toByteBuffer: ByteBuffer = { - return ByteBuffer.wrap(buf, 0, count) + require(closed, "can only call toByteBuffer() after ByteBufferOutputStream has been closed") + ByteBuffer.wrap(buf, 0, count) } } diff --git a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStream.scala b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStream.scala index 67b50d1e70437..a625b3289538a 100644 --- a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStream.scala +++ b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStream.scala @@ -49,10 +49,19 @@ private[spark] class ChunkedByteBufferOutputStream( */ private[this] var position = chunkSize private[this] var _size = 0 + private[this] var closed: Boolean = false def size: Long = _size + override def close(): Unit = { + if (!closed) { + super.close() + closed = true + } + } + override def write(b: Int): Unit = { + require(!closed, "cannot write to a closed ChunkedByteBufferOutputStream") allocateNewChunkIfNeeded() chunks(lastChunkIndex).put(b.toByte) position += 1 @@ -60,6 +69,7 @@ private[spark] class ChunkedByteBufferOutputStream( } override def write(bytes: Array[Byte], off: Int, len: Int): Unit = { + require(!closed, "cannot write to a closed ChunkedByteBufferOutputStream") var written = 0 while (written < len) { allocateNewChunkIfNeeded() @@ -73,7 +83,6 @@ private[spark] class ChunkedByteBufferOutputStream( @inline private def allocateNewChunkIfNeeded(): Unit = { - require(!toChunkedByteBufferWasCalled, "cannot write after toChunkedByteBuffer() is called") if (position == chunkSize) { chunks += allocator(chunkSize) lastChunkIndex += 1 @@ -82,6 +91,7 @@ private[spark] class ChunkedByteBufferOutputStream( } def toChunkedByteBuffer: ChunkedByteBuffer = { + require(closed, "cannot call toChunkedByteBuffer() unless close() has been called") require(!toChunkedByteBufferWasCalled, "toChunkedByteBuffer() can only be called once") toChunkedByteBufferWasCalled = true if (lastChunkIndex == -1) { diff --git a/core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala b/core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala index c11de826677e0..9929ea033a99f 100644 --- a/core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala @@ -79,6 +79,13 @@ class MemoryStoreSuite (memoryStore, blockInfoManager) } + private def assertSameContents[T](expected: Seq[T], actual: Seq[T], hint: String): Unit = { + assert(actual.length === expected.length, s"wrong number of values returned in $hint") + expected.iterator.zip(actual.iterator).foreach { case (e, a) => + assert(e === a, s"$hint did not return original values!") + } + } + test("reserve/release unroll memory") { val (memoryStore, _) = makeMemoryStore(12000) assert(memoryStore.currentUnrollMemory === 0) @@ -137,9 +144,7 @@ class MemoryStoreSuite var putResult = putIteratorAsValues("unroll", smallList.iterator, ClassTag.Any) assert(putResult.isRight) assert(memoryStore.currentUnrollMemoryForThisTask === 0) - smallList.iterator.zip(memoryStore.getValues("unroll").get).foreach { case (e, a) => - assert(e === a, "getValues() did not return original values!") - } + assertSameContents(smallList, memoryStore.getValues("unroll").get.toSeq, "getValues") blockInfoManager.lockForWriting("unroll") assert(memoryStore.remove("unroll")) blockInfoManager.removeBlock("unroll") @@ -152,9 +157,7 @@ class MemoryStoreSuite assert(memoryStore.currentUnrollMemoryForThisTask === 0) assert(memoryStore.contains("someBlock2")) assert(!memoryStore.contains("someBlock1")) - smallList.iterator.zip(memoryStore.getValues("unroll").get).foreach { case (e, a) => - assert(e === a, "getValues() did not return original values!") - } + assertSameContents(smallList, memoryStore.getValues("unroll").get.toSeq, "getValues") blockInfoManager.lockForWriting("unroll") assert(memoryStore.remove("unroll")) blockInfoManager.removeBlock("unroll") @@ -167,9 +170,7 @@ class MemoryStoreSuite assert(memoryStore.currentUnrollMemoryForThisTask > 0) // we returned an iterator assert(!memoryStore.contains("someBlock2")) assert(putResult.isLeft) - bigList.iterator.zip(putResult.left.get).foreach { case (e, a) => - assert(e === a, "putIterator() did not return original values!") - } + assertSameContents(bigList, putResult.left.get.toSeq, "putIterator") // The unroll memory was freed once the iterator returned by putIterator() was fully traversed. assert(memoryStore.currentUnrollMemoryForThisTask === 0) } @@ -316,9 +317,8 @@ class MemoryStoreSuite assert(res.isLeft) assert(memoryStore.currentUnrollMemoryForThisTask > 0) val valuesReturnedFromFailedPut = res.left.get.valuesIterator.toSeq // force materialization - valuesReturnedFromFailedPut.zip(bigList).foreach { case (e, a) => - assert(e === a, "PartiallySerializedBlock.valuesIterator() did not return original values!") - } + assertSameContents( + bigList, valuesReturnedFromFailedPut, "PartiallySerializedBlock.valuesIterator()") // The unroll memory was freed once the iterator was fully traversed. assert(memoryStore.currentUnrollMemoryForThisTask === 0) } @@ -340,12 +340,10 @@ class MemoryStoreSuite res.left.get.finishWritingToStream(bos) // The unroll memory was freed once the block was fully written. assert(memoryStore.currentUnrollMemoryForThisTask === 0) - val deserializationStream = serializerManager.dataDeserializeStream[Any]( - "b1", new ByteBufferInputStream(bos.toByteBuffer))(ClassTag.Any) - deserializationStream.zip(bigList.iterator).foreach { case (e, a) => - assert(e === a, - "PartiallySerializedBlock.finishWritingtoStream() did not write original values!") - } + val deserializedValues = serializerManager.dataDeserializeStream[Any]( + "b1", new ByteBufferInputStream(bos.toByteBuffer))(ClassTag.Any).toSeq + assertSameContents( + bigList, deserializedValues, "PartiallySerializedBlock.finishWritingToStream()") } test("multiple unrolls by the same thread") { diff --git a/core/src/test/scala/org/apache/spark/storage/PartiallySerializedBlockSuite.scala b/core/src/test/scala/org/apache/spark/storage/PartiallySerializedBlockSuite.scala new file mode 100644 index 0000000000000..ec4f2637fadd0 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/storage/PartiallySerializedBlockSuite.scala @@ -0,0 +1,215 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import java.nio.ByteBuffer + +import scala.reflect.ClassTag + +import org.mockito.Mockito +import org.mockito.Mockito.atLeastOnce +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer +import org.scalatest.{BeforeAndAfterEach, PrivateMethodTester} + +import org.apache.spark.{SparkConf, SparkFunSuite, TaskContext, TaskContextImpl} +import org.apache.spark.memory.MemoryMode +import org.apache.spark.serializer.{JavaSerializer, SerializationStream, SerializerManager} +import org.apache.spark.storage.memory.{MemoryStore, PartiallySerializedBlock, RedirectableOutputStream} +import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream} +import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream} + +class PartiallySerializedBlockSuite + extends SparkFunSuite + with BeforeAndAfterEach + with PrivateMethodTester { + + private val blockId = new TestBlockId("test") + private val conf = new SparkConf() + private val memoryStore = Mockito.mock(classOf[MemoryStore], Mockito.RETURNS_SMART_NULLS) + private val serializerManager = new SerializerManager(new JavaSerializer(conf), conf) + + private val getSerializationStream = PrivateMethod[SerializationStream]('serializationStream) + private val getRedirectableOutputStream = + PrivateMethod[RedirectableOutputStream]('redirectableOutputStream) + + override protected def beforeEach(): Unit = { + super.beforeEach() + Mockito.reset(memoryStore) + } + + private def partiallyUnroll[T: ClassTag]( + iter: Iterator[T], + numItemsToBuffer: Int): PartiallySerializedBlock[T] = { + + val bbos: ChunkedByteBufferOutputStream = { + val spy = Mockito.spy(new ChunkedByteBufferOutputStream(128, ByteBuffer.allocate)) + Mockito.doAnswer(new Answer[ChunkedByteBuffer] { + override def answer(invocationOnMock: InvocationOnMock): ChunkedByteBuffer = { + Mockito.spy(invocationOnMock.callRealMethod().asInstanceOf[ChunkedByteBuffer]) + } + }).when(spy).toChunkedByteBuffer + spy + } + + val serializer = serializerManager.getSerializer(implicitly[ClassTag[T]]).newInstance() + val redirectableOutputStream = Mockito.spy(new RedirectableOutputStream) + redirectableOutputStream.setOutputStream(bbos) + val serializationStream = Mockito.spy(serializer.serializeStream(redirectableOutputStream)) + + (1 to numItemsToBuffer).foreach { _ => + assert(iter.hasNext) + serializationStream.writeObject[T](iter.next()) + } + + val unrollMemory = bbos.size + new PartiallySerializedBlock[T]( + memoryStore, + serializerManager, + blockId, + serializationStream = serializationStream, + redirectableOutputStream, + unrollMemory = unrollMemory, + memoryMode = MemoryMode.ON_HEAP, + bbos, + rest = iter, + classTag = implicitly[ClassTag[T]]) + } + + test("valuesIterator() and finishWritingToStream() cannot be called after discard() is called") { + val partiallySerializedBlock = partiallyUnroll((1 to 10).iterator, 2) + partiallySerializedBlock.discard() + intercept[IllegalStateException] { + partiallySerializedBlock.finishWritingToStream(null) + } + intercept[IllegalStateException] { + partiallySerializedBlock.valuesIterator + } + } + + test("discard() can be called more than once") { + val partiallySerializedBlock = partiallyUnroll((1 to 10).iterator, 2) + partiallySerializedBlock.discard() + partiallySerializedBlock.discard() + } + + test("cannot call valuesIterator() more than once") { + val partiallySerializedBlock = partiallyUnroll((1 to 10).iterator, 2) + partiallySerializedBlock.valuesIterator + intercept[IllegalStateException] { + partiallySerializedBlock.valuesIterator + } + } + + test("cannot call finishWritingToStream() more than once") { + val partiallySerializedBlock = partiallyUnroll((1 to 10).iterator, 2) + partiallySerializedBlock.finishWritingToStream(new ByteBufferOutputStream()) + intercept[IllegalStateException] { + partiallySerializedBlock.finishWritingToStream(new ByteBufferOutputStream()) + } + } + + test("cannot call finishWritingToStream() after valuesIterator()") { + val partiallySerializedBlock = partiallyUnroll((1 to 10).iterator, 2) + partiallySerializedBlock.valuesIterator + intercept[IllegalStateException] { + partiallySerializedBlock.finishWritingToStream(new ByteBufferOutputStream()) + } + } + + test("cannot call valuesIterator() after finishWritingToStream()") { + val partiallySerializedBlock = partiallyUnroll((1 to 10).iterator, 2) + partiallySerializedBlock.finishWritingToStream(new ByteBufferOutputStream()) + intercept[IllegalStateException] { + partiallySerializedBlock.valuesIterator + } + } + + test("buffers are deallocated in a TaskCompletionListener") { + try { + TaskContext.setTaskContext(TaskContext.empty()) + val partiallySerializedBlock = partiallyUnroll((1 to 10).iterator, 2) + TaskContext.get().asInstanceOf[TaskContextImpl].markTaskCompleted() + Mockito.verify(partiallySerializedBlock.getUnrolledChunkedByteBuffer).dispose() + Mockito.verifyNoMoreInteractions(memoryStore) + } finally { + TaskContext.unset() + } + } + + private def testUnroll[T: ClassTag]( + testCaseName: String, + items: Seq[T], + numItemsToBuffer: Int): Unit = { + + test(s"$testCaseName with discard() and numBuffered = $numItemsToBuffer") { + val partiallySerializedBlock = partiallyUnroll(items.iterator, numItemsToBuffer) + partiallySerializedBlock.discard() + + Mockito.verify(memoryStore).releaseUnrollMemoryForThisTask( + MemoryMode.ON_HEAP, partiallySerializedBlock.unrollMemory) + Mockito.verify(partiallySerializedBlock.invokePrivate(getSerializationStream())).close() + Mockito.verify(partiallySerializedBlock.invokePrivate(getRedirectableOutputStream())).close() + Mockito.verifyNoMoreInteractions(memoryStore) + Mockito.verify(partiallySerializedBlock.getUnrolledChunkedByteBuffer, atLeastOnce).dispose() + } + + test(s"$testCaseName with finishWritingToStream() and numBuffered = $numItemsToBuffer") { + val partiallySerializedBlock = partiallyUnroll(items.iterator, numItemsToBuffer) + val bbos = Mockito.spy(new ByteBufferOutputStream()) + partiallySerializedBlock.finishWritingToStream(bbos) + + Mockito.verify(memoryStore).releaseUnrollMemoryForThisTask( + MemoryMode.ON_HEAP, partiallySerializedBlock.unrollMemory) + Mockito.verify(partiallySerializedBlock.invokePrivate(getSerializationStream())).close() + Mockito.verify(partiallySerializedBlock.invokePrivate(getRedirectableOutputStream())).close() + Mockito.verify(bbos).close() + Mockito.verifyNoMoreInteractions(memoryStore) + Mockito.verify(partiallySerializedBlock.getUnrolledChunkedByteBuffer, atLeastOnce).dispose() + + val serializer = serializerManager.getSerializer(implicitly[ClassTag[T]]).newInstance() + val deserialized = + serializer.deserializeStream(new ByteBufferInputStream(bbos.toByteBuffer)).asIterator.toSeq + assert(deserialized === items) + } + + test(s"$testCaseName with valuesIterator() and numBuffered = $numItemsToBuffer") { + val partiallySerializedBlock = partiallyUnroll(items.iterator, numItemsToBuffer) + val valuesIterator = partiallySerializedBlock.valuesIterator + Mockito.verify(partiallySerializedBlock.invokePrivate(getSerializationStream())).close() + Mockito.verify(partiallySerializedBlock.invokePrivate(getRedirectableOutputStream())).close() + + val deserializedItems = valuesIterator.toArray.toSeq + Mockito.verify(memoryStore).releaseUnrollMemoryForThisTask( + MemoryMode.ON_HEAP, partiallySerializedBlock.unrollMemory) + Mockito.verifyNoMoreInteractions(memoryStore) + Mockito.verify(partiallySerializedBlock.getUnrolledChunkedByteBuffer, atLeastOnce).dispose() + assert(deserializedItems === items) + } + } + + testUnroll("basic numbers", 1 to 1000, numItemsToBuffer = 50) + testUnroll("basic numbers", 1 to 1000, numItemsToBuffer = 0) + testUnroll("basic numbers", 1 to 1000, numItemsToBuffer = 1000) + testUnroll("case classes", (1 to 1000).map(x => MyCaseClass(x.toString)), numItemsToBuffer = 50) + testUnroll("case classes", (1 to 1000).map(x => MyCaseClass(x.toString)), numItemsToBuffer = 0) + testUnroll("case classes", (1 to 1000).map(x => MyCaseClass(x.toString)), numItemsToBuffer = 1000) + testUnroll("empty iterator", Seq.empty[String], numItemsToBuffer = 0) +} + +private case class MyCaseClass(str: String) diff --git a/core/src/test/scala/org/apache/spark/storage/PartiallyUnrolledIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/PartiallyUnrolledIteratorSuite.scala index 02c2331dc3946..4253cc8ca4cd1 100644 --- a/core/src/test/scala/org/apache/spark/storage/PartiallyUnrolledIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/PartiallyUnrolledIteratorSuite.scala @@ -33,7 +33,7 @@ class PartiallyUnrolledIteratorSuite extends SparkFunSuite with MockitoSugar { val rest = (unrollSize until restSize + unrollSize).iterator val memoryStore = mock[MemoryStore] - val joinIterator = new PartiallyUnrolledIterator(memoryStore, unrollSize, unroll, rest) + val joinIterator = new PartiallyUnrolledIterator(memoryStore, ON_HEAP, unrollSize, unroll, rest) // Firstly iterate over unrolling memory iterator (0 until unrollSize).foreach { value => diff --git a/core/src/test/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStreamSuite.scala b/core/src/test/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStreamSuite.scala index 226622075a6cc..86961745673c6 100644 --- a/core/src/test/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStreamSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStreamSuite.scala @@ -28,12 +28,14 @@ class ChunkedByteBufferOutputStreamSuite extends SparkFunSuite { test("empty output") { val o = new ChunkedByteBufferOutputStream(1024, ByteBuffer.allocate) + o.close() assert(o.toChunkedByteBuffer.size === 0) } test("write a single byte") { val o = new ChunkedByteBufferOutputStream(1024, ByteBuffer.allocate) o.write(10) + o.close() val chunkedByteBuffer = o.toChunkedByteBuffer assert(chunkedByteBuffer.getChunks().length === 1) assert(chunkedByteBuffer.getChunks().head.array().toSeq === Seq(10.toByte)) @@ -43,6 +45,7 @@ class ChunkedByteBufferOutputStreamSuite extends SparkFunSuite { val o = new ChunkedByteBufferOutputStream(10, ByteBuffer.allocate) o.write(new Array[Byte](9)) o.write(99) + o.close() val chunkedByteBuffer = o.toChunkedByteBuffer assert(chunkedByteBuffer.getChunks().length === 1) assert(chunkedByteBuffer.getChunks().head.array()(9) === 99.toByte) @@ -52,6 +55,7 @@ class ChunkedByteBufferOutputStreamSuite extends SparkFunSuite { val o = new ChunkedByteBufferOutputStream(10, ByteBuffer.allocate) o.write(new Array[Byte](10)) o.write(99) + o.close() val arrays = o.toChunkedByteBuffer.getChunks().map(_.array()) assert(arrays.length === 2) assert(arrays(1).length === 1) @@ -63,6 +67,7 @@ class ChunkedByteBufferOutputStreamSuite extends SparkFunSuite { Random.nextBytes(ref) val o = new ChunkedByteBufferOutputStream(10, ByteBuffer.allocate) o.write(ref) + o.close() val arrays = o.toChunkedByteBuffer.getChunks().map(_.array()) assert(arrays.length === 1) assert(arrays.head.length === ref.length) @@ -74,6 +79,7 @@ class ChunkedByteBufferOutputStreamSuite extends SparkFunSuite { Random.nextBytes(ref) val o = new ChunkedByteBufferOutputStream(10, ByteBuffer.allocate) o.write(ref) + o.close() val arrays = o.toChunkedByteBuffer.getChunks().map(_.array()) assert(arrays.length === 1) assert(arrays.head.length === ref.length) @@ -85,6 +91,7 @@ class ChunkedByteBufferOutputStreamSuite extends SparkFunSuite { Random.nextBytes(ref) val o = new ChunkedByteBufferOutputStream(10, ByteBuffer.allocate) o.write(ref) + o.close() val arrays = o.toChunkedByteBuffer.getChunks().map(_.array()) assert(arrays.length === 3) assert(arrays(0).length === 10) @@ -101,6 +108,7 @@ class ChunkedByteBufferOutputStreamSuite extends SparkFunSuite { Random.nextBytes(ref) val o = new ChunkedByteBufferOutputStream(10, ByteBuffer.allocate) o.write(ref) + o.close() val arrays = o.toChunkedByteBuffer.getChunks().map(_.array()) assert(arrays.length === 3) assert(arrays(0).length === 10) From 3a3c9ffbd282244407e9437c2b02ae7e062dd183 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sun, 18 Sep 2016 15:37:15 +0800 Subject: [PATCH 046/306] [SPARK-17518][SQL] Block Users to Specify the Internal Data Source Provider Hive ### What changes were proposed in this pull request? In Spark 2.1, we introduced a new internal provider `hive` for telling Hive serde tables from data source tables. This PR is to block users to specify this in `DataFrameWriter` and SQL APIs. ### How was this patch tested? Added a test case Author: gatorsmile Closes #15073 from gatorsmile/formatHive. --- .../apache/spark/sql/DataFrameWriter.scala | 3 ++ .../spark/sql/execution/SparkSqlParser.scala | 5 +- .../spark/sql/internal/CatalogImpl.scala | 2 +- .../spark/sql/internal/CatalogSuite.scala | 7 +++ .../spark/sql/hive/HiveStrategies.scala | 2 +- .../sql/hive/MetastoreDataSourcesSuite.scala | 51 +++++++++++++++++++ 6 files changed, 67 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index e137f076a0cab..64d3422cb4b54 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -357,6 +357,9 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { } private def saveAsTable(tableIdent: TableIdentifier): Unit = { + if (source.toLowerCase == "hive") { + throw new AnalysisException("Cannot create hive serde table with saveAsTable API") + } val tableExists = df.sparkSession.sessionState.catalog.tableExists(tableIdent) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 7ba1a9ff223de..5359cedc80974 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -22,7 +22,7 @@ import scala.collection.JavaConverters._ import org.antlr.v4.runtime.{ParserRuleContext, Token} import org.antlr.v4.runtime.tree.TerminalNode -import org.apache.spark.sql.SaveMode +import org.apache.spark.sql.{AnalysisException, SaveMode} import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.parser._ @@ -316,6 +316,9 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { } val options = Option(ctx.tablePropertyList).map(visitPropertyKeyValues).getOrElse(Map.empty) val provider = ctx.tableProvider.qualifiedName.getText + if (provider.toLowerCase == "hive") { + throw new AnalysisException("Cannot create hive serde table with CREATE TABLE USING") + } val schema = Option(ctx.colTypeList()).map(createStructType) val partitionColumnNames = Option(ctx.partitionColumnNames) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala index 1f87f0e73a3ba..78ad710a6262e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala @@ -258,7 +258,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { source: String, schema: StructType, options: Map[String, String]): DataFrame = { - if (source == "hive") { + if (source.toLowerCase == "hive") { throw new AnalysisException("Cannot create hive serde table with createExternalTable API.") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala index b221eed7b2426..549fd63f7462e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala @@ -322,6 +322,13 @@ class CatalogSuite assert(e2.message == "Cannot create a file-based external data source table without path") } + test("createExternalTable should fail if provider is hive") { + val e = intercept[AnalysisException] { + spark.catalog.createExternalTable("tbl", "HiVe", Map.empty[String, String]) + } + assert(e.message.contains("Cannot create hive serde table with createExternalTable API")) + } + // TODO: add tests for the rest of them } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index fb11c849edd94..9d2930948d6ba 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -61,7 +61,7 @@ private[hive] trait HiveStrategies { // `ErrorIfExists` mode, and `DataFrameWriter.saveAsTable` doesn't support hive serde // tables yet. if (mode == SaveMode.Append || mode == SaveMode.Overwrite) { - throw new AnalysisException("" + + throw new AnalysisException( "CTAS for hive serde tables does not support append or overwrite semantics.") } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index 3466733d7fdcd..0f331bae930f4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -26,6 +26,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType} +import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} import org.apache.spark.sql.hive.HiveExternalCatalog._ import org.apache.spark.sql.hive.client.HiveClient @@ -1151,6 +1152,56 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv } } + test("save API - format hive") { + withTempDir { dir => + val path = dir.getCanonicalPath + val e = intercept[ClassNotFoundException] { + spark.range(10).write.format("hive").mode(SaveMode.Ignore).save(path) + }.getMessage + assert(e.contains("Failed to find data source: hive")) + } + } + + test("saveAsTable API - format hive") { + val tableName = "tab1" + withTable(tableName) { + val e = intercept[AnalysisException] { + spark.range(10).write.format("hive").mode(SaveMode.Overwrite).saveAsTable(tableName) + }.getMessage + assert(e.contains("Cannot create hive serde table with saveAsTable API")) + } + } + + test("create a data source table using hive") { + val tableName = "tab1" + withTable (tableName) { + val e = intercept[AnalysisException] { + sql( + s""" + |CREATE TABLE $tableName + |(col1 int) + |USING hive + """.stripMargin) + }.getMessage + assert(e.contains("Cannot create hive serde table with CREATE TABLE USING")) + } + } + + test("create a temp view using hive") { + val tableName = "tab1" + withTable (tableName) { + val e = intercept[ClassNotFoundException] { + sql( + s""" + |CREATE TEMPORARY VIEW $tableName + |(col1 int) + |USING hive + """.stripMargin) + }.getMessage + assert(e.contains("Failed to find data source: hive")) + } + } + test("saveAsTable - source and target are the same table") { val tableName = "tab1" withTable(tableName) { From 3fe630d314cf50d69868b7707ac8d8d2027080b8 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sun, 18 Sep 2016 21:15:35 +0800 Subject: [PATCH 047/306] [SPARK-17541][SQL] fix some DDL bugs about table management when same-name temp view exists ## What changes were proposed in this pull request? In `SessionCatalog`, we have several operations(`tableExists`, `dropTable`, `loopupRelation`, etc) that handle both temp views and metastore tables/views. This brings some bugs to DDL commands that want to handle temp view only or metastore table/view only. These bugs are: 1. `CREATE TABLE USING` will fail if a same-name temp view exists 2. `Catalog.dropTempView`will un-cache and drop metastore table if a same-name table exists 3. `saveAsTable` will fail or have unexpected behaviour if a same-name temp view exists. These bug fixes are pulled out from https://github.com/apache/spark/pull/14962 and targets both master and 2.0 branch ## How was this patch tested? new regression tests Author: Wenchen Fan Closes #15099 from cloud-fan/fix-view. --- .../sql/catalyst/catalog/SessionCatalog.scala | 32 +++++--- .../catalog/SessionCatalogSuite.scala | 24 +++--- .../apache/spark/sql/DataFrameWriter.scala | 9 ++- .../command/createDataSourceTables.scala | 22 ++++-- .../spark/sql/internal/CatalogImpl.scala | 8 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 11 +++ .../spark/sql/internal/CatalogSuite.scala | 11 +++ .../sql/test/DataFrameReaderWriterSuite.scala | 76 +++++++++++++++++++ .../sql/hive/MetastoreDataSourcesSuite.scala | 13 ++-- .../sql/sources/HadoopFsRelationTest.scala | 10 +-- 10 files changed, 170 insertions(+), 46 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 9fb5db573b70f..574c3d7eeeec9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -325,9 +325,9 @@ class SessionCatalog( new Path(new Path(dbLocation), formatTableName(tableIdent.table)).toString } - // ------------------------------------------------------------- - // | Methods that interact with temporary and metastore tables | - // ------------------------------------------------------------- + // ---------------------------------------------- + // | Methods that interact with temp views only | + // ---------------------------------------------- /** * Create a temporary table. @@ -343,6 +343,24 @@ class SessionCatalog( tempTables.put(table, tableDefinition) } + /** + * Return a temporary view exactly as it was stored. + */ + def getTempView(name: String): Option[LogicalPlan] = synchronized { + tempTables.get(formatTableName(name)) + } + + /** + * Drop a temporary view. + */ + def dropTempView(name: String): Unit = synchronized { + tempTables.remove(formatTableName(name)) + } + + // ------------------------------------------------------------- + // | Methods that interact with temporary and metastore tables | + // ------------------------------------------------------------- + /** * Rename a table. * @@ -492,14 +510,6 @@ class SessionCatalog( tempTables.clear() } - /** - * Return a temporary table exactly as it was stored. - * For testing only. - */ - private[catalog] def getTempTable(name: String): Option[LogicalPlan] = synchronized { - tempTables.get(formatTableName(name)) - } - // ---------------------------------------------------------------------------- // Partitions // ---------------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index 012df629bbdef..84b77ad250b5c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -201,16 +201,16 @@ class SessionCatalogSuite extends SparkFunSuite { val tempTable2 = Range(1, 20, 2, 10) catalog.createTempView("tbl1", tempTable1, overrideIfExists = false) catalog.createTempView("tbl2", tempTable2, overrideIfExists = false) - assert(catalog.getTempTable("tbl1") == Option(tempTable1)) - assert(catalog.getTempTable("tbl2") == Option(tempTable2)) - assert(catalog.getTempTable("tbl3").isEmpty) + assert(catalog.getTempView("tbl1") == Option(tempTable1)) + assert(catalog.getTempView("tbl2") == Option(tempTable2)) + assert(catalog.getTempView("tbl3").isEmpty) // Temporary table already exists intercept[TempTableAlreadyExistsException] { catalog.createTempView("tbl1", tempTable1, overrideIfExists = false) } // Temporary table already exists but we override it catalog.createTempView("tbl1", tempTable2, overrideIfExists = true) - assert(catalog.getTempTable("tbl1") == Option(tempTable2)) + assert(catalog.getTempView("tbl1") == Option(tempTable2)) } test("drop table") { @@ -251,11 +251,11 @@ class SessionCatalogSuite extends SparkFunSuite { val tempTable = Range(1, 10, 2, 10) sessionCatalog.createTempView("tbl1", tempTable, overrideIfExists = false) sessionCatalog.setCurrentDatabase("db2") - assert(sessionCatalog.getTempTable("tbl1") == Some(tempTable)) + assert(sessionCatalog.getTempView("tbl1") == Some(tempTable)) assert(externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2")) // If database is not specified, temp table should be dropped first sessionCatalog.dropTable(TableIdentifier("tbl1"), ignoreIfNotExists = false, purge = false) - assert(sessionCatalog.getTempTable("tbl1") == None) + assert(sessionCatalog.getTempView("tbl1") == None) assert(externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2")) // If temp table does not exist, the table in the current database should be dropped sessionCatalog.dropTable(TableIdentifier("tbl1"), ignoreIfNotExists = false, purge = false) @@ -265,7 +265,7 @@ class SessionCatalogSuite extends SparkFunSuite { sessionCatalog.createTable(newTable("tbl1", "db2"), ignoreIfExists = false) sessionCatalog.dropTable(TableIdentifier("tbl1", Some("db2")), ignoreIfNotExists = false, purge = false) - assert(sessionCatalog.getTempTable("tbl1") == Some(tempTable)) + assert(sessionCatalog.getTempView("tbl1") == Some(tempTable)) assert(externalCatalog.listTables("db2").toSet == Set("tbl2")) } @@ -303,17 +303,17 @@ class SessionCatalogSuite extends SparkFunSuite { val tempTable = Range(1, 10, 2, 10) sessionCatalog.createTempView("tbl1", tempTable, overrideIfExists = false) sessionCatalog.setCurrentDatabase("db2") - assert(sessionCatalog.getTempTable("tbl1") == Option(tempTable)) + assert(sessionCatalog.getTempView("tbl1") == Option(tempTable)) assert(externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2")) // If database is not specified, temp table should be renamed first sessionCatalog.renameTable(TableIdentifier("tbl1"), "tbl3") - assert(sessionCatalog.getTempTable("tbl1").isEmpty) - assert(sessionCatalog.getTempTable("tbl3") == Option(tempTable)) + assert(sessionCatalog.getTempView("tbl1").isEmpty) + assert(sessionCatalog.getTempView("tbl3") == Option(tempTable)) assert(externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2")) // If database is specified, temp tables are never renamed sessionCatalog.renameTable(TableIdentifier("tbl2", Some("db2")), "tbl4") - assert(sessionCatalog.getTempTable("tbl3") == Option(tempTable)) - assert(sessionCatalog.getTempTable("tbl4").isEmpty) + assert(sessionCatalog.getTempView("tbl3") == Option(tempTable)) + assert(sessionCatalog.getTempView("tbl4").isEmpty) assert(externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl4")) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 64d3422cb4b54..9e343b5d24986 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -361,7 +361,12 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { throw new AnalysisException("Cannot create hive serde table with saveAsTable API") } - val tableExists = df.sparkSession.sessionState.catalog.tableExists(tableIdent) + val sessionState = df.sparkSession.sessionState + val db = tableIdent.database.getOrElse(sessionState.catalog.getCurrentDatabase) + val tableIdentWithDB = tableIdent.copy(database = Some(db)) + // Pass a table identifier with database part, so that `tableExists` won't check temp views + // unexpectedly. + val tableExists = sessionState.catalog.tableExists(tableIdentWithDB) (tableExists, mode) match { case (true, SaveMode.Ignore) => @@ -387,7 +392,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { bucketSpec = getBucketSpec ) val cmd = CreateTable(tableDesc, mode, Some(df.logicalPlan)) - df.sparkSession.sessionState.executePlan(cmd).toRdd + sessionState.executePlan(cmd).toRdd } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala index b1830e6cf3ea8..d8e20b09c1add 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala @@ -47,11 +47,15 @@ case class CreateDataSourceTableCommand(table: CatalogTable, ignoreIfExists: Boo assert(table.provider.isDefined) val sessionState = sparkSession.sessionState - if (sessionState.catalog.tableExists(table.identifier)) { + val db = table.identifier.database.getOrElse(sessionState.catalog.getCurrentDatabase) + val tableIdentWithDB = table.identifier.copy(database = Some(db)) + // Pass a table identifier with database part, so that `tableExists` won't check temp views + // unexpectedly. + if (sessionState.catalog.tableExists(tableIdentWithDB)) { if (ignoreIfExists) { return Seq.empty[Row] } else { - throw new AnalysisException(s"Table ${table.identifier.unquotedString} already exists.") + throw new AnalysisException(s"Table ${tableIdentWithDB.unquotedString} already exists.") } } @@ -128,9 +132,11 @@ case class CreateDataSourceTableAsSelectCommand( assert(table.provider.isDefined) assert(table.schema.isEmpty) - val tableName = table.identifier.unquotedString val provider = table.provider.get val sessionState = sparkSession.sessionState + val db = table.identifier.database.getOrElse(sessionState.catalog.getCurrentDatabase) + val tableIdentWithDB = table.identifier.copy(database = Some(db)) + val tableName = tableIdentWithDB.unquotedString val optionsWithPath = if (table.tableType == CatalogTableType.MANAGED) { table.storage.properties + ("path" -> sessionState.catalog.defaultTablePath(table.identifier)) @@ -140,7 +146,9 @@ case class CreateDataSourceTableAsSelectCommand( var createMetastoreTable = false var existingSchema = Option.empty[StructType] - if (sparkSession.sessionState.catalog.tableExists(table.identifier)) { + // Pass a table identifier with database part, so that `tableExists` won't check temp views + // unexpectedly. + if (sparkSession.sessionState.catalog.tableExists(tableIdentWithDB)) { // Check if we need to throw an exception or just return. mode match { case SaveMode.ErrorIfExists => @@ -165,7 +173,7 @@ case class CreateDataSourceTableAsSelectCommand( // inserting into (i.e. using the same compression). EliminateSubqueryAliases( - sessionState.catalog.lookupRelation(table.identifier)) match { + sessionState.catalog.lookupRelation(tableIdentWithDB)) match { case l @ LogicalRelation(_: InsertableRelation | _: HadoopFsRelation, _, _) => // check if the file formats match l.relation match { @@ -188,7 +196,7 @@ case class CreateDataSourceTableAsSelectCommand( throw new AnalysisException(s"Saving data in ${o.toString} is not supported.") } case SaveMode.Overwrite => - sparkSession.sql(s"DROP TABLE IF EXISTS $tableName") + sessionState.catalog.dropTable(tableIdentWithDB, ignoreIfNotExists = true, purge = false) // Need to create the table again. createMetastoreTable = true } @@ -230,7 +238,7 @@ case class CreateDataSourceTableAsSelectCommand( } // Refresh the cache of the table in the catalog. - sessionState.catalog.refreshTable(table.identifier) + sessionState.catalog.refreshTable(tableIdentWithDB) Seq.empty[Row] } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala index 78ad710a6262e..3fa62985624f0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalog.{Catalog, Column, Database, Function, Table} import org.apache.spark.sql.catalyst.{DefinedByConstructorParams, TableIdentifier} import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType, SessionCatalog} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.plans.logical.LocalRelation +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, SubqueryAlias} import org.apache.spark.sql.execution.datasources.CreateTable import org.apache.spark.sql.types.StructType @@ -284,8 +284,10 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { * @since 2.0.0 */ override def dropTempView(viewName: String): Unit = { - sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession.table(viewName)) - sessionCatalog.dropTable(TableIdentifier(viewName), ignoreIfNotExists = true, purge = false) + sparkSession.sessionState.catalog.getTempView(viewName).foreach { tempView => + sparkSession.sharedState.cacheManager.uncacheQuery(Dataset.ofRows(sparkSession, tempView)) + sessionCatalog.dropTempView(viewName) + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 3cc3b319f5a57..0ee8c959eeb4d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2667,4 +2667,15 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { }.limit(1).queryExecution.toRdd.count() assert(numRecordsRead.value === 10) } + + test("CREATE TABLE USING should not fail if a same-name temp view exists") { + withTable("same_name") { + withTempView("same_name") { + spark.range(10).createTempView("same_name") + sql("CREATE TABLE same_name(i int) USING json") + checkAnswer(spark.table("same_name"), spark.range(10).toDF()) + assert(spark.table("default.same_name").collect().isEmpty) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala index 549fd63f7462e..3dc67ffafb048 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala @@ -329,6 +329,17 @@ class CatalogSuite assert(e.message.contains("Cannot create hive serde table with createExternalTable API")) } + test("dropTempView should not un-cache and drop metastore table if a same-name table exists") { + withTable("same_name") { + spark.range(10).write.saveAsTable("same_name") + sql("CACHE TABLE same_name") + assert(spark.catalog.isCached("default.same_name")) + spark.catalog.dropTempView("same_name") + assert(spark.sessionState.catalog.tableExists(TableIdentifier("same_name", Some("default")))) + assert(spark.catalog.isCached("default.same_name")) + } + } + // TODO: add tests for the rest of them } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index 63b0e4588e4a6..7368dad62859b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -22,6 +22,7 @@ import java.io.File import org.scalatest.BeforeAndAfter import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} import org.apache.spark.util.Utils @@ -464,4 +465,79 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be checkAnswer(df, spark.createDataset(expectedResult).toDF()) assert(df.schema === expectedSchema) } + + test("saveAsTable with mode Append should not fail if the table not exists " + + "but a same-name temp view exist") { + withTable("same_name") { + withTempView("same_name") { + spark.range(10).createTempView("same_name") + spark.range(20).write.mode(SaveMode.Append).saveAsTable("same_name") + assert( + spark.sessionState.catalog.tableExists(TableIdentifier("same_name", Some("default")))) + } + } + } + + test("saveAsTable with mode Append should not fail if the table already exists " + + "and a same-name temp view exist") { + withTable("same_name") { + withTempView("same_name") { + sql("CREATE TABLE same_name(id LONG) USING parquet") + spark.range(10).createTempView("same_name") + spark.range(20).write.mode(SaveMode.Append).saveAsTable("same_name") + checkAnswer(spark.table("same_name"), spark.range(10).toDF()) + checkAnswer(spark.table("default.same_name"), spark.range(20).toDF()) + } + } + } + + test("saveAsTable with mode ErrorIfExists should not fail if the table not exists " + + "but a same-name temp view exist") { + withTable("same_name") { + withTempView("same_name") { + spark.range(10).createTempView("same_name") + spark.range(20).write.mode(SaveMode.ErrorIfExists).saveAsTable("same_name") + assert( + spark.sessionState.catalog.tableExists(TableIdentifier("same_name", Some("default")))) + } + } + } + + test("saveAsTable with mode Overwrite should not drop the temp view if the table not exists " + + "but a same-name temp view exist") { + withTable("same_name") { + withTempView("same_name") { + spark.range(10).createTempView("same_name") + spark.range(20).write.mode(SaveMode.Overwrite).saveAsTable("same_name") + assert(spark.sessionState.catalog.getTempView("same_name").isDefined) + assert( + spark.sessionState.catalog.tableExists(TableIdentifier("same_name", Some("default")))) + } + } + } + + test("saveAsTable with mode Overwrite should not fail if the table already exists " + + "and a same-name temp view exist") { + withTable("same_name") { + withTempView("same_name") { + sql("CREATE TABLE same_name(id LONG) USING parquet") + spark.range(10).createTempView("same_name") + spark.range(20).write.mode(SaveMode.Overwrite).saveAsTable("same_name") + checkAnswer(spark.table("same_name"), spark.range(10).toDF()) + checkAnswer(spark.table("default.same_name"), spark.range(20).toDF()) + } + } + } + + test("saveAsTable with mode Ignore should create the table if the table not exists " + + "but a same-name temp view exist") { + withTable("same_name") { + withTempView("same_name") { + spark.range(10).createTempView("same_name") + spark.range(20).write.mode(SaveMode.Ignore).saveAsTable("same_name") + assert( + spark.sessionState.catalog.tableExists(TableIdentifier("same_name", Some("default")))) + } + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index 0f331bae930f4..7143adf02b0e6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -339,7 +339,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv }.getMessage assert( - message.contains("Table ctasJsonTable already exists."), + message.contains("Table default.ctasJsonTable already exists."), "We should complain that ctasJsonTable already exists") // The following statement should be fine if it has IF NOT EXISTS. @@ -515,7 +515,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv assert( intercept[AnalysisException] { sparkSession.catalog.createExternalTable("createdJsonTable", jsonFilePath.toString) - }.getMessage.contains("Table createdJsonTable already exists."), + }.getMessage.contains("Table default.createdJsonTable already exists."), "We should complain that createdJsonTable already exists") } @@ -907,7 +907,8 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv val e = intercept[AnalysisException] { createDF(10, 19).write.mode(SaveMode.Append).format("orc").saveAsTable("appendOrcToParquet") } - assert(e.getMessage.contains("The file format of the existing table appendOrcToParquet " + + assert(e.getMessage.contains( + "The file format of the existing table default.appendOrcToParquet " + "is `org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat`. " + "It doesn't match the specified format `orc`")) } @@ -918,7 +919,8 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv createDF(10, 19).write.mode(SaveMode.Append).format("parquet") .saveAsTable("appendParquetToJson") } - assert(e.getMessage.contains("The file format of the existing table appendParquetToJson " + + assert(e.getMessage.contains( + "The file format of the existing table default.appendParquetToJson " + "is `org.apache.spark.sql.execution.datasources.json.JsonFileFormat`. " + "It doesn't match the specified format `parquet`")) } @@ -929,7 +931,8 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv createDF(10, 19).write.mode(SaveMode.Append).format("text") .saveAsTable("appendTextToJson") } - assert(e.getMessage.contains("The file format of the existing table appendTextToJson is " + + assert(e.getMessage.contains( + "The file format of the existing table default.appendTextToJson is " + "`org.apache.spark.sql.execution.datasources.json.JsonFileFormat`. " + "It doesn't match the specified format `text`")) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala index 27bb9676e9abf..22f13a494cd4c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala @@ -337,9 +337,8 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes } test("saveAsTable()/load() - non-partitioned table - ErrorIfExists") { - Seq.empty[(Int, String)].toDF().createOrReplaceTempView("t") - - withTempView("t") { + withTable("t") { + sql("CREATE TABLE t(i INT) USING parquet") intercept[AnalysisException] { testDF.write.format(dataSourceName).mode(SaveMode.ErrorIfExists).saveAsTable("t") } @@ -347,9 +346,8 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes } test("saveAsTable()/load() - non-partitioned table - Ignore") { - Seq.empty[(Int, String)].toDF().createOrReplaceTempView("t") - - withTempView("t") { + withTable("t") { + sql("CREATE TABLE t(i INT) USING parquet") testDF.write.format(dataSourceName).mode(SaveMode.Ignore).saveAsTable("t") assert(spark.table("t").collect().isEmpty) } From 5d3f4615f8d0a19b97cde5ae603f74aef2cc2fd2 Mon Sep 17 00:00:00 2001 From: jiangxingbo Date: Sun, 18 Sep 2016 16:04:37 +0100 Subject: [PATCH 048/306] [SPARK-17506][SQL] Improve the check double values equality rule. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? In `ExpressionEvalHelper`, we check the equality between two double values by comparing whether the expected value is within the range [target - tolerance, target + tolerance], but this can cause a negative false when the compared numerics are very large. Before: ``` val1 = 1.6358558070241E306 val2 = 1.6358558070240974E306 ExpressionEvalHelper.compareResults(val1, val2) false ``` In fact, `val1` and `val2` are but with different precisions, we should tolerant this case by comparing with percentage range, eg.,expected is within range [target - target * tolerance_percentage, target + target * tolerance_percentage]. After: ``` val1 = 1.6358558070241E306 val2 = 1.6358558070240974E306 ExpressionEvalHelper.compareResults(val1, val2) true ``` ## How was this patch tested? Exsiting testcases. Author: jiangxingbo Closes #15059 from jiangxb1987/deq. --- .../ArithmeticExpressionSuite.scala | 8 ++--- .../expressions/ExpressionEvalHelper.scala | 29 +++++++++++++++++-- 2 files changed, 30 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index 687387507e214..5c9824289b3cb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -170,11 +170,9 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Remainder(positiveLongLit, positiveLongLit), 0L) checkEvaluation(Remainder(negativeLongLit, negativeLongLit), 0L) - // TODO: the following lines would fail the test due to inconsistency result of interpret - // and codegen for remainder between giant values, seems like a numeric stability issue - // DataTypeTestUtils.numericTypeWithoutDecimal.foreach { tpe => - // checkConsistencyBetweenInterpretedAndCodegen(Remainder, tpe, tpe) - // } + DataTypeTestUtils.numericTypeWithoutDecimal.foreach { tpe => + checkConsistencyBetweenInterpretedAndCodegen(Remainder, tpe, tpe) + } } test("Abs") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 668543a28bd30..f0c149c02b9aa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.scalacheck.Gen import org.scalactic.TripleEqualsSupport.Spread +import org.scalatest.exceptions.TestFailedException import org.scalatest.prop.GeneratorDrivenPropertyChecks import org.apache.spark.SparkFunSuite @@ -289,13 +290,37 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { (result, expected) match { case (result: Array[Byte], expected: Array[Byte]) => java.util.Arrays.equals(result, expected) - case (result: Double, expected: Spread[Double @unchecked]) => - expected.asInstanceOf[Spread[Double]].isWithin(result) case (result: Double, expected: Double) if result.isNaN && expected.isNaN => true + case (result: Double, expected: Double) => + relativeErrorComparison(result, expected) case (result: Float, expected: Float) if result.isNaN && expected.isNaN => true case _ => result == expected } } + + /** + * Private helper function for comparing two values using relative tolerance. + * Note that if x or y is extremely close to zero, i.e., smaller than Double.MinPositiveValue, + * the relative tolerance is meaningless, so the exception will be raised to warn users. + * + * TODO: this duplicates functions in spark.ml.util.TestingUtils.relTol and + * spark.mllib.util.TestingUtils.relTol, they could be moved to common utils sub module for the + * whole spark project which does not depend on other modules. See more detail in discussion: + * https://github.com/apache/spark/pull/15059#issuecomment-246940444 + */ + private def relativeErrorComparison(x: Double, y: Double, eps: Double = 1E-8): Boolean = { + val absX = math.abs(x) + val absY = math.abs(y) + val diff = math.abs(x - y) + if (x == y) { + true + } else if (absX < Double.MinPositiveValue || absY < Double.MinPositiveValue) { + throw new TestFailedException( + s"$x or $y is extremely close to zero, so the relative tolerance is meaningless.", 0) + } else { + diff < eps * math.min(absX, absY) + } + } } From 342c0e65bec4b9a715017089ab6ea127f3c46540 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sun, 18 Sep 2016 16:22:31 +0100 Subject: [PATCH 049/306] [SPARK-17546][DEPLOY] start-* scripts should use hostname -f ## What changes were proposed in this pull request? Call `hostname -f` to get fully qualified host name ## How was this patch tested? Jenkins tests of course, but also verified output of command on OS X and Linux Author: Sean Owen Closes #15129 from srowen/SPARK-17546. --- sbin/start-master.sh | 2 +- sbin/start-mesos-dispatcher.sh | 2 +- sbin/start-slaves.sh | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sbin/start-master.sh b/sbin/start-master.sh index 981cb15bc0006..d970fcc45e2c1 100755 --- a/sbin/start-master.sh +++ b/sbin/start-master.sh @@ -48,7 +48,7 @@ if [ "$SPARK_MASTER_PORT" = "" ]; then fi if [ "$SPARK_MASTER_HOST" = "" ]; then - SPARK_MASTER_HOST=`hostname` + SPARK_MASTER_HOST=`hostname -f` fi if [ "$SPARK_MASTER_WEBUI_PORT" = "" ]; then diff --git a/sbin/start-mesos-dispatcher.sh b/sbin/start-mesos-dispatcher.sh index 06a966d1c20b4..ef65fb9539146 100755 --- a/sbin/start-mesos-dispatcher.sh +++ b/sbin/start-mesos-dispatcher.sh @@ -34,7 +34,7 @@ if [ "$SPARK_MESOS_DISPATCHER_PORT" = "" ]; then fi if [ "$SPARK_MESOS_DISPATCHER_HOST" = "" ]; then - SPARK_MESOS_DISPATCHER_HOST=`hostname` + SPARK_MESOS_DISPATCHER_HOST=`hostname -f` fi if [ "$SPARK_MESOS_DISPATCHER_NUM" = "" ]; then diff --git a/sbin/start-slaves.sh b/sbin/start-slaves.sh index 0fa1605489704..7d8871251f81b 100755 --- a/sbin/start-slaves.sh +++ b/sbin/start-slaves.sh @@ -32,7 +32,7 @@ if [ "$SPARK_MASTER_PORT" = "" ]; then fi if [ "$SPARK_MASTER_HOST" = "" ]; then - SPARK_MASTER_HOST="`hostname`" + SPARK_MASTER_HOST="`hostname -f`" fi # Launch the slaves From 7151011b38a841d9d4bc2e453b9a7cfe42f74f8f Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sun, 18 Sep 2016 19:18:49 +0100 Subject: [PATCH 050/306] [SPARK-17586][BUILD] Do not call static member via instance reference ## What changes were proposed in this pull request? This PR fixes a warning message as below: ``` [WARNING] .../UnsafeInMemorySorter.java:284: warning: [static] static method should be qualified by type name, TaskMemoryManager, instead of by an expression [WARNING] currentPageNumber = memoryManager.decodePageNumber(recordPointer) ``` by referencing the static member via class not instance reference. ## How was this patch tested? Existing tests should cover this - Jenkins tests. Author: hyukjinkwon Closes #15141 from HyukjinKwon/SPARK-17586. --- .../spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index be382955c0d42..3b1ece4373f49 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -281,7 +281,7 @@ public boolean hasNext() { public void loadNext() { // This pointer points to a 4-byte record length, followed by the record's bytes final long recordPointer = array.get(offset + position); - currentPageNumber = memoryManager.decodePageNumber(recordPointer); + currentPageNumber = TaskMemoryManager.decodePageNumber(recordPointer); baseObject = memoryManager.getPage(recordPointer); baseOffset = memoryManager.getOffsetInPage(recordPointer) + 4; // Skip over record length recordLength = Platform.getInt(baseObject, baseOffset - 4); From 1dbb725dbef30bf7633584ce8efdb573f2d92bca Mon Sep 17 00:00:00 2001 From: Liwei Lin Date: Sun, 18 Sep 2016 19:25:58 +0100 Subject: [PATCH 051/306] [SPARK-16462][SPARK-16460][SPARK-15144][SQL] Make CSV cast null values properly ## Problem CSV in Spark 2.0.0: - does not read null values back correctly for certain data types such as `Boolean`, `TimestampType`, `DateType` -- this is a regression comparing to 1.6; - does not read empty values (specified by `options.nullValue`) as `null`s for `StringType` -- this is compatible with 1.6 but leads to problems like SPARK-16903. ## What changes were proposed in this pull request? This patch makes changes to read all empty values back as `null`s. ## How was this patch tested? New test cases. Author: Liwei Lin Closes #14118 from lw-lin/csv-cast-null. --- python/pyspark/sql/readwriter.py | 3 +- python/pyspark/sql/streaming.py | 3 +- .../apache/spark/sql/DataFrameReader.scala | 3 +- .../datasources/csv/CSVInferSchema.scala | 108 ++++++++---------- .../sql/streaming/DataStreamReader.scala | 3 +- .../execution/datasources/csv/CSVSuite.scala | 2 +- .../datasources/csv/CSVTypeCastSuite.scala | 54 +++++---- 7 files changed, 93 insertions(+), 83 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 3d79e0ccccee4..a6860efa89b9e 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -329,7 +329,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non being read should be skipped. If None is set, it uses the default value, ``false``. :param nullValue: sets the string representation of a null value. If None is set, it uses - the default value, empty string. + the default value, empty string. Since 2.0.1, this ``nullValue`` param + applies to all supported types including the string type. :param nanValue: sets the string representation of a non-number value. If None is set, it uses the default value, ``NaN``. :param positiveInf: sets the string representation of a positive infinity value. If None diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 67375f6b5f942..01364517edd0e 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -497,7 +497,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non being read should be skipped. If None is set, it uses the default value, ``false``. :param nullValue: sets the string representation of a null value. If None is set, it uses - the default value, empty string. + the default value, empty string. Since 2.0.1, this ``nullValue`` param + applies to all supported types including the string type. :param nanValue: sets the string representation of a non-number value. If None is set, it uses the default value, ``NaN``. :param positiveInf: sets the string representation of a positive infinity value. If None diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index d29d90ce40453..30f39c70fe0bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -376,7 +376,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * from values being read should be skipped. *
    • `ignoreTrailingWhiteSpace` (default `false`): defines whether or not trailing * whitespaces from values being read should be skipped.
    • - *
    • `nullValue` (default empty string): sets the string representation of a null value.
    • + *
    • `nullValue` (default empty string): sets the string representation of a null value. Since + * 2.0.1, this applies to all supported types including the string type.
    • *
    • `nanValue` (default `NaN`): sets the string representation of a non-number" value.
    • *
    • `positiveInf` (default `Inf`): sets the string representation of a positive infinity * value.
    • diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala index 1ca6eff1b8c2e..3ab775c909238 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala @@ -232,66 +232,58 @@ private[csv] object CSVTypeCast { nullable: Boolean = true, options: CSVOptions = CSVOptions()): Any = { - castType match { - case _: ByteType => if (datum == options.nullValue && nullable) null else datum.toByte - case _: ShortType => if (datum == options.nullValue && nullable) null else datum.toShort - case _: IntegerType => if (datum == options.nullValue && nullable) null else datum.toInt - case _: LongType => if (datum == options.nullValue && nullable) null else datum.toLong - case _: FloatType => - if (datum == options.nullValue && nullable) { - null - } else if (datum == options.nanValue) { - Float.NaN - } else if (datum == options.negativeInf) { - Float.NegativeInfinity - } else if (datum == options.positiveInf) { - Float.PositiveInfinity - } else { - Try(datum.toFloat) - .getOrElse(NumberFormat.getInstance(Locale.getDefault).parse(datum).floatValue()) - } - case _: DoubleType => - if (datum == options.nullValue && nullable) { - null - } else if (datum == options.nanValue) { - Double.NaN - } else if (datum == options.negativeInf) { - Double.NegativeInfinity - } else if (datum == options.positiveInf) { - Double.PositiveInfinity - } else { - Try(datum.toDouble) - .getOrElse(NumberFormat.getInstance(Locale.getDefault).parse(datum).doubleValue()) - } - case _: BooleanType => datum.toBoolean - case dt: DecimalType => - if (datum == options.nullValue && nullable) { - null - } else { - val value = new BigDecimal(datum.replaceAll(",", "")) - Decimal(value, dt.precision, dt.scale) - } - case _: TimestampType => - // This one will lose microseconds parts. - // See https://issues.apache.org/jira/browse/SPARK-10681. - Try(options.timestampFormat.parse(datum).getTime * 1000L) - .getOrElse { - // If it fails to parse, then tries the way used in 2.0 and 1.x for backwards - // compatibility. - DateTimeUtils.stringToTime(datum).getTime * 1000L + if (nullable && datum == options.nullValue) { + null + } else { + castType match { + case _: ByteType => datum.toByte + case _: ShortType => datum.toShort + case _: IntegerType => datum.toInt + case _: LongType => datum.toLong + case _: FloatType => + datum match { + case options.nanValue => Float.NaN + case options.negativeInf => Float.NegativeInfinity + case options.positiveInf => Float.PositiveInfinity + case _ => + Try(datum.toFloat) + .getOrElse(NumberFormat.getInstance(Locale.getDefault).parse(datum).floatValue()) } - case _: DateType => - // This one will lose microseconds parts. - // See https://issues.apache.org/jira/browse/SPARK-10681.x - Try(DateTimeUtils.millisToDays(options.dateFormat.parse(datum).getTime)) - .getOrElse { - // If it fails to parse, then tries the way used in 2.0 and 1.x for backwards - // compatibility. - DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(datum).getTime) + case _: DoubleType => + datum match { + case options.nanValue => Double.NaN + case options.negativeInf => Double.NegativeInfinity + case options.positiveInf => Double.PositiveInfinity + case _ => + Try(datum.toDouble) + .getOrElse(NumberFormat.getInstance(Locale.getDefault).parse(datum).doubleValue()) } - case _: StringType => UTF8String.fromString(datum) - case udt: UserDefinedType[_] => castTo(datum, udt.sqlType, nullable, options) - case _ => throw new RuntimeException(s"Unsupported type: ${castType.typeName}") + case _: BooleanType => datum.toBoolean + case dt: DecimalType => + val value = new BigDecimal(datum.replaceAll(",", "")) + Decimal(value, dt.precision, dt.scale) + case _: TimestampType => + // This one will lose microseconds parts. + // See https://issues.apache.org/jira/browse/SPARK-10681. + Try(options.timestampFormat.parse(datum).getTime * 1000L) + .getOrElse { + // If it fails to parse, then tries the way used in 2.0 and 1.x for backwards + // compatibility. + DateTimeUtils.stringToTime(datum).getTime * 1000L + } + case _: DateType => + // This one will lose microseconds parts. + // See https://issues.apache.org/jira/browse/SPARK-10681.x + Try(DateTimeUtils.millisToDays(options.dateFormat.parse(datum).getTime)) + .getOrElse { + // If it fails to parse, then tries the way used in 2.0 and 1.x for backwards + // compatibility. + DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(datum).getTime) + } + case _: StringType => UTF8String.fromString(datum) + case udt: UserDefinedType[_] => castTo(datum, udt.sqlType, nullable, options) + case _ => throw new RuntimeException(s"Unsupported type: ${castType.typeName}") + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index c25f71af7362a..9d174051bc923 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -232,7 +232,8 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * from values being read should be skipped. *
    • `ignoreTrailingWhiteSpace` (default `false`): defines whether or not trailing * whitespaces from values being read should be skipped.
    • - *
    • `nullValue` (default empty string): sets the string representation of a null value.
    • + *
    • `nullValue` (default empty string): sets the string representation of a null value. Since + * 2.0.1, this applies to all supported types including the string type.
    • *
    • `nanValue` (default `NaN`): sets the string representation of a non-number" value.
    • *
    • `positiveInf` (default `Inf`): sets the string representation of a positive infinity * value.
    • diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 1930862118e9b..29aac9def6924 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -554,7 +554,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { verifyCars(cars, withHeader = true, checkValues = false) val results = cars.collect() - assert(results(0).toSeq === Array(2012, "Tesla", "S", "null", "null")) + assert(results(0).toSeq === Array(2012, "Tesla", "S", null, null)) assert(results(2).toSeq === Array(null, "Chevy", "Volt", null, null)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala index 3ce643e667ce4..dae92f626c225 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala @@ -68,16 +68,46 @@ class CSVTypeCastSuite extends SparkFunSuite { } test("Nullable types are handled") { - assert(CSVTypeCast.castTo("", IntegerType, nullable = true, CSVOptions()) == null) + assertNull( + CSVTypeCast.castTo("-", ByteType, nullable = true, CSVOptions("nullValue", "-"))) + assertNull( + CSVTypeCast.castTo("-", ShortType, nullable = true, CSVOptions("nullValue", "-"))) + assertNull( + CSVTypeCast.castTo("-", IntegerType, nullable = true, CSVOptions("nullValue", "-"))) + assertNull( + CSVTypeCast.castTo("-", LongType, nullable = true, CSVOptions("nullValue", "-"))) + assertNull( + CSVTypeCast.castTo("-", FloatType, nullable = true, CSVOptions("nullValue", "-"))) + assertNull( + CSVTypeCast.castTo("-", DoubleType, nullable = true, CSVOptions("nullValue", "-"))) + assertNull( + CSVTypeCast.castTo("-", BooleanType, nullable = true, CSVOptions("nullValue", "-"))) + assertNull( + CSVTypeCast.castTo("-", DecimalType.DoubleDecimal, true, CSVOptions("nullValue", "-"))) + assertNull( + CSVTypeCast.castTo("-", TimestampType, nullable = true, CSVOptions("nullValue", "-"))) + assertNull( + CSVTypeCast.castTo("-", DateType, nullable = true, CSVOptions("nullValue", "-"))) + assertNull( + CSVTypeCast.castTo("-", StringType, nullable = true, CSVOptions("nullValue", "-"))) } - test("String type should always return the same as the input") { + test("String type should also respect `nullValue`") { + assertNull( + CSVTypeCast.castTo("", StringType, nullable = true, CSVOptions())) assert( - CSVTypeCast.castTo("", StringType, nullable = true, CSVOptions()) == + CSVTypeCast.castTo("", StringType, nullable = false, CSVOptions()) == UTF8String.fromString("")) + assert( - CSVTypeCast.castTo("", StringType, nullable = false, CSVOptions()) == + CSVTypeCast.castTo("", StringType, nullable = true, CSVOptions("nullValue", "null")) == + UTF8String.fromString("")) + assert( + CSVTypeCast.castTo("", StringType, nullable = false, CSVOptions("nullValue", "null")) == UTF8String.fromString("")) + + assertNull( + CSVTypeCast.castTo(null, StringType, nullable = true, CSVOptions("nullValue", "null"))) } test("Throws exception for empty string with non null type") { @@ -170,20 +200,4 @@ class CSVTypeCastSuite extends SparkFunSuite { assert(doubleVal2 == Double.PositiveInfinity) } - test("Type-specific null values are used for casting") { - assertNull( - CSVTypeCast.castTo("-", ByteType, nullable = true, CSVOptions("nullValue", "-"))) - assertNull( - CSVTypeCast.castTo("-", ShortType, nullable = true, CSVOptions("nullValue", "-"))) - assertNull( - CSVTypeCast.castTo("-", IntegerType, nullable = true, CSVOptions("nullValue", "-"))) - assertNull( - CSVTypeCast.castTo("-", LongType, nullable = true, CSVOptions("nullValue", "-"))) - assertNull( - CSVTypeCast.castTo("-", FloatType, nullable = true, CSVOptions("nullValue", "-"))) - assertNull( - CSVTypeCast.castTo("-", DoubleType, nullable = true, CSVOptions("nullValue", "-"))) - assertNull( - CSVTypeCast.castTo("-", DecimalType.DoubleDecimal, true, CSVOptions("nullValue", "-"))) - } } From 8f0c35a4d0dd458719627be5f524792bf244d70a Mon Sep 17 00:00:00 2001 From: petermaxlee Date: Sun, 18 Sep 2016 15:22:01 -0700 Subject: [PATCH 052/306] [SPARK-17571][SQL] AssertOnQuery.condition should always return Boolean value ## What changes were proposed in this pull request? AssertOnQuery has two apply constructor: one that accepts a closure that returns boolean, and another that accepts a closure that returns Unit. This is actually very confusing because developers could mistakenly think that AssertOnQuery always require a boolean return type and verifies the return result, when indeed the value of the last statement is ignored in one of the constructors. This pull request makes the two constructor consistent and always require boolean value. It will overall make the test suites more robust against developer errors. As an evidence for the confusing behavior, this change also identified a bug with an existing test case due to file system time granularity. This pull request fixes that test case as well. ## How was this patch tested? This is a test only change. Author: petermaxlee Closes #15127 from petermaxlee/SPARK-17571. --- .../apache/spark/sql/streaming/FileStreamSourceSuite.scala | 7 +++++-- .../scala/org/apache/spark/sql/streaming/StreamTest.scala | 4 ++-- .../spark/sql/streaming/StreamingQueryListenerSuite.scala | 3 +++ 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index 886f7be59db93..a02a36c00499c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -354,7 +354,9 @@ class FileStreamSourceSuite extends FileStreamSourceTest { CheckAnswer("a", "b"), // SLeeps longer than 5ms (maxFileAge) - AssertOnQuery { _ => Thread.sleep(10); true }, + // Unfortunately since a lot of file system does not have modification time granularity + // finer grained than 1 sec, we need to use 1 sec here. + AssertOnQuery { _ => Thread.sleep(1000); true }, AddTextFileData("c\nd", src, tmp), CheckAnswer("a", "b", "c", "d"), @@ -363,7 +365,8 @@ class FileStreamSourceSuite extends FileStreamSourceTest { val source = streamExecution.logicalPlan.collect { case e: StreamingExecutionRelation => e.source.asInstanceOf[FileStreamSource] }.head - source.seenFiles.size == 1 + assert(source.seenFiles.size == 1) + true } ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index af2b58116b2aa..6c5b170d9c7c3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -188,8 +188,8 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { new AssertOnQuery(condition, message) } - def apply(message: String)(condition: StreamExecution => Unit): AssertOnQuery = { - new AssertOnQuery(s => { condition(s); true }, message) + def apply(message: String)(condition: StreamExecution => Boolean): AssertOnQuery = { + new AssertOnQuery(condition, message) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala index 77602e8167fa3..831543a47420a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala @@ -66,6 +66,7 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { // No progress events or termination events assert(listener.progressStatuses.isEmpty) assert(listener.terminationStatus === null) + true }, AddDataMemory(input, Seq(1, 2, 3)), CheckAnswer(1, 2, 3), @@ -84,6 +85,7 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { // No termination events assert(listener.terminationStatus === null) } + true }, StopStream, AssertOnQuery("Incorrect query status in onQueryTerminated") { query => @@ -97,6 +99,7 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { assert(listener.terminationException === None) } listener.checkAsyncErrors() + true } ) } From d720a4019460b6c284d0473249303c349df60a1f Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Mon, 19 Sep 2016 09:38:25 +0100 Subject: [PATCH 053/306] [SPARK-17297][DOCS] Clarify window/slide duration as absolute time, not relative to a calendar ## What changes were proposed in this pull request? Clarify that slide and window duration are absolute, and not relative to a calendar. ## How was this patch tested? Doc build (no functional change) Author: Sean Owen Closes #15142 from srowen/SPARK-17297. --- R/pkg/R/functions.R | 8 ++++++-- .../scala/org/apache/spark/sql/functions.scala | 15 +++++++++++---- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index ceedbe76711b1..4d94b4cd05d44 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -2713,11 +2713,15 @@ setMethod("from_unixtime", signature(x = "Column"), #' @param x a time Column. Must be of TimestampType. #' @param windowDuration a string specifying the width of the window, e.g. '1 second', #' '1 day 12 hours', '2 minutes'. Valid interval strings are 'week', -#' 'day', 'hour', 'minute', 'second', 'millisecond', 'microsecond'. +#' 'day', 'hour', 'minute', 'second', 'millisecond', 'microsecond'. Note that +#' the duration is a fixed length of time, and does not vary over time +#' according to a calendar. For example, '1 day' always means 86,400,000 +#' milliseconds, not a calendar day. #' @param slideDuration a string specifying the sliding interval of the window. Same format as #' \code{windowDuration}. A new window will be generated every #' \code{slideDuration}. Must be less than or equal to -#' the \code{windowDuration}. +#' the \code{windowDuration}. This duration is likewise absolute, and does not +#' vary according to a calendar. #' @param startTime the offset with respect to 1970-01-01 00:00:00 UTC with which to start #' window intervals. For example, in order to have hourly tumbling windows #' that start 15 minutes past the hour, e.g. 12:15-13:15, 13:15-14:15... provide diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 18e736ab69861..960c87f60e624 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2606,12 +2606,15 @@ object functions { * The time column must be of TimestampType. * @param windowDuration A string specifying the width of the window, e.g. `10 minutes`, * `1 second`. Check [[org.apache.spark.unsafe.types.CalendarInterval]] for - * valid duration identifiers. + * valid duration identifiers. Note that the duration is a fixed length of + * time, and does not vary over time according to a calendar. For example, + * `1 day` always means 86,400,000 milliseconds, not a calendar day. * @param slideDuration A string specifying the sliding interval of the window, e.g. `1 minute`. * A new window will be generated every `slideDuration`. Must be less than * or equal to the `windowDuration`. Check * [[org.apache.spark.unsafe.types.CalendarInterval]] for valid duration - * identifiers. + * identifiers. This duration is likewise absolute, and does not vary + * according to a calendar. * @param startTime The offset with respect to 1970-01-01 00:00:00 UTC with which to start * window intervals. For example, in order to have hourly tumbling windows that * start 15 minutes past the hour, e.g. 12:15-13:15, 13:15-14:15... provide @@ -2660,11 +2663,15 @@ object functions { * The time column must be of TimestampType. * @param windowDuration A string specifying the width of the window, e.g. `10 minutes`, * `1 second`. Check [[org.apache.spark.unsafe.types.CalendarInterval]] for - * valid duration identifiers. + * valid duration identifiers. Note that the duration is a fixed length of + * time, and does not vary over time according to a calendar. For example, + * `1 day` always means 86,400,000 milliseconds, not a calendar day. * @param slideDuration A string specifying the sliding interval of the window, e.g. `1 minute`. * A new window will be generated every `slideDuration`. Must be less than * or equal to the `windowDuration`. Check - * [[org.apache.spark.unsafe.types.CalendarInterval]] for valid duration. + * [[org.apache.spark.unsafe.types.CalendarInterval]] for valid duration + * identifiers. This duration is likewise absolute, and does not vary + * according to a calendar. * * @group datetime_funcs * @since 2.0.0 From cdea1d1343d02f0077e1f3c92ca46d04a3d30414 Mon Sep 17 00:00:00 2001 From: sureshthalamati Date: Mon, 19 Sep 2016 09:56:16 -0700 Subject: [PATCH 054/306] [SPARK-17473][SQL] fixing docker integration tests error due to different versions of jars. ## What changes were proposed in this pull request? Docker tests are using older version of jersey jars (1.19), which was used in older releases of spark. In 2.0 releases Spark was upgraded to use 2.x verison of Jersey. After upgrade to new versions, docker tests are failing with AbstractMethodError. Now that spark is upgraded to 2.x jersey version, using of shaded docker jars may not be required any more. Removed the exclusions/overrides of jersey related classes from pom file, and changed the docker-client to use regular jar instead of shaded one. ## How was this patch tested? Tested using existing docker-integration-tests Author: sureshthalamati Closes #15114 from sureshthalamati/docker_testfix-spark-17473. --- external/docker-integration-tests/pom.xml | 68 ----------------------- pom.xml | 1 - 2 files changed, 69 deletions(-) diff --git a/external/docker-integration-tests/pom.xml b/external/docker-integration-tests/pom.xml index 7417199e7693d..57d553b75b872 100644 --- a/external/docker-integration-tests/pom.xml +++ b/external/docker-integration-tests/pom.xml @@ -49,38 +49,7 @@ com.spotify docker-client - shaded test - - - - com.fasterxml.jackson.jaxrs - jackson-jaxrs-json-provider - - - com.fasterxml.jackson.datatype - jackson-datatype-guava - - - com.fasterxml.jackson.core - jackson-databind - - - org.glassfish.jersey.core - jersey-client - - - org.glassfish.jersey.connectors - jersey-apache-connector - - - org.glassfish.jersey.media - jersey-media-json-jackson - - org.apache.httpcomponents @@ -152,43 +121,6 @@ test - - - com.sun.jersey - jersey-server - 1.19 - test - - - com.sun.jersey - jersey-core - 1.19 - test - - - com.sun.jersey - jersey-servlet - 1.19 - test - - - com.sun.jersey - jersey-json - 1.19 - test - - - stax - stax-api - - - - - + + org.apache.maven.plugins + maven-deploy-plugin + + true + + + + org.apache.maven.plugins + maven-install-plugin + + true + + org.apache.maven.plugins maven-shade-plugin From b4a4421b610e776e5280fd5e7453f937f806cbd1 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Wed, 21 Sep 2016 18:56:16 +0000 Subject: [PATCH 082/306] [SPARK-11918][ML] Better error from WLS for cases like singular input ## What changes were proposed in this pull request? Update error handling for Cholesky decomposition to provide a little more info when input is singular. ## How was this patch tested? New test case; jenkins tests. Author: Sean Owen Closes #15177 from srowen/SPARK-11918. --- .../mllib/linalg/CholeskyDecomposition.scala | 19 ++++++++++++++---- .../ml/optim/WeightedLeastSquaresSuite.scala | 20 +++++++++++++++++++ 2 files changed, 35 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/CholeskyDecomposition.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/CholeskyDecomposition.scala index e4494792bb390..08f8f19c1e77d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/CholeskyDecomposition.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/CholeskyDecomposition.scala @@ -36,8 +36,7 @@ private[spark] object CholeskyDecomposition { val k = bx.length val info = new intW(0) lapack.dppsv("U", k, 1, A, bx, k, info) - val code = info.`val` - assert(code == 0, s"lapack.dppsv returned $code.") + checkReturnValue(info, "dppsv") bx } @@ -52,8 +51,20 @@ private[spark] object CholeskyDecomposition { def inverse(UAi: Array[Double], k: Int): Array[Double] = { val info = new intW(0) lapack.dpptri("U", k, UAi, info) - val code = info.`val` - assert(code == 0, s"lapack.dpptri returned $code.") + checkReturnValue(info, "dpptri") UAi } + + private def checkReturnValue(info: intW, method: String): Unit = { + info.`val` match { + case code if code < 0 => + throw new IllegalStateException(s"LAPACK.$method returned $code; arg ${-code} is illegal") + case code if code > 0 => + throw new IllegalArgumentException( + s"LAPACK.$method returned $code because A is not positive definite. Is A derived from " + + "a singular matrix (e.g. collinear column values)?") + case _ => // do nothing + } + } + } diff --git a/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala index c8de796b2de87..2cb1af0dee0bc 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala @@ -60,6 +60,26 @@ class WeightedLeastSquaresSuite extends SparkFunSuite with MLlibTestSparkContext ), 2) } + test("two collinear features result in error with no regularization") { + val singularInstances = sc.parallelize(Seq( + Instance(1.0, 1.0, Vectors.dense(1.0, 2.0)), + Instance(2.0, 1.0, Vectors.dense(2.0, 4.0)), + Instance(3.0, 1.0, Vectors.dense(3.0, 6.0)), + Instance(4.0, 1.0, Vectors.dense(4.0, 8.0)) + ), 2) + + intercept[IllegalArgumentException] { + new WeightedLeastSquares( + false, regParam = 0.0, standardizeFeatures = false, + standardizeLabel = false).fit(singularInstances) + } + + // Should not throw an exception + new WeightedLeastSquares( + false, regParam = 1.0, standardizeFeatures = false, + standardizeLabel = false).fit(singularInstances) + } + test("WLS against lm") { /* R code: From 2cd1bfa4f0c6625b0ab1dbeba2b9586b9a6a9f42 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 21 Sep 2016 14:42:41 -0700 Subject: [PATCH 083/306] [SPARK-4563][CORE] Allow driver to advertise a different network address. The goal of this feature is to allow the Spark driver to run in an isolated environment, such as a docker container, and be able to use the host's port forwarding mechanism to be able to accept connections from the outside world. The change is restricted to the driver: there is no support for achieving the same thing on executors (or the YARN AM for that matter). Those still need full access to the outside world so that, for example, connections can be made to an executor's block manager. The core of the change is simple: add a new configuration that tells what's the address the driver should bind to, which can be different than the address it advertises to executors (spark.driver.host). Everything else is plumbing the new configuration where it's needed. To use the feature, the host starting the container needs to set up the driver's port range to fall into a range that is being forwarded; this required the block manager port to need a special configuration just for the driver, which falls back to the existing spark.blockManager.port when not set. This way, users can modify the driver settings without affecting the executors; it would theoretically be nice to also have different retry counts for driver and executors, but given that docker (at least) allows forwarding port ranges, we can probably live without that for now. Because of the nature of the feature it's kinda hard to add unit tests; I just added a simple one to make sure the configuration works. This was tested with a docker image running spark-shell with the following command: docker blah blah blah \ -p 38000-38100:38000-38100 \ [image] \ spark-shell \ --num-executors 3 \ --conf spark.shuffle.service.enabled=false \ --conf spark.dynamicAllocation.enabled=false \ --conf spark.driver.host=[host's address] \ --conf spark.driver.port=38000 \ --conf spark.driver.blockManager.port=38020 \ --conf spark.ui.port=38040 Running on YARN; verified the driver works, executors start up and listen on ephemeral ports (instead of using the driver's config), and that caching and shuffling (without the shuffle service) works. Clicked through the UI to make sure all pages (including executor thread dumps) worked. Also tested apps without docker, and ran unit tests. Author: Marcelo Vanzin Closes #15120 from vanzin/SPARK-4563. --- .../scala/org/apache/spark/SparkConf.scala | 2 ++ .../scala/org/apache/spark/SparkContext.scala | 5 ++-- .../scala/org/apache/spark/SparkEnv.scala | 27 ++++++++++++++----- .../internal/config/ConfigProvider.scala | 2 +- .../spark/internal/config/package.scala | 20 ++++++++++++++ .../netty/NettyBlockTransferService.scala | 7 ++--- .../scala/org/apache/spark/rpc/RpcEnv.scala | 17 ++++++++++-- .../apache/spark/rpc/netty/NettyRpcEnv.scala | 9 ++++--- .../scala/org/apache/spark/ui/WebUI.scala | 5 ++-- .../scala/org/apache/spark/util/Utils.scala | 6 ++--- .../NettyBlockTransferSecuritySuite.scala | 6 +++-- .../NettyBlockTransferServiceSuite.scala | 5 ++-- .../spark/rpc/netty/NettyRpcEnvSuite.scala | 16 +++++++++-- .../BlockManagerReplicationSuite.scala | 2 +- .../spark/storage/BlockManagerSuite.scala | 4 +-- docs/configuration.md | 23 +++++++++++++++- .../cluster/mesos/MesosSchedulerUtils.scala | 3 ++- ...osCoarseGrainedSchedulerBackendSuite.scala | 5 ++-- .../mesos/MesosSchedulerUtilsSuite.scala | 3 ++- .../spark/streaming/CheckpointSuite.scala | 4 ++- .../streaming/ReceivedBlockHandlerSuite.scala | 2 +- 21 files changed, 133 insertions(+), 40 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index e85e5aa237384..51a699f41d15d 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -422,6 +422,8 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria configsWithAlternatives.get(key).toSeq.flatten.exists { alt => contains(alt.key) } } + private[spark] def contains(entry: ConfigEntry[_]): Boolean = contains(entry.key) + /** Copy this object */ override def clone: SparkConf = { val cloned = new SparkConf(false) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 35b6334832393..db84172e1680a 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -383,8 +383,9 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli logInfo("Spark configuration:\n" + _conf.toDebugString) } - // Set Spark driver host and port system properties - _conf.setIfMissing("spark.driver.host", Utils.localHostName()) + // Set Spark driver host and port system properties. This explicitly sets the configuration + // instead of relying on the default value of the config constant. + _conf.set(DRIVER_HOST_ADDRESS, _conf.get(DRIVER_HOST_ADDRESS)) _conf.setIfMissing("spark.driver.port", "0") _conf.set("spark.executor.id", SparkContext.DRIVER_IDENTIFIER) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index cc8e3fdc97a91..1ffeb129880f9 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -29,6 +29,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.python.PythonWorkerFactory import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ import org.apache.spark.memory.{MemoryManager, StaticMemoryManager, UnifiedMemoryManager} import org.apache.spark.metrics.MetricsSystem import org.apache.spark.network.netty.NettyBlockTransferService @@ -158,14 +159,17 @@ object SparkEnv extends Logging { listenerBus: LiveListenerBus, numCores: Int, mockOutputCommitCoordinator: Option[OutputCommitCoordinator] = None): SparkEnv = { - assert(conf.contains("spark.driver.host"), "spark.driver.host is not set on the driver!") + assert(conf.contains(DRIVER_HOST_ADDRESS), + s"${DRIVER_HOST_ADDRESS.key} is not set on the driver!") assert(conf.contains("spark.driver.port"), "spark.driver.port is not set on the driver!") - val hostname = conf.get("spark.driver.host") + val bindAddress = conf.get(DRIVER_BIND_ADDRESS) + val advertiseAddress = conf.get(DRIVER_HOST_ADDRESS) val port = conf.get("spark.driver.port").toInt create( conf, SparkContext.DRIVER_IDENTIFIER, - hostname, + bindAddress, + advertiseAddress, port, isDriver = true, isLocal = isLocal, @@ -190,6 +194,7 @@ object SparkEnv extends Logging { conf, executorId, hostname, + hostname, port, isDriver = false, isLocal = isLocal, @@ -205,7 +210,8 @@ object SparkEnv extends Logging { private def create( conf: SparkConf, executorId: String, - hostname: String, + bindAddress: String, + advertiseAddress: String, port: Int, isDriver: Boolean, isLocal: Boolean, @@ -221,8 +227,8 @@ object SparkEnv extends Logging { val securityManager = new SecurityManager(conf) val systemName = if (isDriver) driverSystemName else executorSystemName - val rpcEnv = RpcEnv.create(systemName, hostname, port, conf, securityManager, - clientMode = !isDriver) + val rpcEnv = RpcEnv.create(systemName, bindAddress, advertiseAddress, port, conf, + securityManager, clientMode = !isDriver) // Figure out which port RpcEnv actually bound to in case the original port is 0 or occupied. // In the non-driver case, the RPC env's address may be null since it may not be listening @@ -309,8 +315,15 @@ object SparkEnv extends Logging { UnifiedMemoryManager(conf, numUsableCores) } + val blockManagerPort = if (isDriver) { + conf.get(DRIVER_BLOCK_MANAGER_PORT) + } else { + conf.get(BLOCK_MANAGER_PORT) + } + val blockTransferService = - new NettyBlockTransferService(conf, securityManager, hostname, numUsableCores) + new NettyBlockTransferService(conf, securityManager, bindAddress, advertiseAddress, + blockManagerPort, numUsableCores) val blockManagerMaster = new BlockManagerMaster(registerOrLookupEndpoint( BlockManagerMaster.DRIVER_ENDPOINT_NAME, diff --git a/core/src/main/scala/org/apache/spark/internal/config/ConfigProvider.scala b/core/src/main/scala/org/apache/spark/internal/config/ConfigProvider.scala index 4b546c847a49f..97f56a64d600f 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/ConfigProvider.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/ConfigProvider.scala @@ -66,7 +66,7 @@ private[spark] class SparkConfigProvider(conf: JMap[String, String]) extends Con findEntry(key) match { case e: ConfigEntryWithDefault[_] => Option(e.defaultValueString) case e: ConfigEntryWithDefaultString[_] => Option(e.defaultValueString) - case e: FallbackConfigEntry[_] => defaultValueString(e.fallback.key) + case e: FallbackConfigEntry[_] => get(e.fallback.key) case _ => None } } diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 02d7d182a48c2..d536cc5097b2d 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -19,6 +19,7 @@ package org.apache.spark.internal import org.apache.spark.launcher.SparkLauncher import org.apache.spark.network.util.ByteUnit +import org.apache.spark.util.Utils package object config { @@ -143,4 +144,23 @@ package object config { .internal() .stringConf .createWithDefaultString("AES/CTR/NoPadding") + + private[spark] val DRIVER_HOST_ADDRESS = ConfigBuilder("spark.driver.host") + .doc("Address of driver endpoints.") + .stringConf + .createWithDefault(Utils.localHostName()) + + private[spark] val DRIVER_BIND_ADDRESS = ConfigBuilder("spark.driver.bindAddress") + .doc("Address where to bind network listen sockets on the driver.") + .fallbackConf(DRIVER_HOST_ADDRESS) + + private[spark] val BLOCK_MANAGER_PORT = ConfigBuilder("spark.blockManager.port") + .doc("Port to use for the block manager when a more specific setting is not provided.") + .intConf + .createWithDefault(0) + + private[spark] val DRIVER_BLOCK_MANAGER_PORT = ConfigBuilder("spark.driver.blockManager.port") + .doc("Port to use for the block managed on the driver.") + .fallbackConf(BLOCK_MANAGER_PORT) + } diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index 33a3219607749..dc70eb82d2b54 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -42,7 +42,9 @@ import org.apache.spark.util.Utils private[spark] class NettyBlockTransferService( conf: SparkConf, securityManager: SecurityManager, + bindAddress: String, override val hostName: String, + _port: Int, numCores: Int) extends BlockTransferService { @@ -75,12 +77,11 @@ private[spark] class NettyBlockTransferService( /** Creates and binds the TransportServer, possibly trying multiple ports. */ private def createServer(bootstraps: List[TransportServerBootstrap]): TransportServer = { def startService(port: Int): (TransportServer, Int) = { - val server = transportContext.createServer(hostName, port, bootstraps.asJava) + val server = transportContext.createServer(bindAddress, port, bootstraps.asJava) (server, server.getPort) } - val portToTry = conf.getInt("spark.blockManager.port", 0) - Utils.startServiceOnPort(portToTry, startService, conf, getClass.getName)._1 + Utils.startServiceOnPort(_port, startService, conf, getClass.getName)._1 } override def fetchBlocks( diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index 56683771335a6..579122868afc8 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -40,7 +40,19 @@ private[spark] object RpcEnv { conf: SparkConf, securityManager: SecurityManager, clientMode: Boolean = false): RpcEnv = { - val config = RpcEnvConfig(conf, name, host, port, securityManager, clientMode) + create(name, host, host, port, conf, securityManager, clientMode) + } + + def create( + name: String, + bindAddress: String, + advertiseAddress: String, + port: Int, + conf: SparkConf, + securityManager: SecurityManager, + clientMode: Boolean): RpcEnv = { + val config = RpcEnvConfig(conf, name, bindAddress, advertiseAddress, port, securityManager, + clientMode) new NettyRpcEnvFactory().create(config) } } @@ -186,7 +198,8 @@ private[spark] trait RpcEnvFileServer { private[spark] case class RpcEnvConfig( conf: SparkConf, name: String, - host: String, + bindAddress: String, + advertiseAddress: String, port: Int, securityManager: SecurityManager, clientMode: Boolean) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index 89d2fb9b47971..e51649a1ecce9 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -108,14 +108,14 @@ private[netty] class NettyRpcEnv( } } - def startServer(port: Int): Unit = { + def startServer(bindAddress: String, port: Int): Unit = { val bootstraps: java.util.List[TransportServerBootstrap] = if (securityManager.isAuthenticationEnabled()) { java.util.Arrays.asList(new SaslServerBootstrap(transportConf, securityManager)) } else { java.util.Collections.emptyList() } - server = transportContext.createServer(host, port, bootstraps) + server = transportContext.createServer(bindAddress, port, bootstraps) dispatcher.registerRpcEndpoint( RpcEndpointVerifier.NAME, new RpcEndpointVerifier(this, dispatcher)) } @@ -441,10 +441,11 @@ private[rpc] class NettyRpcEnvFactory extends RpcEnvFactory with Logging { val javaSerializerInstance = new JavaSerializer(sparkConf).newInstance().asInstanceOf[JavaSerializerInstance] val nettyEnv = - new NettyRpcEnv(sparkConf, javaSerializerInstance, config.host, config.securityManager) + new NettyRpcEnv(sparkConf, javaSerializerInstance, config.advertiseAddress, + config.securityManager) if (!config.clientMode) { val startNettyRpcEnv: Int => (NettyRpcEnv, Int) = { actualPort => - nettyEnv.startServer(actualPort) + nettyEnv.startServer(config.bindAddress, actualPort) (nettyEnv, nettyEnv.address.port) } try { diff --git a/core/src/main/scala/org/apache/spark/ui/WebUI.scala b/core/src/main/scala/org/apache/spark/ui/WebUI.scala index 38363800ec505..4118fcf46b428 100644 --- a/core/src/main/scala/org/apache/spark/ui/WebUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/WebUI.scala @@ -28,6 +28,7 @@ import org.json4s.JsonAST.{JNothing, JValue} import org.apache.spark.{SecurityManager, SparkConf, SSLOptions} import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ import org.apache.spark.ui.JettyUtils._ import org.apache.spark.util.Utils @@ -50,8 +51,8 @@ private[spark] abstract class WebUI( protected val handlers = ArrayBuffer[ServletContextHandler]() protected val pageToHandlers = new HashMap[WebUIPage, ArrayBuffer[ServletContextHandler]] protected var serverInfo: Option[ServerInfo] = None - protected val localHostName = Utils.localHostNameForURI() - protected val publicHostName = Option(conf.getenv("SPARK_PUBLIC_DNS")).getOrElse(localHostName) + protected val publicHostName = Option(conf.getenv("SPARK_PUBLIC_DNS")).getOrElse( + conf.get(DRIVER_HOST_ADDRESS)) private val className = Utils.getFormattedClassName(this) def getBasePath: String = basePath diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 9b4274a27b3be..09896c4e2f502 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -2079,9 +2079,9 @@ private[spark] object Utils extends Logging { case e: Exception if isBindCollision(e) => if (offset >= maxRetries) { val exceptionMessage = s"${e.getMessage}: Service$serviceString failed after " + - s"$maxRetries retries! Consider explicitly setting the appropriate port for the " + - s"service$serviceString (for example spark.ui.port for SparkUI) to an available " + - "port or increasing spark.port.maxRetries." + s"$maxRetries retries (starting from $startPort)! Consider explicitly setting " + + s"the appropriate port for the service$serviceString (for example spark.ui.port " + + s"for SparkUI) to an available port or increasing spark.port.maxRetries." val exception = new BindException(exceptionMessage) // restore original stack trace exception.setStackTrace(e.getStackTrace) diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala index ed15e77ff1421..022fe91edade9 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala @@ -108,11 +108,13 @@ class NettyBlockTransferSecuritySuite extends SparkFunSuite with MockitoSugar wi when(blockManager.getBlockData(blockId)).thenReturn(blockBuffer) val securityManager0 = new SecurityManager(conf0) - val exec0 = new NettyBlockTransferService(conf0, securityManager0, "localhost", numCores = 1) + val exec0 = new NettyBlockTransferService(conf0, securityManager0, "localhost", "localhost", 0, + 1) exec0.init(blockManager) val securityManager1 = new SecurityManager(conf1) - val exec1 = new NettyBlockTransferService(conf1, securityManager1, "localhost", numCores = 1) + val exec1 = new NettyBlockTransferService(conf1, securityManager1, "localhost", "localhost", 0, + 1) exec1.init(blockManager) val result = fetchBlock(exec0, exec1, "1", blockId) match { diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala index e7df7cb419339..121447a96529b 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala @@ -23,6 +23,7 @@ import org.mockito.Mockito.mock import org.scalatest._ import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} +import org.apache.spark.internal.config._ import org.apache.spark.network.BlockDataManager class NettyBlockTransferServiceSuite @@ -86,10 +87,10 @@ class NettyBlockTransferServiceSuite private def createService(port: Int): NettyBlockTransferService = { val conf = new SparkConf() .set("spark.app.id", s"test-${getClass.getName}") - .set("spark.blockManager.port", port.toString) val securityManager = new SecurityManager(conf) val blockDataManager = mock(classOf[BlockDataManager]) - val service = new NettyBlockTransferService(conf, securityManager, "localhost", numCores = 1) + val service = new NettyBlockTransferService(conf, securityManager, "localhost", "localhost", + port, 1) service.init(blockDataManager) service } diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala index 2d6543d328618..0409aa3a5dee1 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala @@ -27,8 +27,8 @@ class NettyRpcEnvSuite extends RpcEnvSuite { name: String, port: Int, clientMode: Boolean = false): RpcEnv = { - val config = RpcEnvConfig(conf, "test", "localhost", port, new SecurityManager(conf), - clientMode) + val config = RpcEnvConfig(conf, "test", "localhost", "localhost", port, + new SecurityManager(conf), clientMode) new NettyRpcEnvFactory().create(config) } @@ -41,4 +41,16 @@ class NettyRpcEnvSuite extends RpcEnvSuite { assert(e.getCause.getMessage.contains(uri)) } + test("advertise address different from bind address") { + val sparkConf = new SparkConf() + val config = RpcEnvConfig(sparkConf, "test", "localhost", "example.com", 0, + new SecurityManager(sparkConf), false) + val env = new NettyRpcEnvFactory().create(config) + try { + assert(env.address.hostPort.startsWith("example.com:")) + } finally { + env.shutdown() + } + } + } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index b9e3a364ee221..e1c1787cbd15e 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -67,7 +67,7 @@ class BlockManagerReplicationSuite extends SparkFunSuite name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { conf.set("spark.testing.memory", maxMem.toString) conf.set("spark.memory.offHeap.size", maxMem.toString) - val transfer = new NettyBlockTransferService(conf, securityMgr, "localhost", numCores = 1) + val transfer = new NettyBlockTransferService(conf, securityMgr, "localhost", "localhost", 0, 1) val memManager = UnifiedMemoryManager(conf, numCores = 1) val serializerManager = new SerializerManager(serializer, conf) val store = new BlockManager(name, rpcEnv, master, serializerManager, conf, diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 6d53d2e5f0ca6..1652fcdb964da 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -80,7 +80,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE conf.set("spark.memory.offHeap.size", maxMem.toString) val serializer = new KryoSerializer(conf) val transfer = transferService - .getOrElse(new NettyBlockTransferService(conf, securityMgr, "localhost", numCores = 1)) + .getOrElse(new NettyBlockTransferService(conf, securityMgr, "localhost", "localhost", 0, 1)) val memManager = UnifiedMemoryManager(conf, numCores = 1) val serializerManager = new SerializerManager(serializer, conf) val blockManager = new BlockManager(name, rpcEnv, master, serializerManager, conf, @@ -854,7 +854,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE test("block store put failure") { // Use Java serializer so we can create an unserializable error. conf.set("spark.testing.memory", "1200") - val transfer = new NettyBlockTransferService(conf, securityMgr, "localhost", numCores = 1) + val transfer = new NettyBlockTransferService(conf, securityMgr, "localhost", "localhost", 0, 1) val memoryManager = UnifiedMemoryManager(conf, numCores = 1) val serializerManager = new SerializerManager(new JavaSerializer(conf), conf) store = new BlockManager(SparkContext.DRIVER_IDENTIFIER, rpcEnv, master, diff --git a/docs/configuration.md b/docs/configuration.md index b50565367a98b..82ce232b336d9 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1068,11 +1068,32 @@ Apart from these, the following properties are also available, and may be useful Port for all block managers to listen on. These exist on both the driver and the executors. + + spark.driver.blockManager.port + (value of spark.blockManager.port) + + Driver-specific port for the block manager to listen on, for cases where it cannot use the same + configuration as executors. + + + + spark.driver.bindAddress + (value of spark.driver.host) + +

      Hostname or IP address where to bind listening sockets. This config overrides the SPARK_LOCAL_IP + environment variable (see below).

      + +

      It also allows a different address from the local one to be advertised to executors or external systems. + This is useful, for example, when running containers with bridged networking. For this to properly work, + the different ports used by the driver (RPC, block manager and UI) need to be forwarded from the + container's host.

      + + spark.driver.host (local hostname) - Hostname or IP address for the driver to listen on. + Hostname or IP address for the driver. This is used for communicating with the executors and the standalone Master. diff --git a/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala index e19d445137207..2963d161d6700 100644 --- a/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala +++ b/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala @@ -32,6 +32,7 @@ import org.apache.mesos.protobuf.{ByteString, GeneratedMessage} import org.apache.spark.{SparkConf, SparkContext, SparkException} import org.apache.spark.TaskState import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ import org.apache.spark.util.Utils @@ -424,7 +425,7 @@ trait MesosSchedulerUtils extends Logging { } } - val managedPortNames = List("spark.executor.port", "spark.blockManager.port") + val managedPortNames = List("spark.executor.port", BLOCK_MANAGER_PORT.key) /** * The values of the non-zero ports to be used by the executor process. diff --git a/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala b/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala index bbc79dd1eda07..c3ab488e2aa69 100644 --- a/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala +++ b/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala @@ -35,6 +35,7 @@ import org.scalatest.mock.MockitoSugar import org.scalatest.BeforeAndAfter import org.apache.spark.{LocalSparkContext, SecurityManager, SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.internal.config._ import org.apache.spark.network.shuffle.mesos.MesosExternalShuffleClient import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.RemoveExecutor @@ -221,7 +222,7 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite } test("Port offer decline when there is no appropriate range") { - setBackend(Map("spark.blockManager.port" -> "30100")) + setBackend(Map(BLOCK_MANAGER_PORT.key -> "30100")) val offeredPorts = (31100L, 31200L) val (mem, cpu) = (backend.executorMemory(sc), 4) @@ -242,7 +243,7 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite test("Port offer accepted with user defined port numbers") { val port = 30100 - setBackend(Map("spark.blockManager.port" -> s"$port")) + setBackend(Map(BLOCK_MANAGER_PORT.key -> s"$port")) val offeredPorts = (30000L, 31000L) val (mem, cpu) = (backend.executorMemory(sc), 4) diff --git a/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala b/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala index e3d794931a5e3..ec47ab153177e 100644 --- a/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala +++ b/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala @@ -26,6 +26,7 @@ import org.scalatest._ import org.scalatest.mock.MockitoSugar import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.internal.config._ class MesosSchedulerUtilsSuite extends SparkFunSuite with Matchers with MockitoSugar { @@ -179,7 +180,7 @@ class MesosSchedulerUtilsSuite extends SparkFunSuite with Matchers with MockitoS test("Port reservation is done correctly with user specified ports only") { val conf = new SparkConf() conf.set("spark.executor.port", "3000" ) - conf.set("spark.blockManager.port", "4000") + conf.set(BLOCK_MANAGER_PORT, 4000) val portResource = createTestPortResource((3000, 5000), Some("my_role")) val (resourcesLeft, resourcesToBeUsed) = utils diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index bd8f9950bf1c7..b79cc65d8b5e9 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -35,6 +35,7 @@ import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite, TestUtils} +import org.apache.spark.internal.config._ import org.apache.spark.rdd.RDD import org.apache.spark.streaming.dstream._ import org.apache.spark.streaming.scheduler._ @@ -406,7 +407,8 @@ class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester // explicitly. ssc = new StreamingContext(null, newCp, null) val restoredConf1 = ssc.conf - assert(restoredConf1.get("spark.driver.host") === "localhost") + val defaultConf = new SparkConf() + assert(restoredConf1.get("spark.driver.host") === defaultConf.get(DRIVER_HOST_ADDRESS)) assert(restoredConf1.get("spark.driver.port") !== "9999") } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala index 7e665454a5400..f2241936000a0 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -272,7 +272,7 @@ class ReceivedBlockHandlerSuite conf: SparkConf, name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { val memManager = new StaticMemoryManager(conf, Long.MaxValue, maxMem, numCores = 1) - val transfer = new NettyBlockTransferService(conf, securityMgr, "localhost", numCores = 1) + val transfer = new NettyBlockTransferService(conf, securityMgr, "localhost", "localhost", 0, 1) val blockManager = new BlockManager(name, rpcEnv, blockManagerMaster, serializerManager, conf, memManager, mapOutputTracker, shuffleManager, transfer, securityMgr, 0) memManager.setMemoryStore(blockManager.memoryStore) From 9fcf1c51d518847eda7f5ea71337cfa7def3c45c Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Wed, 21 Sep 2016 17:49:36 -0400 Subject: [PATCH 084/306] [SPARK-17623][CORE] Clarify type of TaskEndReason with a failed task. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? In TaskResultGetter, enqueueFailedTask currently deserializes the result as a TaskEndReason. But the type is actually more specific, its a TaskFailedReason. This just leads to more blind casting later on – it would be more clear if the msg was cast to the right type immediately, so method parameter types could be tightened. ## How was this patch tested? Existing unit tests via jenkins. Note that the code was already performing a blind-cast to a TaskFailedReason before in any case, just in a different spot, so there shouldn't be any behavior change. Author: Imran Rashid Closes #15181 from squito/SPARK-17623. --- .../spark/executor/CommitDeniedException.scala | 4 ++-- .../scala/org/apache/spark/executor/Executor.scala | 4 ++-- .../apache/spark/scheduler/TaskResultGetter.scala | 4 ++-- .../apache/spark/scheduler/TaskSchedulerImpl.scala | 2 +- .../org/apache/spark/scheduler/TaskSetManager.scala | 12 +++--------- .../apache/spark/shuffle/FetchFailedException.scala | 4 ++-- .../org/apache/spark/util/JsonProtocolSuite.scala | 2 +- 7 files changed, 13 insertions(+), 19 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala b/core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala index 7d84889a2def0..326e042419774 100644 --- a/core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala +++ b/core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala @@ -17,7 +17,7 @@ package org.apache.spark.executor -import org.apache.spark.{TaskCommitDenied, TaskEndReason} +import org.apache.spark.{TaskCommitDenied, TaskFailedReason} /** * Exception thrown when a task attempts to commit output to HDFS but is denied by the driver. @@ -29,5 +29,5 @@ private[spark] class CommitDeniedException( attemptNumber: Int) extends Exception(msg) { - def toTaskEndReason: TaskEndReason = TaskCommitDenied(jobID, splitID, attemptNumber) + def toTaskFailedReason: TaskFailedReason = TaskCommitDenied(jobID, splitID, attemptNumber) } diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index fbf2b86db1a2e..668ec41153086 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -355,7 +355,7 @@ private[spark] class Executor( } catch { case ffe: FetchFailedException => - val reason = ffe.toTaskEndReason + val reason = ffe.toTaskFailedReason setTaskFinishedAndClearInterruptStatus() execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) @@ -370,7 +370,7 @@ private[spark] class Executor( execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled)) case CausedBy(cDE: CommitDeniedException) => - val reason = cDE.toTaskEndReason + val reason = cDE.toTaskFailedReason setTaskFinishedAndClearInterruptStatus() execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala index 685ef55c66876..1c3fcbd4612a0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala @@ -118,14 +118,14 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul def enqueueFailedTask(taskSetManager: TaskSetManager, tid: Long, taskState: TaskState, serializedData: ByteBuffer) { - var reason : TaskEndReason = UnknownReason + var reason : TaskFailedReason = UnknownReason try { getTaskResultExecutor.execute(new Runnable { override def run(): Unit = Utils.logUncaughtExceptions { val loader = Utils.getContextOrSparkClassLoader try { if (serializedData != null && serializedData.limit() > 0) { - reason = serializer.get().deserialize[TaskEndReason]( + reason = serializer.get().deserialize[TaskFailedReason]( serializedData, loader) } } catch { diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index ee5cbfeb47353..52a7186cbf45c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -431,7 +431,7 @@ private[spark] class TaskSchedulerImpl( taskSetManager: TaskSetManager, tid: Long, taskState: TaskState, - reason: TaskEndReason): Unit = synchronized { + reason: TaskFailedReason): Unit = synchronized { taskSetManager.handleFailedTask(tid, taskState, reason) if (!taskSetManager.isZombie && taskState != TaskState.KILLED) { // Need to revive offers again now that the task set manager state has been updated to diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 2fef447b0a3c1..226bed284a40a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -696,7 +696,7 @@ private[spark] class TaskSetManager( * Marks the task as failed, re-adds it to the list of pending tasks, and notifies the * DAG Scheduler. */ - def handleFailedTask(tid: Long, state: TaskState, reason: TaskEndReason) { + def handleFailedTask(tid: Long, state: TaskState, reason: TaskFailedReason) { val info = taskInfos(tid) if (info.failed || info.killed) { return @@ -707,7 +707,7 @@ private[spark] class TaskSetManager( copiesRunning(index) -= 1 var accumUpdates: Seq[AccumulatorV2[_, _]] = Seq.empty val failureReason = s"Lost task ${info.id} in stage ${taskSet.id} (TID $tid, ${info.host}): " + - reason.asInstanceOf[TaskFailedReason].toErrorString + reason.toErrorString val failureException: Option[Throwable] = reason match { case fetchFailed: FetchFailed => logWarning(failureReason) @@ -765,10 +765,6 @@ private[spark] class TaskSetManager( case e: TaskFailedReason => // TaskResultLost, TaskKilled, and others logWarning(failureReason) None - - case e: TaskEndReason => - logError("Unknown TaskEndReason: " + e) - None } // always add to failed executors failedExecutors.getOrElseUpdate(index, new HashMap[String, Long]()). @@ -784,9 +780,7 @@ private[spark] class TaskSetManager( addPendingTask(index) } - if (!isZombie && state != TaskState.KILLED - && reason.isInstanceOf[TaskFailedReason] - && reason.asInstanceOf[TaskFailedReason].countTowardsTaskFailures) { + if (!isZombie && state != TaskState.KILLED && reason.countTowardsTaskFailures) { assert (null != failureReason) numFailures(index) += 1 if (numFailures(index) >= maxTaskFailures) { diff --git a/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala index b2d050b218f53..498c12e196ce0 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala @@ -17,7 +17,7 @@ package org.apache.spark.shuffle -import org.apache.spark.{FetchFailed, TaskEndReason} +import org.apache.spark.{FetchFailed, TaskFailedReason} import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.Utils @@ -45,7 +45,7 @@ private[spark] class FetchFailedException( this(bmAddress, shuffleId, mapId, reduceId, cause.getMessage, cause) } - def toTaskEndReason: TaskEndReason = FetchFailed(bmAddress, shuffleId, mapId, reduceId, + def toTaskFailedReason: TaskFailedReason = FetchFailed(bmAddress, shuffleId, mapId, reduceId, Utils.exceptionString(this)) } diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index c89be22a34c9d..00314abf49fd4 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -146,7 +146,7 @@ class JsonProtocolSuite extends SparkFunSuite { val fetchFailed = FetchFailed(BlockManagerId("With or", "without you", 15), 17, 18, 19, "Some exception") val fetchMetadataFailed = new MetadataFetchFailedException(17, - 19, "metadata Fetch failed exception").toTaskEndReason + 19, "metadata Fetch failed exception").toTaskFailedReason val exceptionFailure = new ExceptionFailure(exception, Seq.empty[AccumulableInfo]) testTaskEndReason(Success) testTaskEndReason(Resubmitted) From 8c3ee2bc42e6320b9341cebdba51a00162c897ea Mon Sep 17 00:00:00 2001 From: jerryshao Date: Wed, 21 Sep 2016 17:57:21 -0400 Subject: [PATCH 085/306] [SPARK-17512][CORE] Avoid formatting to python path for yarn and mesos cluster mode ## What changes were proposed in this pull request? Yarn and mesos cluster mode support remote python path (HDFS/S3 scheme) by their own mechanism, it is not necessary to check and format the python when running on these modes. This is a potential regression compared to 1.6, so here propose to fix it. ## How was this patch tested? Unit test to verify SparkSubmit arguments, also with local cluster verification. Because of lack of `MiniDFSCluster` support in Spark unit test, there's no integration test added. Author: jerryshao Closes #15137 from jerryshao/SPARK-17512. --- .../org/apache/spark/deploy/SparkSubmit.scala | 13 ++++++++++--- .../spark/deploy/SparkSubmitSuite.scala | 19 +++++++++++++++++++ 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 7b6d5a394bc35..80611658a1640 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -311,7 +311,7 @@ object SparkSubmit { // In Mesos cluster mode, non-local python files are automatically downloaded by Mesos. if (args.isPython && !isYarnCluster && !isMesosCluster) { if (Utils.nonLocalPaths(args.primaryResource).nonEmpty) { - printErrorAndExit(s"Only local python files are supported: $args.primaryResource") + printErrorAndExit(s"Only local python files are supported: ${args.primaryResource}") } val nonLocalPyFiles = Utils.nonLocalPaths(args.pyFiles).mkString(",") if (nonLocalPyFiles.nonEmpty) { @@ -322,7 +322,7 @@ object SparkSubmit { // Require all R files to be local if (args.isR && !isYarnCluster) { if (Utils.nonLocalPaths(args.primaryResource).nonEmpty) { - printErrorAndExit(s"Only local R files are supported: $args.primaryResource") + printErrorAndExit(s"Only local R files are supported: ${args.primaryResource}") } } @@ -633,7 +633,14 @@ object SparkSubmit { // explicitly sets `spark.submit.pyFiles` in his/her default properties file. sysProps.get("spark.submit.pyFiles").foreach { pyFiles => val resolvedPyFiles = Utils.resolveURIs(pyFiles) - val formattedPyFiles = PythonRunner.formatPaths(resolvedPyFiles).mkString(",") + val formattedPyFiles = if (!isYarnCluster && !isMesosCluster) { + PythonRunner.formatPaths(resolvedPyFiles).mkString(",") + } else { + // Ignoring formatting python path in yarn and mesos cluster mode, these two modes + // support dealing with remote python files, they could distribute and add python files + // locally. + resolvedPyFiles + } sysProps("spark.submit.pyFiles") = formattedPyFiles } diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 961ece3e0004a..31c8fb26460df 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -582,6 +582,25 @@ class SparkSubmitSuite val sysProps3 = SparkSubmit.prepareSubmitEnvironment(appArgs3)._3 sysProps3("spark.submit.pyFiles") should be( PythonRunner.formatPaths(Utils.resolveURIs(pyFiles)).mkString(",")) + + // Test remote python files + val f4 = File.createTempFile("test-submit-remote-python-files", "", tmpDir) + val writer4 = new PrintWriter(f4) + val remotePyFiles = "hdfs:///tmp/file1.py,hdfs:///tmp/file2.py" + writer4.println("spark.submit.pyFiles " + remotePyFiles) + writer4.close() + val clArgs4 = Seq( + "--master", "yarn", + "--deploy-mode", "cluster", + "--properties-file", f4.getPath, + "hdfs:///tmp/mister.py" + ) + val appArgs4 = new SparkSubmitArguments(clArgs4) + val sysProps4 = SparkSubmit.prepareSubmitEnvironment(appArgs4)._3 + // Should not format python path for yarn cluster mode + sysProps4("spark.submit.pyFiles") should be( + Utils.resolveURIs(remotePyFiles) + ) } test("user classpath first in driver") { From 7cbe2164499e83b6c009fdbab0fbfffe89a2ecc0 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Wed, 21 Sep 2016 17:12:52 -0700 Subject: [PATCH 086/306] [SPARK-17569] Make StructuredStreaming FileStreamSource batch generation faster ## What changes were proposed in this pull request? While getting the batch for a `FileStreamSource` in StructuredStreaming, we know which files we must take specifically. We already have verified that they exist, and have committed them to a metadata log. When creating the FileSourceRelation however for an incremental execution, the code checks the existence of every single file once again! When you have 100,000s of files in a folder, creating the first batch takes 2 hours+ when working with S3! This PR disables that check ## How was this patch tested? Added a unit test to `FileStreamSource`. Author: Burak Yavuz Closes #15122 from brkyvz/SPARK-17569. --- .../execution/datasources/DataSource.scala | 10 +++- .../streaming/FileStreamSource.scala | 3 +- .../streaming/FileStreamSourceSuite.scala | 53 ++++++++++++++++++- 3 files changed, 62 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 93154bd2ca69c..413976a7ef244 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -316,8 +316,14 @@ case class DataSource( /** * Create a resolved [[BaseRelation]] that can be used to read data from or write data into this * [[DataSource]] + * + * @param checkFilesExist Whether to confirm that the files exist when generating the + * non-streaming file based datasource. StructuredStreaming jobs already + * list file existence, and when generating incremental jobs, the batch + * is considered as a non-streaming file based data source. Since we know + * that files already exist, we don't need to check them again. */ - def resolveRelation(): BaseRelation = { + def resolveRelation(checkFilesExist: Boolean = true): BaseRelation = { val caseInsensitiveOptions = new CaseInsensitiveMap(options) val relation = (providingClass.newInstance(), userSpecifiedSchema) match { // TODO: Throw when too much is given. @@ -368,7 +374,7 @@ case class DataSource( throw new AnalysisException(s"Path does not exist: $qualified") } // Sufficient to check head of the globPath seq for non-glob scenario - if (!fs.exists(globPath.head)) { + if (checkFilesExist && !fs.exists(globPath.head)) { throw new AnalysisException(s"Path does not exist: ${globPath.head}") } globPath diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala index 0dc08b1467b14..5ebc083a7da92 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala @@ -133,7 +133,8 @@ class FileStreamSource( userSpecifiedSchema = Some(schema), className = fileFormatClassName, options = sourceOptions.optionMapWithoutPath) - Dataset.ofRows(sparkSession, LogicalRelation(newDataSource.resolveRelation())) + Dataset.ofRows(sparkSession, LogicalRelation(newDataSource.resolveRelation( + checkFilesExist = false))) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceSuite.scala index c6db2fd3f908e..e8fa6a59c57ae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceSuite.scala @@ -17,9 +17,19 @@ package org.apache.spark.sql.execution.streaming +import java.io.File +import java.net.URI + +import scala.util.Random + +import org.apache.hadoop.fs.{FileStatus, Path, RawLocalFileSystem} + import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.execution.streaming.ExistsThrowsExceptionFileSystem._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.StructType -class FileStreamSourceSuite extends SparkFunSuite { +class FileStreamSourceSuite extends SparkFunSuite with SharedSQLContext { import FileStreamSource._ @@ -73,4 +83,45 @@ class FileStreamSourceSuite extends SparkFunSuite { assert(map.isNewFile(FileEntry("b", 10))) } + testWithUninterruptibleThread("do not recheck that files exist during getBatch") { + withTempDir { temp => + spark.conf.set( + s"fs.$scheme.impl", + classOf[ExistsThrowsExceptionFileSystem].getName) + // add the metadata entries as a pre-req + val dir = new File(temp, "dir") // use non-existent directory to test whether log make the dir + val metadataLog = + new FileStreamSourceLog(FileStreamSourceLog.VERSION, spark, dir.getAbsolutePath) + assert(metadataLog.add(0, Array(FileEntry(s"$scheme:///file1", 100L)))) + + val newSource = new FileStreamSource(spark, s"$scheme:///", "parquet", StructType(Nil), + dir.getAbsolutePath, Map.empty) + // this method should throw an exception if `fs.exists` is called during resolveRelation + newSource.getBatch(None, LongOffset(1)) + } + } +} + +/** Fake FileSystem to test whether the method `fs.exists` is called during + * `DataSource.resolveRelation`. + */ +class ExistsThrowsExceptionFileSystem extends RawLocalFileSystem { + override def getUri: URI = { + URI.create(s"$scheme:///") + } + + override def exists(f: Path): Boolean = { + throw new IllegalArgumentException("Exists shouldn't have been called!") + } + + /** Simply return an empty file for now. */ + override def listStatus(file: Path): Array[FileStatus] = { + val emptyFile = new FileStatus() + emptyFile.setPath(file) + Array(emptyFile) + } +} + +object ExistsThrowsExceptionFileSystem { + val scheme = s"FileStreamSourceSuite${math.abs(Random.nextInt)}fs" } From c133907c5d9a6e6411b896b5e0cff48b2beff09f Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 21 Sep 2016 20:08:28 -0700 Subject: [PATCH 087/306] [SPARK-17577][SPARKR][CORE] SparkR support add files to Spark job and get by executors ## What changes were proposed in this pull request? Scala/Python users can add files to Spark job by submit options ```--files``` or ```SparkContext.addFile()```. Meanwhile, users can get the added file by ```SparkFiles.get(filename)```. We should also support this function for SparkR users, since they also have the requirements for some shared dependency files. For example, SparkR users can download third party R packages to driver firstly, add these files to the Spark job as dependency by this API and then each executor can install these packages by ```install.packages```. ## How was this patch tested? Add unit test. Author: Yanbo Liang Closes #15131 from yanboliang/spark-17577. --- R/pkg/NAMESPACE | 3 ++ R/pkg/R/context.R | 48 +++++++++++++++++++ R/pkg/inst/tests/testthat/test_context.R | 13 +++++ .../scala/org/apache/spark/SparkContext.scala | 6 +-- 4 files changed, 67 insertions(+), 3 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index a5e9cbdc37f06..267a38c21530b 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -336,6 +336,9 @@ export("as.DataFrame", "read.parquet", "read.text", "spark.lapply", + "spark.addFile", + "spark.getSparkFilesRootDirectory", + "spark.getSparkFiles", "sql", "str", "tableToDF", diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R index 13ade49eabfa6..4793578ad684e 100644 --- a/R/pkg/R/context.R +++ b/R/pkg/R/context.R @@ -225,6 +225,54 @@ setCheckpointDir <- function(sc, dirName) { invisible(callJMethod(sc, "setCheckpointDir", suppressWarnings(normalizePath(dirName)))) } +#' Add a file or directory to be downloaded with this Spark job on every node. +#' +#' The path passed can be either a local file, a file in HDFS (or other Hadoop-supported +#' filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs, +#' use spark.getSparkFiles(fileName) to find its download location. +#' +#' @rdname spark.addFile +#' @param path The path of the file to be added +#' @export +#' @examples +#'\dontrun{ +#' spark.addFile("~/myfile") +#'} +#' @note spark.addFile since 2.1.0 +spark.addFile <- function(path) { + sc <- getSparkContext() + invisible(callJMethod(sc, "addFile", suppressWarnings(normalizePath(path)))) +} + +#' Get the root directory that contains files added through spark.addFile. +#' +#' @rdname spark.getSparkFilesRootDirectory +#' @return the root directory that contains files added through spark.addFile +#' @export +#' @examples +#'\dontrun{ +#' spark.getSparkFilesRootDirectory() +#'} +#' @note spark.getSparkFilesRootDirectory since 2.1.0 +spark.getSparkFilesRootDirectory <- function() { + callJStatic("org.apache.spark.SparkFiles", "getRootDirectory") +} + +#' Get the absolute path of a file added through spark.addFile. +#' +#' @rdname spark.getSparkFiles +#' @param fileName The name of the file added through spark.addFile +#' @return the absolute path of a file added through spark.addFile. +#' @export +#' @examples +#'\dontrun{ +#' spark.getSparkFiles("myfile") +#'} +#' @note spark.getSparkFiles since 2.1.0 +spark.getSparkFiles <- function(fileName) { + callJStatic("org.apache.spark.SparkFiles", "get", as.character(fileName)) +} + #' Run a function over a list of elements, distributing the computations with Spark #' #' Run a function over a list of elements, distributing the computations with Spark. Applies a diff --git a/R/pkg/inst/tests/testthat/test_context.R b/R/pkg/inst/tests/testthat/test_context.R index 1ab7f319df9ff..0495418bb7779 100644 --- a/R/pkg/inst/tests/testthat/test_context.R +++ b/R/pkg/inst/tests/testthat/test_context.R @@ -166,3 +166,16 @@ test_that("spark.lapply should perform simple transforms", { expect_equal(doubled, as.list(2 * 1:10)) sparkR.session.stop() }) + +test_that("add and get file to be downloaded with Spark job on every node", { + sparkR.sparkContext() + path <- tempfile(pattern = "hello", fileext = ".txt") + filename <- basename(path) + words <- "Hello World!" + writeLines(words, path) + spark.addFile(path) + download_path <- spark.getSparkFiles(filename) + expect_equal(readLines(download_path), words) + unlink(path) + sparkR.session.stop() +}) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index db84172e1680a..1981ad5671093 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1427,7 +1427,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * supported for Hadoop-supported filesystems. */ def addFile(path: String, recursive: Boolean): Unit = { - val uri = new URI(path) + val uri = new Path(path).toUri val schemeCorrectedPath = uri.getScheme match { case null | "local" => new File(path).getCanonicalFile.toURI.toString case _ => path @@ -1458,8 +1458,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli logInfo(s"Added file $path at $key with timestamp $timestamp") // Fetch the file locally so that closures which are run on the driver can still use the // SparkFiles API to access files. - Utils.fetchFile(path, new File(SparkFiles.getRootDirectory()), conf, env.securityManager, - hadoopConfiguration, timestamp, useCache = false) + Utils.fetchFile(uri.toString, new File(SparkFiles.getRootDirectory()), conf, + env.securityManager, hadoopConfiguration, timestamp, useCache = false) postEnvironmentUpdate() } } From 6902edab7e80e96e3f57cf80f26cefb209d4d63c Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 21 Sep 2016 20:14:18 -0700 Subject: [PATCH 088/306] [SPARK-17315][FOLLOW-UP][SPARKR][ML] Fix print of Kolmogorov-Smirnov test summary ## What changes were proposed in this pull request? #14881 added Kolmogorov-Smirnov Test wrapper to SparkR. I found that ```print.summary.KSTest``` was implemented inappropriately and result in no effect. Running the following code for KSTest: ```Scala data <- data.frame(test = c(0.1, 0.15, 0.2, 0.3, 0.25, -1, -0.5)) df <- createDataFrame(data) testResult <- spark.kstest(df, "test", "norm") summary(testResult) ``` Before this PR: ![image](https://cloud.githubusercontent.com/assets/1962026/18615016/b9a2823a-7d4f-11e6-934b-128beade355e.png) After this PR: ![image](https://cloud.githubusercontent.com/assets/1962026/18615014/aafe2798-7d4f-11e6-8b99-c705bb9fe8f2.png) The new implementation is similar with [```print.summary.GeneralizedLinearRegressionModel```](https://github.com/apache/spark/blob/master/R/pkg/R/mllib.R#L284) of SparkR and [```print.summary.glm```](https://svn.r-project.org/R/trunk/src/library/stats/R/glm.R) of native R. BTW, I removed the comparison of ```print.summary.KSTest``` in unit test, since it's only wrappers of the summary output which has been checked. Another reason is that these comparison will output summary information to the test console, it will make the test output in a mess. ## How was this patch tested? Existing test. Author: Yanbo Liang Closes #15139 from yanboliang/spark-17315. --- R/pkg/R/mllib.R | 16 +++++++++------- R/pkg/inst/tests/testthat/test_mllib.R | 16 ++-------------- 2 files changed, 11 insertions(+), 21 deletions(-) diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index 234b208166b54..98db367a856ee 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -1398,20 +1398,22 @@ setMethod("summary", signature(object = "KSTest"), distParams <- unlist(callJMethod(jobj, "distParams")) degreesOfFreedom <- callJMethod(jobj, "degreesOfFreedom") - list(p.value = pValue, statistic = statistic, nullHypothesis = nullHypothesis, - nullHypothesis.name = distName, nullHypothesis.parameters = distParams, - degreesOfFreedom = degreesOfFreedom) + ans <- list(p.value = pValue, statistic = statistic, nullHypothesis = nullHypothesis, + nullHypothesis.name = distName, nullHypothesis.parameters = distParams, + degreesOfFreedom = degreesOfFreedom, jobj = jobj) + class(ans) <- "summary.KSTest" + ans }) # Prints the summary of KSTest #' @rdname spark.kstest -#' @param x test result object of KSTest by \code{spark.kstest}. +#' @param x summary object of KSTest returned by \code{summary}. #' @export #' @note print.summary.KSTest since 2.1.0 print.summary.KSTest <- function(x, ...) { - jobj <- x@jobj + jobj <- x$jobj summaryStr <- callJMethod(jobj, "summary") - cat(summaryStr) - invisible(summaryStr) + cat(summaryStr, "\n") + invisible(x) } diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index 5b1404c621bd1..24c40a88231a7 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -760,13 +760,7 @@ test_that("spark.kstest", { expect_equal(stats$p.value, rStats$p.value, tolerance = 1e-4) expect_equal(stats$statistic, unname(rStats$statistic), tolerance = 1e-4) - - printStr <- print.summary.KSTest(testResult) - expect_match(printStr, paste0("Kolmogorov-Smirnov test summary:\\n", - "degrees of freedom = 0 \\n", - "statistic = 0.38208[0-9]* \\n", - "pValue = 0.19849[0-9]* \\n", - ".*"), perl = TRUE) + expect_match(capture.output(stats)[1], "Kolmogorov-Smirnov test summary:") testResult <- spark.kstest(df, "test", "norm", -0.5) stats <- summary(testResult) @@ -775,13 +769,7 @@ test_that("spark.kstest", { expect_equal(stats$p.value, rStats$p.value, tolerance = 1e-4) expect_equal(stats$statistic, unname(rStats$statistic), tolerance = 1e-4) - - printStr <- print.summary.KSTest(testResult) - expect_match(printStr, paste0("Kolmogorov-Smirnov test summary:\\n", - "degrees of freedom = 0 \\n", - "statistic = 0.44003[0-9]* \\n", - "pValue = 0.09470[0-9]* \\n", - ".*"), perl = TRUE) + expect_match(capture.output(stats)[1], "Kolmogorov-Smirnov test summary:") }) sparkR.session.stop() From 3497ebe511fee67e66387e9e737c843a2939ce45 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Wed, 21 Sep 2016 20:59:46 -0700 Subject: [PATCH 089/306] [SPARK-17627] Mark Streaming Providers Experimental All of structured streaming is experimental in its first release. We missed the annotation on two of the APIs. Author: Michael Armbrust Closes #15188 from marmbrus/experimentalApi. --- .../main/scala/org/apache/spark/sql/sources/interfaces.scala | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index a16d7ed0a7c29..6484c782b5d15 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -112,8 +112,10 @@ trait SchemaRelationProvider { } /** + * ::Experimental:: * Implemented by objects that can produce a streaming [[Source]] for a specific format or system. */ +@Experimental trait StreamSourceProvider { /** Returns the name and schema of the source that can be used to continually read data. */ @@ -132,8 +134,10 @@ trait StreamSourceProvider { } /** + * ::Experimental:: * Implemented by objects that can produce a streaming [[Sink]] for a specific format or system. */ +@Experimental trait StreamSinkProvider { def createSink( sqlContext: SQLContext, From 8bde03bf9a0896ea59ceaa699df7700351a130fb Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 21 Sep 2016 21:02:30 -0700 Subject: [PATCH 090/306] [SPARK-17494][SQL] changePrecision() on compact decimal should respect rounding mode ## What changes were proposed in this pull request? Floor()/Ceil() of decimal is implemented using changePrecision() by passing a rounding mode, but the rounding mode is not respected when the decimal is in compact mode (could fit within a Long). This Update the changePrecision() to respect rounding mode, which could be ROUND_FLOOR, ROUND_CEIL, ROUND_HALF_UP, ROUND_HALF_EVEN. ## How was this patch tested? Added regression tests. Author: Davies Liu Closes #15154 from davies/decimal_round. --- .../org/apache/spark/sql/types/Decimal.scala | 28 ++++++++++++++++--- .../apache/spark/sql/types/DecimalSuite.scala | 15 ++++++++++ 2 files changed, 39 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index cc8175c0a366d..70859052872dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -242,10 +242,30 @@ final class Decimal extends Ordered[Decimal] with Serializable { if (scale < _scale) { // Easier case: we just need to divide our scale down val diff = _scale - scale - val droppedDigits = longVal % POW_10(diff) - longVal /= POW_10(diff) - if (math.abs(droppedDigits) * 2 >= POW_10(diff)) { - longVal += (if (longVal < 0) -1L else 1L) + val pow10diff = POW_10(diff) + // % and / always round to 0 + val droppedDigits = longVal % pow10diff + longVal /= pow10diff + roundMode match { + case ROUND_FLOOR => + if (droppedDigits < 0) { + longVal += -1L + } + case ROUND_CEILING => + if (droppedDigits > 0) { + longVal += 1L + } + case ROUND_HALF_UP => + if (math.abs(droppedDigits) * 2 >= pow10diff) { + longVal += (if (droppedDigits < 0) -1L else 1L) + } + case ROUND_HALF_EVEN => + val doubled = math.abs(droppedDigits) * 2 + if (doubled > pow10diff || doubled == pow10diff && longVal % 2 != 0) { + longVal += (if (droppedDigits < 0) -1L else 1L) + } + case _ => + sys.error(s"Not supported rounding mode: $roundMode") } } else if (scale > _scale) { // We might be able to multiply longVal by a power of 10 and not overflow, but if not, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala index a10c0e39eb687..52d0692524d0f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.types import org.scalatest.PrivateMethodTester import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types.Decimal._ class DecimalSuite extends SparkFunSuite with PrivateMethodTester { /** Check that a Decimal has the given string representation, precision and scale */ @@ -191,4 +192,18 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester { assert(new Decimal().set(100L, 10, 0).toUnscaledLong === 100L) assert(Decimal(Long.MaxValue, 100, 0).toUnscaledLong === Long.MaxValue) } + + test("changePrecision() on compact decimal should respect rounding mode") { + Seq(ROUND_FLOOR, ROUND_CEILING, ROUND_HALF_UP, ROUND_HALF_EVEN).foreach { mode => + Seq("0.4", "0.5", "0.6", "1.0", "1.1", "1.6", "2.5", "5.5").foreach { n => + Seq("", "-").foreach { sign => + val bd = BigDecimal(sign + n) + val unscaled = (bd * 10).toLongExact + val d = Decimal(unscaled, 8, 1) + assert(d.changePrecision(10, 0, mode)) + assert(d.toString === bd.setScale(0, mode).toString(), s"num: $sign$n, mode: $mode") + } + } + } + } } From b50b34f5611a1f182ba9b6eaf86c666bbd9f9eb0 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 22 Sep 2016 12:52:09 +0800 Subject: [PATCH 091/306] [SPARK-17609][SQL] SessionCatalog.tableExists should not check temp view ## What changes were proposed in this pull request? After #15054 , there is no place in Spark SQL that need `SessionCatalog.tableExists` to check temp views, so this PR makes `SessionCatalog.tableExists` only check permanent table/view and removes some hacks. This PR also improves the `getTempViewOrPermanentTableMetadata` that is introduced in #15054 , to make the code simpler. ## How was this patch tested? existing tests Author: Wenchen Fan Closes #15160 from cloud-fan/exists. --- .../sql/catalyst/catalog/SessionCatalog.scala | 70 +++++++++---------- .../catalog/SessionCatalogSuite.scala | 30 ++++---- .../apache/spark/sql/DataFrameWriter.scala | 9 +-- .../command/createDataSourceTables.scala | 15 ++-- .../spark/sql/execution/command/ddl.scala | 43 +++++------- .../spark/sql/execution/command/tables.scala | 17 +---- .../spark/sql/internal/CatalogImpl.scala | 6 +- .../sql/hive/MetastoreDataSourcesSuite.scala | 2 +- .../sql/hive/execution/HiveDDLSuite.scala | 4 +- 9 files changed, 81 insertions(+), 115 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index ef29c75c01898..8c01c7a3f2bd5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -245,6 +245,16 @@ class SessionCatalog( externalCatalog.alterTable(newTableDefinition) } + /** + * Return whether a table/view with the specified name exists. If no database is specified, check + * with current database. + */ + def tableExists(name: TableIdentifier): Boolean = synchronized { + val db = formatDatabaseName(name.database.getOrElse(currentDb)) + val table = formatTableName(name.table) + externalCatalog.tableExists(db, table) + } + /** * Retrieve the metadata of an existing permanent table/view. If no database is specified, * assume the table/view is in the current database. If the specified table/view is not found @@ -270,24 +280,6 @@ class SessionCatalog( externalCatalog.getTableOption(db, table) } - /** - * Retrieve the metadata of an existing temporary view or permanent table/view. - * If the temporary view does not exist, tries to get the metadata an existing permanent - * table/view. If no database is specified, assume the table/view is in the current database. - * If the specified table/view is not found in the database then a [[NoSuchTableException]] is - * thrown. - */ - def getTempViewOrPermanentTableMetadata(name: String): CatalogTable = synchronized { - val table = formatTableName(name) - getTempView(table).map { plan => - CatalogTable( - identifier = TableIdentifier(table), - tableType = CatalogTableType.VIEW, - storage = CatalogStorageFormat.empty, - schema = plan.output.toStructType) - }.getOrElse(getTableMetadata(TableIdentifier(name))) - } - /** * Load files stored in given path into an existing metastore table. * If no database is specified, assume the table is in the current database. @@ -368,6 +360,30 @@ class SessionCatalog( // | Methods that interact with temporary and metastore tables | // ------------------------------------------------------------- + /** + * Retrieve the metadata of an existing temporary view or permanent table/view. + * + * If a database is specified in `name`, this will return the metadata of table/view in that + * database. + * If no database is specified, this will first attempt to get the metadata of a temporary view + * with the same name, then, if that does not exist, return the metadata of table/view in the + * current database. + */ + def getTempViewOrPermanentTableMetadata(name: TableIdentifier): CatalogTable = synchronized { + val table = formatTableName(name.table) + if (name.database.isDefined) { + getTableMetadata(name) + } else { + getTempView(table).map { plan => + CatalogTable( + identifier = TableIdentifier(table), + tableType = CatalogTableType.VIEW, + storage = CatalogStorageFormat.empty, + schema = plan.output.toStructType) + }.getOrElse(getTableMetadata(name)) + } + } + /** * Rename a table. * @@ -449,24 +465,6 @@ class SessionCatalog( } } - /** - * Return whether a table/view with the specified name exists. - * - * Note: If a database is explicitly specified, then this will return whether the table/view - * exists in that particular database instead. In that case, even if there is a temporary - * table with the same name, we will return false if the specified database does not - * contain the table/view. - */ - def tableExists(name: TableIdentifier): Boolean = synchronized { - val db = formatDatabaseName(name.database.getOrElse(currentDb)) - val table = formatTableName(name.table) - if (isTemporaryTable(name)) { - true - } else { - externalCatalog.tableExists(db, table) - } - } - /** * Return whether a table with the specified name is a temporary table. * diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index 384a7308615e5..915ed8f8b1787 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -425,35 +425,37 @@ class SessionCatalogSuite extends SparkFunSuite { assert(!catalog.tableExists(TableIdentifier("tbl2", Some("db1")))) // If database is explicitly specified, do not check temporary tables val tempTable = Range(1, 10, 1, 10) - catalog.createTempView("tbl3", tempTable, overrideIfExists = false) assert(!catalog.tableExists(TableIdentifier("tbl3", Some("db2")))) // If database is not explicitly specified, check the current database catalog.setCurrentDatabase("db2") assert(catalog.tableExists(TableIdentifier("tbl1"))) assert(catalog.tableExists(TableIdentifier("tbl2"))) - assert(catalog.tableExists(TableIdentifier("tbl3"))) - } - test("tableExists on temporary views") { - val catalog = new SessionCatalog(newBasicCatalog()) - val tempTable = Range(1, 10, 2, 10) - assert(!catalog.tableExists(TableIdentifier("view1"))) - assert(!catalog.tableExists(TableIdentifier("view1", Some("default")))) - catalog.createTempView("view1", tempTable, overrideIfExists = false) - assert(catalog.tableExists(TableIdentifier("view1"))) - assert(!catalog.tableExists(TableIdentifier("view1", Some("default")))) + catalog.createTempView("tbl3", tempTable, overrideIfExists = false) + // tableExists should not check temp view. + assert(!catalog.tableExists(TableIdentifier("tbl3"))) } test("getTempViewOrPermanentTableMetadata on temporary views") { val catalog = new SessionCatalog(newBasicCatalog()) val tempTable = Range(1, 10, 2, 10) intercept[NoSuchTableException] { - catalog.getTempViewOrPermanentTableMetadata("view1") + catalog.getTempViewOrPermanentTableMetadata(TableIdentifier("view1")) + }.getMessage + + intercept[NoSuchTableException] { + catalog.getTempViewOrPermanentTableMetadata(TableIdentifier("view1", Some("default"))) }.getMessage catalog.createTempView("view1", tempTable, overrideIfExists = false) - assert(catalog.getTempViewOrPermanentTableMetadata("view1").identifier == - TableIdentifier("view1"), "the temporary view `view1` should exist") + assert(catalog.getTempViewOrPermanentTableMetadata( + TableIdentifier("view1")).identifier.table == "view1") + assert(catalog.getTempViewOrPermanentTableMetadata( + TableIdentifier("view1")).schema(0).name == "id") + + intercept[NoSuchTableException] { + catalog.getTempViewOrPermanentTableMetadata(TableIdentifier("view1", Some("default"))) + }.getMessage } test("list tables without pattern") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 9e343b5d24986..64d3422cb4b54 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -361,12 +361,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { throw new AnalysisException("Cannot create hive serde table with saveAsTable API") } - val sessionState = df.sparkSession.sessionState - val db = tableIdent.database.getOrElse(sessionState.catalog.getCurrentDatabase) - val tableIdentWithDB = tableIdent.copy(database = Some(db)) - // Pass a table identifier with database part, so that `tableExists` won't check temp views - // unexpectedly. - val tableExists = sessionState.catalog.tableExists(tableIdentWithDB) + val tableExists = df.sparkSession.sessionState.catalog.tableExists(tableIdent) (tableExists, mode) match { case (true, SaveMode.Ignore) => @@ -392,7 +387,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { bucketSpec = getBucketSpec ) val cmd = CreateTable(tableDesc, mode, Some(df.logicalPlan)) - sessionState.executePlan(cmd).toRdd + df.sparkSession.sessionState.executePlan(cmd).toRdd } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala index d8e20b09c1add..a04a13e698c43 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala @@ -47,15 +47,11 @@ case class CreateDataSourceTableCommand(table: CatalogTable, ignoreIfExists: Boo assert(table.provider.isDefined) val sessionState = sparkSession.sessionState - val db = table.identifier.database.getOrElse(sessionState.catalog.getCurrentDatabase) - val tableIdentWithDB = table.identifier.copy(database = Some(db)) - // Pass a table identifier with database part, so that `tableExists` won't check temp views - // unexpectedly. - if (sessionState.catalog.tableExists(tableIdentWithDB)) { + if (sessionState.catalog.tableExists(table.identifier)) { if (ignoreIfExists) { return Seq.empty[Row] } else { - throw new AnalysisException(s"Table ${tableIdentWithDB.unquotedString} already exists.") + throw new AnalysisException(s"Table ${table.identifier.unquotedString} already exists.") } } @@ -146,8 +142,6 @@ case class CreateDataSourceTableAsSelectCommand( var createMetastoreTable = false var existingSchema = Option.empty[StructType] - // Pass a table identifier with database part, so that `tableExists` won't check temp views - // unexpectedly. if (sparkSession.sessionState.catalog.tableExists(tableIdentWithDB)) { // Check if we need to throw an exception or just return. mode match { @@ -172,8 +166,9 @@ case class CreateDataSourceTableAsSelectCommand( // TODO: Check that options from the resolved relation match the relation that we are // inserting into (i.e. using the same compression). - EliminateSubqueryAliases( - sessionState.catalog.lookupRelation(tableIdentWithDB)) match { + // Pass a table identifier with database part, so that `lookupRelation` won't get temp + // views unexpectedly. + EliminateSubqueryAliases(sessionState.catalog.lookupRelation(tableIdentWithDB)) match { case l @ LogicalRelation(_: InsertableRelation | _: HadoopFsRelation, _, _) => // check if the file formats match l.relation match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index b57b2d280d8f8..01ac89868d100 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -183,32 +183,25 @@ case class DropTableCommand( override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog - if (!catalog.tableExists(tableName)) { - if (!ifExists) { - val objectName = if (isView) "View" else "Table" - throw new AnalysisException(s"$objectName to drop '$tableName' does not exist") - } - } else { - // If the command DROP VIEW is to drop a table or DROP TABLE is to drop a view - // issue an exception. - catalog.getTableMetadataOption(tableName).map(_.tableType match { - case CatalogTableType.VIEW if !isView => - throw new AnalysisException( - "Cannot drop a view with DROP TABLE. Please use DROP VIEW instead") - case o if o != CatalogTableType.VIEW && isView => - throw new AnalysisException( - s"Cannot drop a table with DROP VIEW. Please use DROP TABLE instead") - case _ => - }) - try { - sparkSession.sharedState.cacheManager.uncacheQuery( - sparkSession.table(tableName.quotedString)) - } catch { - case NonFatal(e) => log.warn(e.toString, e) - } - catalog.refreshTable(tableName) - catalog.dropTable(tableName, ifExists, purge) + // If the command DROP VIEW is to drop a table or DROP TABLE is to drop a view + // issue an exception. + catalog.getTableMetadataOption(tableName).map(_.tableType match { + case CatalogTableType.VIEW if !isView => + throw new AnalysisException( + "Cannot drop a view with DROP TABLE. Please use DROP VIEW instead") + case o if o != CatalogTableType.VIEW && isView => + throw new AnalysisException( + s"Cannot drop a table with DROP VIEW. Please use DROP TABLE instead") + case _ => + }) + try { + sparkSession.sharedState.cacheManager.uncacheQuery( + sparkSession.table(tableName.quotedString)) + } catch { + case NonFatal(e) => log.warn(e.toString, e) } + catalog.refreshTable(tableName) + catalog.dropTable(tableName, ifExists, purge) Seq.empty[Row] } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index 94b46c5d97155..0f61629317c81 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -59,16 +59,7 @@ case class CreateTableLikeCommand( override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog - if (!catalog.tableExists(sourceTable)) { - throw new AnalysisException( - s"Source table in CREATE TABLE LIKE does not exist: '$sourceTable'") - } - - val sourceTableDesc = if (sourceTable.database.isDefined) { - catalog.getTableMetadata(sourceTable) - } else { - catalog.getTempViewOrPermanentTableMetadata(sourceTable.table) - } + val sourceTableDesc = catalog.getTempViewOrPermanentTableMetadata(sourceTable) // Storage format val newStorage = @@ -602,11 +593,7 @@ case class ShowColumnsCommand(tableName: TableIdentifier) extends RunnableComman override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog - val table = if (tableName.database.isDefined) { - catalog.getTableMetadata(tableName) - } else { - catalog.getTempViewOrPermanentTableMetadata(tableName.table) - } + val table = catalog.getTempViewOrPermanentTableMetadata(tableName) table.schema.map { c => Row(c.name) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala index 6fecda232ab88..f252535765899 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala @@ -151,11 +151,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { } private def listColumns(tableIdentifier: TableIdentifier): Dataset[Column] = { - val tableMetadata = if (tableIdentifier.database.isDefined) { - sessionCatalog.getTableMetadata(tableIdentifier) - } else { - sessionCatalog.getTempViewOrPermanentTableMetadata(tableIdentifier.table) - } + val tableMetadata = sessionCatalog.getTempViewOrPermanentTableMetadata(tableIdentifier) val partitionColumnNames = tableMetadata.partitionColumnNames.toSet val bucketColumnNames = tableMetadata.bucketSpec.map(_.bucketColumnNames).getOrElse(Nil).toSet diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index 7143adf02b0e6..8ae6868c9848a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -515,7 +515,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv assert( intercept[AnalysisException] { sparkSession.catalog.createExternalTable("createdJsonTable", jsonFilePath.toString) - }.getMessage.contains("Table default.createdJsonTable already exists."), + }.getMessage.contains("Table createdJsonTable already exists."), "We should complain that createdJsonTable already exists") } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 38482f66a38e9..c927e5d802c90 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -678,8 +678,8 @@ class HiveDDLSuite .createTempView(sourceViewName) sql(s"CREATE TABLE $targetTabName LIKE $sourceViewName") - val sourceTable = - spark.sessionState.catalog.getTempViewOrPermanentTableMetadata(sourceViewName) + val sourceTable = spark.sessionState.catalog.getTempViewOrPermanentTableMetadata( + TableIdentifier(sourceViewName)) val targetTable = spark.sessionState.catalog.getTableMetadata( TableIdentifier(targetTabName, Some("default"))) From cb324f61150c962aeabf0a779f6a09797b3d5072 Mon Sep 17 00:00:00 2001 From: Yadong Qi Date: Thu, 22 Sep 2016 13:04:42 +0800 Subject: [PATCH 092/306] [SPARK-17425][SQL] Override sameResult in HiveTableScanExec to make ReuseExchange work in text format table ## What changes were proposed in this pull request? The PR will override the `sameResult` in `HiveTableScanExec` to make `ReuseExchange` work in text format table. ## How was this patch tested? # SQL ```sql SELECT * FROM src t1 JOIN src t2 ON t1.key = t2.key JOIN src t3 ON t1.key = t3.key; ``` # Before ``` == Physical Plan == *BroadcastHashJoin [key#30], [key#34], Inner, BuildRight :- *BroadcastHashJoin [key#30], [key#32], Inner, BuildRight : :- *Filter isnotnull(key#30) : : +- HiveTableScan [key#30, value#31], MetastoreRelation default, src : +- BroadcastExchange HashedRelationBroadcastMode(List(cast(input[0, int, false] as bigint))) : +- *Filter isnotnull(key#32) : +- HiveTableScan [key#32, value#33], MetastoreRelation default, src +- BroadcastExchange HashedRelationBroadcastMode(List(cast(input[0, int, false] as bigint))) +- *Filter isnotnull(key#34) +- HiveTableScan [key#34, value#35], MetastoreRelation default, src ``` # After ``` == Physical Plan == *BroadcastHashJoin [key#2], [key#6], Inner, BuildRight :- *BroadcastHashJoin [key#2], [key#4], Inner, BuildRight : :- *Filter isnotnull(key#2) : : +- HiveTableScan [key#2, value#3], MetastoreRelation default, src : +- BroadcastExchange HashedRelationBroadcastMode(List(cast(input[0, int, false] as bigint))) : +- *Filter isnotnull(key#4) : +- HiveTableScan [key#4, value#5], MetastoreRelation default, src +- ReusedExchange [key#6, value#7], BroadcastExchange HashedRelationBroadcastMode(List(cast(input[0, int, false] as bigint))) ``` cc: davies cloud-fan Author: Yadong Qi Closes #14988 from watermen/SPARK-17425. --- .../sql/hive/execution/HiveTableScanExec.scala | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala index a716a3eab6219..231f204b12b47 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala @@ -164,4 +164,19 @@ case class HiveTableScanExec( } override def output: Seq[Attribute] = attributes + + override def sameResult(plan: SparkPlan): Boolean = plan match { + case other: HiveTableScanExec => + val thisPredicates = partitionPruningPred.map(cleanExpression) + val otherPredicates = other.partitionPruningPred.map(cleanExpression) + + val result = relation.sameResult(other.relation) && + output.length == other.output.length && + output.zip(other.output) + .forall(p => p._1.name == p._2.name && p._1.dataType == p._2.dataType) && + thisPredicates.length == otherPredicates.length && + thisPredicates.zip(otherPredicates).forall(p => p._1.semanticEquals(p._2)) + result + case _ => false + } } From 3a80f92f8f4b91d0a85724bca7d81c6f5bbb78fd Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 22 Sep 2016 13:19:06 +0800 Subject: [PATCH 093/306] [SPARK-17492][SQL] Fix Reading Cataloged Data Sources without Extending SchemaRelationProvider ### What changes were proposed in this pull request? For data sources without extending `SchemaRelationProvider`, we expect users to not specify schemas when they creating tables. If the schema is input from users, an exception is issued. Since Spark 2.1, for any data source, to avoid infer the schema every time, we store the schema in the metastore catalog. Thus, when reading a cataloged data source table, the schema could be read from metastore catalog. In this case, we also got an exception. For example, ```Scala sql( s""" |CREATE TABLE relationProvierWithSchema |USING org.apache.spark.sql.sources.SimpleScanSource |OPTIONS ( | From '1', | To '10' |) """.stripMargin) spark.table(tableName).show() ``` ``` org.apache.spark.sql.sources.SimpleScanSource does not allow user-specified schemas.; ``` This PR is to fix the above issue. When building a data source, we introduce a flag `isSchemaFromUsers` to indicate whether the schema is really input from users. If true, we issue an exception. Otherwise, we will call the `createRelation` of `RelationProvider` to generate the `BaseRelation`, in which it contains the actual schema. ### How was this patch tested? Added a few cases. Author: gatorsmile Closes #15046 from gatorsmile/tempViewCases. --- .../execution/datasources/DataSource.scala | 9 ++- .../spark/sql/sources/InsertSuite.scala | 20 ++++++ .../spark/sql/sources/TableScanSuite.scala | 64 ++++++++++++------- .../sql/test/DataFrameReaderWriterSuite.scala | 33 ++++++++++ 4 files changed, 102 insertions(+), 24 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 413976a7ef244..32067011c3dff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -333,8 +333,13 @@ case class DataSource( dataSource.createRelation(sparkSession.sqlContext, caseInsensitiveOptions) case (_: SchemaRelationProvider, None) => throw new AnalysisException(s"A schema needs to be specified when using $className.") - case (_: RelationProvider, Some(_)) => - throw new AnalysisException(s"$className does not allow user-specified schemas.") + case (dataSource: RelationProvider, Some(schema)) => + val baseRelation = + dataSource.createRelation(sparkSession.sqlContext, caseInsensitiveOptions) + if (baseRelation.schema != schema) { + throw new AnalysisException(s"$className does not allow user-specified schemas.") + } + baseRelation // We are reading from the results of a streaming query. Load files from the metadata log // instead of listing them using HDFS APIs. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index 6454d716ec0db..5eb54643f204f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -65,6 +65,26 @@ class InsertSuite extends DataSourceTest with SharedSQLContext { ) } + test("insert into a temp view that does not point to an insertable data source") { + import testImplicits._ + withTempView("t1", "t2") { + sql( + """ + |CREATE TEMPORARY VIEW t1 + |USING org.apache.spark.sql.sources.SimpleScanSource + |OPTIONS ( + | From '1', + | To '10') + """.stripMargin) + sparkContext.parallelize(1 to 10).toDF("a").createOrReplaceTempView("t2") + + val message = intercept[AnalysisException] { + sql("INSERT INTO TABLE t1 SELECT a FROM t2") + }.getMessage + assert(message.contains("does not allow insertion")) + } + } + test("PreInsert casting and renaming") { sql( s""" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index e8fed039fa993..86bcb4d4b00c1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -348,31 +348,51 @@ class TableScanSuite extends DataSourceTest with SharedSQLContext { test("exceptions") { // Make sure we do throw correct exception when users use a relation provider that // only implements the RelationProvider or the SchemaRelationProvider. - val schemaNotAllowed = intercept[Exception] { - sql( - """ - |CREATE TEMPORARY VIEW relationProvierWithSchema (i int) - |USING org.apache.spark.sql.sources.SimpleScanSource - |OPTIONS ( - | From '1', - | To '10' - |) - """.stripMargin) + Seq("TEMPORARY VIEW", "TABLE").foreach { tableType => + val schemaNotAllowed = intercept[Exception] { + sql( + s""" + |CREATE $tableType relationProvierWithSchema (i int) + |USING org.apache.spark.sql.sources.SimpleScanSource + |OPTIONS ( + | From '1', + | To '10' + |) + """.stripMargin) + } + assert(schemaNotAllowed.getMessage.contains("does not allow user-specified schemas")) + + val schemaNeeded = intercept[Exception] { + sql( + s""" + |CREATE $tableType schemaRelationProvierWithoutSchema + |USING org.apache.spark.sql.sources.AllDataTypesScanSource + |OPTIONS ( + | From '1', + | To '10' + |) + """.stripMargin) + } + assert(schemaNeeded.getMessage.contains("A schema needs to be specified when using")) } - assert(schemaNotAllowed.getMessage.contains("does not allow user-specified schemas")) + } - val schemaNeeded = intercept[Exception] { - sql( - """ - |CREATE TEMPORARY VIEW schemaRelationProvierWithoutSchema - |USING org.apache.spark.sql.sources.AllDataTypesScanSource - |OPTIONS ( - | From '1', - | To '10' - |) - """.stripMargin) + test("read the data source tables that do not extend SchemaRelationProvider") { + Seq("TEMPORARY VIEW", "TABLE").foreach { tableType => + val tableName = "relationProvierWithSchema" + withTable (tableName) { + sql( + s""" + |CREATE $tableType $tableName + |USING org.apache.spark.sql.sources.SimpleScanSource + |OPTIONS ( + | From '1', + | To '10' + |) + """.stripMargin) + checkAnswer(spark.table(tableName), spark.range(1, 11).toDF()) + } } - assert(schemaNeeded.getMessage.contains("A schema needs to be specified when using")) } test("SPARK-5196 schema field with comment") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index 7368dad62859b..a7fda01098560 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -293,6 +293,39 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be Option(dir).map(spark.read.format("org.apache.spark.sql.test").load) } + test("read a data source that does not extend SchemaRelationProvider") { + val dfReader = spark.read + .option("from", "1") + .option("TO", "10") + .format("org.apache.spark.sql.sources.SimpleScanSource") + + // when users do not specify the schema + checkAnswer(dfReader.load(), spark.range(1, 11).toDF()) + + // when users specify the schema + val inputSchema = new StructType().add("s", IntegerType, nullable = false) + val e = intercept[AnalysisException] { dfReader.schema(inputSchema).load() } + assert(e.getMessage.contains( + "org.apache.spark.sql.sources.SimpleScanSource does not allow user-specified schemas")) + } + + test("read a data source that does not extend RelationProvider") { + val dfReader = spark.read + .option("from", "1") + .option("TO", "10") + .option("option_with_underscores", "someval") + .option("option.with.dots", "someval") + .format("org.apache.spark.sql.sources.AllDataTypesScanSource") + + // when users do not specify the schema + val e = intercept[AnalysisException] { dfReader.load() } + assert(e.getMessage.contains("A schema needs to be specified when using")) + + // when users specify the schema + val inputSchema = new StructType().add("s", StringType, nullable = false) + assert(dfReader.schema(inputSchema).load().count() == 10) + } + test("text - API and behavior regarding schema") { // Writer spark.createDataset(data).write.mode(SaveMode.Overwrite).text(dir) From de7df7defc99e04fefd990974151a701f64b75b4 Mon Sep 17 00:00:00 2001 From: Zhenhua Wang Date: Thu, 22 Sep 2016 14:48:49 +0800 Subject: [PATCH 094/306] [SPARK-17625][SQL] set expectedOutputAttributes when converting SimpleCatalogRelation to LogicalRelation ## What changes were proposed in this pull request? We should set expectedOutputAttributes when converting SimpleCatalogRelation to LogicalRelation, otherwise the outputs of LogicalRelation are different from outputs of SimpleCatalogRelation - they have different exprId's. ## How was this patch tested? add a test case Author: Zhenhua Wang Closes #15182 from wzhfy/expectedAttributes. --- .../execution/datasources/DataSourceStrategy.scala | 10 +++++++--- .../org/apache/spark/sql/DataFrameSuite.scala | 14 +++++++++++++- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index c8ad5b303491f..63f01c5bb9e3c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -197,7 +197,10 @@ case class DataSourceAnalysis(conf: CatalystConf) extends Rule[LogicalPlan] { * source information. */ class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan] { - private def readDataSourceTable(sparkSession: SparkSession, table: CatalogTable): LogicalPlan = { + private def readDataSourceTable( + sparkSession: SparkSession, + simpleCatalogRelation: SimpleCatalogRelation): LogicalPlan = { + val table = simpleCatalogRelation.catalogTable val dataSource = DataSource( sparkSession, @@ -209,16 +212,17 @@ class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan] LogicalRelation( dataSource.resolveRelation(), + expectedOutputAttributes = Some(simpleCatalogRelation.output), catalogTable = Some(table)) } override def apply(plan: LogicalPlan): LogicalPlan = plan transform { case i @ logical.InsertIntoTable(s: SimpleCatalogRelation, _, _, _, _) if DDLUtils.isDatasourceTable(s.metadata) => - i.copy(table = readDataSourceTable(sparkSession, s.metadata)) + i.copy(table = readDataSourceTable(sparkSession, s)) case s: SimpleCatalogRelation if DDLUtils.isDatasourceTable(s.metadata) => - readDataSourceTable(sparkSession, s.metadata) + readDataSourceTable(sparkSession, s) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index c2d256bdd335b..2c60a7dd9209b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -26,7 +26,8 @@ import scala.util.Random import org.scalatest.Matchers._ import org.apache.spark.SparkException -import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Union} +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project, Union} import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchange} @@ -1585,4 +1586,15 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val d = sampleDf.withColumn("c", monotonically_increasing_id).select($"c").collect assert(d.size == d.distinct.size) } + + test("SPARK-17625: data source table in InMemoryCatalog should guarantee output consistency") { + val tableName = "tbl" + withTable(tableName) { + spark.range(10).select('id as 'i, 'id as 'j).write.saveAsTable(tableName) + val relation = spark.sessionState.catalog.lookupRelation(TableIdentifier(tableName)) + val expr = relation.resolve("i") + val qe = spark.sessionState.executePlan(Project(Seq(expr), relation)) + qe.assertAnalyzed() + } + } } From 646f383465c123062cbcce288a127e23984c7c7f Mon Sep 17 00:00:00 2001 From: frreiss Date: Thu, 22 Sep 2016 10:31:15 +0100 Subject: [PATCH 095/306] [SPARK-17421][DOCS] Documenting the current treatment of MAVEN_OPTS. ## What changes were proposed in this pull request? Modified the documentation to clarify that `build/mvn` and `pom.xml` always add Java 7-specific parameters to `MAVEN_OPTS`, and that developers can safely ignore warnings about `-XX:MaxPermSize` that may result from compiling or running tests with Java 8. ## How was this patch tested? Rebuilt HTML documentation, made sure that building-spark.html displays correctly in a browser. Author: frreiss Closes #15005 from frreiss/fred-17421a. --- docs/building-spark.md | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/docs/building-spark.md b/docs/building-spark.md index 6908fc1ba74d0..75c304a3ccecd 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -16,11 +16,13 @@ Building Spark using Maven requires Maven 3.3.9 or newer and Java 7+. ### Setting up Maven's Memory Usage -You'll need to configure Maven to use more memory than usual by setting `MAVEN_OPTS`. We recommend the following settings: +You'll need to configure Maven to use more memory than usual by setting `MAVEN_OPTS`: - export MAVEN_OPTS="-Xmx2g -XX:MaxPermSize=512M -XX:ReservedCodeCacheSize=512m" + export MAVEN_OPTS="-Xmx2g -XX:ReservedCodeCacheSize=512m" -If you don't run this, you may see errors like the following: +When compiling with Java 7, you will need to add the additional option "-XX:MaxPermSize=512M" to MAVEN_OPTS. + +If you don't add these parameters to `MAVEN_OPTS`, you may see errors and warnings like the following: [INFO] Compiling 203 Scala sources and 9 Java sources to /Users/me/Development/spark/core/target/scala-{{site.SCALA_BINARY_VERSION}}/classes... [ERROR] PermGen space -> [Help 1] @@ -28,12 +30,18 @@ If you don't run this, you may see errors like the following: [INFO] Compiling 203 Scala sources and 9 Java sources to /Users/me/Development/spark/core/target/scala-{{site.SCALA_BINARY_VERSION}}/classes... [ERROR] Java heap space -> [Help 1] -You can fix this by setting the `MAVEN_OPTS` variable as discussed before. + [INFO] Compiling 233 Scala sources and 41 Java sources to /Users/me/Development/spark/sql/core/target/scala-{site.SCALA_BINARY_VERSION}/classes... + OpenJDK 64-Bit Server VM warning: CodeCache is full. Compiler has been disabled. + OpenJDK 64-Bit Server VM warning: Try increasing the code cache size using -XX:ReservedCodeCacheSize= + +You can fix these problems by setting the `MAVEN_OPTS` variable as discussed before. **Note:** -* For Java 8 and above this step is not required. -* If using `build/mvn` with no `MAVEN_OPTS` set, the script will automate this for you. +* If using `build/mvn` with no `MAVEN_OPTS` set, the script will automatically add the above options to the `MAVEN_OPTS` environment variable. +* The `test` phase of the Spark build will automatically add these options to `MAVEN_OPTS`, even when not using `build/mvn`. +* You may see warnings like "ignoring option MaxPermSize=1g; support was removed in 8.0" when building or running tests with Java 8 and `build/mvn`. These warnings are harmless. + ### build/mvn From 72d9fba26c19aae73116fd0d00b566967934c6fc Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Thu, 22 Sep 2016 04:35:54 -0700 Subject: [PATCH 096/306] [SPARK-17281][ML][MLLIB] Add treeAggregateDepth parameter for AFTSurvivalRegression ## What changes were proposed in this pull request? Add treeAggregateDepth parameter for AFTSurvivalRegression to keep consistent with LiR/LoR. ## How was this patch tested? Existing tests. Author: WeichenXu Closes #14851 from WeichenXu123/add_treeAggregate_param_for_survival_regression. --- .../ml/regression/AFTSurvivalRegression.scala | 24 +++++++++++++++---- python/pyspark/ml/regression.py | 11 +++++---- 2 files changed, 25 insertions(+), 10 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index 3179f4882fd49..9d5ba999781f6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -46,7 +46,7 @@ import org.apache.spark.storage.StorageLevel */ private[regression] trait AFTSurvivalRegressionParams extends Params with HasFeaturesCol with HasLabelCol with HasPredictionCol with HasMaxIter - with HasTol with HasFitIntercept with Logging { + with HasTol with HasFitIntercept with HasAggregationDepth with Logging { /** * Param for censor column name. @@ -183,6 +183,17 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S def setTol(value: Double): this.type = set(tol, value) setDefault(tol -> 1E-6) + /** + * Suggested depth for treeAggregate (>= 2). + * If the dimensions of features or the number of partitions are large, + * this param could be adjusted to a larger size. + * Default is 2. + * @group expertSetParam + */ + @Since("2.1.0") + def setAggregationDepth(value: Int): this.type = set(aggregationDepth, value) + setDefault(aggregationDepth -> 2) + /** * Extract [[featuresCol]], [[labelCol]] and [[censorCol]] from input dataset, * and put it in an RDD with strong types. @@ -207,7 +218,9 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S val combOp = (c1: MultivariateOnlineSummarizer, c2: MultivariateOnlineSummarizer) => { c1.merge(c2) } - instances.treeAggregate(new MultivariateOnlineSummarizer)(seqOp, combOp) + instances.treeAggregate( + new MultivariateOnlineSummarizer + )(seqOp, combOp, $(aggregationDepth)) } val featuresStd = featuresSummarizer.variance.toArray.map(math.sqrt) @@ -222,7 +235,7 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S val bcFeaturesStd = instances.context.broadcast(featuresStd) - val costFun = new AFTCostFun(instances, $(fitIntercept), bcFeaturesStd) + val costFun = new AFTCostFun(instances, $(fitIntercept), bcFeaturesStd, $(aggregationDepth)) val optimizer = new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol)) /* @@ -591,7 +604,8 @@ private class AFTAggregator( private class AFTCostFun( data: RDD[AFTPoint], fitIntercept: Boolean, - bcFeaturesStd: Broadcast[Array[Double]]) extends DiffFunction[BDV[Double]] { + bcFeaturesStd: Broadcast[Array[Double]], + aggregationDepth: Int) extends DiffFunction[BDV[Double]] { override def calculate(parameters: BDV[Double]): (Double, BDV[Double]) = { @@ -604,7 +618,7 @@ private class AFTCostFun( }, combOp = (c1, c2) => (c1, c2) match { case (aggregator1, aggregator2) => aggregator1.merge(aggregator2) - }) + }, depth = aggregationDepth) bcParameters.destroy(blocking = false) (aftAggregator.loss, aftAggregator.gradient) diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 19afc723bb784..55d38033ef72a 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -1088,7 +1088,8 @@ def trees(self): @inherit_doc class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, - HasFitIntercept, HasMaxIter, HasTol, JavaMLWritable, JavaMLReadable): + HasFitIntercept, HasMaxIter, HasTol, HasAggregationDepth, + JavaMLWritable, JavaMLReadable): """ .. note:: Experimental @@ -1153,12 +1154,12 @@ class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", quantileProbabilities=list([0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99]), - quantilesCol=None): + quantilesCol=None, aggregationDepth=2): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", \ quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99], \ - quantilesCol=None) + quantilesCol=None, aggregationDepth=2) """ super(AFTSurvivalRegression, self).__init__() self._java_obj = self._new_java_obj( @@ -1174,12 +1175,12 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", quantileProbabilities=list([0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99]), - quantilesCol=None): + quantilesCol=None, aggregationDepth=2): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", \ quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99], \ - quantilesCol=None): + quantilesCol=None, aggregationDepth=2): """ kwargs = self.setParams._input_kwargs return self._set(**kwargs) From 8a02410a92429bff50d6ce082f873cea9e9fa91e Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 22 Sep 2016 23:25:32 +0800 Subject: [PATCH 097/306] [SQL][MINOR] correct the comment of SortBasedAggregationIterator.safeProj ## What changes were proposed in this pull request? This comment went stale long time ago, this PR fixes it according to my understanding. ## How was this patch tested? N/A Author: Wenchen Fan Closes #15095 from cloud-fan/update-comment. --- .../aggregate/SortBasedAggregationIterator.scala | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala index 3f7f84988594a..c2b1ef0fe3c2c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala @@ -86,8 +86,15 @@ class SortBasedAggregationIterator( // The aggregation buffer used by the sort-based aggregation. private[this] val sortBasedAggregationBuffer: MutableRow = newBuffer - // A SafeProjection to turn UnsafeRow into GenericInternalRow, because UnsafeRow can't be - // compared to MutableRow (aggregation buffer) directly. + // This safe projection is used to turn the input row into safe row. This is necessary + // because the input row may be produced by unsafe projection in child operator and all the + // produced rows share one byte array. However, when we update the aggregate buffer according to + // the input row, we may cache some values from input row, e.g. `Max` will keep the max value from + // input row via MutableProjection, `CollectList` will keep all values in an array via + // ImperativeAggregate framework. These values may get changed unexpectedly if the underlying + // unsafe projection update the shared byte array. By applying a safe projection to the input row, + // we can cut down the connection from input row to the shared byte array, and thus it's safe to + // cache values from input row while updating the aggregation buffer. private[this] val safeProj: Projection = FromUnsafeProjection(valueAttributes.map(_.dataType)) protected def initialize(): Unit = { From 17b72d31e0c59711eddeb525becb8085930eadcc Mon Sep 17 00:00:00 2001 From: Dhruve Ashar Date: Thu, 22 Sep 2016 10:10:37 -0700 Subject: [PATCH 098/306] [SPARK-17365][CORE] Remove/Kill multiple executors together to reduce RPC call time. ## What changes were proposed in this pull request? We are killing multiple executors together instead of iterating over expensive RPC calls to kill single executor. ## How was this patch tested? Executed sample spark job to observe executors being killed/removed with dynamic allocation enabled. Author: Dhruve Ashar Author: Dhruve Ashar Closes #15152 from dhruve/impr/SPARK-17365. --- .../spark/ExecutorAllocationClient.scala | 9 +- .../spark/ExecutorAllocationManager.scala | 86 ++++++++--- .../scala/org/apache/spark/SparkContext.scala | 24 ++-- .../CoarseGrainedSchedulerBackend.scala | 12 +- ...che.spark.scheduler.ExternalClusterManager | 3 +- .../ExecutorAllocationManagerSuite.scala | 135 ++++++++++++++++-- .../StandaloneDynamicAllocationSuite.scala | 6 +- project/MimaExcludes.scala | 3 + .../scheduler/ExecutorAllocationManager.scala | 2 +- .../streaming/scheduler/JobScheduler.scala | 9 +- .../ExecutorAllocationManagerSuite.scala | 5 +- 11 files changed, 239 insertions(+), 55 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala index 8baddf45bfc31..5d47f624ac8a3 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala @@ -54,13 +54,16 @@ private[spark] trait ExecutorAllocationClient { /** * Request that the cluster manager kill the specified executors. - * @return whether the request is acknowledged by the cluster manager. + * @return the ids of the executors acknowledged by the cluster manager to be removed. */ - def killExecutors(executorIds: Seq[String]): Boolean + def killExecutors(executorIds: Seq[String]): Seq[String] /** * Request that the cluster manager kill the specified executor. * @return whether the request is acknowledged by the cluster manager. */ - def killExecutor(executorId: String): Boolean = killExecutors(Seq(executorId)) + def killExecutor(executorId: String): Boolean = { + val killedExecutors = killExecutors(Seq(executorId)) + killedExecutors.nonEmpty && killedExecutors(0).equals(executorId) + } } diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 6f320c524201c..1366251d0618f 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -20,6 +20,7 @@ package org.apache.spark import java.util.concurrent.TimeUnit import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer import scala.util.control.ControlThrowable import com.codahale.metrics.{Gauge, MetricRegistry} @@ -279,14 +280,18 @@ private[spark] class ExecutorAllocationManager( updateAndSyncNumExecutorsTarget(now) + val executorIdsToBeRemoved = ArrayBuffer[String]() removeTimes.retain { case (executorId, expireTime) => val expired = now >= expireTime if (expired) { initializing = false - removeExecutor(executorId) + executorIdsToBeRemoved += executorId } !expired } + if (executorIdsToBeRemoved.nonEmpty) { + removeExecutors(executorIdsToBeRemoved) + } } /** @@ -391,11 +396,67 @@ private[spark] class ExecutorAllocationManager( } } + /** + * Request the cluster manager to remove the given executors. + * Returns the list of executors which are removed. + */ + private def removeExecutors(executors: Seq[String]): Seq[String] = synchronized { + val executorIdsToBeRemoved = new ArrayBuffer[String] + + logInfo("Request to remove executorIds: " + executors.mkString(", ")) + val numExistingExecutors = allocationManager.executorIds.size - executorsPendingToRemove.size + + var newExecutorTotal = numExistingExecutors + executors.foreach { executorIdToBeRemoved => + if (newExecutorTotal - 1 < minNumExecutors) { + logDebug(s"Not removing idle executor $executorIdToBeRemoved because there are only " + + s"$newExecutorTotal executor(s) left (limit $minNumExecutors)") + } else if (canBeKilled(executorIdToBeRemoved)) { + executorIdsToBeRemoved += executorIdToBeRemoved + newExecutorTotal -= 1 + } + } + + if (executorIdsToBeRemoved.isEmpty) { + return Seq.empty[String] + } + + // Send a request to the backend to kill this executor(s) + val executorsRemoved = if (testing) { + executorIdsToBeRemoved + } else { + client.killExecutors(executorIdsToBeRemoved) + } + // reset the newExecutorTotal to the existing number of executors + newExecutorTotal = numExistingExecutors + if (testing || executorsRemoved.nonEmpty) { + executorsRemoved.foreach { removedExecutorId => + newExecutorTotal -= 1 + logInfo(s"Removing executor $removedExecutorId because it has been idle for " + + s"$executorIdleTimeoutS seconds (new desired total will be $newExecutorTotal)") + executorsPendingToRemove.add(removedExecutorId) + } + executorsRemoved + } else { + logWarning(s"Unable to reach the cluster manager to kill executor/s " + + "executorIdsToBeRemoved.mkString(\",\") or no executor eligible to kill!") + Seq.empty[String] + } + } + /** * Request the cluster manager to remove the given executor. - * Return whether the request is received. + * Return whether the request is acknowledged. */ private def removeExecutor(executorId: String): Boolean = synchronized { + val executorsRemoved = removeExecutors(Seq(executorId)) + executorsRemoved.nonEmpty && executorsRemoved(0) == executorId + } + + /** + * Determine if the given executor can be killed. + */ + private def canBeKilled(executorId: String): Boolean = synchronized { // Do not kill the executor if we are not aware of it (should never happen) if (!executorIds.contains(executorId)) { logWarning(s"Attempted to remove unknown executor $executorId!") @@ -409,26 +470,7 @@ private[spark] class ExecutorAllocationManager( return false } - // Do not kill the executor if we have already reached the lower bound - val numExistingExecutors = executorIds.size - executorsPendingToRemove.size - if (numExistingExecutors - 1 < minNumExecutors) { - logDebug(s"Not removing idle executor $executorId because there are only " + - s"$numExistingExecutors executor(s) left (limit $minNumExecutors)") - return false - } - - // Send a request to the backend to kill this executor - val removeRequestAcknowledged = testing || client.killExecutor(executorId) - if (removeRequestAcknowledged) { - logInfo(s"Removing executor $executorId because it has been idle for " + - s"$executorIdleTimeoutS seconds (new desired total will be ${numExistingExecutors - 1})") - executorsPendingToRemove.add(executorId) - true - } else { - logWarning(s"Unable to reach the cluster manager to kill executor $executorId," + - s"or no executor eligible to kill!") - false - } + true } /** diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 1981ad5671093..f58037e100989 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -73,7 +73,7 @@ import org.apache.spark.util._ * @param config a Spark Config object describing the application configuration. Any settings in * this config overrides the default configs as well as system properties. */ -class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationClient { +class SparkContext(config: SparkConf) extends Logging { // The call site where this SparkContext was constructed. private val creationSite: CallSite = Utils.getCallSite() @@ -534,7 +534,13 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli val dynamicAllocationEnabled = Utils.isDynamicAllocationEnabled(_conf) _executorAllocationManager = if (dynamicAllocationEnabled) { - Some(new ExecutorAllocationManager(this, listenerBus, _conf)) + schedulerBackend match { + case b: ExecutorAllocationClient => + Some(new ExecutorAllocationManager( + schedulerBackend.asInstanceOf[ExecutorAllocationClient], listenerBus, _conf)) + case _ => + None + } } else { None } @@ -1473,7 +1479,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli listenerBus.addListener(listener) } - private[spark] override def getExecutorIds(): Seq[String] = { + private[spark] def getExecutorIds(): Seq[String] = { schedulerBackend match { case b: CoarseGrainedSchedulerBackend => b.getExecutorIds() @@ -1498,7 +1504,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * @return whether the request is acknowledged by the cluster manager. */ @DeveloperApi - override def requestTotalExecutors( + def requestTotalExecutors( numExecutors: Int, localityAwareTasks: Int, hostToLocalTaskCount: scala.collection.immutable.Map[String, Int] @@ -1518,7 +1524,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * @return whether the request is received. */ @DeveloperApi - override def requestExecutors(numAdditionalExecutors: Int): Boolean = { + def requestExecutors(numAdditionalExecutors: Int): Boolean = { schedulerBackend match { case b: CoarseGrainedSchedulerBackend => b.requestExecutors(numAdditionalExecutors) @@ -1540,10 +1546,10 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * @return whether the request is received. */ @DeveloperApi - override def killExecutors(executorIds: Seq[String]): Boolean = { + def killExecutors(executorIds: Seq[String]): Boolean = { schedulerBackend match { case b: CoarseGrainedSchedulerBackend => - b.killExecutors(executorIds, replace = false, force = true) + b.killExecutors(executorIds, replace = false, force = true).nonEmpty case _ => logWarning("Killing executors is only supported in coarse-grained mode") false @@ -1562,7 +1568,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * @return whether the request is received. */ @DeveloperApi - override def killExecutor(executorId: String): Boolean = super.killExecutor(executorId) + def killExecutor(executorId: String): Boolean = killExecutors(Seq(executorId)) /** * Request that the cluster manager kill the specified executor without adjusting the @@ -1581,7 +1587,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli private[spark] def killAndReplaceExecutor(executorId: String): Boolean = { schedulerBackend match { case b: CoarseGrainedSchedulerBackend => - b.killExecutors(Seq(executorId), replace = true, force = true) + b.killExecutors(Seq(executorId), replace = true, force = true).nonEmpty case _ => logWarning("Killing executors is only supported in coarse-grained mode") false diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index c6b3fdf439f5f..edc3c199376ef 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -528,7 +528,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp * @return whether the kill request is acknowledged. If list to kill is empty, it will return * false. */ - final override def killExecutors(executorIds: Seq[String]): Boolean = { + final override def killExecutors(executorIds: Seq[String]): Seq[String] = { killExecutors(executorIds, replace = false, force = false) } @@ -548,7 +548,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp final def killExecutors( executorIds: Seq[String], replace: Boolean, - force: Boolean): Boolean = { + force: Boolean): Seq[String] = { logInfo(s"Requesting to kill executor(s) ${executorIds.mkString(", ")}") val response = synchronized { @@ -564,6 +564,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp .filter { id => force || !scheduler.isExecutorBusy(id) } executorsToKill.foreach { id => executorsPendingToRemove(id) = !replace } + logInfo(s"Actual list of executor(s) to be killed is ${executorsToKill.mkString(", ")}") + // If we do not wish to replace the executors we kill, sync the target number of executors // with the cluster manager to avoid allocating new ones. When computing the new target, // take into account executors that are pending to be added or removed. @@ -583,7 +585,11 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp _ => Future.successful(false) } - adjustTotalExecutors.flatMap(killExecutors)(ThreadUtils.sameThread) + val killResponse = adjustTotalExecutors.flatMap(killExecutors)(ThreadUtils.sameThread) + + killResponse.flatMap(killSuccessful => + Future.successful (if (killSuccessful) executorsToKill else Seq.empty[String]) + )(ThreadUtils.sameThread) } defaultAskTimeout.awaitResult(response) diff --git a/core/src/test/resources/META-INF/services/org.apache.spark.scheduler.ExternalClusterManager b/core/src/test/resources/META-INF/services/org.apache.spark.scheduler.ExternalClusterManager index 757c6d2296aff..cf8565c74e95e 100644 --- a/core/src/test/resources/META-INF/services/org.apache.spark.scheduler.ExternalClusterManager +++ b/core/src/test/resources/META-INF/services/org.apache.spark.scheduler.ExternalClusterManager @@ -1,2 +1,3 @@ org.apache.spark.scheduler.DummyExternalClusterManager -org.apache.spark.scheduler.MockExternalClusterManager \ No newline at end of file +org.apache.spark.scheduler.MockExternalClusterManager +org.apache.spark.DummyLocalExternalClusterManager diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala index c130649830416..ec409712b953c 100644 --- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -23,7 +23,9 @@ import org.scalatest.{BeforeAndAfter, PrivateMethodTester} import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ +import org.apache.spark.scheduler.ExternalClusterManager import org.apache.spark.scheduler.cluster.ExecutorInfo +import org.apache.spark.scheduler.local.LocalSchedulerBackend import org.apache.spark.util.ManualClock /** @@ -49,7 +51,7 @@ class ExecutorAllocationManagerSuite test("verify min/max executors") { val conf = new SparkConf() - .setMaster("local") + .setMaster("myDummyLocalExternalClusterManager") .setAppName("test-executor-allocation-manager") .set("spark.dynamicAllocation.enabled", "true") .set("spark.dynamicAllocation.testing", "true") @@ -263,6 +265,55 @@ class ExecutorAllocationManagerSuite assert(executorsPendingToRemove(manager).isEmpty) } + test("remove multiple executors") { + sc = createSparkContext(5, 10, 5) + val manager = sc.executorAllocationManager.get + (1 to 10).map(_.toString).foreach { id => onExecutorAdded(manager, id) } + + // Keep removing until the limit is reached + assert(executorsPendingToRemove(manager).isEmpty) + assert(removeExecutors(manager, Seq("1")) === Seq("1")) + assert(executorsPendingToRemove(manager).size === 1) + assert(executorsPendingToRemove(manager).contains("1")) + assert(removeExecutors(manager, Seq("2", "3")) === Seq("2", "3")) + assert(executorsPendingToRemove(manager).size === 3) + assert(executorsPendingToRemove(manager).contains("2")) + assert(executorsPendingToRemove(manager).contains("3")) + assert(!removeExecutor(manager, "100")) // remove non-existent executors + assert(removeExecutors(manager, Seq("101", "102")) !== Seq("101", "102")) + assert(executorsPendingToRemove(manager).size === 3) + assert(removeExecutor(manager, "4")) + assert(removeExecutors(manager, Seq("5")) === Seq("5")) + assert(!removeExecutor(manager, "6")) // reached the limit of 5 + assert(executorsPendingToRemove(manager).size === 5) + assert(executorsPendingToRemove(manager).contains("4")) + assert(executorsPendingToRemove(manager).contains("5")) + assert(!executorsPendingToRemove(manager).contains("6")) + + // Kill executors previously requested to remove + onExecutorRemoved(manager, "1") + assert(executorsPendingToRemove(manager).size === 4) + assert(!executorsPendingToRemove(manager).contains("1")) + onExecutorRemoved(manager, "2") + onExecutorRemoved(manager, "3") + assert(executorsPendingToRemove(manager).size === 2) + assert(!executorsPendingToRemove(manager).contains("2")) + assert(!executorsPendingToRemove(manager).contains("3")) + onExecutorRemoved(manager, "2") // duplicates should not count + onExecutorRemoved(manager, "3") + assert(executorsPendingToRemove(manager).size === 2) + onExecutorRemoved(manager, "4") + onExecutorRemoved(manager, "5") + assert(executorsPendingToRemove(manager).isEmpty) + + // Try removing again + // This should still fail because the number pending + running is still at the limit + assert(!removeExecutor(manager, "7")) + assert(executorsPendingToRemove(manager).isEmpty) + assert(removeExecutors(manager, Seq("8")) !== Seq("8")) + assert(executorsPendingToRemove(manager).isEmpty) + } + test ("interleaving add and remove") { sc = createSparkContext(5, 10, 5) val manager = sc.executorAllocationManager.get @@ -283,8 +334,7 @@ class ExecutorAllocationManagerSuite // Remove until limit assert(removeExecutor(manager, "1")) - assert(removeExecutor(manager, "2")) - assert(removeExecutor(manager, "3")) + assert(removeExecutors(manager, Seq("2", "3")) === Seq("2", "3")) assert(!removeExecutor(manager, "4")) // lower limit reached assert(!removeExecutor(manager, "5")) onExecutorRemoved(manager, "1") @@ -296,7 +346,7 @@ class ExecutorAllocationManagerSuite assert(addExecutors(manager) === 2) // upper limit reached assert(addExecutors(manager) === 0) assert(!removeExecutor(manager, "4")) // still at lower limit - assert(!removeExecutor(manager, "5")) + assert((manager, Seq("5")) !== Seq("5")) onExecutorAdded(manager, "9") onExecutorAdded(manager, "10") onExecutorAdded(manager, "11") @@ -305,9 +355,7 @@ class ExecutorAllocationManagerSuite assert(executorIds(manager).size === 10) // Remove succeeds again, now that we are no longer at the lower limit - assert(removeExecutor(manager, "4")) - assert(removeExecutor(manager, "5")) - assert(removeExecutor(manager, "6")) + assert(removeExecutors(manager, Seq("4", "5", "6")) === Seq("4", "5", "6")) assert(removeExecutor(manager, "7")) assert(executorIds(manager).size === 10) assert(addExecutors(manager) === 0) @@ -870,8 +918,8 @@ class ExecutorAllocationManagerSuite assert(executorIds(manager) === Set("first", "second", "third", "fourth", "fifth")) removeExecutor(manager, "first") - removeExecutor(manager, "second") - assert(executorsPendingToRemove(manager) === Set("first", "second")) + removeExecutors(manager, Seq("second", "third")) + assert(executorsPendingToRemove(manager) === Set("first", "second", "third")) assert(executorIds(manager) === Set("first", "second", "third", "fourth", "fifth")) @@ -895,7 +943,7 @@ class ExecutorAllocationManagerSuite maxExecutors: Int = 5, initialExecutors: Int = 1): SparkContext = { val conf = new SparkConf() - .setMaster("local") + .setMaster("myDummyLocalExternalClusterManager") .setAppName("test-executor-allocation-manager") .set("spark.dynamicAllocation.enabled", "true") .set("spark.dynamicAllocation.minExecutors", minExecutors.toString) @@ -953,6 +1001,7 @@ private object ExecutorAllocationManagerSuite extends PrivateMethodTester { private val _updateAndSyncNumExecutorsTarget = PrivateMethod[Int]('updateAndSyncNumExecutorsTarget) private val _removeExecutor = PrivateMethod[Boolean]('removeExecutor) + private val _removeExecutors = PrivateMethod[Seq[String]]('removeExecutors) private val _onExecutorAdded = PrivateMethod[Unit]('onExecutorAdded) private val _onExecutorRemoved = PrivateMethod[Unit]('onExecutorRemoved) private val _onSchedulerBacklogged = PrivateMethod[Unit]('onSchedulerBacklogged) @@ -1008,6 +1057,10 @@ private object ExecutorAllocationManagerSuite extends PrivateMethodTester { manager invokePrivate _removeExecutor(id) } + private def removeExecutors(manager: ExecutorAllocationManager, ids: Seq[String]): Seq[String] = { + manager invokePrivate _removeExecutors(ids) + } + private def onExecutorAdded(manager: ExecutorAllocationManager, id: String): Unit = { manager invokePrivate _onExecutorAdded(id) } @@ -1040,3 +1093,65 @@ private object ExecutorAllocationManagerSuite extends PrivateMethodTester { manager invokePrivate _hostToLocalTaskCount() } } + +/** + * A cluster manager which wraps around the scheduler and backend for local mode. It is used for + * testing the dynamic allocation policy. + */ +private class DummyLocalExternalClusterManager extends ExternalClusterManager { + + def canCreate(masterURL: String): Boolean = masterURL == "myDummyLocalExternalClusterManager" + + override def createTaskScheduler( + sc: SparkContext, + masterURL: String): TaskScheduler = new TaskSchedulerImpl(sc, 1, isLocal = true) + + override def createSchedulerBackend( + sc: SparkContext, + masterURL: String, + scheduler: TaskScheduler): SchedulerBackend = { + val sb = new LocalSchedulerBackend(sc.getConf, scheduler.asInstanceOf[TaskSchedulerImpl], 1) + new DummyLocalSchedulerBackend(sc, sb) + } + + override def initialize(scheduler: TaskScheduler, backend: SchedulerBackend): Unit = { + val sc = scheduler.asInstanceOf[TaskSchedulerImpl] + sc.initialize(backend) + } +} + +/** + * A scheduler backend which wraps around local scheduler backend and exposes the executor + * allocation client interface for testing dynamic allocation. + */ +private class DummyLocalSchedulerBackend (sc: SparkContext, sb: SchedulerBackend) + extends SchedulerBackend with ExecutorAllocationClient { + + override private[spark] def getExecutorIds(): Seq[String] = sc.getExecutorIds() + + override private[spark] def requestTotalExecutors( + numExecutors: Int, + localityAwareTasks: Int, + hostToLocalTaskCount: Map[String, Int]): Boolean = + sc.requestTotalExecutors(numExecutors, localityAwareTasks, hostToLocalTaskCount) + + override def requestExecutors(numAdditionalExecutors: Int): Boolean = + sc.requestExecutors(numAdditionalExecutors) + + override def killExecutors(executorIds: Seq[String]): Seq[String] = { + val response = sc.killExecutors(executorIds) + if (response) { + executorIds + } else { + Seq.empty[String] + } + } + + override def start(): Unit = sb.start() + + override def stop(): Unit = sb.stop() + + override def reviveOffers(): Unit = sb.reviveOffers() + + override def defaultParallelism(): Int = sb.defaultParallelism() +} diff --git a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala index 814027076d6fe..e29eb8552e134 100644 --- a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala @@ -438,12 +438,12 @@ class StandaloneDynamicAllocationSuite val executorIdToTaskCount = taskScheduler invokePrivate getMap() executorIdToTaskCount(executors.head) = 1 // kill the busy executor without force; this should fail - assert(!killExecutor(sc, executors.head, force = false)) + assert(killExecutor(sc, executors.head, force = false).isEmpty) apps = getApplications() assert(apps.head.executors.size === 2) // force kill busy executor - assert(killExecutor(sc, executors.head, force = true)) + assert(killExecutor(sc, executors.head, force = true).nonEmpty) apps = getApplications() // kill executor successfully assert(apps.head.executors.size === 1) @@ -518,7 +518,7 @@ class StandaloneDynamicAllocationSuite } /** Kill the given executor, specifying whether to force kill it. */ - private def killExecutor(sc: SparkContext, executorId: String, force: Boolean): Boolean = { + private def killExecutor(sc: SparkContext, executorId: String, force: Boolean): Seq[String] = { syncExecutors(sc) sc.schedulerBackend match { case b: CoarseGrainedSchedulerBackend => diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index f13f3ff789484..0a56a6b19e4ce 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -818,6 +818,9 @@ object MimaExcludes { ) ++ Seq( // [SPARK-17017] Add chiSquare selector based on False Positive Rate (FPR) test ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.feature.ChiSqSelectorModel.isSorted") + ) ++ Seq( + // [SPARK-17365][Core] Remove/Kill multiple executors together to reduce RPC call time + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.SparkContext") ) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManager.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManager.scala index fb5587edeccee..7b29b40668def 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManager.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManager.scala @@ -226,7 +226,7 @@ private[streaming] object ExecutorAllocationManager extends Logging { conf: SparkConf, batchDurationMs: Long, clock: Clock): Option[ExecutorAllocationManager] = { - if (isDynamicAllocationEnabled(conf)) { + if (isDynamicAllocationEnabled(conf) && client != null) { Some(new ExecutorAllocationManager(client, receiverTracker, conf, batchDurationMs, clock)) } else None } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala index 79d6254eb372b..dbc50da21c701 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala @@ -24,6 +24,7 @@ import scala.util.Failure import org.apache.commons.lang3.SerializationUtils +import org.apache.spark.ExecutorAllocationClient import org.apache.spark.internal.Logging import org.apache.spark.rdd.{PairRDDFunctions, RDD} import org.apache.spark.streaming._ @@ -83,8 +84,14 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { listenerBus.start() receiverTracker = new ReceiverTracker(ssc) inputInfoTracker = new InputInfoTracker(ssc) + + val executorAllocClient: ExecutorAllocationClient = ssc.sparkContext.schedulerBackend match { + case b: ExecutorAllocationClient => b.asInstanceOf[ExecutorAllocationClient] + case _ => null + } + executorAllocationManager = ExecutorAllocationManager.createIfEnabled( - ssc.sparkContext, + executorAllocClient, receiverTracker, ssc.conf, ssc.graph.batchDuration.milliseconds, diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManagerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManagerSuite.scala index 7630f4a75e336..b49e5790711cd 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManagerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManagerSuite.scala @@ -380,8 +380,9 @@ class ExecutorAllocationManagerSuite extends SparkFunSuite } private def withStreamingContext(conf: SparkConf)(body: StreamingContext => Unit): Unit = { - conf.setMaster("local").setAppName(this.getClass.getSimpleName).set( - "spark.streaming.dynamicAllocation.testing", "true") // to test dynamic allocation + conf.setMaster("myDummyLocalExternalClusterManager") + .setAppName(this.getClass.getSimpleName) + .set("spark.streaming.dynamicAllocation.testing", "true") // to test dynamic allocation var ssc: StreamingContext = null try { From 9f24a17c59b1130d97efa7d313c06577f7344338 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Thu, 22 Sep 2016 11:52:42 -0700 Subject: [PATCH 099/306] Skip building R vignettes if Spark is not built ## What changes were proposed in this pull request? When we build the docs separately we don't have the JAR files from the Spark build in the same tree. As the SparkR vignettes need to launch a SparkContext to be built, we skip building them if JAR files don't exist ## How was this patch tested? To test this we can run the following: ``` build/mvn -DskipTests -Psparkr clean ./R/create-docs.sh ``` You should see a line `Skipping R vignettes as Spark JARs not found` at the end Author: Shivaram Venkataraman Closes #15200 from shivaram/sparkr-vignette-skip. --- R/create-docs.sh | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/R/create-docs.sh b/R/create-docs.sh index 0dfba22463396..69ffc5f678c36 100755 --- a/R/create-docs.sh +++ b/R/create-docs.sh @@ -30,6 +30,13 @@ set -e # Figure out where the script is export FWDIR="$(cd "`dirname "$0"`"; pwd)" +export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" + +# Required for setting SPARK_SCALA_VERSION +. "${SPARK_HOME}"/bin/load-spark-env.sh + +echo "Using Scala $SPARK_SCALA_VERSION" + pushd $FWDIR # Install the package (this will also generate the Rd files) @@ -45,9 +52,21 @@ Rscript -e 'libDir <- "../../lib"; library(SparkR, lib.loc=libDir); library(knit popd -# render creates SparkR vignettes -Rscript -e 'library(rmarkdown); paths <- .libPaths(); .libPaths(c("lib", paths)); Sys.setenv(SPARK_HOME=tools::file_path_as_absolute("..")); render("pkg/vignettes/sparkr-vignettes.Rmd"); .libPaths(paths)' +# Find Spark jars. +if [ -f "${SPARK_HOME}/RELEASE" ]; then + SPARK_JARS_DIR="${SPARK_HOME}/jars" +else + SPARK_JARS_DIR="${SPARK_HOME}/assembly/target/scala-$SPARK_SCALA_VERSION/jars" +fi + +# Only create vignettes if Spark JARs exist +if [ -d "$SPARK_JARS_DIR" ]; then + # render creates SparkR vignettes + Rscript -e 'library(rmarkdown); paths <- .libPaths(); .libPaths(c("lib", paths)); Sys.setenv(SPARK_HOME=tools::file_path_as_absolute("..")); render("pkg/vignettes/sparkr-vignettes.Rmd"); .libPaths(paths)' -find pkg/vignettes/. -not -name '.' -not -name '*.Rmd' -not -name '*.md' -not -name '*.pdf' -not -name '*.html' -delete + find pkg/vignettes/. -not -name '.' -not -name '*.Rmd' -not -name '*.md' -not -name '*.pdf' -not -name '*.html' -delete +else + echo "Skipping R vignettes as Spark JARs not found in $SPARK_HOME" +fi popd From 85d609cf25c1da2df3cd4f5d5aeaf3cbcf0d674c Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Thu, 22 Sep 2016 13:05:41 -0700 Subject: [PATCH 100/306] [SPARK-17613] S3A base paths with no '/' at the end return empty DataFrames ## What changes were proposed in this pull request? Consider you have a bucket as `s3a://some-bucket` and under it you have files: ``` s3a://some-bucket/file1.parquet s3a://some-bucket/file2.parquet ``` Getting the parent path of `s3a://some-bucket/file1.parquet` yields `s3a://some-bucket/` and the ListingFileCatalog uses this as the key in the hash map. When catalog.allFiles is called, we use `s3a://some-bucket` (no slash at the end) to get the list of files, and we're left with an empty list! This PR fixes this by adding a `/` at the end of the `URI` iff the given `Path` doesn't have a parent, i.e. is the root. This is a no-op if the path already had a `/` at the end, and is handled through the Hadoop Path, path merging semantics. ## How was this patch tested? Unit test in `FileCatalogSuite`. Author: Burak Yavuz Closes #15169 from brkyvz/SPARK-17613. --- .../PartitioningAwareFileCatalog.scala | 10 ++++- .../datasources/FileCatalogSuite.scala | 45 ++++++++++++++++++- 2 files changed, 53 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileCatalog.scala index d2d5b56c82946..702ba97222e34 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileCatalog.scala @@ -76,7 +76,15 @@ abstract class PartitioningAwareFileCatalog( paths.flatMap { path => // Make the path qualified (consistent with listLeafFiles and listLeafFilesInParallel). val fs = path.getFileSystem(hadoopConf) - val qualifiedPath = fs.makeQualified(path) + val qualifiedPathPre = fs.makeQualified(path) + val qualifiedPath: Path = if (qualifiedPathPre.isRoot && !qualifiedPathPre.isAbsolute) { + // SPARK-17613: Always append `Path.SEPARATOR` to the end of parent directories, + // because the `leafFile.getParent` would have returned an absolute path with the + // separator at the end. + new Path(qualifiedPathPre, Path.SEPARATOR) + } else { + qualifiedPathPre + } // There are three cases possible with each path // 1. The path is a directory and has children files in it. Then it must be present in diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileCatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileCatalogSuite.scala index 5c8d3226e9e26..fa3abd0098f5b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileCatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileCatalogSuite.scala @@ -18,10 +18,12 @@ package org.apache.spark.sql.execution.datasources import java.io.File +import java.net.URI +import scala.collection.mutable import scala.language.reflectiveCalls -import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.{FileStatus, Path, RawLocalFileSystem} import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.test.SharedSQLContext @@ -78,4 +80,45 @@ class FileCatalogSuite extends SharedSQLContext { assert(catalog1.listLeafFiles(catalog1.paths).isEmpty) } } + + test("SPARK-17613 - PartitioningAwareFileCatalog: base path w/o '/' at end") { + class MockCatalog( + override val paths: Seq[Path]) extends PartitioningAwareFileCatalog(spark, Map.empty, None) { + + override def refresh(): Unit = {} + + override def leafFiles: mutable.LinkedHashMap[Path, FileStatus] = mutable.LinkedHashMap( + new Path("mockFs://some-bucket/file1.json") -> new FileStatus() + ) + + override def leafDirToChildrenFiles: Map[Path, Array[FileStatus]] = Map( + new Path("mockFs://some-bucket/") -> Array(new FileStatus()) + ) + + override def partitionSpec(): PartitionSpec = { + PartitionSpec.emptySpec + } + } + + withSQLConf( + "fs.mockFs.impl" -> classOf[FakeParentPathFileSystem].getName, + "fs.mockFs.impl.disable.cache" -> "true") { + val pathWithSlash = new Path("mockFs://some-bucket/") + assert(pathWithSlash.getParent === null) + val pathWithoutSlash = new Path("mockFs://some-bucket") + assert(pathWithoutSlash.getParent === null) + val catalog1 = new MockCatalog(Seq(pathWithSlash)) + val catalog2 = new MockCatalog(Seq(pathWithoutSlash)) + assert(catalog1.allFiles().nonEmpty) + assert(catalog2.allFiles().nonEmpty) + } + } +} + +class FakeParentPathFileSystem extends RawLocalFileSystem { + override def getScheme: String = "mockFs" + + override def getUri: URI = { + URI.create("mockFs://some-bucket") + } } From 3cdae0ff2f45643df7bc198cb48623526c7eb1a6 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Thu, 22 Sep 2016 14:26:45 -0700 Subject: [PATCH 101/306] [SPARK-17638][STREAMING] Stop JVM StreamingContext when the Python process is dead ## What changes were proposed in this pull request? When the Python process is dead, the JVM StreamingContext is still running. Hence we will see a lot of Py4jException before the JVM process exits. It's better to stop the JVM StreamingContext to avoid those annoying logs. ## How was this patch tested? Jenkins Author: Shixiong Zhu Closes #15201 from zsxwing/stop-jvm-ssc. --- .../streaming/api/python/PythonDStream.scala | 33 +++++++++++++++++-- .../streaming/scheduler/JobGenerator.scala | 2 ++ .../streaming/scheduler/JobScheduler.scala | 2 ++ 3 files changed, 35 insertions(+), 2 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala index aeff4d7a98e7a..46bfc60856453 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala @@ -24,11 +24,14 @@ import java.util.{ArrayList => JArrayList, List => JList} import scala.collection.JavaConverters._ import scala.language.existentials +import py4j.Py4JException + import org.apache.spark.SparkException import org.apache.spark.api.java._ +import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.{Duration, Interval, Time} +import org.apache.spark.streaming.{Duration, Interval, StreamingContext, Time} import org.apache.spark.streaming.api.java._ import org.apache.spark.streaming.dstream._ import org.apache.spark.util.Utils @@ -157,7 +160,7 @@ private[python] object PythonTransformFunctionSerializer { /** * Helper functions, which are called from Python via Py4J. */ -private[python] object PythonDStream { +private[streaming] object PythonDStream { /** * can not access PythonTransformFunctionSerializer.register() via Py4j @@ -184,6 +187,32 @@ private[python] object PythonDStream { rdds.asScala.foreach(queue.add) queue } + + /** + * Stop [[StreamingContext]] if the Python process crashes (E.g., OOM) in case the user cannot + * stop it in the Python side. + */ + def stopStreamingContextIfPythonProcessIsDead(e: Throwable): Unit = { + // These two special messages are from: + // scalastyle:off + // https://github.com/bartdag/py4j/blob/5cbb15a21f857e8cf334ce5f675f5543472f72eb/py4j-java/src/main/java/py4j/CallbackClient.java#L218 + // https://github.com/bartdag/py4j/blob/5cbb15a21f857e8cf334ce5f675f5543472f72eb/py4j-java/src/main/java/py4j/CallbackClient.java#L340 + // scalastyle:on + if (e.isInstanceOf[Py4JException] && + ("Cannot obtain a new communication channel" == e.getMessage || + "Error while obtaining a new communication channel" == e.getMessage)) { + // Start a new thread to stop StreamingContext to avoid deadlock. + new Thread("Stop-StreamingContext") with Logging { + setDaemon(true) + + override def run(): Unit = { + logError( + "Cannot connect to Python process. It's probably dead. Stopping StreamingContext.", e) + StreamingContext.getActive().foreach(_.stop(stopSparkContext = false)) + } + }.start() + } + } } /** diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala index 10d64f98ac71b..8d83dc8a8fc04 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala @@ -22,6 +22,7 @@ import scala.util.{Failure, Success, Try} import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.streaming.{Checkpoint, CheckpointWriter, Time} +import org.apache.spark.streaming.api.python.PythonDStream import org.apache.spark.streaming.util.RecurringTimer import org.apache.spark.util.{Clock, EventLoop, ManualClock, Utils} @@ -252,6 +253,7 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { jobScheduler.submitJobSet(JobSet(time, jobs, streamIdToInputInfos)) case Failure(e) => jobScheduler.reportError("Error generating jobs for time " + time, e) + PythonDStream.stopStreamingContextIfPythonProcessIsDead(e) } eventLoop.post(DoCheckpoint(time, clearCheckpointDataLater = false)) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala index dbc50da21c701..98e099354a7db 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala @@ -28,6 +28,7 @@ import org.apache.spark.ExecutorAllocationClient import org.apache.spark.internal.Logging import org.apache.spark.rdd.{PairRDDFunctions, RDD} import org.apache.spark.streaming._ +import org.apache.spark.streaming.api.python.PythonDStream import org.apache.spark.streaming.ui.UIUtils import org.apache.spark.util.{EventLoop, ThreadUtils} @@ -217,6 +218,7 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { private def handleError(msg: String, e: Throwable) { logError(msg, e) ssc.waiter.notifyError(e) + PythonDStream.stopStreamingContextIfPythonProcessIsDead(e) } private class JobHandler(job: Job) extends Runnable with Logging { From 0d634875026ccf1eaf984996e9460d7673561f80 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Thu, 22 Sep 2016 14:29:27 -0700 Subject: [PATCH 102/306] [SPARK-17616][SQL] Support a single distinct aggregate combined with a non-partial aggregate ## What changes were proposed in this pull request? We currently cannot execute an aggregate that contains a single distinct aggregate function and an one or more non-partially plannable aggregate functions, for example: ```sql select grp, collect_list(col1), count(distinct col2) from tbl_a group by 1 ``` This is a regression from Spark 1.6. This is caused by the fact that the single distinct aggregation code path assumes that all aggregates can be planned in two phases (is partially aggregatable). This PR works around this issue by triggering the `RewriteDistinctAggregates` in such cases (this is similar to the approach taken in 1.6). ## How was this patch tested? Created `RewriteDistinctAggregatesSuite` which checks if the aggregates with distinct aggregate functions get rewritten into two `Aggregates` and an `Expand`. Added a regression test to `DataFrameAggregateSuite`. Author: Herman van Hovell Closes #15187 from hvanhovell/SPARK-17616. --- .../optimizer/RewriteDistinctAggregates.scala | 18 ++-- .../RewriteDistinctAggregatesSuite.scala | 94 +++++++++++++++++++ .../spark/sql/DataFrameAggregateSuite.scala | 8 ++ 3 files changed, 111 insertions(+), 9 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index 0f43e7bb88733..d6a39ecf53b86 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -119,14 +119,16 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { .filter(_.isDistinct) .groupBy(_.aggregateFunction.children.toSet) - // Aggregation strategy can handle the query with single distinct - if (distinctAggGroups.size > 1) { + // Check if the aggregates contains functions that do not support partial aggregation. + val existsNonPartial = aggExpressions.exists(!_.aggregateFunction.supportsPartial) + + // Aggregation strategy can handle queries with a single distinct group and partial aggregates. + if (distinctAggGroups.size > 1 || (distinctAggGroups.size == 1 && existsNonPartial)) { // Create the attributes for the grouping id and the group by clause. - val gid = - new AttributeReference("gid", IntegerType, false)(isGenerated = true) + val gid = AttributeReference("gid", IntegerType, nullable = false)(isGenerated = true) val groupByMap = a.groupingExpressions.collect { case ne: NamedExpression => ne -> ne.toAttribute - case e => e -> new AttributeReference(e.sql, e.dataType, e.nullable)() + case e => e -> AttributeReference(e.sql, e.dataType, e.nullable)() } val groupByAttrs = groupByMap.map(_._2) @@ -135,9 +137,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { def patchAggregateFunctionChildren( af: AggregateFunction)( attrs: Expression => Expression): AggregateFunction = { - af.withNewChildren(af.children.map { - case afc => attrs(afc) - }).asInstanceOf[AggregateFunction] + af.withNewChildren(af.children.map(attrs)).asInstanceOf[AggregateFunction] } // Setup unique distinct aggregate children. @@ -265,5 +265,5 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { // NamedExpression. This is done to prevent collisions between distinct and regular aggregate // children, in this case attribute reuse causes the input of the regular aggregate to bound to // the (nulled out) input of the distinct aggregate. - e -> new AttributeReference(e.sql, e.dataType, true)() + e -> AttributeReference(e.sql, e.dataType, nullable = true)() } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala new file mode 100644 index 0000000000000..0b973c3b659cf --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.SimpleCatalystConf +import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry} +import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.{If, Literal} +import org.apache.spark.sql.catalyst.expressions.aggregate.{CollectSet, Count} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LocalRelation, LogicalPlan} +import org.apache.spark.sql.types.{IntegerType, StringType} + +class RewriteDistinctAggregatesSuite extends PlanTest { + val conf = SimpleCatalystConf(caseSensitiveAnalysis = false, groupByOrdinal = false) + val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf) + val analyzer = new Analyzer(catalog, conf) + + val nullInt = Literal(null, IntegerType) + val nullString = Literal(null, StringType) + val testRelation = LocalRelation('a.string, 'b.string, 'c.string, 'd.string, 'e.int) + + private def checkRewrite(rewrite: LogicalPlan): Unit = rewrite match { + case Aggregate(_, _, Aggregate(_, _, _: Expand)) => + case _ => fail(s"Plan is not rewritten:\n$rewrite") + } + + test("single distinct group") { + val input = testRelation + .groupBy('a)(countDistinct('e)) + .analyze + val rewrite = RewriteDistinctAggregates(input) + comparePlans(input, rewrite) + } + + test("single distinct group with partial aggregates") { + val input = testRelation + .groupBy('a, 'd)( + countDistinct('e, 'c).as('agg1), + max('b).as('agg2)) + .analyze + val rewrite = RewriteDistinctAggregates(input) + comparePlans(input, rewrite) + } + + test("single distinct group with non-partial aggregates") { + val input = testRelation + .groupBy('a, 'd)( + countDistinct('e, 'c).as('agg1), + CollectSet('b).toAggregateExpression().as('agg2)) + .analyze + checkRewrite(RewriteDistinctAggregates(input)) + } + + test("multiple distinct groups") { + val input = testRelation + .groupBy('a)(countDistinct('b, 'c), countDistinct('d)) + .analyze + checkRewrite(RewriteDistinctAggregates(input)) + } + + test("multiple distinct groups with partial aggregates") { + val input = testRelation + .groupBy('a)(countDistinct('b, 'c), countDistinct('d), sum('e)) + .analyze + checkRewrite(RewriteDistinctAggregates(input)) + } + + test("multiple distinct groups with non-partial aggregates") { + val input = testRelation + .groupBy('a)( + countDistinct('b, 'c), + countDistinct('d), + CollectSet('b).toAggregateExpression()) + .analyze + checkRewrite(RewriteDistinctAggregates(input)) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 427390a90f1e6..0e172bee4f661 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -493,4 +493,12 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(1.5)), Row(new java.math.BigDecimal(3.0), new java.math.BigDecimal(1.5)))) } + + test("SPARK-17616: distinct aggregate combined with a non-partial aggregate") { + val df = Seq((1, 3, "a"), (1, 2, "b"), (3, 4, "c"), (3, 4, "c"), (3, 5, "d")) + .toDF("x", "y", "z") + checkAnswer( + df.groupBy($"x").agg(countDistinct($"y"), sort_array(collect_list($"z"))), + Seq(Row(1, 2, Seq("a", "b")), Row(3, 2, Seq("c", "c", "d")))) + } } From f4f6bd8c9884e3919509907307fda774f56b5ecc Mon Sep 17 00:00:00 2001 From: Gayathri Murali Date: Thu, 22 Sep 2016 16:34:42 -0700 Subject: [PATCH 103/306] [SPARK-16240][ML] ML persistence backward compatibility for LDA ## What changes were proposed in this pull request? Allow Spark 2.x to load instances of LDA, LocalLDAModel, and DistributedLDAModel saved from Spark 1.6. ## How was this patch tested? I tested this manually, saving the 3 types from 1.6 and loading them into master (2.x). In the future, we can add generic tests for testing backwards compatibility across all ML models in SPARK-15573. Author: Joseph K. Bradley Closes #15034 from jkbradley/lda-backwards. --- .../org/apache/spark/ml/clustering/LDA.scala | 86 +++++++++++++++---- project/MimaExcludes.scala | 4 +- 2 files changed, 72 insertions(+), 18 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index b5a764b5863f1..7773802854c00 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -18,6 +18,9 @@ package org.apache.spark.ml.clustering import org.apache.hadoop.fs.Path +import org.json4s.DefaultFormats +import org.json4s.JsonAST.JObject +import org.json4s.jackson.JsonMethods._ import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} import org.apache.spark.internal.Logging @@ -26,19 +29,21 @@ import org.apache.spark.ml.linalg.{Matrix, Vector, Vectors, VectorUDT} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasCheckpointInterval, HasFeaturesCol, HasMaxIter, HasSeed} import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.DefaultParamsReader.Metadata import org.apache.spark.mllib.clustering.{DistributedLDAModel => OldDistributedLDAModel, EMLDAOptimizer => OldEMLDAOptimizer, LDA => OldLDA, LDAModel => OldLDAModel, LDAOptimizer => OldLDAOptimizer, LocalLDAModel => OldLocalLDAModel, OnlineLDAOptimizer => OldOnlineLDAOptimizer} import org.apache.spark.mllib.impl.PeriodicCheckpointer -import org.apache.spark.mllib.linalg.{Matrices => OldMatrices, Vector => OldVector, - Vectors => OldVectors} +import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors} import org.apache.spark.mllib.linalg.MatrixImplicits._ import org.apache.spark.mllib.linalg.VectorImplicits._ +import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} import org.apache.spark.sql.functions.{col, monotonically_increasing_id, udf} import org.apache.spark.sql.types.StructType +import org.apache.spark.util.VersionUtils private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasMaxIter @@ -80,6 +85,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM * - Values should be >= 0 * - default = uniformly (1.0 / k), following the implementation from * [[https://github.com/Blei-Lab/onlineldavb]]. + * * @group param */ @Since("1.6.0") @@ -121,6 +127,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM * - Value should be >= 0 * - default = (1.0 / k), following the implementation from * [[https://github.com/Blei-Lab/onlineldavb]]. + * * @group param */ @Since("1.6.0") @@ -354,6 +361,39 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM } } +private object LDAParams { + + /** + * Equivalent to [[DefaultParamsReader.getAndSetParams()]], but handles [[LDA]] and [[LDAModel]] + * formats saved with Spark 1.6, which differ from the formats in Spark 2.0+. + * + * @param model [[LDA]] or [[LDAModel]] instance. This instance will be modified with + * [[Param]] values extracted from metadata. + * @param metadata Loaded model metadata + */ + def getAndSetParams(model: LDAParams, metadata: Metadata): Unit = { + VersionUtils.majorMinorVersion(metadata.sparkVersion) match { + case (1, 6) => + implicit val format = DefaultFormats + metadata.params match { + case JObject(pairs) => + pairs.foreach { case (paramName, jsonValue) => + val origParam = + if (paramName == "topicDistribution") "topicDistributionCol" else paramName + val param = model.getParam(origParam) + val value = param.jsonDecode(compact(render(jsonValue))) + model.set(param, value) + } + case _ => + throw new IllegalArgumentException( + s"Cannot recognize JSON metadata: ${metadata.metadataJson}.") + } + case _ => // 2.0+ + DefaultParamsReader.getAndSetParams(model, metadata) + } + } +} + /** * :: Experimental :: @@ -418,11 +458,11 @@ sealed abstract class LDAModel private[ml] ( val transformer = oldLocalModel.getTopicDistributionMethod(sparkSession.sparkContext) val t = udf { (v: Vector) => transformer(OldVectors.fromML(v)).asML } - dataset.withColumn($(topicDistributionCol), t(col($(featuresCol)))).toDF + dataset.withColumn($(topicDistributionCol), t(col($(featuresCol)))).toDF() } else { logWarning("LDAModel.transform was called without any output columns. Set an output column" + " such as topicDistributionCol to produce results.") - dataset.toDF + dataset.toDF() } } @@ -578,18 +618,16 @@ object LocalLDAModel extends MLReadable[LocalLDAModel] { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString val data = sparkSession.read.parquet(dataPath) - .select("vocabSize", "topicsMatrix", "docConcentration", "topicConcentration", - "gammaShape") - .head() - val vocabSize = data.getAs[Int](0) - val topicsMatrix = data.getAs[Matrix](1) - val docConcentration = data.getAs[Vector](2) - val topicConcentration = data.getAs[Double](3) - val gammaShape = data.getAs[Double](4) + val vectorConverted = MLUtils.convertVectorColumnsToML(data, "docConcentration") + val matrixConverted = MLUtils.convertMatrixColumnsToML(vectorConverted, "topicsMatrix") + val Row(vocabSize: Int, topicsMatrix: Matrix, docConcentration: Vector, + topicConcentration: Double, gammaShape: Double) = + matrixConverted.select("vocabSize", "topicsMatrix", "docConcentration", + "topicConcentration", "gammaShape").head() val oldModel = new OldLocalLDAModel(topicsMatrix, docConcentration, topicConcentration, gammaShape) val model = new LocalLDAModel(metadata.uid, vocabSize, oldModel, sparkSession) - DefaultParamsReader.getAndSetParams(model, metadata) + LDAParams.getAndSetParams(model, metadata) model } } @@ -735,9 +773,9 @@ object DistributedLDAModel extends MLReadable[DistributedLDAModel] { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val modelPath = new Path(path, "oldModel").toString val oldModel = OldDistributedLDAModel.load(sc, modelPath) - val model = new DistributedLDAModel( - metadata.uid, oldModel.vocabSize, oldModel, sparkSession, None) - DefaultParamsReader.getAndSetParams(model, metadata) + val model = new DistributedLDAModel(metadata.uid, oldModel.vocabSize, + oldModel, sparkSession, None) + LDAParams.getAndSetParams(model, metadata) model } } @@ -885,7 +923,7 @@ class LDA @Since("1.6.0") ( } @Since("2.0.0") -object LDA extends DefaultParamsReadable[LDA] { +object LDA extends MLReadable[LDA] { /** Get dataset for spark.mllib LDA */ private[clustering] def getOldDataset( @@ -900,6 +938,20 @@ object LDA extends DefaultParamsReadable[LDA] { } } + private class LDAReader extends MLReader[LDA] { + + private val className = classOf[LDA].getName + + override def load(path: String): LDA = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val model = new LDA(metadata.uid) + LDAParams.getAndSetParams(model, metadata) + model + } + } + + override def read: MLReader[LDA] = new LDAReader + @Since("2.0.0") override def load(path: String): LDA = super.load(path) } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 0a56a6b19e4ce..b6f64e5a703ca 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -44,7 +44,9 @@ object MimaExcludes { // [SPARK-16853][SQL] Fixes encoder error in DataSet typed select ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.Dataset.select"), // [SPARK-16967] Move Mesos to Module - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkMasterRegex.MESOS_REGEX") + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkMasterRegex.MESOS_REGEX"), + // [SPARK-16240] ML persistence backward compatibility for LDA + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.clustering.LDA$") ) } From a1661968310de35e710e3b6784f63a77c44453fc Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Thu, 22 Sep 2016 16:50:22 -0700 Subject: [PATCH 104/306] [SPARK-17569][SPARK-17569][TEST] Make the unit test added for work again ## What changes were proposed in this pull request? A [PR](https://github.com/apache/spark/commit/a6aade0042d9c065669f46d2dac40ec6ce361e63) was merged concurrently that made the unit test for PR #15122 not test anything anymore. This PR fixes the test. ## How was this patch tested? Changed line https://github.com/apache/spark/blob/0d634875026ccf1eaf984996e9460d7673561f80/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala#L137 from `false` to `true` and made sure the unit test failed. Author: Burak Yavuz Closes #15203 from brkyvz/fix-test. --- .../spark/sql/execution/streaming/FileStreamSourceSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceSuite.scala index e8fa6a59c57ae..0795a0527f13a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceSuite.scala @@ -92,7 +92,7 @@ class FileStreamSourceSuite extends SparkFunSuite with SharedSQLContext { val dir = new File(temp, "dir") // use non-existent directory to test whether log make the dir val metadataLog = new FileStreamSourceLog(FileStreamSourceLog.VERSION, spark, dir.getAbsolutePath) - assert(metadataLog.add(0, Array(FileEntry(s"$scheme:///file1", 100L)))) + assert(metadataLog.add(0, Array(FileEntry(s"$scheme:///file1", 100L, 0)))) val newSource = new FileStreamSource(spark, s"$scheme:///", "parquet", StructType(Nil), dir.getAbsolutePath, Map.empty) From 79159a1e87f19fb08a36857fc30b600ee7fdc52b Mon Sep 17 00:00:00 2001 From: Yucai Yu Date: Thu, 22 Sep 2016 17:22:56 -0700 Subject: [PATCH 105/306] [SPARK-17635][SQL] Remove hardcode "agg_plan" in HashAggregateExec ## What changes were proposed in this pull request? "agg_plan" are hardcoded in HashAggregateExec, which have potential issue, so removing them. ## How was this patch tested? existing tests. Author: Yucai Yu Closes #15199 from yucai/agg_plan. --- .../spark/sql/execution/aggregate/HashAggregateExec.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 59e132dfb252d..06199ef3e8243 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -552,7 +552,7 @@ case class HashAggregateExec( } else { ctx.addMutableState(fastHashMapClassName, fastHashMapTerm, s"$fastHashMapTerm = new $fastHashMapClassName(" + - s"agg_plan.getTaskMemoryManager(), agg_plan.getEmptyAggregationBuffer());") + s"$thisPlan.getTaskMemoryManager(), $thisPlan.getEmptyAggregationBuffer());") ctx.addMutableState( "org.apache.spark.unsafe.KVIterator", iterTermForFastHashMap, "") From a4aeb7677bc07d0b83f82de62dcffd7867d19d9b Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 22 Sep 2016 21:35:25 -0700 Subject: [PATCH 106/306] [SPARK-17639][BUILD] Add jce.jar to buildclasspath when building. This was missing, preventing code that uses javax.crypto to properly compile in Spark. Author: Marcelo Vanzin Closes #15204 from vanzin/SPARK-17639. --- core/pom.xml | 4 +--- pom.xml | 7 ++++--- project/SparkBuild.scala | 4 ++-- 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/core/pom.xml b/core/pom.xml index 3c8138f974a56..9a4f234953a23 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -417,7 +417,6 @@ - \ .bat @@ -429,7 +428,6 @@ - / .sh @@ -450,7 +448,7 @@ - ..${path.separator}R${path.separator}install-dev${script.extension} + ..${file.separator}R${file.separator}install-dev${script.extension}
      diff --git a/pom.xml b/pom.xml index 8afc39bb46f80..8408f4b1fa5ed 100644 --- a/pom.xml +++ b/pom.xml @@ -2617,8 +2617,9 @@ -bootclasspath - ${env.JAVA_7_HOME}/jre/lib/rt.jar + ${env.JAVA_7_HOME}/jre/lib/rt.jar${path.separator}${env.JAVA_7_HOME}/jre/lib/jce.jar + true @@ -2633,7 +2634,7 @@ -javabootclasspath - ${env.JAVA_7_HOME}/jre/lib/rt.jar + ${env.JAVA_7_HOME}/jre/lib/rt.jar${path.separator}${env.JAVA_7_HOME}/jre/lib/jce.jar @@ -2642,7 +2643,7 @@ -javabootclasspath - ${env.JAVA_7_HOME}/jre/lib/rt.jar + ${env.JAVA_7_HOME}/jre/lib/rt.jar${path.separator}${env.JAVA_7_HOME}/jre/lib/jce.jar diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index a39c93e9574fa..8e47e7f13d367 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -280,7 +280,7 @@ object SparkBuild extends PomBuild { "-target", javacJVMVersion.value ) ++ sys.env.get("JAVA_7_HOME").toSeq.flatMap { jdk7 => if (javacJVMVersion.value == "1.7") { - Seq("-bootclasspath", s"$jdk7/jre/lib/rt.jar") + Seq("-bootclasspath", s"$jdk7/jre/lib/rt.jar${File.pathSeparator}$jdk7/jre/lib/jce.jar") } else { Nil } @@ -291,7 +291,7 @@ object SparkBuild extends PomBuild { "-sourcepath", (baseDirectory in ThisBuild).value.getAbsolutePath // Required for relative source links in scaladoc ) ++ sys.env.get("JAVA_7_HOME").toSeq.flatMap { jdk7 => if (javacJVMVersion.value == "1.7") { - Seq("-javabootclasspath", s"$jdk7/jre/lib/rt.jar") + Seq("-javabootclasspath", s"$jdk7/jre/lib/rt.jar${File.pathSeparator}$jdk7/jre/lib/jce.jar") } else { Nil } From 947b8c6e3acd671d501f0ed6c077aac8e51ccede Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Thu, 22 Sep 2016 22:27:28 -0700 Subject: [PATCH 107/306] [SPARK-16719][ML] Random Forests should communicate fewer trees on each iteration ## What changes were proposed in this pull request? RandomForest currently sends the entire forest to each worker on each iteration. This is because (a) the node queue is FIFO and (b) the closure references the entire array of trees (topNodes). (a) causes RFs to handle splits in many trees, especially early on in learning. (b) sends all trees explicitly. This PR: (a) Change the RF node queue to be FILO (a stack), so that RFs tend to focus on 1 or a few trees before focusing on others. (b) Change topNodes to pass only the trees required on that iteration. ## How was this patch tested? Unit tests: * Existing tests for correctness of tree learning * Manually modifying code and running tests to verify that a small number of trees are communicated on each iteration * This last item is hard to test via unit tests given the current APIs. Author: Joseph K. Bradley Closes #14359 from jkbradley/rfs-fewer-trees. --- .../spark/ml/tree/impl/RandomForest.scala | 54 +++++++++++-------- .../ml/tree/impl/RandomForestSuite.scala | 26 ++++----- 2 files changed, 46 insertions(+), 34 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 71c8c42ce5eba..0b7ad92b3cf30 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -51,7 +51,7 @@ import org.apache.spark.util.random.{SamplingUtils, XORShiftRandom} * findSplits() method during initialization, after which each continuous feature becomes * an ordered discretized feature with at most maxBins possible values. * - * The main loop in the algorithm operates on a queue of nodes (nodeQueue). These nodes + * The main loop in the algorithm operates on a queue of nodes (nodeStack). These nodes * lie at the periphery of the tree being trained. If multiple trees are being trained at once, * then this queue contains nodes from all of them. Each iteration works roughly as follows: * On the master node: @@ -161,31 +161,42 @@ private[spark] object RandomForest extends Logging { None } - // FIFO queue of nodes to train: (treeIndex, node) - val nodeQueue = new mutable.Queue[(Int, LearningNode)]() + /* + Stack of nodes to train: (treeIndex, node) + The reason this is a stack is that we train many trees at once, but we want to focus on + completing trees, rather than training all simultaneously. If we are splitting nodes from + 1 tree, then the new nodes to split will be put at the top of this stack, so we will continue + training the same tree in the next iteration. This focus allows us to send fewer trees to + workers on each iteration; see topNodesForGroup below. + */ + val nodeStack = new mutable.Stack[(Int, LearningNode)] val rng = new Random() rng.setSeed(seed) // Allocate and queue root nodes. val topNodes = Array.fill[LearningNode](numTrees)(LearningNode.emptyNode(nodeIndex = 1)) - Range(0, numTrees).foreach(treeIndex => nodeQueue.enqueue((treeIndex, topNodes(treeIndex)))) + Range(0, numTrees).foreach(treeIndex => nodeStack.push((treeIndex, topNodes(treeIndex)))) timer.stop("init") - while (nodeQueue.nonEmpty) { + while (nodeStack.nonEmpty) { // Collect some nodes to split, and choose features for each node (if subsampling). // Each group of nodes may come from one or multiple trees, and at multiple levels. val (nodesForGroup, treeToNodeToIndexInfo) = - RandomForest.selectNodesToSplit(nodeQueue, maxMemoryUsage, metadata, rng) + RandomForest.selectNodesToSplit(nodeStack, maxMemoryUsage, metadata, rng) // Sanity check (should never occur): assert(nodesForGroup.nonEmpty, s"RandomForest selected empty nodesForGroup. Error for unknown reason.") + // Only send trees to worker if they contain nodes being split this iteration. + val topNodesForGroup: Map[Int, LearningNode] = + nodesForGroup.keys.map(treeIdx => treeIdx -> topNodes(treeIdx)).toMap + // Choose node splits, and enqueue new nodes as needed. timer.start("findBestSplits") - RandomForest.findBestSplits(baggedInput, metadata, topNodes, nodesForGroup, - treeToNodeToIndexInfo, splits, nodeQueue, timer, nodeIdCache) + RandomForest.findBestSplits(baggedInput, metadata, topNodesForGroup, nodesForGroup, + treeToNodeToIndexInfo, splits, nodeStack, timer, nodeIdCache) timer.stop("findBestSplits") } @@ -334,13 +345,14 @@ private[spark] object RandomForest extends Logging { * * @param input Training data: RDD of [[org.apache.spark.ml.tree.impl.TreePoint]] * @param metadata Learning and dataset metadata - * @param topNodes Root node for each tree. Used for matching instances with nodes. + * @param topNodesForGroup For each tree in group, tree index -> root node. + * Used for matching instances with nodes. * @param nodesForGroup Mapping: treeIndex --> nodes to be split in tree * @param treeToNodeToIndexInfo Mapping: treeIndex --> nodeIndex --> nodeIndexInfo, * where nodeIndexInfo stores the index in the group and the * feature subsets (if using feature subsets). * @param splits possible splits for all features, indexed (numFeatures)(numSplits) - * @param nodeQueue Queue of nodes to split, with values (treeIndex, node). + * @param nodeStack Queue of nodes to split, with values (treeIndex, node). * Updated with new non-leaf nodes which are created. * @param nodeIdCache Node Id cache containing an RDD of Array[Int] where * each value in the array is the data point's node Id @@ -351,11 +363,11 @@ private[spark] object RandomForest extends Logging { private[tree] def findBestSplits( input: RDD[BaggedPoint[TreePoint]], metadata: DecisionTreeMetadata, - topNodes: Array[LearningNode], + topNodesForGroup: Map[Int, LearningNode], nodesForGroup: Map[Int, Array[LearningNode]], treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]], splits: Array[Array[Split]], - nodeQueue: mutable.Queue[(Int, LearningNode)], + nodeStack: mutable.Stack[(Int, LearningNode)], timer: TimeTracker = new TimeTracker, nodeIdCache: Option[NodeIdCache] = None): Unit = { @@ -437,7 +449,8 @@ private[spark] object RandomForest extends Logging { agg: Array[DTStatsAggregator], baggedPoint: BaggedPoint[TreePoint]): Array[DTStatsAggregator] = { treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) => - val nodeIndex = topNodes(treeIndex).predictImpl(baggedPoint.datum.binnedFeatures, splits) + val nodeIndex = + topNodesForGroup(treeIndex).predictImpl(baggedPoint.datum.binnedFeatures, splits) nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), agg, baggedPoint) } agg @@ -593,10 +606,10 @@ private[spark] object RandomForest extends Logging { // enqueue left child and right child if they are not leaves if (!leftChildIsLeaf) { - nodeQueue.enqueue((treeIndex, node.leftChild.get)) + nodeStack.push((treeIndex, node.leftChild.get)) } if (!rightChildIsLeaf) { - nodeQueue.enqueue((treeIndex, node.rightChild.get)) + nodeStack.push((treeIndex, node.rightChild.get)) } logDebug("leftChildIndex = " + node.leftChild.get.id + @@ -1029,7 +1042,7 @@ private[spark] object RandomForest extends Logging { * will be needed; this allows an adaptive number of nodes since different nodes may require * different amounts of memory (if featureSubsetStrategy is not "all"). * - * @param nodeQueue Queue of nodes to split. + * @param nodeStack Queue of nodes to split. * @param maxMemoryUsage Bound on size of aggregate statistics. * @return (nodesForGroup, treeToNodeToIndexInfo). * nodesForGroup holds the nodes to split: treeIndex --> nodes in tree. @@ -1041,7 +1054,7 @@ private[spark] object RandomForest extends Logging { * The feature indices are None if not subsampling features. */ private[tree] def selectNodesToSplit( - nodeQueue: mutable.Queue[(Int, LearningNode)], + nodeStack: mutable.Stack[(Int, LearningNode)], maxMemoryUsage: Long, metadata: DecisionTreeMetadata, rng: Random): (Map[Int, Array[LearningNode]], Map[Int, Map[Int, NodeIndexInfo]]) = { @@ -1054,8 +1067,8 @@ private[spark] object RandomForest extends Logging { var numNodesInGroup = 0 // If maxMemoryInMB is set very small, we want to still try to split 1 node, // so we allow one iteration if memUsage == 0. - while (nodeQueue.nonEmpty && (memUsage < maxMemoryUsage || memUsage == 0)) { - val (treeIndex, node) = nodeQueue.head + while (nodeStack.nonEmpty && (memUsage < maxMemoryUsage || memUsage == 0)) { + val (treeIndex, node) = nodeStack.top // Choose subset of features for node (if subsampling). val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) { Some(SamplingUtils.reservoirSampleAndCount(Range(0, @@ -1066,7 +1079,7 @@ private[spark] object RandomForest extends Logging { // Check if enough memory remains to add this node to the group. val nodeMemUsage = RandomForest.aggregateSizeForNode(metadata, featureSubset) * 8L if (memUsage + nodeMemUsage <= maxMemoryUsage || memUsage == 0) { - nodeQueue.dequeue() + nodeStack.pop() mutableNodesForGroup.getOrElseUpdate(treeIndex, new mutable.ArrayBuffer[LearningNode]()) += node mutableTreeToNodeToIndexInfo @@ -1109,5 +1122,4 @@ private[spark] object RandomForest extends Logging { 3 * totalBins } } - } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala index dcc2f305df75a..79b19ea5ad206 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala @@ -26,7 +26,8 @@ import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.tree._ import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.tree.{DecisionTreeSuite => OldDTSuite, EnsembleTestHelper} -import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, QuantileStrategy, Strategy => OldStrategy} +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, QuantileStrategy, + Strategy => OldStrategy} import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, GiniCalculator} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.util.collection.OpenHashMap @@ -239,12 +240,12 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { val treeToNodeToIndexInfo = Map((0, Map( (topNode.id, new RandomForest.NodeIndexInfo(0, None)) ))) - val nodeQueue = new mutable.Queue[(Int, LearningNode)]() - RandomForest.findBestSplits(baggedInput, metadata, Array(topNode), - nodesForGroup, treeToNodeToIndexInfo, splits, nodeQueue) + val nodeStack = new mutable.Stack[(Int, LearningNode)] + RandomForest.findBestSplits(baggedInput, metadata, Map(0 -> topNode), + nodesForGroup, treeToNodeToIndexInfo, splits, nodeStack) // don't enqueue leaf nodes into node queue - assert(nodeQueue.isEmpty) + assert(nodeStack.isEmpty) // set impurity and predict for topNode assert(topNode.stats !== null) @@ -281,12 +282,12 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { val treeToNodeToIndexInfo = Map((0, Map( (topNode.id, new RandomForest.NodeIndexInfo(0, None)) ))) - val nodeQueue = new mutable.Queue[(Int, LearningNode)]() - RandomForest.findBestSplits(baggedInput, metadata, Array(topNode), - nodesForGroup, treeToNodeToIndexInfo, splits, nodeQueue) + val nodeStack = new mutable.Stack[(Int, LearningNode)] + RandomForest.findBestSplits(baggedInput, metadata, Map(0 -> topNode), + nodesForGroup, treeToNodeToIndexInfo, splits, nodeStack) // don't enqueue a node into node queue if its impurity is 0.0 - assert(nodeQueue.isEmpty) + assert(nodeStack.isEmpty) // set impurity and predict for topNode assert(topNode.stats !== null) @@ -393,16 +394,16 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { val failString = s"Failed on test with:" + s"numTrees=$numTrees, featureSubsetStrategy=$featureSubsetStrategy," + s" numFeaturesPerNode=$numFeaturesPerNode, seed=$seed" - val nodeQueue = new mutable.Queue[(Int, LearningNode)]() + val nodeStack = new mutable.Stack[(Int, LearningNode)] val topNodes: Array[LearningNode] = new Array[LearningNode](numTrees) Range(0, numTrees).foreach { treeIndex => topNodes(treeIndex) = LearningNode.emptyNode(nodeIndex = 1) - nodeQueue.enqueue((treeIndex, topNodes(treeIndex))) + nodeStack.push((treeIndex, topNodes(treeIndex))) } val rng = new scala.util.Random(seed = seed) val (nodesForGroup: Map[Int, Array[LearningNode]], treeToNodeToIndexInfo: Map[Int, Map[Int, RandomForest.NodeIndexInfo]]) = - RandomForest.selectNodesToSplit(nodeQueue, maxMemoryUsage, metadata, rng) + RandomForest.selectNodesToSplit(nodeStack, maxMemoryUsage, metadata, rng) assert(nodesForGroup.size === numTrees, failString) assert(nodesForGroup.values.forall(_.length == 1), failString) // 1 node per tree @@ -546,7 +547,6 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { val expected = Map(0 -> 1.0 / 3.0, 2 -> 2.0 / 3.0) assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01) } - } private object RandomForestSuite { From 62ccf27ab4b55e734646678ae78b7e812262d14b Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Thu, 22 Sep 2016 23:35:08 -0700 Subject: [PATCH 108/306] [SPARK-17640][SQL] Avoid using -1 as the default batchId for FileStreamSource.FileEntry ## What changes were proposed in this pull request? Avoid using -1 as the default batchId for FileStreamSource.FileEntry so that we can make sure not writing any FileEntry(..., batchId = -1) into the log. This also avoids people misusing it in future (#15203 is an example). ## How was this patch tested? Jenkins. Author: Shixiong Zhu Closes #15206 from zsxwing/cleanup. --- .../streaming/FileStreamSource.scala | 37 ++++++++++--------- .../streaming/FileStreamSourceSuite.scala | 24 ++++++------ 2 files changed, 31 insertions(+), 30 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala index 5ebc083a7da92..be023273db2f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala @@ -59,7 +59,7 @@ class FileStreamSource( val seenFiles = new SeenFilesMap(sourceOptions.maxFileAgeMs) metadataLog.allFiles().foreach { entry => - seenFiles.add(entry) + seenFiles.add(entry.path, entry.timestamp) } seenFiles.purge() @@ -73,14 +73,16 @@ class FileStreamSource( */ private def fetchMaxOffset(): LongOffset = synchronized { // All the new files found - ignore aged files and files that we have seen. - val newFiles = fetchAllFiles().filter(seenFiles.isNewFile) + val newFiles = fetchAllFiles().filter { + case (path, timestamp) => seenFiles.isNewFile(path, timestamp) + } // Obey user's setting to limit the number of files in this batch trigger. val batchFiles = if (maxFilesPerBatch.nonEmpty) newFiles.take(maxFilesPerBatch.get) else newFiles batchFiles.foreach { file => - seenFiles.add(file) + seenFiles.add(file._1, file._2) logDebug(s"New file: $file") } val numPurged = seenFiles.purge() @@ -95,7 +97,9 @@ class FileStreamSource( if (batchFiles.nonEmpty) { maxBatchId += 1 - metadataLog.add(maxBatchId, batchFiles.map(_.copy(batchId = maxBatchId)).toArray) + metadataLog.add(maxBatchId, batchFiles.map { case (path, timestamp) => + FileEntry(path = path, timestamp = timestamp, batchId = maxBatchId) + }.toArray) logInfo(s"Max batch id increased to $maxBatchId with ${batchFiles.size} new files") } @@ -140,12 +144,12 @@ class FileStreamSource( /** * Returns a list of files found, sorted by their timestamp. */ - private def fetchAllFiles(): Seq[FileEntry] = { + private def fetchAllFiles(): Seq[(String, Long)] = { val startTime = System.nanoTime val globbedPaths = SparkHadoopUtil.get.globPathIfNecessary(qualifiedBasePath) val catalog = new ListingFileCatalog(sparkSession, globbedPaths, options, Some(new StructType)) val files = catalog.allFiles().sortBy(_.getModificationTime).map { status => - FileEntry(status.getPath.toUri.toString, status.getModificationTime) + (status.getPath.toUri.toString, status.getModificationTime) } val endTime = System.nanoTime val listingTimeMs = (endTime.toDouble - startTime) / 1000000 @@ -172,10 +176,7 @@ object FileStreamSource { /** Timestamp for file modification time, in ms since January 1, 1970 UTC. */ type Timestamp = Long - val NOT_SET = -1L - - case class FileEntry(path: String, timestamp: Timestamp, batchId: Long = NOT_SET) - extends Serializable + case class FileEntry(path: String, timestamp: Timestamp, batchId: Long) extends Serializable /** * A custom hash map used to track the list of files seen. This map is not thread-safe. @@ -196,10 +197,10 @@ object FileStreamSource { private var lastPurgeTimestamp: Timestamp = 0L /** Add a new file to the map. */ - def add(file: FileEntry): Unit = { - map.put(file.path, file.timestamp) - if (file.timestamp > latestTimestamp) { - latestTimestamp = file.timestamp + def add(path: String, timestamp: Timestamp): Unit = { + map.put(path, timestamp) + if (timestamp > latestTimestamp) { + latestTimestamp = timestamp } } @@ -207,10 +208,10 @@ object FileStreamSource { * Returns true if we should consider this file a new file. The file is only considered "new" * if it is new enough that we are still tracking, and we have not seen it before. */ - def isNewFile(file: FileEntry): Boolean = { + def isNewFile(path: String, timestamp: Timestamp): Boolean = { // Note that we are testing against lastPurgeTimestamp here so we'd never miss a file that // is older than (latestTimestamp - maxAgeMs) but has not been purged yet. - file.timestamp >= lastPurgeTimestamp && !map.containsKey(file.path) + timestamp >= lastPurgeTimestamp && !map.containsKey(path) } /** Removes aged entries and returns the number of files removed. */ @@ -230,8 +231,8 @@ object FileStreamSource { def size: Int = map.size() - def allEntries: Seq[FileEntry] = { - map.entrySet().asScala.map(entry => FileEntry(entry.getKey, entry.getValue)).toSeq + def allEntries: Seq[(String, Timestamp)] = { + map.asScala.toSeq } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceSuite.scala index 0795a0527f13a..3e1e1126f9e6b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceSuite.scala @@ -36,51 +36,51 @@ class FileStreamSourceSuite extends SparkFunSuite with SharedSQLContext { test("SeenFilesMap") { val map = new SeenFilesMap(maxAgeMs = 10) - map.add(FileEntry("a", 5)) + map.add("a", 5) assert(map.size == 1) map.purge() assert(map.size == 1) // Add a new entry and purge should be no-op, since the gap is exactly 10 ms. - map.add(FileEntry("b", 15)) + map.add("b", 15) assert(map.size == 2) map.purge() assert(map.size == 2) // Add a new entry that's more than 10 ms than the first entry. We should be able to purge now. - map.add(FileEntry("c", 16)) + map.add("c", 16) assert(map.size == 3) map.purge() assert(map.size == 2) // Override existing entry shouldn't change the size - map.add(FileEntry("c", 25)) + map.add("c", 25) assert(map.size == 2) // Not a new file because we have seen c before - assert(!map.isNewFile(FileEntry("c", 20))) + assert(!map.isNewFile("c", 20)) // Not a new file because timestamp is too old - assert(!map.isNewFile(FileEntry("d", 5))) + assert(!map.isNewFile("d", 5)) // Finally a new file: never seen and not too old - assert(map.isNewFile(FileEntry("e", 20))) + assert(map.isNewFile("e", 20)) } test("SeenFilesMap should only consider a file old if it is earlier than last purge time") { val map = new SeenFilesMap(maxAgeMs = 10) - map.add(FileEntry("a", 20)) + map.add("a", 20) assert(map.size == 1) // Timestamp 5 should still considered a new file because purge time should be 0 - assert(map.isNewFile(FileEntry("b", 9))) - assert(map.isNewFile(FileEntry("b", 10))) + assert(map.isNewFile("b", 9)) + assert(map.isNewFile("b", 10)) // Once purge, purge time should be 10 and then b would be a old file if it is less than 10. map.purge() - assert(!map.isNewFile(FileEntry("b", 9))) - assert(map.isNewFile(FileEntry("b", 10))) + assert(!map.isNewFile("b", 9)) + assert(map.isNewFile("b", 10)) } testWithUninterruptibleThread("do not recheck that files exist during getBatch") { From 5c5396cb4725ba5ceee26ed885e8b941d219757b Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 23 Sep 2016 09:41:50 +0100 Subject: [PATCH 109/306] [BUILD] Closes some stale PRs ## What changes were proposed in this pull request? This PR proposes to close some stale PRs and ones suggested to be closed by committer(s) Closes #12415 Closes #14765 Closes #15118 Closes #15184 Closes #15183 Closes #9440 Closes #15023 Closes #14643 Closes #14827 ## How was this patch tested? N/A Author: hyukjinkwon Closes #15198 from HyukjinKwon/stale-prs. From 90d5754212425d55f992c939a2bc7d9ac6ef92b8 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Fri, 23 Sep 2016 09:44:30 +0100 Subject: [PATCH 110/306] [SPARK-16861][PYSPARK][CORE] Refactor PySpark accumulator API on top of Accumulator V2 ## What changes were proposed in this pull request? Move the internals of the PySpark accumulator API from the old deprecated API on top of the new accumulator API. ## How was this patch tested? The existing PySpark accumulator tests (both unit tests and doc tests at the start of accumulator.py). Author: Holden Karau Closes #14467 from holdenk/SPARK-16861-refactor-pyspark-accumulator-api. --- .../apache/spark/api/python/PythonRDD.scala | 42 ++++++++++--------- python/pyspark/context.py | 5 +-- 2 files changed, 25 insertions(+), 22 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index d841091a316b1..0ca91b9bf86c6 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -20,7 +20,7 @@ package org.apache.spark.api.python import java.io._ import java.net._ import java.nio.charset.StandardCharsets -import java.util.{ArrayList => JArrayList, Collections, List => JList, Map => JMap} +import java.util.{ArrayList => JArrayList, List => JList, Map => JMap} import scala.collection.JavaConverters._ import scala.collection.mutable @@ -38,7 +38,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.input.PortableDataStream import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD -import org.apache.spark.util.{SerializableConfiguration, Utils} +import org.apache.spark.util._ private[spark] class PythonRDD( @@ -75,7 +75,7 @@ private[spark] case class PythonFunction( pythonExec: String, pythonVer: String, broadcastVars: JList[Broadcast[PythonBroadcast]], - accumulator: Accumulator[JList[Array[Byte]]]) + accumulator: PythonAccumulatorV2) /** * A wrapper for chained Python functions (from bottom to top). @@ -200,7 +200,7 @@ private[spark] class PythonRunner( val updateLen = stream.readInt() val update = new Array[Byte](updateLen) stream.readFully(update) - accumulator += Collections.singletonList(update) + accumulator.add(update) } // Check whether the worker is ready to be re-used. if (stream.readInt() == SpecialLengths.END_OF_STREAM) { @@ -461,7 +461,7 @@ private[spark] object PythonRDD extends Logging { JavaRDD[Array[Byte]] = { val file = new DataInputStream(new FileInputStream(filename)) try { - val objs = new collection.mutable.ArrayBuffer[Array[Byte]] + val objs = new mutable.ArrayBuffer[Array[Byte]] try { while (true) { val length = file.readInt() @@ -866,11 +866,13 @@ class BytesToString extends org.apache.spark.api.java.function.Function[Array[By } /** - * Internal class that acts as an `AccumulatorParam` for Python accumulators. Inside, it + * Internal class that acts as an `AccumulatorV2` for Python accumulators. Inside, it * collects a list of pickled strings that we pass to Python through a socket. */ -private class PythonAccumulatorParam(@transient private val serverHost: String, serverPort: Int) - extends AccumulatorParam[JList[Array[Byte]]] { +private[spark] class PythonAccumulatorV2( + @transient private val serverHost: String, + private val serverPort: Int) + extends CollectionAccumulator[Array[Byte]] { Utils.checkHost(serverHost, "Expected hostname") @@ -880,30 +882,33 @@ private class PythonAccumulatorParam(@transient private val serverHost: String, * We try to reuse a single Socket to transfer accumulator updates, as they are all added * by the DAGScheduler's single-threaded RpcEndpoint anyway. */ - @transient var socket: Socket = _ + @transient private var socket: Socket = _ - def openSocket(): Socket = synchronized { + private def openSocket(): Socket = synchronized { if (socket == null || socket.isClosed) { socket = new Socket(serverHost, serverPort) } socket } - override def zero(value: JList[Array[Byte]]): JList[Array[Byte]] = new JArrayList + // Need to override so the types match with PythonFunction + override def copyAndReset(): PythonAccumulatorV2 = new PythonAccumulatorV2(serverHost, serverPort) - override def addInPlace(val1: JList[Array[Byte]], val2: JList[Array[Byte]]) - : JList[Array[Byte]] = synchronized { + override def merge(other: AccumulatorV2[Array[Byte], JList[Array[Byte]]]): Unit = synchronized { + val otherPythonAccumulator = other.asInstanceOf[PythonAccumulatorV2] + // This conditional isn't strictly speaking needed - merging only currently happens on the + // driver program - but that isn't gauranteed so incase this changes. if (serverHost == null) { - // This happens on the worker node, where we just want to remember all the updates - val1.addAll(val2) - val1 + // We are on the worker + super.merge(otherPythonAccumulator) } else { // This happens on the master, where we pass the updates to Python through a socket val socket = openSocket() val in = socket.getInputStream val out = new DataOutputStream(new BufferedOutputStream(socket.getOutputStream, bufferSize)) - out.writeInt(val2.size) - for (array <- val2.asScala) { + val values = other.value + out.writeInt(values.size) + for (array <- values.asScala) { out.writeInt(array.length) out.write(array) } @@ -913,7 +918,6 @@ private class PythonAccumulatorParam(@transient private val serverHost: String, if (byteRead == -1) { throw new SparkException("EOF reached before Python server acknowledged") } - null } } } diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 7a7f59cb50a86..a3dd1950a522f 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -173,9 +173,8 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, # they will be passed back to us through a TCP server self._accumulatorServer = accumulators._start_update_server() (host, port) = self._accumulatorServer.server_address - self._javaAccumulator = self._jsc.accumulator( - self._jvm.java.util.ArrayList(), - self._jvm.PythonAccumulatorParam(host, port)) + self._javaAccumulator = self._jvm.PythonAccumulatorV2(host, port) + self._jsc.sc().register(self._javaAccumulator) self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python') self.pythonVer = "%d.%d" % sys.version_info[:2] From f89808b0fdbc04e1bdff1489a6ec4c84ddb2adc4 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Fri, 23 Sep 2016 11:14:22 -0700 Subject: [PATCH 111/306] [SPARK-17499][SPARKR][ML][MLLIB] make the default params in sparkR spark.mlp consistent with MultilayerPerceptronClassifier ## What changes were proposed in this pull request? update `MultilayerPerceptronClassifierWrapper.fit` paramter type: `layers: Array[Int]` `seed: String` update several default params in sparkR `spark.mlp`: `tol` --> 1e-6 `stepSize` --> 0.03 `seed` --> NULL ( when seed == NULL, the scala-side wrapper regard it as a `null` value and the seed will use the default one ) r-side `seed` only support 32bit integer. remove `layers` default value, and move it in front of those parameters with default value. add `layers` parameter validation check. ## How was this patch tested? tests added. Author: WeichenXu Closes #15051 from WeichenXu123/update_py_mlp_default. --- R/pkg/R/mllib.R | 13 ++++++++++--- R/pkg/inst/tests/testthat/test_mllib.R | 19 +++++++++++++++++++ ...ultilayerPerceptronClassifierWrapper.scala | 8 ++++---- 3 files changed, 33 insertions(+), 7 deletions(-) diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index 98db367a856ee..971c16658fe9a 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -694,12 +694,19 @@ setMethod("predict", signature(object = "KMeansModel"), #' } #' @note spark.mlp since 2.1.0 setMethod("spark.mlp", signature(data = "SparkDataFrame"), - function(data, blockSize = 128, layers = c(3, 5, 2), solver = "l-bfgs", maxIter = 100, - tol = 0.5, stepSize = 1, seed = 1) { + function(data, layers, blockSize = 128, solver = "l-bfgs", maxIter = 100, + tol = 1E-6, stepSize = 0.03, seed = NULL) { + layers <- as.integer(na.omit(layers)) + if (length(layers) <= 1) { + stop ("layers must be a integer vector with length > 1.") + } + if (!is.null(seed)) { + seed <- as.character(as.integer(seed)) + } jobj <- callJStatic("org.apache.spark.ml.r.MultilayerPerceptronClassifierWrapper", "fit", data@sdf, as.integer(blockSize), as.array(layers), as.character(solver), as.integer(maxIter), as.numeric(tol), - as.numeric(stepSize), as.integer(seed)) + as.numeric(stepSize), seed) new("MultilayerPerceptronClassificationModel", jobj = jobj) }) diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index 24c40a88231a7..a1eaaf20916a2 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -391,6 +391,25 @@ test_that("spark.mlp", { unlink(modelPath) + # Test default parameter + model <- spark.mlp(df, layers = c(4, 5, 4, 3)) + mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) + expect_equal(head(mlpPredictions$prediction, 10), c(1, 1, 1, 1, 0, 1, 2, 2, 1, 0)) + + # Test illegal parameter + expect_error(spark.mlp(df, layers = NULL), "layers must be a integer vector with length > 1.") + expect_error(spark.mlp(df, layers = c()), "layers must be a integer vector with length > 1.") + expect_error(spark.mlp(df, layers = c(3)), "layers must be a integer vector with length > 1.") + + # Test random seed + # default seed + model <- spark.mlp(df, layers = c(4, 5, 4, 3), maxIter = 10) + mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) + expect_equal(head(mlpPredictions$prediction, 12), c(1, 1, 1, 1, 0, 1, 2, 2, 1, 2, 0, 1)) + # seed equals 10 + model <- spark.mlp(df, layers = c(4, 5, 4, 3), maxIter = 10, seed = 10) + mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) + expect_equal(head(mlpPredictions$prediction, 12), c(1, 1, 1, 1, 2, 1, 2, 2, 1, 0, 0, 1)) }) test_that("spark.naiveBayes", { diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala index be51e74187faa..10673003534e6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala @@ -53,26 +53,26 @@ private[r] object MultilayerPerceptronClassifierWrapper def fit( data: DataFrame, blockSize: Int, - layers: Array[Double], + layers: Array[Int], solver: String, maxIter: Int, tol: Double, stepSize: Double, - seed: Int + seed: String ): MultilayerPerceptronClassifierWrapper = { // get labels and feature names from output schema val schema = data.schema // assemble and fit the pipeline val mlp = new MultilayerPerceptronClassifier() - .setLayers(layers.map(_.toInt)) + .setLayers(layers) .setBlockSize(blockSize) .setSolver(solver) .setMaxIter(maxIter) .setTol(tol) .setStepSize(stepSize) - .setSeed(seed) .setPredictionCol(PREDICTED_LABEL_COL) + if (seed != null && seed.length > 0) mlp.setSeed(seed.toInt) val pipeline = new Pipeline() .setStages(Array(mlp)) .fit(data) From f62ddc5983a08d4d54c0a9a8210dd6cbec555671 Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Fri, 23 Sep 2016 11:37:43 -0700 Subject: [PATCH 112/306] [SPARK-17210][SPARKR] sparkr.zip is not distributed to executors when running sparkr in RStudio MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Spark will add sparkr.zip to archive only when it is yarn mode (SparkSubmit.scala). ``` if (args.isR && clusterManager == YARN) { val sparkRPackagePath = RUtils.localSparkRPackagePath if (sparkRPackagePath.isEmpty) { printErrorAndExit("SPARK_HOME does not exist for R application in YARN mode.") } val sparkRPackageFile = new File(sparkRPackagePath.get, SPARKR_PACKAGE_ARCHIVE) if (!sparkRPackageFile.exists()) { printErrorAndExit(s"$SPARKR_PACKAGE_ARCHIVE does not exist for R application in YARN mode.") } val sparkRPackageURI = Utils.resolveURI(sparkRPackageFile.getAbsolutePath).toString // Distribute the SparkR package. // Assigns a symbol link name "sparkr" to the shipped package. args.archives = mergeFileLists(args.archives, sparkRPackageURI + "#sparkr") // Distribute the R package archive containing all the built R packages. if (!RUtils.rPackages.isEmpty) { val rPackageFile = RPackageUtils.zipRLibraries(new File(RUtils.rPackages.get), R_PACKAGE_ARCHIVE) if (!rPackageFile.exists()) { printErrorAndExit("Failed to zip all the built R packages.") } val rPackageURI = Utils.resolveURI(rPackageFile.getAbsolutePath).toString // Assigns a symbol link name "rpkg" to the shipped package. args.archives = mergeFileLists(args.archives, rPackageURI + "#rpkg") } } ``` So it is necessary to pass spark.master from R process to JVM. Otherwise sparkr.zip won't be distributed to executor. Besides that I also pass spark.yarn.keytab/spark.yarn.principal to spark side, because JVM process need them to access secured cluster. ## How was this patch tested? Verify it manually in R Studio using the following code. ``` Sys.setenv(SPARK_HOME="/Users/jzhang/github/spark") .libPaths(c(file.path(Sys.getenv(), "R", "lib"), .libPaths())) library(SparkR) sparkR.session(master="yarn-client", sparkConfig = list(spark.executor.instances="1")) df <- as.DataFrame(mtcars) head(df) ``` … Author: Jeff Zhang Closes #14784 from zjffdu/SPARK-17210. --- R/pkg/R/sparkR.R | 4 ++++ docs/sparkr.md | 15 +++++++++++++++ 2 files changed, 19 insertions(+) diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index 06015362e6bc1..cc6d591bb2f4c 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -491,6 +491,10 @@ sparkConfToSubmitOps[["spark.driver.memory"]] <- "--driver-memory" sparkConfToSubmitOps[["spark.driver.extraClassPath"]] <- "--driver-class-path" sparkConfToSubmitOps[["spark.driver.extraJavaOptions"]] <- "--driver-java-options" sparkConfToSubmitOps[["spark.driver.extraLibraryPath"]] <- "--driver-library-path" +sparkConfToSubmitOps[["spark.master"]] <- "--master" +sparkConfToSubmitOps[["spark.yarn.keytab"]] <- "--keytab" +sparkConfToSubmitOps[["spark.yarn.principal"]] <- "--principal" + # Utility function that returns Spark Submit arguments as a string # diff --git a/docs/sparkr.md b/docs/sparkr.md index b881119731045..340e7f7cb1a0b 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -62,6 +62,21 @@ The following Spark driver properties can be set in `sparkConfig` with `sparkR.s + + + + + + + + + + + + + + + From 988c71457354b0a443471f501cef544a85b1a76a Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Fri, 23 Sep 2016 12:17:59 -0700 Subject: [PATCH 113/306] [SPARK-17643] Remove comparable requirement from Offset For some sources, it is difficult to provide a global ordering based only on the data in the offset. Since we don't use comparison for correctness, lets remove it. Author: Michael Armbrust Closes #15207 from marmbrus/removeComparable. --- .../execution/streaming/CompositeOffset.scala | 30 -------------- .../sql/execution/streaming/LongOffset.scala | 6 --- .../sql/execution/streaming/Offset.scala | 19 ++------- .../execution/streaming/StreamExecution.scala | 9 +++-- .../spark/sql/streaming/OffsetSuite.scala | 39 ------------------- 5 files changed, 9 insertions(+), 94 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompositeOffset.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompositeOffset.scala index 729c8462fed65..ebc6ee8184902 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompositeOffset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompositeOffset.scala @@ -23,36 +23,6 @@ package org.apache.spark.sql.execution.streaming * vector clock that must progress linearly forward. */ case class CompositeOffset(offsets: Seq[Option[Offset]]) extends Offset { - /** - * Returns a negative integer, zero, or a positive integer as this object is less than, equal to, - * or greater than the specified object. - */ - override def compareTo(other: Offset): Int = other match { - case otherComposite: CompositeOffset if otherComposite.offsets.size == offsets.size => - val comparisons = offsets.zip(otherComposite.offsets).map { - case (Some(a), Some(b)) => a compareTo b - case (None, None) => 0 - case (None, _) => -1 - case (_, None) => 1 - } - val nonZeroSigns = comparisons.map(sign).filter(_ != 0).toSet - nonZeroSigns.size match { - case 0 => 0 // if both empty or only 0s - case 1 => nonZeroSigns.head // if there are only (0s and 1s) or (0s and -1s) - case _ => // there are both 1s and -1s - throw new IllegalArgumentException( - s"Invalid comparison between non-linear histories: $this <=> $other") - } - case _ => - throw new IllegalArgumentException(s"Cannot compare $this <=> $other") - } - - private def sign(num: Int): Int = num match { - case i if i < 0 => -1 - case i if i == 0 => 0 - case i if i > 0 => 1 - } - /** * Unpacks an offset into [[StreamProgress]] by associating each offset with the order list of * sources. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala index bb176408d8f59..c5e8827777792 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala @@ -22,12 +22,6 @@ package org.apache.spark.sql.execution.streaming */ case class LongOffset(offset: Long) extends Offset { - override def compareTo(other: Offset): Int = other match { - case l: LongOffset => offset.compareTo(l.offset) - case _ => - throw new IllegalArgumentException(s"Invalid comparison of $getClass with ${other.getClass}") - } - def +(increment: Long): LongOffset = new LongOffset(offset + increment) def -(decrement: Long): LongOffset = new LongOffset(offset - decrement) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.scala index 2cc012840dcaa..1f52abf277581 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.scala @@ -19,19 +19,8 @@ package org.apache.spark.sql.execution.streaming /** * An offset is a monotonically increasing metric used to track progress in the computation of a - * stream. An [[Offset]] must be comparable, and the result of `compareTo` must be consistent - * with `equals` and `hashcode`. + * stream. Since offsets are retrieved from a [[Source]] by a single thread, we know the global + * ordering of two [[Offset]] instances. We do assume that if two offsets are `equal` then no + * new data has arrived. */ -trait Offset extends Serializable { - - /** - * Returns a negative integer, zero, or a positive integer as this object is less than, equal to, - * or greater than the specified object. - */ - def compareTo(other: Offset): Int - - def >(other: Offset): Boolean = compareTo(other) > 0 - def <(other: Offset): Boolean = compareTo(other) < 0 - def <=(other: Offset): Boolean = compareTo(other) <= 0 - def >=(other: Offset): Boolean = compareTo(other) >= 0 -} +trait Offset extends Serializable {} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 220f77dc24ce0..9825f19b86a55 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -259,7 +259,7 @@ class StreamExecution( case (source, available) => committedOffsets .get(source) - .map(committed => committed < available) + .map(committed => committed != available) .getOrElse(true) } } @@ -318,7 +318,8 @@ class StreamExecution( // Request unprocessed data from all sources. val newData = availableOffsets.flatMap { - case (source, available) if committedOffsets.get(source).map(_ < available).getOrElse(true) => + case (source, available) + if committedOffsets.get(source).map(_ != available).getOrElse(true) => val current = committedOffsets.get(source) val batch = source.getBatch(current, available) logDebug(s"Retrieving data from $source: $current -> $available") @@ -404,10 +405,10 @@ class StreamExecution( * Blocks the current thread until processing for data from the given `source` has reached at * least the given `Offset`. This method is indented for use primarily when writing tests. */ - def awaitOffset(source: Source, newOffset: Offset): Unit = { + private[sql] def awaitOffset(source: Source, newOffset: Offset): Unit = { def notDone = { val localCommittedOffsets = committedOffsets - !localCommittedOffsets.contains(source) || localCommittedOffsets(source) < newOffset + !localCommittedOffsets.contains(source) || localCommittedOffsets(source) != newOffset } while (notDone) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/OffsetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/OffsetSuite.scala index 9590af4e7737d..b65a987770304 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/OffsetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/OffsetSuite.scala @@ -24,44 +24,12 @@ trait OffsetSuite extends SparkFunSuite { /** Creates test to check all the comparisons of offsets given a `one` that is less than `two`. */ def compare(one: Offset, two: Offset): Unit = { test(s"comparison $one <=> $two") { - assert(one < two) - assert(one <= two) - assert(one <= one) - assert(two > one) - assert(two >= one) - assert(one >= one) assert(one == one) assert(two == two) assert(one != two) assert(two != one) } } - - /** Creates test to check that non-equality comparisons throw exception. */ - def compareInvalid(one: Offset, two: Offset): Unit = { - test(s"invalid comparison $one <=> $two") { - intercept[IllegalArgumentException] { - assert(one < two) - } - - intercept[IllegalArgumentException] { - assert(one <= two) - } - - intercept[IllegalArgumentException] { - assert(one > two) - } - - intercept[IllegalArgumentException] { - assert(one >= two) - } - - assert(!(one == two)) - assert(!(two == one)) - assert(one != two) - assert(two != one) - } - } } class LongOffsetSuite extends OffsetSuite { @@ -79,10 +47,6 @@ class CompositeOffsetSuite extends OffsetSuite { one = CompositeOffset(None :: Nil), two = CompositeOffset(Some(LongOffset(2)) :: Nil)) - compareInvalid( // sizes must be same - one = CompositeOffset(Nil), - two = CompositeOffset(Some(LongOffset(2)) :: Nil)) - compare( one = CompositeOffset.fill(LongOffset(0), LongOffset(1)), two = CompositeOffset.fill(LongOffset(1), LongOffset(2))) @@ -91,8 +55,5 @@ class CompositeOffsetSuite extends OffsetSuite { one = CompositeOffset.fill(LongOffset(1), LongOffset(1)), two = CompositeOffset.fill(LongOffset(1), LongOffset(2))) - compareInvalid( - one = CompositeOffset.fill(LongOffset(2), LongOffset(1)), // vector time inconsistent - two = CompositeOffset.fill(LongOffset(1), LongOffset(2))) } From 90a30f46349182b6fc9d4123090c4712fdb425be Mon Sep 17 00:00:00 2001 From: jisookim Date: Fri, 23 Sep 2016 13:43:47 -0700 Subject: [PATCH 114/306] [SPARK-12221] add cpu time to metrics Currently task metrics don't support executor CPU time, so there's no way to calculate how much CPU time a stage/task took from History Server metrics. This PR enables reporting CPU time. Author: jisookim Closes #10212 from jisookim0513/add-cpu-time-metric. --- .../apache/spark/InternalAccumulator.scala | 2 + .../org/apache/spark/executor/Executor.scala | 15 +++ .../apache/spark/executor/TaskMetrics.scala | 18 ++++ .../apache/spark/scheduler/ResultTask.scala | 8 ++ .../spark/scheduler/ShuffleMapTask.scala | 8 ++ .../org/apache/spark/scheduler/Task.scala | 2 + .../status/api/v1/AllStagesResource.scala | 5 + .../org/apache/spark/status/api/v1/api.scala | 5 + .../spark/ui/jobs/JobProgressListener.scala | 4 + .../org/apache/spark/ui/jobs/UIData.scala | 5 + .../org/apache/spark/util/JsonProtocol.scala | 10 ++ .../complete_stage_list_json_expectation.json | 3 + .../failed_stage_list_json_expectation.json | 1 + .../one_stage_attempt_json_expectation.json | 17 +++ .../one_stage_json_expectation.json | 17 +++ .../stage_list_json_expectation.json | 4 + ...ist_with_accumulable_json_expectation.json | 1 + .../stage_task_list_expectation.json | 40 +++++++ ...multi_attempt_app_json_1__expectation.json | 16 +++ ...multi_attempt_app_json_2__expectation.json | 16 +++ ...k_list_w__offset___length_expectation.json | 100 ++++++++++++++++++ ...stage_task_list_w__sortBy_expectation.json | 40 +++++++ ...tBy_short_names___runtime_expectation.json | 40 +++++++ ...rtBy_short_names__runtime_expectation.json | 40 +++++++ ...mmary_w__custom_quantiles_expectation.json | 2 + ...sk_summary_w_shuffle_read_expectation.json | 2 + ...k_summary_w_shuffle_write_expectation.json | 2 + ...age_with_accumulable_json_expectation.json | 17 +++ .../apache/spark/util/JsonProtocolSuite.scala | 69 ++++++++---- project/MimaExcludes.scala | 4 + 30 files changed, 492 insertions(+), 21 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/InternalAccumulator.scala b/core/src/main/scala/org/apache/spark/InternalAccumulator.scala index 0b494c146fa1b..82d3098e2e055 100644 --- a/core/src/main/scala/org/apache/spark/InternalAccumulator.scala +++ b/core/src/main/scala/org/apache/spark/InternalAccumulator.scala @@ -31,7 +31,9 @@ private[spark] object InternalAccumulator { // Names of internal task level metrics val EXECUTOR_DESERIALIZE_TIME = METRICS_PREFIX + "executorDeserializeTime" + val EXECUTOR_DESERIALIZE_CPU_TIME = METRICS_PREFIX + "executorDeserializeCpuTime" val EXECUTOR_RUN_TIME = METRICS_PREFIX + "executorRunTime" + val EXECUTOR_CPU_TIME = METRICS_PREFIX + "executorCpuTime" val RESULT_SIZE = METRICS_PREFIX + "resultSize" val JVM_GC_TIME = METRICS_PREFIX + "jvmGCTime" val RESULT_SERIALIZATION_TIME = METRICS_PREFIX + "resultSerializationTime" diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 668ec41153086..9501dd9cd8e93 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -232,13 +232,18 @@ private[spark] class Executor( } override def run(): Unit = { + val threadMXBean = ManagementFactory.getThreadMXBean val taskMemoryManager = new TaskMemoryManager(env.memoryManager, taskId) val deserializeStartTime = System.currentTimeMillis() + val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) { + threadMXBean.getCurrentThreadCpuTime + } else 0L Thread.currentThread.setContextClassLoader(replClassLoader) val ser = env.closureSerializer.newInstance() logInfo(s"Running $taskName (TID $taskId)") execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER) var taskStart: Long = 0 + var taskStartCpu: Long = 0 startGCTime = computeTotalGcTime() try { @@ -269,6 +274,9 @@ private[spark] class Executor( // Run the actual task and measure its runtime. taskStart = System.currentTimeMillis() + taskStartCpu = if (threadMXBean.isCurrentThreadCpuTimeSupported) { + threadMXBean.getCurrentThreadCpuTime + } else 0L var threwException = true val value = try { val res = task.run( @@ -302,6 +310,9 @@ private[spark] class Executor( } } val taskFinish = System.currentTimeMillis() + val taskFinishCpu = if (threadMXBean.isCurrentThreadCpuTimeSupported) { + threadMXBean.getCurrentThreadCpuTime + } else 0L // If the task has been killed, let's fail it. if (task.killed) { @@ -317,8 +328,12 @@ private[spark] class Executor( // includes the Partition. Second, Task.run() deserializes the RDD and function to be run. task.metrics.setExecutorDeserializeTime( (taskStart - deserializeStartTime) + task.executorDeserializeTime) + task.metrics.setExecutorDeserializeCpuTime( + (taskStartCpu - deserializeStartCpuTime) + task.executorDeserializeCpuTime) // We need to subtract Task.run()'s deserialization time to avoid double-counting task.metrics.setExecutorRunTime((taskFinish - taskStart) - task.executorDeserializeTime) + task.metrics.setExecutorCpuTime( + (taskFinishCpu - taskStartCpu) - task.executorDeserializeCpuTime) task.metrics.setJvmGCTime(computeTotalGcTime() - startGCTime) task.metrics.setResultSerializationTime(afterSerialization - beforeSerialization) diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index 52a349919e336..2956768c16417 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -47,7 +47,9 @@ import org.apache.spark.util.{AccumulatorContext, AccumulatorMetadata, Accumulat class TaskMetrics private[spark] () extends Serializable { // Each metric is internally represented as an accumulator private val _executorDeserializeTime = new LongAccumulator + private val _executorDeserializeCpuTime = new LongAccumulator private val _executorRunTime = new LongAccumulator + private val _executorCpuTime = new LongAccumulator private val _resultSize = new LongAccumulator private val _jvmGCTime = new LongAccumulator private val _resultSerializationTime = new LongAccumulator @@ -61,11 +63,22 @@ class TaskMetrics private[spark] () extends Serializable { */ def executorDeserializeTime: Long = _executorDeserializeTime.sum + /** + * CPU Time taken on the executor to deserialize this task in nanoseconds. + */ + def executorDeserializeCpuTime: Long = _executorDeserializeCpuTime.sum + /** * Time the executor spends actually running the task (including fetching shuffle data). */ def executorRunTime: Long = _executorRunTime.sum + /** + * CPU Time the executor spends actually running the task + * (including fetching shuffle data) in nanoseconds. + */ + def executorCpuTime: Long = _executorCpuTime.sum + /** * The number of bytes this task transmitted back to the driver as the TaskResult. */ @@ -111,7 +124,10 @@ class TaskMetrics private[spark] () extends Serializable { // Setters and increment-ers private[spark] def setExecutorDeserializeTime(v: Long): Unit = _executorDeserializeTime.setValue(v) + private[spark] def setExecutorDeserializeCpuTime(v: Long): Unit = + _executorDeserializeCpuTime.setValue(v) private[spark] def setExecutorRunTime(v: Long): Unit = _executorRunTime.setValue(v) + private[spark] def setExecutorCpuTime(v: Long): Unit = _executorCpuTime.setValue(v) private[spark] def setResultSize(v: Long): Unit = _resultSize.setValue(v) private[spark] def setJvmGCTime(v: Long): Unit = _jvmGCTime.setValue(v) private[spark] def setResultSerializationTime(v: Long): Unit = @@ -188,7 +204,9 @@ class TaskMetrics private[spark] () extends Serializable { import InternalAccumulator._ @transient private[spark] lazy val nameToAccums = LinkedHashMap( EXECUTOR_DESERIALIZE_TIME -> _executorDeserializeTime, + EXECUTOR_DESERIALIZE_CPU_TIME -> _executorDeserializeCpuTime, EXECUTOR_RUN_TIME -> _executorRunTime, + EXECUTOR_CPU_TIME -> _executorCpuTime, RESULT_SIZE -> _resultSize, JVM_GC_TIME -> _jvmGCTime, RESULT_SERIALIZATION_TIME -> _resultSerializationTime, diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index 75c6018e214d8..609f10aee940d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -18,6 +18,7 @@ package org.apache.spark.scheduler import java.io._ +import java.lang.management.ManagementFactory import java.nio.ByteBuffer import java.util.Properties @@ -61,11 +62,18 @@ private[spark] class ResultTask[T, U]( override def runTask(context: TaskContext): U = { // Deserialize the RDD and the func using the broadcast variables. + val threadMXBean = ManagementFactory.getThreadMXBean val deserializeStartTime = System.currentTimeMillis() + val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) { + threadMXBean.getCurrentThreadCpuTime + } else 0L val ser = SparkEnv.get.closureSerializer.newInstance() val (rdd, func) = ser.deserialize[(RDD[T], (TaskContext, Iterator[T]) => U)]( ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader) _executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime + _executorDeserializeCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) { + threadMXBean.getCurrentThreadCpuTime - deserializeStartCpuTime + } else 0L func(context, rdd.iterator(partition, context)) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index 84b3e5ba6c1f3..448fe02084e0d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -17,6 +17,7 @@ package org.apache.spark.scheduler +import java.lang.management.ManagementFactory import java.nio.ByteBuffer import java.util.Properties @@ -66,11 +67,18 @@ private[spark] class ShuffleMapTask( override def runTask(context: TaskContext): MapStatus = { // Deserialize the RDD using the broadcast variable. + val threadMXBean = ManagementFactory.getThreadMXBean val deserializeStartTime = System.currentTimeMillis() + val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) { + threadMXBean.getCurrentThreadCpuTime + } else 0L val ser = SparkEnv.get.closureSerializer.newInstance() val (rdd, dep) = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])]( ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader) _executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime + _executorDeserializeCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) { + threadMXBean.getCurrentThreadCpuTime - deserializeStartCpuTime + } else 0L var writer: ShuffleWriter[Any, Any] = null try { diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index ea9dc3988d934..48daa344f3c88 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -139,6 +139,7 @@ private[spark] abstract class Task[T]( @volatile @transient private var _killed = false protected var _executorDeserializeTime: Long = 0 + protected var _executorDeserializeCpuTime: Long = 0 /** * Whether the task has been killed. @@ -149,6 +150,7 @@ private[spark] abstract class Task[T]( * Returns the amount of time spent deserializing the RDD and function to be run. */ def executorDeserializeTime: Long = _executorDeserializeTime + def executorDeserializeCpuTime: Long = _executorDeserializeCpuTime /** * Collect the latest values of accumulators used in this task. If the task failed, diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala index 7d63a8f734f0e..acb7c23079681 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala @@ -101,6 +101,7 @@ private[v1] object AllStagesResource { numCompleteTasks = stageUiData.numCompleteTasks, numFailedTasks = stageUiData.numFailedTasks, executorRunTime = stageUiData.executorRunTime, + executorCpuTime = stageUiData.executorCpuTime, submissionTime = stageInfo.submissionTime.map(new Date(_)), firstTaskLaunchedTime, completionTime = stageInfo.completionTime.map(new Date(_)), @@ -220,7 +221,9 @@ private[v1] object AllStagesResource { new TaskMetricDistributions( quantiles = quantiles, executorDeserializeTime = metricQuantiles(_.executorDeserializeTime), + executorDeserializeCpuTime = metricQuantiles(_.executorDeserializeCpuTime), executorRunTime = metricQuantiles(_.executorRunTime), + executorCpuTime = metricQuantiles(_.executorCpuTime), resultSize = metricQuantiles(_.resultSize), jvmGcTime = metricQuantiles(_.jvmGCTime), resultSerializationTime = metricQuantiles(_.resultSerializationTime), @@ -241,7 +244,9 @@ private[v1] object AllStagesResource { def convertUiTaskMetrics(internal: InternalTaskMetrics): TaskMetrics = { new TaskMetrics( executorDeserializeTime = internal.executorDeserializeTime, + executorDeserializeCpuTime = internal.executorDeserializeCpuTime, executorRunTime = internal.executorRunTime, + executorCpuTime = internal.executorCpuTime, resultSize = internal.resultSize, jvmGcTime = internal.jvmGCTime, resultSerializationTime = internal.resultSerializationTime, diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index 32e332a9adb9d..44a929b310384 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -128,6 +128,7 @@ class StageData private[spark]( val numFailedTasks: Int, val executorRunTime: Long, + val executorCpuTime: Long, val submissionTime: Option[Date], val firstTaskLaunchedTime: Option[Date], val completionTime: Option[Date], @@ -166,7 +167,9 @@ class TaskData private[spark]( class TaskMetrics private[spark]( val executorDeserializeTime: Long, + val executorDeserializeCpuTime: Long, val executorRunTime: Long, + val executorCpuTime: Long, val resultSize: Long, val jvmGcTime: Long, val resultSerializationTime: Long, @@ -202,7 +205,9 @@ class TaskMetricDistributions private[spark]( val quantiles: IndexedSeq[Double], val executorDeserializeTime: IndexedSeq[Double], + val executorDeserializeCpuTime: IndexedSeq[Double], val executorRunTime: IndexedSeq[Double], + val executorCpuTime: IndexedSeq[Double], val resultSize: IndexedSeq[Double], val jvmGcTime: IndexedSeq[Double], val resultSerializationTime: IndexedSeq[Double], diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index d3a4f9d3223a7..83dc5d874589e 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -503,6 +503,10 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { val timeDelta = taskMetrics.executorRunTime - oldMetrics.map(_.executorRunTime).getOrElse(0L) stageData.executorRunTime += timeDelta + + val cpuTimeDelta = + taskMetrics.executorCpuTime - oldMetrics.map(_.executorCpuTime).getOrElse(0L) + stageData.executorCpuTime += cpuTimeDelta } override def onExecutorMetricsUpdate(executorMetricsUpdate: SparkListenerExecutorMetricsUpdate) { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala index c729f03b3c383..f4a04609c4c69 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala @@ -80,6 +80,7 @@ private[spark] object UIData { var numKilledTasks: Int = _ var executorRunTime: Long = _ + var executorCpuTime: Long = _ var inputBytes: Long = _ var inputRecords: Long = _ @@ -137,7 +138,9 @@ private[spark] object UIData { metrics.map { m => TaskMetricsUIData( executorDeserializeTime = m.executorDeserializeTime, + executorDeserializeCpuTime = m.executorDeserializeCpuTime, executorRunTime = m.executorRunTime, + executorCpuTime = m.executorCpuTime, resultSize = m.resultSize, jvmGCTime = m.jvmGCTime, resultSerializationTime = m.resultSerializationTime, @@ -179,7 +182,9 @@ private[spark] object UIData { case class TaskMetricsUIData( executorDeserializeTime: Long, + executorDeserializeCpuTime: Long, executorRunTime: Long, + executorCpuTime: Long, resultSize: Long, jvmGCTime: Long, resultSerializationTime: Long, diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 41d947c4428ad..f4fa7b4061640 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -348,7 +348,9 @@ private[spark] object JsonProtocol { ("Status" -> blockStatusToJson(status)) }) ("Executor Deserialize Time" -> taskMetrics.executorDeserializeTime) ~ + ("Executor Deserialize CPU Time" -> taskMetrics.executorDeserializeCpuTime) ~ ("Executor Run Time" -> taskMetrics.executorRunTime) ~ + ("Executor CPU Time" -> taskMetrics.executorCpuTime) ~ ("Result Size" -> taskMetrics.resultSize) ~ ("JVM GC Time" -> taskMetrics.jvmGCTime) ~ ("Result Serialization Time" -> taskMetrics.resultSerializationTime) ~ @@ -759,7 +761,15 @@ private[spark] object JsonProtocol { return metrics } metrics.setExecutorDeserializeTime((json \ "Executor Deserialize Time").extract[Long]) + metrics.setExecutorDeserializeCpuTime((json \ "Executor Deserialize CPU Time") match { + case JNothing => 0 + case x => x.extract[Long] + }) metrics.setExecutorRunTime((json \ "Executor Run Time").extract[Long]) + metrics.setExecutorCpuTime((json \ "Executor CPU Time") match { + case JNothing => 0 + case x => x.extract[Long] + }) metrics.setResultSize((json \ "Result Size").extract[Long]) metrics.setJvmGCTime((json \ "JVM GC Time").extract[Long]) metrics.setResultSerializationTime((json \ "Result Serialization Time").extract[Long]) diff --git a/core/src/test/resources/HistoryServerExpectations/complete_stage_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/complete_stage_list_json_expectation.json index 8f8067f86d57f..25c4fff77e0ad 100644 --- a/core/src/test/resources/HistoryServerExpectations/complete_stage_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/complete_stage_list_json_expectation.json @@ -6,6 +6,7 @@ "numCompleteTasks" : 8, "numFailedTasks" : 0, "executorRunTime" : 162, + "executorCpuTime" : 0, "submissionTime" : "2015-02-03T16:43:07.191GMT", "firstTaskLaunchedTime" : "2015-02-03T16:43:07.191GMT", "completionTime" : "2015-02-03T16:43:07.226GMT", @@ -31,6 +32,7 @@ "numCompleteTasks" : 8, "numFailedTasks" : 0, "executorRunTime" : 3476, + "executorCpuTime" : 0, "submissionTime" : "2015-02-03T16:43:05.829GMT", "firstTaskLaunchedTime" : "2015-02-03T16:43:05.829GMT", "completionTime" : "2015-02-03T16:43:06.286GMT", @@ -56,6 +58,7 @@ "numCompleteTasks" : 8, "numFailedTasks" : 0, "executorRunTime" : 4338, + "executorCpuTime" : 0, "submissionTime" : "2015-02-03T16:43:04.228GMT", "firstTaskLaunchedTime" : "2015-02-03T16:43:04.234GMT", "completionTime" : "2015-02-03T16:43:04.819GMT", diff --git a/core/src/test/resources/HistoryServerExpectations/failed_stage_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/failed_stage_list_json_expectation.json index 08b692eda8028..b86ba1e65de12 100644 --- a/core/src/test/resources/HistoryServerExpectations/failed_stage_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/failed_stage_list_json_expectation.json @@ -6,6 +6,7 @@ "numCompleteTasks" : 7, "numFailedTasks" : 1, "executorRunTime" : 278, + "executorCpuTime" : 0, "submissionTime" : "2015-02-03T16:43:06.296GMT", "firstTaskLaunchedTime" : "2015-02-03T16:43:06.296GMT", "completionTime" : "2015-02-03T16:43:06.347GMT", diff --git a/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json index 477a2fec8b69b..0084339d24642 100644 --- a/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json @@ -6,6 +6,7 @@ "numCompleteTasks" : 8, "numFailedTasks" : 0, "executorRunTime" : 3476, + "executorCpuTime" : 0, "submissionTime" : "2015-02-03T16:43:05.829GMT", "firstTaskLaunchedTime" : "2015-02-03T16:43:05.829GMT", "completionTime" : "2015-02-03T16:43:06.286GMT", @@ -36,7 +37,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 1, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 435, + "executorCpuTime" : 0, "resultSize" : 1902, "jvmGcTime" : 19, "resultSerializationTime" : 2, @@ -77,7 +80,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 2, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 434, + "executorCpuTime" : 0, "resultSize" : 1902, "jvmGcTime" : 19, "resultSerializationTime" : 1, @@ -118,7 +123,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 2, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 434, + "executorCpuTime" : 0, "resultSize" : 1902, "jvmGcTime" : 19, "resultSerializationTime" : 1, @@ -159,7 +166,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 2, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 434, + "executorCpuTime" : 0, "resultSize" : 1902, "jvmGcTime" : 19, "resultSerializationTime" : 2, @@ -200,7 +209,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 2, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 434, + "executorCpuTime" : 0, "resultSize" : 1902, "jvmGcTime" : 19, "resultSerializationTime" : 1, @@ -241,7 +252,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 1, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 436, + "executorCpuTime" : 0, "resultSize" : 1902, "jvmGcTime" : 19, "resultSerializationTime" : 0, @@ -282,7 +295,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 2, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 434, + "executorCpuTime" : 0, "resultSize" : 1902, "jvmGcTime" : 19, "resultSerializationTime" : 1, @@ -323,7 +338,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 1, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 435, + "executorCpuTime" : 0, "resultSize" : 1902, "jvmGcTime" : 19, "resultSerializationTime" : 1, diff --git a/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json index 388e51f77a24d..63fe3b2f958e5 100644 --- a/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json @@ -6,6 +6,7 @@ "numCompleteTasks" : 8, "numFailedTasks" : 0, "executorRunTime" : 3476, + "executorCpuTime" : 0, "submissionTime" : "2015-02-03T16:43:05.829GMT", "firstTaskLaunchedTime" : "2015-02-03T16:43:05.829GMT", "completionTime" : "2015-02-03T16:43:06.286GMT", @@ -36,7 +37,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 1, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 435, + "executorCpuTime" : 0, "resultSize" : 1902, "jvmGcTime" : 19, "resultSerializationTime" : 2, @@ -77,7 +80,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 2, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 434, + "executorCpuTime" : 0, "resultSize" : 1902, "jvmGcTime" : 19, "resultSerializationTime" : 1, @@ -118,7 +123,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 2, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 434, + "executorCpuTime" : 0, "resultSize" : 1902, "jvmGcTime" : 19, "resultSerializationTime" : 1, @@ -159,7 +166,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 2, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 434, + "executorCpuTime" : 0, "resultSize" : 1902, "jvmGcTime" : 19, "resultSerializationTime" : 2, @@ -200,7 +209,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 2, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 434, + "executorCpuTime" : 0, "resultSize" : 1902, "jvmGcTime" : 19, "resultSerializationTime" : 1, @@ -241,7 +252,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 1, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 436, + "executorCpuTime" : 0, "resultSize" : 1902, "jvmGcTime" : 19, "resultSerializationTime" : 0, @@ -282,7 +295,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 2, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 434, + "executorCpuTime" : 0, "resultSize" : 1902, "jvmGcTime" : 19, "resultSerializationTime" : 1, @@ -323,7 +338,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 1, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 435, + "executorCpuTime" : 0, "resultSize" : 1902, "jvmGcTime" : 19, "resultSerializationTime" : 1, diff --git a/core/src/test/resources/HistoryServerExpectations/stage_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_list_json_expectation.json index 5b957ed549556..6509df1508b30 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_list_json_expectation.json @@ -6,6 +6,7 @@ "numCompleteTasks" : 8, "numFailedTasks" : 0, "executorRunTime" : 162, + "executorCpuTime" : 0, "submissionTime" : "2015-02-03T16:43:07.191GMT", "firstTaskLaunchedTime" : "2015-02-03T16:43:07.191GMT", "completionTime" : "2015-02-03T16:43:07.226GMT", @@ -31,6 +32,7 @@ "numCompleteTasks" : 8, "numFailedTasks" : 0, "executorRunTime" : 3476, + "executorCpuTime" : 0, "submissionTime" : "2015-02-03T16:43:05.829GMT", "firstTaskLaunchedTime" : "2015-02-03T16:43:05.829GMT", "completionTime" : "2015-02-03T16:43:06.286GMT", @@ -56,6 +58,7 @@ "numCompleteTasks" : 8, "numFailedTasks" : 0, "executorRunTime" : 4338, + "executorCpuTime" : 0, "submissionTime" : "2015-02-03T16:43:04.228GMT", "firstTaskLaunchedTime" : "2015-02-03T16:43:04.234GMT", "completionTime" : "2015-02-03T16:43:04.819GMT", @@ -81,6 +84,7 @@ "numCompleteTasks" : 7, "numFailedTasks" : 1, "executorRunTime" : 278, + "executorCpuTime" : 0, "submissionTime" : "2015-02-03T16:43:06.296GMT", "firstTaskLaunchedTime" : "2015-02-03T16:43:06.296GMT", "completionTime" : "2015-02-03T16:43:06.347GMT", diff --git a/core/src/test/resources/HistoryServerExpectations/stage_list_with_accumulable_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_list_with_accumulable_json_expectation.json index afa425f8c27bb..8496863a93469 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_list_with_accumulable_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_list_with_accumulable_json_expectation.json @@ -6,6 +6,7 @@ "numCompleteTasks" : 8, "numFailedTasks" : 0, "executorRunTime" : 120, + "executorCpuTime" : 0, "submissionTime" : "2015-03-16T19:25:36.103GMT", "firstTaskLaunchedTime" : "2015-03-16T19:25:36.515GMT", "completionTime" : "2015-03-16T19:25:36.579GMT", diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_expectation.json index 8e09aabbad7c9..e0661c464179d 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_expectation.json @@ -10,7 +10,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 32, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 349, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 7, "resultSerializationTime" : 1, @@ -50,7 +52,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 31, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 350, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 7, "resultSerializationTime" : 0, @@ -90,7 +94,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 32, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 348, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 7, "resultSerializationTime" : 2, @@ -130,7 +136,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 31, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 349, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 7, "resultSerializationTime" : 2, @@ -170,7 +178,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 31, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 349, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 7, "resultSerializationTime" : 1, @@ -210,7 +220,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 30, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 350, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 7, "resultSerializationTime" : 1, @@ -250,7 +262,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 29, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 351, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 7, "resultSerializationTime" : 1, @@ -290,7 +304,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 31, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 349, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 7, "resultSerializationTime" : 0, @@ -330,7 +346,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 3, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 80, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -370,7 +388,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 9, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 84, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -410,7 +430,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 8, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 73, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -450,7 +472,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 6, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 75, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -490,7 +514,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 9, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 77, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -530,7 +556,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 9, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 76, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -570,7 +598,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 6, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 83, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -610,7 +640,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 3, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 76, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -650,7 +682,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 10, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 84, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 5, "resultSerializationTime" : 0, @@ -690,7 +724,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 11, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 91, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 5, "resultSerializationTime" : 1, @@ -730,7 +766,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 10, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 92, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 5, "resultSerializationTime" : 0, @@ -770,7 +808,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 5, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 84, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 5, "resultSerializationTime" : 0, diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_1__expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_1__expectation.json index 1dbf72b42a926..8492f19ab7a5f 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_1__expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_1__expectation.json @@ -15,7 +15,9 @@ } ], "taskMetrics" : { "executorDeserializeTime" : 14, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 15, + "executorCpuTime" : 0, "resultSize" : 697, "jvmGcTime" : 0, "resultSerializationTime" : 2, @@ -60,7 +62,9 @@ } ], "taskMetrics" : { "executorDeserializeTime" : 14, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 15, + "executorCpuTime" : 0, "resultSize" : 697, "jvmGcTime" : 0, "resultSerializationTime" : 2, @@ -105,7 +109,9 @@ } ], "taskMetrics" : { "executorDeserializeTime" : 13, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 15, + "executorCpuTime" : 0, "resultSize" : 697, "jvmGcTime" : 0, "resultSerializationTime" : 2, @@ -150,7 +156,9 @@ } ], "taskMetrics" : { "executorDeserializeTime" : 13, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 15, + "executorCpuTime" : 0, "resultSize" : 697, "jvmGcTime" : 0, "resultSerializationTime" : 2, @@ -195,7 +203,9 @@ } ], "taskMetrics" : { "executorDeserializeTime" : 12, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 15, + "executorCpuTime" : 0, "resultSize" : 697, "jvmGcTime" : 0, "resultSerializationTime" : 1, @@ -240,7 +250,9 @@ } ], "taskMetrics" : { "executorDeserializeTime" : 12, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 15, + "executorCpuTime" : 0, "resultSize" : 697, "jvmGcTime" : 0, "resultSerializationTime" : 2, @@ -285,7 +297,9 @@ } ], "taskMetrics" : { "executorDeserializeTime" : 12, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 15, + "executorCpuTime" : 0, "resultSize" : 697, "jvmGcTime" : 0, "resultSerializationTime" : 2, @@ -330,7 +344,9 @@ } ], "taskMetrics" : { "executorDeserializeTime" : 12, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 15, + "executorCpuTime" : 0, "resultSize" : 697, "jvmGcTime" : 0, "resultSerializationTime" : 2, diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_2__expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_2__expectation.json index 483492282dd64..4de4c501a43ad 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_2__expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_2__expectation.json @@ -15,7 +15,9 @@ } ], "taskMetrics" : { "executorDeserializeTime" : 14, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 15, + "executorCpuTime" : 0, "resultSize" : 697, "jvmGcTime" : 0, "resultSerializationTime" : 2, @@ -60,7 +62,9 @@ } ], "taskMetrics" : { "executorDeserializeTime" : 14, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 15, + "executorCpuTime" : 0, "resultSize" : 697, "jvmGcTime" : 0, "resultSerializationTime" : 2, @@ -105,7 +109,9 @@ } ], "taskMetrics" : { "executorDeserializeTime" : 13, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 15, + "executorCpuTime" : 0, "resultSize" : 697, "jvmGcTime" : 0, "resultSerializationTime" : 2, @@ -150,7 +156,9 @@ } ], "taskMetrics" : { "executorDeserializeTime" : 13, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 15, + "executorCpuTime" : 0, "resultSize" : 697, "jvmGcTime" : 0, "resultSerializationTime" : 2, @@ -195,7 +203,9 @@ } ], "taskMetrics" : { "executorDeserializeTime" : 12, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 15, + "executorCpuTime" : 0, "resultSize" : 697, "jvmGcTime" : 0, "resultSerializationTime" : 1, @@ -240,7 +250,9 @@ } ], "taskMetrics" : { "executorDeserializeTime" : 12, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 15, + "executorCpuTime" : 0, "resultSize" : 697, "jvmGcTime" : 0, "resultSerializationTime" : 2, @@ -285,7 +297,9 @@ } ], "taskMetrics" : { "executorDeserializeTime" : 12, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 15, + "executorCpuTime" : 0, "resultSize" : 697, "jvmGcTime" : 0, "resultSerializationTime" : 2, @@ -330,7 +344,9 @@ } ], "taskMetrics" : { "executorDeserializeTime" : 12, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 15, + "executorCpuTime" : 0, "resultSize" : 697, "jvmGcTime" : 0, "resultSerializationTime" : 2, diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__offset___length_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__offset___length_expectation.json index 624f2bb16df48..d2eceeb3f97a9 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__offset___length_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__offset___length_expectation.json @@ -10,7 +10,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 8, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 73, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -50,7 +52,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 6, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 75, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -90,7 +94,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 9, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 77, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -130,7 +136,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 9, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 76, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -170,7 +178,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 6, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 83, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -210,7 +220,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 3, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 76, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -250,7 +262,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 10, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 84, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 5, "resultSerializationTime" : 0, @@ -290,7 +304,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 11, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 91, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 5, "resultSerializationTime" : 1, @@ -330,7 +346,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 10, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 92, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 5, "resultSerializationTime" : 0, @@ -370,7 +388,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 5, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 84, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 5, "resultSerializationTime" : 0, @@ -410,7 +430,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 3, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 83, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 5, "resultSerializationTime" : 0, @@ -450,7 +472,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 5, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 88, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 5, "resultSerializationTime" : 0, @@ -490,7 +514,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 5, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 93, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 5, "resultSerializationTime" : 0, @@ -530,7 +556,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 3, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 65, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 5, "resultSerializationTime" : 0, @@ -570,7 +598,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 3, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 43, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 1, @@ -610,7 +640,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 3, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 49, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -650,7 +682,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 6, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 38, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -690,7 +724,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 3, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 32, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -730,7 +766,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 3, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 29, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -770,7 +808,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 2, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 39, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -810,7 +850,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 4, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 34, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -850,7 +892,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 36, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 24, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 5, "resultSerializationTime" : 0, @@ -890,7 +934,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 3, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 17, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -930,7 +976,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 5, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 43, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 5, "resultSerializationTime" : 0, @@ -970,7 +1018,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 6, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 27, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 5, "resultSerializationTime" : 0, @@ -1010,7 +1060,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 5, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 35, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 5, "resultSerializationTime" : 0, @@ -1050,7 +1102,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 5, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 29, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 5, "resultSerializationTime" : 0, @@ -1090,7 +1144,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 4, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 32, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 5, "resultSerializationTime" : 0, @@ -1130,7 +1186,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 3, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 31, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 5, "resultSerializationTime" : 0, @@ -1170,7 +1228,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 2, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 17, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -1210,7 +1270,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 4, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 14, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -1250,7 +1312,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 2, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 16, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -1290,7 +1354,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 10, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 17, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -1330,7 +1396,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 2, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 16, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -1370,7 +1438,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 3, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 18, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -1410,7 +1480,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 2, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 19, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -1450,7 +1522,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 1, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 31, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 6, "resultSerializationTime" : 0, @@ -1490,7 +1564,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 2, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 18, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -1530,7 +1606,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 3, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 24, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 6, "resultSerializationTime" : 0, @@ -1570,7 +1648,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 7, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 23, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 6, "resultSerializationTime" : 0, @@ -1610,7 +1690,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 4, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 18, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -1650,7 +1732,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 2, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 17, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -1690,7 +1774,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 5, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 18, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -1730,7 +1816,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 6, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 18, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -1770,7 +1858,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 3, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 18, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -1810,7 +1900,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 4, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 21, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -1850,7 +1942,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 5, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 20, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -1890,7 +1984,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 3, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 16, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -1930,7 +2026,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 3, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 16, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -1970,7 +2068,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 3, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 17, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_expectation.json index 96d86b7278ff1..f42c3a4ee5c38 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_expectation.json @@ -10,7 +10,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 29, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 351, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 7, "resultSerializationTime" : 1, @@ -50,7 +52,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 31, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 350, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 7, "resultSerializationTime" : 0, @@ -90,7 +94,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 30, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 350, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 7, "resultSerializationTime" : 1, @@ -130,7 +136,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 32, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 349, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 7, "resultSerializationTime" : 1, @@ -170,7 +178,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 31, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 349, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 7, "resultSerializationTime" : 2, @@ -210,7 +220,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 31, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 349, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 7, "resultSerializationTime" : 1, @@ -250,7 +262,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 31, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 349, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 7, "resultSerializationTime" : 0, @@ -290,7 +304,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 32, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 348, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 7, "resultSerializationTime" : 2, @@ -330,7 +346,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 5, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 93, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 5, "resultSerializationTime" : 0, @@ -370,7 +388,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 10, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 92, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 5, "resultSerializationTime" : 0, @@ -410,7 +430,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 11, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 91, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 5, "resultSerializationTime" : 1, @@ -450,7 +472,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 5, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 88, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 5, "resultSerializationTime" : 0, @@ -490,7 +514,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 9, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 84, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -530,7 +556,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 10, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 84, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 5, "resultSerializationTime" : 0, @@ -570,7 +598,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 5, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 84, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 5, "resultSerializationTime" : 0, @@ -610,7 +640,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 6, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 83, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -650,7 +682,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 3, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 83, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 5, "resultSerializationTime" : 0, @@ -690,7 +724,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 3, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 80, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -730,7 +766,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 9, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 77, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -770,7 +808,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 9, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 76, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 0, "resultSerializationTime" : 0, diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names___runtime_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names___runtime_expectation.json index 96d86b7278ff1..f42c3a4ee5c38 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names___runtime_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names___runtime_expectation.json @@ -10,7 +10,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 29, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 351, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 7, "resultSerializationTime" : 1, @@ -50,7 +52,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 31, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 350, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 7, "resultSerializationTime" : 0, @@ -90,7 +94,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 30, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 350, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 7, "resultSerializationTime" : 1, @@ -130,7 +136,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 32, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 349, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 7, "resultSerializationTime" : 1, @@ -170,7 +178,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 31, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 349, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 7, "resultSerializationTime" : 2, @@ -210,7 +220,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 31, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 349, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 7, "resultSerializationTime" : 1, @@ -250,7 +262,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 31, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 349, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 7, "resultSerializationTime" : 0, @@ -290,7 +304,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 32, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 348, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 7, "resultSerializationTime" : 2, @@ -330,7 +346,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 5, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 93, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 5, "resultSerializationTime" : 0, @@ -370,7 +388,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 10, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 92, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 5, "resultSerializationTime" : 0, @@ -410,7 +430,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 11, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 91, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 5, "resultSerializationTime" : 1, @@ -450,7 +472,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 5, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 88, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 5, "resultSerializationTime" : 0, @@ -490,7 +514,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 9, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 84, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -530,7 +556,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 10, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 84, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 5, "resultSerializationTime" : 0, @@ -570,7 +598,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 5, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 84, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 5, "resultSerializationTime" : 0, @@ -610,7 +640,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 6, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 83, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -650,7 +682,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 3, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 83, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 5, "resultSerializationTime" : 0, @@ -690,7 +724,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 3, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 80, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -730,7 +766,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 9, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 77, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -770,7 +808,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 9, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 76, + "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 0, "resultSerializationTime" : 0, diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names__runtime_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names__runtime_expectation.json index e0e9e8140c717..db60ccccbf8c8 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names__runtime_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names__runtime_expectation.json @@ -10,7 +10,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 4, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 14, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -50,7 +52,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 2, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 16, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -90,7 +94,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 2, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 16, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -130,7 +136,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 3, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 16, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -170,7 +178,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 3, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 16, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -210,7 +220,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 2, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 16, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -250,7 +262,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 3, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 16, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 1, @@ -290,7 +304,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 3, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 17, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -330,7 +346,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 2, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 17, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -370,7 +388,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 10, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 17, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -410,7 +430,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 2, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 17, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -450,7 +472,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 3, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 17, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -490,7 +514,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 20, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 17, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 5, "resultSerializationTime" : 0, @@ -530,7 +556,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 12, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 17, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -570,7 +598,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 2, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 17, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -610,7 +640,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 2, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 17, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -650,7 +682,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 3, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 18, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -690,7 +724,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 2, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 18, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -730,7 +766,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 4, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 18, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, @@ -770,7 +808,9 @@ "accumulatorUpdates" : [ ], "taskMetrics" : { "executorDeserializeTime" : 5, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 18, + "executorCpuTime" : 0, "resultSize" : 2065, "jvmGcTime" : 0, "resultSerializationTime" : 0, diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w__custom_quantiles_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w__custom_quantiles_expectation.json index 76d1553bc8f77..5dcbc890438b2 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w__custom_quantiles_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w__custom_quantiles_expectation.json @@ -1,7 +1,9 @@ { "quantiles" : [ 0.01, 0.5, 0.99 ], "executorDeserializeTime" : [ 1.0, 3.0, 36.0 ], + "executorDeserializeCpuTime" : [ 0.0, 0.0, 0.0 ], "executorRunTime" : [ 16.0, 28.0, 351.0 ], + "executorCpuTime" : [ 0.0, 0.0, 0.0], "resultSize" : [ 2010.0, 2065.0, 2065.0 ], "jvmGcTime" : [ 0.0, 0.0, 7.0 ], "resultSerializationTime" : [ 0.0, 0.0, 2.0 ], diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_read_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_read_expectation.json index 7baffc5df0b0f..6d230ac653776 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_read_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_read_expectation.json @@ -1,7 +1,9 @@ { "quantiles" : [ 0.05, 0.25, 0.5, 0.75, 0.95 ], "executorDeserializeTime" : [ 1.0, 2.0, 2.0, 2.0, 3.0 ], + "executorDeserializeCpuTime" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], "executorRunTime" : [ 30.0, 74.0, 75.0, 76.0, 79.0 ], + "executorCpuTime" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], "resultSize" : [ 1034.0, 1034.0, 1034.0, 1034.0, 1034.0 ], "jvmGcTime" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], "resultSerializationTime" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_write_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_write_expectation.json index f8c4b7c128733..aea0f5413d8b9 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_write_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_write_expectation.json @@ -1,7 +1,9 @@ { "quantiles" : [ 0.05, 0.25, 0.5, 0.75, 0.95 ], "executorDeserializeTime" : [ 2.0, 2.0, 3.0, 7.0, 31.0 ], + "executorDeserializeCpuTime" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], "executorRunTime" : [ 16.0, 18.0, 28.0, 49.0, 349.0 ], + "executorCpuTime" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], "resultSize" : [ 2010.0, 2065.0, 2065.0, 2065.0, 2065.0 ], "jvmGcTime" : [ 0.0, 0.0, 0.0, 5.0, 7.0 ], "resultSerializationTime" : [ 0.0, 0.0, 0.0, 0.0, 1.0 ], diff --git a/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json index ce008bf40967d..aaeef1f2f582c 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json @@ -6,6 +6,7 @@ "numCompleteTasks" : 8, "numFailedTasks" : 0, "executorRunTime" : 120, + "executorCpuTime" : 0, "submissionTime" : "2015-03-16T19:25:36.103GMT", "firstTaskLaunchedTime" : "2015-03-16T19:25:36.515GMT", "completionTime" : "2015-03-16T19:25:36.579GMT", @@ -45,7 +46,9 @@ } ], "taskMetrics" : { "executorDeserializeTime" : 13, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 15, + "executorCpuTime" : 0, "resultSize" : 697, "jvmGcTime" : 0, "resultSerializationTime" : 2, @@ -91,7 +94,9 @@ } ], "taskMetrics" : { "executorDeserializeTime" : 12, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 15, + "executorCpuTime" : 0, "resultSize" : 697, "jvmGcTime" : 0, "resultSerializationTime" : 2, @@ -137,7 +142,9 @@ } ], "taskMetrics" : { "executorDeserializeTime" : 12, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 15, + "executorCpuTime" : 0, "resultSize" : 697, "jvmGcTime" : 0, "resultSerializationTime" : 1, @@ -183,7 +190,9 @@ } ], "taskMetrics" : { "executorDeserializeTime" : 12, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 15, + "executorCpuTime" : 0, "resultSize" : 697, "jvmGcTime" : 0, "resultSerializationTime" : 2, @@ -229,7 +238,9 @@ } ], "taskMetrics" : { "executorDeserializeTime" : 14, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 15, + "executorCpuTime" : 0, "resultSize" : 697, "jvmGcTime" : 0, "resultSerializationTime" : 2, @@ -275,7 +286,9 @@ } ], "taskMetrics" : { "executorDeserializeTime" : 13, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 15, + "executorCpuTime" : 0, "resultSize" : 697, "jvmGcTime" : 0, "resultSerializationTime" : 2, @@ -321,7 +334,9 @@ } ], "taskMetrics" : { "executorDeserializeTime" : 12, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 15, + "executorCpuTime" : 0, "resultSize" : 697, "jvmGcTime" : 0, "resultSerializationTime" : 2, @@ -367,7 +382,9 @@ } ], "taskMetrics" : { "executorDeserializeTime" : 14, + "executorDeserializeCpuTime" : 0, "executorRunTime" : 15, + "executorCpuTime" : 0, "resultSize" : 697, "jvmGcTime" : 0, "resultSerializationTime" : 2, diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 00314abf49fd4..d5146d70ebaa3 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -606,6 +606,9 @@ private[spark] object JsonProtocolSuite extends Assertions { private def assertEquals(metrics1: TaskMetrics, metrics2: TaskMetrics) { assert(metrics1.executorDeserializeTime === metrics2.executorDeserializeTime) + assert(metrics1.executorDeserializeCpuTime === metrics2.executorDeserializeCpuTime) + assert(metrics1.executorRunTime === metrics2.executorRunTime) + assert(metrics1.executorCpuTime === metrics2.executorCpuTime) assert(metrics1.resultSize === metrics2.resultSize) assert(metrics1.jvmGCTime === metrics2.jvmGCTime) assert(metrics1.resultSerializationTime === metrics2.resultSerializationTime) @@ -816,8 +819,11 @@ private[spark] object JsonProtocolSuite extends Assertions { hasOutput: Boolean, hasRecords: Boolean = true) = { val t = TaskMetrics.empty + // Set CPU times same as wall times for testing purpose t.setExecutorDeserializeTime(a) + t.setExecutorDeserializeCpuTime(a) t.setExecutorRunTime(b) + t.setExecutorCpuTime(b) t.setResultSize(c) t.setJvmGCTime(d) t.setResultSerializationTime(a + b) @@ -1097,7 +1103,9 @@ private[spark] object JsonProtocolSuite extends Assertions { | }, | "Task Metrics": { | "Executor Deserialize Time": 300, + | "Executor Deserialize CPU Time": 300, | "Executor Run Time": 400, + | "Executor CPU Time": 400, | "Result Size": 500, | "JVM GC Time": 600, | "Result Serialization Time": 700, @@ -1195,7 +1203,9 @@ private[spark] object JsonProtocolSuite extends Assertions { | }, | "Task Metrics": { | "Executor Deserialize Time": 300, + | "Executor Deserialize CPU Time": 300, | "Executor Run Time": 400, + | "Executor CPU Time": 400, | "Result Size": 500, | "JVM GC Time": 600, | "Result Serialization Time": 700, @@ -1293,7 +1303,9 @@ private[spark] object JsonProtocolSuite extends Assertions { | }, | "Task Metrics": { | "Executor Deserialize Time": 300, + | "Executor Deserialize CPU Time": 300, | "Executor Run Time": 400, + | "Executor CPU Time": 400, | "Result Size": 500, | "JVM GC Time": 600, | "Result Serialization Time": 700, @@ -1785,55 +1797,70 @@ private[spark] object JsonProtocolSuite extends Assertions { | }, | { | "ID": 1, + | "Name": "$EXECUTOR_DESERIALIZE_CPU_TIME", + | "Update": 300, + | "Internal": true, + | "Count Failed Values": true + | }, + | + | { + | "ID": 2, | "Name": "$EXECUTOR_RUN_TIME", | "Update": 400, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 2, + | "ID": 3, + | "Name": "$EXECUTOR_CPU_TIME", + | "Update": 400, + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 4, | "Name": "$RESULT_SIZE", | "Update": 500, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 3, + | "ID": 5, | "Name": "$JVM_GC_TIME", | "Update": 600, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 4, + | "ID": 6, | "Name": "$RESULT_SERIALIZATION_TIME", | "Update": 700, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 5, + | "ID": 7, | "Name": "$MEMORY_BYTES_SPILLED", | "Update": 800, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 6, + | "ID": 8, | "Name": "$DISK_BYTES_SPILLED", | "Update": 0, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 7, + | "ID": 9, | "Name": "$PEAK_EXECUTION_MEMORY", | "Update": 0, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 8, + | "ID": 10, | "Name": "$UPDATED_BLOCK_STATUSES", | "Update": [ | { @@ -1854,98 +1881,98 @@ private[spark] object JsonProtocolSuite extends Assertions { | "Count Failed Values": true | }, | { - | "ID": 9, + | "ID": 11, | "Name": "${shuffleRead.REMOTE_BLOCKS_FETCHED}", | "Update": 0, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 10, + | "ID": 12, | "Name": "${shuffleRead.LOCAL_BLOCKS_FETCHED}", | "Update": 0, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 11, + | "ID": 13, | "Name": "${shuffleRead.REMOTE_BYTES_READ}", | "Update": 0, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 12, + | "ID": 14, | "Name": "${shuffleRead.LOCAL_BYTES_READ}", | "Update": 0, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 13, + | "ID": 15, | "Name": "${shuffleRead.FETCH_WAIT_TIME}", | "Update": 0, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 14, + | "ID": 16, | "Name": "${shuffleRead.RECORDS_READ}", | "Update": 0, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 15, + | "ID": 17, | "Name": "${shuffleWrite.BYTES_WRITTEN}", | "Update": 0, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 16, + | "ID": 18, | "Name": "${shuffleWrite.RECORDS_WRITTEN}", | "Update": 0, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 17, + | "ID": 19, | "Name": "${shuffleWrite.WRITE_TIME}", | "Update": 0, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 18, + | "ID": 20, | "Name": "${input.BYTES_READ}", | "Update": 2100, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 19, + | "ID": 21, | "Name": "${input.RECORDS_READ}", | "Update": 21, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 20, + | "ID": 22, | "Name": "${output.BYTES_WRITTEN}", | "Update": 1200, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 21, + | "ID": 23, | "Name": "${output.RECORDS_WRITTEN}", | "Update": 12, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 22, + | "ID": 24, | "Name": "$TEST_ACCUM", | "Update": 0, | "Internal": true, diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index b6f64e5a703ca..8024fbd21bbfc 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -823,6 +823,10 @@ object MimaExcludes { ) ++ Seq( // [SPARK-17365][Core] Remove/Kill multiple executors together to reduce RPC call time ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.SparkContext") + ) ++ Seq( + // [SPARK-12221] Add CPU time to metrics + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskMetrics.this"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskMetricDistributions.this") ) } From 7c382524a959a2bc9b3d2fca44f6f0b41aba4e3c Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Fri, 23 Sep 2016 14:35:18 -0700 Subject: [PATCH 115/306] [SPARK-17651][SPARKR] Set R package version number along with mvn ## What changes were proposed in this pull request? This PR sets the R package version while tagging releases. Note that since R doesn't accept `-SNAPSHOT` in version number field, we remove that while setting the next version ## How was this patch tested? Tested manually by running locally Author: Shivaram Venkataraman Closes #15223 from shivaram/sparkr-version-change. --- dev/create-release/release-tag.sh | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/dev/create-release/release-tag.sh b/dev/create-release/release-tag.sh index d404939d1caee..b7e5100ca7408 100755 --- a/dev/create-release/release-tag.sh +++ b/dev/create-release/release-tag.sh @@ -60,12 +60,27 @@ git config user.email $GIT_EMAIL # Create release version $MVN versions:set -DnewVersion=$RELEASE_VERSION | grep -v "no value" # silence logs +# Set the release version in R/pkg/DESCRIPTION +sed -i".tmp1" 's/Version.*$/Version: '"$RELEASE_VERSION"'/g' R/pkg/DESCRIPTION +# Set the release version in docs +sed -i".tmp1" 's/SPARK_VERSION:.*$/SPARK_VERSION: '"$RELEASE_VERSION"'/g' docs/_config.yml +sed -i".tmp2" 's/SPARK_VERSION_SHORT:.*$/SPARK_VERSION_SHORT: '"$RELEASE_VERSION"'/g' docs/_config.yml + git commit -a -m "Preparing Spark release $RELEASE_TAG" echo "Creating tag $RELEASE_TAG at the head of $GIT_BRANCH" git tag $RELEASE_TAG # Create next version $MVN versions:set -DnewVersion=$NEXT_VERSION | grep -v "no value" # silence logs +# Remove -SNAPSHOT before setting the R version as R expects version strings to only have numbers +R_NEXT_VERSION=`echo $NEXT_VERSION | sed 's/-SNAPSHOT//g'` +sed -i".tmp2" 's/Version.*$/Version: '"$R_NEXT_VERSION"'/g' R/pkg/DESCRIPTION + +# Update docs with next version +sed -i".tmp3" 's/SPARK_VERSION:.*$/SPARK_VERSION: '"$NEXT_VERSION"'/g' docs/_config.yml +# Use R version for short version +sed -i".tmp4" 's/SPARK_VERSION_SHORT:.*$/SPARK_VERSION_SHORT: '"$R_NEXT_VERSION"'/g' docs/_config.yml + git commit -a -m "Preparing development version $NEXT_VERSION" # Push changes From f3fe55439e4c865c26502487a1bccf255da33f4a Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sat, 24 Sep 2016 08:06:41 +0100 Subject: [PATCH 116/306] [SPARK-10835][ML] Word2Vec should accept non-null string array, in addition to existing null string array ## What changes were proposed in this pull request? To match Tokenizer and for compatibility with Word2Vec, output a nullable string array type in NGram ## How was this patch tested? Jenkins tests. Author: Sean Owen Closes #15179 from srowen/SPARK-10835. --- .../apache/spark/ml/feature/Word2Vec.scala | 3 ++- .../spark/ml/feature/Word2VecSuite.scala | 21 +++++++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index 14c05123c62ed..d53f3df514dff 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -108,7 +108,8 @@ private[feature] trait Word2VecBase extends Params * Validate and transform the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { - SchemaUtils.checkColumnType(schema, $(inputCol), new ArrayType(StringType, true)) + val typeCandidates = List(new ArrayType(StringType, true), new ArrayType(StringType, false)) + SchemaUtils.checkColumnTypes(schema, $(inputCol), typeCandidates) SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala index 0b441f8b80810..613cc3d60b227 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala @@ -207,5 +207,26 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul val newInstance = testDefaultReadWrite(instance) assert(newInstance.getVectors.collect() === instance.getVectors.collect()) } + + test("Word2Vec works with input that is non-nullable (NGram)") { + val spark = this.spark + import spark.implicits._ + + val sentence = "a q s t q s t b b b s t m s t m q " + val docDF = sc.parallelize(Seq(sentence, sentence)).map(_.split(" ")).toDF("text") + + val ngram = new NGram().setN(2).setInputCol("text").setOutputCol("ngrams") + val ngramDF = ngram.transform(docDF) + + val model = new Word2Vec() + .setVectorSize(2) + .setInputCol("ngrams") + .setOutputCol("result") + .fit(ngramDF) + + // Just test that this transformation succeeds + model.transform(ngramDF).collect() + } + } From 248916f5589155c0c3e93c3874781f17b08d598d Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sat, 24 Sep 2016 08:15:55 +0100 Subject: [PATCH 117/306] [SPARK-17057][ML] ProbabilisticClassifierModels' thresholds should have at most one 0 ## What changes were proposed in this pull request? Match ProbabilisticClassifer.thresholds requirements to R randomForest cutoff, requiring all > 0 ## How was this patch tested? Jenkins tests plus new test cases Author: Sean Owen Closes #15149 from srowen/SPARK-17057. --- .../classification/LogisticRegression.scala | 5 +-- .../ProbabilisticClassifier.scala | 20 +++++------ .../ml/param/shared/SharedParamsCodeGen.scala | 8 +++-- .../spark/ml/param/shared/sharedParams.scala | 4 +-- .../ProbabilisticClassifierSuite.scala | 35 +++++++++++++++---- .../ml/param/_shared_params_code_gen.py | 5 +-- python/pyspark/ml/param/shared.py | 4 +-- 7 files changed, 52 insertions(+), 29 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 343d50c790e85..5ab63d1de95d3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -123,9 +123,10 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas /** * Set thresholds in multiclass (or binary) classification to adjust the probability of - * predicting each class. Array must have length equal to the number of classes, with values >= 0. + * predicting each class. Array must have length equal to the number of classes, with values > 0, + * excepting that at most one value may be 0. * The class with largest value p/t is predicted, where p is the original probability of that - * class and t is the class' threshold. + * class and t is the class's threshold. * * Note: When [[setThresholds()]] is called, any user-set value for [[threshold]] will be cleared. * If both [[threshold]] and [[thresholds]] are set in a ParamMap, then they must be diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala index 1b6e77542cc80..e89da6ff8bdd7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala @@ -18,7 +18,7 @@ package org.apache.spark.ml.classification import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.ml.linalg.{DenseVector, Vector, Vectors, VectorUDT} +import org.apache.spark.ml.linalg.{DenseVector, Vector, VectorUDT} import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.sql.{DataFrame, Dataset} @@ -200,22 +200,20 @@ abstract class ProbabilisticClassificationModel[ if (!isDefined(thresholds)) { probability.argmax } else { - val thresholds: Array[Double] = getThresholds - val probabilities = probability.toArray + val thresholds = getThresholds var argMax = 0 var max = Double.NegativeInfinity var i = 0 val probabilitySize = probability.size while (i < probabilitySize) { - if (thresholds(i) == 0.0) { - max = Double.PositiveInfinity + // Thresholds are all > 0, excepting that at most one may be 0. + // The single class whose threshold is 0, if any, will always be predicted + // ('scaled' = +Infinity). However in the case that this class also has + // 0 probability, the class will not be selected ('scaled' is NaN). + val scaled = probability(i) / thresholds(i) + if (scaled > max) { + max = scaled argMax = i - } else { - val scaled = probabilities(i) / thresholds(i) - if (scaled > max) { - max = scaled - argMax = i - } } i += 1 } diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index 480b03d0f35c4..c94b8b4e9dfda 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -50,10 +50,12 @@ private[shared] object SharedParamsCodeGen { isValid = "ParamValidators.inRange(0, 1)", finalMethods = false), ParamDesc[Array[Double]]("thresholds", "Thresholds in multi-class classification" + " to adjust the probability of predicting each class." + - " Array must have length equal to the number of classes, with values >= 0." + + " Array must have length equal to the number of classes, with values > 0" + + " excepting that at most one value may be 0." + " The class with largest value p/t is predicted, where p is the original probability" + - " of that class and t is the class' threshold", - isValid = "(t: Array[Double]) => t.forall(_ >= 0)", finalMethods = false), + " of that class and t is the class's threshold", + isValid = "(t: Array[Double]) => t.forall(_ >= 0) && t.count(_ == 0) <= 1", + finalMethods = false), ParamDesc[String]("inputCol", "input column name"), ParamDesc[Array[String]]("inputCols", "input column names"), ParamDesc[String]("outputCol", "output column name", Some("uid + \"__output\"")), diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index 9125d9e19bf09..fa4530927e8b0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -176,10 +176,10 @@ private[ml] trait HasThreshold extends Params { private[ml] trait HasThresholds extends Params { /** - * Param for Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold. + * Param for Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0 excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold. * @group param */ - final val thresholds: DoubleArrayParam = new DoubleArrayParam(this, "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold", (t: Array[Double]) => t.forall(_ >= 0)) + final val thresholds: DoubleArrayParam = new DoubleArrayParam(this, "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0 excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold", (t: Array[Double]) => t.forall(_ >= 0) && t.count(_ == 0) <= 1) /** @group getParam */ def getThresholds: Array[Double] = $(thresholds) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala index b3bd2b3e57b36..172c64aab9d3d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala @@ -36,8 +36,8 @@ final class TestProbabilisticClassificationModel( rawPrediction } - def friendlyPredict(input: Vector): Double = { - predict(input) + def friendlyPredict(values: Double*): Double = { + predict(Vectors.dense(values.toArray)) } } @@ -45,16 +45,37 @@ final class TestProbabilisticClassificationModel( class ProbabilisticClassifierSuite extends SparkFunSuite { test("test thresholding") { - val thresholds = Array(0.5, 0.2) val testModel = new TestProbabilisticClassificationModel("myuid", 2, 2) - .setThresholds(thresholds) - assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 1.0))) === 1.0) - assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 0.2))) === 0.0) + .setThresholds(Array(0.5, 0.2)) + assert(testModel.friendlyPredict(1.0, 1.0) === 1.0) + assert(testModel.friendlyPredict(1.0, 0.2) === 0.0) } test("test thresholding not required") { val testModel = new TestProbabilisticClassificationModel("myuid", 2, 2) - assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 2.0))) === 1.0) + assert(testModel.friendlyPredict(1.0, 2.0) === 1.0) + } + + test("test tiebreak") { + val testModel = new TestProbabilisticClassificationModel("myuid", 2, 2) + .setThresholds(Array(0.4, 0.4)) + assert(testModel.friendlyPredict(0.6, 0.6) === 0.0) + } + + test("test one zero threshold") { + val testModel = new TestProbabilisticClassificationModel("myuid", 2, 2) + .setThresholds(Array(0.0, 0.1)) + assert(testModel.friendlyPredict(1.0, 10.0) === 0.0) + assert(testModel.friendlyPredict(0.0, 10.0) === 1.0) + } + + test("bad thresholds") { + intercept[IllegalArgumentException] { + new TestProbabilisticClassificationModel("myuid", 2, 2).setThresholds(Array(0.0, 0.0)) + } + intercept[IllegalArgumentException] { + new TestProbabilisticClassificationModel("myuid", 2, 2).setThresholds(Array(-0.1, 0.1)) + } } } diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py index 4f4328bcadc6f..929591236d688 100644 --- a/python/pyspark/ml/param/_shared_params_code_gen.py +++ b/python/pyspark/ml/param/_shared_params_code_gen.py @@ -139,8 +139,9 @@ def get$Name(self): "model.", "True", "TypeConverters.toBoolean"), ("thresholds", "Thresholds in multi-class classification to adjust the probability of " + "predicting each class. Array must have length equal to the number of classes, with " + - "values >= 0. The class with largest value p/t is predicted, where p is the original " + - "probability of that class and t is the class' threshold.", None, + "values > 0, excepting that at most one value may be 0. " + + "The class with largest value p/t is predicted, where p is the original " + + "probability of that class and t is the class's threshold.", None, "TypeConverters.toListFloat"), ("weightCol", "weight column name. If this is not set or empty, we treat " + "all instance weights as 1.0.", None, "TypeConverters.toString"), diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py index 24af07afc7d5c..cc596936d82f6 100644 --- a/python/pyspark/ml/param/shared.py +++ b/python/pyspark/ml/param/shared.py @@ -469,10 +469,10 @@ def getStandardization(self): class HasThresholds(Params): """ - Mixin for param thresholds: Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold. + Mixin for param thresholds: Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0, excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold. """ - thresholds = Param(Params._dummy(), "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.", typeConverter=TypeConverters.toListFloat) + thresholds = Param(Params._dummy(), "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0, excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold.", typeConverter=TypeConverters.toListFloat) def __init__(self): super(HasThresholds, self).__init__() From 7945daed12542587d51ece8f07e5c828b40db14a Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Sat, 24 Sep 2016 01:03:11 -0700 Subject: [PATCH 118/306] [MINOR][SPARKR] Add sparkr-vignettes.html to gitignore. ## What changes were proposed in this pull request? Add ```sparkr-vignettes.html``` to ```.gitignore```. ## How was this patch tested? No need test. Author: Yanbo Liang Closes #15215 from yanboliang/ignore. --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index cfa8ad05f7da1..39d17e1793f77 100644 --- a/.gitignore +++ b/.gitignore @@ -24,6 +24,7 @@ R-unit-tests.log R/unit-tests.out R/cran-check.out +R/pkg/vignettes/sparkr-vignettes.html build/*.jar build/apache-maven* build/scala* From de333d121da4cb80d45819cbcf8b4246e48ec4d0 Mon Sep 17 00:00:00 2001 From: xin wu Date: Sun, 25 Sep 2016 16:46:12 -0700 Subject: [PATCH 119/306] [SPARK-17551][SQL] Add DataFrame API for null ordering ## What changes were proposed in this pull request? This pull request adds Scala/Java DataFrame API for null ordering (NULLS FIRST | LAST). Also did some minor clean up for related code (e.g. incorrect indentation), and renamed "orderby-nulls-ordering.sql" to be consistent with existing test files. ## How was this patch tested? Added a new test case in DataFrameSuite. Author: petermaxlee Author: Xin Wu Closes #15123 from petermaxlee/SPARK-17551. --- .../sql/catalyst/expressions/SortOrder.scala | 28 ++------ .../codegen/GenerateOrdering.scala | 16 ++--- .../scala/org/apache/spark/sql/Column.scala | 64 ++++++++++++++++++- .../org/apache/spark/sql/functions.scala | 51 ++++++++++++++- ...dering.sql => order-by-nulls-ordering.sql} | 0 ...ql.out => order-by-nulls-ordering.sql.out} | 0 .../org/apache/spark/sql/DataFrameSuite.scala | 18 ++++++ 7 files changed, 144 insertions(+), 33 deletions(-) rename sql/core/src/test/resources/sql-tests/inputs/{orderby-nulls-ordering.sql => order-by-nulls-ordering.sql} (100%) rename sql/core/src/test/resources/sql-tests/results/{orderby-nulls-ordering.sql.out => order-by-nulls-ordering.sql.out} (100%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index d015125baccaf..3bebd552ef51a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -54,10 +54,7 @@ case object NullsLast extends NullOrdering{ * An expression that can be used to sort a tuple. This class extends expression primarily so that * transformations over expression will descend into its child. */ -case class SortOrder( - child: Expression, - direction: SortDirection, - nullOrdering: NullOrdering) +case class SortOrder(child: Expression, direction: SortDirection, nullOrdering: NullOrdering) extends UnaryExpression with Unevaluable { /** Sort order is not foldable because we don't have an eval for it. */ @@ -94,17 +91,9 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression { val nullValue = child.child.dataType match { case BooleanType | DateType | TimestampType | _: IntegralType => - if (nullAsSmallest) { - Long.MinValue - } else { - Long.MaxValue - } + if (nullAsSmallest) Long.MinValue else Long.MaxValue case dt: DecimalType if dt.precision - dt.scale <= Decimal.MAX_LONG_DIGITS => - if (nullAsSmallest) { - Long.MinValue - } else { - Long.MaxValue - } + if (nullAsSmallest) Long.MinValue else Long.MaxValue case _: DecimalType => if (nullAsSmallest) { DoublePrefixComparator.computePrefix(Double.NegativeInfinity) @@ -112,16 +101,13 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression { DoublePrefixComparator.computePrefix(Double.NaN) } case _ => - if (nullAsSmallest) { - 0L - } else { - -1L - } + if (nullAsSmallest) 0L else -1L } - private def nullAsSmallest: Boolean = (child.isAscending && child.nullOrdering == NullsFirst) || + private def nullAsSmallest: Boolean = { + (child.isAscending && child.nullOrdering == NullsFirst) || (!child.isAscending && child.nullOrdering == NullsLast) - + } override def eval(input: InternalRow): Any = throw new UnsupportedOperationException diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index e7df95e1142ca..f1c30ef6c7fb8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -100,16 +100,16 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR // Nothing } else if ($isNullA) { return ${ - order.nullOrdering match { - case NullsFirst => "-1" - case NullsLast => "1" - }}; + order.nullOrdering match { + case NullsFirst => "-1" + case NullsLast => "1" + }}; } else if ($isNullB) { return ${ - order.nullOrdering match { - case NullsFirst => "1" - case NullsLast => "-1" - }}; + order.nullOrdering match { + case NullsFirst => "1" + case NullsLast => "-1" + }}; } else { int comp = ${ctx.genComp(order.child.dataType, primitiveA, primitiveB)}; if (comp != 0) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 844ca7a8e99ca..63da501f18cca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -1007,7 +1007,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { /** * Returns an ordering used in sorting. * {{{ - * // Scala: sort a DataFrame by age column in descending order. + * // Scala * df.sort(df("age").desc) * * // Java @@ -1020,7 +1020,37 @@ class Column(protected[sql] val expr: Expression) extends Logging { def desc: Column = withExpr { SortOrder(expr, Descending) } /** - * Returns an ordering used in sorting. + * Returns a descending ordering used in sorting, where null values appear before non-null values. + * {{{ + * // Scala: sort a DataFrame by age column in descending order and null values appearing first. + * df.sort(df("age").desc_nulls_first) + * + * // Java + * df.sort(df.col("age").desc_nulls_first()); + * }}} + * + * @group expr_ops + * @since 2.1.0 + */ + def desc_nulls_first: Column = withExpr { SortOrder(expr, Descending, NullsFirst) } + + /** + * Returns a descending ordering used in sorting, where null values appear after non-null values. + * {{{ + * // Scala: sort a DataFrame by age column in descending order and null values appearing last. + * df.sort(df("age").desc_nulls_last) + * + * // Java + * df.sort(df.col("age").desc_nulls_last()); + * }}} + * + * @group expr_ops + * @since 2.1.0 + */ + def desc_nulls_last: Column = withExpr { SortOrder(expr, Descending, NullsLast) } + + /** + * Returns an ascending ordering used in sorting. * {{{ * // Scala: sort a DataFrame by age column in ascending order. * df.sort(df("age").asc) @@ -1034,6 +1064,36 @@ class Column(protected[sql] val expr: Expression) extends Logging { */ def asc: Column = withExpr { SortOrder(expr, Ascending) } + /** + * Returns an ascending ordering used in sorting, where null values appear before non-null values. + * {{{ + * // Scala: sort a DataFrame by age column in ascending order and null values appearing first. + * df.sort(df("age").asc_nulls_last) + * + * // Java + * df.sort(df.col("age").asc_nulls_last()); + * }}} + * + * @group expr_ops + * @since 2.1.0 + */ + def asc_nulls_first: Column = withExpr { SortOrder(expr, Ascending, NullsFirst) } + + /** + * Returns an ordering used in sorting, where null values appear after non-null values. + * {{{ + * // Scala: sort a DataFrame by age column in ascending order and null values appearing last. + * df.sort(df("age").asc_nulls_last) + * + * // Java + * df.sort(df.col("age").asc_nulls_last()); + * }}} + * + * @group expr_ops + * @since 2.1.0 + */ + def asc_nulls_last: Column = withExpr { SortOrder(expr, Ascending, NullsLast) } + /** * Prints the expression to the console for debugging purpose. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 960c87f60e624..47bf41a2da813 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -109,7 +109,6 @@ object functions { /** * Returns a sort expression based on ascending order of the column. * {{{ - * // Sort by dept in ascending order, and then age in descending order. * df.sort(asc("dept"), desc("age")) * }}} * @@ -118,10 +117,33 @@ object functions { */ def asc(columnName: String): Column = Column(columnName).asc + /** + * Returns a sort expression based on ascending order of the column, + * and null values return before non-null values. + * {{{ + * df.sort(asc_nulls_last("dept"), desc("age")) + * }}} + * + * @group sort_funcs + * @since 2.1.0 + */ + def asc_nulls_first(columnName: String): Column = Column(columnName).asc_nulls_first + + /** + * Returns a sort expression based on ascending order of the column, + * and null values appear after non-null values. + * {{{ + * df.sort(asc_nulls_last("dept"), desc("age")) + * }}} + * + * @group sort_funcs + * @since 2.1.0 + */ + def asc_nulls_last(columnName: String): Column = Column(columnName).asc_nulls_last + /** * Returns a sort expression based on the descending order of the column. * {{{ - * // Sort by dept in ascending order, and then age in descending order. * df.sort(asc("dept"), desc("age")) * }}} * @@ -130,6 +152,31 @@ object functions { */ def desc(columnName: String): Column = Column(columnName).desc + /** + * Returns a sort expression based on the descending order of the column, + * and null values appear before non-null values. + * {{{ + * df.sort(asc("dept"), desc_nulls_first("age")) + * }}} + * + * @group sort_funcs + * @since 2.1.0 + */ + def desc_nulls_first(columnName: String): Column = Column(columnName).desc_nulls_first + + /** + * Returns a sort expression based on the descending order of the column, + * and null values appear after non-null values. + * {{{ + * df.sort(asc("dept"), desc_nulls_last("age")) + * }}} + * + * @group sort_funcs + * @since 2.1.0 + */ + def desc_nulls_last(columnName: String): Column = Column(columnName).desc_nulls_last + + ////////////////////////////////////////////////////////////////////////////////////////////// // Aggregate functions ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/resources/sql-tests/inputs/orderby-nulls-ordering.sql b/sql/core/src/test/resources/sql-tests/inputs/order-by-nulls-ordering.sql similarity index 100% rename from sql/core/src/test/resources/sql-tests/inputs/orderby-nulls-ordering.sql rename to sql/core/src/test/resources/sql-tests/inputs/order-by-nulls-ordering.sql diff --git a/sql/core/src/test/resources/sql-tests/results/orderby-nulls-ordering.sql.out b/sql/core/src/test/resources/sql-tests/results/order-by-nulls-ordering.sql.out similarity index 100% rename from sql/core/src/test/resources/sql-tests/results/orderby-nulls-ordering.sql.out rename to sql/core/src/test/resources/sql-tests/results/order-by-nulls-ordering.sql.out diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 2c60a7dd9209b..16cc368208485 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -326,6 +326,24 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { Row(6)) } + test("sorting with null ordering") { + val data = Seq[java.lang.Integer](2, 1, null).toDF("key") + + checkAnswer(data.orderBy('key.asc), Row(null) :: Row(1) :: Row(2) :: Nil) + checkAnswer(data.orderBy(asc("key")), Row(null) :: Row(1) :: Row(2) :: Nil) + checkAnswer(data.orderBy('key.asc_nulls_first), Row(null) :: Row(1) :: Row(2) :: Nil) + checkAnswer(data.orderBy(asc_nulls_first("key")), Row(null) :: Row(1) :: Row(2) :: Nil) + checkAnswer(data.orderBy('key.asc_nulls_last), Row(1) :: Row(2) :: Row(null) :: Nil) + checkAnswer(data.orderBy(asc_nulls_last("key")), Row(1) :: Row(2) :: Row(null) :: Nil) + + checkAnswer(data.orderBy('key.desc), Row(2) :: Row(1) :: Row(null) :: Nil) + checkAnswer(data.orderBy(desc("key")), Row(2) :: Row(1) :: Row(null) :: Nil) + checkAnswer(data.orderBy('key.desc_nulls_first), Row(null) :: Row(2) :: Row(1) :: Nil) + checkAnswer(data.orderBy(desc_nulls_first("key")), Row(null) :: Row(2) :: Row(1) :: Nil) + checkAnswer(data.orderBy('key.desc_nulls_last), Row(2) :: Row(1) :: Row(null) :: Nil) + checkAnswer(data.orderBy(desc_nulls_last("key")), Row(2) :: Row(1) :: Row(null) :: Nil) + } + test("global sorting") { checkAnswer( testData2.orderBy('a.asc, 'b.asc), From 59d87d24079bc633e63ce032f0a5ddd18a3b02cb Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Sun, 25 Sep 2016 22:57:31 -0700 Subject: [PATCH 120/306] [SPARK-17650] malformed url's throw exceptions before bricking Executors ## What changes were proposed in this pull request? When a malformed URL was sent to Executors through `sc.addJar` and `sc.addFile`, the executors become unusable, because they constantly throw `MalformedURLException`s and can never acknowledge that the file or jar is just bad input. This PR tries to fix that problem by making sure MalformedURLs can never be submitted through `sc.addJar` and `sc.addFile`. Another solution would be to blacklist bad files and jars on Executors. Maybe fail the first time, and then ignore the second time (but print a warning message). ## How was this patch tested? Unit tests in SparkContextSuite Author: Burak Yavuz Closes #15224 from brkyvz/SPARK-17650. --- .../scala/org/apache/spark/SparkContext.scala | 16 ++++++++------ .../scala/org/apache/spark/util/Utils.scala | 20 +++++++++++++++++ .../org/apache/spark/SparkContextSuite.scala | 22 +++++++++++++++++++ 3 files changed, 51 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index f58037e100989..4694790c72cd8 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -19,7 +19,7 @@ package org.apache.spark import java.io._ import java.lang.reflect.Constructor -import java.net.URI +import java.net.{MalformedURLException, URI} import java.util.{Arrays, Locale, Properties, ServiceLoader, UUID} import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap} import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger, AtomicReference} @@ -36,18 +36,15 @@ import com.google.common.collect.MapMaker import org.apache.commons.lang3.SerializationUtils import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.hadoop.io.{ArrayWritable, BooleanWritable, BytesWritable, DoubleWritable, - FloatWritable, IntWritable, LongWritable, NullWritable, Text, Writable} -import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf, SequenceFileInputFormat, - TextInputFormat} +import org.apache.hadoop.io.{ArrayWritable, BooleanWritable, BytesWritable, DoubleWritable, FloatWritable, IntWritable, LongWritable, NullWritable, Text, Writable} +import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf, SequenceFileInputFormat, TextInputFormat} import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, Job => NewHadoopJob} import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil} -import org.apache.spark.input.{FixedLengthBinaryInputFormat, PortableDataStream, StreamInputFormat, - WholeTextFileInputFormat} +import org.apache.spark.input.{FixedLengthBinaryInputFormat, PortableDataStream, StreamInputFormat, WholeTextFileInputFormat} import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.io.CompressionCodec @@ -1452,6 +1449,9 @@ class SparkContext(config: SparkConf) extends Logging { throw new SparkException(s"Added file $hadoopPath is a directory and recursive is not " + "turned on.") } + } else { + // SPARK-17650: Make sure this is a valid URL before adding it to the list of dependencies + Utils.validateURL(uri) } val key = if (!isLocal && scheme == "file") { @@ -1711,6 +1711,8 @@ class SparkContext(config: SparkConf) extends Logging { key = env.rpcEnv.fileServer.addJar(new File(path)) } else { val uri = new URI(path) + // SPARK-17650: Make sure this is a valid URL before adding it to the list of dependencies + Utils.validateURL(uri) key = uri.getScheme match { // A JAR file which exists only on the driver node case null | "file" => diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 09896c4e2f502..e09666c6103c6 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -697,6 +697,26 @@ private[spark] object Utils extends Logging { } } + /** + * Validate that a given URI is actually a valid URL as well. + * @param uri The URI to validate + */ + @throws[MalformedURLException]("when the URI is an invalid URL") + def validateURL(uri: URI): Unit = { + Option(uri.getScheme).getOrElse("file") match { + case "http" | "https" | "ftp" => + try { + uri.toURL + } catch { + case e: MalformedURLException => + val ex = new MalformedURLException(s"URI (${uri.toString}) is not a valid URL.") + ex.initCause(e) + throw ex + } + case _ => // will not be turned into a URL anyway + } + } + /** * Get the path of a temporary directory. Spark's local directories can be configured through * multiple settings, which are used with the following precedence: diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index f8d143dc610cb..c451c596b069a 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark import java.io.File +import java.net.MalformedURLException import java.nio.charset.StandardCharsets import java.util.concurrent.TimeUnit @@ -173,6 +174,27 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext { } } + test("SPARK-17650: malformed url's throw exceptions before bricking Executors") { + try { + sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) + Seq("http", "https", "ftp").foreach { scheme => + val badURL = s"$scheme://user:pwd/path" + val e1 = intercept[MalformedURLException] { + sc.addFile(badURL) + } + assert(e1.getMessage.contains(badURL)) + val e2 = intercept[MalformedURLException] { + sc.addJar(badURL) + } + assert(e2.getMessage.contains(badURL)) + assert(sc.addedFiles.isEmpty) + assert(sc.addedJars.isEmpty) + } + } finally { + sc.stop() + } + } + test("addFile recursive works") { val pluto = Utils.createTempDir() val neptune = Utils.createTempDir(pluto.getAbsolutePath) From ac65139be96dbf87402b9a85729a93afd3c6ff17 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Mon, 26 Sep 2016 09:45:33 +0100 Subject: [PATCH 121/306] [SPARK-17017][FOLLOW-UP][ML] Refactor of ChiSqSelector and add ML Python API. ## What changes were proposed in this pull request? #14597 modified ```ChiSqSelector``` to support ```fpr``` type selector, however, it left some issue need to be addressed: * We should allow users to set selector type explicitly rather than switching them by using different setting function, since the setting order will involves some unexpected issue. For example, if users both set ```numTopFeatures``` and ```percentile```, it will train ```kbest``` or ```percentile``` model based on the order of setting (the latter setting one will be trained). This make users confused, and we should allow users to set selector type explicitly. We handle similar issues at other place of ML code base such as ```GeneralizedLinearRegression``` and ```LogisticRegression```. * Meanwhile, if there are more than one parameter except ```alpha``` can be set for ```fpr``` model, we can not handle it elegantly in the existing framework. And similar issues for ```kbest``` and ```percentile``` model. Setting selector type explicitly can solve this issue also. * If setting selector type explicitly by users is allowed, we should handle param interaction such as if users set ```selectorType = percentile``` and ```alpha = 0.1```, we should notify users the parameter ```alpha``` will take no effect. We should handle complex parameter interaction checks at ```transformSchema```. (FYI #11620) * We should use lower case of the selector type names to follow MLlib convention. * Add ML Python API. ## How was this patch tested? Unit test. Author: Yanbo Liang Closes #15214 from yanboliang/spark-17017. --- .../spark/ml/feature/ChiSqSelector.scala | 86 ++++++++++--------- .../mllib/api/python/PythonMLLibAPI.scala | 38 +++----- .../spark/mllib/feature/ChiSqSelector.scala | 51 ++++++----- .../spark/ml/feature/ChiSqSelectorSuite.scala | 27 ++++-- .../mllib/feature/ChiSqSelectorSuite.scala | 2 +- python/pyspark/ml/feature.py | 71 +++++++++++++-- python/pyspark/mllib/feature.py | 59 ++++++------- 7 files changed, 206 insertions(+), 128 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala index 0c6a37bab0aad..9c131a41850cc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala @@ -27,7 +27,7 @@ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.mllib.feature -import org.apache.spark.mllib.feature.ChiSqSelectorType +import org.apache.spark.mllib.feature.{ChiSqSelector => OldChiSqSelector} import org.apache.spark.mllib.linalg.{Vectors => OldVectors} import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} import org.apache.spark.rdd.RDD @@ -44,7 +44,9 @@ private[feature] trait ChiSqSelectorParams extends Params /** * Number of features that selector will select (ordered by statistic value descending). If the * number of features is less than numTopFeatures, then this will select all features. + * Only applicable when selectorType = "kbest". * The default value of numTopFeatures is 50. + * * @group param */ final val numTopFeatures = new IntParam(this, "numTopFeatures", @@ -56,6 +58,11 @@ private[feature] trait ChiSqSelectorParams extends Params /** @group getParam */ def getNumTopFeatures: Int = $(numTopFeatures) + /** + * Percentile of features that selector will select, ordered by statistics value descending. + * Only applicable when selectorType = "percentile". + * Default value is 0.1. + */ final val percentile = new DoubleParam(this, "percentile", "Percentile of features that selector will select, ordered by statistics value descending.", ParamValidators.inRange(0, 1)) @@ -64,8 +71,12 @@ private[feature] trait ChiSqSelectorParams extends Params /** @group getParam */ def getPercentile: Double = $(percentile) - final val alpha = new DoubleParam(this, "alpha", - "The highest p-value for features to be kept.", + /** + * The highest p-value for features to be kept. + * Only applicable when selectorType = "fpr". + * Default value is 0.05. + */ + final val alpha = new DoubleParam(this, "alpha", "The highest p-value for features to be kept.", ParamValidators.inRange(0, 1)) setDefault(alpha -> 0.05) @@ -73,29 +84,27 @@ private[feature] trait ChiSqSelectorParams extends Params def getAlpha: Double = $(alpha) /** - * The ChiSqSelector supports KBest, Percentile, FPR selection, - * which is the same as ChiSqSelectorType defined in MLLIB. - * when call setNumTopFeatures, the selectorType is set to KBest - * when call setPercentile, the selectorType is set to Percentile - * when call setAlpha, the selectorType is set to FPR + * The selector type of the ChisqSelector. + * Supported options: "kbest" (default), "percentile" and "fpr". */ final val selectorType = new Param[String](this, "selectorType", - "ChiSqSelector Type: KBest, Percentile, FPR") - setDefault(selectorType -> ChiSqSelectorType.KBest.toString) + "The selector type of the ChisqSelector. " + + "Supported options: kbest (default), percentile and fpr.", + ParamValidators.inArray[String](OldChiSqSelector.supportedSelectorTypes.toArray)) + setDefault(selectorType -> OldChiSqSelector.KBest) /** @group getParam */ - def getChiSqSelectorType: String = $(selectorType) + def getSelectorType: String = $(selectorType) } /** * Chi-Squared feature selection, which selects categorical features to use for predicting a * categorical label. - * The selector supports three selection methods: `KBest`, `Percentile` and `FPR`. - * `KBest` chooses the `k` top features according to a chi-squared test. - * `Percentile` is similar but chooses a fraction of all features instead of a fixed number. - * `FPR` chooses all features whose false positive rate meets some threshold. - * By default, the selection method is `KBest`, the default number of top features is 50. - * User can use setNumTopFeatures, setPercentile and setAlpha to set different selection methods. + * The selector supports three selection methods: `kbest`, `percentile` and `fpr`. + * `kbest` chooses the `k` top features according to a chi-squared test. + * `percentile` is similar but chooses a fraction of all features instead of a fixed number. + * `fpr` chooses all features whose false positive rate meets some threshold. + * By default, the selection method is `kbest`, the default number of top features is 50. */ @Since("1.6.0") final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") override val uid: String) @@ -104,24 +113,21 @@ final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") override val uid: Str @Since("1.6.0") def this() = this(Identifiable.randomUID("chiSqSelector")) + /** @group setParam */ + @Since("2.1.0") + def setSelectorType(value: String): this.type = set(selectorType, value) + /** @group setParam */ @Since("1.6.0") - def setNumTopFeatures(value: Int): this.type = { - set(selectorType, ChiSqSelectorType.KBest.toString) - set(numTopFeatures, value) - } + def setNumTopFeatures(value: Int): this.type = set(numTopFeatures, value) + /** @group setParam */ @Since("2.1.0") - def setPercentile(value: Double): this.type = { - set(selectorType, ChiSqSelectorType.Percentile.toString) - set(percentile, value) - } + def setPercentile(value: Double): this.type = set(percentile, value) + /** @group setParam */ @Since("2.1.0") - def setAlpha(value: Double): this.type = { - set(selectorType, ChiSqSelectorType.FPR.toString) - set(alpha, value) - } + def setAlpha(value: Double): this.type = set(alpha, value) /** @group setParam */ @Since("1.6.0") @@ -143,23 +149,23 @@ final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") override val uid: Str case Row(label: Double, features: Vector) => OldLabeledPoint(label, OldVectors.fromML(features)) } - var selector = new feature.ChiSqSelector() - ChiSqSelectorType.withName($(selectorType)) match { - case ChiSqSelectorType.KBest => - selector.setNumTopFeatures($(numTopFeatures)) - case ChiSqSelectorType.Percentile => - selector.setPercentile($(percentile)) - case ChiSqSelectorType.FPR => - selector.setAlpha($(alpha)) - case errorType => - throw new IllegalStateException(s"Unknown ChiSqSelector Type: $errorType") - } + val selector = new feature.ChiSqSelector() + .setSelectorType($(selectorType)) + .setNumTopFeatures($(numTopFeatures)) + .setPercentile($(percentile)) + .setAlpha($(alpha)) val model = selector.fit(input) copyValues(new ChiSqSelectorModel(uid, model).setParent(this)) } @Since("1.6.0") override def transformSchema(schema: StructType): StructType = { + val otherPairs = OldChiSqSelector.supportedTypeAndParamPairs.filter(_._1 != $(selectorType)) + otherPairs.foreach { case (_, paramName: String) => + if (isSet(getParam(paramName))) { + logWarning(s"Param $paramName will take no effect when selector type = ${$(selectorType)}.") + } + } SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) SchemaUtils.checkNumericType(schema, $(labelCol)) SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 5cffbf0892888..904000f50d0a2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -629,35 +629,23 @@ private[python] class PythonMLLibAPI extends Serializable { } /** - * Java stub for ChiSqSelector.fit() when the seletion type is KBest. This stub returns a + * Java stub for ChiSqSelector.fit(). This stub returns a * handle to the Java object instead of the content of the Java object. * Extra care needs to be taken in the Python code to ensure it gets freed on * exit; see the Py4J documentation. */ - def fitChiSqSelectorKBest(numTopFeatures: Int, - data: JavaRDD[LabeledPoint]): ChiSqSelectorModel = { - new ChiSqSelector().setNumTopFeatures(numTopFeatures).fit(data.rdd) - } - - /** - * Java stub for ChiSqSelector.fit() when the selection type is Percentile. This stub returns a - * handle to the Java object instead of the content of the Java object. - * Extra care needs to be taken in the Python code to ensure it gets freed on - * exit; see the Py4J documentation. - */ - def fitChiSqSelectorPercentile(percentile: Double, - data: JavaRDD[LabeledPoint]): ChiSqSelectorModel = { - new ChiSqSelector().setPercentile(percentile).fit(data.rdd) - } - - /** - * Java stub for ChiSqSelector.fit() when the selection type is FPR. This stub returns a - * handle to the Java object instead of the content of the Java object. - * Extra care needs to be taken in the Python code to ensure it gets freed on - * exit; see the Py4J documentation. - */ - def fitChiSqSelectorFPR(alpha: Double, data: JavaRDD[LabeledPoint]): ChiSqSelectorModel = { - new ChiSqSelector().setAlpha(alpha).fit(data.rdd) + def fitChiSqSelector( + selectorType: String, + numTopFeatures: Int, + percentile: Double, + alpha: Double, + data: JavaRDD[LabeledPoint]): ChiSqSelectorModel = { + new ChiSqSelector() + .setSelectorType(selectorType) + .setNumTopFeatures(numTopFeatures) + .setPercentile(percentile) + .setAlpha(alpha) + .fit(data.rdd) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala index f68a017184b21..0f7c6e8bc04bb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala @@ -32,12 +32,6 @@ import org.apache.spark.rdd.RDD import org.apache.spark.SparkContext import org.apache.spark.sql.{Row, SparkSession} -@Since("2.1.0") -private[spark] object ChiSqSelectorType extends Enumeration { - type SelectorType = Value - val KBest, Percentile, FPR = Value -} - /** * Chi Squared selector model. * @@ -166,19 +160,18 @@ object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] { /** * Creates a ChiSquared feature selector. - * The selector supports three selection methods: `KBest`, `Percentile` and `FPR`. - * `KBest` chooses the `k` top features according to a chi-squared test. - * `Percentile` is similar but chooses a fraction of all features instead of a fixed number. - * `FPR` chooses all features whose false positive rate meets some threshold. - * By default, the selection method is `KBest`, the default number of top features is 50. - * User can use setNumTopFeatures, setPercentile and setAlpha to set different selection methods. + * The selector supports three selection methods: `kbest`, `percentile` and `fpr`. + * `kbest` chooses the `k` top features according to a chi-squared test. + * `percentile` is similar but chooses a fraction of all features instead of a fixed number. + * `fpr` chooses all features whose false positive rate meets some threshold. + * By default, the selection method is `kbest`, the default number of top features is 50. */ @Since("1.3.0") class ChiSqSelector @Since("2.1.0") () extends Serializable { var numTopFeatures: Int = 50 var percentile: Double = 0.1 var alpha: Double = 0.05 - var selectorType = ChiSqSelectorType.KBest + var selectorType = ChiSqSelector.KBest /** * The is the same to call this() and setNumTopFeatures(numTopFeatures) @@ -192,7 +185,6 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable { @Since("1.6.0") def setNumTopFeatures(value: Int): this.type = { numTopFeatures = value - selectorType = ChiSqSelectorType.KBest this } @@ -200,7 +192,6 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable { def setPercentile(value: Double): this.type = { require(0.0 <= value && value <= 1.0, "Percentile must be in [0,1]") percentile = value - selectorType = ChiSqSelectorType.Percentile this } @@ -208,12 +199,13 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable { def setAlpha(value: Double): this.type = { require(0.0 <= value && value <= 1.0, "Alpha must be in [0,1]") alpha = value - selectorType = ChiSqSelectorType.FPR this } @Since("2.1.0") - def setChiSqSelectorType(value: ChiSqSelectorType.Value): this.type = { + def setSelectorType(value: String): this.type = { + require(ChiSqSelector.supportedSelectorTypes.toSeq.contains(value), + s"ChiSqSelector Type: $value was not supported.") selectorType = value this } @@ -230,11 +222,11 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable { val chiSqTestResult = Statistics.chiSqTest(data) .zipWithIndex.sortBy { case (res, _) => -res.statistic } val features = selectorType match { - case ChiSqSelectorType.KBest => chiSqTestResult + case ChiSqSelector.KBest => chiSqTestResult .take(numTopFeatures) - case ChiSqSelectorType.Percentile => chiSqTestResult + case ChiSqSelector.Percentile => chiSqTestResult .take((chiSqTestResult.length * percentile).toInt) - case ChiSqSelectorType.FPR => chiSqTestResult + case ChiSqSelector.FPR => chiSqTestResult .filter{ case (res, _) => res.pValue < alpha } case errorType => throw new IllegalStateException(s"Unknown ChiSqSelector Type: $errorType") @@ -244,3 +236,22 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable { } } +@Since("2.1.0") +object ChiSqSelector { + + /** String name for `kbest` selector type. */ + private[spark] val KBest: String = "kbest" + + /** String name for `percentile` selector type. */ + private[spark] val Percentile: String = "percentile" + + /** String name for `fpr` selector type. */ + private[spark] val FPR: String = "fpr" + + /** Set of selector type and param pairs that ChiSqSelector supports. */ + private[spark] val supportedTypeAndParamPairs = Set(KBest -> "numTopFeatures", + Percentile -> "percentile", FPR -> "alpha") + + /** Set of selector types that ChiSqSelector supports. */ + private[spark] val supportedSelectorTypes = supportedTypeAndParamPairs.map(_._1) +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala index e0293dbc4b0b2..6b56e4200250c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala @@ -50,6 +50,7 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext .toDF("label", "data", "preFilteredData") val selector = new ChiSqSelector() + .setSelectorType("kbest") .setNumTopFeatures(1) .setFeaturesCol("data") .setLabelCol("label") @@ -60,12 +61,28 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext assert(vec1 ~== vec2 absTol 1e-1) } - selector.setPercentile(0.34).fit(df).transform(df) - .select("filtered", "preFilteredData").collect().foreach { - case Row(vec1: Vector, vec2: Vector) => - assert(vec1 ~== vec2 absTol 1e-1) - } + selector.setSelectorType("percentile").setPercentile(0.34).fit(df).transform(df) + .select("filtered", "preFilteredData").collect().foreach { + case Row(vec1: Vector, vec2: Vector) => + assert(vec1 ~== vec2 absTol 1e-1) + } + + val preFilteredData2 = Seq( + Vectors.dense(8.0, 7.0), + Vectors.dense(0.0, 9.0), + Vectors.dense(0.0, 9.0), + Vectors.dense(8.0, 9.0) + ) + val df2 = sc.parallelize(data.zip(preFilteredData2)) + .map(x => (x._1.label, x._1.features, x._2)) + .toDF("label", "data", "preFilteredData") + + selector.setSelectorType("fpr").setAlpha(0.2).fit(df2).transform(df2) + .select("filtered", "preFilteredData").collect().foreach { + case Row(vec1: Vector, vec2: Vector) => + assert(vec1 ~== vec2 absTol 1e-1) + } } test("ChiSqSelector read/write") { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala index e181a544f7159..ec23a4aa7364d 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala @@ -76,7 +76,7 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext { LabeledPoint(1.0, Vectors.dense(Array(4.0))), LabeledPoint(1.0, Vectors.dense(Array(4.0))), LabeledPoint(2.0, Vectors.dense(Array(9.0)))) - val model = new ChiSqSelector().setAlpha(0.1).fit(labeledDiscreteData) + val model = new ChiSqSelector().setSelectorType("fpr").setAlpha(0.1).fit(labeledDiscreteData) val filteredData = labeledDiscreteData.map { lp => LabeledPoint(lp.label, model.transform(lp.features)) }.collect().toSet diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index c45434f1a57ca..12a13849dc9bc 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -2586,39 +2586,68 @@ class ChiSqSelector(JavaEstimator, HasFeaturesCol, HasOutputCol, HasLabelCol, Ja .. versionadded:: 2.0.0 """ + selectorType = Param(Params._dummy(), "selectorType", + "The selector type of the ChisqSelector. " + + "Supported options: kbest (default), percentile and fpr.", + typeConverter=TypeConverters.toString) + numTopFeatures = \ Param(Params._dummy(), "numTopFeatures", "Number of features that selector will select, ordered by statistics value " + "descending. If the number of features is < numTopFeatures, then this will select " + "all features.", typeConverter=TypeConverters.toInt) + percentile = Param(Params._dummy(), "percentile", "Percentile of features that selector " + + "will select, ordered by statistics value descending.", + typeConverter=TypeConverters.toFloat) + + alpha = Param(Params._dummy(), "alpha", "The highest p-value for features to be kept.", + typeConverter=TypeConverters.toFloat) + @keyword_only - def __init__(self, numTopFeatures=50, featuresCol="features", outputCol=None, labelCol="label"): + def __init__(self, numTopFeatures=50, featuresCol="features", outputCol=None, + labelCol="label", selectorType="kbest", percentile=0.1, alpha=0.05): """ - __init__(self, numTopFeatures=50, featuresCol="features", outputCol=None, labelCol="label") + __init__(self, numTopFeatures=50, featuresCol="features", outputCol=None, \ + labelCol="label", selectorType="kbest", percentile=0.1, alpha=0.05) """ super(ChiSqSelector, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.ChiSqSelector", self.uid) - self._setDefault(numTopFeatures=50) + self._setDefault(numTopFeatures=50, selectorType="kbest", percentile=0.1, alpha=0.05) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @keyword_only @since("2.0.0") def setParams(self, numTopFeatures=50, featuresCol="features", outputCol=None, - labelCol="labels"): + labelCol="labels", selectorType="kbest", percentile=0.1, alpha=0.05): """ - setParams(self, numTopFeatures=50, featuresCol="features", outputCol=None,\ - labelCol="labels") + setParams(self, numTopFeatures=50, featuresCol="features", outputCol=None, \ + labelCol="labels", selectorType="kbest", percentile=0.1, alpha=0.05) Sets params for this ChiSqSelector. """ kwargs = self.setParams._input_kwargs return self._set(**kwargs) + @since("2.1.0") + def setSelectorType(self, value): + """ + Sets the value of :py:attr:`selectorType`. + """ + return self._set(selectorType=value) + + @since("2.1.0") + def getSelectorType(self): + """ + Gets the value of selectorType or its default value. + """ + return self.getOrDefault(self.selectorType) + @since("2.0.0") def setNumTopFeatures(self, value): """ Sets the value of :py:attr:`numTopFeatures`. + Only applicable when selectorType = "kbest". """ return self._set(numTopFeatures=value) @@ -2629,6 +2658,36 @@ def getNumTopFeatures(self): """ return self.getOrDefault(self.numTopFeatures) + @since("2.1.0") + def setPercentile(self, value): + """ + Sets the value of :py:attr:`percentile`. + Only applicable when selectorType = "percentile". + """ + return self._set(percentile=value) + + @since("2.1.0") + def getPercentile(self): + """ + Gets the value of percentile or its default value. + """ + return self.getOrDefault(self.percentile) + + @since("2.1.0") + def setAlpha(self, value): + """ + Sets the value of :py:attr:`alpha`. + Only applicable when selectorType = "fpr". + """ + return self._set(alpha=value) + + @since("2.1.0") + def getAlpha(self): + """ + Gets the value of alpha or its default value. + """ + return self.getOrDefault(self.alpha) + def _create_model(self, java_model): return ChiSqSelectorModel(java_model) diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index 077c11370eb3f..4aea81840a162 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -271,22 +271,14 @@ def transform(self, vector): return JavaVectorTransformer.transform(self, vector) -class ChiSqSelectorType: - """ - This class defines the selector types of Chi Square Selector. - """ - KBest, Percentile, FPR = range(3) - - class ChiSqSelector(object): """ Creates a ChiSquared feature selector. The selector supports three selection methods: `KBest`, `Percentile` and `FPR`. - `KBest` chooses the `k` top features according to a chi-squared test. - `Percentile` is similar but chooses a fraction of all features instead of a fixed number. - `FPR` chooses all features whose false positive rate meets some threshold. - By default, the selection method is `KBest`, the default number of top features is 50. - User can use setNumTopFeatures, setPercentile and setAlpha to set different selection methods. + `kbest` chooses the `k` top features according to a chi-squared test. + `percentile` is similar but chooses a fraction of all features instead of a fixed number. + `fpr` chooses all features whose false positive rate meets some threshold. + By default, the selection method is `kbest`, the default number of top features is 50. >>> data = [ ... LabeledPoint(0.0, SparseVector(3, {0: 8.0, 1: 7.0})), @@ -299,7 +291,8 @@ class ChiSqSelector(object): SparseVector(1, {0: 6.0}) >>> model.transform(DenseVector([8.0, 9.0, 5.0])) DenseVector([5.0]) - >>> model = ChiSqSelector().setPercentile(0.34).fit(sc.parallelize(data)) + >>> model = ChiSqSelector().setSelectorType("percentile").setPercentile(0.34).fit( + ... sc.parallelize(data)) >>> model.transform(SparseVector(3, {1: 9.0, 2: 6.0})) SparseVector(1, {0: 6.0}) >>> model.transform(DenseVector([8.0, 9.0, 5.0])) @@ -310,41 +303,52 @@ class ChiSqSelector(object): ... LabeledPoint(1.0, [0.0, 9.0, 8.0, 4.0]), ... LabeledPoint(2.0, [8.0, 9.0, 5.0, 9.0]) ... ] - >>> model = ChiSqSelector().setAlpha(0.1).fit(sc.parallelize(data)) + >>> model = ChiSqSelector().setSelectorType("fpr").setAlpha(0.1).fit(sc.parallelize(data)) >>> model.transform(DenseVector([1.0,2.0,3.0,4.0])) DenseVector([4.0]) .. versionadded:: 1.4.0 """ - def __init__(self, numTopFeatures=50): + def __init__(self, numTopFeatures=50, selectorType="kbest", percentile=0.1, alpha=0.05): self.numTopFeatures = numTopFeatures - self.selectorType = ChiSqSelectorType.KBest + self.selectorType = selectorType + self.percentile = percentile + self.alpha = alpha @since('2.1.0') def setNumTopFeatures(self, numTopFeatures): """ - set numTopFeature for feature selection by number of top features + set numTopFeature for feature selection by number of top features. + Only applicable when selectorType = "kbest". """ self.numTopFeatures = int(numTopFeatures) - self.selectorType = ChiSqSelectorType.KBest return self @since('2.1.0') def setPercentile(self, percentile): """ - set percentile [0.0, 1.0] for feature selection by percentile + set percentile [0.0, 1.0] for feature selection by percentile. + Only applicable when selectorType = "percentile". """ self.percentile = float(percentile) - self.selectorType = ChiSqSelectorType.Percentile return self @since('2.1.0') def setAlpha(self, alpha): """ - set alpha [0.0, 1.0] for feature selection by FPR + set alpha [0.0, 1.0] for feature selection by FPR. + Only applicable when selectorType = "fpr". """ self.alpha = float(alpha) - self.selectorType = ChiSqSelectorType.FPR + return self + + @since('2.1.0') + def setSelectorType(self, selectorType): + """ + set the selector type of the ChisqSelector. + Supported options: "kbest" (default), "percentile" and "fpr". + """ + self.selectorType = str(selectorType) return self @since('1.4.0') @@ -357,15 +361,8 @@ def fit(self, data): treated as categorical for each distinct value. Apply feature discretizer before using this function. """ - if self.selectorType == ChiSqSelectorType.KBest: - jmodel = callMLlibFunc("fitChiSqSelectorKBest", self.numTopFeatures, data) - elif self.selectorType == ChiSqSelectorType.Percentile: - jmodel = callMLlibFunc("fitChiSqSelectorPercentile", self.percentile, data) - elif self.selectorType == ChiSqSelectorType.FPR: - jmodel = callMLlibFunc("fitChiSqSelectorFPR", self.alpha, data) - else: - raise ValueError("ChiSqSelector type supports KBest(0), Percentile(1) and" - " FPR(2), the current value is: %s" % self.selectorType) + jmodel = callMLlibFunc("fitChiSqSelector", self.selectorType, self.numTopFeatures, + self.percentile, self.alpha, data) return ChiSqSelectorModel(jmodel) From 50b89d05b7bffc212cc9b9ae6e0bca7cb90b9c77 Mon Sep 17 00:00:00 2001 From: Justin Pihony Date: Mon, 26 Sep 2016 09:54:22 +0100 Subject: [PATCH 122/306] [SPARK-14525][SQL] Make DataFrameWrite.save work for jdbc ## What changes were proposed in this pull request? This change modifies the implementation of DataFrameWriter.save such that it works with jdbc, and the call to jdbc merely delegates to save. ## How was this patch tested? This was tested via unit tests in the JDBCWriteSuite, of which I added one new test to cover this scenario. ## Additional details rxin This seems to have been most recently touched by you and was also commented on in the JIRA. This contribution is my original work and I license the work to the project under the project's open source license. Author: Justin Pihony Author: Justin Pihony Closes #12601 from JustinPihony/jdbc_reconciliation. --- docs/sql-programming-guide.md | 6 +- .../sql/JavaSQLDataSourceExample.java | 21 ++++ examples/src/main/python/sql/datasource.py | 19 ++++ examples/src/main/r/RSparkSQLExample.R | 4 + .../examples/sql/SQLDataSourceExample.scala | 22 +++++ .../apache/spark/sql/DataFrameWriter.scala | 59 +----------- .../datasources/jdbc/JDBCOptions.scala | 11 ++- .../jdbc/JdbcRelationProvider.scala | 95 ++++++++++++++++--- .../spark/sql/jdbc/JDBCWriteSuite.scala | 82 ++++++++++++++++ 9 files changed, 246 insertions(+), 73 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 4ac5fae566abe..71bdd19c16dbb 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1100,9 +1100,13 @@ CREATE TEMPORARY VIEW jdbcTable USING org.apache.spark.sql.jdbc OPTIONS ( url "jdbc:postgresql:dbserver", - dbtable "schema.tablename" + dbtable "schema.tablename", + user 'username', + password 'password' ) +INSERT INTO TABLE jdbcTable +SELECT * FROM resultTable {% endhighlight %} diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java index f9087e059385e..1860594e8e547 100644 --- a/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java +++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java @@ -22,6 +22,7 @@ import java.util.Arrays; import java.util.List; // $example off:schema_merging$ +import java.util.Properties; // $example on:basic_parquet_example$ import org.apache.spark.api.java.JavaRDD; @@ -235,6 +236,8 @@ private static void runJsonDatasetExample(SparkSession spark) { private static void runJdbcDatasetExample(SparkSession spark) { // $example on:jdbc_dataset$ + // Note: JDBC loading and saving can be achieved via either the load/save or jdbc methods + // Loading data from a JDBC source Dataset jdbcDF = spark.read() .format("jdbc") .option("url", "jdbc:postgresql:dbserver") @@ -242,6 +245,24 @@ private static void runJdbcDatasetExample(SparkSession spark) { .option("user", "username") .option("password", "password") .load(); + + Properties connectionProperties = new Properties(); + connectionProperties.put("user", "username"); + connectionProperties.put("password", "password"); + Dataset jdbcDF2 = spark.read() + .jdbc("jdbc:postgresql:dbserver", "schema.tablename", connectionProperties); + + // Saving data to a JDBC source + jdbcDF.write() + .format("jdbc") + .option("url", "jdbc:postgresql:dbserver") + .option("dbtable", "schema.tablename") + .option("user", "username") + .option("password", "password") + .save(); + + jdbcDF2.write() + .jdbc("jdbc:postgresql:dbserver", "schema.tablename", connectionProperties); // $example off:jdbc_dataset$ } } diff --git a/examples/src/main/python/sql/datasource.py b/examples/src/main/python/sql/datasource.py index b36c901d2b403..e9aa9d9ac2583 100644 --- a/examples/src/main/python/sql/datasource.py +++ b/examples/src/main/python/sql/datasource.py @@ -143,6 +143,8 @@ def json_dataset_example(spark): def jdbc_dataset_example(spark): # $example on:jdbc_dataset$ + # Note: JDBC loading and saving can be achieved via either the load/save or jdbc methods + # Loading data from a JDBC source jdbcDF = spark.read \ .format("jdbc") \ .option("url", "jdbc:postgresql:dbserver") \ @@ -150,6 +152,23 @@ def jdbc_dataset_example(spark): .option("user", "username") \ .option("password", "password") \ .load() + + jdbcDF2 = spark.read \ + .jdbc("jdbc:postgresql:dbserver", "schema.tablename", + properties={"user": "username", "password": "password"}) + + # Saving data to a JDBC source + jdbcDF.write \ + .format("jdbc") \ + .option("url", "jdbc:postgresql:dbserver") \ + .option("dbtable", "schema.tablename") \ + .option("user", "username") \ + .option("password", "password") \ + .save() + + jdbcDF2.write \ + .jdbc("jdbc:postgresql:dbserver", "schema.tablename", + properties={"user": "username", "password": "password"}) # $example off:jdbc_dataset$ diff --git a/examples/src/main/r/RSparkSQLExample.R b/examples/src/main/r/RSparkSQLExample.R index 4e0267a03851b..373a36dba14f0 100644 --- a/examples/src/main/r/RSparkSQLExample.R +++ b/examples/src/main/r/RSparkSQLExample.R @@ -204,7 +204,11 @@ results <- collect(sql("FROM src SELECT key, value")) # $example on:jdbc_dataset$ +# Loading data from a JDBC source df <- read.jdbc("jdbc:postgresql:dbserver", "schema.tablename", user = "username", password = "password") + +# Saving data to a JDBC source +write.jdbc(df, "jdbc:postgresql:dbserver", "schema.tablename", user = "username", password = "password") # $example off:jdbc_dataset$ # Stop the SparkSession now diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala b/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala index dc3915a4882b0..66f7cb1b53f48 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala @@ -16,6 +16,8 @@ */ package org.apache.spark.examples.sql +import java.util.Properties + import org.apache.spark.sql.SparkSession object SQLDataSourceExample { @@ -148,6 +150,8 @@ object SQLDataSourceExample { private def runJdbcDatasetExample(spark: SparkSession): Unit = { // $example on:jdbc_dataset$ + // Note: JDBC loading and saving can be achieved via either the load/save or jdbc methods + // Loading data from a JDBC source val jdbcDF = spark.read .format("jdbc") .option("url", "jdbc:postgresql:dbserver") @@ -155,6 +159,24 @@ object SQLDataSourceExample { .option("user", "username") .option("password", "password") .load() + + val connectionProperties = new Properties() + connectionProperties.put("user", "username") + connectionProperties.put("password", "password") + val jdbcDF2 = spark.read + .jdbc("jdbc:postgresql:dbserver", "schema.tablename", connectionProperties) + + // Saving data to a JDBC source + jdbcDF.write + .format("jdbc") + .option("url", "jdbc:postgresql:dbserver") + .option("dbtable", "schema.tablename") + .option("user", "username") + .option("password", "password") + .save() + + jdbcDF2.write + .jdbc("jdbc:postgresql:dbserver", "schema.tablename", connectionProperties) // $example off:jdbc_dataset$ } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 64d3422cb4b54..7374a8e045035 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -425,62 +425,11 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { def jdbc(url: String, table: String, connectionProperties: Properties): Unit = { assertNotPartitioned("jdbc") assertNotBucketed("jdbc") - - // to add required options like URL and dbtable - val params = extraOptions.toMap ++ Map("url" -> url, "dbtable" -> table) - val jdbcOptions = new JDBCOptions(params) - val jdbcUrl = jdbcOptions.url - val jdbcTable = jdbcOptions.table - - val props = new Properties() - extraOptions.foreach { case (key, value) => - props.put(key, value) - } // connectionProperties should override settings in extraOptions - props.putAll(connectionProperties) - val conn = JdbcUtils.createConnectionFactory(jdbcUrl, props)() - - try { - var tableExists = JdbcUtils.tableExists(conn, jdbcUrl, jdbcTable) - - if (mode == SaveMode.Ignore && tableExists) { - return - } - - if (mode == SaveMode.ErrorIfExists && tableExists) { - sys.error(s"Table $jdbcTable already exists.") - } - - if (mode == SaveMode.Overwrite && tableExists) { - if (jdbcOptions.isTruncate && - JdbcUtils.isCascadingTruncateTable(jdbcUrl) == Some(false)) { - JdbcUtils.truncateTable(conn, jdbcTable) - } else { - JdbcUtils.dropTable(conn, jdbcTable) - tableExists = false - } - } - - // Create the table if the table didn't exist. - if (!tableExists) { - val schema = JdbcUtils.schemaString(df, jdbcUrl) - // To allow certain options to append when create a new table, which can be - // table_options or partition_options. - // E.g., "CREATE TABLE t (name string) ENGINE=InnoDB DEFAULT CHARSET=utf8" - val createtblOptions = jdbcOptions.createTableOptions - val sql = s"CREATE TABLE $jdbcTable ($schema) $createtblOptions" - val statement = conn.createStatement - try { - statement.executeUpdate(sql) - } finally { - statement.close() - } - } - } finally { - conn.close() - } - - JdbcUtils.saveTable(df, jdbcUrl, jdbcTable, props) + this.extraOptions = this.extraOptions ++ (connectionProperties.asScala) + // explicit url and dbtable should override all + this.extraOptions += ("url" -> url, "dbtable" -> table) + format("jdbc").save() } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala index 1db090eaf9c9e..bcf65e53afa73 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala @@ -27,10 +27,12 @@ class JDBCOptions( // ------------------------------------------------------------ // Required parameters // ------------------------------------------------------------ + require(parameters.isDefinedAt("url"), "Option 'url' is required.") + require(parameters.isDefinedAt("dbtable"), "Option 'dbtable' is required.") // a JDBC URL - val url = parameters.getOrElse("url", sys.error("Option 'url' not specified")) + val url = parameters("url") // name of table - val table = parameters.getOrElse("dbtable", sys.error("Option 'dbtable' not specified")) + val table = parameters("dbtable") // ------------------------------------------------------------ // Optional parameter list @@ -44,6 +46,11 @@ class JDBCOptions( // the number of partitions val numPartitions = parameters.getOrElse("numPartitions", null) + require(partitionColumn == null || + (lowerBound != null && upperBound != null && numPartitions != null), + "If 'partitionColumn' is specified then 'lowerBound', 'upperBound'," + + " and 'numPartitions' are required.") + // ------------------------------------------------------------ // The options for DataFrameWriter // ------------------------------------------------------------ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala index 106ed1d440102..ae04af2479c8d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala @@ -19,37 +19,102 @@ package org.apache.spark.sql.execution.datasources.jdbc import java.util.Properties -import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.sources.{BaseRelation, DataSourceRegister, RelationProvider} +import scala.collection.JavaConverters.mapAsJavaMapConverter -class JdbcRelationProvider extends RelationProvider with DataSourceRegister { +import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SQLContext} +import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister, RelationProvider} + +class JdbcRelationProvider extends CreatableRelationProvider + with RelationProvider with DataSourceRegister { override def shortName(): String = "jdbc" - /** Returns a new base relation with the given parameters. */ override def createRelation( sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = { val jdbcOptions = new JDBCOptions(parameters) - if (jdbcOptions.partitionColumn != null - && (jdbcOptions.lowerBound == null - || jdbcOptions.upperBound == null - || jdbcOptions.numPartitions == null)) { - sys.error("Partitioning incompletely specified") - } + val partitionColumn = jdbcOptions.partitionColumn + val lowerBound = jdbcOptions.lowerBound + val upperBound = jdbcOptions.upperBound + val numPartitions = jdbcOptions.numPartitions - val partitionInfo = if (jdbcOptions.partitionColumn == null) { + val partitionInfo = if (partitionColumn == null) { null } else { JDBCPartitioningInfo( - jdbcOptions.partitionColumn, - jdbcOptions.lowerBound.toLong, - jdbcOptions.upperBound.toLong, - jdbcOptions.numPartitions.toInt) + partitionColumn, lowerBound.toLong, upperBound.toLong, numPartitions.toInt) } val parts = JDBCRelation.columnPartition(partitionInfo) val properties = new Properties() // Additional properties that we will pass to getConnection parameters.foreach(kv => properties.setProperty(kv._1, kv._2)) JDBCRelation(jdbcOptions.url, jdbcOptions.table, parts, properties)(sqlContext.sparkSession) } + + /* + * The following structure applies to this code: + * | tableExists | !tableExists + *------------------------------------------------------------------------------------ + * Ignore | BaseRelation | CreateTable, saveTable, BaseRelation + * ErrorIfExists | ERROR | CreateTable, saveTable, BaseRelation + * Overwrite* | (DropTable, CreateTable,) | CreateTable, saveTable, BaseRelation + * | saveTable, BaseRelation | + * Append | saveTable, BaseRelation | CreateTable, saveTable, BaseRelation + * + * *Overwrite & tableExists with truncate, will not drop & create, but instead truncate + */ + override def createRelation( + sqlContext: SQLContext, + mode: SaveMode, + parameters: Map[String, String], + data: DataFrame): BaseRelation = { + val jdbcOptions = new JDBCOptions(parameters) + val url = jdbcOptions.url + val table = jdbcOptions.table + + val props = new Properties() + props.putAll(parameters.asJava) + val conn = JdbcUtils.createConnectionFactory(url, props)() + + try { + val tableExists = JdbcUtils.tableExists(conn, url, table) + + val (doCreate, doSave) = (mode, tableExists) match { + case (SaveMode.Ignore, true) => (false, false) + case (SaveMode.ErrorIfExists, true) => throw new AnalysisException( + s"Table or view '$table' already exists, and SaveMode is set to ErrorIfExists.") + case (SaveMode.Overwrite, true) => + if (jdbcOptions.isTruncate && JdbcUtils.isCascadingTruncateTable(url) == Some(false)) { + JdbcUtils.truncateTable(conn, table) + (false, true) + } else { + JdbcUtils.dropTable(conn, table) + (true, true) + } + case (SaveMode.Append, true) => (false, true) + case (_, true) => throw new IllegalArgumentException(s"Unexpected SaveMode, '$mode'," + + " for handling existing tables.") + case (_, false) => (true, true) + } + + if (doCreate) { + val schema = JdbcUtils.schemaString(data, url) + // To allow certain options to append when create a new table, which can be + // table_options or partition_options. + // E.g., "CREATE TABLE t (name string) ENGINE=InnoDB DEFAULT CHARSET=utf8" + val createtblOptions = jdbcOptions.createTableOptions + val sql = s"CREATE TABLE $table ($schema) $createtblOptions" + val statement = conn.createStatement + try { + statement.executeUpdate(sql) + } finally { + statement.close() + } + } + if (doSave) JdbcUtils.saveTable(data, url, table, props) + } finally { + conn.close() + } + + createRelation(sqlContext, parameters) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index ff3309874f2e1..506971362f867 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.jdbc import java.sql.DriverManager import java.util.Properties +import scala.collection.JavaConverters.propertiesAsScalaMapConverter + import org.scalatest.BeforeAndAfter import org.apache.spark.SparkException @@ -208,4 +210,84 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { assert(2 === spark.read.jdbc(url1, "TEST.PEOPLE1", properties).count()) assert(2 === spark.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) } + + test("save works for format(\"jdbc\") if url and dbtable are set") { + val df = sqlContext.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + + df.write.format("jdbc") + .options(Map("url" -> url, "dbtable" -> "TEST.SAVETEST")) + .save() + + assert(2 === sqlContext.read.jdbc(url, "TEST.SAVETEST", new Properties).count) + assert( + 2 === sqlContext.read.jdbc(url, "TEST.SAVETEST", new Properties).collect()(0).length) + } + + test("save API with SaveMode.Overwrite") { + val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + val df2 = spark.createDataFrame(sparkContext.parallelize(arr1x2), schema2) + + df.write.format("jdbc") + .option("url", url1) + .option("dbtable", "TEST.SAVETEST") + .options(properties.asScala) + .save() + df2.write.mode(SaveMode.Overwrite).format("jdbc") + .option("url", url1) + .option("dbtable", "TEST.SAVETEST") + .options(properties.asScala) + .save() + assert(1 === spark.read.jdbc(url1, "TEST.SAVETEST", properties).count()) + assert(2 === spark.read.jdbc(url1, "TEST.SAVETEST", properties).collect()(0).length) + } + + test("save errors if url is not specified") { + val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + + val e = intercept[RuntimeException] { + df.write.format("jdbc") + .option("dbtable", "TEST.SAVETEST") + .options(properties.asScala) + .save() + }.getMessage + assert(e.contains("Option 'url' is required")) + } + + test("save errors if dbtable is not specified") { + val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + + val e = intercept[RuntimeException] { + df.write.format("jdbc") + .option("url", url1) + .options(properties.asScala) + .save() + }.getMessage + assert(e.contains("Option 'dbtable' is required")) + } + + test("save errors if wrong user/password combination") { + val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + + val e = intercept[org.h2.jdbc.JdbcSQLException] { + df.write.format("jdbc") + .option("dbtable", "TEST.SAVETEST") + .option("url", url1) + .save() + }.getMessage + assert(e.contains("Wrong user name or password")) + } + + test("save errors if partitionColumn and numPartitions and bounds not set") { + val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + + val e = intercept[java.lang.IllegalArgumentException] { + df.write.format("jdbc") + .option("dbtable", "TEST.SAVETEST") + .option("url", url1) + .option("partitionColumn", "foo") + .save() + }.getMessage + assert(e.contains("If 'partitionColumn' is specified then 'lowerBound', 'upperBound'," + + " and 'numPartitions' are required.")) + } } From f234b7cd795dd9baa3feff541c211b4daf39ccc6 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 26 Sep 2016 04:19:39 -0700 Subject: [PATCH 123/306] [SPARK-16356][ML] Add testImplicits for ML unit tests and promote toDF() ## What changes were proposed in this pull request? This was suggested in https://github.com/apache/spark/commit/101663f1ae222a919fc40510aa4f2bad22d1be6f#commitcomment-17114968. This PR adds `testImplicits` to `MLlibTestSparkContext` so that some implicits such as `toDF()` can be sued across ml tests. This PR also changes all the usages of `spark.createDataFrame( ... )` to `toDF()` where applicable in ml tests in Scala. ## How was this patch tested? Existing tests should work. Author: hyukjinkwon Closes #14035 from HyukjinKwon/minor-ml-test. --- .../org/apache/spark/ml/PipelineSuite.scala | 13 +- .../ml/classification/ClassifierSuite.scala | 16 +-- .../DecisionTreeClassifierSuite.scala | 3 +- .../classification/GBTClassifierSuite.scala | 6 +- .../LogisticRegressionSuite.scala | 43 +++--- .../MultilayerPerceptronClassifierSuite.scala | 26 ++-- .../ml/classification/NaiveBayesSuite.scala | 20 +-- .../ml/classification/OneVsRestSuite.scala | 4 +- .../RandomForestClassifierSuite.scala | 3 +- .../apache/spark/ml/clustering/LDASuite.scala | 6 +- .../BinaryClassificationEvaluatorSuite.scala | 14 +- .../evaluation/RegressionEvaluatorSuite.scala | 8 +- .../spark/ml/feature/BinarizerSuite.scala | 16 +-- .../spark/ml/feature/BucketizerSuite.scala | 15 +-- .../spark/ml/feature/ChiSqSelectorSuite.scala | 3 +- .../ml/feature/CountVectorizerSuite.scala | 30 +++-- .../apache/spark/ml/feature/DCTSuite.scala | 10 +- .../spark/ml/feature/HashingTFSuite.scala | 10 +- .../apache/spark/ml/feature/IDFSuite.scala | 6 +- .../spark/ml/feature/InteractionSuite.scala | 53 ++++---- .../spark/ml/feature/MaxAbsScalerSuite.scala | 5 +- .../spark/ml/feature/MinMaxScalerSuite.scala | 13 +- .../apache/spark/ml/feature/NGramSuite.scala | 35 +++-- .../spark/ml/feature/NormalizerSuite.scala | 4 +- .../spark/ml/feature/OneHotEncoderSuite.scala | 10 +- .../apache/spark/ml/feature/PCASuite.scala | 4 +- .../ml/feature/PolynomialExpansionSuite.scala | 11 +- .../spark/ml/feature/RFormulaSuite.scala | 126 ++++++++---------- .../ml/feature/SQLTransformerSuite.scala | 8 +- .../ml/feature/StandardScalerSuite.scala | 12 +- .../ml/feature/StopWordsRemoverSuite.scala | 29 ++-- .../spark/ml/feature/StringIndexerSuite.scala | 32 ++--- .../spark/ml/feature/TokenizerSuite.scala | 17 +-- .../ml/feature/VectorAssemblerSuite.scala | 10 +- .../spark/ml/feature/VectorIndexerSuite.scala | 15 ++- .../AFTSurvivalRegressionSuite.scala | 26 ++-- .../ml/regression/GBTRegressorSuite.scala | 7 +- .../GeneralizedLinearRegressionSuite.scala | 115 ++++++++-------- .../regression/IsotonicRegressionSuite.scala | 14 +- .../ml/regression/LinearRegressionSuite.scala | 62 ++++----- .../tree/impl/GradientBoostedTreesSuite.scala | 6 +- .../spark/ml/tuning/CrossValidatorSuite.scala | 12 +- .../ml/tuning/TrainValidationSplitSuite.scala | 13 +- .../spark/mllib/util/MLUtilsSuite.scala | 18 +-- .../mllib/util/MLlibTestSparkContext.scala | 13 +- 45 files changed, 462 insertions(+), 460 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index 3b490cdf56018..6413ca1f8b19e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -36,6 +36,8 @@ import org.apache.spark.sql.types.StructType class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import testImplicits._ + abstract class MyModel extends Model[MyModel] test("pipeline") { @@ -183,12 +185,11 @@ class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul } test("pipeline validateParams") { - val df = spark.createDataFrame( - Seq( - (1, Vectors.dense(0.0, 1.0, 4.0), 1.0), - (2, Vectors.dense(1.0, 0.0, 4.0), 2.0), - (3, Vectors.dense(1.0, 0.0, 5.0), 3.0), - (4, Vectors.dense(0.0, 0.0, 5.0), 4.0)) + val df = Seq( + (1, Vectors.dense(0.0, 1.0, 4.0), 1.0), + (2, Vectors.dense(1.0, 0.0, 4.0), 2.0), + (3, Vectors.dense(1.0, 0.0, 5.0), 3.0), + (4, Vectors.dense(0.0, 0.0, 5.0), 4.0) ).toDF("id", "features", "label") intercept[IllegalArgumentException] { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala index 4db5f03fb00b4..de712079329da 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala @@ -29,12 +29,13 @@ import org.apache.spark.sql.{DataFrame, Dataset} class ClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { - test("extractLabeledPoints") { - def getTestData(labels: Seq[Double]): DataFrame = { - val data = labels.map { label: Double => LabeledPoint(label, Vectors.dense(0.0)) } - spark.createDataFrame(data) - } + import testImplicits._ + + private def getTestData(labels: Seq[Double]): DataFrame = { + labels.map { label: Double => LabeledPoint(label, Vectors.dense(0.0)) }.toDF() + } + test("extractLabeledPoints") { val c = new MockClassifier // Valid dataset val df0 = getTestData(Seq(0.0, 2.0, 1.0, 5.0)) @@ -70,11 +71,6 @@ class ClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { } test("getNumClasses") { - def getTestData(labels: Seq[Double]): DataFrame = { - val data = labels.map { label: Double => LabeledPoint(label, Vectors.dense(0.0)) } - spark.createDataFrame(data) - } - val c = new MockClassifier // Valid dataset val df0 = getTestData(Seq(0.0, 2.0, 1.0, 5.0)) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index 089d30abb5ef9..c711e7fa9dc67 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -34,6 +34,7 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { import DecisionTreeClassifierSuite.compareAPIs + import testImplicits._ private var categoricalDataPointsRDD: RDD[LabeledPoint] = _ private var orderedLabeledPointsWithLabel0RDD: RDD[LabeledPoint] = _ @@ -345,7 +346,7 @@ class DecisionTreeClassifierSuite } test("Fitting without numClasses in metadata") { - val df: DataFrame = spark.createDataFrame(TreeTests.featureImportanceData(sc)) + val df: DataFrame = TreeTests.featureImportanceData(sc).toDF() val dt = new DecisionTreeClassifier().setMaxDepth(1) dt.fit(df) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index 8d588ccfd3545..3492709677d4f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -39,6 +39,7 @@ import org.apache.spark.util.Utils class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import testImplicits._ import GBTClassifierSuite.compareAPIs // Combinations for estimators, learning rates and subsamplingRate @@ -134,15 +135,14 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext */ test("Fitting without numClasses in metadata") { - val df: DataFrame = spark.createDataFrame(TreeTests.featureImportanceData(sc)) + val df: DataFrame = TreeTests.featureImportanceData(sc).toDF() val gbt = new GBTClassifier().setMaxDepth(1).setMaxIter(1) gbt.fit(df) } test("extractLabeledPoints with bad data") { def getTestData(labels: Seq[Double]): DataFrame = { - val data = labels.map { label: Double => LabeledPoint(label, Vectors.dense(0.0)) } - spark.createDataFrame(data) + labels.map { label: Double => LabeledPoint(label, Vectors.dense(0.0)) }.toDF() } val gbt = new GBTClassifier().setMaxDepth(1).setMaxIter(1) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 2623759f24d91..8451e60144981 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -37,6 +37,8 @@ import org.apache.spark.sql.functions.lit class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import testImplicits._ + @transient var smallBinaryDataset: Dataset[_] = _ @transient var smallMultinomialDataset: Dataset[_] = _ @transient var binaryDataset: Dataset[_] = _ @@ -46,8 +48,7 @@ class LogisticRegressionSuite override def beforeAll(): Unit = { super.beforeAll() - smallBinaryDataset = - spark.createDataFrame(generateLogisticInput(1.0, 1.0, nPoints = 100, seed = 42)) + smallBinaryDataset = generateLogisticInput(1.0, 1.0, nPoints = 100, seed = 42).toDF() smallMultinomialDataset = { val nPoints = 100 @@ -61,7 +62,7 @@ class LogisticRegressionSuite val testData = generateMultinomialLogisticInput( coefficients, xMean, xVariance, addIntercept = true, nPoints, 42) - val df = spark.createDataFrame(sc.parallelize(testData, 4)) + val df = sc.parallelize(testData, 4).toDF() df.cache() df } @@ -76,7 +77,7 @@ class LogisticRegressionSuite generateMultinomialLogisticInput(coefficients, xMean, xVariance, addIntercept = true, nPoints, 42) - spark.createDataFrame(sc.parallelize(testData, 4)) + sc.parallelize(testData, 4).toDF() } multinomialDataset = { @@ -91,7 +92,7 @@ class LogisticRegressionSuite val testData = generateMultinomialLogisticInput( coefficients, xMean, xVariance, addIntercept = true, nPoints, 42) - val df = spark.createDataFrame(sc.parallelize(testData, 4)) + val df = sc.parallelize(testData, 4).toDF() df.cache() df } @@ -430,10 +431,10 @@ class LogisticRegressionSuite val model = new LogisticRegressionModel("mLogReg", Matrices.dense(3, 2, Array(0.0, 0.0, 0.0, 1.0, 2.0, 3.0)), Vectors.dense(0.0, 0.0, 0.0), 3, true) - val overFlowData = spark.createDataFrame(Seq( + val overFlowData = Seq( LabeledPoint(1.0, Vectors.dense(0.0, 1000.0)), LabeledPoint(1.0, Vectors.dense(0.0, -1.0)) - )) + ).toDF() val results = model.transform(overFlowData).select("rawPrediction", "probability").collect() // probabilities are correct when margins have to be adjusted @@ -1795,9 +1796,9 @@ class LogisticRegressionSuite val numPoints = 40 val outlierData = MLTestingUtils.genClassificationInstancesWithWeightedOutliers(spark, numClasses, numPoints) - val testData = spark.createDataFrame(Array.tabulate[LabeledPoint](numClasses) { i => + val testData = Array.tabulate[LabeledPoint](numClasses) { i => LabeledPoint(i.toDouble, Vectors.dense(i.toDouble)) - }) + }.toSeq.toDF() val lr = new LogisticRegression().setFamily("binomial").setWeightCol("weight") val model = lr.fit(outlierData) val results = model.transform(testData).select("label", "prediction").collect() @@ -1819,9 +1820,9 @@ class LogisticRegressionSuite val numPoints = 40 val outlierData = MLTestingUtils.genClassificationInstancesWithWeightedOutliers(spark, numClasses, numPoints) - val testData = spark.createDataFrame(Array.tabulate[LabeledPoint](numClasses) { i => + val testData = Array.tabulate[LabeledPoint](numClasses) { i => LabeledPoint(i.toDouble, Vectors.dense(i.toDouble)) - }) + }.toSeq.toDF() val mlr = new LogisticRegression().setFamily("multinomial").setWeightCol("weight") val model = mlr.fit(outlierData) val results = model.transform(testData).select("label", "prediction").collect() @@ -1945,11 +1946,10 @@ class LogisticRegressionSuite } test("multiclass logistic regression with all labels the same") { - val constantData = spark.createDataFrame(Seq( + val constantData = Seq( LabeledPoint(4.0, Vectors.dense(0.0)), LabeledPoint(4.0, Vectors.dense(1.0)), - LabeledPoint(4.0, Vectors.dense(2.0))) - ) + LabeledPoint(4.0, Vectors.dense(2.0))).toDF() val mlr = new LogisticRegression().setFamily("multinomial") val model = mlr.fit(constantData) val results = model.transform(constantData) @@ -1961,11 +1961,10 @@ class LogisticRegressionSuite } // force the model to be trained with only one class - val constantZeroData = spark.createDataFrame(Seq( + val constantZeroData = Seq( LabeledPoint(0.0, Vectors.dense(0.0)), LabeledPoint(0.0, Vectors.dense(1.0)), - LabeledPoint(0.0, Vectors.dense(2.0))) - ) + LabeledPoint(0.0, Vectors.dense(2.0))).toDF() val modelZeroLabel = mlr.setFitIntercept(false).fit(constantZeroData) val resultsZero = modelZeroLabel.transform(constantZeroData) resultsZero.select("rawPrediction", "probability", "prediction").collect().foreach { @@ -1990,20 +1989,18 @@ class LogisticRegressionSuite } test("compressed storage") { - val moreClassesThanFeatures = spark.createDataFrame(Seq( + val moreClassesThanFeatures = Seq( LabeledPoint(4.0, Vectors.dense(0.0, 0.0, 0.0)), LabeledPoint(4.0, Vectors.dense(1.0, 1.0, 1.0)), - LabeledPoint(4.0, Vectors.dense(2.0, 2.0, 2.0))) - ) + LabeledPoint(4.0, Vectors.dense(2.0, 2.0, 2.0))).toDF() val mlr = new LogisticRegression().setFamily("multinomial") val model = mlr.fit(moreClassesThanFeatures) assert(model.coefficientMatrix.isInstanceOf[SparseMatrix]) assert(model.coefficientMatrix.asInstanceOf[SparseMatrix].colPtrs.length === 4) - val moreFeaturesThanClasses = spark.createDataFrame(Seq( + val moreFeaturesThanClasses = Seq( LabeledPoint(1.0, Vectors.dense(0.0, 0.0, 0.0)), LabeledPoint(1.0, Vectors.dense(1.0, 1.0, 1.0)), - LabeledPoint(1.0, Vectors.dense(2.0, 2.0, 2.0))) - ) + LabeledPoint(1.0, Vectors.dense(2.0, 2.0, 2.0))).toDF() val model2 = mlr.fit(moreFeaturesThanClasses) assert(model2.coefficientMatrix.isInstanceOf[SparseMatrix]) assert(model2.coefficientMatrix.asInstanceOf[SparseMatrix].colPtrs.length === 3) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala index e809dd4092afa..c08cb695806d0 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala @@ -33,16 +33,18 @@ import org.apache.spark.sql.{Dataset, Row} class MultilayerPerceptronClassifierSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import testImplicits._ + @transient var dataset: Dataset[_] = _ override def beforeAll(): Unit = { super.beforeAll() - dataset = spark.createDataFrame(Seq( - (Vectors.dense(0.0, 0.0), 0.0), - (Vectors.dense(0.0, 1.0), 1.0), - (Vectors.dense(1.0, 0.0), 1.0), - (Vectors.dense(1.0, 1.0), 0.0)) + dataset = Seq( + (Vectors.dense(0.0, 0.0), 0.0), + (Vectors.dense(0.0, 1.0), 1.0), + (Vectors.dense(1.0, 0.0), 1.0), + (Vectors.dense(1.0, 1.0), 0.0) ).toDF("features", "label") } @@ -80,11 +82,11 @@ class MultilayerPerceptronClassifierSuite } test("Test setWeights by training restart") { - val dataFrame = spark.createDataFrame(Seq( + val dataFrame = Seq( (Vectors.dense(0.0, 0.0), 0.0), (Vectors.dense(0.0, 1.0), 1.0), (Vectors.dense(1.0, 0.0), 1.0), - (Vectors.dense(1.0, 1.0), 0.0)) + (Vectors.dense(1.0, 1.0), 0.0) ).toDF("features", "label") val layers = Array[Int](2, 5, 2) val trainer = new MultilayerPerceptronClassifier() @@ -114,9 +116,9 @@ class MultilayerPerceptronClassifierSuite val xMean = Array(5.843, 3.057, 3.758, 1.199) val xVariance = Array(0.6856, 0.1899, 3.116, 0.581) // the input seed is somewhat magic, to make this test pass - val rdd = sc.parallelize(generateMultinomialLogisticInput( - coefficients, xMean, xVariance, true, nPoints, 1), 2) - val dataFrame = spark.createDataFrame(rdd).toDF("label", "features") + val data = generateMultinomialLogisticInput( + coefficients, xMean, xVariance, true, nPoints, 1).toDS() + val dataFrame = data.toDF("label", "features") val numClasses = 3 val numIterations = 100 val layers = Array[Int](4, 5, 4, numClasses) @@ -137,9 +139,9 @@ class MultilayerPerceptronClassifierSuite .setNumClasses(numClasses) lr.optimizer.setRegParam(0.0) .setNumIterations(numIterations) - val lrModel = lr.run(rdd.map(OldLabeledPoint.fromML)) + val lrModel = lr.run(data.rdd.map(OldLabeledPoint.fromML)) val lrPredictionAndLabels = - lrModel.predict(rdd.map(p => OldVectors.fromML(p.features))).zip(rdd.map(_.label)) + lrModel.predict(data.rdd.map(p => OldVectors.fromML(p.features))).zip(data.rdd.map(_.label)) // MLP's predictions should not differ a lot from LR's. val lrMetrics = new MulticlassMetrics(lrPredictionAndLabels) val mlpMetrics = new MulticlassMetrics(mlpPredictionAndLabels) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala index 04c010bd13e1e..99099324284dc 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala @@ -35,6 +35,8 @@ import org.apache.spark.sql.{DataFrame, Dataset, Row} class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import testImplicits._ + @transient var dataset: Dataset[_] = _ override def beforeAll(): Unit = { @@ -47,7 +49,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa Array(0.10, 0.10, 0.70, 0.10) // label 2 ).map(_.map(math.log)) - dataset = spark.createDataFrame(generateNaiveBayesInput(pi, theta, 100, 42)) + dataset = generateNaiveBayesInput(pi, theta, 100, 42).toDF() } def validatePrediction(predictionAndLabels: DataFrame): Unit = { @@ -131,16 +133,16 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa val pi = Vectors.dense(piArray) val theta = new DenseMatrix(3, 4, thetaArray.flatten, true) - val testDataset = spark.createDataFrame(generateNaiveBayesInput( - piArray, thetaArray, nPoints, 42, "multinomial")) + val testDataset = + generateNaiveBayesInput(piArray, thetaArray, nPoints, 42, "multinomial").toDF() val nb = new NaiveBayes().setSmoothing(1.0).setModelType("multinomial") val model = nb.fit(testDataset) validateModelFit(pi, theta, model) assert(model.hasParent) - val validationDataset = spark.createDataFrame(generateNaiveBayesInput( - piArray, thetaArray, nPoints, 17, "multinomial")) + val validationDataset = + generateNaiveBayesInput(piArray, thetaArray, nPoints, 17, "multinomial").toDF() val predictionAndLabels = model.transform(validationDataset).select("prediction", "label") validatePrediction(predictionAndLabels) @@ -161,16 +163,16 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa val pi = Vectors.dense(piArray) val theta = new DenseMatrix(3, 12, thetaArray.flatten, true) - val testDataset = spark.createDataFrame(generateNaiveBayesInput( - piArray, thetaArray, nPoints, 45, "bernoulli")) + val testDataset = + generateNaiveBayesInput(piArray, thetaArray, nPoints, 45, "bernoulli").toDF() val nb = new NaiveBayes().setSmoothing(1.0).setModelType("bernoulli") val model = nb.fit(testDataset) validateModelFit(pi, theta, model) assert(model.hasParent) - val validationDataset = spark.createDataFrame(generateNaiveBayesInput( - piArray, thetaArray, nPoints, 20, "bernoulli")) + val validationDataset = + generateNaiveBayesInput(piArray, thetaArray, nPoints, 20, "bernoulli").toDF() val predictionAndLabels = model.transform(validationDataset).select("prediction", "label") validatePrediction(predictionAndLabels) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index 99dd5854ff649..3f9bcec427399 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -37,6 +37,8 @@ import org.apache.spark.sql.types.Metadata class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import testImplicits._ + @transient var dataset: Dataset[_] = _ @transient var rdd: RDD[LabeledPoint] = _ @@ -55,7 +57,7 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with Defau val xVariance = Array(0.6856, 0.1899, 3.116, 0.581) rdd = sc.parallelize(generateMultinomialLogisticInput( coefficients, xMean, xVariance, true, nPoints, 42), 2) - dataset = spark.createDataFrame(rdd) + dataset = rdd.toDF() } test("params") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index 2e99ee157ae95..44e1585ee514b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -39,6 +39,7 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { import RandomForestClassifierSuite.compareAPIs + import testImplicits._ private var orderedLabeledPoints50_1000: RDD[LabeledPoint] = _ private var orderedLabeledPoints5_20: RDD[LabeledPoint] = _ @@ -158,7 +159,7 @@ class RandomForestClassifierSuite } test("Fitting without numClasses in metadata") { - val df: DataFrame = spark.createDataFrame(TreeTests.featureImportanceData(sc)) + val df: DataFrame = TreeTests.featureImportanceData(sc).toDF() val rf = new RandomForestClassifier().setMaxDepth(1).setNumTrees(1) rf.fit(df) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala index ddfa87555427b..3f39deddf20b4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala @@ -62,6 +62,8 @@ object LDASuite { class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import testImplicits._ + val k: Int = 5 val vocabSize: Int = 30 @transient var dataset: Dataset[_] = _ @@ -140,8 +142,8 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead new LDA().setTopicConcentration(-1.1) } - val dummyDF = spark.createDataFrame(Seq( - (1, Vectors.dense(1.0, 2.0)))).toDF("id", "features") + val dummyDF = Seq((1, Vectors.dense(1.0, 2.0))).toDF("id", "features") + // validate parameters lda.transformSchema(dummyDF.schema) lda.setDocConcentration(1.1) diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala index 9ee3df5eb5e33..ede284712b1c0 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala @@ -26,6 +26,8 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext class BinaryClassificationEvaluatorSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import testImplicits._ + test("params") { ParamsSuite.checkParams(new BinaryClassificationEvaluator) } @@ -42,25 +44,25 @@ class BinaryClassificationEvaluatorSuite val evaluator = new BinaryClassificationEvaluator() .setMetricName("areaUnderPR") - val vectorDF = spark.createDataFrame(Seq( + val vectorDF = Seq( (0d, Vectors.dense(12, 2.5)), (1d, Vectors.dense(1, 3)), (0d, Vectors.dense(10, 2)) - )).toDF("label", "rawPrediction") + ).toDF("label", "rawPrediction") assert(evaluator.evaluate(vectorDF) === 1.0) - val doubleDF = spark.createDataFrame(Seq( + val doubleDF = Seq( (0d, 0d), (1d, 1d), (0d, 0d) - )).toDF("label", "rawPrediction") + ).toDF("label", "rawPrediction") assert(evaluator.evaluate(doubleDF) === 1.0) - val stringDF = spark.createDataFrame(Seq( + val stringDF = Seq( (0d, "0d"), (1d, "1d"), (0d, "0d") - )).toDF("label", "rawPrediction") + ).toDF("label", "rawPrediction") val thrown = intercept[IllegalArgumentException] { evaluator.evaluate(stringDF) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala index 42ff8adf6bd65..c1a156959618e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala @@ -27,6 +27,8 @@ import org.apache.spark.mllib.util.TestingUtils._ class RegressionEvaluatorSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import testImplicits._ + test("params") { ParamsSuite.checkParams(new RegressionEvaluator) } @@ -42,9 +44,9 @@ class RegressionEvaluatorSuite * data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1)) * .saveAsTextFile("path") */ - val dataset = spark.createDataFrame( - sc.parallelize(LinearDataGenerator.generateLinearInput( - 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1), 2).map(_.asML)) + val dataset = LinearDataGenerator.generateLinearInput( + 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1) + .map(_.asML).toDF() /** * Using the following R code to load the data, train the model and evaluate metrics. diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala index 9cb84a6ee9b87..4455d35210878 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala @@ -26,6 +26,8 @@ import org.apache.spark.sql.{DataFrame, Row} class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import testImplicits._ + @transient var data: Array[Double] = _ override def beforeAll(): Unit = { @@ -39,8 +41,7 @@ class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defau test("Binarize continuous features with default parameter") { val defaultBinarized: Array[Double] = data.map(x => if (x > 0.0) 1.0 else 0.0) - val dataFrame: DataFrame = spark.createDataFrame( - data.zip(defaultBinarized)).toDF("feature", "expected") + val dataFrame: DataFrame = data.zip(defaultBinarized).toSeq.toDF("feature", "expected") val binarizer: Binarizer = new Binarizer() .setInputCol("feature") @@ -55,8 +56,7 @@ class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defau test("Binarize continuous features with setter") { val threshold: Double = 0.2 val thresholdBinarized: Array[Double] = data.map(x => if (x > threshold) 1.0 else 0.0) - val dataFrame: DataFrame = spark.createDataFrame( - data.zip(thresholdBinarized)).toDF("feature", "expected") + val dataFrame: DataFrame = data.zip(thresholdBinarized).toSeq.toDF("feature", "expected") val binarizer: Binarizer = new Binarizer() .setInputCol("feature") @@ -71,9 +71,9 @@ class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defau test("Binarize vector of continuous features with default parameter") { val defaultBinarized: Array[Double] = data.map(x => if (x > 0.0) 1.0 else 0.0) - val dataFrame: DataFrame = spark.createDataFrame(Seq( + val dataFrame: DataFrame = Seq( (Vectors.dense(data), Vectors.dense(defaultBinarized)) - )).toDF("feature", "expected") + ).toDF("feature", "expected") val binarizer: Binarizer = new Binarizer() .setInputCol("feature") @@ -88,9 +88,9 @@ class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defau test("Binarize vector of continuous features with setter") { val threshold: Double = 0.2 val defaultBinarized: Array[Double] = data.map(x => if (x > threshold) 1.0 else 0.0) - val dataFrame: DataFrame = spark.createDataFrame(Seq( + val dataFrame: DataFrame = Seq( (Vectors.dense(data), Vectors.dense(defaultBinarized)) - )).toDF("feature", "expected") + ).toDF("feature", "expected") val binarizer: Binarizer = new Binarizer() .setInputCol("feature") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala index c7f5093e74740..87cdceb267387 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala @@ -29,6 +29,8 @@ import org.apache.spark.sql.{DataFrame, Row} class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import testImplicits._ + test("params") { ParamsSuite.checkParams(new Bucketizer) } @@ -38,8 +40,7 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa val splits = Array(-0.5, 0.0, 0.5) val validData = Array(-0.5, -0.3, 0.0, 0.2) val expectedBuckets = Array(0.0, 0.0, 1.0, 1.0) - val dataFrame: DataFrame = - spark.createDataFrame(validData.zip(expectedBuckets)).toDF("feature", "expected") + val dataFrame: DataFrame = validData.zip(expectedBuckets).toSeq.toDF("feature", "expected") val bucketizer: Bucketizer = new Bucketizer() .setInputCol("feature") @@ -55,13 +56,13 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa // Check for exceptions when using a set of invalid feature values. val invalidData1: Array[Double] = Array(-0.9) ++ validData val invalidData2 = Array(0.51) ++ validData - val badDF1 = spark.createDataFrame(invalidData1.zipWithIndex).toDF("feature", "idx") + val badDF1 = invalidData1.zipWithIndex.toSeq.toDF("feature", "idx") withClue("Invalid feature value -0.9 was not caught as an invalid feature!") { intercept[SparkException] { bucketizer.transform(badDF1).collect() } } - val badDF2 = spark.createDataFrame(invalidData2.zipWithIndex).toDF("feature", "idx") + val badDF2 = invalidData2.zipWithIndex.toSeq.toDF("feature", "idx") withClue("Invalid feature value 0.51 was not caught as an invalid feature!") { intercept[SparkException] { bucketizer.transform(badDF2).collect() @@ -73,8 +74,7 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa val splits = Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity) val validData = Array(-0.9, -0.5, -0.3, 0.0, 0.2, 0.5, 0.9) val expectedBuckets = Array(0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0) - val dataFrame: DataFrame = - spark.createDataFrame(validData.zip(expectedBuckets)).toDF("feature", "expected") + val dataFrame: DataFrame = validData.zip(expectedBuckets).toSeq.toDF("feature", "expected") val bucketizer: Bucketizer = new Bucketizer() .setInputCol("feature") @@ -92,8 +92,7 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa val splits = Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity) val validData = Array(-0.9, -0.5, -0.3, 0.0, 0.2, 0.5, 0.9, Double.NaN, Double.NaN, Double.NaN) val expectedBuckets = Array(0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0, 4.0) - val dataFrame: DataFrame = - spark.createDataFrame(validData.zip(expectedBuckets)).toDF("feature", "expected") + val dataFrame: DataFrame = validData.zip(expectedBuckets).toSeq.toDF("feature", "expected") val bucketizer: Bucketizer = new Bucketizer() .setInputCol("feature") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala index 6b56e4200250c..dfebfc87ea1d3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala @@ -29,8 +29,7 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("Test Chi-Square selector") { - val spark = this.spark - import spark.implicits._ + import testImplicits._ val data = Seq( LabeledPoint(0.0, Vectors.sparse(3, Array((0, 8.0), (1, 7.0)))), LabeledPoint(1.0, Vectors.sparse(3, Array((1, 9.0), (2, 6.0)))), diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala index 863b66bf497fe..69d3033bb2189 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala @@ -27,6 +27,8 @@ import org.apache.spark.sql.Row class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import testImplicits._ + test("params") { ParamsSuite.checkParams(new CountVectorizer) ParamsSuite.checkParams(new CountVectorizerModel(Array("empty"))) @@ -35,7 +37,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext private def split(s: String): Seq[String] = s.split("\\s+") test("CountVectorizerModel common cases") { - val df = spark.createDataFrame(Seq( + val df = Seq( (0, split("a b c d"), Vectors.sparse(4, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0)))), (1, split("a b b c d a"), @@ -44,7 +46,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext (3, split(""), Vectors.sparse(4, Seq())), // empty string (4, split("a notInDict d"), Vectors.sparse(4, Seq((0, 1.0), (3, 1.0)))) // with words not in vocabulary - )).toDF("id", "words", "expected") + ).toDF("id", "words", "expected") val cv = new CountVectorizerModel(Array("a", "b", "c", "d")) .setInputCol("words") .setOutputCol("features") @@ -55,13 +57,13 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext } test("CountVectorizer common cases") { - val df = spark.createDataFrame(Seq( + val df = Seq( (0, split("a b c d e"), Vectors.sparse(5, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0), (4, 1.0)))), (1, split("a a a a a a"), Vectors.sparse(5, Seq((0, 6.0)))), (2, split("c c"), Vectors.sparse(5, Seq((2, 2.0)))), (3, split("d"), Vectors.sparse(5, Seq((3, 1.0)))), - (4, split("b b b b b"), Vectors.sparse(5, Seq((1, 5.0))))) + (4, split("b b b b b"), Vectors.sparse(5, Seq((1, 5.0)))) ).toDF("id", "words", "expected") val cv = new CountVectorizer() .setInputCol("words") @@ -76,11 +78,11 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext } test("CountVectorizer vocabSize and minDF") { - val df = spark.createDataFrame(Seq( + val df = Seq( (0, split("a b c d"), Vectors.sparse(2, Seq((0, 1.0), (1, 1.0)))), (1, split("a b c"), Vectors.sparse(2, Seq((0, 1.0), (1, 1.0)))), (2, split("a b"), Vectors.sparse(2, Seq((0, 1.0), (1, 1.0)))), - (3, split("a"), Vectors.sparse(2, Seq((0, 1.0))))) + (3, split("a"), Vectors.sparse(2, Seq((0, 1.0)))) ).toDF("id", "words", "expected") val cvModel = new CountVectorizer() .setInputCol("words") @@ -118,9 +120,9 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext test("CountVectorizer throws exception when vocab is empty") { intercept[IllegalArgumentException] { - val df = spark.createDataFrame(Seq( + val df = Seq( (0, split("a a b b c c")), - (1, split("aa bb cc"))) + (1, split("aa bb cc")) ).toDF("id", "words") val cvModel = new CountVectorizer() .setInputCol("words") @@ -132,11 +134,11 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext } test("CountVectorizerModel with minTF count") { - val df = spark.createDataFrame(Seq( + val df = Seq( (0, split("a a a b b c c c d "), Vectors.sparse(4, Seq((0, 3.0), (2, 3.0)))), (1, split("c c c c c c"), Vectors.sparse(4, Seq((2, 6.0)))), (2, split("a"), Vectors.sparse(4, Seq())), - (3, split("e e e e e"), Vectors.sparse(4, Seq()))) + (3, split("e e e e e"), Vectors.sparse(4, Seq())) ).toDF("id", "words", "expected") // minTF: count @@ -151,11 +153,11 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext } test("CountVectorizerModel with minTF freq") { - val df = spark.createDataFrame(Seq( + val df = Seq( (0, split("a a a b b c c c d "), Vectors.sparse(4, Seq((0, 3.0), (2, 3.0)))), (1, split("c c c c c c"), Vectors.sparse(4, Seq((2, 6.0)))), (2, split("a"), Vectors.sparse(4, Seq((0, 1.0)))), - (3, split("e e e e e"), Vectors.sparse(4, Seq()))) + (3, split("e e e e e"), Vectors.sparse(4, Seq())) ).toDF("id", "words", "expected") // minTF: set frequency @@ -170,12 +172,12 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext } test("CountVectorizerModel and CountVectorizer with binary") { - val df = spark.createDataFrame(Seq( + val df = Seq( (0, split("a a a a b b b b c d"), Vectors.sparse(4, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0)))), (1, split("c c c"), Vectors.sparse(4, Seq((2, 1.0)))), (2, split("a"), Vectors.sparse(4, Seq((0, 1.0)))) - )).toDF("id", "words", "expected") + ).toDF("id", "words", "expected") // CountVectorizer test val cv = new CountVectorizer() diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala index c02e9610418bf..8dd3dd75e1be5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala @@ -32,6 +32,8 @@ case class DCTTestData(vec: Vector, wantedVec: Vector) class DCTSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import testImplicits._ + test("forward transform of discrete cosine matches jTransforms result") { val data = Vectors.dense((0 until 128).map(_ => 2D * math.random - 1D).toArray) val inverse = false @@ -57,15 +59,13 @@ class DCTSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead private def testDCT(data: Vector, inverse: Boolean): Unit = { val expectedResultBuffer = data.toArray.clone() if (inverse) { - (new DoubleDCT_1D(data.size)).inverse(expectedResultBuffer, true) + new DoubleDCT_1D(data.size).inverse(expectedResultBuffer, true) } else { - (new DoubleDCT_1D(data.size)).forward(expectedResultBuffer, true) + new DoubleDCT_1D(data.size).forward(expectedResultBuffer, true) } val expectedResult = Vectors.dense(expectedResultBuffer) - val dataset = spark.createDataFrame(Seq( - DCTTestData(data, expectedResult) - )) + val dataset = Seq(DCTTestData(data, expectedResult)).toDF() val transformer = new DCT() .setInputCol("vec") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala index 99b800776bb64..1d14866cc933b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala @@ -29,14 +29,14 @@ import org.apache.spark.util.Utils class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import testImplicits._ + test("params") { ParamsSuite.checkParams(new HashingTF) } test("hashingTF") { - val df = spark.createDataFrame(Seq( - (0, "a a b b c d".split(" ").toSeq) - )).toDF("id", "words") + val df = Seq((0, "a a b b c d".split(" ").toSeq)).toDF("id", "words") val n = 100 val hashingTF = new HashingTF() .setInputCol("words") @@ -54,9 +54,7 @@ class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext with Defau } test("applying binary term freqs") { - val df = spark.createDataFrame(Seq( - (0, "a a b c c c".split(" ").toSeq) - )).toDF("id", "words") + val df = Seq((0, "a a b c c c".split(" ").toSeq)).toDF("id", "words") val n = 100 val hashingTF = new HashingTF() .setInputCol("words") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala index 09dc8b9b932fd..5325d95526a50 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala @@ -29,6 +29,8 @@ import org.apache.spark.sql.Row class IDFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import testImplicits._ + def scaleDataWithIDF(dataSet: Array[Vector], model: Vector): Array[Vector] = { dataSet.map { case data: DenseVector => @@ -61,7 +63,7 @@ class IDFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead }) val expected = scaleDataWithIDF(data, idf) - val df = spark.createDataFrame(data.zip(expected)).toDF("features", "expected") + val df = data.zip(expected).toSeq.toDF("features", "expected") val idfModel = new IDF() .setInputCol("features") @@ -87,7 +89,7 @@ class IDFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead }) val expected = scaleDataWithIDF(data, idf) - val df = spark.createDataFrame(data.zip(expected)).toDF("features", "expected") + val df = data.zip(expected).toSeq.toDF("features", "expected") val idfModel = new IDF() .setInputCol("features") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala index 3429172a8c903..54f059e5f143e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala @@ -28,6 +28,9 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.functions.col class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + + import testImplicits._ + test("params") { ParamsSuite.checkParams(new Interaction()) } @@ -59,11 +62,10 @@ class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with Def } test("numeric interaction") { - val data = spark.createDataFrame( - Seq( - (2, Vectors.dense(3.0, 4.0)), - (1, Vectors.dense(1.0, 5.0))) - ).toDF("a", "b") + val data = Seq( + (2, Vectors.dense(3.0, 4.0)), + (1, Vectors.dense(1.0, 5.0)) + ).toDF("a", "b") val groupAttr = new AttributeGroup( "b", Array[Attribute]( @@ -74,11 +76,10 @@ class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with Def col("b").as("b", groupAttr.toMetadata())) val trans = new Interaction().setInputCols(Array("a", "b")).setOutputCol("features") val res = trans.transform(df) - val expected = spark.createDataFrame( - Seq( - (2, Vectors.dense(3.0, 4.0), Vectors.dense(6.0, 8.0)), - (1, Vectors.dense(1.0, 5.0), Vectors.dense(1.0, 5.0))) - ).toDF("a", "b", "features") + val expected = Seq( + (2, Vectors.dense(3.0, 4.0), Vectors.dense(6.0, 8.0)), + (1, Vectors.dense(1.0, 5.0), Vectors.dense(1.0, 5.0)) + ).toDF("a", "b", "features") assert(res.collect() === expected.collect()) val attrs = AttributeGroup.fromStructField(res.schema("features")) val expectedAttrs = new AttributeGroup( @@ -90,11 +91,10 @@ class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with Def } test("nominal interaction") { - val data = spark.createDataFrame( - Seq( - (2, Vectors.dense(3.0, 4.0)), - (1, Vectors.dense(1.0, 5.0))) - ).toDF("a", "b") + val data = Seq( + (2, Vectors.dense(3.0, 4.0)), + (1, Vectors.dense(1.0, 5.0)) + ).toDF("a", "b") val groupAttr = new AttributeGroup( "b", Array[Attribute]( @@ -106,11 +106,10 @@ class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with Def col("b").as("b", groupAttr.toMetadata())) val trans = new Interaction().setInputCols(Array("a", "b")).setOutputCol("features") val res = trans.transform(df) - val expected = spark.createDataFrame( - Seq( - (2, Vectors.dense(3.0, 4.0), Vectors.dense(0, 0, 0, 0, 3, 4)), - (1, Vectors.dense(1.0, 5.0), Vectors.dense(0, 0, 1, 5, 0, 0))) - ).toDF("a", "b", "features") + val expected = Seq( + (2, Vectors.dense(3.0, 4.0), Vectors.dense(0, 0, 0, 0, 3, 4)), + (1, Vectors.dense(1.0, 5.0), Vectors.dense(0, 0, 1, 5, 0, 0)) + ).toDF("a", "b", "features") assert(res.collect() === expected.collect()) val attrs = AttributeGroup.fromStructField(res.schema("features")) val expectedAttrs = new AttributeGroup( @@ -126,10 +125,9 @@ class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with Def } test("default attr names") { - val data = spark.createDataFrame( - Seq( + val data = Seq( (2, Vectors.dense(0.0, 4.0), 1.0), - (1, Vectors.dense(1.0, 5.0), 10.0)) + (1, Vectors.dense(1.0, 5.0), 10.0) ).toDF("a", "b", "c") val groupAttr = new AttributeGroup( "b", @@ -142,11 +140,10 @@ class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with Def col("c").as("c", NumericAttribute.defaultAttr.toMetadata())) val trans = new Interaction().setInputCols(Array("a", "b", "c")).setOutputCol("features") val res = trans.transform(df) - val expected = spark.createDataFrame( - Seq( - (2, Vectors.dense(0.0, 4.0), 1.0, Vectors.dense(0, 0, 0, 0, 0, 0, 1, 0, 4)), - (1, Vectors.dense(1.0, 5.0), 10.0, Vectors.dense(0, 0, 0, 0, 10, 50, 0, 0, 0))) - ).toDF("a", "b", "c", "features") + val expected = Seq( + (2, Vectors.dense(0.0, 4.0), 1.0, Vectors.dense(0, 0, 0, 0, 0, 0, 1, 0, 4)), + (1, Vectors.dense(1.0, 5.0), 10.0, Vectors.dense(0, 0, 0, 0, 10, 50, 0, 0, 0)) + ).toDF("a", "b", "c", "features") assert(res.collect() === expected.collect()) val attrs = AttributeGroup.fromStructField(res.schema("features")) val expectedAttrs = new AttributeGroup( diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala index d6400ee02f951..a12174493b867 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala @@ -23,6 +23,9 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Row class MaxAbsScalerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + + import testImplicits._ + test("MaxAbsScaler fit basic case") { val data = Array( Vectors.dense(1, 0, 100), @@ -36,7 +39,7 @@ class MaxAbsScalerSuite extends SparkFunSuite with MLlibTestSparkContext with De Vectors.sparse(3, Array(0, 2), Array(-1, -1)), Vectors.sparse(3, Array(0), Array(-0.75))) - val df = spark.createDataFrame(data.zip(expected)).toDF("features", "expected") + val df = data.zip(expected).toSeq.toDF("features", "expected") val scaler = new MaxAbsScaler() .setInputCol("features") .setOutputCol("scaled") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala index 9f376b70035c5..b79eeb2d75ef0 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala @@ -25,6 +25,8 @@ import org.apache.spark.sql.Row class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import testImplicits._ + test("MinMaxScaler fit basic case") { val data = Array( Vectors.dense(1, 0, Long.MinValue), @@ -38,7 +40,7 @@ class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext with De Vectors.sparse(3, Array(0, 2), Array(5, 5)), Vectors.sparse(3, Array(0), Array(-2.5))) - val df = spark.createDataFrame(data.zip(expected)).toDF("features", "expected") + val df = data.zip(expected).toSeq.toDF("features", "expected") val scaler = new MinMaxScaler() .setInputCol("features") .setOutputCol("scaled") @@ -57,14 +59,13 @@ class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext with De test("MinMaxScaler arguments max must be larger than min") { withClue("arguments max must be larger than min") { - val dummyDF = spark.createDataFrame(Seq( - (1, Vectors.dense(1.0, 2.0)))).toDF("id", "feature") + val dummyDF = Seq((1, Vectors.dense(1.0, 2.0))).toDF("id", "features") intercept[IllegalArgumentException] { - val scaler = new MinMaxScaler().setMin(10).setMax(0).setInputCol("feature") + val scaler = new MinMaxScaler().setMin(10).setMax(0).setInputCol("features") scaler.transformSchema(dummyDF.schema) } intercept[IllegalArgumentException] { - val scaler = new MinMaxScaler().setMin(0).setMax(0).setInputCol("feature") + val scaler = new MinMaxScaler().setMin(0).setMax(0).setInputCol("features") scaler.transformSchema(dummyDF.schema) } } @@ -104,7 +105,7 @@ class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext with De Vectors.dense(-1.0, Double.NaN, -5.0, -5.0), Vectors.dense(5.0, 0.0, 5.0, Double.NaN)) - val df = spark.createDataFrame(data.zip(expected)).toDF("features", "expected") + val df = data.zip(expected).toSeq.toDF("features", "expected") val scaler = new MinMaxScaler() .setInputCol("features") .setOutputCol("scaled") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala index e5288d9259d3c..d4975c0b4e20e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala @@ -28,17 +28,18 @@ import org.apache.spark.sql.{Dataset, Row} case class NGramTestData(inputTokens: Array[String], wantedNGrams: Array[String]) class NGramSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import org.apache.spark.ml.feature.NGramSuite._ + import testImplicits._ test("default behavior yields bigram features") { val nGram = new NGram() .setInputCol("inputTokens") .setOutputCol("nGrams") - val dataset = spark.createDataFrame(Seq( - NGramTestData( - Array("Test", "for", "ngram", "."), - Array("Test for", "for ngram", "ngram .") - ))) + val dataset = Seq(NGramTestData( + Array("Test", "for", "ngram", "."), + Array("Test for", "for ngram", "ngram .") + )).toDF() testNGram(nGram, dataset) } @@ -47,11 +48,10 @@ class NGramSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRe .setInputCol("inputTokens") .setOutputCol("nGrams") .setN(4) - val dataset = spark.createDataFrame(Seq( - NGramTestData( - Array("a", "b", "c", "d", "e"), - Array("a b c d", "b c d e") - ))) + val dataset = Seq(NGramTestData( + Array("a", "b", "c", "d", "e"), + Array("a b c d", "b c d e") + )).toDF() testNGram(nGram, dataset) } @@ -60,11 +60,7 @@ class NGramSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRe .setInputCol("inputTokens") .setOutputCol("nGrams") .setN(4) - val dataset = spark.createDataFrame(Seq( - NGramTestData( - Array(), - Array() - ))) + val dataset = Seq(NGramTestData(Array(), Array())).toDF() testNGram(nGram, dataset) } @@ -73,11 +69,10 @@ class NGramSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRe .setInputCol("inputTokens") .setOutputCol("nGrams") .setN(6) - val dataset = spark.createDataFrame(Seq( - NGramTestData( - Array("a", "b", "c", "d", "e"), - Array() - ))) + val dataset = Seq(NGramTestData( + Array("a", "b", "c", "d", "e"), + Array() + )).toDF() testNGram(nGram, dataset) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala index b692831714466..c75027fb4553d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala @@ -27,6 +27,8 @@ import org.apache.spark.sql.{DataFrame, Row} class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import testImplicits._ + @transient var data: Array[Vector] = _ @transient var dataFrame: DataFrame = _ @transient var normalizer: Normalizer = _ @@ -61,7 +63,7 @@ class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa Vectors.sparse(3, Seq()) ) - dataFrame = spark.createDataFrame(sc.parallelize(data, 2).map(NormalizerSuite.FeatureData)) + dataFrame = data.map(NormalizerSuite.FeatureData).toSeq.toDF() normalizer = new Normalizer() .setInputCol("features") .setOutputCol("normalized_features") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala index d41eeec1329c5..c44c6813a94be 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala @@ -30,9 +30,11 @@ import org.apache.spark.sql.types._ class OneHotEncoderSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import testImplicits._ + def stringIndexed(): DataFrame = { - val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2) - val df = spark.createDataFrame(data).toDF("id", "label") + val data = Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")) + val df = data.toDF("id", "label") val indexer = new StringIndexer() .setInputCol("label") .setOutputCol("labelIndex") @@ -83,7 +85,7 @@ class OneHotEncoderSuite test("input column with ML attribute") { val attr = NominalAttribute.defaultAttr.withValues("small", "medium", "large") - val df = spark.createDataFrame(Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply)).toDF("size") + val df = Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply).toDF("size") .select(col("size").as("size", attr.toMetadata())) val encoder = new OneHotEncoder() .setInputCol("size") @@ -96,7 +98,7 @@ class OneHotEncoderSuite } test("input column without ML attribute") { - val df = spark.createDataFrame(Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply)).toDF("index") + val df = Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply).toDF("index") val encoder = new OneHotEncoder() .setInputCol("index") .setOutputCol("encoded") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala index ddb51fb1706a7..a60e87590f060 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala @@ -29,6 +29,8 @@ import org.apache.spark.sql.Row class PCASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import testImplicits._ + test("params") { ParamsSuite.checkParams(new PCA) val mat = Matrices.dense(2, 2, Array(0.0, 1.0, 2.0, 3.0)).asInstanceOf[DenseMatrix] @@ -50,7 +52,7 @@ class PCASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead val pc = mat.computePrincipalComponents(3) val expected = mat.multiply(pc).rows.map(_.asML) - val df = spark.createDataFrame(dataRDD.zip(expected)).toDF("features", "expected") + val df = dataRDD.zip(expected).toDF("features", "expected") val pca = new PCA() .setInputCol("features") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala index 9ecd321b128f6..e4b0ddf98bfad 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala @@ -30,6 +30,8 @@ import org.apache.spark.sql.Row class PolynomialExpansionSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import testImplicits._ + test("params") { ParamsSuite.checkParams(new PolynomialExpansion) } @@ -59,7 +61,7 @@ class PolynomialExpansionSuite Vectors.sparse(19, Array.empty, Array.empty)) test("Polynomial expansion with default parameter") { - val df = spark.createDataFrame(data.zip(twoDegreeExpansion)).toDF("features", "expected") + val df = data.zip(twoDegreeExpansion).toSeq.toDF("features", "expected") val polynomialExpansion = new PolynomialExpansion() .setInputCol("features") @@ -76,7 +78,7 @@ class PolynomialExpansionSuite } test("Polynomial expansion with setter") { - val df = spark.createDataFrame(data.zip(threeDegreeExpansion)).toDF("features", "expected") + val df = data.zip(threeDegreeExpansion).toSeq.toDF("features", "expected") val polynomialExpansion = new PolynomialExpansion() .setInputCol("features") @@ -94,7 +96,7 @@ class PolynomialExpansionSuite } test("Polynomial expansion with degree 1 is identity on vectors") { - val df = spark.createDataFrame(data.zip(data)).toDF("features", "expected") + val df = data.zip(data).toSeq.toDF("features", "expected") val polynomialExpansion = new PolynomialExpansion() .setInputCol("features") @@ -124,8 +126,7 @@ class PolynomialExpansionSuite (Vectors.dense(1.0, 2.0, 3.0, 4.0, 5.0, 6.0), 8007, 12375) ) - val df = spark.createDataFrame(data) - .toDF("features", "expectedPoly10size", "expectedPoly11size") + val df = data.toSeq.toDF("features", "expectedPoly10size", "expectedPoly11size") val t = new PolynomialExpansion() .setInputCol("features") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index 0794a049d9cd8..97c268f3d5c97 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -26,22 +26,23 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.types.DoubleType class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + + import testImplicits._ + test("params") { ParamsSuite.checkParams(new RFormula()) } test("transform numeric data") { val formula = new RFormula().setFormula("id ~ v1 + v2") - val original = spark.createDataFrame( - Seq((0, 1.0, 3.0), (2, 2.0, 5.0))).toDF("id", "v1", "v2") + val original = Seq((0, 1.0, 3.0), (2, 2.0, 5.0)).toDF("id", "v1", "v2") val model = formula.fit(original) val result = model.transform(original) val resultSchema = model.transformSchema(original.schema) - val expected = spark.createDataFrame( - Seq( - (0, 1.0, 3.0, Vectors.dense(1.0, 3.0), 0.0), - (2, 2.0, 5.0, Vectors.dense(2.0, 5.0), 2.0)) - ).toDF("id", "v1", "v2", "features", "label") + val expected = Seq( + (0, 1.0, 3.0, Vectors.dense(1.0, 3.0), 0.0), + (2, 2.0, 5.0, Vectors.dense(2.0, 5.0), 2.0) + ).toDF("id", "v1", "v2", "features", "label") // TODO(ekl) make schema comparisons ignore metadata, to avoid .toString assert(result.schema.toString == resultSchema.toString) assert(resultSchema == expected.schema) @@ -50,7 +51,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("features column already exists") { val formula = new RFormula().setFormula("y ~ x").setFeaturesCol("x") - val original = spark.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "y") + val original = Seq((0, 1.0), (2, 2.0)).toDF("x", "y") intercept[IllegalArgumentException] { formula.fit(original) } @@ -58,7 +59,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("label column already exists") { val formula = new RFormula().setFormula("y ~ x").setLabelCol("y") - val original = spark.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "y") + val original = Seq((0, 1.0), (2, 2.0)).toDF("x", "y") val model = formula.fit(original) val resultSchema = model.transformSchema(original.schema) assert(resultSchema.length == 3) @@ -67,7 +68,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("label column already exists but is not numeric type") { val formula = new RFormula().setFormula("y ~ x").setLabelCol("y") - val original = spark.createDataFrame(Seq((0, true), (2, false))).toDF("x", "y") + val original = Seq((0, true), (2, false)).toDF("x", "y") val model = formula.fit(original) intercept[IllegalArgumentException] { model.transformSchema(original.schema) @@ -79,7 +80,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("allow missing label column for test datasets") { val formula = new RFormula().setFormula("y ~ x").setLabelCol("label") - val original = spark.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "_not_y") + val original = Seq((0, 1.0), (2, 2.0)).toDF("x", "_not_y") val model = formula.fit(original) val resultSchema = model.transformSchema(original.schema) assert(resultSchema.length == 3) @@ -88,37 +89,32 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul } test("allow empty label") { - val original = spark.createDataFrame( - Seq((1, 2.0, 3.0), (4, 5.0, 6.0), (7, 8.0, 9.0)) - ).toDF("id", "a", "b") + val original = Seq((1, 2.0, 3.0), (4, 5.0, 6.0), (7, 8.0, 9.0)).toDF("id", "a", "b") val formula = new RFormula().setFormula("~ a + b") val model = formula.fit(original) val result = model.transform(original) val resultSchema = model.transformSchema(original.schema) - val expected = spark.createDataFrame( - Seq( - (1, 2.0, 3.0, Vectors.dense(2.0, 3.0)), - (4, 5.0, 6.0, Vectors.dense(5.0, 6.0)), - (7, 8.0, 9.0, Vectors.dense(8.0, 9.0))) - ).toDF("id", "a", "b", "features") + val expected = Seq( + (1, 2.0, 3.0, Vectors.dense(2.0, 3.0)), + (4, 5.0, 6.0, Vectors.dense(5.0, 6.0)), + (7, 8.0, 9.0, Vectors.dense(8.0, 9.0)) + ).toDF("id", "a", "b", "features") assert(result.schema.toString == resultSchema.toString) assert(result.collect() === expected.collect()) } test("encodes string terms") { val formula = new RFormula().setFormula("id ~ a + b") - val original = spark.createDataFrame( - Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5)) - ).toDF("id", "a", "b") + val original = Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5)) + .toDF("id", "a", "b") val model = formula.fit(original) val result = model.transform(original) val resultSchema = model.transformSchema(original.schema) - val expected = spark.createDataFrame( - Seq( + val expected = Seq( (1, "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0), (2, "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 2.0), (3, "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 3.0), - (4, "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 4.0)) + (4, "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 4.0) ).toDF("id", "a", "b", "features", "label") assert(result.schema.toString == resultSchema.toString) assert(result.collect() === expected.collect()) @@ -126,17 +122,16 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("index string label") { val formula = new RFormula().setFormula("id ~ a + b") - val original = spark.createDataFrame( + val original = Seq(("male", "foo", 4), ("female", "bar", 4), ("female", "bar", 5), ("male", "baz", 5)) - ).toDF("id", "a", "b") + .toDF("id", "a", "b") val model = formula.fit(original) val result = model.transform(original) - val expected = spark.createDataFrame( - Seq( + val expected = Seq( ("male", "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0), ("female", "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 0.0), ("female", "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 0.0), - ("male", "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 1.0)) + ("male", "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 1.0) ).toDF("id", "a", "b", "features", "label") // assert(result.schema.toString == resultSchema.toString) assert(result.collect() === expected.collect()) @@ -144,9 +139,8 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("attribute generation") { val formula = new RFormula().setFormula("id ~ a + b") - val original = spark.createDataFrame( - Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5)) - ).toDF("id", "a", "b") + val original = Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5)) + .toDF("id", "a", "b") val model = formula.fit(original) val result = model.transform(original) val attrs = AttributeGroup.fromStructField(result.schema("features")) @@ -161,9 +155,8 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("vector attribute generation") { val formula = new RFormula().setFormula("id ~ vec") - val original = spark.createDataFrame( - Seq((1, Vectors.dense(0.0, 1.0)), (2, Vectors.dense(1.0, 2.0))) - ).toDF("id", "vec") + val original = Seq((1, Vectors.dense(0.0, 1.0)), (2, Vectors.dense(1.0, 2.0))) + .toDF("id", "vec") val model = formula.fit(original) val result = model.transform(original) val attrs = AttributeGroup.fromStructField(result.schema("features")) @@ -177,9 +170,8 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("vector attribute generation with unnamed input attrs") { val formula = new RFormula().setFormula("id ~ vec2") - val base = spark.createDataFrame( - Seq((1, Vectors.dense(0.0, 1.0)), (2, Vectors.dense(1.0, 2.0))) - ).toDF("id", "vec") + val base = Seq((1, Vectors.dense(0.0, 1.0)), (2, Vectors.dense(1.0, 2.0))) + .toDF("id", "vec") val metadata = new AttributeGroup( "vec2", Array[Attribute]( @@ -199,16 +191,13 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("numeric interaction") { val formula = new RFormula().setFormula("a ~ b:c:d") - val original = spark.createDataFrame( - Seq((1, 2, 4, 2), (2, 3, 4, 1)) - ).toDF("a", "b", "c", "d") + val original = Seq((1, 2, 4, 2), (2, 3, 4, 1)).toDF("a", "b", "c", "d") val model = formula.fit(original) val result = model.transform(original) - val expected = spark.createDataFrame( - Seq( - (1, 2, 4, 2, Vectors.dense(16.0), 1.0), - (2, 3, 4, 1, Vectors.dense(12.0), 2.0)) - ).toDF("a", "b", "c", "d", "features", "label") + val expected = Seq( + (1, 2, 4, 2, Vectors.dense(16.0), 1.0), + (2, 3, 4, 1, Vectors.dense(12.0), 2.0) + ).toDF("a", "b", "c", "d", "features", "label") assert(result.collect() === expected.collect()) val attrs = AttributeGroup.fromStructField(result.schema("features")) val expectedAttrs = new AttributeGroup( @@ -219,20 +208,19 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("factor numeric interaction") { val formula = new RFormula().setFormula("id ~ a:b") - val original = spark.createDataFrame( + val original = Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5), (4, "baz", 5), (4, "baz", 5)) - ).toDF("id", "a", "b") + .toDF("id", "a", "b") val model = formula.fit(original) val result = model.transform(original) - val expected = spark.createDataFrame( - Seq( - (1, "foo", 4, Vectors.dense(0.0, 0.0, 4.0), 1.0), - (2, "bar", 4, Vectors.dense(0.0, 4.0, 0.0), 2.0), - (3, "bar", 5, Vectors.dense(0.0, 5.0, 0.0), 3.0), - (4, "baz", 5, Vectors.dense(5.0, 0.0, 0.0), 4.0), - (4, "baz", 5, Vectors.dense(5.0, 0.0, 0.0), 4.0), - (4, "baz", 5, Vectors.dense(5.0, 0.0, 0.0), 4.0)) - ).toDF("id", "a", "b", "features", "label") + val expected = Seq( + (1, "foo", 4, Vectors.dense(0.0, 0.0, 4.0), 1.0), + (2, "bar", 4, Vectors.dense(0.0, 4.0, 0.0), 2.0), + (3, "bar", 5, Vectors.dense(0.0, 5.0, 0.0), 3.0), + (4, "baz", 5, Vectors.dense(5.0, 0.0, 0.0), 4.0), + (4, "baz", 5, Vectors.dense(5.0, 0.0, 0.0), 4.0), + (4, "baz", 5, Vectors.dense(5.0, 0.0, 0.0), 4.0) + ).toDF("id", "a", "b", "features", "label") assert(result.collect() === expected.collect()) val attrs = AttributeGroup.fromStructField(result.schema("features")) val expectedAttrs = new AttributeGroup( @@ -246,17 +234,15 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("factor factor interaction") { val formula = new RFormula().setFormula("id ~ a:b") - val original = spark.createDataFrame( - Seq((1, "foo", "zq"), (2, "bar", "zq"), (3, "bar", "zz")) - ).toDF("id", "a", "b") + val original = + Seq((1, "foo", "zq"), (2, "bar", "zq"), (3, "bar", "zz")).toDF("id", "a", "b") val model = formula.fit(original) val result = model.transform(original) - val expected = spark.createDataFrame( - Seq( - (1, "foo", "zq", Vectors.dense(0.0, 0.0, 1.0, 0.0), 1.0), - (2, "bar", "zq", Vectors.dense(1.0, 0.0, 0.0, 0.0), 2.0), - (3, "bar", "zz", Vectors.dense(0.0, 1.0, 0.0, 0.0), 3.0)) - ).toDF("id", "a", "b", "features", "label") + val expected = Seq( + (1, "foo", "zq", Vectors.dense(0.0, 0.0, 1.0, 0.0), 1.0), + (2, "bar", "zq", Vectors.dense(1.0, 0.0, 0.0, 0.0), 2.0), + (3, "bar", "zz", Vectors.dense(0.0, 1.0, 0.0, 0.0), 3.0) + ).toDF("id", "a", "b", "features", "label") assert(result.collect() === expected.collect()) val attrs = AttributeGroup.fromStructField(result.schema("features")) val expectedAttrs = new AttributeGroup( @@ -295,9 +281,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul } } - val dataset = spark.createDataFrame( - Seq((1, "foo", "zq"), (2, "bar", "zq"), (3, "bar", "zz")) - ).toDF("id", "a", "b") + val dataset = Seq((1, "foo", "zq"), (2, "bar", "zq"), (3, "bar", "zz")).toDF("id", "a", "b") val rFormula = new RFormula().setFormula("id ~ a:b") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala index 1401ea9c4b431..23464073e6edb 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala @@ -26,19 +26,19 @@ import org.apache.spark.sql.types.{LongType, StructField, StructType} class SQLTransformerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import testImplicits._ + test("params") { ParamsSuite.checkParams(new SQLTransformer()) } test("transform numeric data") { - val original = spark.createDataFrame( - Seq((0, 1.0, 3.0), (2, 2.0, 5.0))).toDF("id", "v1", "v2") + val original = Seq((0, 1.0, 3.0), (2, 2.0, 5.0)).toDF("id", "v1", "v2") val sqlTrans = new SQLTransformer().setStatement( "SELECT *, (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__") val result = sqlTrans.transform(original) val resultSchema = sqlTrans.transformSchema(original.schema) - val expected = spark.createDataFrame( - Seq((0, 1.0, 3.0, 4.0, 3.0), (2, 2.0, 5.0, 7.0, 10.0))) + val expected = Seq((0, 1.0, 3.0, 4.0, 3.0), (2, 2.0, 5.0, 7.0, 10.0)) .toDF("id", "v1", "v2", "v3", "v4") assert(result.schema.toString == resultSchema.toString) assert(resultSchema == expected.schema) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala index 827ecb0fadbee..a928f93633011 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala @@ -28,6 +28,8 @@ import org.apache.spark.sql.{DataFrame, Row} class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import testImplicits._ + @transient var data: Array[Vector] = _ @transient var resWithStd: Array[Vector] = _ @transient var resWithMean: Array[Vector] = _ @@ -73,7 +75,7 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext } test("Standardization with default parameter") { - val df0 = spark.createDataFrame(data.zip(resWithStd)).toDF("features", "expected") + val df0 = data.zip(resWithStd).toSeq.toDF("features", "expected") val standardScaler0 = new StandardScaler() .setInputCol("features") @@ -84,9 +86,9 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext } test("Standardization with setter") { - val df1 = spark.createDataFrame(data.zip(resWithBoth)).toDF("features", "expected") - val df2 = spark.createDataFrame(data.zip(resWithMean)).toDF("features", "expected") - val df3 = spark.createDataFrame(data.zip(data)).toDF("features", "expected") + val df1 = data.zip(resWithBoth).toSeq.toDF("features", "expected") + val df2 = data.zip(resWithMean).toSeq.toDF("features", "expected") + val df3 = data.zip(data).toSeq.toDF("features", "expected") val standardScaler1 = new StandardScaler() .setInputCol("features") @@ -120,7 +122,7 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext Vectors.sparse(3, Array(1, 2), Array(-5.1, 1.0)), Vectors.dense(1.7, -0.6, 3.3) ) - val df = spark.createDataFrame(someSparseData.zip(resWithMean)).toDF("features", "expected") + val df = someSparseData.zip(resWithMean).toSeq.toDF("features", "expected") val standardScaler = new StandardScaler() .setInputCol("features") .setOutputCol("standardized_features") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala index 125ad02ebcc02..957cf58a68f85 100755 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala @@ -37,19 +37,20 @@ class StopWordsRemoverSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { import StopWordsRemoverSuite._ + import testImplicits._ test("StopWordsRemover default") { val remover = new StopWordsRemover() .setInputCol("raw") .setOutputCol("filtered") - val dataSet = spark.createDataFrame(Seq( + val dataSet = Seq( (Seq("test", "test"), Seq("test", "test")), (Seq("a", "b", "c", "d"), Seq("b", "c")), (Seq("a", "the", "an"), Seq()), (Seq("A", "The", "AN"), Seq()), (Seq(null), Seq(null)), (Seq(), Seq()) - )).toDF("raw", "expected") + ).toDF("raw", "expected") testStopWordsRemover(remover, dataSet) } @@ -60,14 +61,14 @@ class StopWordsRemoverSuite .setInputCol("raw") .setOutputCol("filtered") .setStopWords(stopWords) - val dataSet = spark.createDataFrame(Seq( + val dataSet = Seq( (Seq("test", "test"), Seq()), (Seq("a", "b", "c", "d"), Seq("b", "c", "d")), (Seq("a", "the", "an"), Seq()), (Seq("A", "The", "AN"), Seq()), (Seq(null), Seq(null)), (Seq(), Seq()) - )).toDF("raw", "expected") + ).toDF("raw", "expected") testStopWordsRemover(remover, dataSet) } @@ -77,10 +78,10 @@ class StopWordsRemoverSuite .setInputCol("raw") .setOutputCol("filtered") .setCaseSensitive(true) - val dataSet = spark.createDataFrame(Seq( + val dataSet = Seq( (Seq("A"), Seq("A")), (Seq("The", "the"), Seq("The")) - )).toDF("raw", "expected") + ).toDF("raw", "expected") testStopWordsRemover(remover, dataSet) } @@ -98,10 +99,10 @@ class StopWordsRemoverSuite .setInputCol("raw") .setOutputCol("filtered") .setStopWords(stopWords) - val dataSet = spark.createDataFrame(Seq( + val dataSet = Seq( (Seq("acaba", "ama", "biri"), Seq()), (Seq("hep", "her", "scala"), Seq("scala")) - )).toDF("raw", "expected") + ).toDF("raw", "expected") testStopWordsRemover(remover, dataSet) } @@ -112,10 +113,10 @@ class StopWordsRemoverSuite .setInputCol("raw") .setOutputCol("filtered") .setStopWords(stopWords.toArray) - val dataSet = spark.createDataFrame(Seq( + val dataSet = Seq( (Seq("python", "scala", "a"), Seq("python", "scala", "a")), (Seq("Python", "Scala", "swift"), Seq("Python", "Scala", "swift")) - )).toDF("raw", "expected") + ).toDF("raw", "expected") testStopWordsRemover(remover, dataSet) } @@ -126,10 +127,10 @@ class StopWordsRemoverSuite .setInputCol("raw") .setOutputCol("filtered") .setStopWords(stopWords.toArray) - val dataSet = spark.createDataFrame(Seq( + val dataSet = Seq( (Seq("python", "scala", "a"), Seq()), (Seq("Python", "Scala", "swift"), Seq("swift")) - )).toDF("raw", "expected") + ).toDF("raw", "expected") testStopWordsRemover(remover, dataSet) } @@ -148,9 +149,7 @@ class StopWordsRemoverSuite val remover = new StopWordsRemover() .setInputCol("raw") .setOutputCol(outputCol) - val dataSet = spark.createDataFrame(Seq( - (Seq("The", "the", "swift"), Seq("swift")) - )).toDF("raw", outputCol) + val dataSet = Seq((Seq("The", "the", "swift"), Seq("swift"))).toDF("raw", outputCol) val thrown = intercept[IllegalArgumentException] { testStopWordsRemover(remover, dataSet) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index b478fea5e74ec..a6bbb944a1bd7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -29,6 +29,8 @@ import org.apache.spark.sql.types.{DoubleType, StringType, StructField, StructTy class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import testImplicits._ + test("params") { ParamsSuite.checkParams(new StringIndexer) val model = new StringIndexerModel("indexer", Array("a", "b")) @@ -38,8 +40,8 @@ class StringIndexerSuite } test("StringIndexer") { - val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2) - val df = spark.createDataFrame(data).toDF("id", "label") + val data = Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")) + val df = data.toDF("id", "label") val indexer = new StringIndexer() .setInputCol("label") .setOutputCol("labelIndex") @@ -61,10 +63,10 @@ class StringIndexerSuite } test("StringIndexerUnseen") { - val data = sc.parallelize(Seq((0, "a"), (1, "b"), (4, "b")), 2) - val data2 = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c")), 2) - val df = spark.createDataFrame(data).toDF("id", "label") - val df2 = spark.createDataFrame(data2).toDF("id", "label") + val data = Seq((0, "a"), (1, "b"), (4, "b")) + val data2 = Seq((0, "a"), (1, "b"), (2, "c")) + val df = data.toDF("id", "label") + val df2 = data2.toDF("id", "label") val indexer = new StringIndexer() .setInputCol("label") .setOutputCol("labelIndex") @@ -92,8 +94,8 @@ class StringIndexerSuite } test("StringIndexer with a numeric input column") { - val data = sc.parallelize(Seq((0, 100), (1, 200), (2, 300), (3, 100), (4, 100), (5, 300)), 2) - val df = spark.createDataFrame(data).toDF("id", "label") + val data = Seq((0, 100), (1, 200), (2, 300), (3, 100), (4, 100), (5, 300)) + val df = data.toDF("id", "label") val indexer = new StringIndexer() .setInputCol("label") .setOutputCol("labelIndex") @@ -119,7 +121,7 @@ class StringIndexerSuite } test("StringIndexerModel can't overwrite output column") { - val df = spark.createDataFrame(Seq((1, 2), (3, 4))).toDF("input", "output") + val df = Seq((1, 2), (3, 4)).toDF("input", "output") intercept[IllegalArgumentException] { new StringIndexer() .setInputCol("input") @@ -161,9 +163,7 @@ class StringIndexerSuite test("IndexToString.transform") { val labels = Array("a", "b", "c") - val df0 = spark.createDataFrame(Seq( - (0, "a"), (1, "b"), (2, "c"), (0, "a") - )).toDF("index", "expected") + val df0 = Seq((0, "a"), (1, "b"), (2, "c"), (0, "a")).toDF("index", "expected") val idxToStr0 = new IndexToString() .setInputCol("index") @@ -187,8 +187,8 @@ class StringIndexerSuite } test("StringIndexer, IndexToString are inverses") { - val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2) - val df = spark.createDataFrame(data).toDF("id", "label") + val data = Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")) + val df = data.toDF("id", "label") val indexer = new StringIndexer() .setInputCol("label") .setOutputCol("labelIndex") @@ -220,8 +220,8 @@ class StringIndexerSuite } test("StringIndexer metadata") { - val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2) - val df = spark.createDataFrame(data).toDF("id", "label") + val data = Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")) + val df = data.toDF("id", "label") val indexer = new StringIndexer() .setInputCol("label") .setOutputCol("labelIndex") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala index f30bdc3ddc0d7..c895659a2d8be 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala @@ -46,6 +46,7 @@ class RegexTokenizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { import org.apache.spark.ml.feature.RegexTokenizerSuite._ + import testImplicits._ test("params") { ParamsSuite.checkParams(new RegexTokenizer) @@ -57,26 +58,26 @@ class RegexTokenizerSuite .setPattern("\\w+|\\p{Punct}") .setInputCol("rawText") .setOutputCol("tokens") - val dataset0 = spark.createDataFrame(Seq( + val dataset0 = Seq( TokenizerTestData("Test for tokenization.", Array("test", "for", "tokenization", ".")), TokenizerTestData("Te,st. punct", Array("te", ",", "st", ".", "punct")) - )) + ).toDF() testRegexTokenizer(tokenizer0, dataset0) - val dataset1 = spark.createDataFrame(Seq( + val dataset1 = Seq( TokenizerTestData("Test for tokenization.", Array("test", "for", "tokenization")), TokenizerTestData("Te,st. punct", Array("punct")) - )) + ).toDF() tokenizer0.setMinTokenLength(3) testRegexTokenizer(tokenizer0, dataset1) val tokenizer2 = new RegexTokenizer() .setInputCol("rawText") .setOutputCol("tokens") - val dataset2 = spark.createDataFrame(Seq( + val dataset2 = Seq( TokenizerTestData("Test for tokenization.", Array("test", "for", "tokenization.")), TokenizerTestData("Te,st. punct", Array("te,st.", "punct")) - )) + ).toDF() testRegexTokenizer(tokenizer2, dataset2) } @@ -85,10 +86,10 @@ class RegexTokenizerSuite .setInputCol("rawText") .setOutputCol("tokens") .setToLowercase(false) - val dataset = spark.createDataFrame(Seq( + val dataset = Seq( TokenizerTestData("JAVA SCALA", Array("JAVA", "SCALA")), TokenizerTestData("java scala", Array("java", "scala")) - )) + ).toDF() testRegexTokenizer(tokenizer, dataset) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala index 561493fbafd6c..46cced3a9a6e5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala @@ -29,6 +29,8 @@ import org.apache.spark.sql.functions.col class VectorAssemblerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import testImplicits._ + test("params") { ParamsSuite.checkParams(new VectorAssembler) } @@ -57,9 +59,9 @@ class VectorAssemblerSuite } test("VectorAssembler") { - val df = spark.createDataFrame(Seq( + val df = Seq( (0, 0.0, Vectors.dense(1.0, 2.0), "a", Vectors.sparse(2, Array(1), Array(3.0)), 10L) - )).toDF("id", "x", "y", "name", "z", "n") + ).toDF("id", "x", "y", "name", "z", "n") val assembler = new VectorAssembler() .setInputCols(Array("x", "y", "z", "n")) .setOutputCol("features") @@ -70,7 +72,7 @@ class VectorAssemblerSuite } test("transform should throw an exception in case of unsupported type") { - val df = spark.createDataFrame(Seq(("a", "b", "c"))).toDF("a", "b", "c") + val df = Seq(("a", "b", "c")).toDF("a", "b", "c") val assembler = new VectorAssembler() .setInputCols(Array("a", "b", "c")) .setOutputCol("features") @@ -87,7 +89,7 @@ class VectorAssemblerSuite NominalAttribute.defaultAttr.withName("gender").withValues("male", "female"), NumericAttribute.defaultAttr.withName("salary"))) val row = (1.0, 0.5, 1, Vectors.dense(1.0, 1000.0), Vectors.sparse(2, Array(1), Array(2.0))) - val df = spark.createDataFrame(Seq(row)).toDF("browser", "hour", "count", "user", "ad") + val df = Seq(row).toDF("browser", "hour", "count", "user", "ad") .select( col("browser").as("browser", browser.toMetadata()), col("hour").as("hour", hour.toMetadata()), diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala index 707142332349c..4da1b133e8cd5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.DataFrame class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest with Logging { + import testImplicits._ import VectorIndexerSuite.FeatureData // identical, of length 3 @@ -85,11 +86,13 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext checkPair(densePoints1Seq, sparsePoints1Seq) checkPair(densePoints2Seq, sparsePoints2Seq) - densePoints1 = spark.createDataFrame(sc.parallelize(densePoints1Seq, 2).map(FeatureData)) - sparsePoints1 = spark.createDataFrame(sc.parallelize(sparsePoints1Seq, 2).map(FeatureData)) - densePoints2 = spark.createDataFrame(sc.parallelize(densePoints2Seq, 2).map(FeatureData)) - sparsePoints2 = spark.createDataFrame(sc.parallelize(sparsePoints2Seq, 2).map(FeatureData)) - badPoints = spark.createDataFrame(sc.parallelize(badPointsSeq, 2).map(FeatureData)) + densePoints1 = densePoints1Seq.map(FeatureData).toDF() + sparsePoints1 = sparsePoints1Seq.map(FeatureData).toDF() + // TODO: If we directly use `toDF` without parallelize, the test in + // "Throws error when given RDDs with different size vectors" is failed for an unknown reason. + densePoints2 = sc.parallelize(densePoints2Seq, 2).map(FeatureData).toDF() + sparsePoints2 = sparsePoints2Seq.map(FeatureData).toDF() + badPoints = badPointsSeq.map(FeatureData).toDF() } private def getIndexer: VectorIndexer = @@ -102,7 +105,7 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext } test("Cannot fit an empty DataFrame") { - val rdd = spark.createDataFrame(sc.parallelize(Array.empty[Vector], 2).map(FeatureData)) + val rdd = Array.empty[Vector].map(FeatureData).toSeq.toDF() val vectorIndexer = getIndexer intercept[IllegalArgumentException] { vectorIndexer.fit(rdd) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala index 1c70b702de063..0fdfdf37cf38d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala @@ -31,23 +31,22 @@ import org.apache.spark.sql.{DataFrame, Row} class AFTSurvivalRegressionSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import testImplicits._ + @transient var datasetUnivariate: DataFrame = _ @transient var datasetMultivariate: DataFrame = _ @transient var datasetUnivariateScaled: DataFrame = _ override def beforeAll(): Unit = { super.beforeAll() - datasetUnivariate = spark.createDataFrame( - sc.parallelize(generateAFTInput( - 1, Array(5.5), Array(0.8), 1000, 42, 1.0, 2.0, 2.0))) - datasetMultivariate = spark.createDataFrame( - sc.parallelize(generateAFTInput( - 2, Array(0.9, -1.3), Array(0.7, 1.2), 1000, 42, 1.5, 2.5, 2.0))) - datasetUnivariateScaled = spark.createDataFrame( - sc.parallelize(generateAFTInput( - 1, Array(5.5), Array(0.8), 1000, 42, 1.0, 2.0, 2.0)).map { x => - AFTPoint(Vectors.dense(x.features(0) * 1.0E3), x.label, x.censor) - }) + datasetUnivariate = generateAFTInput( + 1, Array(5.5), Array(0.8), 1000, 42, 1.0, 2.0, 2.0).toDF() + datasetMultivariate = generateAFTInput( + 2, Array(0.9, -1.3), Array(0.7, 1.2), 1000, 42, 1.5, 2.5, 2.0).toDF() + datasetUnivariateScaled = sc.parallelize( + generateAFTInput(1, Array(5.5), Array(0.8), 1000, 42, 1.0, 2.0, 2.0)).map { x => + AFTPoint(Vectors.dense(x.features(0) * 1.0E3), x.label, x.censor) + }.toDF() } /** @@ -396,9 +395,8 @@ class AFTSurvivalRegressionSuite // the parallelism is bigger than that. Because the issue was about `AFTAggregator`s // being merged incorrectly when it has an empty partition, running the codes below // should not throw an exception. - val dataset = spark.createDataFrame( - sc.parallelize(generateAFTInput( - 1, Array(5.5), Array(0.8), 2, 42, 1.0, 2.0, 2.0), numSlices = 3)) + val dataset = sc.parallelize(generateAFTInput( + 1, Array(5.5), Array(0.8), 2, 42, 1.0, 2.0, 2.0), numSlices = 3).toDF() val trainer = new AFTSurvivalRegression() trainer.fit(dataset) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index 7b5df8f31bb38..dcf3f9a1ea9b2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -37,6 +37,7 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { import GBTRegressorSuite.compareAPIs + import testImplicits._ // Combinations for estimators, learning rates and subsamplingRate private val testCombinations = @@ -76,14 +77,14 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext } test("GBTRegressor behaves reasonably on toy data") { - val df = spark.createDataFrame(Seq( + val df = Seq( LabeledPoint(10, Vectors.dense(1, 2, 3, 4)), LabeledPoint(-5, Vectors.dense(6, 3, 2, 1)), LabeledPoint(11, Vectors.dense(2, 2, 3, 4)), LabeledPoint(-6, Vectors.dense(6, 4, 2, 1)), LabeledPoint(9, Vectors.dense(1, 2, 6, 4)), LabeledPoint(-4, Vectors.dense(6, 3, 2, 2)) - )) + ).toDF() val gbt = new GBTRegressor() .setMaxDepth(2) .setMaxIter(2) @@ -103,7 +104,7 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext val path = tempDir.toURI.toString sc.setCheckpointDir(path) - val df = spark.createDataFrame(data) + val df = data.toDF() val gbt = new GBTRegressor() .setMaxDepth(2) .setMaxIter(5) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index d8032c4e1705b..937aa7d3c2045 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -35,6 +35,8 @@ import org.apache.spark.sql.functions._ class GeneralizedLinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import testImplicits._ + private val seed: Int = 42 @transient var datasetGaussianIdentity: DataFrame = _ @transient var datasetGaussianLog: DataFrame = _ @@ -52,23 +54,20 @@ class GeneralizedLinearRegressionSuite import GeneralizedLinearRegressionSuite._ - datasetGaussianIdentity = spark.createDataFrame( - sc.parallelize(generateGeneralizedLinearRegressionInput( - intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5), - xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, - family = "gaussian", link = "identity"), 2)) + datasetGaussianIdentity = generateGeneralizedLinearRegressionInput( + intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5), + xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, + family = "gaussian", link = "identity").toDF() - datasetGaussianLog = spark.createDataFrame( - sc.parallelize(generateGeneralizedLinearRegressionInput( - intercept = 0.25, coefficients = Array(0.22, 0.06), xMean = Array(2.9, 10.5), - xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, - family = "gaussian", link = "log"), 2)) + datasetGaussianLog = generateGeneralizedLinearRegressionInput( + intercept = 0.25, coefficients = Array(0.22, 0.06), xMean = Array(2.9, 10.5), + xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, + family = "gaussian", link = "log").toDF() - datasetGaussianInverse = spark.createDataFrame( - sc.parallelize(generateGeneralizedLinearRegressionInput( - intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5), - xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, - family = "gaussian", link = "inverse"), 2)) + datasetGaussianInverse = generateGeneralizedLinearRegressionInput( + intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5), + xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, + family = "gaussian", link = "inverse").toDF() datasetBinomial = { val nPoints = 10000 @@ -80,44 +79,38 @@ class GeneralizedLinearRegressionSuite generateMultinomialLogisticInput(coefficients, xMean, xVariance, addIntercept = true, nPoints, seed) - spark.createDataFrame(sc.parallelize(testData, 2)) + testData.toDF() } - datasetPoissonLog = spark.createDataFrame( - sc.parallelize(generateGeneralizedLinearRegressionInput( - intercept = 0.25, coefficients = Array(0.22, 0.06), xMean = Array(2.9, 10.5), - xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, - family = "poisson", link = "log"), 2)) - - datasetPoissonIdentity = spark.createDataFrame( - sc.parallelize(generateGeneralizedLinearRegressionInput( - intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5), - xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, - family = "poisson", link = "identity"), 2)) - - datasetPoissonSqrt = spark.createDataFrame( - sc.parallelize(generateGeneralizedLinearRegressionInput( - intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5), - xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, - family = "poisson", link = "sqrt"), 2)) - - datasetGammaInverse = spark.createDataFrame( - sc.parallelize(generateGeneralizedLinearRegressionInput( - intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5), - xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, - family = "gamma", link = "inverse"), 2)) - - datasetGammaIdentity = spark.createDataFrame( - sc.parallelize(generateGeneralizedLinearRegressionInput( - intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5), - xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, - family = "gamma", link = "identity"), 2)) - - datasetGammaLog = spark.createDataFrame( - sc.parallelize(generateGeneralizedLinearRegressionInput( - intercept = 0.25, coefficients = Array(0.22, 0.06), xMean = Array(2.9, 10.5), - xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, - family = "gamma", link = "log"), 2)) + datasetPoissonLog = generateGeneralizedLinearRegressionInput( + intercept = 0.25, coefficients = Array(0.22, 0.06), xMean = Array(2.9, 10.5), + xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, + family = "poisson", link = "log").toDF() + + datasetPoissonIdentity = generateGeneralizedLinearRegressionInput( + intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5), + xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, + family = "poisson", link = "identity").toDF() + + datasetPoissonSqrt = generateGeneralizedLinearRegressionInput( + intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5), + xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, + family = "poisson", link = "sqrt").toDF() + + datasetGammaInverse = generateGeneralizedLinearRegressionInput( + intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5), + xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, + family = "gamma", link = "inverse").toDF() + + datasetGammaIdentity = generateGeneralizedLinearRegressionInput( + intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5), + xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, + family = "gamma", link = "identity").toDF() + + datasetGammaLog = generateGeneralizedLinearRegressionInput( + intercept = 0.25, coefficients = Array(0.22, 0.06), xMean = Array(2.9, 10.5), + xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, + family = "gamma", link = "log").toDF() } /** @@ -540,12 +533,12 @@ class GeneralizedLinearRegressionSuite w <- c(1, 2, 3, 4) df <- as.data.frame(cbind(A, b)) */ - val datasetWithWeight = spark.createDataFrame(sc.parallelize(Seq( + val datasetWithWeight = Seq( Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), Instance(19.0, 2.0, Vectors.dense(1.0, 7.0)), Instance(23.0, 3.0, Vectors.dense(2.0, 11.0)), Instance(29.0, 4.0, Vectors.dense(3.0, 13.0)) - ), 2)) + ).toDF() /* R code: @@ -668,12 +661,12 @@ class GeneralizedLinearRegressionSuite w <- c(1, 2, 3, 4) df <- as.data.frame(cbind(A, b)) */ - val datasetWithWeight = spark.createDataFrame(sc.parallelize(Seq( + val datasetWithWeight = Seq( Instance(1.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), Instance(0.0, 2.0, Vectors.dense(1.0, 2.0)), Instance(1.0, 3.0, Vectors.dense(2.0, 1.0)), Instance(0.0, 4.0, Vectors.dense(3.0, 3.0)) - ), 2)) + ).toDF() /* R code: @@ -782,12 +775,12 @@ class GeneralizedLinearRegressionSuite w <- c(1, 2, 3, 4) df <- as.data.frame(cbind(A, b)) */ - val datasetWithWeight = spark.createDataFrame(sc.parallelize(Seq( + val datasetWithWeight = Seq( Instance(2.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), Instance(8.0, 2.0, Vectors.dense(1.0, 7.0)), Instance(3.0, 3.0, Vectors.dense(2.0, 11.0)), Instance(9.0, 4.0, Vectors.dense(3.0, 13.0)) - ), 2)) + ).toDF() /* R code: @@ -899,12 +892,12 @@ class GeneralizedLinearRegressionSuite w <- c(1, 2, 3, 4) df <- as.data.frame(cbind(A, b)) */ - val datasetWithWeight = spark.createDataFrame(sc.parallelize(Seq( + val datasetWithWeight = Seq( Instance(2.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), Instance(8.0, 2.0, Vectors.dense(1.0, 7.0)), Instance(3.0, 3.0, Vectors.dense(2.0, 11.0)), Instance(9.0, 4.0, Vectors.dense(3.0, 13.0)) - ), 2)) + ).toDF() /* R code: @@ -1054,12 +1047,12 @@ class GeneralizedLinearRegressionSuite [1] 12.92681 [1] 13.32836 */ - val dataset = spark.createDataFrame(Seq( + val dataset = Seq( LabeledPoint(1, Vectors.dense(5, 0)), LabeledPoint(0, Vectors.dense(2, 1)), LabeledPoint(1, Vectors.dense(1, 2)), LabeledPoint(0, Vectors.dense(3, 3)) - )) + ).toDF() val expected = Seq(12.88188, 12.92681, 13.32836) var idx = 0 diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala index 14d8a4e4e3345..c2c79476e8b2b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala @@ -27,15 +27,15 @@ import org.apache.spark.sql.{DataFrame, Row} class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import testImplicits._ + private def generateIsotonicInput(labels: Seq[Double]): DataFrame = { - spark.createDataFrame( - labels.zipWithIndex.map { case (label, i) => (label, i.toDouble, 1.0) } - ).toDF("label", "features", "weight") + labels.zipWithIndex.map { case (label, i) => (label, i.toDouble, 1.0) } + .toDF("label", "features", "weight") } private def generatePredictionInput(features: Seq[Double]): DataFrame = { - spark.createDataFrame(features.map(Tuple1.apply)) - .toDF("features") + features.map(Tuple1.apply).toDF("features") } test("isotonic regression predictions") { @@ -145,10 +145,10 @@ class IsotonicRegressionSuite } test("vector features column with feature index") { - val dataset = spark.createDataFrame(Seq( + val dataset = Seq( (4.0, Vectors.dense(0.0, 1.0)), (3.0, Vectors.dense(0.0, 2.0)), - (5.0, Vectors.sparse(2, Array(1), Array(3.0)))) + (5.0, Vectors.sparse(2, Array(1), Array(3.0))) ).toDF("label", "features") val ir = new IsotonicRegression() diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 265f2f45c45fe..5ae371b489aa5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -32,6 +32,8 @@ import org.apache.spark.sql.{DataFrame, Row} class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import testImplicits._ + private val seed: Int = 42 @transient var datasetWithDenseFeature: DataFrame = _ @transient var datasetWithDenseFeatureWithoutIntercept: DataFrame = _ @@ -42,29 +44,27 @@ class LinearRegressionSuite override def beforeAll(): Unit = { super.beforeAll() - datasetWithDenseFeature = spark.createDataFrame( - sc.parallelize(LinearDataGenerator.generateLinearInput( - intercept = 6.3, weights = Array(4.7, 7.2), xMean = Array(0.9, -1.3), - xVariance = Array(0.7, 1.2), nPoints = 10000, seed, eps = 0.1), 2).map(_.asML)) + datasetWithDenseFeature = sc.parallelize(LinearDataGenerator.generateLinearInput( + intercept = 6.3, weights = Array(4.7, 7.2), xMean = Array(0.9, -1.3), + xVariance = Array(0.7, 1.2), nPoints = 10000, seed, eps = 0.1), 2).map(_.asML).toDF() /* datasetWithDenseFeatureWithoutIntercept is not needed for correctness testing but is useful for illustrating training model without intercept */ - datasetWithDenseFeatureWithoutIntercept = spark.createDataFrame( - sc.parallelize(LinearDataGenerator.generateLinearInput( + datasetWithDenseFeatureWithoutIntercept = sc.parallelize( + LinearDataGenerator.generateLinearInput( intercept = 0.0, weights = Array(4.7, 7.2), xMean = Array(0.9, -1.3), - xVariance = Array(0.7, 1.2), nPoints = 10000, seed, eps = 0.1), 2).map(_.asML)) + xVariance = Array(0.7, 1.2), nPoints = 10000, seed, eps = 0.1), 2).map(_.asML).toDF() val r = new Random(seed) // When feature size is larger than 4096, normal optimizer is choosed // as the solver of linear regression in the case of "auto" mode. val featureSize = 4100 - datasetWithSparseFeature = spark.createDataFrame( - sc.parallelize(LinearDataGenerator.generateLinearInput( + datasetWithSparseFeature = sc.parallelize(LinearDataGenerator.generateLinearInput( intercept = 0.0, weights = Seq.fill(featureSize)(r.nextDouble()).toArray, xMean = Seq.fill(featureSize)(r.nextDouble()).toArray, xVariance = Seq.fill(featureSize)(r.nextDouble()).toArray, nPoints = 200, - seed, eps = 0.1, sparsity = 0.7), 2).map(_.asML)) + seed, eps = 0.1, sparsity = 0.7), 2).map(_.asML).toDF() /* R code: @@ -74,13 +74,12 @@ class LinearRegressionSuite w <- c(1, 2, 3, 4) df <- as.data.frame(cbind(A, b)) */ - datasetWithWeight = spark.createDataFrame( - sc.parallelize(Seq( - Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), - Instance(19.0, 2.0, Vectors.dense(1.0, 7.0)), - Instance(23.0, 3.0, Vectors.dense(2.0, 11.0)), - Instance(29.0, 4.0, Vectors.dense(3.0, 13.0)) - ), 2)) + datasetWithWeight = sc.parallelize(Seq( + Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), + Instance(19.0, 2.0, Vectors.dense(1.0, 7.0)), + Instance(23.0, 3.0, Vectors.dense(2.0, 11.0)), + Instance(29.0, 4.0, Vectors.dense(3.0, 13.0)) + ), 2).toDF() /* R code: @@ -90,20 +89,18 @@ class LinearRegressionSuite w <- c(1, 2, 3, 4) df.const.label <- as.data.frame(cbind(A, b.const)) */ - datasetWithWeightConstantLabel = spark.createDataFrame( - sc.parallelize(Seq( - Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), - Instance(17.0, 2.0, Vectors.dense(1.0, 7.0)), - Instance(17.0, 3.0, Vectors.dense(2.0, 11.0)), - Instance(17.0, 4.0, Vectors.dense(3.0, 13.0)) - ), 2)) - datasetWithWeightZeroLabel = spark.createDataFrame( - sc.parallelize(Seq( - Instance(0.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), - Instance(0.0, 2.0, Vectors.dense(1.0, 7.0)), - Instance(0.0, 3.0, Vectors.dense(2.0, 11.0)), - Instance(0.0, 4.0, Vectors.dense(3.0, 13.0)) - ), 2)) + datasetWithWeightConstantLabel = sc.parallelize(Seq( + Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), + Instance(17.0, 2.0, Vectors.dense(1.0, 7.0)), + Instance(17.0, 3.0, Vectors.dense(2.0, 11.0)), + Instance(17.0, 4.0, Vectors.dense(3.0, 13.0)) + ), 2).toDF() + datasetWithWeightZeroLabel = sc.parallelize(Seq( + Instance(0.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), + Instance(0.0, 2.0, Vectors.dense(1.0, 7.0)), + Instance(0.0, 3.0, Vectors.dense(2.0, 11.0)), + Instance(0.0, 4.0, Vectors.dense(3.0, 13.0)) + ), 2).toDF() } /** @@ -839,8 +836,7 @@ class LinearRegressionSuite } val data2 = weightedSignedData ++ weightedNoiseData - (spark.createDataFrame(sc.parallelize(data1, 4)), - spark.createDataFrame(sc.parallelize(data2, 4))) + (sc.parallelize(data1, 4).toDF(), sc.parallelize(data2, 4).toDF()) } val trainer1a = (new LinearRegression).setFitIntercept(true) diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala index 5c50a88c8314a..4109a299091dc 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala @@ -32,13 +32,15 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext */ class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext with Logging { + import testImplicits._ + test("runWithValidation stops early and performs better on a validation dataset") { // Set numIterations large enough so that it stops early. val numIterations = 20 val trainRdd = sc.parallelize(OldGBTSuite.trainData, 2).map(_.asML) val validateRdd = sc.parallelize(OldGBTSuite.validateData, 2).map(_.asML) - val trainDF = spark.createDataFrame(trainRdd) - val validateDF = spark.createDataFrame(validateRdd) + val trainDF = trainRdd.toDF() + val validateDF = validateRdd.toDF() val algos = Array(Regression, Regression, Classification) val losses = Array(SquaredError, AbsoluteError, LogLoss) diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index 750dc5bf01e6a..7116265474f22 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -35,12 +35,13 @@ import org.apache.spark.sql.types.StructType class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import testImplicits._ + @transient var dataset: Dataset[_] = _ override def beforeAll(): Unit = { super.beforeAll() - dataset = spark.createDataFrame( - sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2)) + dataset = sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2).toDF() } test("cross validation with logistic regression") { @@ -67,9 +68,10 @@ class CrossValidatorSuite } test("cross validation with linear regression") { - val dataset = spark.createDataFrame( - sc.parallelize(LinearDataGenerator.generateLinearInput( - 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1), 2).map(_.asML)) + val dataset = sc.parallelize( + LinearDataGenerator.generateLinearInput( + 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1), 2) + .map(_.asML).toDF() val trainer = new LinearRegression().setSolver("l-bfgs") val lrParamMaps = new ParamGridBuilder() diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala index 9971371e47288..87100ae2e342f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala @@ -33,9 +33,11 @@ import org.apache.spark.sql.types.StructType class TrainValidationSplitSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + + import testImplicits._ + test("train validation with logistic regression") { - val dataset = spark.createDataFrame( - sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2)) + val dataset = sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2).toDF() val lr = new LogisticRegression val lrParamMaps = new ParamGridBuilder() @@ -58,9 +60,10 @@ class TrainValidationSplitSuite } test("train validation with linear regression") { - val dataset = spark.createDataFrame( - sc.parallelize(LinearDataGenerator.generateLinearInput( - 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1), 2).map(_.asML)) + val dataset = sc.parallelize( + LinearDataGenerator.generateLinearInput( + 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1), 2) + .map(_.asML).toDF() val trainer = new LinearRegression().setSolver("l-bfgs") val lrParamMaps = new ParamGridBuilder() diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala index 6aa93c9076007..e4e9be39ff6f9 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala @@ -37,6 +37,8 @@ import org.apache.spark.util.Utils class MLUtilsSuite extends SparkFunSuite with MLlibTestSparkContext { + import testImplicits._ + test("epsilon computation") { assert(1.0 + EPSILON > 1.0, s"EPSILON is too small: $EPSILON.") assert(1.0 + EPSILON / 2.0 === 1.0, s"EPSILON is too big: $EPSILON.") @@ -255,9 +257,7 @@ class MLUtilsSuite extends SparkFunSuite with MLlibTestSparkContext { val z = Vectors.dense(4.0) val p = (5.0, z) val w = Vectors.dense(6.0).asML - val df = spark.createDataFrame(Seq( - (0, x, y, p, w) - )).toDF("id", "x", "y", "p", "w") + val df = Seq((0, x, y, p, w)).toDF("id", "x", "y", "p", "w") .withColumn("x", col("x"), metadata) val newDF1 = convertVectorColumnsToML(df) assert(newDF1.schema("x").metadata === metadata, "Metadata should be preserved.") @@ -282,9 +282,7 @@ class MLUtilsSuite extends SparkFunSuite with MLlibTestSparkContext { val z = Vectors.dense(4.0).asML val p = (5.0, z) val w = Vectors.dense(6.0) - val df = spark.createDataFrame(Seq( - (0, x, y, p, w) - )).toDF("id", "x", "y", "p", "w") + val df = Seq((0, x, y, p, w)).toDF("id", "x", "y", "p", "w") .withColumn("x", col("x"), metadata) val newDF1 = convertVectorColumnsFromML(df) assert(newDF1.schema("x").metadata === metadata, "Metadata should be preserved.") @@ -309,9 +307,7 @@ class MLUtilsSuite extends SparkFunSuite with MLlibTestSparkContext { val z = Matrices.ones(1, 1) val p = (5.0, z) val w = Matrices.dense(1, 1, Array(4.5)).asML - val df = spark.createDataFrame(Seq( - (0, x, y, p, w) - )).toDF("id", "x", "y", "p", "w") + val df = Seq((0, x, y, p, w)).toDF("id", "x", "y", "p", "w") .withColumn("x", col("x"), metadata) val newDF1 = convertMatrixColumnsToML(df) assert(newDF1.schema("x").metadata === metadata, "Metadata should be preserved.") @@ -336,9 +332,7 @@ class MLUtilsSuite extends SparkFunSuite with MLlibTestSparkContext { val z = Matrices.ones(1, 1).asML val p = (5.0, z) val w = Matrices.dense(1, 1, Array(4.5)) - val df = spark.createDataFrame(Seq( - (0, x, y, p, w) - )).toDF("id", "x", "y", "p", "w") + val df = Seq((0, x, y, p, w)).toDF("id", "x", "y", "p", "w") .withColumn("x", col("x"), metadata) val newDF1 = convertMatrixColumnsFromML(df) assert(newDF1.schema("x").metadata === metadata, "Metadata should be preserved.") diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala index db56aff63102c..6bb7ed9c9513c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala @@ -23,7 +23,7 @@ import org.scalatest.Suite import org.apache.spark.SparkContext import org.apache.spark.ml.util.TempDirectory -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.{SparkSession, SQLContext, SQLImplicits} import org.apache.spark.util.Utils trait MLlibTestSparkContext extends TempDirectory { self: Suite => @@ -55,4 +55,15 @@ trait MLlibTestSparkContext extends TempDirectory { self: Suite => super.afterAll() } } + + /** + * A helper object for importing SQL implicits. + * + * Note that the alternative of importing `spark.implicits._` is not possible here. + * This is because we create the [[SQLContext]] immediately before the first test is run, + * but the implicits import is needed in the constructor. + */ + protected object testImplicits extends SQLImplicits { + protected override def _sqlContext: SQLContext = self.spark.sqlContext + } } From bde85f8b70138a51052b613664facbc981378c38 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Mon, 26 Sep 2016 10:44:35 -0700 Subject: [PATCH 124/306] [SPARK-17649][CORE] Log how many Spark events got dropped in LiveListenerBus ## What changes were proposed in this pull request? Log how many Spark events got dropped in LiveListenerBus so that the user can get insights on how to set a correct event queue size. ## How was this patch tested? Jenkins Author: Shixiong Zhu Closes #15220 from zsxwing/SPARK-17649. --- .../spark/scheduler/LiveListenerBus.scala | 26 ++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala index bfa3c408f2284..5533f7b1f2363 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala @@ -18,7 +18,7 @@ package org.apache.spark.scheduler import java.util.concurrent._ -import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong} import scala.util.DynamicVariable @@ -57,6 +57,12 @@ private[spark] class LiveListenerBus(val sparkContext: SparkContext) extends Spa // Indicate if `stop()` is called private val stopped = new AtomicBoolean(false) + /** A counter for dropped events. It will be reset every time we log it. */ + private val droppedEventsCounter = new AtomicLong(0L) + + /** When `droppedEventsCounter` was logged last time in milliseconds. */ + @volatile private var lastReportTimestamp = 0L + // Indicate if we are processing some event // Guarded by `self` private var processingEvent = false @@ -123,6 +129,24 @@ private[spark] class LiveListenerBus(val sparkContext: SparkContext) extends Spa eventLock.release() } else { onDropEvent(event) + droppedEventsCounter.incrementAndGet() + } + + val droppedEvents = droppedEventsCounter.get + if (droppedEvents > 0) { + // Don't log too frequently + if (System.currentTimeMillis() - lastReportTimestamp >= 60 * 1000) { + // There may be multiple threads trying to decrease droppedEventsCounter. + // Use "compareAndSet" to make sure only one thread can win. + // And if another thread is increasing droppedEventsCounter, "compareAndSet" will fail and + // then that thread will update it. + if (droppedEventsCounter.compareAndSet(droppedEvents, 0)) { + val prevLastReportTimestamp = lastReportTimestamp + lastReportTimestamp = System.currentTimeMillis() + logWarning(s"Dropped $droppedEvents SparkListenerEvents since " + + new java.util.Date(prevLastReportTimestamp)) + } + } } } From 8135e0e5ebdb9c7f5ac41c675dc8979a5127a31a Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 26 Sep 2016 13:07:11 -0700 Subject: [PATCH 125/306] [SPARK-17153][SQL] Should read partition data when reading new files in filestream without globbing ## What changes were proposed in this pull request? When reading file stream with non-globbing path, the results return data with all `null`s for the partitioned columns. E.g., case class A(id: Int, value: Int) val data = spark.createDataset(Seq( A(1, 1), A(2, 2), A(2, 3)) ) val url = "/tmp/test" data.write.partitionBy("id").parquet(url) spark.read.parquet(url).show +-----+---+ |value| id| +-----+---+ | 2| 2| | 3| 2| | 1| 1| +-----+---+ val s = spark.readStream.schema(spark.read.load(url).schema).parquet(url) s.writeStream.queryName("test").format("memory").start() sql("SELECT * FROM test").show +-----+----+ |value| id| +-----+----+ | 2|null| | 3|null| | 1|null| +-----+----+ ## How was this patch tested? Jenkins tests. Author: Liang-Chi Hsieh Author: Liang-Chi Hsieh Closes #14803 from viirya/filestreamsource-option. --- .../structured-streaming-programming-guide.md | 6 ++ .../execution/datasources/DataSource.scala | 7 +- .../streaming/FileStreamSource.scala | 9 +- .../sql/streaming/FileStreamSourceSuite.scala | 83 ++++++++++++++++++- .../spark/sql/streaming/StreamTest.scala | 8 ++ 5 files changed, 110 insertions(+), 3 deletions(-) diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index c7ed3b04bced1..2e6df94823d38 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -512,6 +512,12 @@ csvDF = spark \ These examples generate streaming DataFrames that are untyped, meaning that the schema of the DataFrame is not checked at compile time, only checked at runtime when the query is submitted. Some operations like `map`, `flatMap`, etc. need the type to be known at compile time. To do those, you can convert these untyped streaming DataFrames to typed streaming Datasets using the same methods as static DataFrame. See the [SQL Programming Guide](sql-programming-guide.html) for more details. Additionally, more details on the supported streaming sources are discussed later in the document. +### Schema inference and partition of streaming DataFrames/Datasets + +By default, Structured Streaming from file based sources requires you to specify the schema, rather than rely on Spark to infer it automatically. This restriction ensures a consistent schema will be used for the streaming query, even in the case of failures. For ad-hoc use cases, you can reenable schema inference by setting `spark.sql.streaming.schemaInference` to `true`. + +Partition discovery does occur when subdirectories that are named `/key=value/` are present and listing will automatically recurse into these directories. If these columns appear in the user provided schema, they will be filled in by Spark based on the path of the file being read. The directories that make up the partitioning scheme must be present when the query starts and must remain static. For example, it is okay to add `/data/year=2016/` when `/data/year=2015/` was present, but it is invalid to change the partitioning column (i.e. by creating the directory `/data/date=2016-04-17/`). + ## Operations on streaming DataFrames/Datasets You can apply all kinds of operations on streaming DataFrames/Datasets – ranging from untyped, SQL-like operations (e.g. `select`, `where`, `groupBy`), to typed RDD-like operations (e.g. `map`, `filter`, `flatMap`). See the [SQL programming guide](sql-programming-guide.html) for more details. Let’s take a look at a few example operations that you can use. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 32067011c3dff..e75e7d2770b4e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -197,10 +197,15 @@ case class DataSource( SparkHadoopUtil.get.globPathIfNecessary(qualified) }.toArray val fileCatalog = new ListingFileCatalog(sparkSession, globbedPaths, options, None) - format.inferSchema( + val partitionCols = fileCatalog.partitionSpec().partitionColumns.fields + val inferred = format.inferSchema( sparkSession, caseInsensitiveOptions, fileCatalog.allFiles()) + + inferred.map { inferredSchema => + StructType(inferredSchema ++ partitionCols) + } }.getOrElse { throw new AnalysisException("Unable to infer schema. It must be specified manually.") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala index be023273db2f2..614a6261e7c28 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala @@ -47,6 +47,13 @@ class FileStreamSource( fs.makeQualified(new Path(path)) // can contains glob patterns } + private val optionsWithPartitionBasePath = sourceOptions.optionMapWithoutPath ++ { + if (!SparkHadoopUtil.get.isGlobPath(new Path(path)) && options.contains("path")) { + Map("basePath" -> path) + } else { + Map() + }} + private val metadataLog = new FileStreamSourceLog(FileStreamSourceLog.VERSION, sparkSession, metadataPath) private var maxBatchId = metadataLog.getLatest().map(_._1).getOrElse(-1L) @@ -136,7 +143,7 @@ class FileStreamSource( paths = files.map(_.path), userSpecifiedSchema = Some(schema), className = fileFormatClassName, - options = sourceOptions.optionMapWithoutPath) + options = optionsWithPartitionBasePath) Dataset.ofRows(sparkSession, LogicalRelation(newDataSource.resolveRelation( checkFilesExist = false))) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index 55c95ae285c1b..3157afe5a56c0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -102,6 +102,12 @@ class FileStreamSourceTest extends StreamTest with SharedSQLContext with Private } } + case class DeleteFile(file: File) extends ExternalAction { + def runAction(): Unit = { + Utils.deleteRecursively(file) + } + } + /** Use `format` and `path` to create FileStreamSource via DataFrameReader */ def createFileStream( format: String, @@ -608,6 +614,81 @@ class FileStreamSourceSuite extends FileStreamSourceTest { // =============== other tests ================ + test("read new files in partitioned table without globbing, should read partition data") { + withTempDirs { case (dir, tmp) => + val partitionFooSubDir = new File(dir, "partition=foo") + val partitionBarSubDir = new File(dir, "partition=bar") + + val schema = new StructType().add("value", StringType).add("partition", StringType) + val fileStream = createFileStream("json", s"${dir.getCanonicalPath}", Some(schema)) + val filtered = fileStream.filter($"value" contains "keep") + testStream(filtered)( + // Create new partition=foo sub dir and write to it + AddTextFileData("{'value': 'drop1'}\n{'value': 'keep2'}", partitionFooSubDir, tmp), + CheckAnswer(("keep2", "foo")), + + // Append to same partition=foo sub dir + AddTextFileData("{'value': 'keep3'}", partitionFooSubDir, tmp), + CheckAnswer(("keep2", "foo"), ("keep3", "foo")), + + // Create new partition sub dir and write to it + AddTextFileData("{'value': 'keep4'}", partitionBarSubDir, tmp), + CheckAnswer(("keep2", "foo"), ("keep3", "foo"), ("keep4", "bar")), + + // Append to same partition=bar sub dir + AddTextFileData("{'value': 'keep5'}", partitionBarSubDir, tmp), + CheckAnswer(("keep2", "foo"), ("keep3", "foo"), ("keep4", "bar"), ("keep5", "bar")) + ) + } + } + + test("when schema inference is turned on, should read partition data") { + def createFile(content: String, src: File, tmp: File): Unit = { + val tempFile = Utils.tempFileWith(new File(tmp, "text")) + val finalFile = new File(src, tempFile.getName) + src.mkdirs() + require(stringToFile(tempFile, content).renameTo(finalFile)) + } + + withSQLConf(SQLConf.STREAMING_SCHEMA_INFERENCE.key -> "true") { + withTempDirs { case (dir, tmp) => + val partitionFooSubDir = new File(dir, "partition=foo") + val partitionBarSubDir = new File(dir, "partition=bar") + + // Create file in partition, so we can infer the schema. + createFile("{'value': 'drop0'}", partitionFooSubDir, tmp) + + val fileStream = createFileStream("json", s"${dir.getCanonicalPath}") + val filtered = fileStream.filter($"value" contains "keep") + testStream(filtered)( + // Append to same partition=foo sub dir + AddTextFileData("{'value': 'drop1'}\n{'value': 'keep2'}", partitionFooSubDir, tmp), + CheckAnswer(("keep2", "foo")), + + // Append to same partition=foo sub dir + AddTextFileData("{'value': 'keep3'}", partitionFooSubDir, tmp), + CheckAnswer(("keep2", "foo"), ("keep3", "foo")), + + // Create new partition sub dir and write to it + AddTextFileData("{'value': 'keep4'}", partitionBarSubDir, tmp), + CheckAnswer(("keep2", "foo"), ("keep3", "foo"), ("keep4", "bar")), + + // Append to same partition=bar sub dir + AddTextFileData("{'value': 'keep5'}", partitionBarSubDir, tmp), + CheckAnswer(("keep2", "foo"), ("keep3", "foo"), ("keep4", "bar"), ("keep5", "bar")), + + // Delete the two partition dirs + DeleteFile(partitionFooSubDir), + DeleteFile(partitionBarSubDir), + + AddTextFileData("{'value': 'keep6'}", partitionBarSubDir, tmp), + CheckAnswer(("keep2", "foo"), ("keep3", "foo"), ("keep4", "bar"), ("keep5", "bar"), + ("keep6", "bar")) + ) + } + } + } + test("fault tolerance") { withTempDirs { case (src, tmp) => val fileStream = createFileStream("text", src.getCanonicalPath) @@ -792,7 +873,7 @@ class FileStreamSourceSuite extends FileStreamSourceTest { } assert(src.listFiles().size === numFiles) - val files = spark.readStream.text(root.getCanonicalPath).as[String] + val files = spark.readStream.text(root.getCanonicalPath).as[(String, Int)] // Note this query will use constant folding to eliminate the file scan. // This is to avoid actually running a Spark job with 10000 tasks diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 6c5b170d9c7c3..aa6515bc7a909 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -95,6 +95,11 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { def addData(query: Option[StreamExecution]): (Source, Offset) } + /** A trait that can be extended when testing a source. */ + trait ExternalAction extends StreamAction { + def runAction(): Unit + } + case class AddDataMemory[A](source: MemoryStream[A], data: Seq[A]) extends AddData { override def toString: String = s"AddData to $source: ${data.mkString(",")}" @@ -429,6 +434,9 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { failTest("Error adding data", e) } + case e: ExternalAction => + e.runAction() + case CheckAnswerRows(expectedAnswer, lastOnly, isSorted) => verify(currentStream != null, "stream not running") // Get the map of source index to the current source objects From 7c7586aef9243081d02ea5065435234b5950ab66 Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Mon, 26 Sep 2016 13:21:08 -0700 Subject: [PATCH 126/306] [SPARK-17652] Fix confusing exception message while reserving capacity ## What changes were proposed in this pull request? This minor patch fixes a confusing exception message while reserving additional capacity in the vectorized parquet reader. ## How was this patch tested? Exisiting Unit Tests Author: Sameer Agarwal Closes #15225 from sameeragarwal/error-msg. --- .../sql/execution/vectorized/ColumnVector.java | 14 +++++++------- .../execution/vectorized/ColumnarBatchSuite.scala | 4 ++-- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java index a7cb3b11f687a..ff07940422a0b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java @@ -285,19 +285,19 @@ public void reserve(int requiredCapacity) { try { reserveInternal(newCapacity); } catch (OutOfMemoryError outOfMemoryError) { - throwUnsupportedException(newCapacity, requiredCapacity, outOfMemoryError); + throwUnsupportedException(requiredCapacity, outOfMemoryError); } } else { - throwUnsupportedException(newCapacity, requiredCapacity, null); + throwUnsupportedException(requiredCapacity, null); } } } - private void throwUnsupportedException(int newCapacity, int requiredCapacity, Throwable cause) { - String message = "Cannot reserve more than " + newCapacity + - " bytes in the vectorized reader (requested = " + requiredCapacity + " bytes). As a" + - " workaround, you can disable the vectorized reader by setting " - + SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key() + " to false."; + private void throwUnsupportedException(int requiredCapacity, Throwable cause) { + String message = "Cannot reserve additional contiguous bytes in the vectorized reader " + + "(requested = " + requiredCapacity + " bytes). As a workaround, you can disable the " + + "vectorized reader by setting " + SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key() + + " to false."; if (cause != null) { throw new RuntimeException(message, cause); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 100cc4daca875..e3943f31a48ba 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -802,8 +802,8 @@ class ColumnarBatchSuite extends SparkFunSuite { // Over-allocating beyond MAX_CAPACITY throws an exception column.appendBytes(10, 0.toByte) } - assert(ex.getMessage.contains(s"Cannot reserve more than ${column.MAX_CAPACITY} bytes in " + - s"the vectorized reader")) + assert(ex.getMessage.contains(s"Cannot reserve additional contiguous bytes in the " + + s"vectorized reader")) } } } From 00be16df642317137f17d2d7d2887c41edac3680 Mon Sep 17 00:00:00 2001 From: Andrew Mills Date: Mon, 26 Sep 2016 16:41:10 -0400 Subject: [PATCH 127/306] [Docs] Update spark-standalone.md to fix link Corrected a link to the configuration.html page, it was pointing to a page that does not exist (configurations.html). Documentation change, verified in preview. Author: Andrew Mills Closes #15244 from ammills01/master. --- docs/spark-standalone.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md index 1097f1fabef6c..7b82b957d5299 100644 --- a/docs/spark-standalone.md +++ b/docs/spark-standalone.md @@ -348,7 +348,7 @@ Learn more about getting started with ZooKeeper [here](http://zookeeper.apache.o **Configuration** In order to enable this recovery mode, you can set SPARK_DAEMON_JAVA_OPTS in spark-env by configuring `spark.deploy.recoveryMode` and related spark.deploy.zookeeper.* configurations. -For more information about these configurations please refer to the configurations (doc)[configurations.html#deploy] +For more information about these configurations please refer to the [configuration doc](configuration.html#deploy) Possible gotcha: If you have multiple Masters in your cluster but fail to correctly configure the Masters to use ZooKeeper, the Masters will fail to discover each other and think they're all leaders. This will not lead to a healthy cluster state (as all Masters will schedule independently). From 93c743f1aca433144611b11d4e1b169d66e0f57b Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Mon, 26 Sep 2016 16:47:57 -0700 Subject: [PATCH 128/306] [SPARK-17577][FOLLOW-UP][SPARKR] SparkR spark.addFile supports adding directory recursively ## What changes were proposed in this pull request? #15140 exposed ```JavaSparkContext.addFile(path: String, recursive: Boolean)``` to Python/R, then we can update SparkR ```spark.addFile``` to support adding directory recursively. ## How was this patch tested? Added unit test. Author: Yanbo Liang Closes #15216 from yanboliang/spark-17577-2. --- R/pkg/R/context.R | 9 +++++++-- R/pkg/inst/tests/testthat/test_context.R | 22 ++++++++++++++++++++++ 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R index 4793578ad684e..fe2f3e3d10a9b 100644 --- a/R/pkg/R/context.R +++ b/R/pkg/R/context.R @@ -231,17 +231,22 @@ setCheckpointDir <- function(sc, dirName) { #' filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs, #' use spark.getSparkFiles(fileName) to find its download location. #' +#' A directory can be given if the recursive option is set to true. +#' Currently directories are only supported for Hadoop-supported filesystems. +#' Refer Hadoop-supported filesystems at \url{https://wiki.apache.org/hadoop/HCFS}. +#' #' @rdname spark.addFile #' @param path The path of the file to be added +#' @param recursive Whether to add files recursively from the path. Default is FALSE. #' @export #' @examples #'\dontrun{ #' spark.addFile("~/myfile") #'} #' @note spark.addFile since 2.1.0 -spark.addFile <- function(path) { +spark.addFile <- function(path, recursive = FALSE) { sc <- getSparkContext() - invisible(callJMethod(sc, "addFile", suppressWarnings(normalizePath(path)))) + invisible(callJMethod(sc, "addFile", suppressWarnings(normalizePath(path)), recursive)) } #' Get the root directory that contains files added through spark.addFile. diff --git a/R/pkg/inst/tests/testthat/test_context.R b/R/pkg/inst/tests/testthat/test_context.R index 0495418bb7779..caca06933952b 100644 --- a/R/pkg/inst/tests/testthat/test_context.R +++ b/R/pkg/inst/tests/testthat/test_context.R @@ -169,6 +169,7 @@ test_that("spark.lapply should perform simple transforms", { test_that("add and get file to be downloaded with Spark job on every node", { sparkR.sparkContext() + # Test add file. path <- tempfile(pattern = "hello", fileext = ".txt") filename <- basename(path) words <- "Hello World!" @@ -177,5 +178,26 @@ test_that("add and get file to be downloaded with Spark job on every node", { download_path <- spark.getSparkFiles(filename) expect_equal(readLines(download_path), words) unlink(path) + + # Test add directory recursively. + path <- paste0(tempdir(), "/", "recursive_dir") + dir.create(path) + dir_name <- basename(path) + path1 <- paste0(path, "/", "hello.txt") + file.create(path1) + sub_path <- paste0(path, "/", "sub_hello") + dir.create(sub_path) + path2 <- paste0(sub_path, "/", "sub_hello.txt") + file.create(path2) + words <- "Hello World!" + sub_words <- "Sub Hello World!" + writeLines(words, path1) + writeLines(sub_words, path2) + spark.addFile(path, recursive = TRUE) + download_path1 <- spark.getSparkFiles(paste0(dir_name, "/", "hello.txt")) + expect_equal(readLines(download_path1), words) + download_path2 <- spark.getSparkFiles(paste0(dir_name, "/", "sub_hello/sub_hello.txt")) + expect_equal(readLines(download_path2), sub_words) + unlink(path, recursive = TRUE) sparkR.session.stop() }) From 6ee28423ad1b2e6089b82af64a31d77d3552bb38 Mon Sep 17 00:00:00 2001 From: Ding Fei Date: Mon, 26 Sep 2016 23:09:51 -0700 Subject: [PATCH 129/306] Fix two comments since Actor is not used anymore. ## What changes were proposed in this pull request? Fix two comments since Actor is not used anymore. Author: Ding Fei Closes #15251 from danix800/comment-fixing. --- .../scala/org/apache/spark/deploy/worker/WorkerWatcher.scala | 3 ++- .../test/scala/org/apache/spark/MapOutputTrackerSuite.scala | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala index af29de3b0896e..23efcab6caad1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala @@ -21,7 +21,8 @@ import org.apache.spark.internal.Logging import org.apache.spark.rpc._ /** - * Actor which connects to a worker process and terminates the JVM if the connection is severed. + * Endpoint which connects to a worker process and terminates the JVM if the + * connection is severed. * Provides fate sharing between a worker and its associated child processes. */ private[spark] class WorkerWatcher( diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index c6aebc19fd12d..bb24c6ce4d33c 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -253,7 +253,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { rpcEnv.stop(masterTracker.trackerEndpoint) rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, masterEndpoint) - // Frame size should be ~1.1MB, and MapOutputTrackerMasterActor should throw exception. + // Frame size should be ~1.1MB, and MapOutputTrackerMasterEndpoint should throw exception. // Note that the size is hand-selected here because map output statuses are compressed before // being sent. masterTracker.registerShuffle(20, 100) From 85b0a157543201895557d66306b38b3ca52f2151 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 27 Sep 2016 14:18:32 +0800 Subject: [PATCH 130/306] [SPARK-15962][SQL] Introduce implementation with a dense format for UnsafeArrayData ## What changes were proposed in this pull request? This PR introduces more compact representation for ```UnsafeArrayData```. ```UnsafeArrayData``` needs to accept ```null``` value in each entry of an array. In the current version, it has three parts ``` [numElements] [offsets] [values] ``` `Offsets` has the number of `numElements`, and represents `null` if its value is negative. It may increase memory footprint, and introduces an indirection for accessing each of `values`. This PR uses bitvectors to represent nullability for each element like `UnsafeRow`, and eliminates an indirection for accessing each element. The new ```UnsafeArrayData``` has four parts. ``` [numElements][null bits][values or offset&length][variable length portion] ``` In the `null bits` region, we store 1 bit per element, represents whether an element is null. Its total size is ceil(numElements / 8) bytes, and it is aligned to 8-byte boundaries. In the `values or offset&length` region, we store the content of elements. For fields that hold fixed-length primitive types, such as long, double, or int, we store the value directly in the field. For fields with non-primitive or variable-length values, we store a relative offset (w.r.t. the base address of the array) that points to the beginning of the variable-length field and length (they are combined into a long). Each is word-aligned. For `variable length portion`, each is aligned to 8-byte boundaries. The new format can reduce memory footprint and improve performance of accessing each element. An example of memory foot comparison: 1024x1024 elements integer array Size of ```baseObject``` for ```UnsafeArrayData```: 8 + 1024x1024 + 1024x1024 = 2M bytes Size of ```baseObject``` for ```UnsafeArrayData```: 8 + 1024x1024/8 + 1024x1024 = 1.25M bytes In summary, we got 1.0-2.6x performance improvements over the code before applying this PR. Here are performance results of [benchmark programs](https://github.com/kiszk/spark/blob/04d2e4b6dbdc4eff43ce18b3c9b776e0129257c7/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/UnsafeArrayDataBenchmark.scala): **Read UnsafeArrayData**: 1.7x and 1.6x performance improvements over the code before applying this PR ```` OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 4.4.11-200.fc22.x86_64 Intel Xeon E3-12xx v2 (Ivy Bridge) Without SPARK-15962 Read UnsafeArrayData: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ Int 430 / 436 390.0 2.6 1.0X Double 456 / 485 367.8 2.7 0.9X With SPARK-15962 Read UnsafeArrayData: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ Int 252 / 260 666.1 1.5 1.0X Double 281 / 292 597.7 1.7 0.9X ```` **Write UnsafeArrayData**: 1.0x and 1.1x performance improvements over the code before applying this PR ```` OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 4.0.4-301.fc22.x86_64 Intel Xeon E3-12xx v2 (Ivy Bridge) Without SPARK-15962 Write UnsafeArrayData: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ Int 203 / 273 103.4 9.7 1.0X Double 239 / 356 87.9 11.4 0.8X With SPARK-15962 Write UnsafeArrayData: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ Int 196 / 249 107.0 9.3 1.0X Double 227 / 367 92.3 10.8 0.9X ```` **Get primitive array from UnsafeArrayData**: 2.6x and 1.6x performance improvements over the code before applying this PR ```` OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 4.0.4-301.fc22.x86_64 Intel Xeon E3-12xx v2 (Ivy Bridge) Without SPARK-15962 Get primitive array from UnsafeArrayData: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ Int 207 / 217 304.2 3.3 1.0X Double 257 / 363 245.2 4.1 0.8X With SPARK-15962 Get primitive array from UnsafeArrayData: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ Int 151 / 198 415.8 2.4 1.0X Double 214 / 394 293.6 3.4 0.7X ```` **Create UnsafeArrayData from primitive array**: 1.7x and 2.1x performance improvements over the code before applying this PR ```` OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 4.0.4-301.fc22.x86_64 Intel Xeon E3-12xx v2 (Ivy Bridge) Without SPARK-15962 Create UnsafeArrayData from primitive array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ Int 340 / 385 185.1 5.4 1.0X Double 479 / 705 131.3 7.6 0.7X With SPARK-15962 Create UnsafeArrayData from primitive array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ Int 206 / 211 306.0 3.3 1.0X Double 232 / 406 271.6 3.7 0.9X ```` 1.7x and 1.4x performance improvements in [```UDTSerializationBenchmark```](https://github.com/apache/spark/blob/master/mllib/src/test/scala/org/apache/spark/mllib/linalg/UDTSerializationBenchmark.scala) over the code before applying this PR ```` OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 4.4.11-200.fc22.x86_64 Intel Xeon E3-12xx v2 (Ivy Bridge) Without SPARK-15962 VectorUDT de/serialization: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ serialize 442 / 533 0.0 441927.1 1.0X deserialize 217 / 274 0.0 217087.6 2.0X With SPARK-15962 VectorUDT de/serialization: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ serialize 265 / 318 0.0 265138.5 1.0X deserialize 155 / 197 0.0 154611.4 1.7X ```` ## How was this patch tested? Added unit tests into ```UnsafeArraySuite``` Author: Kazuaki Ishizaki Closes #13680 from kiszk/SPARK-15962. --- .../org/apache/spark/unsafe/Platform.java | 4 + .../linalg/UDTSerializationBenchmark.scala | 13 +- .../catalyst/expressions/UnsafeArrayData.java | 269 ++++++++++-------- .../catalyst/expressions/UnsafeMapData.java | 13 +- .../codegen/UnsafeArrayWriter.java | 193 +++++++++---- .../codegen/GenerateUnsafeProjection.scala | 31 +- .../expressions/UnsafeRowConverterSuite.scala | 23 +- .../sql/catalyst/util/UnsafeArraySuite.scala | 195 +++++++++++-- .../sql/execution/columnar/ColumnType.scala | 4 +- .../benchmark/UnsafeArrayDataBenchmark.scala | 232 +++++++++++++++ .../execution/columnar/ColumnTypeSuite.scala | 4 +- 11 files changed, 750 insertions(+), 231 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/UnsafeArrayDataBenchmark.scala diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java index c892b9cdaf49c..671b8c7475943 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java @@ -29,6 +29,8 @@ public final class Platform { private static final Unsafe _UNSAFE; + public static final int BOOLEAN_ARRAY_OFFSET; + public static final int BYTE_ARRAY_OFFSET; public static final int SHORT_ARRAY_OFFSET; @@ -235,6 +237,7 @@ public static void throwException(Throwable t) { _UNSAFE = unsafe; if (_UNSAFE != null) { + BOOLEAN_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(boolean[].class); BYTE_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(byte[].class); SHORT_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(short[].class); INT_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(int[].class); @@ -242,6 +245,7 @@ public static void throwException(Throwable t) { FLOAT_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(float[].class); DOUBLE_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(double[].class); } else { + BOOLEAN_ARRAY_OFFSET = 0; BYTE_ARRAY_OFFSET = 0; SHORT_ARRAY_OFFSET = 0; INT_ARRAY_OFFSET = 0; diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/UDTSerializationBenchmark.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/UDTSerializationBenchmark.scala index 8b439e6b7a017..5973479dfb5ed 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/UDTSerializationBenchmark.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/UDTSerializationBenchmark.scala @@ -57,13 +57,12 @@ object UDTSerializationBenchmark { } /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.11.4 - Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz - - VectorUDT de/serialization: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - serialize 380 / 392 0.0 379730.0 1.0X - deserialize 138 / 142 0.0 137816.6 2.8X + OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 4.4.11-200.fc22.x86_64 + Intel Xeon E3-12xx v2 (Ivy Bridge) + VectorUDT de/serialization: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + serialize 265 / 318 0.0 265138.5 1.0X + deserialize 155 / 197 0.0 154611.4 1.7X */ benchmark.run() } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index 6302660548ec1..86523c1474015 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -25,6 +25,7 @@ import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; +import org.apache.spark.unsafe.bitset.BitSetMethods; import org.apache.spark.unsafe.hash.Murmur3_x86_32; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; @@ -32,23 +33,31 @@ /** * An Unsafe implementation of Array which is backed by raw memory instead of Java objects. * - * Each tuple has three parts: [numElements] [offsets] [values] + * Each array has four parts: + * [numElements][null bits][values or offset&length][variable length portion] * - * The `numElements` is 4 bytes storing the number of elements of this array. + * The `numElements` is 8 bytes storing the number of elements of this array. * - * In the `offsets` region, we store 4 bytes per element, represents the relative offset (w.r.t. the - * base address of the array) of this element in `values` region. We can get the length of this - * element by subtracting next offset. - * Note that offset can by negative which means this element is null. + * In the `null bits` region, we store 1 bit per element, represents whether an element is null + * Its total size is ceil(numElements / 8) bytes, and it is aligned to 8-byte boundaries. * - * In the `values` region, we store the content of elements. As we can get length info, so elements - * can be variable-length. + * In the `values or offset&length` region, we store the content of elements. For fields that hold + * fixed-length primitive types, such as long, double, or int, we store the value directly + * in the field. The whole fixed-length portion (even for byte) is aligned to 8-byte boundaries. + * For fields with non-primitive or variable-length values, we store a relative offset + * (w.r.t. the base address of the array) that points to the beginning of the variable-length field + * and length (they are combined into a long). For variable length portion, each is aligned + * to 8-byte boundaries. * * Instances of `UnsafeArrayData` act as pointers to row data stored in this format. */ -// todo: there is a lof of duplicated code between UnsafeRow and UnsafeArrayData. + public final class UnsafeArrayData extends ArrayData { + public static int calculateHeaderPortionInBytes(int numFields) { + return 8 + ((numFields + 63)/ 64) * 8; + } + private Object baseObject; private long baseOffset; @@ -56,24 +65,19 @@ public final class UnsafeArrayData extends ArrayData { private int numElements; // The size of this array's backing data, in bytes. - // The 4-bytes header of `numElements` is also included. + // The 8-bytes header of `numElements` is also included. private int sizeInBytes; - public Object getBaseObject() { return baseObject; } - public long getBaseOffset() { return baseOffset; } - public int getSizeInBytes() { return sizeInBytes; } + /** The position to start storing array elements, */ + private long elementOffset; - private int getElementOffset(int ordinal) { - return Platform.getInt(baseObject, baseOffset + 4 + ordinal * 4L); + private long getElementOffset(int ordinal, int elementSize) { + return elementOffset + ordinal * elementSize; } - private int getElementSize(int offset, int ordinal) { - if (ordinal == numElements - 1) { - return sizeInBytes - offset; - } else { - return Math.abs(getElementOffset(ordinal + 1)) - offset; - } - } + public Object getBaseObject() { return baseObject; } + public long getBaseOffset() { return baseOffset; } + public int getSizeInBytes() { return sizeInBytes; } private void assertIndexIsValid(int ordinal) { assert ordinal >= 0 : "ordinal (" + ordinal + ") should >= 0"; @@ -102,20 +106,22 @@ public UnsafeArrayData() { } * @param sizeInBytes the size of this array's backing data, in bytes */ public void pointTo(Object baseObject, long baseOffset, int sizeInBytes) { - // Read the number of elements from the first 4 bytes. - final int numElements = Platform.getInt(baseObject, baseOffset); + // Read the number of elements from the first 8 bytes. + final long numElements = Platform.getLong(baseObject, baseOffset); assert numElements >= 0 : "numElements (" + numElements + ") should >= 0"; + assert numElements <= Integer.MAX_VALUE : "numElements (" + numElements + ") should <= Integer.MAX_VALUE"; - this.numElements = numElements; + this.numElements = (int)numElements; this.baseObject = baseObject; this.baseOffset = baseOffset; this.sizeInBytes = sizeInBytes; + this.elementOffset = baseOffset + calculateHeaderPortionInBytes(this.numElements); } @Override public boolean isNullAt(int ordinal) { assertIndexIsValid(ordinal); - return getElementOffset(ordinal) < 0; + return BitSetMethods.isSet(baseObject, baseOffset + 8, ordinal); } @Override @@ -165,68 +171,50 @@ public Object get(int ordinal, DataType dataType) { @Override public boolean getBoolean(int ordinal) { assertIndexIsValid(ordinal); - final int offset = getElementOffset(ordinal); - if (offset < 0) return false; - return Platform.getBoolean(baseObject, baseOffset + offset); + return Platform.getBoolean(baseObject, getElementOffset(ordinal, 1)); } @Override public byte getByte(int ordinal) { assertIndexIsValid(ordinal); - final int offset = getElementOffset(ordinal); - if (offset < 0) return 0; - return Platform.getByte(baseObject, baseOffset + offset); + return Platform.getByte(baseObject, getElementOffset(ordinal, 1)); } @Override public short getShort(int ordinal) { assertIndexIsValid(ordinal); - final int offset = getElementOffset(ordinal); - if (offset < 0) return 0; - return Platform.getShort(baseObject, baseOffset + offset); + return Platform.getShort(baseObject, getElementOffset(ordinal, 2)); } @Override public int getInt(int ordinal) { assertIndexIsValid(ordinal); - final int offset = getElementOffset(ordinal); - if (offset < 0) return 0; - return Platform.getInt(baseObject, baseOffset + offset); + return Platform.getInt(baseObject, getElementOffset(ordinal, 4)); } @Override public long getLong(int ordinal) { assertIndexIsValid(ordinal); - final int offset = getElementOffset(ordinal); - if (offset < 0) return 0; - return Platform.getLong(baseObject, baseOffset + offset); + return Platform.getLong(baseObject, getElementOffset(ordinal, 8)); } @Override public float getFloat(int ordinal) { assertIndexIsValid(ordinal); - final int offset = getElementOffset(ordinal); - if (offset < 0) return 0; - return Platform.getFloat(baseObject, baseOffset + offset); + return Platform.getFloat(baseObject, getElementOffset(ordinal, 4)); } @Override public double getDouble(int ordinal) { assertIndexIsValid(ordinal); - final int offset = getElementOffset(ordinal); - if (offset < 0) return 0; - return Platform.getDouble(baseObject, baseOffset + offset); + return Platform.getDouble(baseObject, getElementOffset(ordinal, 8)); } @Override public Decimal getDecimal(int ordinal, int precision, int scale) { - assertIndexIsValid(ordinal); - final int offset = getElementOffset(ordinal); - if (offset < 0) return null; - + if (isNullAt(ordinal)) return null; if (precision <= Decimal.MAX_LONG_DIGITS()) { - final long value = Platform.getLong(baseObject, baseOffset + offset); - return Decimal.apply(value, precision, scale); + return Decimal.apply(getLong(ordinal), precision, scale); } else { final byte[] bytes = getBinary(ordinal); final BigInteger bigInteger = new BigInteger(bytes); @@ -237,19 +225,19 @@ public Decimal getDecimal(int ordinal, int precision, int scale) { @Override public UTF8String getUTF8String(int ordinal) { - assertIndexIsValid(ordinal); - final int offset = getElementOffset(ordinal); - if (offset < 0) return null; - final int size = getElementSize(offset, ordinal); + if (isNullAt(ordinal)) return null; + final long offsetAndSize = getLong(ordinal); + final int offset = (int) (offsetAndSize >> 32); + final int size = (int) offsetAndSize; return UTF8String.fromAddress(baseObject, baseOffset + offset, size); } @Override public byte[] getBinary(int ordinal) { - assertIndexIsValid(ordinal); - final int offset = getElementOffset(ordinal); - if (offset < 0) return null; - final int size = getElementSize(offset, ordinal); + if (isNullAt(ordinal)) return null; + final long offsetAndSize = getLong(ordinal); + final int offset = (int) (offsetAndSize >> 32); + final int size = (int) offsetAndSize; final byte[] bytes = new byte[size]; Platform.copyMemory(baseObject, baseOffset + offset, bytes, Platform.BYTE_ARRAY_OFFSET, size); return bytes; @@ -257,9 +245,9 @@ public byte[] getBinary(int ordinal) { @Override public CalendarInterval getInterval(int ordinal) { - assertIndexIsValid(ordinal); - final int offset = getElementOffset(ordinal); - if (offset < 0) return null; + if (isNullAt(ordinal)) return null; + final long offsetAndSize = getLong(ordinal); + final int offset = (int) (offsetAndSize >> 32); final int months = (int) Platform.getLong(baseObject, baseOffset + offset); final long microseconds = Platform.getLong(baseObject, baseOffset + offset + 8); return new CalendarInterval(months, microseconds); @@ -267,10 +255,10 @@ public CalendarInterval getInterval(int ordinal) { @Override public UnsafeRow getStruct(int ordinal, int numFields) { - assertIndexIsValid(ordinal); - final int offset = getElementOffset(ordinal); - if (offset < 0) return null; - final int size = getElementSize(offset, ordinal); + if (isNullAt(ordinal)) return null; + final long offsetAndSize = getLong(ordinal); + final int offset = (int) (offsetAndSize >> 32); + final int size = (int) offsetAndSize; final UnsafeRow row = new UnsafeRow(numFields); row.pointTo(baseObject, baseOffset + offset, size); return row; @@ -278,10 +266,10 @@ public UnsafeRow getStruct(int ordinal, int numFields) { @Override public UnsafeArrayData getArray(int ordinal) { - assertIndexIsValid(ordinal); - final int offset = getElementOffset(ordinal); - if (offset < 0) return null; - final int size = getElementSize(offset, ordinal); + if (isNullAt(ordinal)) return null; + final long offsetAndSize = getLong(ordinal); + final int offset = (int) (offsetAndSize >> 32); + final int size = (int) offsetAndSize; final UnsafeArrayData array = new UnsafeArrayData(); array.pointTo(baseObject, baseOffset + offset, size); return array; @@ -289,10 +277,10 @@ public UnsafeArrayData getArray(int ordinal) { @Override public UnsafeMapData getMap(int ordinal) { - assertIndexIsValid(ordinal); - final int offset = getElementOffset(ordinal); - if (offset < 0) return null; - final int size = getElementSize(offset, ordinal); + if (isNullAt(ordinal)) return null; + final long offsetAndSize = getLong(ordinal); + final int offset = (int) (offsetAndSize >> 32); + final int size = (int) offsetAndSize; final UnsafeMapData map = new UnsafeMapData(); map.pointTo(baseObject, baseOffset + offset, size); return map; @@ -341,63 +329,108 @@ public UnsafeArrayData copy() { return arrayCopy; } - public static UnsafeArrayData fromPrimitiveArray(int[] arr) { - if (arr.length > (Integer.MAX_VALUE - 4) / 8) { - throw new UnsupportedOperationException("Cannot convert this array to unsafe format as " + - "it's too big."); - } + @Override + public boolean[] toBooleanArray() { + boolean[] values = new boolean[numElements]; + Platform.copyMemory( + baseObject, elementOffset, values, Platform.BOOLEAN_ARRAY_OFFSET, numElements); + return values; + } - final int offsetRegionSize = 4 * arr.length; - final int valueRegionSize = 4 * arr.length; - final int totalSize = 4 + offsetRegionSize + valueRegionSize; - final byte[] data = new byte[totalSize]; + @Override + public byte[] toByteArray() { + byte[] values = new byte[numElements]; + Platform.copyMemory( + baseObject, elementOffset, values, Platform.BYTE_ARRAY_OFFSET, numElements); + return values; + } - Platform.putInt(data, Platform.BYTE_ARRAY_OFFSET, arr.length); + @Override + public short[] toShortArray() { + short[] values = new short[numElements]; + Platform.copyMemory( + baseObject, elementOffset, values, Platform.SHORT_ARRAY_OFFSET, numElements * 2); + return values; + } - int offsetPosition = Platform.BYTE_ARRAY_OFFSET + 4; - int valueOffset = 4 + offsetRegionSize; - for (int i = 0; i < arr.length; i++) { - Platform.putInt(data, offsetPosition, valueOffset); - offsetPosition += 4; - valueOffset += 4; - } + @Override + public int[] toIntArray() { + int[] values = new int[numElements]; + Platform.copyMemory( + baseObject, elementOffset, values, Platform.INT_ARRAY_OFFSET, numElements * 4); + return values; + } - Platform.copyMemory(arr, Platform.INT_ARRAY_OFFSET, data, - Platform.BYTE_ARRAY_OFFSET + 4 + offsetRegionSize, valueRegionSize); + @Override + public long[] toLongArray() { + long[] values = new long[numElements]; + Platform.copyMemory( + baseObject, elementOffset, values, Platform.LONG_ARRAY_OFFSET, numElements * 8); + return values; + } - UnsafeArrayData result = new UnsafeArrayData(); - result.pointTo(data, Platform.BYTE_ARRAY_OFFSET, totalSize); - return result; + @Override + public float[] toFloatArray() { + float[] values = new float[numElements]; + Platform.copyMemory( + baseObject, elementOffset, values, Platform.FLOAT_ARRAY_OFFSET, numElements * 4); + return values; } - public static UnsafeArrayData fromPrimitiveArray(double[] arr) { - if (arr.length > (Integer.MAX_VALUE - 4) / 12) { + @Override + public double[] toDoubleArray() { + double[] values = new double[numElements]; + Platform.copyMemory( + baseObject, elementOffset, values, Platform.DOUBLE_ARRAY_OFFSET, numElements * 8); + return values; + } + + private static UnsafeArrayData fromPrimitiveArray( + Object arr, int offset, int length, int elementSize) { + final long headerInBytes = calculateHeaderPortionInBytes(length); + final long valueRegionInBytes = elementSize * length; + final long totalSizeInLongs = (headerInBytes + valueRegionInBytes + 7) / 8; + if (totalSizeInLongs > Integer.MAX_VALUE / 8) { throw new UnsupportedOperationException("Cannot convert this array to unsafe format as " + "it's too big."); } - final int offsetRegionSize = 4 * arr.length; - final int valueRegionSize = 8 * arr.length; - final int totalSize = 4 + offsetRegionSize + valueRegionSize; - final byte[] data = new byte[totalSize]; + final long[] data = new long[(int)totalSizeInLongs]; - Platform.putInt(data, Platform.BYTE_ARRAY_OFFSET, arr.length); - - int offsetPosition = Platform.BYTE_ARRAY_OFFSET + 4; - int valueOffset = 4 + offsetRegionSize; - for (int i = 0; i < arr.length; i++) { - Platform.putInt(data, offsetPosition, valueOffset); - offsetPosition += 4; - valueOffset += 8; - } - - Platform.copyMemory(arr, Platform.DOUBLE_ARRAY_OFFSET, data, - Platform.BYTE_ARRAY_OFFSET + 4 + offsetRegionSize, valueRegionSize); + Platform.putLong(data, Platform.LONG_ARRAY_OFFSET, length); + Platform.copyMemory(arr, offset, data, + Platform.LONG_ARRAY_OFFSET + headerInBytes, valueRegionInBytes); UnsafeArrayData result = new UnsafeArrayData(); - result.pointTo(data, Platform.BYTE_ARRAY_OFFSET, totalSize); + result.pointTo(data, Platform.LONG_ARRAY_OFFSET, (int)totalSizeInLongs * 8); return result; } - // TODO: add more specialized methods. + public static UnsafeArrayData fromPrimitiveArray(boolean[] arr) { + return fromPrimitiveArray(arr, Platform.BOOLEAN_ARRAY_OFFSET, arr.length, 1); + } + + public static UnsafeArrayData fromPrimitiveArray(byte[] arr) { + return fromPrimitiveArray(arr, Platform.BYTE_ARRAY_OFFSET, arr.length, 1); + } + + public static UnsafeArrayData fromPrimitiveArray(short[] arr) { + return fromPrimitiveArray(arr, Platform.SHORT_ARRAY_OFFSET, arr.length, 2); + } + + public static UnsafeArrayData fromPrimitiveArray(int[] arr) { + return fromPrimitiveArray(arr, Platform.INT_ARRAY_OFFSET, arr.length, 4); + } + + public static UnsafeArrayData fromPrimitiveArray(long[] arr) { + return fromPrimitiveArray(arr, Platform.LONG_ARRAY_OFFSET, arr.length, 8); + } + + public static UnsafeArrayData fromPrimitiveArray(float[] arr) { + return fromPrimitiveArray(arr, Platform.FLOAT_ARRAY_OFFSET, arr.length, 4); + } + + public static UnsafeArrayData fromPrimitiveArray(double[] arr) { + return fromPrimitiveArray(arr, Platform.DOUBLE_ARRAY_OFFSET, arr.length, 8); + } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java index 0700148becaba..35029f5a50e3e 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java @@ -25,7 +25,7 @@ /** * An Unsafe implementation of Map which is backed by raw memory instead of Java objects. * - * Currently we just use 2 UnsafeArrayData to represent UnsafeMapData, with extra 4 bytes at head + * Currently we just use 2 UnsafeArrayData to represent UnsafeMapData, with extra 8 bytes at head * to indicate the number of bytes of the unsafe key array. * [unsafe key array numBytes] [unsafe key array] [unsafe value array] */ @@ -65,14 +65,15 @@ public UnsafeMapData() { * @param sizeInBytes the size of this map's backing data, in bytes */ public void pointTo(Object baseObject, long baseOffset, int sizeInBytes) { - // Read the numBytes of key array from the first 4 bytes. - final int keyArraySize = Platform.getInt(baseObject, baseOffset); - final int valueArraySize = sizeInBytes - keyArraySize - 4; + // Read the numBytes of key array from the first 8 bytes. + final long keyArraySize = Platform.getLong(baseObject, baseOffset); assert keyArraySize >= 0 : "keyArraySize (" + keyArraySize + ") should >= 0"; + assert keyArraySize <= Integer.MAX_VALUE : "keyArraySize (" + keyArraySize + ") should <= Integer.MAX_VALUE"; + final int valueArraySize = sizeInBytes - (int)keyArraySize - 8; assert valueArraySize >= 0 : "valueArraySize (" + valueArraySize + ") should >= 0"; - keys.pointTo(baseObject, baseOffset + 4, keyArraySize); - values.pointTo(baseObject, baseOffset + 4 + keyArraySize, valueArraySize); + keys.pointTo(baseObject, baseOffset + 8, (int)keyArraySize); + values.pointTo(baseObject, baseOffset + 8 + keyArraySize, valueArraySize); assert keys.numElements() == values.numElements(); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java index 7dd932d1981b7..afea4676893ed 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java @@ -19,9 +19,13 @@ import org.apache.spark.sql.types.Decimal; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.array.ByteArrayMethods; +import org.apache.spark.unsafe.bitset.BitSetMethods; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; +import static org.apache.spark.sql.catalyst.expressions.UnsafeArrayData.calculateHeaderPortionInBytes; + /** * A helper class to write data into global row buffer using `UnsafeArrayData` format, * used by {@link org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection}. @@ -33,134 +37,213 @@ public class UnsafeArrayWriter { // The offset of the global buffer where we start to write this array. private int startingOffset; - public void initialize(BufferHolder holder, int numElements, int fixedElementSize) { - // We need 4 bytes to store numElements and 4 bytes each element to store offset. - final int fixedSize = 4 + 4 * numElements; + // The number of elements in this array + private int numElements; + + private int headerInBytes; + + private void assertIndexIsValid(int index) { + assert index >= 0 : "index (" + index + ") should >= 0"; + assert index < numElements : "index (" + index + ") should < " + numElements; + } + + public void initialize(BufferHolder holder, int numElements, int elementSize) { + // We need 8 bytes to store numElements in header + this.numElements = numElements; + this.headerInBytes = calculateHeaderPortionInBytes(numElements); this.holder = holder; this.startingOffset = holder.cursor; - holder.grow(fixedSize); - Platform.putInt(holder.buffer, holder.cursor, numElements); - holder.cursor += fixedSize; + // Grows the global buffer ahead for header and fixed size data. + int fixedPartInBytes = + ByteArrayMethods.roundNumberOfBytesToNearestWord(elementSize * numElements); + holder.grow(headerInBytes + fixedPartInBytes); + + // Write numElements and clear out null bits to header + Platform.putLong(holder.buffer, startingOffset, numElements); + for (int i = 8; i < headerInBytes; i += 8) { + Platform.putLong(holder.buffer, startingOffset + i, 0L); + } + + // fill 0 into reminder part of 8-bytes alignment in unsafe array + for (int i = elementSize * numElements; i < fixedPartInBytes; i++) { + Platform.putByte(holder.buffer, startingOffset + headerInBytes + i, (byte) 0); + } + holder.cursor += (headerInBytes + fixedPartInBytes); + } + + private void zeroOutPaddingBytes(int numBytes) { + if ((numBytes & 0x07) > 0) { + Platform.putLong(holder.buffer, holder.cursor + ((numBytes >> 3) << 3), 0L); + } + } + + private long getElementOffset(int ordinal, int elementSize) { + return startingOffset + headerInBytes + ordinal * elementSize; + } + + public void setOffsetAndSize(int ordinal, long currentCursor, int size) { + assertIndexIsValid(ordinal); + final long relativeOffset = currentCursor - startingOffset; + final long offsetAndSize = (relativeOffset << 32) | (long)size; - // Grows the global buffer ahead for fixed size data. - holder.grow(fixedElementSize * numElements); + write(ordinal, offsetAndSize); } - private long getElementOffset(int ordinal) { - return startingOffset + 4 + 4 * ordinal; + private void setNullBit(int ordinal) { + assertIndexIsValid(ordinal); + BitSetMethods.set(holder.buffer, startingOffset + 8, ordinal); } - public void setNullAt(int ordinal) { - final int relativeOffset = holder.cursor - startingOffset; - // Writes negative offset value to represent null element. - Platform.putInt(holder.buffer, getElementOffset(ordinal), -relativeOffset); + public void setNullBoolean(int ordinal) { + setNullBit(ordinal); + // put zero into the corresponding field when set null + Platform.putBoolean(holder.buffer, getElementOffset(ordinal, 1), false); } - public void setOffset(int ordinal) { - final int relativeOffset = holder.cursor - startingOffset; - Platform.putInt(holder.buffer, getElementOffset(ordinal), relativeOffset); + public void setNullByte(int ordinal) { + setNullBit(ordinal); + // put zero into the corresponding field when set null + Platform.putByte(holder.buffer, getElementOffset(ordinal, 1), (byte)0); } + public void setNullShort(int ordinal) { + setNullBit(ordinal); + // put zero into the corresponding field when set null + Platform.putShort(holder.buffer, getElementOffset(ordinal, 2), (short)0); + } + + public void setNullInt(int ordinal) { + setNullBit(ordinal); + // put zero into the corresponding field when set null + Platform.putInt(holder.buffer, getElementOffset(ordinal, 4), (int)0); + } + + public void setNullLong(int ordinal) { + setNullBit(ordinal); + // put zero into the corresponding field when set null + Platform.putLong(holder.buffer, getElementOffset(ordinal, 8), (long)0); + } + + public void setNullFloat(int ordinal) { + setNullBit(ordinal); + // put zero into the corresponding field when set null + Platform.putFloat(holder.buffer, getElementOffset(ordinal, 4), (float)0); + } + + public void setNullDouble(int ordinal) { + setNullBit(ordinal); + // put zero into the corresponding field when set null + Platform.putDouble(holder.buffer, getElementOffset(ordinal, 8), (double)0); + } + + public void setNull(int ordinal) { setNullLong(ordinal); } + public void write(int ordinal, boolean value) { - Platform.putBoolean(holder.buffer, holder.cursor, value); - setOffset(ordinal); - holder.cursor += 1; + assertIndexIsValid(ordinal); + Platform.putBoolean(holder.buffer, getElementOffset(ordinal, 1), value); } public void write(int ordinal, byte value) { - Platform.putByte(holder.buffer, holder.cursor, value); - setOffset(ordinal); - holder.cursor += 1; + assertIndexIsValid(ordinal); + Platform.putByte(holder.buffer, getElementOffset(ordinal, 1), value); } public void write(int ordinal, short value) { - Platform.putShort(holder.buffer, holder.cursor, value); - setOffset(ordinal); - holder.cursor += 2; + assertIndexIsValid(ordinal); + Platform.putShort(holder.buffer, getElementOffset(ordinal, 2), value); } public void write(int ordinal, int value) { - Platform.putInt(holder.buffer, holder.cursor, value); - setOffset(ordinal); - holder.cursor += 4; + assertIndexIsValid(ordinal); + Platform.putInt(holder.buffer, getElementOffset(ordinal, 4), value); } public void write(int ordinal, long value) { - Platform.putLong(holder.buffer, holder.cursor, value); - setOffset(ordinal); - holder.cursor += 8; + assertIndexIsValid(ordinal); + Platform.putLong(holder.buffer, getElementOffset(ordinal, 8), value); } public void write(int ordinal, float value) { if (Float.isNaN(value)) { value = Float.NaN; } - Platform.putFloat(holder.buffer, holder.cursor, value); - setOffset(ordinal); - holder.cursor += 4; + assertIndexIsValid(ordinal); + Platform.putFloat(holder.buffer, getElementOffset(ordinal, 4), value); } public void write(int ordinal, double value) { if (Double.isNaN(value)) { value = Double.NaN; } - Platform.putDouble(holder.buffer, holder.cursor, value); - setOffset(ordinal); - holder.cursor += 8; + assertIndexIsValid(ordinal); + Platform.putDouble(holder.buffer, getElementOffset(ordinal, 8), value); } public void write(int ordinal, Decimal input, int precision, int scale) { // make sure Decimal object has the same scale as DecimalType + assertIndexIsValid(ordinal); if (input.changePrecision(precision, scale)) { if (precision <= Decimal.MAX_LONG_DIGITS()) { - Platform.putLong(holder.buffer, holder.cursor, input.toUnscaledLong()); - setOffset(ordinal); - holder.cursor += 8; + write(ordinal, input.toUnscaledLong()); } else { final byte[] bytes = input.toJavaBigDecimal().unscaledValue().toByteArray(); - assert bytes.length <= 16; - holder.grow(bytes.length); + final int numBytes = bytes.length; + assert numBytes <= 16; + int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); + holder.grow(roundedSize); + + zeroOutPaddingBytes(numBytes); // Write the bytes to the variable length portion. Platform.copyMemory( - bytes, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, bytes.length); - setOffset(ordinal); - holder.cursor += bytes.length; + bytes, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, numBytes); + setOffsetAndSize(ordinal, holder.cursor, numBytes); + + // move the cursor forward with 8-bytes boundary + holder.cursor += roundedSize; } } else { - setNullAt(ordinal); + setNull(ordinal); } } public void write(int ordinal, UTF8String input) { final int numBytes = input.numBytes(); + final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); // grow the global buffer before writing data. - holder.grow(numBytes); + holder.grow(roundedSize); + + zeroOutPaddingBytes(numBytes); // Write the bytes to the variable length portion. input.writeToMemory(holder.buffer, holder.cursor); - setOffset(ordinal); + setOffsetAndSize(ordinal, holder.cursor, numBytes); // move the cursor forward. - holder.cursor += numBytes; + holder.cursor += roundedSize; } public void write(int ordinal, byte[] input) { + final int numBytes = input.length; + final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(input.length); + // grow the global buffer before writing data. - holder.grow(input.length); + holder.grow(roundedSize); + + zeroOutPaddingBytes(numBytes); // Write the bytes to the variable length portion. Platform.copyMemory( - input, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, input.length); + input, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, numBytes); - setOffset(ordinal); + setOffsetAndSize(ordinal, holder.cursor, numBytes); // move the cursor forward. - holder.cursor += input.length; + holder.cursor += roundedSize; } public void write(int ordinal, CalendarInterval input) { @@ -171,7 +254,7 @@ public void write(int ordinal, CalendarInterval input) { Platform.putLong(holder.buffer, holder.cursor, input.months); Platform.putLong(holder.buffer, holder.cursor + 8, input.microseconds); - setOffset(ordinal); + setOffsetAndSize(ordinal, holder.cursor, 16); // move the cursor forward. holder.cursor += 16; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 5efba4b3a6087..75bb6936b49e0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -124,7 +124,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro final int $tmpCursor = $bufferHolder.cursor; ${writeArrayToBuffer(ctx, input.value, et, bufferHolder)} $rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); - $rowWriter.alignToWords($bufferHolder.cursor - $tmpCursor); """ case m @ MapType(kt, vt, _) => @@ -134,7 +133,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro final int $tmpCursor = $bufferHolder.cursor; ${writeMapToBuffer(ctx, input.value, kt, vt, bufferHolder)} $rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); - $rowWriter.alignToWords($bufferHolder.cursor - $tmpCursor); """ case t: DecimalType => @@ -189,29 +187,33 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val jt = ctx.javaType(et) - val fixedElementSize = et match { + val elementOrOffsetSize = et match { case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS => 8 case _ if ctx.isPrimitiveType(jt) => et.defaultSize - case _ => 0 + case _ => 8 // we need 8 bytes to store offset and length } + val tmpCursor = ctx.freshName("tmpCursor") val writeElement = et match { case t: StructType => s""" - $arrayWriter.setOffset($index); + final int $tmpCursor = $bufferHolder.cursor; ${writeStructToBuffer(ctx, element, t.map(_.dataType), bufferHolder)} + $arrayWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); """ case a @ ArrayType(et, _) => s""" - $arrayWriter.setOffset($index); + final int $tmpCursor = $bufferHolder.cursor; ${writeArrayToBuffer(ctx, element, et, bufferHolder)} + $arrayWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); """ case m @ MapType(kt, vt, _) => s""" - $arrayWriter.setOffset($index); + final int $tmpCursor = $bufferHolder.cursor; ${writeMapToBuffer(ctx, element, kt, vt, bufferHolder)} + $arrayWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); """ case t: DecimalType => @@ -222,16 +224,17 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case _ => s"$arrayWriter.write($index, $element);" } + val primitiveTypeName = if (ctx.isPrimitiveType(jt)) ctx.primitiveTypeName(et) else "" s""" if ($input instanceof UnsafeArrayData) { ${writeUnsafeData(ctx, s"((UnsafeArrayData) $input)", bufferHolder)} } else { final int $numElements = $input.numElements(); - $arrayWriter.initialize($bufferHolder, $numElements, $fixedElementSize); + $arrayWriter.initialize($bufferHolder, $numElements, $elementOrOffsetSize); for (int $index = 0; $index < $numElements; $index++) { if ($input.isNullAt($index)) { - $arrayWriter.setNullAt($index); + $arrayWriter.setNull$primitiveTypeName($index); } else { final $jt $element = ${ctx.getValue(input, et, index)}; $writeElement @@ -261,16 +264,16 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro final ArrayData $keys = $input.keyArray(); final ArrayData $values = $input.valueArray(); - // preserve 4 bytes to write the key array numBytes later. - $bufferHolder.grow(4); - $bufferHolder.cursor += 4; + // preserve 8 bytes to write the key array numBytes later. + $bufferHolder.grow(8); + $bufferHolder.cursor += 8; // Remember the current cursor so that we can write numBytes of key array later. final int $tmpCursor = $bufferHolder.cursor; ${writeArrayToBuffer(ctx, keys, keyType, bufferHolder)} - // Write the numBytes of key array into the first 4 bytes. - Platform.putInt($bufferHolder.buffer, $tmpCursor - 4, $bufferHolder.cursor - $tmpCursor); + // Write the numBytes of key array into the first 8 bytes. + Platform.putLong($bufferHolder.buffer, $tmpCursor - 8, $bufferHolder.cursor - $tmpCursor); ${writeArrayToBuffer(ctx, values, valueType, bufferHolder)} } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index 1265908182b3a..90790dda753f8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -300,7 +300,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { private def testArrayInt(array: UnsafeArrayData, values: Seq[Int]): Unit = { assert(array.numElements == values.length) - assert(array.getSizeInBytes == 4 + (4 + 4) * values.length) + assert(array.getSizeInBytes == + 8 + scala.math.ceil(values.length / 64.toDouble) * 8 + roundedSize(4 * values.length)) values.zipWithIndex.foreach { case (value, index) => assert(array.getInt(index) == value) } @@ -313,7 +314,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { testArrayInt(map.keyArray, keys) testArrayInt(map.valueArray, values) - assert(map.getSizeInBytes == 4 + map.keyArray.getSizeInBytes + map.valueArray.getSizeInBytes) + assert(map.getSizeInBytes == 8 + map.keyArray.getSizeInBytes + map.valueArray.getSizeInBytes) } test("basic conversion with array type") { @@ -339,7 +340,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { val nestedArray = unsafeArray2.getArray(0) testArrayInt(nestedArray, Seq(3, 4)) - assert(unsafeArray2.getSizeInBytes == 4 + 4 + nestedArray.getSizeInBytes) + assert(unsafeArray2.getSizeInBytes == 8 + 8 + 8 + nestedArray.getSizeInBytes) val array1Size = roundedSize(unsafeArray1.getSizeInBytes) val array2Size = roundedSize(unsafeArray2.getSizeInBytes) @@ -382,10 +383,10 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { val nestedMap = valueArray.getMap(0) testMapInt(nestedMap, Seq(5, 6), Seq(7, 8)) - assert(valueArray.getSizeInBytes == 4 + 4 + nestedMap.getSizeInBytes) + assert(valueArray.getSizeInBytes == 8 + 8 + 8 + roundedSize(nestedMap.getSizeInBytes)) } - assert(unsafeMap2.getSizeInBytes == 4 + keyArray.getSizeInBytes + valueArray.getSizeInBytes) + assert(unsafeMap2.getSizeInBytes == 8 + keyArray.getSizeInBytes + valueArray.getSizeInBytes) val map1Size = roundedSize(unsafeMap1.getSizeInBytes) val map2Size = roundedSize(unsafeMap2.getSizeInBytes) @@ -425,7 +426,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(innerStruct.getLong(0) == 2L) } - assert(field2.getSizeInBytes == 4 + 4 + innerStruct.getSizeInBytes) + assert(field2.getSizeInBytes == 8 + 8 + 8 + innerStruct.getSizeInBytes) assert(unsafeRow.getSizeInBytes == 8 + 8 * 2 + field1.getSizeInBytes + roundedSize(field2.getSizeInBytes)) @@ -468,10 +469,10 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(innerStruct.getSizeInBytes == 8 + 8) assert(innerStruct.getLong(0) == 4L) - assert(valueArray.getSizeInBytes == 4 + 4 + innerStruct.getSizeInBytes) + assert(valueArray.getSizeInBytes == 8 + 8 + 8 + innerStruct.getSizeInBytes) } - assert(field2.getSizeInBytes == 4 + keyArray.getSizeInBytes + valueArray.getSizeInBytes) + assert(field2.getSizeInBytes == 8 + keyArray.getSizeInBytes + valueArray.getSizeInBytes) assert(unsafeRow.getSizeInBytes == 8 + 8 * 2 + field1.getSizeInBytes + roundedSize(field2.getSizeInBytes)) @@ -497,7 +498,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { val innerMap = field1.getMap(0) testMapInt(innerMap, Seq(1), Seq(2)) - assert(field1.getSizeInBytes == 4 + 4 + innerMap.getSizeInBytes) + assert(field1.getSizeInBytes == 8 + 8 + 8 + roundedSize(innerMap.getSizeInBytes)) val field2 = unsafeRow.getMap(1) assert(field2.numElements == 1) @@ -513,10 +514,10 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { val innerArray = valueArray.getArray(0) testArrayInt(innerArray, Seq(4)) - assert(valueArray.getSizeInBytes == 4 + (4 + innerArray.getSizeInBytes)) + assert(valueArray.getSizeInBytes == 8 + 8 + 8 + innerArray.getSizeInBytes) } - assert(field2.getSizeInBytes == 4 + keyArray.getSizeInBytes + valueArray.getSizeInBytes) + assert(field2.getSizeInBytes == 8 + keyArray.getSizeInBytes + valueArray.getSizeInBytes) assert(unsafeRow.getSizeInBytes == 8 + 8 * 2 + roundedSize(field1.getSizeInBytes) + roundedSize(field2.getSizeInBytes)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeArraySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeArraySuite.scala index 1685276ff1201..f0e247bf46c44 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeArraySuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeArraySuite.scala @@ -18,27 +18,190 @@ package org.apache.spark.sql.catalyst.util import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData +import org.apache.spark.sql.Row +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} class UnsafeArraySuite extends SparkFunSuite { - test("from primitive int array") { - val array = Array(1, 10, 100) - val unsafe = UnsafeArrayData.fromPrimitiveArray(array) - assert(unsafe.numElements == 3) - assert(unsafe.getSizeInBytes == 4 + 4 * 3 + 4 * 3) - assert(unsafe.getInt(0) == 1) - assert(unsafe.getInt(1) == 10) - assert(unsafe.getInt(2) == 100) + val booleanArray = Array(false, true) + val shortArray = Array(1.toShort, 10.toShort, 100.toShort) + val intArray = Array(1, 10, 100) + val longArray = Array(1.toLong, 10.toLong, 100.toLong) + val floatArray = Array(1.1.toFloat, 2.2.toFloat, 3.3.toFloat) + val doubleArray = Array(1.1, 2.2, 3.3) + val stringArray = Array("1", "10", "100") + val dateArray = Array( + DateTimeUtils.stringToDate(UTF8String.fromString("1970-1-1")).get, + DateTimeUtils.stringToDate(UTF8String.fromString("2016-7-26")).get) + val timestampArray = Array( + DateTimeUtils.stringToTimestamp(UTF8String.fromString("1970-1-1 00:00:00")).get, + DateTimeUtils.stringToTimestamp(UTF8String.fromString("2016-7-26 00:00:00")).get) + val decimalArray4_1 = Array( + BigDecimal("123.4").setScale(1, BigDecimal.RoundingMode.FLOOR), + BigDecimal("567.8").setScale(1, BigDecimal.RoundingMode.FLOOR)) + val decimalArray20_20 = Array( + BigDecimal("1.2345678901234567890123456").setScale(21, BigDecimal.RoundingMode.FLOOR), + BigDecimal("2.3456789012345678901234567").setScale(21, BigDecimal.RoundingMode.FLOOR)) + + val calenderintervalArray = Array(new CalendarInterval(3, 321), new CalendarInterval(1, 123)) + + val intMultiDimArray = Array(Array(1), Array(2, 20), Array(3, 30, 300)) + val doubleMultiDimArray = Array( + Array(1.1, 11.1), Array(2.2, 22.2, 222.2), Array(3.3, 33.3, 333.3, 3333.3)) + + test("read array") { + val unsafeBoolean = ExpressionEncoder[Array[Boolean]].resolveAndBind(). + toRow(booleanArray).getArray(0) + assert(unsafeBoolean.isInstanceOf[UnsafeArrayData]) + assert(unsafeBoolean.numElements == booleanArray.length) + booleanArray.zipWithIndex.map { case (e, i) => + assert(unsafeBoolean.getBoolean(i) == e) + } + + val unsafeShort = ExpressionEncoder[Array[Short]].resolveAndBind(). + toRow(shortArray).getArray(0) + assert(unsafeShort.isInstanceOf[UnsafeArrayData]) + assert(unsafeShort.numElements == shortArray.length) + shortArray.zipWithIndex.map { case (e, i) => + assert(unsafeShort.getShort(i) == e) + } + + val unsafeInt = ExpressionEncoder[Array[Int]].resolveAndBind(). + toRow(intArray).getArray(0) + assert(unsafeInt.isInstanceOf[UnsafeArrayData]) + assert(unsafeInt.numElements == intArray.length) + intArray.zipWithIndex.map { case (e, i) => + assert(unsafeInt.getInt(i) == e) + } + + val unsafeLong = ExpressionEncoder[Array[Long]].resolveAndBind(). + toRow(longArray).getArray(0) + assert(unsafeLong.isInstanceOf[UnsafeArrayData]) + assert(unsafeLong.numElements == longArray.length) + longArray.zipWithIndex.map { case (e, i) => + assert(unsafeLong.getLong(i) == e) + } + + val unsafeFloat = ExpressionEncoder[Array[Float]].resolveAndBind(). + toRow(floatArray).getArray(0) + assert(unsafeFloat.isInstanceOf[UnsafeArrayData]) + assert(unsafeFloat.numElements == floatArray.length) + floatArray.zipWithIndex.map { case (e, i) => + assert(unsafeFloat.getFloat(i) == e) + } + + val unsafeDouble = ExpressionEncoder[Array[Double]].resolveAndBind(). + toRow(doubleArray).getArray(0) + assert(unsafeDouble.isInstanceOf[UnsafeArrayData]) + assert(unsafeDouble.numElements == doubleArray.length) + doubleArray.zipWithIndex.map { case (e, i) => + assert(unsafeDouble.getDouble(i) == e) + } + + val unsafeString = ExpressionEncoder[Array[String]].resolveAndBind(). + toRow(stringArray).getArray(0) + assert(unsafeString.isInstanceOf[UnsafeArrayData]) + assert(unsafeString.numElements == stringArray.length) + stringArray.zipWithIndex.map { case (e, i) => + assert(unsafeString.getUTF8String(i).toString().equals(e)) + } + + val unsafeDate = ExpressionEncoder[Array[Int]].resolveAndBind(). + toRow(dateArray).getArray(0) + assert(unsafeDate.isInstanceOf[UnsafeArrayData]) + assert(unsafeDate.numElements == dateArray.length) + dateArray.zipWithIndex.map { case (e, i) => + assert(unsafeDate.get(i, DateType) == e) + } + + val unsafeTimestamp = ExpressionEncoder[Array[Long]].resolveAndBind(). + toRow(timestampArray).getArray(0) + assert(unsafeTimestamp.isInstanceOf[UnsafeArrayData]) + assert(unsafeTimestamp.numElements == timestampArray.length) + timestampArray.zipWithIndex.map { case (e, i) => + assert(unsafeTimestamp.get(i, TimestampType) == e) + } + + Seq(decimalArray4_1, decimalArray20_20).map { decimalArray => + val decimal = decimalArray(0) + val schema = new StructType().add( + "array", ArrayType(DecimalType(decimal.precision, decimal.scale))) + val encoder = RowEncoder(schema).resolveAndBind() + val externalRow = Row(decimalArray) + val ir = encoder.toRow(externalRow) + + val unsafeDecimal = ir.getArray(0) + assert(unsafeDecimal.isInstanceOf[UnsafeArrayData]) + assert(unsafeDecimal.numElements == decimalArray.length) + decimalArray.zipWithIndex.map { case (e, i) => + assert(unsafeDecimal.getDecimal(i, e.precision, e.scale).toBigDecimal == e) + } + } + + val schema = new StructType().add("array", ArrayType(CalendarIntervalType)) + val encoder = RowEncoder(schema).resolveAndBind() + val externalRow = Row(calenderintervalArray) + val ir = encoder.toRow(externalRow) + val unsafeCalendar = ir.getArray(0) + assert(unsafeCalendar.isInstanceOf[UnsafeArrayData]) + assert(unsafeCalendar.numElements == calenderintervalArray.length) + calenderintervalArray.zipWithIndex.map { case (e, i) => + assert(unsafeCalendar.getInterval(i) == e) + } + + val unsafeMultiDimInt = ExpressionEncoder[Array[Array[Int]]].resolveAndBind(). + toRow(intMultiDimArray).getArray(0) + assert(unsafeMultiDimInt.isInstanceOf[UnsafeArrayData]) + assert(unsafeMultiDimInt.numElements == intMultiDimArray.length) + intMultiDimArray.zipWithIndex.map { case (a, j) => + val u = unsafeMultiDimInt.getArray(j) + assert(u.isInstanceOf[UnsafeArrayData]) + assert(u.numElements == a.length) + a.zipWithIndex.map { case (e, i) => + assert(u.getInt(i) == e) + } + } + + val unsafeMultiDimDouble = ExpressionEncoder[Array[Array[Double]]].resolveAndBind(). + toRow(doubleMultiDimArray).getArray(0) + assert(unsafeDouble.isInstanceOf[UnsafeArrayData]) + assert(unsafeMultiDimDouble.numElements == doubleMultiDimArray.length) + doubleMultiDimArray.zipWithIndex.map { case (a, j) => + val u = unsafeMultiDimDouble.getArray(j) + assert(u.isInstanceOf[UnsafeArrayData]) + assert(u.numElements == a.length) + a.zipWithIndex.map { case (e, i) => + assert(u.getDouble(i) == e) + } + } } - test("from primitive double array") { - val array = Array(1.1, 2.2, 3.3) - val unsafe = UnsafeArrayData.fromPrimitiveArray(array) - assert(unsafe.numElements == 3) - assert(unsafe.getSizeInBytes == 4 + 4 * 3 + 8 * 3) - assert(unsafe.getDouble(0) == 1.1) - assert(unsafe.getDouble(1) == 2.2) - assert(unsafe.getDouble(2) == 3.3) + test("from primitive array") { + val unsafeInt = UnsafeArrayData.fromPrimitiveArray(intArray) + assert(unsafeInt.numElements == 3) + assert(unsafeInt.getSizeInBytes == + ((8 + scala.math.ceil(3/64.toDouble) * 8 + 4 * 3 + 7).toInt / 8) * 8) + intArray.zipWithIndex.map { case (e, i) => + assert(unsafeInt.getInt(i) == e) + } + + val unsafeDouble = UnsafeArrayData.fromPrimitiveArray(doubleArray) + assert(unsafeDouble.numElements == 3) + assert(unsafeDouble.getSizeInBytes == + ((8 + scala.math.ceil(3/64.toDouble) * 8 + 8 * 3 + 7).toInt / 8) * 8) + doubleArray.zipWithIndex.map { case (e, i) => + assert(unsafeDouble.getDouble(i) == e) + } + } + + test("to primitive array") { + val intEncoder = ExpressionEncoder[Array[Int]].resolveAndBind() + assert(intEncoder.toRow(intArray).getArray(0).toIntArray.sameElements(intArray)) + + val doubleEncoder = ExpressionEncoder[Array[Double]].resolveAndBind() + assert(doubleEncoder.toRow(doubleArray).getArray(0).toDoubleArray.sameElements(doubleArray)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala index f9d606e37ea89..fa9619eb07fec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala @@ -601,7 +601,7 @@ private[columnar] case class ARRAY(dataType: ArrayType) override def actualSize(row: InternalRow, ordinal: Int): Int = { val unsafeArray = getField(row, ordinal) - 4 + unsafeArray.getSizeInBytes + 8 + unsafeArray.getSizeInBytes } override def append(value: UnsafeArrayData, buffer: ByteBuffer): Unit = { @@ -640,7 +640,7 @@ private[columnar] case class MAP(dataType: MapType) override def actualSize(row: InternalRow, ordinal: Int): Int = { val unsafeMap = getField(row, ordinal) - 4 + unsafeMap.getSizeInBytes + 8 + unsafeMap.getSizeInBytes } override def append(value: UnsafeMapData, buffer: ByteBuffer): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/UnsafeArrayDataBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/UnsafeArrayDataBenchmark.scala new file mode 100644 index 0000000000000..6c7779b5790d0 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/UnsafeArrayDataBenchmark.scala @@ -0,0 +1,232 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.benchmark + +import scala.util.Random + +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions.{UnsafeArrayData, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeArrayWriter} +import org.apache.spark.util.Benchmark + +/** + * Benchmark [[UnsafeArrayDataBenchmark]] for UnsafeArrayData + * To run this: + * 1. replace ignore(...) with test(...) + * 2. build/sbt "sql/test-only *benchmark.UnsafeArrayDataBenchmark" + * + * Benchmarks in this file are skipped in normal builds. + */ +class UnsafeArrayDataBenchmark extends BenchmarkBase { + + def calculateHeaderPortionInBytes(count: Int) : Int = { + /* 4 + 4 * count // Use this expression for SPARK-15962 */ + UnsafeArrayData.calculateHeaderPortionInBytes(count) + } + + def readUnsafeArray(iters: Int): Unit = { + val count = 1024 * 1024 * 16 + val rand = new Random(42) + + val intPrimitiveArray = Array.fill[Int](count) { rand.nextInt } + val intEncoder = ExpressionEncoder[Array[Int]].resolveAndBind() + val intUnsafeArray = intEncoder.toRow(intPrimitiveArray).getArray(0) + val readIntArray = { i: Int => + var n = 0 + while (n < iters) { + val len = intUnsafeArray.numElements + var sum = 0 + var i = 0 + while (i < len) { + sum += intUnsafeArray.getInt(i) + i += 1 + } + n += 1 + } + } + + val doublePrimitiveArray = Array.fill[Double](count) { rand.nextDouble } + val doubleEncoder = ExpressionEncoder[Array[Double]].resolveAndBind() + val doubleUnsafeArray = doubleEncoder.toRow(doublePrimitiveArray).getArray(0) + val readDoubleArray = { i: Int => + var n = 0 + while (n < iters) { + val len = doubleUnsafeArray.numElements + var sum = 0.0 + var i = 0 + while (i < len) { + sum += doubleUnsafeArray.getDouble(i) + i += 1 + } + n += 1 + } + } + + val benchmark = new Benchmark("Read UnsafeArrayData", count * iters) + benchmark.addCase("Int")(readIntArray) + benchmark.addCase("Double")(readDoubleArray) + benchmark.run + /* + OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 4.4.11-200.fc22.x86_64 + Intel Xeon E3-12xx v2 (Ivy Bridge) + Read UnsafeArrayData: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Int 252 / 260 666.1 1.5 1.0X + Double 281 / 292 597.7 1.7 0.9X + */ + } + + def writeUnsafeArray(iters: Int): Unit = { + val count = 1024 * 1024 * 2 + val rand = new Random(42) + + var intTotalLength: Int = 0 + val intPrimitiveArray = Array.fill[Int](count) { rand.nextInt } + val intEncoder = ExpressionEncoder[Array[Int]].resolveAndBind() + val writeIntArray = { i: Int => + var len = 0 + var n = 0 + while (n < iters) { + len += intEncoder.toRow(intPrimitiveArray).getArray(0).numElements() + n += 1 + } + intTotalLength = len + } + + var doubleTotalLength: Int = 0 + val doublePrimitiveArray = Array.fill[Double](count) { rand.nextDouble } + val doubleEncoder = ExpressionEncoder[Array[Double]].resolveAndBind() + val writeDoubleArray = { i: Int => + var len = 0 + var n = 0 + while (n < iters) { + len += doubleEncoder.toRow(doublePrimitiveArray).getArray(0).numElements() + n += 1 + } + doubleTotalLength = len + } + + val benchmark = new Benchmark("Write UnsafeArrayData", count * iters) + benchmark.addCase("Int")(writeIntArray) + benchmark.addCase("Double")(writeDoubleArray) + benchmark.run + /* + OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 4.4.11-200.fc22.x86_64 + Intel Xeon E3-12xx v2 (Ivy Bridge) + Write UnsafeArrayData: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Int 196 / 249 107.0 9.3 1.0X + Double 227 / 367 92.3 10.8 0.9X + */ + } + + def getPrimitiveArray(iters: Int): Unit = { + val count = 1024 * 1024 * 12 + val rand = new Random(42) + + var intTotalLength: Int = 0 + val intPrimitiveArray = Array.fill[Int](count) { rand.nextInt } + val intEncoder = ExpressionEncoder[Array[Int]].resolveAndBind() + val intUnsafeArray = intEncoder.toRow(intPrimitiveArray).getArray(0) + val readIntArray = { i: Int => + var len = 0 + var n = 0 + while (n < iters) { + len += intUnsafeArray.toIntArray.length + n += 1 + } + intTotalLength = len + } + + var doubleTotalLength: Int = 0 + val doublePrimitiveArray = Array.fill[Double](count) { rand.nextDouble } + val doubleEncoder = ExpressionEncoder[Array[Double]].resolveAndBind() + val doubleUnsafeArray = doubleEncoder.toRow(doublePrimitiveArray).getArray(0) + val readDoubleArray = { i: Int => + var len = 0 + var n = 0 + while (n < iters) { + len += doubleUnsafeArray.toDoubleArray.length + n += 1 + } + doubleTotalLength = len + } + + val benchmark = new Benchmark("Get primitive array from UnsafeArrayData", count * iters) + benchmark.addCase("Int")(readIntArray) + benchmark.addCase("Double")(readDoubleArray) + benchmark.run + /* + OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 4.4.11-200.fc22.x86_64 + Intel Xeon E3-12xx v2 (Ivy Bridge) + Get primitive array from UnsafeArrayData: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Int 151 / 198 415.8 2.4 1.0X + Double 214 / 394 293.6 3.4 0.7X + */ + } + + def putPrimitiveArray(iters: Int): Unit = { + val count = 1024 * 1024 * 12 + val rand = new Random(42) + + var intTotalLen: Int = 0 + val intPrimitiveArray = Array.fill[Int](count) { rand.nextInt } + val createIntArray = { i: Int => + var len = 0 + var n = 0 + while (n < iters) { + len += UnsafeArrayData.fromPrimitiveArray(intPrimitiveArray).numElements + n += 1 + } + intTotalLen = len + } + + var doubleTotalLen: Int = 0 + val doublePrimitiveArray = Array.fill[Double](count) { rand.nextDouble } + val createDoubleArray = { i: Int => + var len = 0 + var n = 0 + while (n < iters) { + len += UnsafeArrayData.fromPrimitiveArray(doublePrimitiveArray).numElements + n += 1 + } + doubleTotalLen = len + } + + val benchmark = new Benchmark("Create UnsafeArrayData from primitive array", count * iters) + benchmark.addCase("Int")(createIntArray) + benchmark.addCase("Double")(createDoubleArray) + benchmark.run + /* + OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 4.4.11-200.fc22.x86_64 + Intel Xeon E3-12xx v2 (Ivy Bridge) + Create UnsafeArrayData from primitive array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Int 206 / 211 306.0 3.3 1.0X + Double 232 / 406 271.6 3.7 0.9X + */ + } + + ignore("Benchmark UnsafeArrayData") { + readUnsafeArray(10) + writeUnsafeArray(10) + getPrimitiveArray(5) + putPrimitiveArray(5) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala index 052f4cbaebc8e..0b93c633b2d93 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala @@ -73,8 +73,8 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { checkActualSize(BINARY, Array.fill[Byte](4)(0.toByte), 4 + 4) checkActualSize(COMPACT_DECIMAL(15, 10), Decimal(0, 15, 10), 8) checkActualSize(LARGE_DECIMAL(20, 10), Decimal(0, 20, 10), 5) - checkActualSize(ARRAY_TYPE, Array[Any](1), 16) - checkActualSize(MAP_TYPE, Map(1 -> "a"), 29) + checkActualSize(ARRAY_TYPE, Array[Any](1), 8 + 8 + 8 + 8) + checkActualSize(MAP_TYPE, Map(1 -> "a"), 8 + (8 + 8 + 8 + 8) + (8 + 8 + 8 + 8)) checkActualSize(STRUCT_TYPE, Row("hello"), 28) } From 7f16affa262b059580ed2775a7b05a767aa72315 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Tue, 27 Sep 2016 00:00:21 -0700 Subject: [PATCH 131/306] [SPARK-17138][ML][MLIB] Add Python API for multinomial logistic regression ## What changes were proposed in this pull request? Add Python API for multinomial logistic regression. - add `family` param in python api. - expose `coefficientMatrix` and `interceptVector` for `LogisticRegressionModel` - add python-side testcase for multinomial logistic regression - update python doc. ## How was this patch tested? existing and added doc tests. Author: WeichenXu Closes #14852 from WeichenXu123/add_MLOR_python. --- python/pyspark/ml/classification.py | 90 ++++++++++++++++++++++------- 1 file changed, 70 insertions(+), 20 deletions(-) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index b4c01fd5c4ffb..505e7bffd1763 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -67,21 +67,34 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti HasWeightCol, HasAggregationDepth, JavaMLWritable, JavaMLReadable): """ Logistic regression. - Currently, this class only supports binary classification. + This class supports multinomial logistic (softmax) and binomial logistic regression. >>> from pyspark.sql import Row >>> from pyspark.ml.linalg import Vectors - >>> df = sc.parallelize([ + >>> bdf = sc.parallelize([ ... Row(label=1.0, weight=2.0, features=Vectors.dense(1.0)), ... Row(label=0.0, weight=2.0, features=Vectors.sparse(1, [], []))]).toDF() - >>> lr = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight") - >>> model = lr.fit(df) - >>> model.coefficients + >>> blor = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight") + >>> blorModel = blor.fit(bdf) + >>> blorModel.coefficients DenseVector([5.5...]) - >>> model.intercept + >>> blorModel.intercept -2.68... + >>> mdf = sc.parallelize([ + ... Row(label=1.0, weight=2.0, features=Vectors.dense(1.0)), + ... Row(label=0.0, weight=2.0, features=Vectors.sparse(1, [], [])), + ... Row(label=2.0, weight=2.0, features=Vectors.dense(3.0))]).toDF() + >>> mlor = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight", + ... family="multinomial") + >>> mlorModel = mlor.fit(mdf) + >>> print(mlorModel.coefficientMatrix) + DenseMatrix([[-2.3...], + [ 0.2...], + [ 2.1... ]]) + >>> mlorModel.interceptVector + DenseVector([2.0..., 0.8..., -2.8...]) >>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0))]).toDF() - >>> result = model.transform(test0).head() + >>> result = blorModel.transform(test0).head() >>> result.prediction 0.0 >>> result.probability @@ -89,23 +102,23 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti >>> result.rawPrediction DenseVector([8.22..., -8.22...]) >>> test1 = sc.parallelize([Row(features=Vectors.sparse(1, [0], [1.0]))]).toDF() - >>> model.transform(test1).head().prediction + >>> blorModel.transform(test1).head().prediction 1.0 - >>> lr.setParams("vector") + >>> blor.setParams("vector") Traceback (most recent call last): ... TypeError: Method setParams forces keyword arguments. >>> lr_path = temp_path + "/lr" - >>> lr.save(lr_path) + >>> blor.save(lr_path) >>> lr2 = LogisticRegression.load(lr_path) >>> lr2.getMaxIter() 5 >>> model_path = temp_path + "/lr_model" - >>> model.save(model_path) + >>> blorModel.save(model_path) >>> model2 = LogisticRegressionModel.load(model_path) - >>> model.coefficients[0] == model2.coefficients[0] + >>> blorModel.coefficients[0] == model2.coefficients[0] True - >>> model.intercept == model2.intercept + >>> blorModel.intercept == model2.intercept True .. versionadded:: 1.3.0 @@ -117,24 +130,29 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti "e.g. if threshold is p, then thresholds must be equal to [1-p, p].", typeConverter=TypeConverters.toFloat) + family = Param(Params._dummy(), "family", + "The name of family which is a description of the label distribution to " + + "be used in the model. Supported options: auto, binomial, multinomial", + typeConverter=TypeConverters.toString) + @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, threshold=0.5, thresholds=None, probabilityCol="probability", rawPredictionCol="rawPrediction", standardization=True, weightCol=None, - aggregationDepth=2): + aggregationDepth=2, family="auto"): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ threshold=0.5, thresholds=None, probabilityCol="probability", \ rawPredictionCol="rawPrediction", standardization=True, weightCol=None, \ - aggregationDepth=2) + aggregationDepth=2, family="auto") If the threshold and thresholds Params are both set, they must be equivalent. """ super(LogisticRegression, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.classification.LogisticRegression", self.uid) - self._setDefault(maxIter=100, regParam=0.0, tol=1E-6, threshold=0.5) + self._setDefault(maxIter=100, regParam=0.0, tol=1E-6, threshold=0.5, family="auto") kwargs = self.__init__._input_kwargs self.setParams(**kwargs) self._checkThresholdConsistency() @@ -145,13 +163,13 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, threshold=0.5, thresholds=None, probabilityCol="probability", rawPredictionCol="rawPrediction", standardization=True, weightCol=None, - aggregationDepth=2): + aggregationDepth=2, family="auto"): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ threshold=0.5, thresholds=None, probabilityCol="probability", \ rawPredictionCol="rawPrediction", standardization=True, weightCol=None, \ - aggregationDepth=2) + aggregationDepth=2, family="auto") Sets params for logistic regression. If the threshold and thresholds Params are both set, they must be equivalent. """ @@ -232,6 +250,20 @@ def _checkThresholdConsistency(self): raise ValueError("Logistic Regression getThreshold found inconsistent values for" + " threshold (%g) and thresholds (equivalent to %g)" % (t2, t)) + @since("2.1.0") + def setFamily(self, value): + """ + Sets the value of :py:attr:`family`. + """ + return self._set(family=value) + + @since("2.1.0") + def getFamily(self): + """ + Gets the value of :py:attr:`family` or its default value. + """ + return self.getOrDefault(self.family) + class LogisticRegressionModel(JavaModel, JavaClassificationModel, JavaMLWritable, JavaMLReadable): """ @@ -244,7 +276,8 @@ class LogisticRegressionModel(JavaModel, JavaClassificationModel, JavaMLWritable @since("2.0.0") def coefficients(self): """ - Model coefficients. + Model coefficients of binomial logistic regression. + An exception is thrown in the case of multinomial logistic regression. """ return self._call_java("coefficients") @@ -252,10 +285,27 @@ def coefficients(self): @since("1.4.0") def intercept(self): """ - Model intercept. + Model intercept of binomial logistic regression. + An exception is thrown in the case of multinomial logistic regression. """ return self._call_java("intercept") + @property + @since("2.1.0") + def coefficientMatrix(self): + """ + Model coefficients. + """ + return self._call_java("coefficientMatrix") + + @property + @since("2.1.0") + def interceptVector(self): + """ + Model intercept. + """ + return self._call_java("interceptVector") + @property @since("2.0.0") def summary(self): From 6a68c5d7b4eb07e4ed6b702dd1536cd08d9bba7d Mon Sep 17 00:00:00 2001 From: Weiqing Yang Date: Tue, 27 Sep 2016 08:10:38 -0500 Subject: [PATCH 132/306] [SPARK-16757] Set up Spark caller context to HDFS and YARN ## What changes were proposed in this pull request? 1. Pass `jobId` to Task. 2. Invoke Hadoop APIs. * A new function `setCallerContext` is added in `Utils`. `setCallerContext` function invokes APIs of `org.apache.hadoop.ipc.CallerContext` to set up spark caller contexts, which will be written into `hdfs-audit.log` and Yarn RM audit log. * For HDFS: Spark sets up its caller context by invoking`org.apache.hadoop.ipc.CallerContext` in `Task` and Yarn `Client` and `ApplicationMaster`. * For Yarn: Spark sets up its caller context by invoking `org.apache.hadoop.ipc.CallerContext` in Yarn `Client`. ## How was this patch tested? Manual Tests against some Spark applications in Yarn client mode and Yarn cluster mode. Need to check if spark caller contexts are written into HDFS hdfs-audit.log and Yarn RM audit log successfully. For example, run SparkKmeans in Yarn client mode: ``` ./bin/spark-submit --verbose --executor-cores 3 --num-executors 1 --master yarn --deploy-mode client --class org.apache.spark.examples.SparkKMeans examples/target/original-spark-examples_2.11-2.1.0-SNAPSHOT.jar hdfs://localhost:9000/lr_big.txt 2 5 ``` **Before**: There will be no Spark caller context in records of `hdfs-audit.log` and Yarn RM audit log. **After**: Spark caller contexts will be written in records of `hdfs-audit.log` and Yarn RM audit log. These are records in `hdfs-audit.log`: ``` 2016-09-20 11:54:24,116 INFO FSNamesystem.audit: allowed=true ugi=wyang (auth:SIMPLE) ip=/127.0.0.1 cmd=open src=/lr_big.txt dst=null perm=null proto=rpc callerContext=SPARK_CLIENT_AppId_application_1474394339641_0005 2016-09-20 11:54:28,164 INFO FSNamesystem.audit: allowed=true ugi=wyang (auth:SIMPLE) ip=/127.0.0.1 cmd=open src=/lr_big.txt dst=null perm=null proto=rpc callerContext=SPARK_TASK_AppId_application_1474394339641_0005_JobId_0_StageId_0_AttemptId_0_TaskId_2_AttemptNum_0 2016-09-20 11:54:28,164 INFO FSNamesystem.audit: allowed=true ugi=wyang (auth:SIMPLE) ip=/127.0.0.1 cmd=open src=/lr_big.txt dst=null perm=null proto=rpc callerContext=SPARK_TASK_AppId_application_1474394339641_0005_JobId_0_StageId_0_AttemptId_0_TaskId_1_AttemptNum_0 2016-09-20 11:54:28,164 INFO FSNamesystem.audit: allowed=true ugi=wyang (auth:SIMPLE) ip=/127.0.0.1 cmd=open src=/lr_big.txt dst=null perm=null proto=rpc callerContext=SPARK_TASK_AppId_application_1474394339641_0005_JobId_0_StageId_0_AttemptId_0_TaskId_0_AttemptNum_0 ``` ``` 2016-09-20 11:59:33,868 INFO FSNamesystem.audit: allowed=true ugi=wyang (auth:SIMPLE) ip=/127.0.0.1 cmd=mkdirs src=/private/tmp/hadoop-wyang/nm-local-dir/usercache/wyang/appcache/application_1474394339641_0006/container_1474394339641_0006_01_000001/spark-warehouse dst=null perm=wyang:supergroup:rwxr-xr-x proto=rpc callerContext=SPARK_APPLICATION_MASTER_AppId_application_1474394339641_0006_AttemptId_1 2016-09-20 11:59:37,214 INFO FSNamesystem.audit: allowed=true ugi=wyang (auth:SIMPLE) ip=/127.0.0.1 cmd=open src=/lr_big.txt dst=null perm=null proto=rpc callerContext=SPARK_TASK_AppId_application_1474394339641_0006_AttemptId_1_JobId_0_StageId_0_AttemptId_0_TaskId_1_AttemptNum_0 2016-09-20 11:59:37,215 INFO FSNamesystem.audit: allowed=true ugi=wyang (auth:SIMPLE) ip=/127.0.0.1 cmd=open src=/lr_big.txt dst=null perm=null proto=rpc callerContext=SPARK_TASK_AppId_application_1474394339641_0006_AttemptId_1_JobId_0_StageId_0_AttemptId_0_TaskId_2_AttemptNum_0 2016-09-20 11:59:37,215 INFO FSNamesystem.audit: allowed=true ugi=wyang (auth:SIMPLE) ip=/127.0.0.1 cmd=open src=/lr_big.txt dst=null perm=null proto=rpc callerContext=SPARK_TASK_AppId_application_1474394339641_0006_AttemptId_1_JobId_0_StageId_0_AttemptId_0_TaskId_0_AttemptNum_0 2016-09-20 11:59:42,391 INFO FSNamesystem.audit: allowed=true ugi=wyang (auth:SIMPLE) ip=/127.0.0.1 cmd=open src=/lr_big.txt dst=null perm=null proto=rpc callerContext=SPARK_TASK_AppId_application_1474394339641_0006_AttemptId_1_JobId_0_StageId_0_AttemptId_0_TaskId_3_AttemptNum_0 ``` This is a record in Yarn RM log: ``` 2016-09-20 11:59:24,050 INFO org.apache.hadoop.yarn.server.resourcemanager.RMAuditLogger: USER=wyang IP=127.0.0.1 OPERATION=Submit Application Request TARGET=ClientRMService RESULT=SUCCESS APPID=application_1474394339641_0006 CALLERCONTEXT=SPARK_CLIENT_AppId_application_1474394339641_0006 ``` Author: Weiqing Yang Closes #14659 from Sherry302/callercontextSubmit. --- .../apache/spark/scheduler/DAGScheduler.scala | 6 +- .../apache/spark/scheduler/ResultTask.scala | 15 ++++- .../spark/scheduler/ShuffleMapTask.scala | 13 +++- .../org/apache/spark/scheduler/Task.scala | 17 ++++- .../scala/org/apache/spark/util/Utils.scala | 62 +++++++++++++++++++ .../org/apache/spark/util/UtilsSuite.scala | 12 ++++ .../spark/deploy/yarn/ApplicationMaster.scala | 7 +++ .../org/apache/spark/deploy/yarn/Client.scala | 4 +- 8 files changed, 126 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index dd47c1dbbec06..5ea0b48f6e4c4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1015,7 +1015,8 @@ class DAGScheduler( val locs = taskIdToLocations(id) val part = stage.rdd.partitions(id) new ShuffleMapTask(stage.id, stage.latestInfo.attemptId, - taskBinary, part, locs, stage.latestInfo.taskMetrics, properties) + taskBinary, part, locs, stage.latestInfo.taskMetrics, properties, Option(jobId), + Option(sc.applicationId), sc.applicationAttemptId) } case stage: ResultStage => @@ -1024,7 +1025,8 @@ class DAGScheduler( val part = stage.rdd.partitions(p) val locs = taskIdToLocations(id) new ResultTask(stage.id, stage.latestInfo.attemptId, - taskBinary, part, locs, id, properties, stage.latestInfo.taskMetrics) + taskBinary, part, locs, id, properties, stage.latestInfo.taskMetrics, + Option(jobId), Option(sc.applicationId), sc.applicationAttemptId) } } } catch { diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index 609f10aee940d..1e7c63af2e797 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -43,7 +43,12 @@ import org.apache.spark.rdd.RDD * input RDD's partitions). * @param localProperties copy of thread-local properties set by the user on the driver side. * @param metrics a [[TaskMetrics]] that is created at driver side and sent to executor side. - */ + * + * The parameters below are optional: + * @param jobId id of the job this task belongs to + * @param appId id of the app this task belongs to + * @param appAttemptId attempt id of the app this task belongs to + */ private[spark] class ResultTask[T, U]( stageId: Int, stageAttemptId: Int, @@ -52,8 +57,12 @@ private[spark] class ResultTask[T, U]( locs: Seq[TaskLocation], val outputId: Int, localProperties: Properties, - metrics: TaskMetrics) - extends Task[U](stageId, stageAttemptId, partition.index, metrics, localProperties) + metrics: TaskMetrics, + jobId: Option[Int] = None, + appId: Option[String] = None, + appAttemptId: Option[String] = None) + extends Task[U](stageId, stageAttemptId, partition.index, metrics, localProperties, jobId, + appId, appAttemptId) with Serializable { @transient private[this] val preferredLocs: Seq[TaskLocation] = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index 448fe02084e0d..66d6790e168f2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -44,6 +44,11 @@ import org.apache.spark.shuffle.ShuffleWriter * @param locs preferred task execution locations for locality scheduling * @param metrics a [[TaskMetrics]] that is created at driver side and sent to executor side. * @param localProperties copy of thread-local properties set by the user on the driver side. + * + * The parameters below are optional: + * @param jobId id of the job this task belongs to + * @param appId id of the app this task belongs to + * @param appAttemptId attempt id of the app this task belongs to */ private[spark] class ShuffleMapTask( stageId: Int, @@ -52,8 +57,12 @@ private[spark] class ShuffleMapTask( partition: Partition, @transient private var locs: Seq[TaskLocation], metrics: TaskMetrics, - localProperties: Properties) - extends Task[MapStatus](stageId, stageAttemptId, partition.index, metrics, localProperties) + localProperties: Properties, + jobId: Option[Int] = None, + appId: Option[String] = None, + appAttemptId: Option[String] = None) + extends Task[MapStatus](stageId, stageAttemptId, partition.index, metrics, localProperties, jobId, + appId, appAttemptId) with Logging { /** A constructor used only in test suites. This does not require passing in an RDD. */ diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 48daa344f3c88..9385e3c31e1e4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -29,7 +29,7 @@ import org.apache.spark.executor.TaskMetrics import org.apache.spark.memory.{MemoryMode, TaskMemoryManager} import org.apache.spark.metrics.MetricsSystem import org.apache.spark.serializer.SerializerInstance -import org.apache.spark.util.{AccumulatorV2, ByteBufferInputStream, ByteBufferOutputStream, Utils} +import org.apache.spark.util._ /** * A unit of execution. We have two kinds of Task's in Spark: @@ -47,6 +47,11 @@ import org.apache.spark.util.{AccumulatorV2, ByteBufferInputStream, ByteBufferOu * @param partitionId index of the number in the RDD * @param metrics a [[TaskMetrics]] that is created at driver side and sent to executor side. * @param localProperties copy of thread-local properties set by the user on the driver side. + * + * The parameters below are optional: + * @param jobId id of the job this task belongs to + * @param appId id of the app this task belongs to + * @param appAttemptId attempt id of the app this task belongs to */ private[spark] abstract class Task[T]( val stageId: Int, @@ -54,7 +59,10 @@ private[spark] abstract class Task[T]( val partitionId: Int, // The default value is only used in tests. val metrics: TaskMetrics = TaskMetrics.registered, - @transient var localProperties: Properties = new Properties) extends Serializable { + @transient var localProperties: Properties = new Properties, + val jobId: Option[Int] = None, + val appId: Option[String] = None, + val appAttemptId: Option[String] = None) extends Serializable { /** * Called by [[org.apache.spark.executor.Executor]] to run this task. @@ -79,9 +87,14 @@ private[spark] abstract class Task[T]( metrics) TaskContext.setTaskContext(context) taskThread = Thread.currentThread() + if (_killed) { kill(interruptThread = false) } + + new CallerContext("TASK", appId, appAttemptId, jobId, Option(stageId), Option(stageAttemptId), + Option(taskAttemptId), Option(attemptNumber)).setCurrentContext() + try { runTask(context) } catch { diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index e09666c6103c6..caa768cfbdc6c 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -2440,6 +2440,68 @@ private[spark] object Utils extends Logging { } } +/** + * An utility class used to set up Spark caller contexts to HDFS and Yarn. The `context` will be + * constructed by parameters passed in. + * When Spark applications run on Yarn and HDFS, its caller contexts will be written into Yarn RM + * audit log and hdfs-audit.log. That can help users to better diagnose and understand how + * specific applications impacting parts of the Hadoop system and potential problems they may be + * creating (e.g. overloading NN). As HDFS mentioned in HDFS-9184, for a given HDFS operation, it's + * very helpful to track which upper level job issues it. + * + * @param from who sets up the caller context (TASK, CLIENT, APPMASTER) + * + * The parameters below are optional: + * @param appId id of the app this task belongs to + * @param appAttemptId attempt id of the app this task belongs to + * @param jobId id of the job this task belongs to + * @param stageId id of the stage this task belongs to + * @param stageAttemptId attempt id of the stage this task belongs to + * @param taskId task id + * @param taskAttemptNumber task attempt id + */ +private[spark] class CallerContext( + from: String, + appId: Option[String] = None, + appAttemptId: Option[String] = None, + jobId: Option[Int] = None, + stageId: Option[Int] = None, + stageAttemptId: Option[Int] = None, + taskId: Option[Long] = None, + taskAttemptNumber: Option[Int] = None) extends Logging { + + val appIdStr = if (appId.isDefined) s"_${appId.get}" else "" + val appAttemptIdStr = if (appAttemptId.isDefined) s"_${appAttemptId.get}" else "" + val jobIdStr = if (jobId.isDefined) s"_JId_${jobId.get}" else "" + val stageIdStr = if (stageId.isDefined) s"_SId_${stageId.get}" else "" + val stageAttemptIdStr = if (stageAttemptId.isDefined) s"_${stageAttemptId.get}" else "" + val taskIdStr = if (taskId.isDefined) s"_TId_${taskId.get}" else "" + val taskAttemptNumberStr = + if (taskAttemptNumber.isDefined) s"_${taskAttemptNumber.get}" else "" + + val context = "SPARK_" + from + appIdStr + appAttemptIdStr + + jobIdStr + stageIdStr + stageAttemptIdStr + taskIdStr + taskAttemptNumberStr + + /** + * Set up the caller context [[context]] by invoking Hadoop CallerContext API of + * [[org.apache.hadoop.ipc.CallerContext]], which was added in hadoop 2.8. + */ + def setCurrentContext(): Boolean = { + var succeed = false + try { + val callerContext = Utils.classForName("org.apache.hadoop.ipc.CallerContext") + val Builder = Utils.classForName("org.apache.hadoop.ipc.CallerContext$Builder") + val builderInst = Builder.getConstructor(classOf[String]).newInstance(context) + val hdfsContext = Builder.getMethod("build").invoke(builderInst) + callerContext.getMethod("setCurrent", callerContext).invoke(null, hdfsContext) + succeed = true + } catch { + case NonFatal(e) => logInfo("Fail to set Spark caller context", e) + } + succeed + } +} + /** * A utility class to redirect the child process's stdout or stderr. */ diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 4715fd29375d6..bc28b2d9cb831 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -788,6 +788,18 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { .set("spark.executor.instances", "1")) === 3) } + test("Set Spark CallerContext") { + val context = "test" + try { + val callerContext = Utils.classForName("org.apache.hadoop.ipc.CallerContext") + assert(new CallerContext(context).setCurrentContext()) + assert(s"SPARK_$context" === + callerContext.getMethod("getCurrent").invoke(null).toString) + } catch { + case e: ClassNotFoundException => + assert(!new CallerContext(context).setCurrentContext()) + } + } test("encodeFileNameToURIRawPath") { assert(Utils.encodeFileNameToURIRawPath("abc") === "abc") diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index ad50ea789a913..aabae140af8b1 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -184,6 +184,8 @@ private[spark] class ApplicationMaster( try { val appAttemptId = client.getAttemptId() + var attemptID: Option[String] = None + if (isClusterMode) { // Set the web ui port to be ephemeral for yarn so we don't conflict with // other spark processes running on the same box @@ -196,8 +198,13 @@ private[spark] class ApplicationMaster( // Set this internal configuration if it is running on cluster mode, this // configuration will be checked in SparkContext to avoid misuse of yarn cluster mode. System.setProperty("spark.yarn.app.id", appAttemptId.getApplicationId().toString()) + + attemptID = Option(appAttemptId.getAttemptId.toString) } + new CallerContext("APPMASTER", + Option(appAttemptId.getApplicationId.toString), attemptID).setCurrentContext() + logInfo("ApplicationAttemptId: " + appAttemptId) val fs = FileSystem.get(yarnConf) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 2398f0aea316a..ea4e1160b7672 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -54,7 +54,7 @@ import org.apache.spark.deploy.yarn.security.ConfigurableCredentialManager import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.launcher.{LauncherBackend, SparkAppHandle, YarnCommandBuilderUtils} -import org.apache.spark.util.Utils +import org.apache.spark.util.{CallerContext, Utils} private[spark] class Client( val args: ClientArguments, @@ -161,6 +161,8 @@ private[spark] class Client( reportLauncherState(SparkAppHandle.State.SUBMITTED) launcherBackend.setAppId(appId.toString) + new CallerContext("CLIENT", Option(appId.toString)).setCurrentContext() + // Verify whether the cluster has enough resources for our AM verifyClusterResources(newAppResponse) From 5de1737b02710e36f6804d2ae243d1aeb30a0b32 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Wed, 28 Sep 2016 00:39:47 +0800 Subject: [PATCH 133/306] [SPARK-16777][SQL] Do not use deprecated listType API in ParquetSchemaConverter ## What changes were proposed in this pull request? This PR removes build waning as below. ```scala [WARNING] .../spark/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala:448: method listType in object ConversionPatterns is deprecated: see corresponding Javadoc for more information. [WARNING] ConversionPatterns.listType( [WARNING] ^ [WARNING] .../spark/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala:464: method listType in object ConversionPatterns is deprecated: see corresponding Javadoc for more information. [WARNING] ConversionPatterns.listType( [WARNING] ^ ``` This should not use `listOfElements` (recommended to be replaced from `listType`) instead because the new method checks if the name of elements in Parquet's `LIST` is `element` in Parquet schema and throws an exception if not. However, It seems Spark prior to 1.4.x writes `ArrayType` with Parquet's `LIST` but with `array` as its element name. Therefore, this PR avoids to use both `listOfElements` and `listType` but just use the existing schema builder to construct the same `GroupType`. ## How was this patch tested? Existing tests should cover this. Author: hyukjinkwon Closes #14399 from HyukjinKwon/SPARK-16777. --- .../parquet/ParquetSchemaConverter.scala | 26 ++++++++++++------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala index c81a65f4973e3..b4f36ce3752c0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala @@ -445,14 +445,20 @@ private[parquet] class ParquetSchemaConverter( // repeated array; // } // } - ConversionPatterns.listType( - repetition, - field.name, - Types + + // This should not use `listOfElements` here because this new method checks if the + // element name is `element` in the `GroupType` and throws an exception if not. + // As mentioned above, Spark prior to 1.4.x writes `ArrayType` as `LIST` but with + // `array` as its element name as below. Therefore, we build manually + // the correct group type here via the builder. (See SPARK-16777) + Types + .buildGroup(repetition).as(LIST) + .addField(Types .buildGroup(REPEATED) - // "array_element" is the name chosen by parquet-hive (1.7.0 and prior version) + // "array" is the name chosen by parquet-hive (1.7.0 and prior version) .addField(convertField(StructField("array", elementType, nullable))) .named("bag")) + .named(field.name) // Spark 1.4.x and prior versions convert ArrayType with non-nullable elements into a 2-level // LIST structure. This behavior mimics parquet-avro (1.6.0rc3). Note that this case is @@ -461,11 +467,13 @@ private[parquet] class ParquetSchemaConverter( // group (LIST) { // repeated element; // } - ConversionPatterns.listType( - repetition, - field.name, + + // Here too, we should not use `listOfElements`. (See SPARK-16777) + Types + .buildGroup(repetition).as(LIST) // "array" is the name chosen by parquet-avro (1.7.0 and prior version) - convertField(StructField("array", elementType, nullable), REPEATED)) + .addField(convertField(StructField("array", elementType, nullable), REPEATED)) + .named(field.name) // Spark 1.4.x and prior versions convert MapType into a 3-level group annotated by // MAP_KEY_VALUE. This is covered by `convertGroupField(field: GroupType): DataType`. From 2cac3b2d4a4a4f3d0d45af4defc23bb0ba53484b Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Wed, 28 Sep 2016 00:50:12 +0800 Subject: [PATCH 134/306] [SPARK-16516][SQL] Support for pushing down filters for decimal and timestamp types in ORC ## What changes were proposed in this pull request? It seems ORC supports all the types in ([`PredicateLeaf.Type`](https://github.com/apache/hive/blob/e085b7e9bd059d91aaf013df0db4d71dca90ec6f/storage-api/src/java/org/apache/hadoop/hive/ql/io/sarg/PredicateLeaf.java#L50-L56)) which includes timestamp type and decimal type. In more details, the types listed in [`SearchArgumentImpl.boxLiteral()`](https://github.com/apache/hive/blob/branch-1.2/ql/src/java/org/apache/hadoop/hive/ql/io/sarg/SearchArgumentImpl.java#L1068-L1093) can be used as a filter value. FYI, inital `case` caluse for supported types was introduced in https://github.com/apache/spark/commit/65d71bd9fbfe6fe1b741c80fed72d6ae3d22b028 and this was not changed overtime. At that time, Hive version was, 0.13 which supports only some types for filter-push down (See [SearchArgumentImpl.java#L945-L965](https://github.com/apache/hive/blob/branch-0.13/ql/src/java/org/apache/hadoop/hive/ql/io/sarg/SearchArgumentImpl.java#L945-L965) at 0.13). However, the version was upgraded into 1.2.x and now it supports more types (See [SearchArgumentImpl.java#L1068-L1093](https://github.com/apache/hive/blob/branch-1.2/ql/src/java/org/apache/hadoop/hive/ql/io/sarg/SearchArgumentImpl.java#L1068-L1093) at 1.2.0) ## How was this patch tested? Unit tests in `OrcFilterSuite` and `OrcQuerySuite` Author: hyukjinkwon Closes #14172 from HyukjinKwon/SPARK-16516. --- .../spark/sql/hive/orc/OrcFilters.scala | 1 + .../spark/sql/hive/orc/OrcFilterSuite.scala | 62 ++++++++++++++++--- .../spark/sql/hive/orc/OrcQuerySuite.scala | 35 +++++++++++ 3 files changed, 89 insertions(+), 9 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala index 6ab824455929d..d9efd0cb457cd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala @@ -84,6 +84,7 @@ private[orc] object OrcFilters extends Logging { // the `SearchArgumentImpl.BuilderImpl.boxLiteral()` method. case ByteType | ShortType | FloatType | DoubleType => true case IntegerType | LongType | StringType | BooleanType => true + case TimestampType | _: DecimalType => true case _ => false } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcFilterSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcFilterSuite.scala index 471192a369f4a..222c24927a763 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcFilterSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcFilterSuite.scala @@ -229,6 +229,59 @@ class OrcFilterSuite extends QueryTest with OrcTest { } } + test("filter pushdown - decimal") { + withOrcDataFrame((1 to 4).map(i => Tuple1.apply(BigDecimal.valueOf(i)))) { implicit df => + checkFilterPredicate('_1.isNull, PredicateLeaf.Operator.IS_NULL) + + checkFilterPredicate('_1 === BigDecimal.valueOf(1), PredicateLeaf.Operator.EQUALS) + checkFilterPredicate('_1 <=> BigDecimal.valueOf(1), PredicateLeaf.Operator.NULL_SAFE_EQUALS) + + checkFilterPredicate('_1 < BigDecimal.valueOf(2), PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate('_1 > BigDecimal.valueOf(3), PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 <= BigDecimal.valueOf(1), PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 >= BigDecimal.valueOf(4), PredicateLeaf.Operator.LESS_THAN) + + checkFilterPredicate( + Literal(BigDecimal.valueOf(1)) === '_1, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate( + Literal(BigDecimal.valueOf(1)) <=> '_1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate( + Literal(BigDecimal.valueOf(2)) > '_1, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate( + Literal(BigDecimal.valueOf(3)) < '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate( + Literal(BigDecimal.valueOf(1)) >= '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate( + Literal(BigDecimal.valueOf(4)) <= '_1, PredicateLeaf.Operator.LESS_THAN) + } + } + + test("filter pushdown - timestamp") { + val timeString = "2015-08-20 14:57:00" + val timestamps = (1 to 4).map { i => + val milliseconds = Timestamp.valueOf(timeString).getTime + i * 3600 + new Timestamp(milliseconds) + } + withOrcDataFrame(timestamps.map(Tuple1(_))) { implicit df => + checkFilterPredicate('_1.isNull, PredicateLeaf.Operator.IS_NULL) + + checkFilterPredicate('_1 === timestamps(0), PredicateLeaf.Operator.EQUALS) + checkFilterPredicate('_1 <=> timestamps(0), PredicateLeaf.Operator.NULL_SAFE_EQUALS) + + checkFilterPredicate('_1 < timestamps(1), PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate('_1 > timestamps(2), PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 <= timestamps(0), PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 >= timestamps(3), PredicateLeaf.Operator.LESS_THAN) + + checkFilterPredicate(Literal(timestamps(0)) === '_1, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(Literal(timestamps(0)) <=> '_1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate(Literal(timestamps(1)) > '_1, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(Literal(timestamps(2)) < '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(timestamps(0)) >= '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(timestamps(3)) <= '_1, PredicateLeaf.Operator.LESS_THAN) + } + } + test("filter pushdown - combinations with logical operators") { withOrcDataFrame((1 to 4).map(i => Tuple1(Option(i)))) { implicit df => // Because `ExpressionTree` is not accessible at Hive 1.2.x, this should be checked @@ -277,19 +330,10 @@ class OrcFilterSuite extends QueryTest with OrcTest { withOrcDataFrame((1 to 4).map(i => Tuple1(Array(i)))) { implicit df => checkNoFilterPredicate('_1.isNull) } - // DecimalType - withOrcDataFrame((1 to 4).map(i => Tuple1(BigDecimal.valueOf(i)))) { implicit df => - checkNoFilterPredicate('_1 <= BigDecimal.valueOf(4)) - } // BinaryType withOrcDataFrame((1 to 4).map(i => Tuple1(i.b))) { implicit df => checkNoFilterPredicate('_1 <=> 1.b) } - // TimestampType - val stringTimestamp = "2015-08-20 15:57:00" - withOrcDataFrame(Seq(Tuple1(Timestamp.valueOf(stringTimestamp)))) { implicit df => - checkNoFilterPredicate('_1 <=> Timestamp.valueOf(stringTimestamp)) - } // DateType val stringDate = "2015-01-01" withOrcDataFrame(Seq(Tuple1(Date.valueOf(stringDate)))) { implicit df => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala index b13878d578603..b2ee49c441ef2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.hive.orc import java.nio.charset.StandardCharsets +import java.sql.Timestamp import org.scalatest.BeforeAndAfterAll @@ -500,6 +501,40 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { } } + test("Support for pushing down filters for decimal types") { + withSQLConf(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key -> "true") { + val data = (0 until 10).map(i => Tuple1(BigDecimal.valueOf(i))) + withTempPath { file => + // It needs to repartition data so that we can have several ORC files + // in order to skip stripes in ORC. + createDataFrame(data).toDF("a").repartition(10).write.orc(file.getCanonicalPath) + val df = spark.read.orc(file.getCanonicalPath).where("a == 2") + val actual = stripSparkFilter(df).count() + + assert(actual < 10) + } + } + } + + test("Support for pushing down filters for timestamp types") { + withSQLConf(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key -> "true") { + val timeString = "2015-08-20 14:57:00" + val data = (0 until 10).map { i => + val milliseconds = Timestamp.valueOf(timeString).getTime + i * 3600 + Tuple1(new Timestamp(milliseconds)) + } + withTempPath { file => + // It needs to repartition data so that we can have several ORC files + // in order to skip stripes in ORC. + createDataFrame(data).toDF("a").repartition(10).write.orc(file.getCanonicalPath) + val df = spark.read.orc(file.getCanonicalPath).where(s"a == '$timeString'") + val actual = stripSparkFilter(df).count() + + assert(actual < 10) + } + } + } + test("column nullability and comment - write and then read") { val schema = (new StructType) .add("cl1", IntegerType, nullable = false, comment = "test") From 120723f934dc386a46a043d2833bfcee60d14e74 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 27 Sep 2016 10:20:30 -0700 Subject: [PATCH 135/306] [SPARK-17682][SQL] Mark children as final for unary, binary, leaf expressions and plan nodes ## What changes were proposed in this pull request? This patch marks the children method as final in unary, binary, and leaf expressions and plan nodes (both logical plan and physical plan), as brought up in http://apache-spark-developers-list.1001551.n3.nabble.com/Should-LeafExpression-have-children-final-override-like-Nondeterministic-td19104.html ## How was this patch tested? This is a simple modifier change and has no impact on test coverage. Author: Reynold Xin Closes #15256 from rxin/SPARK-17682. --- .../apache/spark/sql/catalyst/expressions/Expression.scala | 6 +++--- .../apache/spark/sql/catalyst/expressions/generators.scala | 4 ---- .../apache/spark/sql/catalyst/plans/logical/Command.scala | 1 - .../spark/sql/catalyst/plans/logical/LogicalPlan.scala | 6 +++--- .../scala/org/apache/spark/sql/execution/SparkPlan.scala | 6 +++--- 5 files changed, 9 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 7abbbe257d830..fa1a2ad56ccb3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -295,7 +295,7 @@ trait Nondeterministic extends Expression { */ abstract class LeafExpression extends Expression { - def children: Seq[Expression] = Nil + override final def children: Seq[Expression] = Nil } @@ -307,7 +307,7 @@ abstract class UnaryExpression extends Expression { def child: Expression - override def children: Seq[Expression] = child :: Nil + override final def children: Seq[Expression] = child :: Nil override def foldable: Boolean = child.foldable override def nullable: Boolean = child.nullable @@ -394,7 +394,7 @@ abstract class BinaryExpression extends Expression { def left: Expression def right: Expression - override def children: Seq[Expression] = Seq(left, right) + override final def children: Seq[Expression] = Seq(left, right) override def foldable: Boolean = left.foldable && right.foldable diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 9d5c856a23e2a..f74208ff66db7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -152,8 +152,6 @@ case class Stack(children: Seq[Expression]) abstract class ExplodeBase(child: Expression, position: Boolean) extends UnaryExpression with Generator with CodegenFallback with Serializable { - override def children: Seq[Expression] = child :: Nil - override def checkInputDataTypes(): TypeCheckResult = { if (child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType]) { TypeCheckResult.TypeCheckSuccess @@ -257,8 +255,6 @@ case class PosExplode(child: Expression) extends ExplodeBase(child, position = t extended = "> SELECT _FUNC_(array(struct(1, 'a'), struct(2, 'b')));\n [1,a]\n [2,b]") case class Inline(child: Expression) extends UnaryExpression with Generator with CodegenFallback { - override def children: Seq[Expression] = child :: Nil - override def checkInputDataTypes(): TypeCheckResult = child.dataType match { case ArrayType(et, _) if et.isInstanceOf[StructType] => TypeCheckResult.TypeCheckSuccess diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Command.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Command.scala index 64f57835c8898..38f47081b6f55 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Command.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Command.scala @@ -25,6 +25,5 @@ import org.apache.spark.sql.catalyst.expressions.Attribute * eagerly executed. */ trait Command extends LeafNode { - final override def children: Seq[LogicalPlan] = Seq.empty override def output: Seq[Attribute] = Seq.empty } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 6d7799151d93b..09725473a384d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -276,7 +276,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { * A logical plan node with no children. */ abstract class LeafNode extends LogicalPlan { - override def children: Seq[LogicalPlan] = Nil + override final def children: Seq[LogicalPlan] = Nil override def producedAttributes: AttributeSet = outputSet } @@ -286,7 +286,7 @@ abstract class LeafNode extends LogicalPlan { abstract class UnaryNode extends LogicalPlan { def child: LogicalPlan - override def children: Seq[LogicalPlan] = child :: Nil + override final def children: Seq[LogicalPlan] = child :: Nil /** * Generates an additional set of aliased constraints by replacing the original constraint @@ -330,5 +330,5 @@ abstract class BinaryNode extends LogicalPlan { def left: LogicalPlan def right: LogicalPlan - override def children: Seq[LogicalPlan] = Seq(left, right) + override final def children: Seq[LogicalPlan] = Seq(left, right) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 6aeefa6eddafe..48d6ef6dcd44a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -380,7 +380,7 @@ object SparkPlan { } trait LeafExecNode extends SparkPlan { - override def children: Seq[SparkPlan] = Nil + override final def children: Seq[SparkPlan] = Nil override def producedAttributes: AttributeSet = outputSet } @@ -394,7 +394,7 @@ object UnaryExecNode { trait UnaryExecNode extends SparkPlan { def child: SparkPlan - override def children: Seq[SparkPlan] = child :: Nil + override final def children: Seq[SparkPlan] = child :: Nil override def outputPartitioning: Partitioning = child.outputPartitioning } @@ -403,5 +403,5 @@ trait BinaryExecNode extends SparkPlan { def left: SparkPlan def right: SparkPlan - override def children: Seq[SparkPlan] = Seq(left, right) + override final def children: Seq[SparkPlan] = Seq(left, right) } From 2ab24a7bf6687ec238306772c4c7ddef6ac93e99 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 27 Sep 2016 10:52:26 -0700 Subject: [PATCH 136/306] [SPARK-17660][SQL] DESC FORMATTED for VIEW Lacks View Definition ### What changes were proposed in this pull request? Before this PR, `DESC FORMATTED` does not have a section for the view definition. We should add it for permanent views, like what Hive does. ``` +----------------------------+-------------------------------------------------------------------------------------------------------------------------------------+-------+ |col_name |data_type |comment| +----------------------------+-------------------------------------------------------------------------------------------------------------------------------------+-------+ |a |int |null | | | | | |# Detailed Table Information| | | |Database: |default | | |Owner: |xiaoli | | |Create Time: |Sat Sep 24 21:46:19 PDT 2016 | | |Last Access Time: |Wed Dec 31 16:00:00 PST 1969 | | |Location: | | | |Table Type: |VIEW | | |Table Parameters: | | | | transient_lastDdlTime |1474778779 | | | | | | |# Storage Information | | | |SerDe Library: |org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe | | |InputFormat: |org.apache.hadoop.mapred.SequenceFileInputFormat | | |OutputFormat: |org.apache.hadoop.hive.ql.io.HiveSequenceFileOutputFormat | | |Compressed: |No | | |Storage Desc Parameters: | | | | serialization.format |1 | | | | | | |# View Information | | | |View Original Text: |SELECT * FROM tbl | | |View Expanded Text: |SELECT `gen_attr_0` AS `a` FROM (SELECT `gen_attr_0` FROM (SELECT `a` AS `gen_attr_0` FROM `default`.`tbl`) AS gen_subquery_0) AS tbl| | +----------------------------+-------------------------------------------------------------------------------------------------------------------------------------+-------+ ``` ### How was this patch tested? Added a test case Author: gatorsmile Closes #15234 from gatorsmile/descFormattedView. --- .../spark/sql/execution/command/tables.scala | 9 +++++++++ .../sql/hive/execution/HiveDDLSuite.scala | 19 +++++++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index 0f61629317c81..6a91c997bac63 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -462,6 +462,8 @@ case class DescribeTableCommand(table: TableIdentifier, isExtended: Boolean, isF } describeStorageInfo(table, buffer) + + if (table.tableType == CatalogTableType.VIEW) describeViewInfo(table, buffer) } private def describeStorageInfo(metadata: CatalogTable, buffer: ArrayBuffer[Row]): Unit = { @@ -479,6 +481,13 @@ case class DescribeTableCommand(table: TableIdentifier, isExtended: Boolean, isF } } + private def describeViewInfo(metadata: CatalogTable, buffer: ArrayBuffer[Row]): Unit = { + append(buffer, "", "", "") + append(buffer, "# View Information", "", "") + append(buffer, "View Original Text:", metadata.viewOriginalText.getOrElse(""), "") + append(buffer, "View Expanded Text:", metadata.viewText.getOrElse(""), "") + } + private def describeBucketingInfo(metadata: CatalogTable, buffer: ArrayBuffer[Row]): Unit = { metadata.bucketSpec match { case Some(BucketSpec(numBuckets, bucketColumnNames, sortColumnNames)) => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index c927e5d802c90..751e976c7b908 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -506,6 +506,25 @@ class HiveDDLSuite } } + test("desc formatted table for permanent view") { + withTable("tbl") { + withView("view1") { + sql("CREATE TABLE tbl(a int)") + sql("CREATE VIEW view1 AS SELECT * FROM tbl") + assert(sql("DESC FORMATTED view1").collect().containsSlice( + Seq( + Row("# View Information", "", ""), + Row("View Original Text:", "SELECT * FROM tbl", ""), + Row("View Expanded Text:", + "SELECT `gen_attr_0` AS `a` FROM (SELECT `gen_attr_0` FROM " + + "(SELECT `a` AS `gen_attr_0` FROM `default`.`tbl`) AS gen_subquery_0) AS tbl", + "") + ) + )) + } + } + } + test("desc table for data source table using Hive Metastore") { assume(spark.sparkContext.conf.get(CATALOG_IMPLEMENTATION) == "hive") val tabName = "tab1" From 67c73052b877a8709ae6fa22b844a45f114b1f7e Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 27 Sep 2016 12:37:19 -0700 Subject: [PATCH 137/306] [SPARK-17677][SQL] Break WindowExec.scala into multiple files ## What changes were proposed in this pull request? As of Spark 2.0, all the window function execution code are in WindowExec.scala. This file is pretty large (over 1k loc) and has a lot of different abstractions in them. This patch creates a new package sql.execution.window, moves WindowExec.scala in it, and breaks WindowExec.scala into multiple, more maintainable pieces: - AggregateProcessor.scala - BoundOrdering.scala - RowBuffer.scala - WindowExec - WindowFunctionFrame.scala ## How was this patch tested? This patch mostly moves code around, and should not change any existing test coverage. Author: Reynold Xin Closes #15252 from rxin/SPARK-17677. --- .../spark/sql/execution/SparkStrategies.scala | 3 +- .../spark/sql/execution/WindowExec.scala | 1013 ----------------- .../execution/window/AggregateProcessor.scala | 159 +++ .../sql/execution/window/BoundOrdering.scala | 58 + .../sql/execution/window/RowBuffer.scala | 115 ++ .../sql/execution/window/WindowExec.scala | 412 +++++++ .../window/WindowFunctionFrame.scala | 367 ++++++ 7 files changed, 1112 insertions(+), 1015 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/WindowExec.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/window/AggregateProcessor.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/window/BoundOrdering.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/window/RowBuffer.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 3441ccf53b45b..7cfae5ce283bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD import org.apache.spark.sql.{execution, SaveMode, Strategy} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.catalog.CatalogTableType import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning._ @@ -387,7 +386,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case e @ logical.Expand(_, _, child) => execution.ExpandExec(e.projections, e.output, planLater(child)) :: Nil case logical.Window(windowExprs, partitionSpec, orderSpec, child) => - execution.WindowExec(windowExprs, partitionSpec, orderSpec, planLater(child)) :: Nil + execution.window.WindowExec(windowExprs, partitionSpec, orderSpec, planLater(child)) :: Nil case logical.Sample(lb, ub, withReplacement, seed, child) => execution.SampleExec(lb, ub, withReplacement, seed, planLater(child)) :: Nil case logical.LocalRelation(output, data) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WindowExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WindowExec.scala deleted file mode 100644 index 9d006d21d9440..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WindowExec.scala +++ /dev/null @@ -1,1013 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import java.util - -import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer - -import org.apache.spark.{SparkEnv, TaskContext} -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.types.IntegerType -import org.apache.spark.util.collection.unsafe.sort.{UnsafeExternalSorter, UnsafeSorterIterator} - -/** - * This class calculates and outputs (windowed) aggregates over the rows in a single (sorted) - * partition. The aggregates are calculated for each row in the group. Special processing - * instructions, frames, are used to calculate these aggregates. Frames are processed in the order - * specified in the window specification (the ORDER BY ... clause). There are four different frame - * types: - * - Entire partition: The frame is the entire partition, i.e. - * UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING. For this case, window function will take all - * rows as inputs and be evaluated once. - * - Growing frame: We only add new rows into the frame, i.e. UNBOUNDED PRECEDING AND .... - * Every time we move to a new row to process, we add some rows to the frame. We do not remove - * rows from this frame. - * - Shrinking frame: We only remove rows from the frame, i.e. ... AND UNBOUNDED FOLLOWING. - * Every time we move to a new row to process, we remove some rows from the frame. We do not add - * rows to this frame. - * - Moving frame: Every time we move to a new row to process, we remove some rows from the frame - * and we add some rows to the frame. Examples are: - * 1 PRECEDING AND CURRENT ROW and 1 FOLLOWING AND 2 FOLLOWING. - * - Offset frame: The frame consist of one row, which is an offset number of rows away from the - * current row. Only [[OffsetWindowFunction]]s can be processed in an offset frame. - * - * Different frame boundaries can be used in Growing, Shrinking and Moving frames. A frame - * boundary can be either Row or Range based: - * - Row Based: A row based boundary is based on the position of the row within the partition. - * An offset indicates the number of rows above or below the current row, the frame for the - * current row starts or ends. For instance, given a row based sliding frame with a lower bound - * offset of -1 and a upper bound offset of +2. The frame for row with index 5 would range from - * index 4 to index 6. - * - Range based: A range based boundary is based on the actual value of the ORDER BY - * expression(s). An offset is used to alter the value of the ORDER BY expression, for - * instance if the current order by expression has a value of 10 and the lower bound offset - * is -3, the resulting lower bound for the current row will be 10 - 3 = 7. This however puts a - * number of constraints on the ORDER BY expressions: there can be only one expression and this - * expression must have a numerical data type. An exception can be made when the offset is 0, - * because no value modification is needed, in this case multiple and non-numeric ORDER BY - * expression are allowed. - * - * This is quite an expensive operator because every row for a single group must be in the same - * partition and partitions must be sorted according to the grouping and sort order. The operator - * requires the planner to take care of the partitioning and sorting. - * - * The operator is semi-blocking. The window functions and aggregates are calculated one group at - * a time, the result will only be made available after the processing for the entire group has - * finished. The operator is able to process different frame configurations at the same time. This - * is done by delegating the actual frame processing (i.e. calculation of the window functions) to - * specialized classes, see [[WindowFunctionFrame]], which take care of their own frame type: - * Entire Partition, Sliding, Growing & Shrinking. Boundary evaluation is also delegated to a pair - * of specialized classes: [[RowBoundOrdering]] & [[RangeBoundOrdering]]. - */ -case class WindowExec( - windowExpression: Seq[NamedExpression], - partitionSpec: Seq[Expression], - orderSpec: Seq[SortOrder], - child: SparkPlan) - extends UnaryExecNode { - - override def output: Seq[Attribute] = - child.output ++ windowExpression.map(_.toAttribute) - - override def requiredChildDistribution: Seq[Distribution] = { - if (partitionSpec.isEmpty) { - // Only show warning when the number of bytes is larger than 100 MB? - logWarning("No Partition Defined for Window operation! Moving all data to a single " - + "partition, this can cause serious performance degradation.") - AllTuples :: Nil - } else ClusteredDistribution(partitionSpec) :: Nil - } - - override def requiredChildOrdering: Seq[Seq[SortOrder]] = - Seq(partitionSpec.map(SortOrder(_, Ascending)) ++ orderSpec) - - override def outputOrdering: Seq[SortOrder] = child.outputOrdering - - /** - * Create a bound ordering object for a given frame type and offset. A bound ordering object is - * used to determine which input row lies within the frame boundaries of an output row. - * - * This method uses Code Generation. It can only be used on the executor side. - * - * @param frameType to evaluate. This can either be Row or Range based. - * @param offset with respect to the row. - * @return a bound ordering object. - */ - private[this] def createBoundOrdering(frameType: FrameType, offset: Int): BoundOrdering = { - frameType match { - case RangeFrame => - val (exprs, current, bound) = if (offset == 0) { - // Use the entire order expression when the offset is 0. - val exprs = orderSpec.map(_.child) - val buildProjection = () => newMutableProjection(exprs, child.output) - (orderSpec, buildProjection(), buildProjection()) - } else if (orderSpec.size == 1) { - // Use only the first order expression when the offset is non-null. - val sortExpr = orderSpec.head - val expr = sortExpr.child - // Create the projection which returns the current 'value'. - val current = newMutableProjection(expr :: Nil, child.output) - // Flip the sign of the offset when processing the order is descending - val boundOffset = sortExpr.direction match { - case Descending => -offset - case Ascending => offset - } - // Create the projection which returns the current 'value' modified by adding the offset. - val boundExpr = Add(expr, Cast(Literal.create(boundOffset, IntegerType), expr.dataType)) - val bound = newMutableProjection(boundExpr :: Nil, child.output) - (sortExpr :: Nil, current, bound) - } else { - sys.error("Non-Zero range offsets are not supported for windows " + - "with multiple order expressions.") - } - // Construct the ordering. This is used to compare the result of current value projection - // to the result of bound value projection. This is done manually because we want to use - // Code Generation (if it is enabled). - val sortExprs = exprs.zipWithIndex.map { case (e, i) => - SortOrder(BoundReference(i, e.dataType, e.nullable), e.direction) - } - val ordering = newOrdering(sortExprs, Nil) - RangeBoundOrdering(ordering, current, bound) - case RowFrame => RowBoundOrdering(offset) - } - } - - /** - * Collection containing an entry for each window frame to process. Each entry contains a frames' - * WindowExpressions and factory function for the WindowFrameFunction. - */ - private[this] lazy val windowFrameExpressionFactoryPairs = { - type FrameKey = (String, FrameType, Option[Int], Option[Int]) - type ExpressionBuffer = mutable.Buffer[Expression] - val framedFunctions = mutable.Map.empty[FrameKey, (ExpressionBuffer, ExpressionBuffer)] - - // Add a function and its function to the map for a given frame. - def collect(tpe: String, fr: SpecifiedWindowFrame, e: Expression, fn: Expression): Unit = { - val key = (tpe, fr.frameType, FrameBoundary(fr.frameStart), FrameBoundary(fr.frameEnd)) - val (es, fns) = framedFunctions.getOrElseUpdate( - key, (ArrayBuffer.empty[Expression], ArrayBuffer.empty[Expression])) - es += e - fns += fn - } - - // Collect all valid window functions and group them by their frame. - windowExpression.foreach { x => - x.foreach { - case e @ WindowExpression(function, spec) => - val frame = spec.frameSpecification.asInstanceOf[SpecifiedWindowFrame] - function match { - case AggregateExpression(f, _, _, _) => collect("AGGREGATE", frame, e, f) - case f: AggregateWindowFunction => collect("AGGREGATE", frame, e, f) - case f: OffsetWindowFunction => collect("OFFSET", frame, e, f) - case f => sys.error(s"Unsupported window function: $f") - } - case _ => - } - } - - // Map the groups to a (unbound) expression and frame factory pair. - var numExpressions = 0 - framedFunctions.toSeq.map { - case (key, (expressions, functionSeq)) => - val ordinal = numExpressions - val functions = functionSeq.toArray - - // Construct an aggregate processor if we need one. - def processor = AggregateProcessor( - functions, - ordinal, - child.output, - (expressions, schema) => - newMutableProjection(expressions, schema, subexpressionEliminationEnabled)) - - // Create the factory - val factory = key match { - // Offset Frame - case ("OFFSET", RowFrame, Some(offset), Some(h)) if offset == h => - target: MutableRow => - new OffsetWindowFunctionFrame( - target, - ordinal, - // OFFSET frame functions are guaranteed be OffsetWindowFunctions. - functions.map(_.asInstanceOf[OffsetWindowFunction]), - child.output, - (expressions, schema) => - newMutableProjection(expressions, schema, subexpressionEliminationEnabled), - offset) - - // Growing Frame. - case ("AGGREGATE", frameType, None, Some(high)) => - target: MutableRow => { - new UnboundedPrecedingWindowFunctionFrame( - target, - processor, - createBoundOrdering(frameType, high)) - } - - // Shrinking Frame. - case ("AGGREGATE", frameType, Some(low), None) => - target: MutableRow => { - new UnboundedFollowingWindowFunctionFrame( - target, - processor, - createBoundOrdering(frameType, low)) - } - - // Moving Frame. - case ("AGGREGATE", frameType, Some(low), Some(high)) => - target: MutableRow => { - new SlidingWindowFunctionFrame( - target, - processor, - createBoundOrdering(frameType, low), - createBoundOrdering(frameType, high)) - } - - // Entire Partition Frame. - case ("AGGREGATE", frameType, None, None) => - target: MutableRow => { - new UnboundedWindowFunctionFrame(target, processor) - } - } - - // Keep track of the number of expressions. This is a side-effect in a map... - numExpressions += expressions.size - - // Create the Frame Expression - Factory pair. - (expressions, factory) - } - } - - /** - * Create the resulting projection. - * - * This method uses Code Generation. It can only be used on the executor side. - * - * @param expressions unbound ordered function expressions. - * @return the final resulting projection. - */ - private[this] def createResultProjection( - expressions: Seq[Expression]): UnsafeProjection = { - val references = expressions.zipWithIndex.map{ case (e, i) => - // Results of window expressions will be on the right side of child's output - BoundReference(child.output.size + i, e.dataType, e.nullable) - } - val unboundToRefMap = expressions.zip(references).toMap - val patchedWindowExpression = windowExpression.map(_.transform(unboundToRefMap)) - UnsafeProjection.create( - child.output ++ patchedWindowExpression, - child.output) - } - - protected override def doExecute(): RDD[InternalRow] = { - // Unwrap the expressions and factories from the map. - val expressions = windowFrameExpressionFactoryPairs.flatMap(_._1) - val factories = windowFrameExpressionFactoryPairs.map(_._2).toArray - - // Start processing. - child.execute().mapPartitions { stream => - new Iterator[InternalRow] { - - // Get all relevant projections. - val result = createResultProjection(expressions) - val grouping = UnsafeProjection.create(partitionSpec, child.output) - - // Manage the stream and the grouping. - var nextRow: UnsafeRow = null - var nextGroup: UnsafeRow = null - var nextRowAvailable: Boolean = false - private[this] def fetchNextRow() { - nextRowAvailable = stream.hasNext - if (nextRowAvailable) { - nextRow = stream.next().asInstanceOf[UnsafeRow] - nextGroup = grouping(nextRow) - } else { - nextRow = null - nextGroup = null - } - } - fetchNextRow() - - // Manage the current partition. - val rows = ArrayBuffer.empty[UnsafeRow] - val inputFields = child.output.length - var sorter: UnsafeExternalSorter = null - var rowBuffer: RowBuffer = null - val windowFunctionResult = new SpecificMutableRow(expressions.map(_.dataType)) - val frames = factories.map(_(windowFunctionResult)) - val numFrames = frames.length - private[this] def fetchNextPartition() { - // Collect all the rows in the current partition. - // Before we start to fetch new input rows, make a copy of nextGroup. - val currentGroup = nextGroup.copy() - - // clear last partition - if (sorter != null) { - // the last sorter of this task will be cleaned up via task completion listener - sorter.cleanupResources() - sorter = null - } else { - rows.clear() - } - - while (nextRowAvailable && nextGroup == currentGroup) { - if (sorter == null) { - rows += nextRow.copy() - - if (rows.length >= 4096) { - // We will not sort the rows, so prefixComparator and recordComparator are null. - sorter = UnsafeExternalSorter.create( - TaskContext.get().taskMemoryManager(), - SparkEnv.get.blockManager, - SparkEnv.get.serializerManager, - TaskContext.get(), - null, - null, - 1024, - SparkEnv.get.memoryManager.pageSizeBytes, - SparkEnv.get.conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", - UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD), - false) - rows.foreach { r => - sorter.insertRecord(r.getBaseObject, r.getBaseOffset, r.getSizeInBytes, 0, false) - } - rows.clear() - } - } else { - sorter.insertRecord(nextRow.getBaseObject, nextRow.getBaseOffset, - nextRow.getSizeInBytes, 0, false) - } - fetchNextRow() - } - if (sorter != null) { - rowBuffer = new ExternalRowBuffer(sorter, inputFields) - } else { - rowBuffer = new ArrayRowBuffer(rows) - } - - // Setup the frames. - var i = 0 - while (i < numFrames) { - frames(i).prepare(rowBuffer.copy()) - i += 1 - } - - // Setup iteration - rowIndex = 0 - rowsSize = rowBuffer.size() - } - - // Iteration - var rowIndex = 0 - var rowsSize = 0L - - override final def hasNext: Boolean = rowIndex < rowsSize || nextRowAvailable - - val join = new JoinedRow - override final def next(): InternalRow = { - // Load the next partition if we need to. - if (rowIndex >= rowsSize && nextRowAvailable) { - fetchNextPartition() - } - - if (rowIndex < rowsSize) { - // Get the results for the window frames. - var i = 0 - val current = rowBuffer.next() - while (i < numFrames) { - frames(i).write(rowIndex, current) - i += 1 - } - - // 'Merge' the input row with the window function result - join(current, windowFunctionResult) - rowIndex += 1 - - // Return the projection. - result(join) - } else throw new NoSuchElementException - } - } - } - } -} - -/** - * Function for comparing boundary values. - */ -private[execution] abstract class BoundOrdering { - def compare(inputRow: InternalRow, inputIndex: Int, outputRow: InternalRow, outputIndex: Int): Int -} - -/** - * Compare the input index to the bound of the output index. - */ -private[execution] final case class RowBoundOrdering(offset: Int) extends BoundOrdering { - override def compare( - inputRow: InternalRow, - inputIndex: Int, - outputRow: InternalRow, - outputIndex: Int): Int = - inputIndex - (outputIndex + offset) -} - -/** - * Compare the value of the input index to the value bound of the output index. - */ -private[execution] final case class RangeBoundOrdering( - ordering: Ordering[InternalRow], - current: Projection, - bound: Projection) extends BoundOrdering { - override def compare( - inputRow: InternalRow, - inputIndex: Int, - outputRow: InternalRow, - outputIndex: Int): Int = - ordering.compare(current(inputRow), bound(outputRow)) -} - -/** - * The interface of row buffer for a partition - */ -private[execution] abstract class RowBuffer { - - /** Number of rows. */ - def size(): Int - - /** Return next row in the buffer, null if no more left. */ - def next(): InternalRow - - /** Skip the next `n` rows. */ - def skip(n: Int): Unit - - /** Return a new RowBuffer that has the same rows. */ - def copy(): RowBuffer -} - -/** - * A row buffer based on ArrayBuffer (the number of rows is limited) - */ -private[execution] class ArrayRowBuffer(buffer: ArrayBuffer[UnsafeRow]) extends RowBuffer { - - private[this] var cursor: Int = -1 - - /** Number of rows. */ - def size(): Int = buffer.length - - /** Return next row in the buffer, null if no more left. */ - def next(): InternalRow = { - cursor += 1 - if (cursor < buffer.length) { - buffer(cursor) - } else { - null - } - } - - /** Skip the next `n` rows. */ - def skip(n: Int): Unit = { - cursor += n - } - - /** Return a new RowBuffer that has the same rows. */ - def copy(): RowBuffer = { - new ArrayRowBuffer(buffer) - } -} - -/** - * An external buffer of rows based on UnsafeExternalSorter - */ -private[execution] class ExternalRowBuffer(sorter: UnsafeExternalSorter, numFields: Int) - extends RowBuffer { - - private[this] val iter: UnsafeSorterIterator = sorter.getIterator - - private[this] val currentRow = new UnsafeRow(numFields) - - /** Number of rows. */ - def size(): Int = iter.getNumRecords() - - /** Return next row in the buffer, null if no more left. */ - def next(): InternalRow = { - if (iter.hasNext) { - iter.loadNext() - currentRow.pointTo(iter.getBaseObject, iter.getBaseOffset, iter.getRecordLength) - currentRow - } else { - null - } - } - - /** Skip the next `n` rows. */ - def skip(n: Int): Unit = { - var i = 0 - while (i < n && iter.hasNext) { - iter.loadNext() - i += 1 - } - } - - /** Return a new RowBuffer that has the same rows. */ - def copy(): RowBuffer = { - new ExternalRowBuffer(sorter, numFields) - } -} - -/** - * A window function calculates the results of a number of window functions for a window frame. - * Before use a frame must be prepared by passing it all the rows in the current partition. After - * preparation the update method can be called to fill the output rows. - */ -private[execution] abstract class WindowFunctionFrame { - /** - * Prepare the frame for calculating the results for a partition. - * - * @param rows to calculate the frame results for. - */ - def prepare(rows: RowBuffer): Unit - - /** - * Write the current results to the target row. - */ - def write(index: Int, current: InternalRow): Unit -} - -/** - * The offset window frame calculates frames containing LEAD/LAG statements. - * - * @param target to write results to. - * @param ordinal the ordinal is the starting offset at which the results of the window frame get - * written into the (shared) target row. The result of the frame expression with - * index 'i' will be written to the 'ordinal' + 'i' position in the target row. - * @param expressions to shift a number of rows. - * @param inputSchema required for creating a projection. - * @param newMutableProjection function used to create the projection. - * @param offset by which rows get moved within a partition. - */ -private[execution] final class OffsetWindowFunctionFrame( - target: MutableRow, - ordinal: Int, - expressions: Array[OffsetWindowFunction], - inputSchema: Seq[Attribute], - newMutableProjection: (Seq[Expression], Seq[Attribute]) => MutableProjection, - offset: Int) extends WindowFunctionFrame { - - /** Rows of the partition currently being processed. */ - private[this] var input: RowBuffer = null - - /** Index of the input row currently used for output. */ - private[this] var inputIndex = 0 - - /** - * Create the projection used when the offset row exists. - * Please note that this project always respect null input values (like PostgreSQL). - */ - private[this] val projection = { - // Collect the expressions and bind them. - val inputAttrs = inputSchema.map(_.withNullability(true)) - val boundExpressions = Seq.fill(ordinal)(NoOp) ++ expressions.toSeq.map { e => - BindReferences.bindReference(e.input, inputAttrs) - } - - // Create the projection. - newMutableProjection(boundExpressions, Nil).target(target) - } - - /** Create the projection used when the offset row DOES NOT exists. */ - private[this] val fillDefaultValue = { - // Collect the expressions and bind them. - val inputAttrs = inputSchema.map(_.withNullability(true)) - val boundExpressions = Seq.fill(ordinal)(NoOp) ++ expressions.toSeq.map { e => - if (e.default == null || e.default.foldable && e.default.eval() == null) { - // The default value is null. - Literal.create(null, e.dataType) - } else { - // The default value is an expression. - BindReferences.bindReference(e.default, inputAttrs) - } - } - - // Create the projection. - newMutableProjection(boundExpressions, Nil).target(target) - } - - override def prepare(rows: RowBuffer): Unit = { - input = rows - // drain the first few rows if offset is larger than zero - inputIndex = 0 - while (inputIndex < offset) { - input.next() - inputIndex += 1 - } - inputIndex = offset - } - - override def write(index: Int, current: InternalRow): Unit = { - if (inputIndex >= 0 && inputIndex < input.size) { - val r = input.next() - projection(r) - } else { - // Use default values since the offset row does not exist. - fillDefaultValue(current) - } - inputIndex += 1 - } -} - -/** - * The sliding window frame calculates frames with the following SQL form: - * ... BETWEEN 1 PRECEDING AND 1 FOLLOWING - * - * @param target to write results to. - * @param processor to calculate the row values with. - * @param lbound comparator used to identify the lower bound of an output row. - * @param ubound comparator used to identify the upper bound of an output row. - */ -private[execution] final class SlidingWindowFunctionFrame( - target: MutableRow, - processor: AggregateProcessor, - lbound: BoundOrdering, - ubound: BoundOrdering) extends WindowFunctionFrame { - - /** Rows of the partition currently being processed. */ - private[this] var input: RowBuffer = null - - /** The next row from `input`. */ - private[this] var nextRow: InternalRow = null - - /** The rows within current sliding window. */ - private[this] val buffer = new util.ArrayDeque[InternalRow]() - - /** - * Index of the first input row with a value greater than the upper bound of the current - * output row. - */ - private[this] var inputHighIndex = 0 - - /** - * Index of the first input row with a value equal to or greater than the lower bound of the - * current output row. - */ - private[this] var inputLowIndex = 0 - - /** Prepare the frame for calculating a new partition. Reset all variables. */ - override def prepare(rows: RowBuffer): Unit = { - input = rows - nextRow = rows.next() - inputHighIndex = 0 - inputLowIndex = 0 - buffer.clear() - } - - /** Write the frame columns for the current row to the given target row. */ - override def write(index: Int, current: InternalRow): Unit = { - var bufferUpdated = index == 0 - - // Add all rows to the buffer for which the input row value is equal to or less than - // the output row upper bound. - while (nextRow != null && ubound.compare(nextRow, inputHighIndex, current, index) <= 0) { - buffer.add(nextRow.copy()) - nextRow = input.next() - inputHighIndex += 1 - bufferUpdated = true - } - - // Drop all rows from the buffer for which the input row value is smaller than - // the output row lower bound. - while (!buffer.isEmpty && lbound.compare(buffer.peek(), inputLowIndex, current, index) < 0) { - buffer.remove() - inputLowIndex += 1 - bufferUpdated = true - } - - // Only recalculate and update when the buffer changes. - if (bufferUpdated) { - processor.initialize(input.size) - val iter = buffer.iterator() - while (iter.hasNext) { - processor.update(iter.next()) - } - processor.evaluate(target) - } - } -} - -/** - * The unbounded window frame calculates frames with the following SQL forms: - * ... (No Frame Definition) - * ... BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING - * - * Its results are the same for each and every row in the partition. This class can be seen as a - * special case of a sliding window, but is optimized for the unbound case. - * - * @param target to write results to. - * @param processor to calculate the row values with. - */ -private[execution] final class UnboundedWindowFunctionFrame( - target: MutableRow, - processor: AggregateProcessor) extends WindowFunctionFrame { - - /** Prepare the frame for calculating a new partition. Process all rows eagerly. */ - override def prepare(rows: RowBuffer): Unit = { - val size = rows.size() - processor.initialize(size) - var i = 0 - while (i < size) { - processor.update(rows.next()) - i += 1 - } - } - - /** Write the frame columns for the current row to the given target row. */ - override def write(index: Int, current: InternalRow): Unit = { - // Unfortunately we cannot assume that evaluation is deterministic. So we need to re-evaluate - // for each row. - processor.evaluate(target) - } -} - -/** - * The UnboundPreceding window frame calculates frames with the following SQL form: - * ... BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW - * - * There is only an upper bound. Very common use cases are for instance running sums or counts - * (row_number). Technically this is a special case of a sliding window. However a sliding window - * has to maintain a buffer, and it must do a full evaluation everytime the buffer changes. This - * is not the case when there is no lower bound, given the additive nature of most aggregates - * streaming updates and partial evaluation suffice and no buffering is needed. - * - * @param target to write results to. - * @param processor to calculate the row values with. - * @param ubound comparator used to identify the upper bound of an output row. - */ -private[execution] final class UnboundedPrecedingWindowFunctionFrame( - target: MutableRow, - processor: AggregateProcessor, - ubound: BoundOrdering) extends WindowFunctionFrame { - - /** Rows of the partition currently being processed. */ - private[this] var input: RowBuffer = null - - /** The next row from `input`. */ - private[this] var nextRow: InternalRow = null - - /** - * Index of the first input row with a value greater than the upper bound of the current - * output row. - */ - private[this] var inputIndex = 0 - - /** Prepare the frame for calculating a new partition. */ - override def prepare(rows: RowBuffer): Unit = { - input = rows - nextRow = rows.next() - inputIndex = 0 - processor.initialize(input.size) - } - - /** Write the frame columns for the current row to the given target row. */ - override def write(index: Int, current: InternalRow): Unit = { - var bufferUpdated = index == 0 - - // Add all rows to the aggregates for which the input row value is equal to or less than - // the output row upper bound. - while (nextRow != null && ubound.compare(nextRow, inputIndex, current, index) <= 0) { - processor.update(nextRow) - nextRow = input.next() - inputIndex += 1 - bufferUpdated = true - } - - // Only recalculate and update when the buffer changes. - if (bufferUpdated) { - processor.evaluate(target) - } - } -} - -/** - * The UnboundFollowing window frame calculates frames with the following SQL form: - * ... BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING - * - * There is only an upper bound. This is a slightly modified version of the sliding window. The - * sliding window operator has to check if both upper and the lower bound change when a new row - * gets processed, where as the unbounded following only has to check the lower bound. - * - * This is a very expensive operator to use, O(n * (n - 1) /2), because we need to maintain a - * buffer and must do full recalculation after each row. Reverse iteration would be possible, if - * the commutativity of the used window functions can be guaranteed. - * - * @param target to write results to. - * @param processor to calculate the row values with. - * @param lbound comparator used to identify the lower bound of an output row. - */ -private[execution] final class UnboundedFollowingWindowFunctionFrame( - target: MutableRow, - processor: AggregateProcessor, - lbound: BoundOrdering) extends WindowFunctionFrame { - - /** Rows of the partition currently being processed. */ - private[this] var input: RowBuffer = null - - /** - * Index of the first input row with a value equal to or greater than the lower bound of the - * current output row. - */ - private[this] var inputIndex = 0 - - /** Prepare the frame for calculating a new partition. */ - override def prepare(rows: RowBuffer): Unit = { - input = rows - inputIndex = 0 - } - - /** Write the frame columns for the current row to the given target row. */ - override def write(index: Int, current: InternalRow): Unit = { - var bufferUpdated = index == 0 - - // Duplicate the input to have a new iterator - val tmp = input.copy() - - // Drop all rows from the buffer for which the input row value is smaller than - // the output row lower bound. - tmp.skip(inputIndex) - var nextRow = tmp.next() - while (nextRow != null && lbound.compare(nextRow, inputIndex, current, index) < 0) { - nextRow = tmp.next() - inputIndex += 1 - bufferUpdated = true - } - - // Only recalculate and update when the buffer changes. - if (bufferUpdated) { - processor.initialize(input.size) - while (nextRow != null) { - processor.update(nextRow) - nextRow = tmp.next() - } - processor.evaluate(target) - } - } -} - -/** - * This class prepares and manages the processing of a number of [[AggregateFunction]]s within a - * single frame. The [[WindowFunctionFrame]] takes care of processing the frame in the correct way, - * this reduces the processing of a [[AggregateWindowFunction]] to processing the underlying - * [[AggregateFunction]]. All [[AggregateFunction]]s are processed in [[Complete]] mode. - * - * [[SizeBasedWindowFunction]]s are initialized in a slightly different way. These functions - * require the size of the partition processed, this value is exposed to them when the processor is - * constructed. - * - * Processing of distinct aggregates is currently not supported. - * - * The implementation is split into an object which takes care of construction, and a the actual - * processor class. - */ -private[execution] object AggregateProcessor { - def apply( - functions: Array[Expression], - ordinal: Int, - inputAttributes: Seq[Attribute], - newMutableProjection: (Seq[Expression], Seq[Attribute]) => MutableProjection): - AggregateProcessor = { - val aggBufferAttributes = mutable.Buffer.empty[AttributeReference] - val initialValues = mutable.Buffer.empty[Expression] - val updateExpressions = mutable.Buffer.empty[Expression] - val evaluateExpressions = mutable.Buffer.fill[Expression](ordinal)(NoOp) - val imperatives = mutable.Buffer.empty[ImperativeAggregate] - - // SPARK-14244: `SizeBasedWindowFunction`s are firstly created on driver side and then - // serialized to executor side. These functions all reference a global singleton window - // partition size attribute reference, i.e., `SizeBasedWindowFunction.n`. Here we must collect - // the singleton instance created on driver side instead of using executor side - // `SizeBasedWindowFunction.n` to avoid binding failure caused by mismatching expression ID. - val partitionSize: Option[AttributeReference] = { - val aggs = functions.flatMap(_.collectFirst { case f: SizeBasedWindowFunction => f }) - aggs.headOption.map(_.n) - } - - // Check if there are any SizeBasedWindowFunctions. If there are, we add the partition size to - // the aggregation buffer. Note that the ordinal of the partition size value will always be 0. - partitionSize.foreach { n => - aggBufferAttributes += n - initialValues += NoOp - updateExpressions += NoOp - } - - // Add an AggregateFunction to the AggregateProcessor. - functions.foreach { - case agg: DeclarativeAggregate => - aggBufferAttributes ++= agg.aggBufferAttributes - initialValues ++= agg.initialValues - updateExpressions ++= agg.updateExpressions - evaluateExpressions += agg.evaluateExpression - case agg: ImperativeAggregate => - val offset = aggBufferAttributes.size - val imperative = BindReferences.bindReference(agg - .withNewInputAggBufferOffset(offset) - .withNewMutableAggBufferOffset(offset), - inputAttributes) - imperatives += imperative - aggBufferAttributes ++= imperative.aggBufferAttributes - val noOps = Seq.fill(imperative.aggBufferAttributes.size)(NoOp) - initialValues ++= noOps - updateExpressions ++= noOps - evaluateExpressions += imperative - case other => - sys.error(s"Unsupported Aggregate Function: $other") - } - - // Create the projections. - val initialProjection = newMutableProjection( - initialValues, - partitionSize.toSeq) - val updateProjection = newMutableProjection( - updateExpressions, - aggBufferAttributes ++ inputAttributes) - val evaluateProjection = newMutableProjection( - evaluateExpressions, - aggBufferAttributes) - - // Create the processor - new AggregateProcessor( - aggBufferAttributes.toArray, - initialProjection, - updateProjection, - evaluateProjection, - imperatives.toArray, - partitionSize.isDefined) - } -} - -/** - * This class manages the processing of a number of aggregate functions. See the documentation of - * the object for more information. - */ -private[execution] final class AggregateProcessor( - private[this] val bufferSchema: Array[AttributeReference], - private[this] val initialProjection: MutableProjection, - private[this] val updateProjection: MutableProjection, - private[this] val evaluateProjection: MutableProjection, - private[this] val imperatives: Array[ImperativeAggregate], - private[this] val trackPartitionSize: Boolean) { - - private[this] val join = new JoinedRow - private[this] val numImperatives = imperatives.length - private[this] val buffer = new SpecificMutableRow(bufferSchema.toSeq.map(_.dataType)) - initialProjection.target(buffer) - updateProjection.target(buffer) - - /** Create the initial state. */ - def initialize(size: Int): Unit = { - // Some initialization expressions are dependent on the partition size so we have to - // initialize the size before initializing all other fields, and we have to pass the buffer to - // the initialization projection. - if (trackPartitionSize) { - buffer.setInt(0, size) - } - initialProjection(buffer) - var i = 0 - while (i < numImperatives) { - imperatives(i).initialize(buffer) - i += 1 - } - } - - /** Update the buffer. */ - def update(input: InternalRow): Unit = { - updateProjection(join(buffer, input)) - var i = 0 - while (i < numImperatives) { - imperatives(i).update(buffer, input) - i += 1 - } - } - - /** Evaluate buffer. */ - def evaluate(target: MutableRow): Unit = - evaluateProjection.target(target)(buffer) -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/AggregateProcessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/AggregateProcessor.scala new file mode 100644 index 0000000000000..d3a46d020dbbf --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/AggregateProcessor.scala @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.window + +import scala.collection.mutable + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ + + +/** + * This class prepares and manages the processing of a number of [[AggregateFunction]]s within a + * single frame. The [[WindowFunctionFrame]] takes care of processing the frame in the correct way, + * this reduces the processing of a [[AggregateWindowFunction]] to processing the underlying + * [[AggregateFunction]]. All [[AggregateFunction]]s are processed in [[Complete]] mode. + * + * [[SizeBasedWindowFunction]]s are initialized in a slightly different way. These functions + * require the size of the partition processed, this value is exposed to them when the processor is + * constructed. + * + * Processing of distinct aggregates is currently not supported. + * + * The implementation is split into an object which takes care of construction, and a the actual + * processor class. + */ +private[window] object AggregateProcessor { + def apply( + functions: Array[Expression], + ordinal: Int, + inputAttributes: Seq[Attribute], + newMutableProjection: (Seq[Expression], Seq[Attribute]) => MutableProjection) + : AggregateProcessor = { + val aggBufferAttributes = mutable.Buffer.empty[AttributeReference] + val initialValues = mutable.Buffer.empty[Expression] + val updateExpressions = mutable.Buffer.empty[Expression] + val evaluateExpressions = mutable.Buffer.fill[Expression](ordinal)(NoOp) + val imperatives = mutable.Buffer.empty[ImperativeAggregate] + + // SPARK-14244: `SizeBasedWindowFunction`s are firstly created on driver side and then + // serialized to executor side. These functions all reference a global singleton window + // partition size attribute reference, i.e., `SizeBasedWindowFunction.n`. Here we must collect + // the singleton instance created on driver side instead of using executor side + // `SizeBasedWindowFunction.n` to avoid binding failure caused by mismatching expression ID. + val partitionSize: Option[AttributeReference] = { + val aggs = functions.flatMap(_.collectFirst { case f: SizeBasedWindowFunction => f }) + aggs.headOption.map(_.n) + } + + // Check if there are any SizeBasedWindowFunctions. If there are, we add the partition size to + // the aggregation buffer. Note that the ordinal of the partition size value will always be 0. + partitionSize.foreach { n => + aggBufferAttributes += n + initialValues += NoOp + updateExpressions += NoOp + } + + // Add an AggregateFunction to the AggregateProcessor. + functions.foreach { + case agg: DeclarativeAggregate => + aggBufferAttributes ++= agg.aggBufferAttributes + initialValues ++= agg.initialValues + updateExpressions ++= agg.updateExpressions + evaluateExpressions += agg.evaluateExpression + case agg: ImperativeAggregate => + val offset = aggBufferAttributes.size + val imperative = BindReferences.bindReference(agg + .withNewInputAggBufferOffset(offset) + .withNewMutableAggBufferOffset(offset), + inputAttributes) + imperatives += imperative + aggBufferAttributes ++= imperative.aggBufferAttributes + val noOps = Seq.fill(imperative.aggBufferAttributes.size)(NoOp) + initialValues ++= noOps + updateExpressions ++= noOps + evaluateExpressions += imperative + case other => + sys.error(s"Unsupported Aggregate Function: $other") + } + + // Create the projections. + val initialProj = newMutableProjection(initialValues, partitionSize.toSeq) + val updateProj = newMutableProjection(updateExpressions, aggBufferAttributes ++ inputAttributes) + val evalProj = newMutableProjection(evaluateExpressions, aggBufferAttributes) + + // Create the processor + new AggregateProcessor( + aggBufferAttributes.toArray, + initialProj, + updateProj, + evalProj, + imperatives.toArray, + partitionSize.isDefined) + } +} + +/** + * This class manages the processing of a number of aggregate functions. See the documentation of + * the object for more information. + */ +private[window] final class AggregateProcessor( + private[this] val bufferSchema: Array[AttributeReference], + private[this] val initialProjection: MutableProjection, + private[this] val updateProjection: MutableProjection, + private[this] val evaluateProjection: MutableProjection, + private[this] val imperatives: Array[ImperativeAggregate], + private[this] val trackPartitionSize: Boolean) { + + private[this] val join = new JoinedRow + private[this] val numImperatives = imperatives.length + private[this] val buffer = new SpecificMutableRow(bufferSchema.toSeq.map(_.dataType)) + initialProjection.target(buffer) + updateProjection.target(buffer) + + /** Create the initial state. */ + def initialize(size: Int): Unit = { + // Some initialization expressions are dependent on the partition size so we have to + // initialize the size before initializing all other fields, and we have to pass the buffer to + // the initialization projection. + if (trackPartitionSize) { + buffer.setInt(0, size) + } + initialProjection(buffer) + var i = 0 + while (i < numImperatives) { + imperatives(i).initialize(buffer) + i += 1 + } + } + + /** Update the buffer. */ + def update(input: InternalRow): Unit = { + updateProjection(join(buffer, input)) + var i = 0 + while (i < numImperatives) { + imperatives(i).update(buffer, input) + i += 1 + } + } + + /** Evaluate buffer. */ + def evaluate(target: MutableRow): Unit = + evaluateProjection.target(target)(buffer) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/BoundOrdering.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/BoundOrdering.scala new file mode 100644 index 0000000000000..d6a801954c1ac --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/BoundOrdering.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.window + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Projection + + +/** + * Function for comparing boundary values. + */ +private[window] abstract class BoundOrdering { + def compare(inputRow: InternalRow, inputIndex: Int, outputRow: InternalRow, outputIndex: Int): Int +} + +/** + * Compare the input index to the bound of the output index. + */ +private[window] final case class RowBoundOrdering(offset: Int) extends BoundOrdering { + override def compare( + inputRow: InternalRow, + inputIndex: Int, + outputRow: InternalRow, + outputIndex: Int): Int = + inputIndex - (outputIndex + offset) +} + +/** + * Compare the value of the input index to the value bound of the output index. + */ +private[window] final case class RangeBoundOrdering( + ordering: Ordering[InternalRow], + current: Projection, + bound: Projection) + extends BoundOrdering { + + override def compare( + inputRow: InternalRow, + inputIndex: Int, + outputRow: InternalRow, + outputIndex: Int): Int = + ordering.compare(current(inputRow), bound(outputRow)) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/RowBuffer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/RowBuffer.scala new file mode 100644 index 0000000000000..ee36c84251519 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/RowBuffer.scala @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.window + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.util.collection.unsafe.sort.{UnsafeExternalSorter, UnsafeSorterIterator} + + +/** + * The interface of row buffer for a partition. In absence of a buffer pool (with locking), the + * row buffer is used to materialize a partition of rows since we need to repeatedly scan these + * rows in window function processing. + */ +private[window] abstract class RowBuffer { + + /** Number of rows. */ + def size: Int + + /** Return next row in the buffer, null if no more left. */ + def next(): InternalRow + + /** Skip the next `n` rows. */ + def skip(n: Int): Unit + + /** Return a new RowBuffer that has the same rows. */ + def copy(): RowBuffer +} + +/** + * A row buffer based on ArrayBuffer (the number of rows is limited). + */ +private[window] class ArrayRowBuffer(buffer: ArrayBuffer[UnsafeRow]) extends RowBuffer { + + private[this] var cursor: Int = -1 + + /** Number of rows. */ + override def size: Int = buffer.length + + /** Return next row in the buffer, null if no more left. */ + override def next(): InternalRow = { + cursor += 1 + if (cursor < buffer.length) { + buffer(cursor) + } else { + null + } + } + + /** Skip the next `n` rows. */ + override def skip(n: Int): Unit = { + cursor += n + } + + /** Return a new RowBuffer that has the same rows. */ + override def copy(): RowBuffer = { + new ArrayRowBuffer(buffer) + } +} + +/** + * An external buffer of rows based on UnsafeExternalSorter. + */ +private[window] class ExternalRowBuffer(sorter: UnsafeExternalSorter, numFields: Int) + extends RowBuffer { + + private[this] val iter: UnsafeSorterIterator = sorter.getIterator + + private[this] val currentRow = new UnsafeRow(numFields) + + /** Number of rows. */ + override def size: Int = iter.getNumRecords() + + /** Return next row in the buffer, null if no more left. */ + override def next(): InternalRow = { + if (iter.hasNext) { + iter.loadNext() + currentRow.pointTo(iter.getBaseObject, iter.getBaseOffset, iter.getRecordLength) + currentRow + } else { + null + } + } + + /** Skip the next `n` rows. */ + override def skip(n: Int): Unit = { + var i = 0 + while (i < n && iter.hasNext) { + iter.loadNext() + i += 1 + } + } + + /** Return a new RowBuffer that has the same rows. */ + override def copy(): RowBuffer = { + new ExternalRowBuffer(sorter, numFields) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala new file mode 100644 index 0000000000000..7a6a30f120386 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala @@ -0,0 +1,412 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.window + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} +import org.apache.spark.sql.types.IntegerType +import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter + +/** + * This class calculates and outputs (windowed) aggregates over the rows in a single (sorted) + * partition. The aggregates are calculated for each row in the group. Special processing + * instructions, frames, are used to calculate these aggregates. Frames are processed in the order + * specified in the window specification (the ORDER BY ... clause). There are four different frame + * types: + * - Entire partition: The frame is the entire partition, i.e. + * UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING. For this case, window function will take all + * rows as inputs and be evaluated once. + * - Growing frame: We only add new rows into the frame, i.e. UNBOUNDED PRECEDING AND .... + * Every time we move to a new row to process, we add some rows to the frame. We do not remove + * rows from this frame. + * - Shrinking frame: We only remove rows from the frame, i.e. ... AND UNBOUNDED FOLLOWING. + * Every time we move to a new row to process, we remove some rows from the frame. We do not add + * rows to this frame. + * - Moving frame: Every time we move to a new row to process, we remove some rows from the frame + * and we add some rows to the frame. Examples are: + * 1 PRECEDING AND CURRENT ROW and 1 FOLLOWING AND 2 FOLLOWING. + * - Offset frame: The frame consist of one row, which is an offset number of rows away from the + * current row. Only [[OffsetWindowFunction]]s can be processed in an offset frame. + * + * Different frame boundaries can be used in Growing, Shrinking and Moving frames. A frame + * boundary can be either Row or Range based: + * - Row Based: A row based boundary is based on the position of the row within the partition. + * An offset indicates the number of rows above or below the current row, the frame for the + * current row starts or ends. For instance, given a row based sliding frame with a lower bound + * offset of -1 and a upper bound offset of +2. The frame for row with index 5 would range from + * index 4 to index 6. + * - Range based: A range based boundary is based on the actual value of the ORDER BY + * expression(s). An offset is used to alter the value of the ORDER BY expression, for + * instance if the current order by expression has a value of 10 and the lower bound offset + * is -3, the resulting lower bound for the current row will be 10 - 3 = 7. This however puts a + * number of constraints on the ORDER BY expressions: there can be only one expression and this + * expression must have a numerical data type. An exception can be made when the offset is 0, + * because no value modification is needed, in this case multiple and non-numeric ORDER BY + * expression are allowed. + * + * This is quite an expensive operator because every row for a single group must be in the same + * partition and partitions must be sorted according to the grouping and sort order. The operator + * requires the planner to take care of the partitioning and sorting. + * + * The operator is semi-blocking. The window functions and aggregates are calculated one group at + * a time, the result will only be made available after the processing for the entire group has + * finished. The operator is able to process different frame configurations at the same time. This + * is done by delegating the actual frame processing (i.e. calculation of the window functions) to + * specialized classes, see [[WindowFunctionFrame]], which take care of their own frame type: + * Entire Partition, Sliding, Growing & Shrinking. Boundary evaluation is also delegated to a pair + * of specialized classes: [[RowBoundOrdering]] & [[RangeBoundOrdering]]. + */ +case class WindowExec( + windowExpression: Seq[NamedExpression], + partitionSpec: Seq[Expression], + orderSpec: Seq[SortOrder], + child: SparkPlan) + extends UnaryExecNode { + + override def output: Seq[Attribute] = + child.output ++ windowExpression.map(_.toAttribute) + + override def requiredChildDistribution: Seq[Distribution] = { + if (partitionSpec.isEmpty) { + // Only show warning when the number of bytes is larger than 100 MB? + logWarning("No Partition Defined for Window operation! Moving all data to a single " + + "partition, this can cause serious performance degradation.") + AllTuples :: Nil + } else ClusteredDistribution(partitionSpec) :: Nil + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = + Seq(partitionSpec.map(SortOrder(_, Ascending)) ++ orderSpec) + + override def outputOrdering: Seq[SortOrder] = child.outputOrdering + + /** + * Create a bound ordering object for a given frame type and offset. A bound ordering object is + * used to determine which input row lies within the frame boundaries of an output row. + * + * This method uses Code Generation. It can only be used on the executor side. + * + * @param frameType to evaluate. This can either be Row or Range based. + * @param offset with respect to the row. + * @return a bound ordering object. + */ + private[this] def createBoundOrdering(frameType: FrameType, offset: Int): BoundOrdering = { + frameType match { + case RangeFrame => + val (exprs, current, bound) = if (offset == 0) { + // Use the entire order expression when the offset is 0. + val exprs = orderSpec.map(_.child) + val buildProjection = () => newMutableProjection(exprs, child.output) + (orderSpec, buildProjection(), buildProjection()) + } else if (orderSpec.size == 1) { + // Use only the first order expression when the offset is non-null. + val sortExpr = orderSpec.head + val expr = sortExpr.child + // Create the projection which returns the current 'value'. + val current = newMutableProjection(expr :: Nil, child.output) + // Flip the sign of the offset when processing the order is descending + val boundOffset = sortExpr.direction match { + case Descending => -offset + case Ascending => offset + } + // Create the projection which returns the current 'value' modified by adding the offset. + val boundExpr = Add(expr, Cast(Literal.create(boundOffset, IntegerType), expr.dataType)) + val bound = newMutableProjection(boundExpr :: Nil, child.output) + (sortExpr :: Nil, current, bound) + } else { + sys.error("Non-Zero range offsets are not supported for windows " + + "with multiple order expressions.") + } + // Construct the ordering. This is used to compare the result of current value projection + // to the result of bound value projection. This is done manually because we want to use + // Code Generation (if it is enabled). + val sortExprs = exprs.zipWithIndex.map { case (e, i) => + SortOrder(BoundReference(i, e.dataType, e.nullable), e.direction) + } + val ordering = newOrdering(sortExprs, Nil) + RangeBoundOrdering(ordering, current, bound) + case RowFrame => RowBoundOrdering(offset) + } + } + + /** + * Collection containing an entry for each window frame to process. Each entry contains a frames' + * WindowExpressions and factory function for the WindowFrameFunction. + */ + private[this] lazy val windowFrameExpressionFactoryPairs = { + type FrameKey = (String, FrameType, Option[Int], Option[Int]) + type ExpressionBuffer = mutable.Buffer[Expression] + val framedFunctions = mutable.Map.empty[FrameKey, (ExpressionBuffer, ExpressionBuffer)] + + // Add a function and its function to the map for a given frame. + def collect(tpe: String, fr: SpecifiedWindowFrame, e: Expression, fn: Expression): Unit = { + val key = (tpe, fr.frameType, FrameBoundary(fr.frameStart), FrameBoundary(fr.frameEnd)) + val (es, fns) = framedFunctions.getOrElseUpdate( + key, (ArrayBuffer.empty[Expression], ArrayBuffer.empty[Expression])) + es += e + fns += fn + } + + // Collect all valid window functions and group them by their frame. + windowExpression.foreach { x => + x.foreach { + case e @ WindowExpression(function, spec) => + val frame = spec.frameSpecification.asInstanceOf[SpecifiedWindowFrame] + function match { + case AggregateExpression(f, _, _, _) => collect("AGGREGATE", frame, e, f) + case f: AggregateWindowFunction => collect("AGGREGATE", frame, e, f) + case f: OffsetWindowFunction => collect("OFFSET", frame, e, f) + case f => sys.error(s"Unsupported window function: $f") + } + case _ => + } + } + + // Map the groups to a (unbound) expression and frame factory pair. + var numExpressions = 0 + framedFunctions.toSeq.map { + case (key, (expressions, functionSeq)) => + val ordinal = numExpressions + val functions = functionSeq.toArray + + // Construct an aggregate processor if we need one. + def processor = AggregateProcessor( + functions, + ordinal, + child.output, + (expressions, schema) => + newMutableProjection(expressions, schema, subexpressionEliminationEnabled)) + + // Create the factory + val factory = key match { + // Offset Frame + case ("OFFSET", RowFrame, Some(offset), Some(h)) if offset == h => + target: MutableRow => + new OffsetWindowFunctionFrame( + target, + ordinal, + // OFFSET frame functions are guaranteed be OffsetWindowFunctions. + functions.map(_.asInstanceOf[OffsetWindowFunction]), + child.output, + (expressions, schema) => + newMutableProjection(expressions, schema, subexpressionEliminationEnabled), + offset) + + // Growing Frame. + case ("AGGREGATE", frameType, None, Some(high)) => + target: MutableRow => { + new UnboundedPrecedingWindowFunctionFrame( + target, + processor, + createBoundOrdering(frameType, high)) + } + + // Shrinking Frame. + case ("AGGREGATE", frameType, Some(low), None) => + target: MutableRow => { + new UnboundedFollowingWindowFunctionFrame( + target, + processor, + createBoundOrdering(frameType, low)) + } + + // Moving Frame. + case ("AGGREGATE", frameType, Some(low), Some(high)) => + target: MutableRow => { + new SlidingWindowFunctionFrame( + target, + processor, + createBoundOrdering(frameType, low), + createBoundOrdering(frameType, high)) + } + + // Entire Partition Frame. + case ("AGGREGATE", frameType, None, None) => + target: MutableRow => { + new UnboundedWindowFunctionFrame(target, processor) + } + } + + // Keep track of the number of expressions. This is a side-effect in a map... + numExpressions += expressions.size + + // Create the Frame Expression - Factory pair. + (expressions, factory) + } + } + + /** + * Create the resulting projection. + * + * This method uses Code Generation. It can only be used on the executor side. + * + * @param expressions unbound ordered function expressions. + * @return the final resulting projection. + */ + private[this] def createResultProjection(expressions: Seq[Expression]): UnsafeProjection = { + val references = expressions.zipWithIndex.map{ case (e, i) => + // Results of window expressions will be on the right side of child's output + BoundReference(child.output.size + i, e.dataType, e.nullable) + } + val unboundToRefMap = expressions.zip(references).toMap + val patchedWindowExpression = windowExpression.map(_.transform(unboundToRefMap)) + UnsafeProjection.create( + child.output ++ patchedWindowExpression, + child.output) + } + + protected override def doExecute(): RDD[InternalRow] = { + // Unwrap the expressions and factories from the map. + val expressions = windowFrameExpressionFactoryPairs.flatMap(_._1) + val factories = windowFrameExpressionFactoryPairs.map(_._2).toArray + + // Start processing. + child.execute().mapPartitions { stream => + new Iterator[InternalRow] { + + // Get all relevant projections. + val result = createResultProjection(expressions) + val grouping = UnsafeProjection.create(partitionSpec, child.output) + + // Manage the stream and the grouping. + var nextRow: UnsafeRow = null + var nextGroup: UnsafeRow = null + var nextRowAvailable: Boolean = false + private[this] def fetchNextRow() { + nextRowAvailable = stream.hasNext + if (nextRowAvailable) { + nextRow = stream.next().asInstanceOf[UnsafeRow] + nextGroup = grouping(nextRow) + } else { + nextRow = null + nextGroup = null + } + } + fetchNextRow() + + // Manage the current partition. + val rows = ArrayBuffer.empty[UnsafeRow] + val inputFields = child.output.length + var sorter: UnsafeExternalSorter = null + var rowBuffer: RowBuffer = null + val windowFunctionResult = new SpecificMutableRow(expressions.map(_.dataType)) + val frames = factories.map(_(windowFunctionResult)) + val numFrames = frames.length + private[this] def fetchNextPartition() { + // Collect all the rows in the current partition. + // Before we start to fetch new input rows, make a copy of nextGroup. + val currentGroup = nextGroup.copy() + + // clear last partition + if (sorter != null) { + // the last sorter of this task will be cleaned up via task completion listener + sorter.cleanupResources() + sorter = null + } else { + rows.clear() + } + + while (nextRowAvailable && nextGroup == currentGroup) { + if (sorter == null) { + rows += nextRow.copy() + + if (rows.length >= 4096) { + // We will not sort the rows, so prefixComparator and recordComparator are null. + sorter = UnsafeExternalSorter.create( + TaskContext.get().taskMemoryManager(), + SparkEnv.get.blockManager, + SparkEnv.get.serializerManager, + TaskContext.get(), + null, + null, + 1024, + SparkEnv.get.memoryManager.pageSizeBytes, + SparkEnv.get.conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", + UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD), + false) + rows.foreach { r => + sorter.insertRecord(r.getBaseObject, r.getBaseOffset, r.getSizeInBytes, 0, false) + } + rows.clear() + } + } else { + sorter.insertRecord(nextRow.getBaseObject, nextRow.getBaseOffset, + nextRow.getSizeInBytes, 0, false) + } + fetchNextRow() + } + if (sorter != null) { + rowBuffer = new ExternalRowBuffer(sorter, inputFields) + } else { + rowBuffer = new ArrayRowBuffer(rows) + } + + // Setup the frames. + var i = 0 + while (i < numFrames) { + frames(i).prepare(rowBuffer.copy()) + i += 1 + } + + // Setup iteration + rowIndex = 0 + rowsSize = rowBuffer.size + } + + // Iteration + var rowIndex = 0 + var rowsSize = 0L + + override final def hasNext: Boolean = rowIndex < rowsSize || nextRowAvailable + + val join = new JoinedRow + override final def next(): InternalRow = { + // Load the next partition if we need to. + if (rowIndex >= rowsSize && nextRowAvailable) { + fetchNextPartition() + } + + if (rowIndex < rowsSize) { + // Get the results for the window frames. + var i = 0 + val current = rowBuffer.next() + while (i < numFrames) { + frames(i).write(rowIndex, current) + i += 1 + } + + // 'Merge' the input row with the window function result + join(current, windowFunctionResult) + rowIndex += 1 + + // Return the projection. + result(join) + } else throw new NoSuchElementException + } + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala new file mode 100644 index 0000000000000..2ab9faab7a59b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala @@ -0,0 +1,367 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.window + +import java.util + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp + + +/** + * A window function calculates the results of a number of window functions for a window frame. + * Before use a frame must be prepared by passing it all the rows in the current partition. After + * preparation the update method can be called to fill the output rows. + */ +private[window] abstract class WindowFunctionFrame { + /** + * Prepare the frame for calculating the results for a partition. + * + * @param rows to calculate the frame results for. + */ + def prepare(rows: RowBuffer): Unit + + /** + * Write the current results to the target row. + */ + def write(index: Int, current: InternalRow): Unit +} + +/** + * The offset window frame calculates frames containing LEAD/LAG statements. + * + * @param target to write results to. + * @param ordinal the ordinal is the starting offset at which the results of the window frame get + * written into the (shared) target row. The result of the frame expression with + * index 'i' will be written to the 'ordinal' + 'i' position in the target row. + * @param expressions to shift a number of rows. + * @param inputSchema required for creating a projection. + * @param newMutableProjection function used to create the projection. + * @param offset by which rows get moved within a partition. + */ +private[window] final class OffsetWindowFunctionFrame( + target: MutableRow, + ordinal: Int, + expressions: Array[OffsetWindowFunction], + inputSchema: Seq[Attribute], + newMutableProjection: (Seq[Expression], Seq[Attribute]) => MutableProjection, + offset: Int) + extends WindowFunctionFrame { + + /** Rows of the partition currently being processed. */ + private[this] var input: RowBuffer = null + + /** Index of the input row currently used for output. */ + private[this] var inputIndex = 0 + + /** + * Create the projection used when the offset row exists. + * Please note that this project always respect null input values (like PostgreSQL). + */ + private[this] val projection = { + // Collect the expressions and bind them. + val inputAttrs = inputSchema.map(_.withNullability(true)) + val boundExpressions = Seq.fill(ordinal)(NoOp) ++ expressions.toSeq.map { e => + BindReferences.bindReference(e.input, inputAttrs) + } + + // Create the projection. + newMutableProjection(boundExpressions, Nil).target(target) + } + + /** Create the projection used when the offset row DOES NOT exists. */ + private[this] val fillDefaultValue = { + // Collect the expressions and bind them. + val inputAttrs = inputSchema.map(_.withNullability(true)) + val boundExpressions = Seq.fill(ordinal)(NoOp) ++ expressions.toSeq.map { e => + if (e.default == null || e.default.foldable && e.default.eval() == null) { + // The default value is null. + Literal.create(null, e.dataType) + } else { + // The default value is an expression. + BindReferences.bindReference(e.default, inputAttrs) + } + } + + // Create the projection. + newMutableProjection(boundExpressions, Nil).target(target) + } + + override def prepare(rows: RowBuffer): Unit = { + input = rows + // drain the first few rows if offset is larger than zero + inputIndex = 0 + while (inputIndex < offset) { + input.next() + inputIndex += 1 + } + inputIndex = offset + } + + override def write(index: Int, current: InternalRow): Unit = { + if (inputIndex >= 0 && inputIndex < input.size) { + val r = input.next() + projection(r) + } else { + // Use default values since the offset row does not exist. + fillDefaultValue(current) + } + inputIndex += 1 + } +} + +/** + * The sliding window frame calculates frames with the following SQL form: + * ... BETWEEN 1 PRECEDING AND 1 FOLLOWING + * + * @param target to write results to. + * @param processor to calculate the row values with. + * @param lbound comparator used to identify the lower bound of an output row. + * @param ubound comparator used to identify the upper bound of an output row. + */ +private[window] final class SlidingWindowFunctionFrame( + target: MutableRow, + processor: AggregateProcessor, + lbound: BoundOrdering, + ubound: BoundOrdering) + extends WindowFunctionFrame { + + /** Rows of the partition currently being processed. */ + private[this] var input: RowBuffer = null + + /** The next row from `input`. */ + private[this] var nextRow: InternalRow = null + + /** The rows within current sliding window. */ + private[this] val buffer = new util.ArrayDeque[InternalRow]() + + /** + * Index of the first input row with a value greater than the upper bound of the current + * output row. + */ + private[this] var inputHighIndex = 0 + + /** + * Index of the first input row with a value equal to or greater than the lower bound of the + * current output row. + */ + private[this] var inputLowIndex = 0 + + /** Prepare the frame for calculating a new partition. Reset all variables. */ + override def prepare(rows: RowBuffer): Unit = { + input = rows + nextRow = rows.next() + inputHighIndex = 0 + inputLowIndex = 0 + buffer.clear() + } + + /** Write the frame columns for the current row to the given target row. */ + override def write(index: Int, current: InternalRow): Unit = { + var bufferUpdated = index == 0 + + // Add all rows to the buffer for which the input row value is equal to or less than + // the output row upper bound. + while (nextRow != null && ubound.compare(nextRow, inputHighIndex, current, index) <= 0) { + buffer.add(nextRow.copy()) + nextRow = input.next() + inputHighIndex += 1 + bufferUpdated = true + } + + // Drop all rows from the buffer for which the input row value is smaller than + // the output row lower bound. + while (!buffer.isEmpty && lbound.compare(buffer.peek(), inputLowIndex, current, index) < 0) { + buffer.remove() + inputLowIndex += 1 + bufferUpdated = true + } + + // Only recalculate and update when the buffer changes. + if (bufferUpdated) { + processor.initialize(input.size) + val iter = buffer.iterator() + while (iter.hasNext) { + processor.update(iter.next()) + } + processor.evaluate(target) + } + } +} + +/** + * The unbounded window frame calculates frames with the following SQL forms: + * ... (No Frame Definition) + * ... BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING + * + * Its results are the same for each and every row in the partition. This class can be seen as a + * special case of a sliding window, but is optimized for the unbound case. + * + * @param target to write results to. + * @param processor to calculate the row values with. + */ +private[window] final class UnboundedWindowFunctionFrame( + target: MutableRow, + processor: AggregateProcessor) + extends WindowFunctionFrame { + + /** Prepare the frame for calculating a new partition. Process all rows eagerly. */ + override def prepare(rows: RowBuffer): Unit = { + val size = rows.size + processor.initialize(size) + var i = 0 + while (i < size) { + processor.update(rows.next()) + i += 1 + } + } + + /** Write the frame columns for the current row to the given target row. */ + override def write(index: Int, current: InternalRow): Unit = { + // Unfortunately we cannot assume that evaluation is deterministic. So we need to re-evaluate + // for each row. + processor.evaluate(target) + } +} + +/** + * The UnboundPreceding window frame calculates frames with the following SQL form: + * ... BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW + * + * There is only an upper bound. Very common use cases are for instance running sums or counts + * (row_number). Technically this is a special case of a sliding window. However a sliding window + * has to maintain a buffer, and it must do a full evaluation everytime the buffer changes. This + * is not the case when there is no lower bound, given the additive nature of most aggregates + * streaming updates and partial evaluation suffice and no buffering is needed. + * + * @param target to write results to. + * @param processor to calculate the row values with. + * @param ubound comparator used to identify the upper bound of an output row. + */ +private[window] final class UnboundedPrecedingWindowFunctionFrame( + target: MutableRow, + processor: AggregateProcessor, + ubound: BoundOrdering) + extends WindowFunctionFrame { + + /** Rows of the partition currently being processed. */ + private[this] var input: RowBuffer = null + + /** The next row from `input`. */ + private[this] var nextRow: InternalRow = null + + /** + * Index of the first input row with a value greater than the upper bound of the current + * output row. + */ + private[this] var inputIndex = 0 + + /** Prepare the frame for calculating a new partition. */ + override def prepare(rows: RowBuffer): Unit = { + input = rows + nextRow = rows.next() + inputIndex = 0 + processor.initialize(input.size) + } + + /** Write the frame columns for the current row to the given target row. */ + override def write(index: Int, current: InternalRow): Unit = { + var bufferUpdated = index == 0 + + // Add all rows to the aggregates for which the input row value is equal to or less than + // the output row upper bound. + while (nextRow != null && ubound.compare(nextRow, inputIndex, current, index) <= 0) { + processor.update(nextRow) + nextRow = input.next() + inputIndex += 1 + bufferUpdated = true + } + + // Only recalculate and update when the buffer changes. + if (bufferUpdated) { + processor.evaluate(target) + } + } +} + +/** + * The UnboundFollowing window frame calculates frames with the following SQL form: + * ... BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING + * + * There is only an upper bound. This is a slightly modified version of the sliding window. The + * sliding window operator has to check if both upper and the lower bound change when a new row + * gets processed, where as the unbounded following only has to check the lower bound. + * + * This is a very expensive operator to use, O(n * (n - 1) /2), because we need to maintain a + * buffer and must do full recalculation after each row. Reverse iteration would be possible, if + * the commutativity of the used window functions can be guaranteed. + * + * @param target to write results to. + * @param processor to calculate the row values with. + * @param lbound comparator used to identify the lower bound of an output row. + */ +private[window] final class UnboundedFollowingWindowFunctionFrame( + target: MutableRow, + processor: AggregateProcessor, + lbound: BoundOrdering) + extends WindowFunctionFrame { + + /** Rows of the partition currently being processed. */ + private[this] var input: RowBuffer = null + + /** + * Index of the first input row with a value equal to or greater than the lower bound of the + * current output row. + */ + private[this] var inputIndex = 0 + + /** Prepare the frame for calculating a new partition. */ + override def prepare(rows: RowBuffer): Unit = { + input = rows + inputIndex = 0 + } + + /** Write the frame columns for the current row to the given target row. */ + override def write(index: Int, current: InternalRow): Unit = { + var bufferUpdated = index == 0 + + // Duplicate the input to have a new iterator + val tmp = input.copy() + + // Drop all rows from the buffer for which the input row value is smaller than + // the output row lower bound. + tmp.skip(inputIndex) + var nextRow = tmp.next() + while (nextRow != null && lbound.compare(nextRow, inputIndex, current, index) < 0) { + nextRow = tmp.next() + inputIndex += 1 + bufferUpdated = true + } + + // Only recalculate and update when the buffer changes. + if (bufferUpdated) { + processor.initialize(input.size) + while (nextRow != null) { + processor.update(nextRow) + nextRow = tmp.next() + } + processor.evaluate(target) + } + } +} From 2f84a686604b298537bfd4d087b41594d2aa7ec6 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 27 Sep 2016 14:14:27 -0700 Subject: [PATCH 138/306] [SPARK-17618] Guard against invalid comparisons between UnsafeRow and other formats This patch ports changes from #15185 to Spark 2.x. In that patch, a correctness bug in Spark 1.6.x which was caused by an invalid `equals()` comparison between an `UnsafeRow` and another row of a different format. Spark 2.x is not affected by that specific correctness bug but it can still reap the error-prevention benefits of that patch's changes, which modify ``UnsafeRow.equals()` to throw an IllegalArgumentException if it is called with an object that is not an `UnsafeRow`. Author: Josh Rosen Closes #15265 from JoshRosen/SPARK-17618-master. --- .../apache/spark/sql/catalyst/expressions/UnsafeRow.java | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index dd2f39eb816f2..9027652d57f14 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -31,6 +31,7 @@ import com.esotericsoftware.kryo.io.Input; import com.esotericsoftware.kryo.io.Output; +import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; @@ -577,8 +578,12 @@ public boolean equals(Object other) { return (sizeInBytes == o.sizeInBytes) && ByteArrayMethods.arrayEquals(baseObject, baseOffset, o.baseObject, o.baseOffset, sizeInBytes); + } else if (!(other instanceof InternalRow)) { + return false; + } else { + throw new IllegalArgumentException( + "Cannot compare UnsafeRow to " + other.getClass().getName()); } - return false; } /** From e7bce9e1876de6ee975ccc89351db58119674aef Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 27 Sep 2016 16:00:39 -0700 Subject: [PATCH 139/306] [SPARK-17056][CORE] Fix a wrong assert regarding unroll memory in MemoryStore ## What changes were proposed in this pull request? There is an assert in MemoryStore's putIteratorAsValues method which is used to check if unroll memory is not released too much. This assert looks wrong. ## How was this patch tested? Jenkins tests. Author: Liang-Chi Hsieh Closes #14642 from viirya/fix-unroll-memory. --- .../scala/org/apache/spark/storage/memory/MemoryStore.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala index 205d469f48144..095d32407f345 100644 --- a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala @@ -273,7 +273,7 @@ private[spark] class MemoryStore( blockId, Utils.bytesToString(size), Utils.bytesToString(maxMemory - blocksMemoryUsed))) Right(size) } else { - assert(currentUnrollMemoryForThisTask >= currentUnrollMemoryForThisTask, + assert(currentUnrollMemoryForThisTask >= unrollMemoryUsedByThisBlock, "released too much unroll memory") Left(new PartiallyUnrolledIterator( this, From b03b4adf6d8f4c6d92575c0947540cb474bf7de1 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 27 Sep 2016 17:52:57 -0700 Subject: [PATCH 140/306] [SPARK-17666] Ensure that RecordReaders are closed by data source file scans ## What changes were proposed in this pull request? This patch addresses a potential cause of resource leaks in data source file scans. As reported in [SPARK-17666](https://issues.apache.org/jira/browse/SPARK-17666), tasks which do not fully-consume their input may cause file handles / network connections (e.g. S3 connections) to be leaked. Spark's `NewHadoopRDD` uses a TaskContext callback to [close its record readers](https://github.com/apache/spark/blame/master/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala#L208), but the new data source file scans will only close record readers once their iterators are fully-consumed. This patch modifies `RecordReaderIterator` and `HadoopFileLinesReader` to add `close()` methods and modifies all six implementations of `FileFormat.buildReader()` to register TaskContext task completion callbacks to guarantee that cleanup is eventually performed. ## How was this patch tested? Tested manually for now. Author: Josh Rosen Closes #15245 from JoshRosen/SPARK-17666-close-recordreader. --- .../ml/source/libsvm/LibSVMRelation.scala | 7 +++++-- .../datasources/HadoopFileLinesReader.scala | 6 +++++- .../datasources/RecordReaderIterator.scala | 21 +++++++++++++++++-- .../datasources/csv/CSVFileFormat.scala | 5 ++++- .../datasources/json/JsonFileFormat.scala | 5 ++++- .../parquet/ParquetFileFormat.scala | 3 ++- .../datasources/text/TextFileFormat.scala | 2 ++ .../spark/sql/hive/orc/OrcFileFormat.scala | 6 +++++- 8 files changed, 46 insertions(+), 9 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala index 5c79c6905801c..8577803743c8e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -25,6 +25,7 @@ import org.apache.hadoop.io.{NullWritable, Text} import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext} import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat +import org.apache.spark.TaskContext import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.mllib.util.MLUtils @@ -159,8 +160,10 @@ private[libsvm] class LibSVMFileFormat extends TextBasedFileFormat with DataSour sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) (file: PartitionedFile) => { - val points = - new HadoopFileLinesReader(file, broadcastedHadoopConf.value.value) + val linesReader = new HadoopFileLinesReader(file, broadcastedHadoopConf.value.value) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) + + val points = linesReader .map(_.toString.trim) .filterNot(line => line.isEmpty || line.startsWith("#")) .map { line => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReader.scala index 18f9b55895a64..83cf26c63a175 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReader.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.datasources +import java.io.Closeable import java.net.URI import org.apache.hadoop.conf.Configuration @@ -30,7 +31,8 @@ import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl * An adaptor from a [[PartitionedFile]] to an [[Iterator]] of [[Text]], which are all of the lines * in that file. */ -class HadoopFileLinesReader(file: PartitionedFile, conf: Configuration) extends Iterator[Text] { +class HadoopFileLinesReader( + file: PartitionedFile, conf: Configuration) extends Iterator[Text] with Closeable { private val iterator = { val fileSplit = new FileSplit( new Path(new URI(file.filePath)), @@ -48,4 +50,6 @@ class HadoopFileLinesReader(file: PartitionedFile, conf: Configuration) extends override def hasNext: Boolean = iterator.hasNext override def next(): Text = iterator.next() + + override def close(): Unit = iterator.close() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/RecordReaderIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/RecordReaderIterator.scala index f03ae94d55838..938af25a96844 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/RecordReaderIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/RecordReaderIterator.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources +import java.io.Closeable + import org.apache.hadoop.mapreduce.RecordReader import org.apache.spark.sql.catalyst.InternalRow @@ -27,7 +29,8 @@ import org.apache.spark.sql.catalyst.InternalRow * Note that this returns [[Object]]s instead of [[InternalRow]] because we rely on erasure to pass * column batches by pretending they are rows. */ -class RecordReaderIterator[T](rowReader: RecordReader[_, T]) extends Iterator[T] { +class RecordReaderIterator[T]( + private[this] var rowReader: RecordReader[_, T]) extends Iterator[T] with Closeable { private[this] var havePair = false private[this] var finished = false @@ -38,7 +41,7 @@ class RecordReaderIterator[T](rowReader: RecordReader[_, T]) extends Iterator[T] // Close and release the reader here; close() will also be called when the task // completes, but for tasks that read from many files, it helps to release the // resources early. - rowReader.close() + close() } havePair = !finished } @@ -52,4 +55,18 @@ class RecordReaderIterator[T](rowReader: RecordReader[_, T]) extends Iterator[T] havePair = false rowReader.getCurrentValue } + + override def close(): Unit = { + if (rowReader != null) { + try { + // Close the reader and release it. Note: it's very important that we don't close the + // reader more than once, since that exposes us to MAPREDUCE-5918 when running against + // older Hadoop 2.x releases. That bug can lead to non-deterministic corruption issues + // when reading compressed input. + rowReader.close() + } finally { + rowReader = null + } + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index 9a118fe5a273d..9610746a81ef7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -25,6 +25,7 @@ import org.apache.hadoop.io.{LongWritable, Text} import org.apache.hadoop.mapred.TextInputFormat import org.apache.hadoop.mapreduce._ +import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow @@ -112,7 +113,9 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { (file: PartitionedFile) => { val lineIterator = { val conf = broadcastedHadoopConf.value.value - new HadoopFileLinesReader(file, conf).map { line => + val linesReader = new HadoopFileLinesReader(file, conf) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) + linesReader.map { line => new String(line.getBytes, 0, line.getLength, csvOptions.charset) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala index 7421314df7aa5..6882a6cdcac26 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala @@ -27,6 +27,7 @@ import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext} import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat +import org.apache.spark.TaskContext import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.{AnalysisException, Row, SparkSession} @@ -104,7 +105,9 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { .getOrElse(sparkSession.sessionState.conf.columnNameOfCorruptRecord) (file: PartitionedFile) => { - val lines = new HadoopFileLinesReader(file, broadcastedHadoopConf.value.value).map(_.toString) + val linesReader = new HadoopFileLinesReader(file, broadcastedHadoopConf.value.value) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) + val lines = linesReader.map(_.toString) val parser = new JacksonParser(requiredSchema, columnNameOfCorruptRecord, parsedOptions) lines.flatMap(parser.parse) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index e7c3545630fea..4a308ff1a32f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -37,7 +37,7 @@ import org.apache.parquet.hadoop.util.ContextUtil import org.apache.parquet.schema.MessageType import org.slf4j.bridge.SLF4JBridgeHandler -import org.apache.spark.SparkException +import org.apache.spark.{SparkException, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow @@ -388,6 +388,7 @@ class ParquetFileFormat } val iter = new RecordReaderIterator(parquetReader) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => iter.close())) // UnsafeRowParquetRecordReader appends the columns internally to avoid another copy. if (parquetReader.isInstanceOf[VectorizedParquetRecordReader] && diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala index a0c3fd53fb53b..a875b01ec2d7a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala @@ -23,6 +23,7 @@ import org.apache.hadoop.io.{NullWritable, Text} import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext} import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat +import org.apache.spark.TaskContext import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow @@ -100,6 +101,7 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister { (file: PartitionedFile) => { val reader = new HadoopFileLinesReader(file, broadcastedHadoopConf.value.value) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => reader.close())) if (requiredSchema.isEmpty) { val emptyUnsafeRow = new UnsafeRow(0) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala index 03b508e11aa76..15b72d8d2179f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala @@ -31,6 +31,7 @@ import org.apache.hadoop.mapred.{JobConf, OutputFormat => MapRedOutputFormat, Re import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat, FileSplit} +import org.apache.spark.TaskContext import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -146,12 +147,15 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable new SparkOrcNewRecordReader(orcReader, conf, fileSplit.getStart, fileSplit.getLength) } + val recordsIterator = new RecordReaderIterator[OrcStruct](orcRecordReader) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => recordsIterator.close())) + // Unwraps `OrcStruct`s to `UnsafeRow`s OrcRelation.unwrapOrcStructs( conf, requiredSchema, Some(orcRecordReader.getObjectInspector.asInstanceOf[StructObjectInspector]), - new RecordReaderIterator[OrcStruct](orcRecordReader)) + recordsIterator) } } } From 4a83395681e0bca356363a6cfb25c952f235560d Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 27 Sep 2016 21:19:59 -0700 Subject: [PATCH 141/306] [SPARK-17499][SPARKR][FOLLOWUP] Check null first for layers in spark.mlp to avoid warnings in test results ## What changes were proposed in this pull request? Some tests in `test_mllib.r` are as below: ```r expect_error(spark.mlp(df, layers = NULL), "layers must be a integer vector with length > 1.") expect_error(spark.mlp(df, layers = c()), "layers must be a integer vector with length > 1.") ``` The problem is, `is.na` is internally called via `na.omit` in `spark.mlp` which causes warnings as below: ``` Warnings ----------------------------------------------------------------------- 1. spark.mlp (test_mllib.R#400) - is.na() applied to non-(list or vector) of type 'NULL' 2. spark.mlp (test_mllib.R#401) - is.na() applied to non-(list or vector) of type 'NULL' ``` ## How was this patch tested? Manually tested. Also, Jenkins tests and AppVeyor. Author: hyukjinkwon Closes #15232 from HyukjinKwon/remove-warnnings. --- R/pkg/R/mllib.R | 3 +++ 1 file changed, 3 insertions(+) diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index 971c16658fe9a..b901307f8f409 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -696,6 +696,9 @@ setMethod("predict", signature(object = "KMeansModel"), setMethod("spark.mlp", signature(data = "SparkDataFrame"), function(data, layers, blockSize = 128, solver = "l-bfgs", maxIter = 100, tol = 1E-6, stepSize = 0.03, seed = NULL) { + if (is.null(layers)) { + stop ("layers must be a integer vector with length > 1.") + } layers <- as.integer(na.omit(layers)) if (length(layers) <= 1) { stop ("layers must be a integer vector with length > 1.") From b2a7eedcddf0e682ff46afd1b764d0b81ccdf395 Mon Sep 17 00:00:00 2001 From: Shuai Lin Date: Wed, 28 Sep 2016 06:12:48 -0400 Subject: [PATCH 142/306] [SPARK-17017][ML][MLLIB][ML][DOC] Updated the ml/mllib feature selection docs for ChiSqSelector ## What changes were proposed in this pull request? A follow up for #14597 to update feature selection docs about ChiSqSelector. ## How was this patch tested? Generated html docs. It can be previewed at: * ml: http://sparkdocs.lins05.pw/spark-17017/ml-features.html#chisqselector * mllib: http://sparkdocs.lins05.pw/spark-17017/mllib-feature-extraction.html#chisqselector Author: Shuai Lin Closes #15236 from lins05/spark-17017-update-docs-for-chisq-selector-fpr. --- docs/ml-features.md | 14 ++++++++++---- docs/mllib-feature-extraction.md | 14 ++++++++++---- 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/docs/ml-features.md b/docs/ml-features.md index a39b31c8f7ffc..a7f710fa52e64 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -1331,10 +1331,16 @@ for more details on the API. ## ChiSqSelector `ChiSqSelector` stands for Chi-Squared feature selection. It operates on labeled data with -categorical features. ChiSqSelector orders features based on a -[Chi-Squared test of independence](https://en.wikipedia.org/wiki/Chi-squared_test) -from the class, and then filters (selects) the top features which the class label depends on the -most. This is akin to yielding the features with the most predictive power. +categorical features. ChiSqSelector uses the +[Chi-Squared test of independence](https://en.wikipedia.org/wiki/Chi-squared_test) to decide which +features to choose. It supports three selection methods: `KBest`, `Percentile` and `FPR`: + +* `KBest` chooses the `k` top features according to a chi-squared test. This is akin to yielding the features with the most predictive power. +* `Percentile` is similar to `KBest` but chooses a fraction of all features instead of a fixed number. +* `FPR` chooses all features whose false positive rate meets some threshold. + +By default, the selection method is `KBest`, the default number of top features is 50. User can use +`setNumTopFeatures`, `setPercentile` and `setAlpha` to set different selection methods. **Examples** diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md index 353d391249973..87e1e027e945b 100644 --- a/docs/mllib-feature-extraction.md +++ b/docs/mllib-feature-extraction.md @@ -225,10 +225,16 @@ features for use in model construction. It reduces the size of the feature space both speed and statistical learning behavior. [`ChiSqSelector`](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) implements -Chi-Squared feature selection. It operates on labeled data with categorical features. -`ChiSqSelector` orders features based on a Chi-Squared test of independence from the class, -and then filters (selects) the top features which the class label depends on the most. -This is akin to yielding the features with the most predictive power. +Chi-Squared feature selection. It operates on labeled data with categorical features. ChiSqSelector uses the +[Chi-Squared test of independence](https://en.wikipedia.org/wiki/Chi-squared_test) to decide which +features to choose. It supports three selection methods: `KBest`, `Percentile` and `FPR`: + +* `KBest` chooses the `k` top features according to a chi-squared test. This is akin to yielding the features with the most predictive power. +* `Percentile` is similar to `KBest` but chooses a fraction of all features instead of a fixed number. +* `FPR` chooses all features whose false positive rate meets some threshold. + +By default, the selection method is `KBest`, the default number of top features is 50. User can use +`setNumTopFeatures`, `setPercentile` and `setAlpha` to set different selection methods. The number of features to select can be tuned using a held-out validation set. From 2190037757a81d3172f75227f7891d968e1f0d90 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Wed, 28 Sep 2016 06:19:04 -0400 Subject: [PATCH 143/306] [MINOR][PYSPARK][DOCS] Fix examples in PySpark documentation ## What changes were proposed in this pull request? This PR proposes to fix wrongly indented examples in PySpark documentation ``` - >>> json_sdf = spark.readStream.format("json")\ - .schema(sdf_schema)\ - .load(tempfile.mkdtemp()) + >>> json_sdf = spark.readStream.format("json") \\ + ... .schema(sdf_schema) \\ + ... .load(tempfile.mkdtemp()) ``` ``` - people.filter(people.age > 30).join(department, people.deptId == department.id)\ + people.filter(people.age > 30).join(department, people.deptId == department.id) \\ ``` ``` - >>> examples = [LabeledPoint(1.1, Vectors.sparse(3, [(0, 1.23), (2, 4.56)])), \ - LabeledPoint(0.0, Vectors.dense([1.01, 2.02, 3.03]))] + >>> examples = [LabeledPoint(1.1, Vectors.sparse(3, [(0, 1.23), (2, 4.56)])), + ... LabeledPoint(0.0, Vectors.dense([1.01, 2.02, 3.03]))] ``` ``` - >>> examples = [LabeledPoint(1.1, Vectors.sparse(3, [(0, -1.23), (2, 4.56e-7)])), \ - LabeledPoint(0.0, Vectors.dense([1.01, 2.02, 3.03]))] + >>> examples = [LabeledPoint(1.1, Vectors.sparse(3, [(0, -1.23), (2, 4.56e-7)])), + ... LabeledPoint(0.0, Vectors.dense([1.01, 2.02, 3.03]))] ``` ``` - ... for x in iterator: - ... print(x) + ... for x in iterator: + ... print(x) ``` ## How was this patch tested? Manually tested. **Before** ![2016-09-26 8 36 02](https://cloud.githubusercontent.com/assets/6477701/18834471/05c7a478-8431-11e6-94bb-09aa37b12ddb.png) ![2016-09-26 9 22 16](https://cloud.githubusercontent.com/assets/6477701/18834472/06c8735c-8431-11e6-8775-78631eab0411.png) 2016-09-27 2 29 27 2016-09-27 2 29 58 2016-09-27 2 30 05 **After** ![2016-09-26 9 29 47](https://cloud.githubusercontent.com/assets/6477701/18834467/0367f9da-8431-11e6-86d9-a490d3297339.png) ![2016-09-26 9 30 24](https://cloud.githubusercontent.com/assets/6477701/18834463/f870fae0-8430-11e6-9482-01fc47898492.png) 2016-09-27 2 28 19 2016-09-27 3 50 59 2016-09-27 3 51 03 Author: hyukjinkwon Closes #15242 from HyukjinKwon/minor-example-pyspark. --- python/pyspark/mllib/util.py | 8 ++++---- python/pyspark/rdd.py | 4 ++-- python/pyspark/sql/dataframe.py | 2 +- python/pyspark/sql/streaming.py | 6 +++--- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index 48867a08dbfad..ed6fd4bca4c54 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -140,8 +140,8 @@ def saveAsLibSVMFile(data, dir): >>> from pyspark.mllib.regression import LabeledPoint >>> from glob import glob >>> from pyspark.mllib.util import MLUtils - >>> examples = [LabeledPoint(1.1, Vectors.sparse(3, [(0, 1.23), (2, 4.56)])), \ - LabeledPoint(0.0, Vectors.dense([1.01, 2.02, 3.03]))] + >>> examples = [LabeledPoint(1.1, Vectors.sparse(3, [(0, 1.23), (2, 4.56)])), + ... LabeledPoint(0.0, Vectors.dense([1.01, 2.02, 3.03]))] >>> tempFile = NamedTemporaryFile(delete=True) >>> tempFile.close() >>> MLUtils.saveAsLibSVMFile(sc.parallelize(examples), tempFile.name) @@ -166,8 +166,8 @@ def loadLabeledPoints(sc, path, minPartitions=None): >>> from tempfile import NamedTemporaryFile >>> from pyspark.mllib.util import MLUtils >>> from pyspark.mllib.regression import LabeledPoint - >>> examples = [LabeledPoint(1.1, Vectors.sparse(3, [(0, -1.23), (2, 4.56e-7)])), \ - LabeledPoint(0.0, Vectors.dense([1.01, 2.02, 3.03]))] + >>> examples = [LabeledPoint(1.1, Vectors.sparse(3, [(0, -1.23), (2, 4.56e-7)])), + ... LabeledPoint(0.0, Vectors.dense([1.01, 2.02, 3.03]))] >>> tempFile = NamedTemporaryFile(delete=True) >>> tempFile.close() >>> sc.parallelize(examples, 1).saveAsTextFile(tempFile.name) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 0508235c1c9ee..5fb10f86f4692 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -754,8 +754,8 @@ def foreachPartition(self, f): Applies a function to each partition of this RDD. >>> def f(iterator): - ... for x in iterator: - ... print(x) + ... for x in iterator: + ... print(x) >>> sc.parallelize([1, 2, 3, 4, 5]).foreachPartition(f) """ def func(it): diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 0f7d8fba3bd54..0ac481a8a8b56 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -61,7 +61,7 @@ class DataFrame(object): people = sqlContext.read.parquet("...") department = sqlContext.read.parquet("...") - people.filter(people.age > 30).join(department, people.deptId == department.id)\ + people.filter(people.age > 30).join(department, people.deptId == department.id) \\ .groupBy(department.name, "gender").agg({"salary": "avg", "age": "max"}) .. versionadded:: 1.3 diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index cbd827950bbb4..4e438fd5bee22 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -315,9 +315,9 @@ def load(self, path=None, format=None, schema=None, **options): :param schema: optional :class:`pyspark.sql.types.StructType` for the input schema. :param options: all other string options - >>> json_sdf = spark.readStream.format("json")\ - .schema(sdf_schema)\ - .load(tempfile.mkdtemp()) + >>> json_sdf = spark.readStream.format("json") \\ + ... .schema(sdf_schema) \\ + ... .load(tempfile.mkdtemp()) >>> json_sdf.isStreaming True >>> json_sdf.schema == sdf_schema From 46d1203bf2d01b219c4efc7e0e77a844c0c664da Mon Sep 17 00:00:00 2001 From: w00228970 Date: Wed, 28 Sep 2016 12:02:59 -0700 Subject: [PATCH 144/306] [SPARK-17644][CORE] Do not add failedStages when abortStage for fetch failure ## What changes were proposed in this pull request? | Time |Thread 1 , Job1 | Thread 2 , Job2 | |:-------------:|:-------------:|:-----:| | 1 | abort stage due to FetchFailed | | | 2 | failedStages += failedStage | | | 3 | | task failed due to FetchFailed | | 4 | | can not post ResubmitFailedStages because failedStages is not empty | Then job2 of thread2 never resubmit the failed stage and hang. We should not add the failedStages when abortStage for fetch failure ## How was this patch tested? added unit test Author: w00228970 Author: wangfei Closes #15213 from scwf/dag-resubmit. --- .../apache/spark/scheduler/DAGScheduler.scala | 24 ++++---- .../spark/scheduler/DAGSchedulerSuite.scala | 58 ++++++++++++++++++- 2 files changed, 70 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 5ea0b48f6e4c4..f2517401cb76b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1263,18 +1263,20 @@ class DAGScheduler( s"has failed the maximum allowable number of " + s"times: ${Stage.MAX_CONSECUTIVE_FETCH_FAILURES}. " + s"Most recent failure reason: ${failureMessage}", None) - } else if (failedStages.isEmpty) { - // Don't schedule an event to resubmit failed stages if failed isn't empty, because - // in that case the event will already have been scheduled. - // TODO: Cancel running tasks in the stage - logInfo(s"Resubmitting $mapStage (${mapStage.name}) and " + - s"$failedStage (${failedStage.name}) due to fetch failure") - messageScheduler.schedule(new Runnable { - override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages) - }, DAGScheduler.RESUBMIT_TIMEOUT, TimeUnit.MILLISECONDS) + } else { + if (failedStages.isEmpty) { + // Don't schedule an event to resubmit failed stages if failed isn't empty, because + // in that case the event will already have been scheduled. + // TODO: Cancel running tasks in the stage + logInfo(s"Resubmitting $mapStage (${mapStage.name}) and " + + s"$failedStage (${failedStage.name}) due to fetch failure") + messageScheduler.schedule(new Runnable { + override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages) + }, DAGScheduler.RESUBMIT_TIMEOUT, TimeUnit.MILLISECONDS) + } + failedStages += failedStage + failedStages += mapStage } - failedStages += failedStage - failedStages += mapStage // Mark the map whose fetch failed as broken in the map stage if (mapId != -1) { mapStage.removeOutputLoc(mapId, bmAddress) diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 6787b302614e6..bec95d13d193a 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.scheduler import java.util.Properties +import java.util.concurrent.atomic.AtomicBoolean import scala.annotation.meta.param import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map} @@ -31,7 +32,7 @@ import org.apache.spark._ import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.SchedulingMode.SchedulingMode -import org.apache.spark.shuffle.MetadataFetchFailedException +import org.apache.spark.shuffle.{FetchFailedException, MetadataFetchFailedException} import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster} import org.apache.spark.util.{AccumulatorContext, AccumulatorV2, CallSite, LongAccumulator, Utils} @@ -2105,6 +2106,61 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou assert(scheduler.getShuffleDependencies(rddE) === Set(shuffleDepA, shuffleDepC)) } + test("SPARK-17644: After one stage is aborted for too many failed attempts, subsequent stages" + + "still behave correctly on fetch failures") { + // Runs a job that always encounters a fetch failure, so should eventually be aborted + def runJobWithPersistentFetchFailure: Unit = { + val rdd1 = sc.makeRDD(Array(1, 2, 3, 4), 2).map(x => (x, 1)).groupByKey() + val shuffleHandle = + rdd1.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleHandle + rdd1.map { + case (x, _) if (x == 1) => + throw new FetchFailedException( + BlockManagerId("1", "1", 1), shuffleHandle.shuffleId, 0, 0, "test") + case (x, _) => x + }.count() + } + + // Runs a job that encounters a single fetch failure but succeeds on the second attempt + def runJobWithTemporaryFetchFailure: Unit = { + object FailThisAttempt { + val _fail = new AtomicBoolean(true) + } + val rdd1 = sc.makeRDD(Array(1, 2, 3, 4), 2).map(x => (x, 1)).groupByKey() + val shuffleHandle = + rdd1.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleHandle + rdd1.map { + case (x, _) if (x == 1) && FailThisAttempt._fail.getAndSet(false) => + throw new FetchFailedException( + BlockManagerId("1", "1", 1), shuffleHandle.shuffleId, 0, 0, "test") + } + } + + failAfter(10.seconds) { + val e = intercept[SparkException] { + runJobWithPersistentFetchFailure + } + assert(e.getMessage.contains("org.apache.spark.shuffle.FetchFailedException")) + } + + // Run a second job that will fail due to a fetch failure. + // This job will hang without the fix for SPARK-17644. + failAfter(10.seconds) { + val e = intercept[SparkException] { + runJobWithPersistentFetchFailure + } + assert(e.getMessage.contains("org.apache.spark.shuffle.FetchFailedException")) + } + + failAfter(10.seconds) { + try { + runJobWithTemporaryFetchFailure + } catch { + case e: Throwable => fail("A job with one fetch failure should eventually succeed") + } + } + } + /** * Assert that the supplied TaskSet has exactly the given hosts as its preferred locations. * Note that this checks only the host and not the executor ID. From a6cfa3f38bcf6ba154d5ed2a53748fbc90c8872a Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Wed, 28 Sep 2016 13:22:45 -0700 Subject: [PATCH 145/306] [SPARK-17673][SQL] Incorrect exchange reuse with RowDataSourceScan ## What changes were proposed in this pull request? It seems the equality check for reuse of `RowDataSourceScanExec` nodes doesn't respect the output schema. This can cause self-joins or unions over the same underlying data source to return incorrect results if they select different fields. ## How was this patch tested? New unit test passes after the fix. Author: Eric Liang Closes #15273 from ericl/spark-17673. --- .../sql/execution/datasources/DataSourceStrategy.scala | 4 ++++ .../test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala | 8 ++++++++ 2 files changed, 12 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 63f01c5bb9e3c..693b4c4d0e5e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -340,6 +340,8 @@ object DataSourceStrategy extends Strategy with Logging { // `Filter`s or cannot be handled by `relation`. val filterCondition = unhandledPredicates.reduceLeftOption(expressions.And) + // These metadata values make scan plans uniquely identifiable for equality checking. + // TODO(SPARK-17701) using strings for equality checking is brittle val metadata: Map[String, String] = { val pairs = ArrayBuffer.empty[(String, String)] @@ -350,6 +352,8 @@ object DataSourceStrategy extends Strategy with Logging { } pairs += ("PushedFilters" -> markedFilters.mkString("[", ", ", "]")) } + pairs += ("ReadSchema" -> + StructType.fromAttributes(projects.map(_.toAttribute)).catalogString) pairs.toMap } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 10f15ca280689..c94cb3b69dfbe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -791,4 +791,12 @@ class JDBCSuite extends SparkFunSuite val schema = JdbcUtils.schemaString(df, "jdbc:mysql://localhost:3306/temp") assert(schema.contains("`order` TEXT")) } + + test("SPARK-17673: Exchange reuse respects differences in output schema") { + val df = sql("SELECT * FROM inttypes WHERE a IS NOT NULL") + val df1 = df.groupBy("a").agg("c" -> "min") + val df2 = df.groupBy("a").agg("d" -> "min") + val res = df1.union(df2) + assert(res.distinct().count() == 2) // would be 1 if the exchange was incorrectly reused + } } From 557d6e32272dee4eaa0f426cc3e2f82ea361c3da Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Wed, 28 Sep 2016 16:20:49 -0700 Subject: [PATCH 146/306] [SPARK-17713][SQL] Move row-datasource related tests out of JDBCSuite ## What changes were proposed in this pull request? As a followup for https://github.com/apache/spark/pull/15273 we should move non-JDBC specific tests out of that suite. ## How was this patch tested? Ran the test. Author: Eric Liang Closes #15287 from ericl/spark-17713. --- .../RowDataSourceStrategySuite.scala | 72 +++++++++++++++++++ .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 8 --- 2 files changed, 72 insertions(+), 8 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/RowDataSourceStrategySuite.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/RowDataSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/RowDataSourceStrategySuite.scala new file mode 100644 index 0000000000000..d9afa4635318f --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/RowDataSourceStrategySuite.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import java.sql.DriverManager +import java.util.Properties + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.sources._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils + +class RowDataSourceStrategySuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext { + import testImplicits._ + + val url = "jdbc:h2:mem:testdb0" + val urlWithUserAndPass = "jdbc:h2:mem:testdb0;user=testUser;password=testPass" + var conn: java.sql.Connection = null + + before { + Utils.classForName("org.h2.Driver") + // Extra properties that will be specified for our database. We need these to test + // usage of parameters from OPTIONS clause in queries. + val properties = new Properties() + properties.setProperty("user", "testUser") + properties.setProperty("password", "testPass") + properties.setProperty("rowId", "false") + + conn = DriverManager.getConnection(url, properties) + conn.prepareStatement("create schema test").executeUpdate() + conn.prepareStatement("create table test.inttypes (a INT, b INT, c INT)").executeUpdate() + conn.prepareStatement("insert into test.inttypes values (1, 2, 3)").executeUpdate() + conn.commit() + sql( + s""" + |CREATE TEMPORARY TABLE inttypes + |USING org.apache.spark.sql.jdbc + |OPTIONS (url '$url', dbtable 'TEST.INTTYPES', user 'testUser', password 'testPass') + """.stripMargin.replaceAll("\n", " ")) + } + + after { + conn.close() + } + + test("SPARK-17673: Exchange reuse respects differences in output schema") { + val df = sql("SELECT * FROM inttypes") + val df1 = df.groupBy("a").agg("b" -> "min") + val df2 = df.groupBy("a").agg("c" -> "min") + val res = df1.union(df2) + assert(res.distinct().count() == 2) // would be 1 if the exchange was incorrectly reused + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index c94cb3b69dfbe..10f15ca280689 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -791,12 +791,4 @@ class JDBCSuite extends SparkFunSuite val schema = JdbcUtils.schemaString(df, "jdbc:mysql://localhost:3306/temp") assert(schema.contains("`order` TEXT")) } - - test("SPARK-17673: Exchange reuse respects differences in output schema") { - val df = sql("SELECT * FROM inttypes WHERE a IS NOT NULL") - val df1 = df.groupBy("a").agg("c" -> "min") - val df2 = df.groupBy("a").agg("d" -> "min") - val res = df1.union(df2) - assert(res.distinct().count() == 2) // would be 1 if the exchange was incorrectly reused - } } From 7d09232028967978d9db314ec041a762599f636b Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Wed, 28 Sep 2016 16:25:10 -0700 Subject: [PATCH 147/306] [SPARK-17641][SQL] Collect_list/Collect_set should not collect null values. ## What changes were proposed in this pull request? We added native versions of `collect_set` and `collect_list` in Spark 2.0. These currently also (try to) collect null values, this is different from the original Hive implementation. This PR fixes this by adding a null check to the `Collect.update` method. ## How was this patch tested? Added a regression test to `DataFrameAggregateSuite`. Author: Herman van Hovell Closes #15208 from hvanhovell/SPARK-17641. --- .../sql/catalyst/expressions/aggregate/collect.scala | 7 ++++++- .../apache/spark/sql/DataFrameAggregateSuite.scala | 12 ++++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala index 896ff61b23093..78a388d20630b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala @@ -65,7 +65,12 @@ abstract class Collect extends ImperativeAggregate { } override def update(b: MutableRow, input: InternalRow): Unit = { - buffer += child.eval(input) + // Do not allow null values. We follow the semantics of Hive's collect_list/collect_set here. + // See: org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMkCollectionEvaluator + val value = child.eval(input) + if (value != null) { + buffer += value + } } override def merge(buffer: MutableRow, input: InternalRow): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 0e172bee4f661..7aa4f0026f275 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -477,6 +477,18 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { assert(error.message.contains("collect_set() cannot have map type data")) } + test("SPARK-17641: collect functions should not collect null values") { + val df = Seq(("1", 2), (null, 2), ("1", 4)).toDF("a", "b") + checkAnswer( + df.select(collect_list($"a"), collect_list($"b")), + Seq(Row(Seq("1", "1"), Seq(2, 2, 4))) + ) + checkAnswer( + df.select(collect_set($"a"), collect_set($"b")), + Seq(Row(Seq("1"), Seq(2, 4))) + ) + } + test("SPARK-14664: Decimal sum/avg over window should work.") { checkAnswer( spark.sql("select sum(a) over () from values 1.0, 2.0, 3.0 T(a)"), From 7dfad4b132bc46263ef788ced4a935862f5c8756 Mon Sep 17 00:00:00 2001 From: Weiqing Yang Date: Wed, 28 Sep 2016 20:20:03 -0500 Subject: [PATCH 148/306] [SPARK-17710][HOTFIX] Fix ClassCircularityError in ReplSuite tests in Maven build: use 'Class.forName' instead of 'Utils.classForName' ## What changes were proposed in this pull request? Fix ClassCircularityError in ReplSuite tests when Spark is built by Maven build. ## How was this patch tested? (1) ``` build/mvn -DskipTests -Phadoop-2.3 -Pyarn -Phive -Phive-thriftserver -Pkinesis-asl -Pmesos clean package ``` Then test: ``` build/mvn -Dtest=none -DwildcardSuites=org.apache.spark.repl.ReplSuite test ``` ReplSuite tests passed (2) Manual Tests against some Spark applications in Yarn client mode and Yarn cluster mode. Need to check if spark caller contexts are written into HDFS hdfs-audit.log and Yarn RM audit log successfully. Author: Weiqing Yang Closes #15286 from Sherry302/SPARK-16757. --- core/src/main/scala/org/apache/spark/util/Utils.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index caa768cfbdc6c..f3493bd96b1ee 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -2489,8 +2489,10 @@ private[spark] class CallerContext( def setCurrentContext(): Boolean = { var succeed = false try { - val callerContext = Utils.classForName("org.apache.hadoop.ipc.CallerContext") - val Builder = Utils.classForName("org.apache.hadoop.ipc.CallerContext$Builder") + // scalastyle:off classforname + val callerContext = Class.forName("org.apache.hadoop.ipc.CallerContext") + val Builder = Class.forName("org.apache.hadoop.ipc.CallerContext$Builder") + // scalastyle:on classforname val builderInst = Builder.getConstructor(classOf[String]).newInstance(context) val hdfsContext = Builder.getMethod("build").invoke(builderInst) callerContext.getMethod("setCurrent", callerContext).invoke(null, hdfsContext) From 37eb9184f1e9f1c07142c66936671f4711ef407d Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 28 Sep 2016 19:03:05 -0700 Subject: [PATCH 149/306] [SPARK-17712][SQL] Fix invalid pushdown of data-independent filters beneath aggregates ## What changes were proposed in this pull request? This patch fixes a minor correctness issue impacting the pushdown of filters beneath aggregates. Specifically, if a filter condition references no grouping or aggregate columns (e.g. `WHERE false`) then it would be incorrectly pushed beneath an aggregate. Intuitively, the only case where you can push a filter beneath an aggregate is when that filter is deterministic and is defined over the grouping columns / expressions, since in that case the filter is acting to exclude entire groups from the query (like a `HAVING` clause). The existing code would only push deterministic filters beneath aggregates when all of the filter's references were grouping columns, but this logic missed the case where a filter has no references. For example, `WHERE false` is deterministic but is independent of the actual data. This patch fixes this minor bug by adding a new check to ensure that we don't push filters beneath aggregates when those filters don't reference any columns. ## How was this patch tested? New regression test in FilterPushdownSuite. Author: Josh Rosen Closes #15289 from JoshRosen/SPARK-17712. --- .../sql/catalyst/optimizer/Optimizer.scala | 2 +- .../optimizer/FilterPushdownSuite.scala | 17 +++++++++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 0df16b7a56c56..4952ba3b2b99d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -710,7 +710,7 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { val (pushDown, rest) = candidates.partition { cond => val replaced = replaceAlias(cond, aliasMap) - replaced.references.subsetOf(aggregate.child.outputSet) + cond.references.nonEmpty && replaced.references.subsetOf(aggregate.child.outputSet) } val stayUp = rest ++ containingNonDeterministic diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 55836f96f7e0e..019f132d94cb2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -687,6 +687,23 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("SPARK-17712: aggregate: don't push down filters that are data-independent") { + val originalQuery = LocalRelation.apply(testRelation.output, Seq.empty) + .select('a, 'b) + .groupBy('a)(count('a)) + .where(false) + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = testRelation + .select('a, 'b) + .groupBy('a)(count('a)) + .where(false) + .analyze + + comparePlans(optimized, correctAnswer) + } + test("broadcast hint") { val originalQuery = BroadcastHint(testRelation) .where('a === 2L && 'b + Rand(10).as("rnd") === 3) From a19a1bb59411177caaf99581e89098826b7d0c7b Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 29 Sep 2016 00:54:26 -0700 Subject: [PATCH 150/306] [SPARK-16356][FOLLOW-UP][ML] Enforce ML test of exception for local/distributed Dataset. ## What changes were proposed in this pull request? #14035 added ```testImplicits``` to ML unit tests and promoted ```toDF()```, but left one minor issue at ```VectorIndexerSuite```. If we create the DataFrame by ```Seq(...).toDF()```, it will throw different error/exception compared with ```sc.parallelize(Seq(...)).toDF()``` for one of the test cases. After in-depth study, I found it was caused by different behavior of local and distributed Dataset if the UDF failed at ```assert```. If the data is local Dataset, it throws ```AssertionError``` directly; If the data is distributed Dataset, it throws ```SparkException``` which is the wrapper of ```AssertionError```. I think we should enforce this test to cover both case. ## How was this patch tested? Unit test. Author: Yanbo Liang Closes #15261 from yanboliang/spark-16356. --- .../spark/ml/feature/VectorIndexerSuite.scala | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala index 4da1b133e8cd5..b28ce2ab45b45 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala @@ -88,9 +88,7 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext densePoints1 = densePoints1Seq.map(FeatureData).toDF() sparsePoints1 = sparsePoints1Seq.map(FeatureData).toDF() - // TODO: If we directly use `toDF` without parallelize, the test in - // "Throws error when given RDDs with different size vectors" is failed for an unknown reason. - densePoints2 = sc.parallelize(densePoints2Seq, 2).map(FeatureData).toDF() + densePoints2 = densePoints2Seq.map(FeatureData).toDF() sparsePoints2 = sparsePoints2Seq.map(FeatureData).toDF() badPoints = badPointsSeq.map(FeatureData).toDF() } @@ -121,10 +119,17 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext model.transform(densePoints1) // should work model.transform(sparsePoints1) // should work - intercept[SparkException] { + // If the data is local Dataset, it throws AssertionError directly. + intercept[AssertionError] { model.transform(densePoints2).collect() logInfo("Did not throw error when fit, transform were called on vectors of different lengths") } + // If the data is distributed Dataset, it throws SparkException + // which is the wrapper of AssertionError. + intercept[SparkException] { + model.transform(densePoints2.repartition(2)).collect() + logInfo("Did not throw error when fit, transform were called on vectors of different lengths") + } intercept[SparkException] { vectorIndexer.fit(badPoints) logInfo("Did not throw error when fitting vectors of different lengths in same RDD.") From f7082ac12518ae84d6d1d4b7330a9f12cf95e7c1 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 29 Sep 2016 04:30:42 -0700 Subject: [PATCH 151/306] [SPARK-17704][ML][MLLIB] ChiSqSelector performance improvement. ## What changes were proposed in this pull request? Several performance improvement for ```ChiSqSelector```: 1, Keep ```selectedFeatures``` ordered ascendent. ```ChiSqSelectorModel.transform``` need ```selectedFeatures``` ordered to make prediction. We should sort it when training model rather than making prediction, since users usually train model once and use the model to do prediction multiple times. 2, When training ```fpr``` type ```ChiSqSelectorModel```, it's not necessary to sort the ChiSq test result by statistic. ## How was this patch tested? Existing unit tests. Author: Yanbo Liang Closes #15277 from yanboliang/spark-17704. --- .../spark/mllib/feature/ChiSqSelector.scala | 45 ++++++++++++------- project/MimaExcludes.scala | 3 -- 2 files changed, 30 insertions(+), 18 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala index 0f7c6e8bc04bb..706ce78f260a6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala @@ -35,12 +35,24 @@ import org.apache.spark.sql.{Row, SparkSession} /** * Chi Squared selector model. * - * @param selectedFeatures list of indices to select (filter). + * @param selectedFeatures list of indices to select (filter). Must be ordered asc */ @Since("1.3.0") class ChiSqSelectorModel @Since("1.3.0") ( @Since("1.3.0") val selectedFeatures: Array[Int]) extends VectorTransformer with Saveable { + require(isSorted(selectedFeatures), "Array has to be sorted asc") + + protected def isSorted(array: Array[Int]): Boolean = { + var i = 1 + val len = array.length + while (i < len) { + if (array(i) < array(i-1)) return false + i += 1 + } + true + } + /** * Applies transformation on a vector. * @@ -57,22 +69,21 @@ class ChiSqSelectorModel @Since("1.3.0") ( * Preserves the order of filtered features the same as their indices are stored. * Might be moved to Vector as .slice * @param features vector - * @param filterIndices indices of features to filter + * @param filterIndices indices of features to filter, must be ordered asc */ private def compress(features: Vector, filterIndices: Array[Int]): Vector = { - val orderedIndices = filterIndices.sorted features match { case SparseVector(size, indices, values) => - val newSize = orderedIndices.length + val newSize = filterIndices.length val newValues = new ArrayBuilder.ofDouble val newIndices = new ArrayBuilder.ofInt var i = 0 var j = 0 var indicesIdx = 0 var filterIndicesIdx = 0 - while (i < indices.length && j < orderedIndices.length) { + while (i < indices.length && j < filterIndices.length) { indicesIdx = indices(i) - filterIndicesIdx = orderedIndices(j) + filterIndicesIdx = filterIndices(j) if (indicesIdx == filterIndicesIdx) { newIndices += j newValues += values(i) @@ -90,7 +101,7 @@ class ChiSqSelectorModel @Since("1.3.0") ( Vectors.sparse(newSize, newIndices.result(), newValues.result()) case DenseVector(values) => val values = features.toArray - Vectors.dense(orderedIndices.map(i => values(i))) + Vectors.dense(filterIndices.map(i => values(i))) case other => throw new UnsupportedOperationException( s"Only sparse and dense vectors are supported but got ${other.getClass}.") @@ -220,18 +231,22 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable { @Since("1.3.0") def fit(data: RDD[LabeledPoint]): ChiSqSelectorModel = { val chiSqTestResult = Statistics.chiSqTest(data) - .zipWithIndex.sortBy { case (res, _) => -res.statistic } val features = selectorType match { - case ChiSqSelector.KBest => chiSqTestResult - .take(numTopFeatures) - case ChiSqSelector.Percentile => chiSqTestResult - .take((chiSqTestResult.length * percentile).toInt) - case ChiSqSelector.FPR => chiSqTestResult - .filter{ case (res, _) => res.pValue < alpha } + case ChiSqSelector.KBest => + chiSqTestResult.zipWithIndex + .sortBy { case (res, _) => -res.statistic } + .take(numTopFeatures) + case ChiSqSelector.Percentile => + chiSqTestResult.zipWithIndex + .sortBy { case (res, _) => -res.statistic } + .take((chiSqTestResult.length * percentile).toInt) + case ChiSqSelector.FPR => + chiSqTestResult.zipWithIndex + .filter{ case (res, _) => res.pValue < alpha } case errorType => throw new IllegalStateException(s"Unknown ChiSqSelector Type: $errorType") } - val indices = features.map { case (_, indices) => indices } + val indices = features.map { case (_, indices) => indices }.sorted new ChiSqSelectorModel(indices) } } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 8024fbd21bbfc..4db3edb733a56 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -817,9 +817,6 @@ object MimaExcludes { ) ++ Seq( // [SPARK-17163] Unify logistic regression interface. Private constructor has new signature. ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.this") - ) ++ Seq( - // [SPARK-17017] Add chiSquare selector based on False Positive Rate (FPR) test - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.feature.ChiSqSelectorModel.isSorted") ) ++ Seq( // [SPARK-17365][Core] Remove/Kill multiple executors together to reduce RPC call time ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.SparkContext") From b35b0dbbfa3dc1bdf5e2fa1e9677d06635142b22 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Thu, 29 Sep 2016 08:24:34 -0400 Subject: [PATCH 152/306] [SPARK-17614][SQL] sparkSession.read() .jdbc(***) use the sql syntax "where 1=0" that Cassandra does not support ## What changes were proposed in this pull request? Use dialect's table-exists query rather than hard-coded WHERE 1=0 query ## How was this patch tested? Existing tests. Author: Sean Owen Closes #15196 from srowen/SPARK-17614. --- .../sql/execution/datasources/jdbc/JDBCRDD.scala | 6 ++---- .../org/apache/spark/sql/jdbc/JdbcDialects.scala | 15 ++++++++++++++- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index a7da29f9252b3..f10615ebe4bcf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -58,11 +58,11 @@ object JDBCRDD extends Logging { val dialect = JdbcDialects.get(url) val conn: Connection = JdbcUtils.createConnectionFactory(url, properties)() try { - val statement = conn.prepareStatement(s"SELECT * FROM $table WHERE 1=0") + val statement = conn.prepareStatement(dialect.getSchemaQuery(table)) try { val rs = statement.executeQuery() try { - return JdbcUtils.getSchema(rs, dialect) + JdbcUtils.getSchema(rs, dialect) } finally { rs.close() } @@ -72,8 +72,6 @@ object JDBCRDD extends Logging { } finally { conn.close() } - - throw new RuntimeException("This line is unreachable.") } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 3a6d5b7f1ced6..8dd4b8f662713 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.jdbc import java.sql.Connection -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.sql.types._ /** @@ -99,6 +99,19 @@ abstract class JdbcDialect extends Serializable { s"SELECT * FROM $table WHERE 1=0" } + /** + * The SQL query that should be used to discover the schema of a table. It only needs to + * ensure that the result set has the same schema as the table, such as by calling + * "SELECT * ...". Dialects can override this method to return a query that works best in a + * particular database. + * @param table The name of the table. + * @return The SQL query to use for discovering the schema. + */ + @Since("2.1.0") + def getSchemaQuery(table: String): String = { + s"SELECT * FROM $table WHERE 1=0" + } + /** * Override connection specific properties to run before a select is made. This is in place to * allow dialects that need special treatment to optimize behavior. From b2e9731ca494c0c60d571499f68bb8306a3c9fe5 Mon Sep 17 00:00:00 2001 From: Takeshi YAMAMURO Date: Thu, 29 Sep 2016 08:26:03 -0400 Subject: [PATCH 153/306] [MINOR][DOCS] Fix th doc. of spark-streaming with kinesis ## What changes were proposed in this pull request? This pr is just to fix the document of `spark-kinesis-integration`. Since `SPARK-17418` prevented all the kinesis stuffs (including kinesis example code) from publishing, `bin/run-example streaming.KinesisWordCountASL` and `bin/run-example streaming.JavaKinesisWordCountASL` does not work. Instead, it fetches the kinesis jar from the Spark Package. Author: Takeshi YAMAMURO Closes #15260 from maropu/DocFixKinesis. --- docs/streaming-kinesis-integration.md | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/docs/streaming-kinesis-integration.md b/docs/streaming-kinesis-integration.md index 96198ddf537b6..6be0b548bc62b 100644 --- a/docs/streaming-kinesis-integration.md +++ b/docs/streaming-kinesis-integration.md @@ -166,10 +166,7 @@ A Kinesis stream can be set up at one of the valid Kinesis endpoints with 1 or m #### Running the Example To run the example, -- Download Spark source and follow the [instructions](building-spark.html) to build Spark with profile *-Pkinesis-asl*. - - mvn -Pkinesis-asl -DskipTests clean package - +- Download a Spark binary from the [download site](http://spark.apache.org/downloads.html). - Set up Kinesis stream (see earlier section) within AWS. Note the name of the Kinesis stream and the endpoint URL corresponding to the region where the stream was created. @@ -180,12 +177,12 @@ To run the example,
      - bin/run-example streaming.KinesisWordCountASL [Kinesis app name] [Kinesis stream name] [endpoint URL] + bin/run-example --packages org.apache.spark:spark-streaming-kinesis-asl_{{site.SCALA_BINARY_VERSION}}:{{site.SPARK_VERSION_SHORT}} streaming.KinesisWordCountASL [Kinesis app name] [Kinesis stream name] [endpoint URL]
      - bin/run-example streaming.JavaKinesisWordCountASL [Kinesis app name] [Kinesis stream name] [endpoint URL] + bin/run-example --packages org.apache.spark:spark-streaming-kinesis-asl_{{site.SCALA_BINARY_VERSION}}:{{site.SPARK_VERSION_SHORT}} streaming.JavaKinesisWordCountASL [Kinesis app name] [Kinesis stream name] [endpoint URL]
      From 958200497affb40f05e321c2b0e252d365ae02f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Hiram=20Soltren?= Date: Thu, 29 Sep 2016 10:18:56 -0700 Subject: [PATCH 154/306] [DOCS] Reorganize explanation of Accumulators and Broadcast Variables MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? The discussion of the interaction of Accumulators and Broadcast Variables should logically follow the discussion on Checkpointing. As currently written, this section discusses Checkpointing before it is formally introduced. To remedy this: - Rename this section to "Accumulators, Broadcast Variables, and Checkpoints", and - Move this section after "Checkpointing". ## How was this patch tested? Testing: ran $ SKIP_API=1 jekyll build , and verified changes in a Web browser pointed at docs/_site/index.html. Author: José Hiram Soltren Closes #15281 from jsoltren/doc-changes. --- docs/streaming-programming-guide.md | 328 ++++++++++++++-------------- 1 file changed, 164 insertions(+), 164 deletions(-) diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 43f1cf3e31871..0b0315b366501 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -1368,170 +1368,6 @@ Note that the connections in the pool should be lazily created on demand and tim *** -## Accumulators and Broadcast Variables - -[Accumulators](programming-guide.html#accumulators) and [Broadcast variables](programming-guide.html#broadcast-variables) cannot be recovered from checkpoint in Spark Streaming. If you enable checkpointing and use [Accumulators](programming-guide.html#accumulators) or [Broadcast variables](programming-guide.html#broadcast-variables) as well, you'll have to create lazily instantiated singleton instances for [Accumulators](programming-guide.html#accumulators) and [Broadcast variables](programming-guide.html#broadcast-variables) so that they can be re-instantiated after the driver restarts on failure. This is shown in the following example. - -
      -
      -{% highlight scala %} - -object WordBlacklist { - - @volatile private var instance: Broadcast[Seq[String]] = null - - def getInstance(sc: SparkContext): Broadcast[Seq[String]] = { - if (instance == null) { - synchronized { - if (instance == null) { - val wordBlacklist = Seq("a", "b", "c") - instance = sc.broadcast(wordBlacklist) - } - } - } - instance - } -} - -object DroppedWordsCounter { - - @volatile private var instance: LongAccumulator = null - - def getInstance(sc: SparkContext): LongAccumulator = { - if (instance == null) { - synchronized { - if (instance == null) { - instance = sc.longAccumulator("WordsInBlacklistCounter") - } - } - } - instance - } -} - -wordCounts.foreachRDD { (rdd: RDD[(String, Int)], time: Time) => - // Get or register the blacklist Broadcast - val blacklist = WordBlacklist.getInstance(rdd.sparkContext) - // Get or register the droppedWordsCounter Accumulator - val droppedWordsCounter = DroppedWordsCounter.getInstance(rdd.sparkContext) - // Use blacklist to drop words and use droppedWordsCounter to count them - val counts = rdd.filter { case (word, count) => - if (blacklist.value.contains(word)) { - droppedWordsCounter.add(count) - false - } else { - true - } - }.collect().mkString("[", ", ", "]") - val output = "Counts at time " + time + " " + counts -}) - -{% endhighlight %} - -See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala). -
      -
      -{% highlight java %} - -class JavaWordBlacklist { - - private static volatile Broadcast> instance = null; - - public static Broadcast> getInstance(JavaSparkContext jsc) { - if (instance == null) { - synchronized (JavaWordBlacklist.class) { - if (instance == null) { - List wordBlacklist = Arrays.asList("a", "b", "c"); - instance = jsc.broadcast(wordBlacklist); - } - } - } - return instance; - } -} - -class JavaDroppedWordsCounter { - - private static volatile LongAccumulator instance = null; - - public static LongAccumulator getInstance(JavaSparkContext jsc) { - if (instance == null) { - synchronized (JavaDroppedWordsCounter.class) { - if (instance == null) { - instance = jsc.sc().longAccumulator("WordsInBlacklistCounter"); - } - } - } - return instance; - } -} - -wordCounts.foreachRDD(new Function2, Time, Void>() { - @Override - public Void call(JavaPairRDD rdd, Time time) throws IOException { - // Get or register the blacklist Broadcast - final Broadcast> blacklist = JavaWordBlacklist.getInstance(new JavaSparkContext(rdd.context())); - // Get or register the droppedWordsCounter Accumulator - final LongAccumulator droppedWordsCounter = JavaDroppedWordsCounter.getInstance(new JavaSparkContext(rdd.context())); - // Use blacklist to drop words and use droppedWordsCounter to count them - String counts = rdd.filter(new Function, Boolean>() { - @Override - public Boolean call(Tuple2 wordCount) throws Exception { - if (blacklist.value().contains(wordCount._1())) { - droppedWordsCounter.add(wordCount._2()); - return false; - } else { - return true; - } - } - }).collect().toString(); - String output = "Counts at time " + time + " " + counts; - } -} - -{% endhighlight %} - -See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java). -
      -
      -{% highlight python %} -def getWordBlacklist(sparkContext): - if ("wordBlacklist" not in globals()): - globals()["wordBlacklist"] = sparkContext.broadcast(["a", "b", "c"]) - return globals()["wordBlacklist"] - -def getDroppedWordsCounter(sparkContext): - if ("droppedWordsCounter" not in globals()): - globals()["droppedWordsCounter"] = sparkContext.accumulator(0) - return globals()["droppedWordsCounter"] - -def echo(time, rdd): - # Get or register the blacklist Broadcast - blacklist = getWordBlacklist(rdd.context) - # Get or register the droppedWordsCounter Accumulator - droppedWordsCounter = getDroppedWordsCounter(rdd.context) - - # Use blacklist to drop words and use droppedWordsCounter to count them - def filterFunc(wordCount): - if wordCount[0] in blacklist.value: - droppedWordsCounter.add(wordCount[1]) - False - else: - True - - counts = "Counts at time %s %s" % (time, rdd.filter(filterFunc).collect()) - -wordCounts.foreachRDD(echo) - -{% endhighlight %} - -See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/python/streaming/recoverable_network_wordcount.py). - -
      -
      - -*** - ## DataFrame and SQL Operations You can easily use [DataFrames and SQL](sql-programming-guide.html) operations on streaming data. You have to create a SparkSession using the SparkContext that the StreamingContext is using. Furthermore this has to done such that it can be restarted on driver failures. This is done by creating a lazily instantiated singleton instance of SparkSession. This is shown in the following example. It modifies the earlier [word count example](#a-quick-example) to generate word counts using DataFrames and SQL. Each RDD is converted to a DataFrame, registered as a temporary table and then queried using SQL. @@ -1877,6 +1713,170 @@ batch interval that is at least 10 seconds. It can be set by using *** +## Accumulators, Broadcast Variables, and Checkpoints + +[Accumulators](programming-guide.html#accumulators) and [Broadcast variables](programming-guide.html#broadcast-variables) cannot be recovered from checkpoint in Spark Streaming. If you enable checkpointing and use [Accumulators](programming-guide.html#accumulators) or [Broadcast variables](programming-guide.html#broadcast-variables) as well, you'll have to create lazily instantiated singleton instances for [Accumulators](programming-guide.html#accumulators) and [Broadcast variables](programming-guide.html#broadcast-variables) so that they can be re-instantiated after the driver restarts on failure. This is shown in the following example. + +
      +
      +{% highlight scala %} + +object WordBlacklist { + + @volatile private var instance: Broadcast[Seq[String]] = null + + def getInstance(sc: SparkContext): Broadcast[Seq[String]] = { + if (instance == null) { + synchronized { + if (instance == null) { + val wordBlacklist = Seq("a", "b", "c") + instance = sc.broadcast(wordBlacklist) + } + } + } + instance + } +} + +object DroppedWordsCounter { + + @volatile private var instance: LongAccumulator = null + + def getInstance(sc: SparkContext): LongAccumulator = { + if (instance == null) { + synchronized { + if (instance == null) { + instance = sc.longAccumulator("WordsInBlacklistCounter") + } + } + } + instance + } +} + +wordCounts.foreachRDD { (rdd: RDD[(String, Int)], time: Time) => + // Get or register the blacklist Broadcast + val blacklist = WordBlacklist.getInstance(rdd.sparkContext) + // Get or register the droppedWordsCounter Accumulator + val droppedWordsCounter = DroppedWordsCounter.getInstance(rdd.sparkContext) + // Use blacklist to drop words and use droppedWordsCounter to count them + val counts = rdd.filter { case (word, count) => + if (blacklist.value.contains(word)) { + droppedWordsCounter.add(count) + false + } else { + true + } + }.collect().mkString("[", ", ", "]") + val output = "Counts at time " + time + " " + counts +}) + +{% endhighlight %} + +See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala). +
      +
      +{% highlight java %} + +class JavaWordBlacklist { + + private static volatile Broadcast> instance = null; + + public static Broadcast> getInstance(JavaSparkContext jsc) { + if (instance == null) { + synchronized (JavaWordBlacklist.class) { + if (instance == null) { + List wordBlacklist = Arrays.asList("a", "b", "c"); + instance = jsc.broadcast(wordBlacklist); + } + } + } + return instance; + } +} + +class JavaDroppedWordsCounter { + + private static volatile LongAccumulator instance = null; + + public static LongAccumulator getInstance(JavaSparkContext jsc) { + if (instance == null) { + synchronized (JavaDroppedWordsCounter.class) { + if (instance == null) { + instance = jsc.sc().longAccumulator("WordsInBlacklistCounter"); + } + } + } + return instance; + } +} + +wordCounts.foreachRDD(new Function2, Time, Void>() { + @Override + public Void call(JavaPairRDD rdd, Time time) throws IOException { + // Get or register the blacklist Broadcast + final Broadcast> blacklist = JavaWordBlacklist.getInstance(new JavaSparkContext(rdd.context())); + // Get or register the droppedWordsCounter Accumulator + final LongAccumulator droppedWordsCounter = JavaDroppedWordsCounter.getInstance(new JavaSparkContext(rdd.context())); + // Use blacklist to drop words and use droppedWordsCounter to count them + String counts = rdd.filter(new Function, Boolean>() { + @Override + public Boolean call(Tuple2 wordCount) throws Exception { + if (blacklist.value().contains(wordCount._1())) { + droppedWordsCounter.add(wordCount._2()); + return false; + } else { + return true; + } + } + }).collect().toString(); + String output = "Counts at time " + time + " " + counts; + } +} + +{% endhighlight %} + +See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java). +
      +
      +{% highlight python %} +def getWordBlacklist(sparkContext): + if ("wordBlacklist" not in globals()): + globals()["wordBlacklist"] = sparkContext.broadcast(["a", "b", "c"]) + return globals()["wordBlacklist"] + +def getDroppedWordsCounter(sparkContext): + if ("droppedWordsCounter" not in globals()): + globals()["droppedWordsCounter"] = sparkContext.accumulator(0) + return globals()["droppedWordsCounter"] + +def echo(time, rdd): + # Get or register the blacklist Broadcast + blacklist = getWordBlacklist(rdd.context) + # Get or register the droppedWordsCounter Accumulator + droppedWordsCounter = getDroppedWordsCounter(rdd.context) + + # Use blacklist to drop words and use droppedWordsCounter to count them + def filterFunc(wordCount): + if wordCount[0] in blacklist.value: + droppedWordsCounter.add(wordCount[1]) + False + else: + True + + counts = "Counts at time %s %s" % (time, rdd.filter(filterFunc).collect()) + +wordCounts.foreachRDD(echo) + +{% endhighlight %} + +See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/python/streaming/recoverable_network_wordcount.py). + +
      +
      + +*** + ## Deploying Applications This section discusses the steps to deploy a Spark Streaming application. From 7f779e7439127efa0e3611f7745e1c8423845198 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Thu, 29 Sep 2016 15:36:40 -0400 Subject: [PATCH 155/306] [SPARK-17648][CORE] TaskScheduler really needs offers to be an IndexedSeq ## What changes were proposed in this pull request? The Seq[WorkerOffer] is accessed by index, so it really should be an IndexedSeq, otherwise an O(n) operation becomes O(n^2). In practice this hasn't been an issue b/c where these offers are generated, the call to `.toSeq` just happens to create an IndexedSeq anyway.I got bitten by this in performance tests I was doing, and its better for the types to be more precise so eg. a change in Scala doesn't destroy performance. ## How was this patch tested? Unit tests via jenkins. Author: Imran Rashid Closes #15221 from squito/SPARK-17648. --- .../spark/scheduler/TaskSchedulerImpl.scala | 4 +-- .../CoarseGrainedSchedulerBackend.scala | 4 +-- .../local/LocalSchedulerBackend.scala | 2 +- .../scheduler/SchedulerIntegrationSuite.scala | 7 ++-- .../scheduler/TaskSchedulerImplSuite.scala | 32 +++++++++---------- .../MesosFineGrainedSchedulerBackend.scala | 2 +- ...esosFineGrainedSchedulerBackendSuite.scala | 2 +- 7 files changed, 26 insertions(+), 27 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 52a7186cbf45c..0ad4730fe20a6 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -252,7 +252,7 @@ private[spark] class TaskSchedulerImpl( maxLocality: TaskLocality, shuffledOffers: Seq[WorkerOffer], availableCpus: Array[Int], - tasks: Seq[ArrayBuffer[TaskDescription]]) : Boolean = { + tasks: IndexedSeq[ArrayBuffer[TaskDescription]]) : Boolean = { var launchedTask = false for (i <- 0 until shuffledOffers.size) { val execId = shuffledOffers(i).executorId @@ -286,7 +286,7 @@ private[spark] class TaskSchedulerImpl( * sets for tasks in order of priority. We fill each node with tasks in a round-robin manner so * that tasks are balanced across the cluster. */ - def resourceOffers(offers: Seq[WorkerOffer]): Seq[Seq[TaskDescription]] = synchronized { + def resourceOffers(offers: IndexedSeq[WorkerOffer]): Seq[Seq[TaskDescription]] = synchronized { // Mark each slave as alive and remember its hostname // Also track if new executor is added var newExecAvail = false diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index edc3c199376ef..2d0986316601f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -216,7 +216,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp val activeExecutors = executorDataMap.filterKeys(executorIsAlive) val workOffers = activeExecutors.map { case (id, executorData) => new WorkerOffer(id, executorData.executorHost, executorData.freeCores) - }.toSeq + }.toIndexedSeq launchTasks(scheduler.resourceOffers(workOffers)) } @@ -233,7 +233,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // Filter out executors under killing if (executorIsAlive(executorId)) { val executorData = executorDataMap(executorId) - val workOffers = Seq( + val workOffers = IndexedSeq( new WorkerOffer(executorId, executorData.executorHost, executorData.freeCores)) launchTasks(scheduler.resourceOffers(workOffers)) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala index e386052814039..7a73e8ed8a38f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala @@ -81,7 +81,7 @@ private[spark] class LocalEndpoint( } def reviveOffers() { - val offers = Seq(new WorkerOffer(localExecutorId, localExecutorHostname, freeCores)) + val offers = IndexedSeq(new WorkerOffer(localExecutorId, localExecutorHostname, freeCores)) for (task <- scheduler.resourceOffers(offers).flatten) { freeCores -= scheduler.CPUS_PER_TASK executor.launchTask(executorBackend, taskId = task.taskId, attemptNumber = task.attemptNumber, diff --git a/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala index 14f52a6be9d1f..5cd548bbc72d9 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala @@ -366,13 +366,13 @@ private[spark] abstract class MockBackend( */ def executorIdToExecutor: Map[String, ExecutorTaskStatus] - private def generateOffers(): Seq[WorkerOffer] = { + private def generateOffers(): IndexedSeq[WorkerOffer] = { executorIdToExecutor.values.filter { exec => exec.freeCores > 0 }.map { exec => WorkerOffer(executorId = exec.executorId, host = exec.host, cores = exec.freeCores) - }.toSeq + }.toIndexedSeq } /** @@ -381,8 +381,7 @@ private[spark] abstract class MockBackend( * scheduling. */ override def reviveOffers(): Unit = { - val offers: Seq[WorkerOffer] = generateOffers() - val newTaskDescriptions = taskScheduler.resourceOffers(offers).flatten + val newTaskDescriptions = taskScheduler.resourceOffers(generateOffers()).flatten // get the task now, since that requires a lock on TaskSchedulerImpl, to prevent individual // tests from introducing a race if they need it val newTasks = taskScheduler.synchronized { diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index 100b15740ca92..61787b54f824f 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -87,7 +87,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B test("Scheduler does not always schedule tasks on the same workers") { val taskScheduler = setupScheduler() val numFreeCores = 1 - val workerOffers = Seq(new WorkerOffer("executor0", "host0", numFreeCores), + val workerOffers = IndexedSeq(new WorkerOffer("executor0", "host0", numFreeCores), new WorkerOffer("executor1", "host1", numFreeCores)) // Repeatedly try to schedule a 1-task job, and make sure that it doesn't always // get scheduled on the same executor. While there is a chance this test will fail @@ -112,7 +112,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B val taskCpus = 2 val taskScheduler = setupScheduler("spark.task.cpus" -> taskCpus.toString) // Give zero core offers. Should not generate any tasks - val zeroCoreWorkerOffers = Seq(new WorkerOffer("executor0", "host0", 0), + val zeroCoreWorkerOffers = IndexedSeq(new WorkerOffer("executor0", "host0", 0), new WorkerOffer("executor1", "host1", 0)) val taskSet = FakeTask.createTaskSet(1) taskScheduler.submitTasks(taskSet) @@ -121,7 +121,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B // No tasks should run as we only have 1 core free. val numFreeCores = 1 - val singleCoreWorkerOffers = Seq(new WorkerOffer("executor0", "host0", numFreeCores), + val singleCoreWorkerOffers = IndexedSeq(new WorkerOffer("executor0", "host0", numFreeCores), new WorkerOffer("executor1", "host1", numFreeCores)) taskScheduler.submitTasks(taskSet) taskDescriptions = taskScheduler.resourceOffers(singleCoreWorkerOffers).flatten @@ -129,7 +129,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B // Now change the offers to have 2 cores in one executor and verify if it // is chosen. - val multiCoreWorkerOffers = Seq(new WorkerOffer("executor0", "host0", taskCpus), + val multiCoreWorkerOffers = IndexedSeq(new WorkerOffer("executor0", "host0", taskCpus), new WorkerOffer("executor1", "host1", numFreeCores)) taskScheduler.submitTasks(taskSet) taskDescriptions = taskScheduler.resourceOffers(multiCoreWorkerOffers).flatten @@ -144,7 +144,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B val numFreeCores = 1 val taskSet = new TaskSet( Array(new NotSerializableFakeTask(1, 0), new NotSerializableFakeTask(0, 1)), 0, 0, 0, null) - val multiCoreWorkerOffers = Seq(new WorkerOffer("executor0", "host0", taskCpus), + val multiCoreWorkerOffers = IndexedSeq(new WorkerOffer("executor0", "host0", taskCpus), new WorkerOffer("executor1", "host1", numFreeCores)) taskScheduler.submitTasks(taskSet) var taskDescriptions = taskScheduler.resourceOffers(multiCoreWorkerOffers).flatten @@ -184,7 +184,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B val taskScheduler = setupScheduler() val numFreeCores = 1 - val workerOffers = Seq(new WorkerOffer("executor0", "host0", numFreeCores)) + val workerOffers = IndexedSeq(new WorkerOffer("executor0", "host0", numFreeCores)) val attempt1 = FakeTask.createTaskSet(10) // submit attempt 1, offer some resources, some tasks get scheduled @@ -216,7 +216,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B val taskScheduler = setupScheduler() val numFreeCores = 10 - val workerOffers = Seq(new WorkerOffer("executor0", "host0", numFreeCores)) + val workerOffers = IndexedSeq(new WorkerOffer("executor0", "host0", numFreeCores)) val attempt1 = FakeTask.createTaskSet(10) // submit attempt 1, offer some resources, some tasks get scheduled @@ -254,8 +254,8 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B test("tasks are not re-scheduled while executor loss reason is pending") { val taskScheduler = setupScheduler() - val e0Offers = Seq(new WorkerOffer("executor0", "host0", 1)) - val e1Offers = Seq(new WorkerOffer("executor1", "host0", 1)) + val e0Offers = IndexedSeq(new WorkerOffer("executor0", "host0", 1)) + val e1Offers = IndexedSeq(new WorkerOffer("executor1", "host0", 1)) val attempt1 = FakeTask.createTaskSet(1) // submit attempt 1, offer resources, task gets scheduled @@ -296,7 +296,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B taskScheduler.submitTasks(taskSet) val tsm = taskScheduler.taskSetManagerForAttempt(taskSet.stageId, taskSet.stageAttemptId).get - val firstTaskAttempts = taskScheduler.resourceOffers(Seq( + val firstTaskAttempts = taskScheduler.resourceOffers(IndexedSeq( new WorkerOffer("executor0", "host0", 1), new WorkerOffer("executor1", "host1", 1) )).flatten @@ -313,7 +313,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B // on that executor, and make sure that the other task (not the failed one) is assigned there taskScheduler.executorLost("executor1", SlaveLost("oops")) val nextTaskAttempts = - taskScheduler.resourceOffers(Seq(new WorkerOffer("executor0", "host0", 1))).flatten + taskScheduler.resourceOffers(IndexedSeq(new WorkerOffer("executor0", "host0", 1))).flatten // Note: Its OK if some future change makes this already realize the taskset has become // unschedulable at this point (though in the current implementation, we're sure it will not) assert(nextTaskAttempts.size === 1) @@ -323,7 +323,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B // now we should definitely realize that our task set is unschedulable, because the only // task left can't be scheduled on any executors due to the blacklist - taskScheduler.resourceOffers(Seq(new WorkerOffer("executor0", "host0", 1))) + taskScheduler.resourceOffers(IndexedSeq(new WorkerOffer("executor0", "host0", 1))) sc.listenerBus.waitUntilEmpty(100000) assert(tsm.isZombie) assert(failedTaskSet) @@ -348,7 +348,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B taskScheduler.submitTasks(taskSet) val tsm = taskScheduler.taskSetManagerForAttempt(taskSet.stageId, taskSet.stageAttemptId).get - val offers = Seq( + val offers = IndexedSeq( // each offer has more than enough free cores for the entire task set, so when combined // with the locality preferences, we schedule all tasks on one executor new WorkerOffer("executor0", "host0", 4), @@ -380,7 +380,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B (0 until 2).map { _ => Seq(TaskLocation("host0", "executor2"))}: _* )) - val taskDescs = taskScheduler.resourceOffers(Seq( + val taskDescs = taskScheduler.resourceOffers(IndexedSeq( new WorkerOffer("executor0", "host0", 1), new WorkerOffer("executor1", "host1", 1) )).flatten @@ -396,7 +396,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B // when executor2 is added, we should realize that we can run process-local tasks. // And we should know its alive on the host. val secondTaskDescs = taskScheduler.resourceOffers( - Seq(new WorkerOffer("executor2", "host0", 1))).flatten + IndexedSeq(new WorkerOffer("executor2", "host0", 1))).flatten assert(secondTaskDescs.size === 1) assert(mgr.myLocalityLevels.toSet === Set(TaskLocality.PROCESS_LOCAL, TaskLocality.NODE_LOCAL, TaskLocality.ANY)) @@ -406,7 +406,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B // And even if we don't have anything left to schedule, another resource offer on yet another // executor should also update the set of live executors val thirdTaskDescs = taskScheduler.resourceOffers( - Seq(new WorkerOffer("executor3", "host1", 1))).flatten + IndexedSeq(new WorkerOffer("executor3", "host1", 1))).flatten assert(thirdTaskDescs.size === 0) assert(taskScheduler.getExecutorsAliveOnHost("host1") === Some(Set("executor1", "executor3"))) } diff --git a/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala b/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala index eb3b235949501..09a252f3c74ac 100644 --- a/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala +++ b/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala @@ -286,7 +286,7 @@ private[spark] class MesosFineGrainedSchedulerBackend( o.getSlaveId.getValue, o.getHostname, cpus) - } + }.toIndexedSeq val slaveIdToOffer = usableOffers.map(o => o.getSlaveId.getValue -> o).toMap val slaveIdToWorkerOffer = workerOffers.map(o => o.executorId -> o).toMap diff --git a/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala b/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala index 7a706ab256f82..1d7a86f4b0904 100644 --- a/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala +++ b/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala @@ -283,7 +283,7 @@ class MesosFineGrainedSchedulerBackendSuite mesosOffers2.add(createOffer(1, minMem, minCpu)) reset(taskScheduler) reset(driver) - when(taskScheduler.resourceOffers(any(classOf[Seq[WorkerOffer]]))).thenReturn(Seq(Seq())) + when(taskScheduler.resourceOffers(any(classOf[IndexedSeq[WorkerOffer]]))).thenReturn(Seq(Seq())) when(taskScheduler.CPUS_PER_TASK).thenReturn(2) when(driver.declineOffer(mesosOffers2.get(0).getId)).thenReturn(Status.valueOf(1)) From cb87b3ced9453b5717fa8e8637b97a2f3f25fdd7 Mon Sep 17 00:00:00 2001 From: Gang Wu Date: Thu, 29 Sep 2016 15:51:05 -0400 Subject: [PATCH 156/306] [SPARK-17672] Spark 2.0 history server web Ui takes too long for a single application Added a new API getApplicationInfo(appId: String) in class ApplicationHistoryProvider and class SparkUI to get app info. In this change, FsHistoryProvider can directly fetch one app info in O(1) time complexity compared to O(n) before the change which used an Iterator.find() interface. Both ApplicationCache and OneApplicationResource classes adopt this new api. manual tests Author: Gang Wu Closes #15247 from wgtmac/SPARK-17671. --- .../spark/deploy/history/ApplicationHistoryProvider.scala | 5 +++++ .../org/apache/spark/deploy/history/FsHistoryProvider.scala | 4 ++++ .../org/apache/spark/deploy/history/HistoryServer.scala | 4 ++++ .../org/apache/spark/status/api/v1/ApiRootResource.scala | 1 + .../apache/spark/status/api/v1/OneApplicationResource.scala | 2 +- core/src/main/scala/org/apache/spark/ui/SparkUI.scala | 4 ++++ 6 files changed, 19 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala index 44661edfff90b..ba42b4862aa90 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala @@ -109,4 +109,9 @@ private[history] abstract class ApplicationHistoryProvider { @throws(classOf[SparkException]) def writeEventLogs(appId: String, attemptId: Option[String], zipStream: ZipOutputStream): Unit + /** + * @return the [[ApplicationHistoryInfo]] for the appId if it exists. + */ + def getApplicationInfo(appId: String): Option[ApplicationHistoryInfo] + } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 6874aa5f938ac..d494ff0659bd2 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -224,6 +224,10 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) override def getListing(): Iterable[FsApplicationHistoryInfo] = applications.values + override def getApplicationInfo(appId: String): Option[FsApplicationHistoryInfo] = { + applications.get(appId) + } + override def getAppUI(appId: String, attemptId: Option[String]): Option[LoadedAppUI] = { try { applications.get(appId).flatMap { appInfo => diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index c178917d8da3b..735aa43cfc994 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -182,6 +182,10 @@ class HistoryServer( getApplicationList().iterator.map(ApplicationsListResource.appHistoryInfoToPublicAppInfo) } + def getApplicationInfo(appId: String): Option[ApplicationInfo] = { + provider.getApplicationInfo(appId).map(ApplicationsListResource.appHistoryInfoToPublicAppInfo) + } + override def writeEventLogs( appId: String, attemptId: Option[String], diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala index de927117e1f63..17bc04303fa8b 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala @@ -222,6 +222,7 @@ private[spark] object ApiRootResource { private[spark] trait UIRoot { def getSparkUI(appKey: String): Option[SparkUI] def getApplicationInfoList: Iterator[ApplicationInfo] + def getApplicationInfo(appId: String): Option[ApplicationInfo] /** * Write the event logs for the given app to the [[ZipOutputStream]] instance. If attemptId is diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala index d7e6a8b589953..18c3e2f407360 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala @@ -24,7 +24,7 @@ private[v1] class OneApplicationResource(uiRoot: UIRoot) { @GET def getApp(@PathParam("appId") appId: String): ApplicationInfo = { - val apps = uiRoot.getApplicationInfoList.find { _.id == appId } + val apps = uiRoot.getApplicationInfo(appId) apps.getOrElse(throw new NotFoundException("unknown app: " + appId)) } diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index 39155ff2649ec..ef71db89798f1 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -126,6 +126,10 @@ private[spark] class SparkUI private ( )) )) } + + def getApplicationInfo(appId: String): Option[ApplicationInfo] = { + getApplicationInfoList.find(_.id == appId) + } } private[spark] abstract class SparkUITab(parent: SparkUI, prefix: String) From 027dea8f294504bc5cd8bfedde546d171cb78657 Mon Sep 17 00:00:00 2001 From: Brian Cho Date: Thu, 29 Sep 2016 15:59:17 -0400 Subject: [PATCH 157/306] [SPARK-17715][SCHEDULER] Make task launch logs DEBUG ## What changes were proposed in this pull request? Ramp down the task launch logs from INFO to DEBUG. Task launches can happen orders of magnitude more than executor registration so it makes the logs easier to handle if they are different log levels. For larger jobs, there can be 100,000s of task launches which makes the driver log huge. ## How was this patch tested? No tests, as this is a trivial change. Author: Brian Cho Closes #15290 from dafrista/ramp-down-task-logging. --- .../spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 2d0986316601f..0dae0e614e17d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -265,7 +265,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp val executorData = executorDataMap(task.executorId) executorData.freeCores -= scheduler.CPUS_PER_TASK - logInfo(s"Launching task ${task.taskId} on executor id: ${task.executorId} hostname: " + + logDebug(s"Launching task ${task.taskId} on executor id: ${task.executorId} hostname: " + s"${executorData.executorHost}.") executorData.executorEndpoint.send(LaunchTask(new SerializableBuffer(serializedTask))) From fe33121a53384811a8e094ab6c05dc85b7c7ca87 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Thu, 29 Sep 2016 13:01:10 -0700 Subject: [PATCH 158/306] [SPARK-17699] Support for parsing JSON string columns Spark SQL has great support for reading text files that contain JSON data. However, in many cases the JSON data is just one column amongst others. This is particularly true when reading from sources such as Kafka. This PR adds a new functions `from_json` that converts a string column into a nested `StructType` with a user specified schema. Example usage: ```scala val df = Seq("""{"a": 1}""").toDS() val schema = new StructType().add("a", IntegerType) df.select(from_json($"value", schema) as 'json) // => [json: ] ``` This PR adds support for java, scala and python. I leveraged our existing JSON parsing support by moving it into catalyst (so that we could define expressions using it). I left SQL out for now, because I'm not sure how users would specify a schema. Author: Michael Armbrust Closes #15274 from marmbrus/jsonParser. --- python/pyspark/sql/functions.py | 23 ++++++++ .../expressions/jsonExpressions.scala | 31 +++++++++- .../sql/catalyst}/json/JSONOptions.scala | 6 +- .../sql/catalyst}/json/JacksonParser.scala | 13 +++-- .../sql/catalyst}/json/JacksonUtils.scala | 4 +- .../catalyst/util}/CompressionCodecs.scala | 6 +- .../spark/sql/catalyst/util}/ParseModes.scala | 4 +- .../expressions/JsonExpressionsSuite.scala | 26 +++++++++ .../apache/spark/sql/DataFrameReader.scala | 5 +- .../datasources/csv/CSVFileFormat.scala | 1 + .../datasources/csv/CSVOptions.scala | 2 +- .../datasources/json/InferSchema.scala | 3 +- .../datasources/json/JacksonGenerator.scala | 3 +- .../datasources/json/JsonFileFormat.scala | 2 + .../datasources/text/TextFileFormat.scala | 1 + .../org/apache/spark/sql/functions.scala | 58 +++++++++++++++++++ .../apache/spark/sql/JsonFunctionsSuite.scala | 29 ++++++++++ .../json/JsonParsingOptionsSuite.scala | 1 + .../datasources/json/JsonSuite.scala | 3 +- 19 files changed, 198 insertions(+), 23 deletions(-) rename sql/{core/src/main/scala/org/apache/spark/sql/execution/datasources => catalyst/src/main/scala/org/apache/spark/sql/catalyst}/json/JSONOptions.scala (95%) rename sql/{core/src/main/scala/org/apache/spark/sql/execution/datasources => catalyst/src/main/scala/org/apache/spark/sql/catalyst}/json/JacksonParser.scala (97%) rename sql/{core/src/main/scala/org/apache/spark/sql/execution/datasources => catalyst/src/main/scala/org/apache/spark/sql/catalyst}/json/JacksonUtils.scala (92%) rename sql/{core/src/main/scala/org/apache/spark/sql/execution/datasources => catalyst/src/main/scala/org/apache/spark/sql/catalyst/util}/CompressionCodecs.scala (93%) rename sql/{core/src/main/scala/org/apache/spark/sql/execution/datasources => catalyst/src/main/scala/org/apache/spark/sql/catalyst/util}/ParseModes.scala (94%) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 89b3c07c0740f..45d6bf944b702 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1706,6 +1706,29 @@ def json_tuple(col, *fields): return Column(jc) +@since(2.1) +def from_json(col, schema, options={}): + """ + Parses a column containing a JSON string into a [[StructType]] with the + specified schema. Returns `null`, in the case of an unparseable string. + + :param col: string column in json format + :param schema: a StructType to use when parsing the json column + :param options: options to control parsing. accepts the same options as the json datasource + + >>> from pyspark.sql.types import * + >>> data = [(1, '''{"a": 1}''')] + >>> schema = StructType([StructField("a", IntegerType())]) + >>> df = spark.createDataFrame(data, ("key", "value")) + >>> df.select(from_json(df.value, schema).alias("json")).collect() + [Row(json=Row(a=1))] + """ + + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.from_json(_to_java_column(col), schema.json(), options) + return Column(jc) + + @since(1.5) def size(col): """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index c14a2fb122618..65dbd6a4e3f1d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -23,10 +23,12 @@ import scala.util.parsing.combinator.RegexParsers import com.fasterxml.jackson.core._ -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback -import org.apache.spark.sql.types.{DataType, StringType, StructField, StructType} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.json.{JacksonParser, JSONOptions, SparkSQLJsonProcessingException} +import org.apache.spark.sql.catalyst.util.ParseModes +import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils @@ -467,3 +469,28 @@ case class JsonTuple(children: Seq[Expression]) } } +/** + * Converts an json input string to a [[StructType]] with the specified schema. + */ +case class JsonToStruct(schema: StructType, options: Map[String, String], child: Expression) + extends Expression with CodegenFallback with ExpectsInputTypes { + override def nullable: Boolean = true + + @transient + lazy val parser = + new JacksonParser( + schema, + "invalid", // Not used since we force fail fast. Invalid rows will be set to `null`. + new JSONOptions(options ++ Map("mode" -> ParseModes.FAIL_FAST_MODE))) + + override def dataType: DataType = schema + override def children: Seq[Expression] = child :: Nil + + override def eval(input: InternalRow): Any = { + try parser.parse(child.eval(input).toString).head catch { + case _: SparkSQLJsonProcessingException => null + } + } + + override def inputTypes: Seq[AbstractDataType] = StringType :: Nil +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala similarity index 95% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala index 02d211d04265e..aec18922ea6c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala @@ -15,16 +15,16 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.datasources.json +package org.apache.spark.sql.catalyst.json import com.fasterxml.jackson.core.{JsonFactory, JsonParser} import org.apache.commons.lang3.time.FastDateFormat import org.apache.spark.internal.Logging -import org.apache.spark.sql.execution.datasources.{CompressionCodecs, ParseModes} +import org.apache.spark.sql.catalyst.util.{CompressionCodecs, ParseModes} /** - * Options for the JSON data source. + * Options for parsing JSON data into Spark SQL rows. * * Most of these map directly to Jackson's internal options, specified in [[JsonParser.Feature]]. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala similarity index 97% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index 5ce1bf7432159..f80e6373d2f89 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.datasources.json +package org.apache.spark.sql.catalyst.json import java.io.ByteArrayOutputStream @@ -28,19 +28,22 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.execution.datasources.ParseModes.{DROP_MALFORMED_MODE, PERMISSIVE_MODE} -import org.apache.spark.sql.execution.datasources.json.JacksonUtils.nextUntil import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils -private[json] class SparkSQLJsonProcessingException(msg: String) extends RuntimeException(msg) +private[sql] class SparkSQLJsonProcessingException(msg: String) extends RuntimeException(msg) +/** + * Constructs a parser for a given schema that translates a json string to an [[InternalRow]]. + */ class JacksonParser( schema: StructType, columnNameOfCorruptRecord: String, options: JSONOptions) extends Logging { + import JacksonUtils._ + import ParseModes._ import com.fasterxml.jackson.core.JsonToken._ // A `ValueConverter` is responsible for converting a value from `JsonParser` @@ -65,7 +68,7 @@ class JacksonParser( private def failedRecord(record: String): Seq[InternalRow] = { // create a row even if no corrupt record column is present if (options.failFast) { - throw new RuntimeException(s"Malformed line in FAILFAST mode: $record") + throw new SparkSQLJsonProcessingException(s"Malformed line in FAILFAST mode: $record") } if (options.dropMalformed) { if (!isWarningPrintedForMalformedRecord) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonUtils.scala similarity index 92% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonUtils.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonUtils.scala index 005546f37dda0..c4d9abb2c07e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonUtils.scala @@ -15,11 +15,11 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.datasources.json +package org.apache.spark.sql.catalyst.json import com.fasterxml.jackson.core.{JsonParser, JsonToken} -private object JacksonUtils { +object JacksonUtils { /** * Advance the parser until a null or a specific token is found */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CompressionCodecs.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CompressionCodecs.scala similarity index 93% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CompressionCodecs.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CompressionCodecs.scala index 41cff07472d1e..435fba9d8851c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CompressionCodecs.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CompressionCodecs.scala @@ -15,15 +15,15 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.datasources +package org.apache.spark.sql.catalyst.util import org.apache.hadoop.conf.Configuration import org.apache.hadoop.io.SequenceFile.CompressionType -import org.apache.hadoop.io.compress.{BZip2Codec, DeflateCodec, GzipCodec, Lz4Codec, SnappyCodec} +import org.apache.hadoop.io.compress._ import org.apache.spark.util.Utils -private[datasources] object CompressionCodecs { +object CompressionCodecs { private val shortCompressionCodecNames = Map( "none" -> null, "uncompressed" -> null, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ParseModes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ParseModes.scala similarity index 94% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ParseModes.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ParseModes.scala index 468228053c964..0e466962b4678 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ParseModes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ParseModes.scala @@ -15,9 +15,9 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.datasources +package org.apache.spark.sql.catalyst.util -private[datasources] object ParseModes { +object ParseModes { val PERMISSIVE_MODE = "PERMISSIVE" val DROP_MALFORMED_MODE = "DROPMALFORMED" val FAIL_FAST_MODE = "FAILFAST" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala index 7b754091f4714..84623934d95d2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.ParseModes +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.apache.spark.unsafe.types.UTF8String class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -317,4 +319,28 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { JsonTuple(Literal("{\"a\":\"b\nc\"}") :: Literal("a") :: Nil), InternalRow.fromSeq(Seq(UTF8String.fromString("b\nc")))) } + + test("from_json") { + val jsonData = """{"a": 1}""" + val schema = StructType(StructField("a", IntegerType) :: Nil) + checkEvaluation( + JsonToStruct(schema, Map.empty, Literal(jsonData)), + InternalRow.fromSeq(1 :: Nil) + ) + } + + test("from_json - invalid data") { + val jsonData = """{"a" 1}""" + val schema = StructType(StructField("a", IntegerType) :: Nil) + checkEvaluation( + JsonToStruct(schema, Map.empty, Literal(jsonData)), + null + ) + + // Other modes should still return `null`. + checkEvaluation( + JsonToStruct(schema, Map("mode" -> ParseModes.PERMISSIVE_MODE), Literal(jsonData)), + null + ) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index b10d2c86ac5ef..b84fb2fb95914 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -21,14 +21,15 @@ import java.util.Properties import scala.collection.JavaConverters._ -import org.apache.spark.Partition import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging +import org.apache.spark.Partition import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.json.{JacksonParser, JSONOptions} import org.apache.spark.sql.execution.LogicalRDD import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.datasources.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation} -import org.apache.spark.sql.execution.datasources.json.{InferSchema, JacksonParser, JSONOptions} +import org.apache.spark.sql.execution.datasources.json.InferSchema import org.apache.spark.sql.types.StructType /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index 9610746a81ef7..4e662a52a7bb7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -29,6 +29,7 @@ import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.CompressionCodecs import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index e7dcc22272192..014614eb997a5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -22,7 +22,7 @@ import java.nio.charset.StandardCharsets import org.apache.commons.lang3.time.FastDateFormat import org.apache.spark.internal.Logging -import org.apache.spark.sql.execution.datasources.{CompressionCodecs, ParseModes} +import org.apache.spark.sql.catalyst.util.{CompressionCodecs, ParseModes} private[csv] class CSVOptions(@transient private val parameters: Map[String, String]) extends Logging with Serializable { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala index 91c58d059d287..dc8bd817f2906 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala @@ -23,7 +23,8 @@ import com.fasterxml.jackson.core._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.TypeCoercion -import org.apache.spark.sql.execution.datasources.json.JacksonUtils.nextUntil +import org.apache.spark.sql.catalyst.json.JacksonUtils.nextUntil +import org.apache.spark.sql.catalyst.json.JSONOptions import org.apache.spark.sql.types._ import org.apache.spark.util.Utils diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala index 270e7fbd3c137..5b55b701862b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala @@ -21,8 +21,9 @@ import java.io.Writer import com.fasterxml.jackson.core._ -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecializedGetters +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.json.JSONOptions import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils, MapData} import org.apache.spark.sql.types._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala index 6882a6cdcac26..9fe38ccc9fdc6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala @@ -32,6 +32,8 @@ import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.json.{JacksonParser, JSONOptions} +import org.apache.spark.sql.catalyst.util.CompressionCodecs import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.StructType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala index a875b01ec2d7a..9f96667311015 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter} +import org.apache.spark.sql.catalyst.util.CompressionCodecs import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{StringType, StructType} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 47bf41a2da813..3bc1c5b90031d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql +import scala.collection.JavaConverters._ import scala.language.implicitConversions import scala.reflect.runtime.universe.{typeTag, TypeTag} import scala.util.Try @@ -2818,6 +2819,63 @@ object functions { JsonTuple(json.expr +: fields.map(Literal.apply)) } + /** + * (Scala-specific) Parses a column containing a JSON string into a [[StructType]] with the + * specified schema. Returns `null`, in the case of an unparseable string. + * + * @param schema the schema to use when parsing the json string + * @param options options to control how the json is parsed. accepts the same options and the + * json data source. + * @param e a string column containing JSON data. + * + * @group collection_funcs + * @since 2.1.0 + */ + def from_json(e: Column, schema: StructType, options: Map[String, String]): Column = withExpr { + JsonToStruct(schema, options, e.expr) + } + + /** + * (Java-specific) Parses a column containing a JSON string into a [[StructType]] with the + * specified schema. Returns `null`, in the case of an unparseable string. + * + * @param e a string column containing JSON data. + * @param schema the schema to use when parsing the json string + * @param options options to control how the json is parsed. accepts the same options and the + * json data source. + * + * @group collection_funcs + * @since 2.1.0 + */ + def from_json(e: Column, schema: StructType, options: java.util.Map[String, String]): Column = + from_json(e, schema, options.asScala.toMap) + + /** + * Parses a column containing a JSON string into a [[StructType]] with the specified schema. + * Returns `null`, in the case of an unparseable string. + * + * @param e a string column containing JSON data. + * @param schema the schema to use when parsing the json string + * + * @group collection_funcs + * @since 2.1.0 + */ + def from_json(e: Column, schema: StructType): Column = + from_json(e, schema, Map.empty[String, String]) + + /** + * Parses a column containing a JSON string into a [[StructType]] with the specified schema. + * Returns `null`, in the case of an unparseable string. + * + * @param e a string column containing JSON data. + * @param schema the schema to use when parsing the json string as a json string + * + * @group collection_funcs + * @since 2.1.0 + */ + def from_json(e: Column, schema: String, options: java.util.Map[String, String]): Column = + from_json(e, DataType.fromJson(schema).asInstanceOf[StructType], options) + /** * Returns length of array or map. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index 1391c9d57ff7c..518d6e92b2ff7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql +import org.apache.spark.sql.functions.from_json import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{IntegerType, StructType} class JsonFunctionsSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -94,4 +96,31 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(expr, expected) } + + test("json_parser") { + val df = Seq("""{"a": 1}""").toDS() + val schema = new StructType().add("a", IntegerType) + + checkAnswer( + df.select(from_json($"value", schema)), + Row(Row(1)) :: Nil) + } + + test("json_parser missing columns") { + val df = Seq("""{"a": 1}""").toDS() + val schema = new StructType().add("b", IntegerType) + + checkAnswer( + df.select(from_json($"value", schema)), + Row(Row(null)) :: Nil) + } + + test("json_parser invalid json") { + val df = Seq("""{"a" 1}""").toDS() + val schema = new StructType().add("a", IntegerType) + + checkAnswer( + df.select(from_json($"value", schema)), + Row(null) :: Nil) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala index c31dffedbdf67..0b72da5f3759c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.json import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.catalyst.json.JSONOptions import org.apache.spark.sql.test.SharedSQLContext /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 3d533c14e18e7..456052f79afcc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -26,9 +26,10 @@ import org.apache.hadoop.fs.{Path, PathFilter} import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.io.compress.GzipCodec -import org.apache.spark.SparkException import org.apache.spark.rdd.RDD +import org.apache.spark.SparkException import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.json.{JacksonParser, JSONOptions} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.datasources.json.InferSchema.compatibleType From 566d7f28275f90f7b9bed6a75e90989ad0c59931 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 29 Sep 2016 14:30:23 -0700 Subject: [PATCH 159/306] [SPARK-17653][SQL] Remove unnecessary distincts in multiple unions ## What changes were proposed in this pull request? Currently for `Union [Distinct]`, a `Distinct` operator is necessary to be on the top of `Union`. Once there are adjacent `Union [Distinct]`, there will be multiple `Distinct` in the query plan. E.g., For a query like: select 1 a union select 2 b union select 3 c Before this patch, its physical plan looks like: *HashAggregate(keys=[a#13], functions=[]) +- Exchange hashpartitioning(a#13, 200) +- *HashAggregate(keys=[a#13], functions=[]) +- Union :- *HashAggregate(keys=[a#13], functions=[]) : +- Exchange hashpartitioning(a#13, 200) : +- *HashAggregate(keys=[a#13], functions=[]) : +- Union : :- *Project [1 AS a#13] : : +- Scan OneRowRelation[] : +- *Project [2 AS b#14] : +- Scan OneRowRelation[] +- *Project [3 AS c#15] +- Scan OneRowRelation[] Only the top distinct should be necessary. After this patch, the physical plan looks like: *HashAggregate(keys=[a#221], functions=[], output=[a#221]) +- Exchange hashpartitioning(a#221, 5) +- *HashAggregate(keys=[a#221], functions=[], output=[a#221]) +- Union :- *Project [1 AS a#221] : +- Scan OneRowRelation[] :- *Project [2 AS b#222] : +- Scan OneRowRelation[] +- *Project [3 AS c#223] +- Scan OneRowRelation[] ## How was this patch tested? Jenkins tests. Author: Liang-Chi Hsieh Closes #15238 from viirya/remove-extra-distinct-union. --- .../sql/catalyst/optimizer/Optimizer.scala | 24 ++++++- .../sql/catalyst/planning/patterns.scala | 27 -------- .../optimizer/SetOperationSuite.scala | 68 +++++++++++++++++++ 3 files changed, 89 insertions(+), 30 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 4952ba3b2b99d..9df8ce1fa3b28 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer import scala.annotation.tailrec import scala.collection.immutable.HashSet +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import org.apache.spark.api.java.function.FilterFunction @@ -29,7 +30,7 @@ import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} -import org.apache.spark.sql.catalyst.planning.{ExtractFiltersAndInnerJoins, Unions} +import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ @@ -579,8 +580,25 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelpe * Combines all adjacent [[Union]] operators into a single [[Union]]. */ object CombineUnions extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case Unions(children) => Union(children) + def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { + case u: Union => flattenUnion(u, false) + case Distinct(u: Union) => Distinct(flattenUnion(u, true)) + } + + private def flattenUnion(union: Union, flattenDistinct: Boolean): Union = { + val stack = mutable.Stack[LogicalPlan](union) + val flattened = mutable.ArrayBuffer.empty[LogicalPlan] + while (stack.nonEmpty) { + stack.pop() match { + case Distinct(Union(children)) if flattenDistinct => + stack.pushAll(children.reverse) + case Union(children) => + stack.pushAll(children.reverse) + case child => + flattened += child + } + } + Union(flattened) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 41cabb8cb3390..bdae56881bf46 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -188,33 +188,6 @@ object ExtractFiltersAndInnerJoins extends PredicateHelper { } } - -/** - * A pattern that collects all adjacent unions and returns their children as a Seq. - */ -object Unions { - def unapply(plan: LogicalPlan): Option[Seq[LogicalPlan]] = plan match { - case u: Union => Some(collectUnionChildren(mutable.Stack(u), Seq.empty[LogicalPlan])) - case _ => None - } - - // Doing a depth-first tree traversal to combine all the union children. - @tailrec - private def collectUnionChildren( - plans: mutable.Stack[LogicalPlan], - children: Seq[LogicalPlan]): Seq[LogicalPlan] = { - if (plans.isEmpty) children - else { - plans.pop match { - case Union(grandchildren) => - grandchildren.reverseMap(plans.push(_)) - collectUnionChildren(plans, children) - case other => collectUnionChildren(plans, children :+ other) - } - } - } -} - /** * An extractor used when planning the physical execution of an aggregation. Compared with a logical * aggregation, the following transformations are performed: diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala index 7227706ab2b36..21b7f49e14bd5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ @@ -76,4 +77,71 @@ class SetOperationSuite extends PlanTest { testRelation3.select('g) :: Nil).analyze comparePlans(unionOptimized, unionCorrectAnswer) } + + test("Remove unnecessary distincts in multiple unions") { + val query1 = OneRowRelation + .select(Literal(1).as('a)) + val query2 = OneRowRelation + .select(Literal(2).as('b)) + val query3 = OneRowRelation + .select(Literal(3).as('c)) + + // D - U - D - U - query1 + // | | + // query3 query2 + val unionQuery1 = Distinct(Union(Distinct(Union(query1, query2)), query3)).analyze + val optimized1 = Optimize.execute(unionQuery1) + val distinctUnionCorrectAnswer1 = + Distinct(Union(query1 :: query2 :: query3 :: Nil)).analyze + comparePlans(distinctUnionCorrectAnswer1, optimized1) + + // query1 + // | + // D - U - U - query2 + // | + // D - U - query2 + // | + // query3 + val unionQuery2 = Distinct(Union(Union(query1, query2), + Distinct(Union(query2, query3)))).analyze + val optimized2 = Optimize.execute(unionQuery2) + val distinctUnionCorrectAnswer2 = + Distinct(Union(query1 :: query2 :: query2 :: query3 :: Nil)).analyze + comparePlans(distinctUnionCorrectAnswer2, optimized2) + } + + test("Keep necessary distincts in multiple unions") { + val query1 = OneRowRelation + .select(Literal(1).as('a)) + val query2 = OneRowRelation + .select(Literal(2).as('b)) + val query3 = OneRowRelation + .select(Literal(3).as('c)) + val query4 = OneRowRelation + .select(Literal(4).as('d)) + + // U - D - U - query1 + // | | + // query3 query2 + val unionQuery1 = Union(Distinct(Union(query1, query2)), query3).analyze + val optimized1 = Optimize.execute(unionQuery1) + val distinctUnionCorrectAnswer1 = + Union(Distinct(Union(query1 :: query2 :: Nil)) :: query3 :: Nil).analyze + comparePlans(distinctUnionCorrectAnswer1, optimized1) + + // query1 + // | + // U - D - U - query2 + // | + // D - U - query3 + // | + // query4 + val unionQuery2 = + Union(Distinct(Union(query1, query2)), Distinct(Union(query3, query4))).analyze + val optimized2 = Optimize.execute(unionQuery2) + val distinctUnionCorrectAnswer2 = + Union(Distinct(Union(query1 :: query2 :: Nil)), + Distinct(Union(query3 :: query4 :: Nil))).analyze + comparePlans(distinctUnionCorrectAnswer2, optimized2) + } } From 4ecc648ad713f9d618adf0406b5d39981779059d Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 29 Sep 2016 15:30:18 -0700 Subject: [PATCH 160/306] [SPARK-17612][SQL] Support `DESCRIBE table PARTITION` SQL syntax ## What changes were proposed in this pull request? This PR implements `DESCRIBE table PARTITION` SQL Syntax again. It was supported until Spark 1.6.2, but was dropped since 2.0.0. **Spark 1.6.2** ```scala scala> sql("CREATE TABLE partitioned_table (a STRING, b INT) PARTITIONED BY (c STRING, d STRING)") res1: org.apache.spark.sql.DataFrame = [result: string] scala> sql("ALTER TABLE partitioned_table ADD PARTITION (c='Us', d=1)") res2: org.apache.spark.sql.DataFrame = [result: string] scala> sql("DESC partitioned_table PARTITION (c='Us', d=1)").show(false) +----------------------------------------------------------------+ |result | +----------------------------------------------------------------+ |a string | |b int | |c string | |d string | | | |# Partition Information | |# col_name data_type comment | | | |c string | |d string | +----------------------------------------------------------------+ ``` **Spark 2.0** - **Before** ```scala scala> sql("CREATE TABLE partitioned_table (a STRING, b INT) PARTITIONED BY (c STRING, d STRING)") res0: org.apache.spark.sql.DataFrame = [] scala> sql("ALTER TABLE partitioned_table ADD PARTITION (c='Us', d=1)") res1: org.apache.spark.sql.DataFrame = [] scala> sql("DESC partitioned_table PARTITION (c='Us', d=1)").show(false) org.apache.spark.sql.catalyst.parser.ParseException: Unsupported SQL statement ``` - **After** ```scala scala> sql("CREATE TABLE partitioned_table (a STRING, b INT) PARTITIONED BY (c STRING, d STRING)") res0: org.apache.spark.sql.DataFrame = [] scala> sql("ALTER TABLE partitioned_table ADD PARTITION (c='Us', d=1)") res1: org.apache.spark.sql.DataFrame = [] scala> sql("DESC partitioned_table PARTITION (c='Us', d=1)").show(false) +-----------------------+---------+-------+ |col_name |data_type|comment| +-----------------------+---------+-------+ |a |string |null | |b |int |null | |c |string |null | |d |string |null | |# Partition Information| | | |# col_name |data_type|comment| |c |string |null | |d |string |null | +-----------------------+---------+-------+ scala> sql("DESC EXTENDED partitioned_table PARTITION (c='Us', d=1)").show(100,false) +-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+---------+-------+ |col_name |data_type|comment| +-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+---------+-------+ |a |string |null | |b |int |null | |c |string |null | |d |string |null | |# Partition Information | | | |# col_name |data_type|comment| |c |string |null | |d |string |null | | | | | |Detailed Partition Information CatalogPartition( Partition Values: [Us, 1] Storage(Location: file:/Users/dhyun/SPARK-17612-DESC-PARTITION/spark-warehouse/partitioned_table/c=Us/d=1, InputFormat: org.apache.hadoop.mapred.TextInputFormat, OutputFormat: org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat, Serde: org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe, Properties: [serialization.format=1]) Partition Parameters:{transient_lastDdlTime=1475001066})| | | +-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+---------+-------+ scala> sql("DESC FORMATTED partitioned_table PARTITION (c='Us', d=1)").show(100,false) +--------------------------------+---------------------------------------------------------------------------------------+-------+ |col_name |data_type |comment| +--------------------------------+---------------------------------------------------------------------------------------+-------+ |a |string |null | |b |int |null | |c |string |null | |d |string |null | |# Partition Information | | | |# col_name |data_type |comment| |c |string |null | |d |string |null | | | | | |# Detailed Partition Information| | | |Partition Value: |[Us, 1] | | |Database: |default | | |Table: |partitioned_table | | |Location: |file:/Users/dhyun/SPARK-17612-DESC-PARTITION/spark-warehouse/partitioned_table/c=Us/d=1| | |Partition Parameters: | | | | transient_lastDdlTime |1475001066 | | | | | | |# Storage Information | | | |SerDe Library: |org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe | | |InputFormat: |org.apache.hadoop.mapred.TextInputFormat | | |OutputFormat: |org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat | | |Compressed: |No | | |Storage Desc Parameters: | | | | serialization.format |1 | | +--------------------------------+---------------------------------------------------------------------------------------+-------+ ``` ## How was this patch tested? Pass the Jenkins tests with a new testcase. Author: Dongjoon Hyun Closes #15168 from dongjoon-hyun/SPARK-17612. --- .../sql/catalyst/catalog/interface.scala | 13 ++- .../spark/sql/execution/SparkSqlParser.scala | 15 +++- .../spark/sql/execution/command/tables.scala | 83 ++++++++++++++--- .../resources/sql-tests/inputs/describe.sql | 27 ++++++ .../sql-tests/results/describe.sql.out | 90 +++++++++++++++++++ .../sql/hive/execution/SQLQuerySuite.scala | 77 +++++++++++++++- 6 files changed, 287 insertions(+), 18 deletions(-) create mode 100644 sql/core/src/test/resources/sql-tests/inputs/describe.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/describe.sql.out diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index e52251f960ff4..51326ca25e9cc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -86,7 +86,18 @@ object CatalogStorageFormat { case class CatalogTablePartition( spec: CatalogTypes.TablePartitionSpec, storage: CatalogStorageFormat, - parameters: Map[String, String] = Map.empty) + parameters: Map[String, String] = Map.empty) { + + override def toString: String = { + val output = + Seq( + s"Partition Values: [${spec.values.mkString(", ")}]", + s"$storage", + s"Partition Parameters:{${parameters.map(p => p._1 + "=" + p._2).mkString(", ")}}") + + output.filter(_.nonEmpty).mkString("CatalogPartition(\n\t", "\n\t", ")") + } +} /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 5359cedc80974..3f34d0f25393d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -276,13 +276,24 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { * Create a [[DescribeTableCommand]] logical plan. */ override def visitDescribeTable(ctx: DescribeTableContext): LogicalPlan = withOrigin(ctx) { - // Describe partition and column are not supported yet. Return null and let the parser decide + // Describe column are not supported yet. Return null and let the parser decide // what to do with this (create an exception or pass it on to a different system). - if (ctx.describeColName != null || ctx.partitionSpec != null) { + if (ctx.describeColName != null) { null } else { + val partitionSpec = if (ctx.partitionSpec != null) { + // According to the syntax, visitPartitionSpec returns `Map[String, Option[String]]`. + visitPartitionSpec(ctx.partitionSpec).map { + case (key, Some(value)) => key -> value + case (key, _) => + throw new ParseException(s"PARTITION specification is incomplete: `$key`", ctx) + } + } else { + Map.empty[String, String] + } DescribeTableCommand( visitTableIdentifier(ctx.tableIdentifier), + partitionSpec, ctx.EXTENDED != null, ctx.FORMATTED != null) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index 6a91c997bac63..08de6cd4242c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -29,7 +29,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType} +import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.catalog.CatalogTableType._ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} @@ -390,10 +390,14 @@ case class TruncateTableCommand( /** * Command that looks like * {{{ - * DESCRIBE [EXTENDED|FORMATTED] table_name; + * DESCRIBE [EXTENDED|FORMATTED] table_name partitionSpec?; * }}} */ -case class DescribeTableCommand(table: TableIdentifier, isExtended: Boolean, isFormatted: Boolean) +case class DescribeTableCommand( + table: TableIdentifier, + partitionSpec: TablePartitionSpec, + isExtended: Boolean, + isFormatted: Boolean) extends RunnableCommand { override val output: Seq[Attribute] = Seq( @@ -411,17 +415,25 @@ case class DescribeTableCommand(table: TableIdentifier, isExtended: Boolean, isF val catalog = sparkSession.sessionState.catalog if (catalog.isTemporaryTable(table)) { + if (partitionSpec.nonEmpty) { + throw new AnalysisException( + s"DESC PARTITION is not allowed on a temporary view: ${table.identifier}") + } describeSchema(catalog.lookupRelation(table).schema, result) } else { val metadata = catalog.getTableMetadata(table) describeSchema(metadata.schema, result) - if (isExtended) { - describeExtended(metadata, result) - } else if (isFormatted) { - describeFormatted(metadata, result) + describePartitionInfo(metadata, result) + + if (partitionSpec.isEmpty) { + if (isExtended) { + describeExtendedTableInfo(metadata, result) + } else if (isFormatted) { + describeFormattedTableInfo(metadata, result) + } } else { - describePartitionInfo(metadata, result) + describeDetailedPartitionInfo(catalog, metadata, result) } } @@ -436,16 +448,12 @@ case class DescribeTableCommand(table: TableIdentifier, isExtended: Boolean, isF } } - private def describeExtended(table: CatalogTable, buffer: ArrayBuffer[Row]): Unit = { - describePartitionInfo(table, buffer) - + private def describeExtendedTableInfo(table: CatalogTable, buffer: ArrayBuffer[Row]): Unit = { append(buffer, "", "", "") append(buffer, "# Detailed Table Information", table.toString, "") } - private def describeFormatted(table: CatalogTable, buffer: ArrayBuffer[Row]): Unit = { - describePartitionInfo(table, buffer) - + private def describeFormattedTableInfo(table: CatalogTable, buffer: ArrayBuffer[Row]): Unit = { append(buffer, "", "", "") append(buffer, "# Detailed Table Information", "", "") append(buffer, "Database:", table.database, "") @@ -499,6 +507,53 @@ case class DescribeTableCommand(table: TableIdentifier, isExtended: Boolean, isF } } + private def describeDetailedPartitionInfo( + catalog: SessionCatalog, + metadata: CatalogTable, + result: ArrayBuffer[Row]): Unit = { + if (metadata.tableType == CatalogTableType.VIEW) { + throw new AnalysisException( + s"DESC PARTITION is not allowed on a view: ${table.identifier}") + } + if (DDLUtils.isDatasourceTable(metadata)) { + throw new AnalysisException( + s"DESC PARTITION is not allowed on a datasource table: ${table.identifier}") + } + val partition = catalog.getPartition(table, partitionSpec) + if (isExtended) { + describeExtendedDetailedPartitionInfo(table, metadata, partition, result) + } else if (isFormatted) { + describeFormattedDetailedPartitionInfo(table, metadata, partition, result) + describeStorageInfo(metadata, result) + } + } + + private def describeExtendedDetailedPartitionInfo( + tableIdentifier: TableIdentifier, + table: CatalogTable, + partition: CatalogTablePartition, + buffer: ArrayBuffer[Row]): Unit = { + append(buffer, "", "", "") + append(buffer, "Detailed Partition Information " + partition.toString, "", "") + } + + private def describeFormattedDetailedPartitionInfo( + tableIdentifier: TableIdentifier, + table: CatalogTable, + partition: CatalogTablePartition, + buffer: ArrayBuffer[Row]): Unit = { + append(buffer, "", "", "") + append(buffer, "# Detailed Partition Information", "", "") + append(buffer, "Partition Value:", s"[${partition.spec.values.mkString(", ")}]", "") + append(buffer, "Database:", table.database, "") + append(buffer, "Table:", tableIdentifier.table, "") + append(buffer, "Location:", partition.storage.locationUri.getOrElse(""), "") + append(buffer, "Partition Parameters:", "", "") + partition.parameters.foreach { case (key, value) => + append(buffer, s" $key", value, "") + } + } + private def describeSchema(schema: StructType, buffer: ArrayBuffer[Row]): Unit = { schema.foreach { column => append(buffer, column.name, column.dataType.simpleString, column.getComment().orNull) diff --git a/sql/core/src/test/resources/sql-tests/inputs/describe.sql b/sql/core/src/test/resources/sql-tests/inputs/describe.sql new file mode 100644 index 0000000000000..3f0ae902e0529 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/describe.sql @@ -0,0 +1,27 @@ +CREATE TABLE t (a STRING, b INT) PARTITIONED BY (c STRING, d STRING); + +ALTER TABLE t ADD PARTITION (c='Us', d=1); + +DESC t; + +-- Ignore these because there exist timestamp results, e.g., `Create Table`. +-- DESC EXTENDED t; +-- DESC FORMATTED t; + +DESC t PARTITION (c='Us', d=1); + +-- Ignore these because there exist timestamp results, e.g., transient_lastDdlTime. +-- DESC EXTENDED t PARTITION (c='Us', d=1); +-- DESC FORMATTED t PARTITION (c='Us', d=1); + +-- NoSuchPartitionException: Partition not found in table +DESC t PARTITION (c='Us', d=2); + +-- AnalysisException: Partition spec is invalid +DESC t PARTITION (c='Us'); + +-- ParseException: PARTITION specification is incomplete +DESC t PARTITION (c='Us', d); + +-- DROP TEST TABLE +DROP TABLE t; diff --git a/sql/core/src/test/resources/sql-tests/results/describe.sql.out b/sql/core/src/test/resources/sql-tests/results/describe.sql.out new file mode 100644 index 0000000000000..37bf303f1bfe4 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/describe.sql.out @@ -0,0 +1,90 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 8 + + +-- !query 0 +CREATE TABLE t (a STRING, b INT) PARTITIONED BY (c STRING, d STRING) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +ALTER TABLE t ADD PARTITION (c='Us', d=1) +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +DESC t +-- !query 2 schema +struct +-- !query 2 output +# Partition Information +# col_name data_type comment +a string +b int +c string +c string +d string +d string + + +-- !query 3 +DESC t PARTITION (c='Us', d=1) +-- !query 3 schema +struct +-- !query 3 output +# Partition Information +# col_name data_type comment +a string +b int +c string +c string +d string +d string + + +-- !query 4 +DESC t PARTITION (c='Us', d=2) +-- !query 4 schema +struct<> +-- !query 4 output +org.apache.spark.sql.catalyst.analysis.NoSuchPartitionException +Partition not found in table 't' database 'default': +c -> Us +d -> 2; + + +-- !query 5 +DESC t PARTITION (c='Us') +-- !query 5 schema +struct<> +-- !query 5 output +org.apache.spark.sql.AnalysisException +Partition spec is invalid. The spec (c) must match the partition spec (c, d) defined in table '`default`.`t`'; + + +-- !query 6 +DESC t PARTITION (c='Us', d) +-- !query 6 schema +struct<> +-- !query 6 output +org.apache.spark.sql.catalyst.parser.ParseException + +PARTITION specification is incomplete: `d`(line 1, pos 0) + +== SQL == +DESC t PARTITION (c='Us', d) +^^^ + + +-- !query 7 +DROP TABLE t +-- !query 7 schema +struct<> +-- !query 7 output + diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index dc4d099f0f666..6c77a0deb52a4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -26,7 +26,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, FunctionRegistry} +import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, FunctionRegistry, NoSuchPartitionException} import org.apache.spark.sql.catalyst.catalog.CatalogTableType import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} @@ -341,6 +341,81 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } + test("describe partition") { + withTable("partitioned_table") { + sql("CREATE TABLE partitioned_table (a STRING, b INT) PARTITIONED BY (c STRING, d STRING)") + sql("ALTER TABLE partitioned_table ADD PARTITION (c='Us', d=1)") + + checkKeywordsExist(sql("DESC partitioned_table PARTITION (c='Us', d=1)"), + "# Partition Information", + "# col_name") + + checkKeywordsExist(sql("DESC EXTENDED partitioned_table PARTITION (c='Us', d=1)"), + "# Partition Information", + "# col_name", + "Detailed Partition Information CatalogPartition(", + "Partition Values: [Us, 1]", + "Storage(Location:", + "Partition Parameters") + + checkKeywordsExist(sql("DESC FORMATTED partitioned_table PARTITION (c='Us', d=1)"), + "# Partition Information", + "# col_name", + "# Detailed Partition Information", + "Partition Value:", + "Database:", + "Table:", + "Location:", + "Partition Parameters:", + "# Storage Information") + } + } + + test("describe partition - error handling") { + withTable("partitioned_table", "datasource_table") { + sql("CREATE TABLE partitioned_table (a STRING, b INT) PARTITIONED BY (c STRING, d STRING)") + sql("ALTER TABLE partitioned_table ADD PARTITION (c='Us', d=1)") + + val m = intercept[NoSuchPartitionException] { + sql("DESC partitioned_table PARTITION (c='Us', d=2)") + }.getMessage() + assert(m.contains("Partition not found in table")) + + val m2 = intercept[AnalysisException] { + sql("DESC partitioned_table PARTITION (c='Us')") + }.getMessage() + assert(m2.contains("Partition spec is invalid")) + + val m3 = intercept[ParseException] { + sql("DESC partitioned_table PARTITION (c='Us', d)") + }.getMessage() + assert(m3.contains("PARTITION specification is incomplete: `d`")) + + spark + .range(1).select('id as 'a, 'id as 'b, 'id as 'c, 'id as 'd).write + .partitionBy("d") + .saveAsTable("datasource_table") + val m4 = intercept[AnalysisException] { + sql("DESC datasource_table PARTITION (d=2)") + }.getMessage() + assert(m4.contains("DESC PARTITION is not allowed on a datasource table")) + + val m5 = intercept[AnalysisException] { + spark.range(10).select('id as 'a, 'id as 'b).createTempView("view1") + sql("DESC view1 PARTITION (c='Us', d=1)") + }.getMessage() + assert(m5.contains("DESC PARTITION is not allowed on a temporary view")) + + withView("permanent_view") { + val m = intercept[AnalysisException] { + sql("CREATE VIEW permanent_view AS SELECT * FROM partitioned_table") + sql("DESC permanent_view PARTITION (c='Us', d=1)") + }.getMessage() + assert(m.contains("DESC PARTITION is not allowed on a view")) + } + } + } + test("SPARK-5371: union with null and sum") { val df = Seq((1, 1)).toDF("c1", "c2") df.createOrReplaceTempView("table1") From 29396e7d1483d027960b9a1bed47008775c4253e Mon Sep 17 00:00:00 2001 From: Bjarne Fruergaard Date: Thu, 29 Sep 2016 15:39:57 -0700 Subject: [PATCH 161/306] [SPARK-17721][MLLIB][ML] Fix for multiplying transposed SparseMatrix with SparseVector ## What changes were proposed in this pull request? * changes the implementation of gemv with transposed SparseMatrix and SparseVector both in mllib-local and mllib (identical) * adds a test that was failing before this change, but succeeds with these changes. The problem in the previous implementation was that it only increments `i`, that is enumerating the columns of a row in the SparseMatrix, when the row-index of the vector matches the column-index of the SparseMatrix. In cases where a particular row of the SparseMatrix has non-zero values at column-indices lower than corresponding non-zero row-indices of the SparseVector, the non-zero values of the SparseVector are enumerated without ever matching the column-index at index `i` and the remaining column-indices i+1,...,indEnd-1 are never attempted. The test cases in this PR illustrate this issue. ## How was this patch tested? I have run the specific `gemv` tests in both mllib-local and mllib. I am currently still running `./dev/run-tests`. ## ___ As per instructions, I hereby state that this is my original work and that I license the work to the project (Apache Spark) under the project's open source license. Mentioning dbtsai, viirya and brkyvz whom I can see have worked/authored on these parts before. Author: Bjarne Fruergaard Closes #15296 from bwahlgreen/bugfix-spark-17721. --- .../scala/org/apache/spark/ml/linalg/BLAS.scala | 8 ++++++-- .../org/apache/spark/ml/linalg/BLASSuite.scala | 17 +++++++++++++++++ .../org/apache/spark/mllib/linalg/BLAS.scala | 8 ++++++-- .../apache/spark/mllib/linalg/BLASSuite.scala | 17 +++++++++++++++++ 4 files changed, 46 insertions(+), 4 deletions(-) diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala index 41b0c6c89a647..4ca19f3387f07 100644 --- a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala +++ b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala @@ -638,12 +638,16 @@ private[spark] object BLAS extends Serializable { val indEnd = Arows(rowCounter + 1) var sum = 0.0 var k = 0 - while (k < xNnz && i < indEnd) { + while (i < indEnd && k < xNnz) { if (xIndices(k) == Acols(i)) { sum += Avals(i) * xValues(k) + k += 1 + i += 1 + } else if (xIndices(k) < Acols(i)) { + k += 1 + } else { i += 1 } - k += 1 } yValues(rowCounter) = sum * alpha + beta * yValues(rowCounter) rowCounter += 1 diff --git a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/BLASSuite.scala b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/BLASSuite.scala index 8a9f49792c1cd..6e72a5fff0a91 100644 --- a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/BLASSuite.scala +++ b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/BLASSuite.scala @@ -392,6 +392,23 @@ class BLASSuite extends SparkMLFunSuite { } } + val y17 = new DenseVector(Array(0.0, 0.0)) + val y18 = y17.copy + + val sA3 = new SparseMatrix(3, 2, Array(0, 2, 4), Array(1, 2, 0, 1), Array(2.0, 1.0, 1.0, 2.0)) + .transpose + val sA4 = + new SparseMatrix(2, 3, Array(0, 1, 3, 4), Array(1, 0, 1, 0), Array(1.0, 2.0, 2.0, 1.0)) + val sx3 = new SparseVector(3, Array(1, 2), Array(2.0, 1.0)) + + val expected4 = new DenseVector(Array(5.0, 4.0)) + + gemv(1.0, sA3, sx3, 0.0, y17) + gemv(1.0, sA4, sx3, 0.0, y18) + + assert(y17 ~== expected4 absTol 1e-15) + assert(y18 ~== expected4 absTol 1e-15) + val dAT = new DenseMatrix(3, 4, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0)) val sAT = diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala index 6a85608706974..0cd68a633c0b5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala @@ -637,12 +637,16 @@ private[spark] object BLAS extends Serializable with Logging { val indEnd = Arows(rowCounter + 1) var sum = 0.0 var k = 0 - while (k < xNnz && i < indEnd) { + while (i < indEnd && k < xNnz) { if (xIndices(k) == Acols(i)) { sum += Avals(i) * xValues(k) + k += 1 + i += 1 + } else if (xIndices(k) < Acols(i)) { + k += 1 + } else { i += 1 } - k += 1 } yValues(rowCounter) = sum * alpha + beta * yValues(rowCounter) rowCounter += 1 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala index 80da03cc2efeb..6e68c1c9d36c8 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala @@ -392,6 +392,23 @@ class BLASSuite extends SparkFunSuite { } } + val y17 = new DenseVector(Array(0.0, 0.0)) + val y18 = y17.copy + + val sA3 = new SparseMatrix(3, 2, Array(0, 2, 4), Array(1, 2, 0, 1), Array(2.0, 1.0, 1.0, 2.0)) + .transpose + val sA4 = + new SparseMatrix(2, 3, Array(0, 1, 3, 4), Array(1, 0, 1, 0), Array(1.0, 2.0, 2.0, 1.0)) + val sx3 = new SparseVector(3, Array(1, 2), Array(2.0, 1.0)) + + val expected4 = new DenseVector(Array(5.0, 4.0)) + + gemv(1.0, sA3, sx3, 0.0, y17) + gemv(1.0, sA4, sx3, 0.0, y18) + + assert(y17 ~== expected4 absTol 1e-15) + assert(y18 ~== expected4 absTol 1e-15) + val dAT = new DenseMatrix(3, 4, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0)) val sAT = From 3993ebca23afa4b8770695051635933a6c9d2c11 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Thu, 29 Sep 2016 15:40:35 -0700 Subject: [PATCH 162/306] [SPARK-17676][CORE] FsHistoryProvider should ignore hidden files ## What changes were proposed in this pull request? FsHistoryProvider was writing a hidden file (to check the fs's clock). Even though it deleted the file immediately, sometimes another thread would try to scan the files on the fs in-between, and then there would be an error msg logged which was very misleading for the end-user. (The logged error was harmless, though.) ## How was this patch tested? I added one unit test, but to be clear, that test was passing before. The actual change in behavior in that test is just logging (after the change, there is no more logged error), which I just manually verified. Author: Imran Rashid Closes #15250 from squito/SPARK-17676. --- .../deploy/history/FsHistoryProvider.scala | 7 +++- .../history/FsHistoryProviderSuite.scala | 36 +++++++++++++++++-- 2 files changed, 40 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index d494ff0659bd2..c5740e4737094 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -294,7 +294,12 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) .filter { entry => try { val prevFileSize = fileToAppInfo.get(entry.getPath()).map{_.fileSize}.getOrElse(0L) - !entry.isDirectory() && prevFileSize < entry.getLen() + !entry.isDirectory() && + // FsHistoryProvider generates a hidden file which can't be read. Accidentally + // reading a garbage file is safe, but we would log an error which can be scary to + // the end-user. + !entry.getPath().getName().startsWith(".") && + prevFileSize < entry.getLen() } catch { case e: AccessControlException => // Do not use "logInfo" since these messages can get pretty noisy if printed on diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index 39c5857b13451..01bef0a11c124 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.deploy.history -import java.io.{BufferedOutputStream, ByteArrayInputStream, ByteArrayOutputStream, File, - FileOutputStream, OutputStreamWriter} +import java.io._ import java.net.URI import java.nio.charset.StandardCharsets import java.util.concurrent.TimeUnit @@ -394,6 +393,39 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc } } + test("ignore hidden files") { + + // FsHistoryProvider should ignore hidden files. (It even writes out a hidden file itself + // that should be ignored). + + // write out one totally bogus hidden file + val hiddenGarbageFile = new File(testDir, ".garbage") + val out = new PrintWriter(hiddenGarbageFile) + // scalastyle:off println + out.println("GARBAGE") + // scalastyle:on println + out.close() + + // also write out one real event log file, but since its a hidden file, we shouldn't read it + val tmpNewAppFile = newLogFile("hidden", None, inProgress = false) + val hiddenNewAppFile = new File(tmpNewAppFile.getParentFile, "." + tmpNewAppFile.getName) + tmpNewAppFile.renameTo(hiddenNewAppFile) + + // and write one real file, which should still get picked up just fine + val newAppComplete = newLogFile("real-app", None, inProgress = false) + writeFile(newAppComplete, true, None, + SparkListenerApplicationStart(newAppComplete.getName(), Some("new-app-complete"), 1L, "test", + None), + SparkListenerApplicationEnd(5L) + ) + + val provider = new FsHistoryProvider(createTestConf()) + updateAndCheck(provider) { list => + list.size should be (1) + list(0).name should be ("real-app") + } + } + /** * Asks the provider to check for logs and calls a function to perform checks on the updated * app list. Example: From 39eb3bb1ec29aa993de13a6eba3ab27db6fc5371 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 29 Sep 2016 16:01:45 -0700 Subject: [PATCH 163/306] [SPARK-17412][DOC] All test should not be run by `root` or any admin user ## What changes were proposed in this pull request? `FsHistoryProviderSuite` fails if `root` user runs it. The test case **SPARK-3697: ignore directories that cannot be read** depends on `setReadable(false, false)` to make test data files and expects the number of accessible files is 1. But, `root` can access all files, so it returns 2. This PR adds the assumption explicitly on doc. `building-spark.md`. ## How was this patch tested? This is a documentation change. Author: Dongjoon Hyun Closes #15291 from dongjoon-hyun/SPARK-17412. --- docs/building-spark.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/building-spark.md b/docs/building-spark.md index 75c304a3ccecd..da7eeb8348378 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -215,6 +215,7 @@ For help in setting up IntelliJ IDEA or Eclipse for Spark development, and troub # Running Tests Tests are run by default via the [ScalaTest Maven plugin](http://www.scalatest.org/user_guide/using_the_scalatest_maven_plugin). +Note that tests should not be run as root or an admin user. Some of the tests require Spark to be packaged first, so always run `mvn package` with `-DskipTests` the first time. The following is an example of a correct (build, test) sequence: From 2f739567080d804a942cfcca0e22f91ab7cbea36 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Thu, 29 Sep 2016 16:31:30 -0700 Subject: [PATCH 164/306] [SPARK-17697][ML] Fixed bug in summary calculations that pattern match against label without casting ## What changes were proposed in this pull request? In calling LogisticRegression.evaluate and GeneralizedLinearRegression.evaluate using a Dataset where the Label is not of a double type, calculations pattern match against a double and throw a MatchError. This fix casts the Label column to a DoubleType to ensure there is no MatchError. ## How was this patch tested? Added unit tests to call evaluate with a dataset that has Label as other numeric types. Author: Bryan Cutler Closes #15288 from BryanCutler/binaryLOR-numericCheck-SPARK-17697. --- .../classification/LogisticRegression.scala | 2 +- .../GeneralizedLinearRegression.scala | 11 ++++---- .../LogisticRegressionSuite.scala | 18 ++++++++++++- .../GeneralizedLinearRegressionSuite.scala | 25 +++++++++++++++++++ 4 files changed, 49 insertions(+), 7 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 5ab63d1de95d3..329961a25d984 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -1169,7 +1169,7 @@ class BinaryLogisticRegressionSummary private[classification] ( // TODO: Allow the user to vary the number of bins using a setBins method in // BinaryClassificationMetrics. For now the default is set to 100. @transient private val binaryMetrics = new BinaryClassificationMetrics( - predictions.select(probabilityCol, labelCol).rdd.map { + predictions.select(col(probabilityCol), col(labelCol).cast(DoubleType)).rdd.map { case Row(score: Vector, label: Double) => (score(1), label) }, 100 ) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 02b27fb650979..bb9e150c49772 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -992,7 +992,7 @@ class GeneralizedLinearRegressionSummary private[regression] ( } else { link.unlink(0.0) } - predictions.select(col(model.getLabelCol), w).rdd.map { + predictions.select(col(model.getLabelCol).cast(DoubleType), w).rdd.map { case Row(y: Double, weight: Double) => family.deviance(y, wtdmu, weight) }.sum() @@ -1004,7 +1004,7 @@ class GeneralizedLinearRegressionSummary private[regression] ( @Since("2.0.0") lazy val deviance: Double = { val w = weightCol - predictions.select(col(model.getLabelCol), col(predictionCol), w).rdd.map { + predictions.select(col(model.getLabelCol).cast(DoubleType), col(predictionCol), w).rdd.map { case Row(label: Double, pred: Double, weight: Double) => family.deviance(label, pred, weight) }.sum() @@ -1030,9 +1030,10 @@ class GeneralizedLinearRegressionSummary private[regression] ( lazy val aic: Double = { val w = weightCol val weightSum = predictions.select(w).agg(sum(w)).first().getDouble(0) - val t = predictions.select(col(model.getLabelCol), col(predictionCol), w).rdd.map { - case Row(label: Double, pred: Double, weight: Double) => - (label, pred, weight) + val t = predictions.select( + col(model.getLabelCol).cast(DoubleType), col(predictionCol), w).rdd.map { + case Row(label: Double, pred: Double, weight: Double) => + (label, pred, weight) } family.aic(t, deviance, numInstances, weightSum) + 2 * rank } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 8451e60144981..42b56754e0835 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -32,7 +32,8 @@ import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{Dataset, Row} -import org.apache.spark.sql.functions.lit +import org.apache.spark.sql.functions.{col, lit} +import org.apache.spark.sql.types.LongType class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -1776,6 +1777,21 @@ class LogisticRegressionSuite summary.precisionByThreshold.collect() === sameSummary.precisionByThreshold.collect()) } + test("evaluate with labels that are not doubles") { + // Evaluate a test set with Label that is a numeric type other than Double + val lr = new LogisticRegression() + .setMaxIter(1) + .setRegParam(1.0) + val model = lr.fit(smallBinaryDataset) + val summary = model.evaluate(smallBinaryDataset).asInstanceOf[BinaryLogisticRegressionSummary] + + val longLabelData = smallBinaryDataset.select(col(model.getLabelCol).cast(LongType), + col(model.getFeaturesCol)) + val longSummary = model.evaluate(longLabelData).asInstanceOf[BinaryLogisticRegressionSummary] + + assert(summary.areaUnderROC === longSummary.areaUnderROC) + } + test("statistics on training data") { // Test that loss is monotonically decreasing. val lr = new LogisticRegression() diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index 937aa7d3c2045..ac1ef5feb95ba 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -31,6 +31,7 @@ import org.apache.spark.mllib.random._ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.FloatType class GeneralizedLinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -1067,6 +1068,30 @@ class GeneralizedLinearRegressionSuite idx += 1 } } + + test("evaluate with labels that are not doubles") { + // Evaulate with a dataset that contains Labels not as doubles to verify correct casting + val dataset = Seq( + Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), + Instance(19.0, 1.0, Vectors.dense(1.0, 7.0)), + Instance(23.0, 1.0, Vectors.dense(2.0, 11.0)), + Instance(29.0, 1.0, Vectors.dense(3.0, 13.0)) + ).toDF() + + val trainer = new GeneralizedLinearRegression() + .setMaxIter(1) + val model = trainer.fit(dataset) + assert(model.hasSummary) + val summary = model.summary + + val longLabelDataset = dataset.select(col(model.getLabelCol).cast(FloatType), + col(model.getFeaturesCol)) + val evalSummary = model.evaluate(longLabelDataset) + // The calculations below involve pattern matching with Label as a double + assert(evalSummary.nullDeviance === summary.nullDeviance) + assert(evalSummary.deviance === summary.deviance) + assert(evalSummary.aic === summary.aic) + } } object GeneralizedLinearRegressionSuite { From 74ac1c43817c0b8da70342e540ec7638dd7d01bd Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Thu, 29 Sep 2016 17:56:32 -0700 Subject: [PATCH 165/306] [SPARK-17717][SQL] Add exist/find methods to Catalog. ## What changes were proposed in this pull request? The current user facing catalog does not implement methods for checking object existence or finding objects. You could theoretically do this using the `list*` commands, but this is rather cumbersome and can actually be costly when there are many objects. This PR adds `exists*` and `find*` methods for Databases, Table and Functions. ## How was this patch tested? Added tests to `org.apache.spark.sql.internal.CatalogSuite` Author: Herman van Hovell Closes #15301 from hvanhovell/SPARK-17717. --- project/MimaExcludes.scala | 11 +- .../apache/spark/sql/catalog/Catalog.scala | 83 ++++++++++ .../spark/sql/internal/CatalogImpl.scala | 152 +++++++++++++++--- .../spark/sql/internal/CatalogSuite.scala | 118 ++++++++++++++ 4 files changed, 339 insertions(+), 25 deletions(-) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 4db3edb733a56..2ffe0ac9bc982 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -46,7 +46,16 @@ object MimaExcludes { // [SPARK-16967] Move Mesos to Module ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkMasterRegex.MESOS_REGEX"), // [SPARK-16240] ML persistence backward compatibility for LDA - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.clustering.LDA$") + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.clustering.LDA$"), + // [SPARK-17717] Add Find and Exists method to Catalog. + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.findDatabase"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.findTable"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.findFunction"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.findColumn"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.databaseExists"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.tableExists"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.functionExists"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.columnExists") ) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala index 1aed245fdd332..b439022d227cc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala @@ -101,6 +101,89 @@ abstract class Catalog { @throws[AnalysisException]("database or table does not exist") def listColumns(dbName: String, tableName: String): Dataset[Column] + /** + * Find the database with the specified name. This throws an AnalysisException when the database + * cannot be found. + * + * @since 2.1.0 + */ + @throws[AnalysisException]("database does not exist") + def findDatabase(dbName: String): Database + + /** + * Find the table with the specified name. This table can be a temporary table or a table in the + * current database. This throws an AnalysisException when the table cannot be found. + * + * @since 2.1.0 + */ + @throws[AnalysisException]("table does not exist") + def findTable(tableName: String): Table + + /** + * Find the table with the specified name in the specified database. This throws an + * AnalysisException when the table cannot be found. + * + * @since 2.1.0 + */ + @throws[AnalysisException]("database or table does not exist") + def findTable(dbName: String, tableName: String): Table + + /** + * Find the function with the specified name. This function can be a temporary function or a + * function in the current database. This throws an AnalysisException when the function cannot + * be found. + * + * @since 2.1.0 + */ + @throws[AnalysisException]("function does not exist") + def findFunction(functionName: String): Function + + /** + * Find the function with the specified name. This throws an AnalysisException when the function + * cannot be found. + * + * @since 2.1.0 + */ + @throws[AnalysisException]("database or function does not exist") + def findFunction(dbName: String, functionName: String): Function + + /** + * Check if the database with the specified name exists. + * + * @since 2.1.0 + */ + def databaseExists(dbName: String): Boolean + + /** + * Check if the table with the specified name exists. This can either be a temporary table or a + * table in the current database. + * + * @since 2.1.0 + */ + def tableExists(tableName: String): Boolean + + /** + * Check if the table with the specified name exists in the specified database. + * + * @since 2.1.0 + */ + def tableExists(dbName: String, tableName: String): Boolean + + /** + * Check if the function with the specified name exists. This can either be a temporary function + * or a function in the current database. + * + * @since 2.1.0 + */ + def functionExists(functionName: String): Boolean + + /** + * Check if the function with the specified name exists in the specified database. + * + * @since 2.1.0 + */ + def functionExists(dbName: String, functionName: String): Boolean + /** * :: Experimental :: * Creates an external table from the given path and returns the corresponding DataFrame. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala index f252535765899..a1087edd03fdf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala @@ -23,10 +23,10 @@ import scala.reflect.runtime.universe.TypeTag import org.apache.spark.annotation.Experimental import org.apache.spark.sql._ import org.apache.spark.sql.catalog.{Catalog, Column, Database, Function, Table} -import org.apache.spark.sql.catalyst.{DefinedByConstructorParams, TableIdentifier} -import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType, SessionCatalog} +import org.apache.spark.sql.catalyst.{DefinedByConstructorParams, FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, SubqueryAlias} +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.execution.datasources.CreateTable import org.apache.spark.sql.types.StructType @@ -69,15 +69,18 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { */ override def listDatabases(): Dataset[Database] = { val databases = sessionCatalog.listDatabases().map { dbName => - val metadata = sessionCatalog.getDatabaseMetadata(dbName) - new Database( - name = metadata.name, - description = metadata.description, - locationUri = metadata.locationUri) + makeDatabase(sessionCatalog.getDatabaseMetadata(dbName)) } CatalogImpl.makeDataset(databases, sparkSession) } + private def makeDatabase(metadata: CatalogDatabase): Database = { + new Database( + name = metadata.name, + description = metadata.description, + locationUri = metadata.locationUri) + } + /** * Returns a list of tables in the current database. * This includes all temporary tables. @@ -94,18 +97,21 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { override def listTables(dbName: String): Dataset[Table] = { requireDatabaseExists(dbName) val tables = sessionCatalog.listTables(dbName).map { tableIdent => - val isTemp = tableIdent.database.isEmpty - val metadata = if (isTemp) None else Some(sessionCatalog.getTableMetadata(tableIdent)) - new Table( - name = tableIdent.identifier, - database = metadata.flatMap(_.identifier.database).orNull, - description = metadata.flatMap(_.comment).orNull, - tableType = metadata.map(_.tableType.name).getOrElse("TEMPORARY"), - isTemporary = isTemp) + makeTable(tableIdent, tableIdent.database.isEmpty) } CatalogImpl.makeDataset(tables, sparkSession) } + private def makeTable(tableIdent: TableIdentifier, isTemp: Boolean): Table = { + val metadata = if (isTemp) None else Some(sessionCatalog.getTableMetadata(tableIdent)) + new Table( + name = tableIdent.identifier, + database = metadata.flatMap(_.identifier.database).orNull, + description = metadata.flatMap(_.comment).orNull, + tableType = metadata.map(_.tableType.name).getOrElse("TEMPORARY"), + isTemporary = isTemp) + } + /** * Returns a list of functions registered in the current database. * This includes all temporary functions @@ -121,18 +127,22 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { @throws[AnalysisException]("database does not exist") override def listFunctions(dbName: String): Dataset[Function] = { requireDatabaseExists(dbName) - val functions = sessionCatalog.listFunctions(dbName).map { case (funcIdent, _) => - val metadata = sessionCatalog.lookupFunctionInfo(funcIdent) - new Function( - name = funcIdent.identifier, - database = funcIdent.database.orNull, - description = null, // for now, this is always undefined - className = metadata.getClassName, - isTemporary = funcIdent.database.isEmpty) + val functions = sessionCatalog.listFunctions(dbName).map { case (functIdent, _) => + makeFunction(functIdent) } CatalogImpl.makeDataset(functions, sparkSession) } + private def makeFunction(funcIdent: FunctionIdentifier): Function = { + val metadata = sessionCatalog.lookupFunctionInfo(funcIdent) + new Function( + name = funcIdent.identifier, + database = funcIdent.database.orNull, + description = null, // for now, this is always undefined + className = metadata.getClassName, + isTemporary = funcIdent.database.isEmpty) + } + /** * Returns a list of columns for the given table in the current database. */ @@ -167,6 +177,100 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { CatalogImpl.makeDataset(columns, sparkSession) } + /** + * Find the database with the specified name. This throws an [[AnalysisException]] when no + * [[Database]] can be found. + */ + override def findDatabase(dbName: String): Database = { + if (sessionCatalog.databaseExists(dbName)) { + makeDatabase(sessionCatalog.getDatabaseMetadata(dbName)) + } else { + throw new AnalysisException(s"The specified database $dbName does not exist.") + } + } + + /** + * Find the table with the specified name. This table can be a temporary table or a table in the + * current database. This throws an [[AnalysisException]] when no [[Table]] can be found. + */ + override def findTable(tableName: String): Table = { + findTable(null, tableName) + } + + /** + * Find the table with the specified name in the specified database. This throws an + * [[AnalysisException]] when no [[Table]] can be found. + */ + override def findTable(dbName: String, tableName: String): Table = { + val tableIdent = TableIdentifier(tableName, Option(dbName)) + val isTemporary = sessionCatalog.isTemporaryTable(tableIdent) + if (isTemporary || sessionCatalog.tableExists(tableIdent)) { + makeTable(tableIdent, isTemporary) + } else { + throw new AnalysisException(s"The specified table $tableIdent does not exist.") + } + } + + /** + * Find the function with the specified name. This function can be a temporary function or a + * function in the current database. This throws an [[AnalysisException]] when no [[Function]] + * can be found. + */ + override def findFunction(functionName: String): Function = { + findFunction(null, functionName) + } + + /** + * Find the function with the specified name. This returns [[None]] when no [[Function]] can be + * found. + */ + override def findFunction(dbName: String, functionName: String): Function = { + val functionIdent = FunctionIdentifier(functionName, Option(dbName)) + if (sessionCatalog.functionExists(functionIdent)) { + makeFunction(functionIdent) + } else { + throw new AnalysisException(s"The specified function $functionIdent does not exist.") + } + } + + /** + * Check if the database with the specified name exists. + */ + override def databaseExists(dbName: String): Boolean = { + sessionCatalog.databaseExists(dbName) + } + + /** + * Check if the table with the specified name exists. This can either be a temporary table or a + * table in the current database. + */ + override def tableExists(tableName: String): Boolean = { + tableExists(null, tableName) + } + + /** + * Check if the table with the specified name exists in the specified database. + */ + override def tableExists(dbName: String, tableName: String): Boolean = { + val tableIdent = TableIdentifier(tableName, Option(dbName)) + sessionCatalog.isTemporaryTable(tableIdent) || sessionCatalog.tableExists(tableIdent) + } + + /** + * Check if the function with the specified name exists. This can either be a temporary function + * or a function in the current database. + */ + override def functionExists(functionName: String): Boolean = { + functionExists(null, functionName) + } + + /** + * Check if the function with the specified name exists in the specified database. + */ + override def functionExists(dbName: String, functionName: String): Boolean = { + sessionCatalog.functionExists(FunctionIdentifier(functionName, Option(dbName))) + } + /** * :: Experimental :: * Creates an external table from the given path and returns the corresponding DataFrame. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala index 3dc67ffafb048..783bf77f86b46 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala @@ -340,6 +340,124 @@ class CatalogSuite } } + test("find database") { + intercept[AnalysisException](spark.catalog.findDatabase("db10")) + withTempDatabase { db => + assert(spark.catalog.findDatabase(db).name === db) + } + } + + test("find table") { + withTempDatabase { db => + withTable(s"tbl_x", s"$db.tbl_y") { + // Try to find non existing tables. + intercept[AnalysisException](spark.catalog.findTable("tbl_x")) + intercept[AnalysisException](spark.catalog.findTable("tbl_y")) + intercept[AnalysisException](spark.catalog.findTable(db, "tbl_y")) + + // Create objects. + createTempTable("tbl_x") + createTable("tbl_y", Some(db)) + + // Find a temporary table + assert(spark.catalog.findTable("tbl_x").name === "tbl_x") + + // Find a qualified table + assert(spark.catalog.findTable(db, "tbl_y").name === "tbl_y") + + // Find an unqualified table using the current database + intercept[AnalysisException](spark.catalog.findTable("tbl_y")) + spark.catalog.setCurrentDatabase(db) + assert(spark.catalog.findTable("tbl_y").name === "tbl_y") + } + } + } + + test("find function") { + withTempDatabase { db => + withUserDefinedFunction("fn1" -> true, s"$db.fn2" -> false) { + // Try to find non existing functions. + intercept[AnalysisException](spark.catalog.findFunction("fn1")) + intercept[AnalysisException](spark.catalog.findFunction("fn2")) + intercept[AnalysisException](spark.catalog.findFunction(db, "fn2")) + + // Create objects. + createTempFunction("fn1") + createFunction("fn2", Some(db)) + + // Find a temporary function + assert(spark.catalog.findFunction("fn1").name === "fn1") + + // Find a qualified function + assert(spark.catalog.findFunction(db, "fn2").name === "fn2") + + // Find an unqualified function using the current database + intercept[AnalysisException](spark.catalog.findFunction("fn2")) + spark.catalog.setCurrentDatabase(db) + assert(spark.catalog.findFunction("fn2").name === "fn2") + } + } + } + + test("database exists") { + assert(!spark.catalog.databaseExists("db10")) + createDatabase("db10") + assert(spark.catalog.databaseExists("db10")) + dropDatabase("db10") + } + + test("table exists") { + withTempDatabase { db => + withTable(s"tbl_x", s"$db.tbl_y") { + // Try to find non existing tables. + assert(!spark.catalog.tableExists("tbl_x")) + assert(!spark.catalog.tableExists("tbl_y")) + assert(!spark.catalog.tableExists(db, "tbl_y")) + + // Create objects. + createTempTable("tbl_x") + createTable("tbl_y", Some(db)) + + // Find a temporary table + assert(spark.catalog.tableExists("tbl_x")) + + // Find a qualified table + assert(spark.catalog.tableExists(db, "tbl_y")) + + // Find an unqualified table using the current database + assert(!spark.catalog.tableExists("tbl_y")) + spark.catalog.setCurrentDatabase(db) + assert(spark.catalog.tableExists("tbl_y")) + } + } + } + + test("function exists") { + withTempDatabase { db => + withUserDefinedFunction("fn1" -> true, s"$db.fn2" -> false) { + // Try to find non existing functions. + assert(!spark.catalog.functionExists("fn1")) + assert(!spark.catalog.functionExists("fn2")) + assert(!spark.catalog.functionExists(db, "fn2")) + + // Create objects. + createTempFunction("fn1") + createFunction("fn2", Some(db)) + + // Find a temporary function + assert(spark.catalog.functionExists("fn1")) + + // Find a qualified function + assert(spark.catalog.functionExists(db, "fn2")) + + // Find an unqualified function using the current database + assert(!spark.catalog.functionExists("fn2")) + spark.catalog.setCurrentDatabase(db) + assert(spark.catalog.functionExists("fn2")) + } + } + } + // TODO: add tests for the rest of them } From 1fad5596885aab8b32d2307c0edecbae50d5bd7a Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Thu, 29 Sep 2016 23:55:42 -0700 Subject: [PATCH 166/306] [SPARK-14077][ML] Refactor NaiveBayes to support weighted instances ## What changes were proposed in this pull request? 1,support weighted data 2,use dataset/dataframe instead of rdd 3,make mllib as a wrapper to call ml ## How was this patch tested? local manual tests in spark-shell unit tests Author: Zheng RuiFeng Closes #12819 from zhengruifeng/weighted_nb. --- .../spark/ml/classification/NaiveBayes.scala | 154 +++++++++++++----- .../mllib/classification/NaiveBayes.scala | 99 +++-------- .../ml/classification/NaiveBayesSuite.scala | 50 +++++- 3 files changed, 191 insertions(+), 112 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index f939a1c6808e6..0d652aa4c65a1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -19,23 +19,20 @@ package org.apache.spark.ml.classification import org.apache.hadoop.fs.Path -import org.apache.spark.SparkException import org.apache.spark.annotation.Since import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.linalg._ import org.apache.spark.ml.param.{DoubleParam, Param, ParamMap, ParamValidators} +import org.apache.spark.ml.param.shared.HasWeightCol import org.apache.spark.ml.util._ -import org.apache.spark.mllib.classification.{NaiveBayes => OldNaiveBayes} -import org.apache.spark.mllib.classification.{NaiveBayesModel => OldNaiveBayesModel} -import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} -import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Dataset, Row} +import org.apache.spark.sql.Dataset +import org.apache.spark.sql.functions.{col, lit} +import org.apache.spark.sql.types.DoubleType /** * Params for Naive Bayes Classifiers. */ -private[ml] trait NaiveBayesParams extends PredictorParams { +private[ml] trait NaiveBayesParams extends PredictorParams with HasWeightCol { /** * The smoothing parameter. @@ -56,7 +53,7 @@ private[ml] trait NaiveBayesParams extends PredictorParams { */ final val modelType: Param[String] = new Param[String](this, "modelType", "The model type " + "which is a string (case-sensitive). Supported options: multinomial (default) and bernoulli.", - ParamValidators.inArray[String](OldNaiveBayes.supportedModelTypes.toArray)) + ParamValidators.inArray[String](NaiveBayes.supportedModelTypes.toArray)) /** @group getParam */ final def getModelType: String = $(modelType) @@ -64,7 +61,7 @@ private[ml] trait NaiveBayesParams extends PredictorParams { /** * Naive Bayes Classifiers. - * It supports both Multinomial NB + * It supports Multinomial NB * ([[http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html]]) * which can handle finitely supported discrete data. For example, by converting documents into * TF-IDF vectors, it can be used for document classification. By making every vector a @@ -78,6 +75,8 @@ class NaiveBayes @Since("1.5.0") ( extends ProbabilisticClassifier[Vector, NaiveBayes, NaiveBayesModel] with NaiveBayesParams with DefaultParamsWritable { + import NaiveBayes.{Bernoulli, Multinomial} + @Since("1.5.0") def this() = this(Identifiable.randomUID("nb")) @@ -98,7 +97,17 @@ class NaiveBayes @Since("1.5.0") ( */ @Since("1.5.0") def setModelType(value: String): this.type = set(modelType, value) - setDefault(modelType -> OldNaiveBayes.Multinomial) + setDefault(modelType -> NaiveBayes.Multinomial) + + /** + * Sets the value of param [[weightCol]]. + * If this is not set or empty, we treat all instance weights as 1.0. + * Default is not set, so all instances have weight one. + * + * @group setParam + */ + @Since("2.1.0") + def setWeightCol(value: String): this.type = set(weightCol, value) override protected def train(dataset: Dataset[_]): NaiveBayesModel = { val numClasses = getNumClasses(dataset) @@ -109,10 +118,89 @@ class NaiveBayes @Since("1.5.0") ( s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}") } - val oldDataset: RDD[OldLabeledPoint] = - extractLabeledPoints(dataset).map(OldLabeledPoint.fromML) - val oldModel = OldNaiveBayes.train(oldDataset, $(smoothing), $(modelType)) - NaiveBayesModel.fromOld(oldModel, this) + val numFeatures = dataset.select(col($(featuresCol))).head().getAs[Vector](0).size + + val requireNonnegativeValues: Vector => Unit = (v: Vector) => { + val values = v match { + case sv: SparseVector => sv.values + case dv: DenseVector => dv.values + } + + require(values.forall(_ >= 0.0), + s"Naive Bayes requires nonnegative feature values but found $v.") + } + + val requireZeroOneBernoulliValues: Vector => Unit = (v: Vector) => { + val values = v match { + case sv: SparseVector => sv.values + case dv: DenseVector => dv.values + } + + require(values.forall(v => v == 0.0 || v == 1.0), + s"Bernoulli naive Bayes requires 0 or 1 feature values but found $v.") + } + + val requireValues: Vector => Unit = { + $(modelType) match { + case Multinomial => + requireNonnegativeValues + case Bernoulli => + requireZeroOneBernoulliValues + case _ => + // This should never happen. + throw new UnknownError(s"Invalid modelType: ${$(modelType)}.") + } + } + + val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) + + // Aggregates term frequencies per label. + // TODO: Calling aggregateByKey and collect creates two stages, we can implement something + // TODO: similar to reduceByKeyLocally to save one stage. + val aggregated = dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd + .map { row => (row.getDouble(0), (row.getDouble(1), row.getAs[Vector](2))) + }.aggregateByKey[(Double, DenseVector)]((0.0, Vectors.zeros(numFeatures).toDense))( + seqOp = { + case ((weightSum: Double, featureSum: DenseVector), (weight, features)) => + requireValues(features) + BLAS.axpy(weight, features, featureSum) + (weightSum + weight, featureSum) + }, + combOp = { + case ((weightSum1, featureSum1), (weightSum2, featureSum2)) => + BLAS.axpy(1.0, featureSum2, featureSum1) + (weightSum1 + weightSum2, featureSum1) + }).collect().sortBy(_._1) + + val numLabels = aggregated.length + val numDocuments = aggregated.map(_._2._1).sum + + val piArray = Array.fill[Double](numLabels)(0.0) + val thetaArrays = Array.fill[Double](numLabels, numFeatures)(0.0) + + val lambda = $(smoothing) + val piLogDenom = math.log(numDocuments + numLabels * lambda) + var i = 0 + aggregated.foreach { case (label, (n, sumTermFreqs)) => + piArray(i) = math.log(n + lambda) - piLogDenom + val thetaLogDenom = $(modelType) match { + case Multinomial => math.log(sumTermFreqs.values.sum + numFeatures * lambda) + case Bernoulli => math.log(n + 2.0 * lambda) + case _ => + // This should never happen. + throw new UnknownError(s"Invalid modelType: ${$(modelType)}.") + } + var j = 0 + while (j < numFeatures) { + thetaArrays(i)(j) = math.log(sumTermFreqs(j) + lambda) - thetaLogDenom + j += 1 + } + i += 1 + } + + val pi = Vectors.dense(piArray) + val theta = new DenseMatrix(numLabels, thetaArrays(0).length, thetaArrays.flatten, true) + new NaiveBayesModel(uid, pi, theta) } @Since("1.5.0") @@ -121,6 +209,14 @@ class NaiveBayes @Since("1.5.0") ( @Since("1.6.0") object NaiveBayes extends DefaultParamsReadable[NaiveBayes] { + /** String name for multinomial model type. */ + private[spark] val Multinomial: String = "multinomial" + + /** String name for Bernoulli model type. */ + private[spark] val Bernoulli: String = "bernoulli" + + /* Set of modelTypes that NaiveBayes supports */ + private[spark] val supportedModelTypes = Set(Multinomial, Bernoulli) @Since("1.6.0") override def load(path: String): NaiveBayes = super.load(path) @@ -140,7 +236,7 @@ class NaiveBayesModel private[ml] ( extends ProbabilisticClassificationModel[Vector, NaiveBayesModel] with NaiveBayesParams with MLWritable { - import OldNaiveBayes.{Bernoulli, Multinomial} + import NaiveBayes.{Bernoulli, Multinomial} /** * Bernoulli scoring requires log(condprob) if 1, log(1-condprob) if 0. @@ -175,10 +271,8 @@ class NaiveBayesModel private[ml] ( private def bernoulliCalculation(features: Vector) = { features.foreachActive((_, value) => - if (value != 0.0 && value != 1.0) { - throw new SparkException( - s"Bernoulli naive Bayes requires 0 or 1 feature values but found $features.") - } + require(value == 0.0 || value == 1.0, + s"Bernoulli naive Bayes requires 0 or 1 feature values but found $features.") ) val prob = thetaMinusNegTheta.get.multiply(features) BLAS.axpy(1.0, pi, prob) @@ -238,18 +332,6 @@ class NaiveBayesModel private[ml] ( @Since("1.6.0") object NaiveBayesModel extends MLReadable[NaiveBayesModel] { - /** Convert a model from the old API */ - private[ml] def fromOld( - oldModel: OldNaiveBayesModel, - parent: NaiveBayes): NaiveBayesModel = { - val uid = if (parent != null) parent.uid else Identifiable.randomUID("nb") - val labels = Vectors.dense(oldModel.labels) - val pi = Vectors.dense(oldModel.pi) - val theta = new DenseMatrix(oldModel.labels.length, oldModel.theta(0).length, - oldModel.theta.flatten, true) - new NaiveBayesModel(uid, pi, theta) - } - @Since("1.6.0") override def read: MLReader[NaiveBayesModel] = new NaiveBayesModelReader @@ -280,11 +362,9 @@ object NaiveBayesModel extends MLReadable[NaiveBayesModel] { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString - val data = sparkSession.read.parquet(dataPath) - val vecConverted = MLUtils.convertVectorColumnsToML(data, "pi") - val Row(pi: Vector, theta: Matrix) = MLUtils.convertMatrixColumnsToML(vecConverted, "theta") - .select("pi", "theta") - .head() + val data = sparkSession.read.parquet(dataPath).select("pi", "theta").head() + val pi = data.getAs[Vector](0) + val theta = data.getAs[Matrix](1) val model = new NaiveBayesModel(metadata.uid, pi, theta) DefaultParamsReader.getAndSetParams(model, metadata) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index 593a86f69ad51..32d6968a4e85f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -27,7 +27,8 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.{SparkContext, SparkException} import org.apache.spark.annotation.Since import org.apache.spark.internal.Logging -import org.apache.spark.mllib.linalg.{BLAS, DenseMatrix, DenseVector, SparseVector, Vector} +import org.apache.spark.ml.classification.{NaiveBayes => NewNaiveBayes} +import org.apache.spark.mllib.linalg.{BLAS, DenseMatrix, DenseVector, Vector} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD @@ -311,8 +312,6 @@ class NaiveBayes private ( private var lambda: Double, private var modelType: String) extends Serializable with Logging { - import NaiveBayes.{Bernoulli, Multinomial} - @Since("1.4.0") def this(lambda: Double) = this(lambda, NaiveBayes.Multinomial) @@ -355,79 +354,33 @@ class NaiveBayes private ( */ @Since("0.9.0") def run(data: RDD[LabeledPoint]): NaiveBayesModel = { - val requireNonnegativeValues: Vector => Unit = (v: Vector) => { - val values = v match { - case sv: SparseVector => sv.values - case dv: DenseVector => dv.values - } - if (!values.forall(_ >= 0.0)) { - throw new SparkException(s"Naive Bayes requires nonnegative feature values but found $v.") - } - } + val spark = SparkSession + .builder() + .sparkContext(data.context) + .getOrCreate() - val requireZeroOneBernoulliValues: Vector => Unit = (v: Vector) => { - val values = v match { - case sv: SparseVector => sv.values - case dv: DenseVector => dv.values - } - if (!values.forall(v => v == 0.0 || v == 1.0)) { - throw new SparkException( - s"Bernoulli naive Bayes requires 0 or 1 feature values but found $v.") - } - } + import spark.implicits._ - // Aggregates term frequencies per label. - // TODO: Calling combineByKey and collect creates two stages, we can implement something - // TODO: similar to reduceByKeyLocally to save one stage. - val aggregated = data.map(p => (p.label, p.features)).combineByKey[(Long, DenseVector)]( - createCombiner = (v: Vector) => { - if (modelType == Bernoulli) { - requireZeroOneBernoulliValues(v) - } else { - requireNonnegativeValues(v) - } - (1L, v.copy.toDense) - }, - mergeValue = (c: (Long, DenseVector), v: Vector) => { - requireNonnegativeValues(v) - BLAS.axpy(1.0, v, c._2) - (c._1 + 1L, c._2) - }, - mergeCombiners = (c1: (Long, DenseVector), c2: (Long, DenseVector)) => { - BLAS.axpy(1.0, c2._2, c1._2) - (c1._1 + c2._1, c1._2) - } - ).collect().sortBy(_._1) + val nb = new NewNaiveBayes() + .setModelType(modelType) + .setSmoothing(lambda) - val numLabels = aggregated.length - var numDocuments = 0L - aggregated.foreach { case (_, (n, _)) => - numDocuments += n - } - val numFeatures = aggregated.head match { case (_, (_, v)) => v.size } - - val labels = new Array[Double](numLabels) - val pi = new Array[Double](numLabels) - val theta = Array.fill(numLabels)(new Array[Double](numFeatures)) - - val piLogDenom = math.log(numDocuments + numLabels * lambda) - var i = 0 - aggregated.foreach { case (label, (n, sumTermFreqs)) => - labels(i) = label - pi(i) = math.log(n + lambda) - piLogDenom - val thetaLogDenom = modelType match { - case Multinomial => math.log(sumTermFreqs.values.sum + numFeatures * lambda) - case Bernoulli => math.log(n + 2.0 * lambda) - case _ => - // This should never happen. - throw new UnknownError(s"Invalid modelType: $modelType.") - } - var j = 0 - while (j < numFeatures) { - theta(i)(j) = math.log(sumTermFreqs(j) + lambda) - thetaLogDenom - j += 1 - } - i += 1 + val labels = data.map(_.label).distinct().collect().sorted + + // Input labels for [[org.apache.spark.ml.classification.NaiveBayes]] must be + // in range [0, numClasses). + val dataset = data.map { + case LabeledPoint(label, features) => + (labels.indexOf(label).toDouble, features.asML) + }.toDF("label", "features") + + val newModel = nb.fit(dataset) + + val pi = newModel.pi.toArray + val theta = Array.fill[Double](newModel.numClasses, newModel.numFeatures)(0.0) + newModel.theta.foreachActive { + case (i, j, v) => + theta(i)(j) = v } new NaiveBayesModel(labels, pi, theta, modelType) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala index 99099324284dc..597428d036c7a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala @@ -23,13 +23,13 @@ import breeze.linalg.{DenseVector => BDV, Vector => BV} import breeze.stats.distributions.{Multinomial => BrzMultinomial} import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.classification.NaiveBayes.{Bernoulli, Multinomial} import org.apache.spark.ml.classification.NaiveBayesSuite._ -import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.feature.{Instance, LabeledPoint} import org.apache.spark.ml.linalg._ import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.classification.NaiveBayes.{Bernoulli, Multinomial} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Dataset, Row} @@ -152,6 +152,52 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa validateProbabilities(featureAndProbabilities, model, "multinomial") } + test("Naive Bayes Multinomial with weighted samples") { + val nPoints = 1000 + val piArray = Array(0.5, 0.1, 0.4).map(math.log) + val thetaArray = Array( + Array(0.70, 0.10, 0.10, 0.10), // label 0 + Array(0.10, 0.70, 0.10, 0.10), // label 1 + Array(0.10, 0.10, 0.70, 0.10) // label 2 + ).map(_.map(math.log)) + + val testData = generateNaiveBayesInput(piArray, thetaArray, nPoints, 42, "multinomial").toDF() + val (overSampledData, weightedData) = + MLTestingUtils.genEquivalentOversampledAndWeightedInstances(testData, + "label", "features", 42L) + val nb = new NaiveBayes().setModelType("multinomial") + val unweightedModel = nb.fit(weightedData) + val overSampledModel = nb.fit(overSampledData) + val weightedModel = nb.setWeightCol("weight").fit(weightedData) + assert(weightedModel.theta ~== overSampledModel.theta relTol 0.001) + assert(weightedModel.pi ~== overSampledModel.pi relTol 0.001) + assert(unweightedModel.theta !~= overSampledModel.theta relTol 0.001) + assert(unweightedModel.pi !~= overSampledModel.pi relTol 0.001) + } + + test("Naive Bayes Bernoulli with weighted samples") { + val nPoints = 10000 + val piArray = Array(0.5, 0.3, 0.2).map(math.log) + val thetaArray = Array( + Array(0.50, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.40), // label 0 + Array(0.02, 0.70, 0.10, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02), // label 1 + Array(0.02, 0.02, 0.60, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.30) // label 2 + ).map(_.map(math.log)) + + val testData = generateNaiveBayesInput(piArray, thetaArray, nPoints, 42, "bernoulli").toDF() + val (overSampledData, weightedData) = + MLTestingUtils.genEquivalentOversampledAndWeightedInstances(testData, + "label", "features", 42L) + val nb = new NaiveBayes().setModelType("bernoulli") + val unweightedModel = nb.fit(weightedData) + val overSampledModel = nb.fit(overSampledData) + val weightedModel = nb.setWeightCol("weight").fit(weightedData) + assert(weightedModel.theta ~== overSampledModel.theta relTol 0.001) + assert(weightedModel.pi ~== overSampledModel.pi relTol 0.001) + assert(unweightedModel.theta !~= overSampledModel.theta relTol 0.001) + assert(unweightedModel.pi !~= overSampledModel.pi relTol 0.001) + } + test("Naive Bayes Bernoulli") { val nPoints = 10000 val piArray = Array(0.5, 0.3, 0.2).map(math.log) From 8e491af52930886cbe0c54e7d67add3796ddb15f Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Fri, 30 Sep 2016 08:18:48 -0700 Subject: [PATCH 167/306] [SPARK-14077][ML][FOLLOW-UP] Revert change for NB Model's Load to maintain compatibility with the model stored before 2.0 ## What changes were proposed in this pull request? Revert change for NB Model's Load to maintain compatibility with the model stored before 2.0 ## How was this patch tested? local build Author: Zheng RuiFeng Closes #15313 from zhengruifeng/revert_save_load. --- .../apache/spark/ml/classification/NaiveBayes.scala | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index 0d652aa4c65a1..6775745167b08 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -25,7 +25,8 @@ import org.apache.spark.ml.linalg._ import org.apache.spark.ml.param.{DoubleParam, Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.HasWeightCol import org.apache.spark.ml.util._ -import org.apache.spark.sql.Dataset +import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.sql.{Dataset, Row} import org.apache.spark.sql.functions.{col, lit} import org.apache.spark.sql.types.DoubleType @@ -362,9 +363,11 @@ object NaiveBayesModel extends MLReadable[NaiveBayesModel] { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString - val data = sparkSession.read.parquet(dataPath).select("pi", "theta").head() - val pi = data.getAs[Vector](0) - val theta = data.getAs[Matrix](1) + val data = sparkSession.read.parquet(dataPath) + val vecConverted = MLUtils.convertVectorColumnsToML(data, "pi") + val Row(pi: Vector, theta: Matrix) = MLUtils.convertMatrixColumnsToML(vecConverted, "theta") + .select("pi", "theta") + .head() val model = new NaiveBayesModel(metadata.uid, pi, theta) DefaultParamsReader.getAndSetParams(model, metadata) From f327e16863371076dbd2a7f22c8895ae07f8274b Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 30 Sep 2016 09:59:12 -0700 Subject: [PATCH 168/306] [SPARK-17738] [SQL] fix ARRAY/MAP in columnar cache ## What changes were proposed in this pull request? The actualSize() of array and map is different from the actual size, the header is Int, rather than Long. ## How was this patch tested? The flaky test should be fixed. Author: Davies Liu Closes #15305 from davies/fix_MAP. --- .../apache/spark/sql/execution/columnar/ColumnType.scala | 8 ++++---- .../spark/sql/execution/columnar/ColumnTypeSuite.scala | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala index fa9619eb07fec..d27d8c362dd9a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala @@ -589,7 +589,7 @@ private[columnar] case class STRUCT(dataType: StructType) private[columnar] case class ARRAY(dataType: ArrayType) extends ColumnType[UnsafeArrayData] with DirectCopyColumnType[UnsafeArrayData] { - override def defaultSize: Int = 16 + override def defaultSize: Int = 28 override def setField(row: MutableRow, ordinal: Int, value: UnsafeArrayData): Unit = { row.update(ordinal, value) @@ -601,7 +601,7 @@ private[columnar] case class ARRAY(dataType: ArrayType) override def actualSize(row: InternalRow, ordinal: Int): Int = { val unsafeArray = getField(row, ordinal) - 8 + unsafeArray.getSizeInBytes + 4 + unsafeArray.getSizeInBytes } override def append(value: UnsafeArrayData, buffer: ByteBuffer): Unit = { @@ -628,7 +628,7 @@ private[columnar] case class ARRAY(dataType: ArrayType) private[columnar] case class MAP(dataType: MapType) extends ColumnType[UnsafeMapData] with DirectCopyColumnType[UnsafeMapData] { - override def defaultSize: Int = 32 + override def defaultSize: Int = 68 override def setField(row: MutableRow, ordinal: Int, value: UnsafeMapData): Unit = { row.update(ordinal, value) @@ -640,7 +640,7 @@ private[columnar] case class MAP(dataType: MapType) override def actualSize(row: InternalRow, ordinal: Int): Int = { val unsafeMap = getField(row, ordinal) - 8 + unsafeMap.getSizeInBytes + 4 + unsafeMap.getSizeInBytes } override def append(value: UnsafeMapData, buffer: ByteBuffer): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala index 0b93c633b2d93..805b5667287ea 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala @@ -38,7 +38,7 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { val checks = Map( NULL -> 0, BOOLEAN -> 1, BYTE -> 1, SHORT -> 2, INT -> 4, LONG -> 8, FLOAT -> 4, DOUBLE -> 8, COMPACT_DECIMAL(15, 10) -> 8, LARGE_DECIMAL(20, 10) -> 12, - STRING -> 8, BINARY -> 16, STRUCT_TYPE -> 20, ARRAY_TYPE -> 16, MAP_TYPE -> 32) + STRING -> 8, BINARY -> 16, STRUCT_TYPE -> 20, ARRAY_TYPE -> 28, MAP_TYPE -> 68) checks.foreach { case (columnType, expectedSize) => assertResult(expectedSize, s"Wrong defaultSize for $columnType") { @@ -73,8 +73,8 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { checkActualSize(BINARY, Array.fill[Byte](4)(0.toByte), 4 + 4) checkActualSize(COMPACT_DECIMAL(15, 10), Decimal(0, 15, 10), 8) checkActualSize(LARGE_DECIMAL(20, 10), Decimal(0, 20, 10), 5) - checkActualSize(ARRAY_TYPE, Array[Any](1), 8 + 8 + 8 + 8) - checkActualSize(MAP_TYPE, Map(1 -> "a"), 8 + (8 + 8 + 8 + 8) + (8 + 8 + 8 + 8)) + checkActualSize(ARRAY_TYPE, Array[Any](1), 4 + 8 + 8 + 8) + checkActualSize(MAP_TYPE, Map(1 -> "a"), 4 + (8 + 8 + 8 + 8) + (8 + 8 + 8 + 8)) checkActualSize(STRUCT_TYPE, Row("hello"), 28) } From 81455a9cd963098613bad10182e3fafc83a6e352 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Fri, 30 Sep 2016 17:31:59 -0700 Subject: [PATCH 169/306] [SPARK-17703][SQL] Add unnamed version of addReferenceObj for minor objects. ## What changes were proposed in this pull request? There are many minor objects in references, which are extracted to the generated class field, e.g. `errMsg` in `GetExternalRowField` or `ValidateExternalType`, but number of fields in class is limited so we should reduce the number. This pr adds unnamed version of `addReferenceObj` for these minor objects not to store the object into field but refer it from the `references` field at the time of use. ## How was this patch tested? Existing tests. Author: Takuya UESHIN Closes #15276 from ueshin/issues/SPARK-17703. --- .../expressions/codegen/CodeGenerator.scala | 15 +++++++++++++++ .../spark/sql/catalyst/expressions/misc.scala | 5 ++++- .../catalyst/expressions/objects/objects.scala | 12 +++++++++--- 3 files changed, 28 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 33b9b804fc601..cb808e375a35f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -84,6 +84,21 @@ class CodegenContext { */ val references: mutable.ArrayBuffer[Any] = new mutable.ArrayBuffer[Any]() + /** + * Add an object to `references`. + * + * Returns the code to access it. + * + * This is for minor objects not to store the object into field but refer it from the references + * field at the time of use because number of fields in class is limited so we should reduce it. + */ + def addReferenceObj(obj: Any): String = { + val idx = references.length + references += obj + val clsName = obj.getClass.getName + s"(($clsName) references[$idx])" + } + /** * Add an object to `references`, create a class member to access it. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 92f8fb85fc0e2..dbb52a4bb18de 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -517,7 +517,10 @@ case class AssertTrue(child: Expression) extends UnaryExpression with ImplicitCa override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = child.genCode(ctx) - val errMsgField = ctx.addReferenceObj("errMsg", errMsg) + + // Use unnamed reference that doesn't create a local field here to reduce the number of fields + // because errMsgField is used only when the value is null or false. + val errMsgField = ctx.addReferenceObj(errMsg) ExprCode(code = s"""${eval.code} |if (${eval.isNull} || !${eval.value}) { | throw new RuntimeException($errMsgField); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index faf8fecd79f4d..50e2ac3c36d93 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -906,7 +906,9 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String]) override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val childGen = child.genCode(ctx) - val errMsgField = ctx.addReferenceObj("errMsg", errMsg) + // Use unnamed reference that doesn't create a local field here to reduce the number of fields + // because errMsgField is used only when the value is null. + val errMsgField = ctx.addReferenceObj(errMsg) val code = s""" ${childGen.code} @@ -941,7 +943,9 @@ case class GetExternalRowField( private val errMsg = s"The ${index}th field '$fieldName' of input row cannot be null." override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val errMsgField = ctx.addReferenceObj("errMsg", errMsg) + // Use unnamed reference that doesn't create a local field here to reduce the number of fields + // because errMsgField is used only when the field is null. + val errMsgField = ctx.addReferenceObj(errMsg) val row = child.genCode(ctx) val code = s""" ${row.code} @@ -979,7 +983,9 @@ case class ValidateExternalType(child: Expression, expected: DataType) private val errMsg = s" is not a valid external type for schema of ${expected.simpleString}" override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val errMsgField = ctx.addReferenceObj("errMsg", errMsg) + // Use unnamed reference that doesn't create a local field here to reduce the number of fields + // because errMsgField is used only when the type doesn't match. + val errMsgField = ctx.addReferenceObj(errMsg) val input = child.genCode(ctx) val obj = input.value From a26afd52198523dbd51dc94053424494638c7de5 Mon Sep 17 00:00:00 2001 From: Shubham Chopra Date: Fri, 30 Sep 2016 18:24:39 -0700 Subject: [PATCH 170/306] [SPARK-15353][CORE] Making peer selection for block replication pluggable ## What changes were proposed in this pull request? This PR makes block replication strategies pluggable. It provides two trait that can be implemented, one that maps a host to its topology and is used in the master, and the second that helps prioritize a list of peers for block replication and would run in the executors. This patch contains default implementations of these traits that make sure current Spark behavior is unchanged. ## How was this patch tested? This patch should not change Spark behavior in any way, and was tested with unit tests for storage. Author: Shubham Chopra Closes #13152 from shubhamchopra/RackAwareBlockReplication. --- .../apache/spark/storage/BlockManager.scala | 167 +++++++++--------- .../apache/spark/storage/BlockManagerId.scala | 34 +++- .../spark/storage/BlockManagerMaster.scala | 16 +- .../storage/BlockManagerMasterEndpoint.scala | 32 +++- .../storage/BlockReplicationPolicy.scala | 112 ++++++++++++ .../apache/spark/storage/TopologyMapper.scala | 86 +++++++++ .../BlockManagerReplicationSuite.scala | 2 + .../storage/BlockReplicationPolicySuite.scala | 74 ++++++++ .../spark/storage/TopologyMapperSuite.scala | 68 +++++++ 9 files changed, 492 insertions(+), 99 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/storage/BlockReplicationPolicy.scala create mode 100644 core/src/main/scala/org/apache/spark/storage/TopologyMapper.scala create mode 100644 core/src/test/scala/org/apache/spark/storage/BlockReplicationPolicySuite.scala create mode 100644 core/src/test/scala/org/apache/spark/storage/TopologyMapperSuite.scala diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index aa29acfd70461..982b83324e0fc 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -20,7 +20,8 @@ package org.apache.spark.storage import java.io._ import java.nio.ByteBuffer -import scala.collection.mutable.{ArrayBuffer, HashMap} +import scala.collection.mutable +import scala.collection.mutable.HashMap import scala.concurrent.{Await, ExecutionContext, Future} import scala.concurrent.duration._ import scala.reflect.ClassTag @@ -44,6 +45,7 @@ import org.apache.spark.unsafe.Platform import org.apache.spark.util._ import org.apache.spark.util.io.ChunkedByteBuffer + /* Class for returning a fetched block and associated metrics. */ private[spark] class BlockResult( val data: Iterator[Any], @@ -147,6 +149,8 @@ private[spark] class BlockManager( private val peerFetchLock = new Object private var lastPeerFetchTime = 0L + private var blockReplicationPolicy: BlockReplicationPolicy = _ + /** * Initializes the BlockManager with the given appId. This is not performed in the constructor as * the appId may not be known at BlockManager instantiation time (in particular for the driver, @@ -160,8 +164,24 @@ private[spark] class BlockManager( blockTransferService.init(this) shuffleClient.init(appId) - blockManagerId = BlockManagerId( - executorId, blockTransferService.hostName, blockTransferService.port) + blockReplicationPolicy = { + val priorityClass = conf.get( + "spark.storage.replication.policy", classOf[RandomBlockReplicationPolicy].getName) + val clazz = Utils.classForName(priorityClass) + val ret = clazz.newInstance.asInstanceOf[BlockReplicationPolicy] + logInfo(s"Using $priorityClass for block replication policy") + ret + } + + val id = + BlockManagerId(executorId, blockTransferService.hostName, blockTransferService.port, None) + + val idFromMaster = master.registerBlockManager( + id, + maxMemory, + slaveEndpoint) + + blockManagerId = if (idFromMaster != null) idFromMaster else id shuffleServerId = if (externalShuffleServiceEnabled) { logInfo(s"external shuffle service port = $externalShuffleServicePort") @@ -170,12 +190,12 @@ private[spark] class BlockManager( blockManagerId } - master.registerBlockManager(blockManagerId, maxMemory, slaveEndpoint) - // Register Executors' configuration with the local shuffle service, if one should exist. if (externalShuffleServiceEnabled && !blockManagerId.isDriver) { registerWithExternalShuffleServer() } + + logInfo(s"Initialized BlockManager: $blockManagerId") } private def registerWithExternalShuffleServer() { @@ -1111,7 +1131,7 @@ private[spark] class BlockManager( } /** - * Replicate block to another node. Not that this is a blocking call that returns after + * Replicate block to another node. Note that this is a blocking call that returns after * the block has been replicated. */ private def replicate( @@ -1119,101 +1139,78 @@ private[spark] class BlockManager( data: ChunkedByteBuffer, level: StorageLevel, classTag: ClassTag[_]): Unit = { + val maxReplicationFailures = conf.getInt("spark.storage.maxReplicationFailures", 1) - val numPeersToReplicateTo = level.replication - 1 - val peersForReplication = new ArrayBuffer[BlockManagerId] - val peersReplicatedTo = new ArrayBuffer[BlockManagerId] - val peersFailedToReplicateTo = new ArrayBuffer[BlockManagerId] val tLevel = StorageLevel( useDisk = level.useDisk, useMemory = level.useMemory, useOffHeap = level.useOffHeap, deserialized = level.deserialized, replication = 1) - val startTime = System.currentTimeMillis - val random = new Random(blockId.hashCode) - - var replicationFailed = false - var failures = 0 - var done = false - - // Get cached list of peers - peersForReplication ++= getPeers(forceFetch = false) - - // Get a random peer. Note that this selection of a peer is deterministic on the block id. - // So assuming the list of peers does not change and no replication failures, - // if there are multiple attempts in the same node to replicate the same block, - // the same set of peers will be selected. - def getRandomPeer(): Option[BlockManagerId] = { - // If replication had failed, then force update the cached list of peers and remove the peers - // that have been already used - if (replicationFailed) { - peersForReplication.clear() - peersForReplication ++= getPeers(forceFetch = true) - peersForReplication --= peersReplicatedTo - peersForReplication --= peersFailedToReplicateTo - } - if (!peersForReplication.isEmpty) { - Some(peersForReplication(random.nextInt(peersForReplication.size))) - } else { - None - } - } - // One by one choose a random peer and try uploading the block to it - // If replication fails (e.g., target peer is down), force the list of cached peers - // to be re-fetched from driver and then pick another random peer for replication. Also - // temporarily black list the peer for which replication failed. - // - // This selection of a peer and replication is continued in a loop until one of the - // following 3 conditions is fulfilled: - // (i) specified number of peers have been replicated to - // (ii) too many failures in replicating to peers - // (iii) no peer left to replicate to - // - while (!done) { - getRandomPeer() match { - case Some(peer) => - try { - val onePeerStartTime = System.currentTimeMillis - logTrace(s"Trying to replicate $blockId of ${data.size} bytes to $peer") - blockTransferService.uploadBlockSync( - peer.host, - peer.port, - peer.executorId, - blockId, - new NettyManagedBuffer(data.toNetty), - tLevel, - classTag) - logTrace(s"Replicated $blockId of ${data.size} bytes to $peer in %s ms" - .format(System.currentTimeMillis - onePeerStartTime)) - peersReplicatedTo += peer - peersForReplication -= peer - replicationFailed = false - if (peersReplicatedTo.size == numPeersToReplicateTo) { - done = true // specified number of peers have been replicated to - } - } catch { - case NonFatal(e) => - logWarning(s"Failed to replicate $blockId to $peer, failure #$failures", e) - failures += 1 - replicationFailed = true - peersFailedToReplicateTo += peer - if (failures > maxReplicationFailures) { // too many failures in replicating to peers - done = true - } + val numPeersToReplicateTo = level.replication - 1 + + val startTime = System.nanoTime + + var peersReplicatedTo = mutable.HashSet.empty[BlockManagerId] + var peersFailedToReplicateTo = mutable.HashSet.empty[BlockManagerId] + var numFailures = 0 + + var peersForReplication = blockReplicationPolicy.prioritize( + blockManagerId, + getPeers(false), + mutable.HashSet.empty, + blockId, + numPeersToReplicateTo) + + while(numFailures <= maxReplicationFailures && + !peersForReplication.isEmpty && + peersReplicatedTo.size != numPeersToReplicateTo) { + val peer = peersForReplication.head + try { + val onePeerStartTime = System.nanoTime + logTrace(s"Trying to replicate $blockId of ${data.size} bytes to $peer") + blockTransferService.uploadBlockSync( + peer.host, + peer.port, + peer.executorId, + blockId, + new NettyManagedBuffer(data.toNetty), + tLevel, + classTag) + logTrace(s"Replicated $blockId of ${data.size} bytes to $peer" + + s" in ${(System.nanoTime - onePeerStartTime).toDouble / 1e6} ms") + peersForReplication = peersForReplication.tail + peersReplicatedTo += peer + } catch { + case NonFatal(e) => + logWarning(s"Failed to replicate $blockId to $peer, failure #$numFailures", e) + peersFailedToReplicateTo += peer + // we have a failed replication, so we get the list of peers again + // we don't want peers we have already replicated to and the ones that + // have failed previously + val filteredPeers = getPeers(true).filter { p => + !peersFailedToReplicateTo.contains(p) && !peersReplicatedTo.contains(p) } - case None => // no peer left to replicate to - done = true + + numFailures += 1 + peersForReplication = blockReplicationPolicy.prioritize( + blockManagerId, + filteredPeers, + peersReplicatedTo, + blockId, + numPeersToReplicateTo - peersReplicatedTo.size) } } - val timeTakeMs = (System.currentTimeMillis - startTime) + logDebug(s"Replicating $blockId of ${data.size} bytes to " + - s"${peersReplicatedTo.size} peer(s) took $timeTakeMs ms") + s"${peersReplicatedTo.size} peer(s) took ${(System.nanoTime - startTime) / 1e6} ms") if (peersReplicatedTo.size < numPeersToReplicateTo) { logWarning(s"Block $blockId replicated to only " + s"${peersReplicatedTo.size} peer(s) instead of $numPeersToReplicateTo peers") } + + logDebug(s"block $blockId replicated to ${peersReplicatedTo.mkString(", ")}") } /** diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala index f255f5be63fcf..c37a3604d28fa 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala @@ -37,10 +37,11 @@ import org.apache.spark.util.Utils class BlockManagerId private ( private var executorId_ : String, private var host_ : String, - private var port_ : Int) + private var port_ : Int, + private var topologyInfo_ : Option[String]) extends Externalizable { - private def this() = this(null, null, 0) // For deserialization only + private def this() = this(null, null, 0, None) // For deserialization only def executorId: String = executorId_ @@ -60,6 +61,8 @@ class BlockManagerId private ( def port: Int = port_ + def topologyInfo: Option[String] = topologyInfo_ + def isDriver: Boolean = { executorId == SparkContext.DRIVER_IDENTIFIER || executorId == SparkContext.LEGACY_DRIVER_IDENTIFIER @@ -69,24 +72,33 @@ class BlockManagerId private ( out.writeUTF(executorId_) out.writeUTF(host_) out.writeInt(port_) + out.writeBoolean(topologyInfo_.isDefined) + // we only write topologyInfo if we have it + topologyInfo.foreach(out.writeUTF(_)) } override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { executorId_ = in.readUTF() host_ = in.readUTF() port_ = in.readInt() + val isTopologyInfoAvailable = in.readBoolean() + topologyInfo_ = if (isTopologyInfoAvailable) Option(in.readUTF()) else None } @throws(classOf[IOException]) private def readResolve(): Object = BlockManagerId.getCachedBlockManagerId(this) - override def toString: String = s"BlockManagerId($executorId, $host, $port)" + override def toString: String = s"BlockManagerId($executorId, $host, $port, $topologyInfo)" - override def hashCode: Int = (executorId.hashCode * 41 + host.hashCode) * 41 + port + override def hashCode: Int = + ((executorId.hashCode * 41 + host.hashCode) * 41 + port) * 41 + topologyInfo.hashCode override def equals(that: Any): Boolean = that match { case id: BlockManagerId => - executorId == id.executorId && port == id.port && host == id.host + executorId == id.executorId && + port == id.port && + host == id.host && + topologyInfo == id.topologyInfo case _ => false } @@ -101,10 +113,18 @@ private[spark] object BlockManagerId { * @param execId ID of the executor. * @param host Host name of the block manager. * @param port Port of the block manager. + * @param topologyInfo topology information for the blockmanager, if available + * This can be network topology information for use while choosing peers + * while replicating data blocks. More information available here: + * [[org.apache.spark.storage.TopologyMapper]] * @return A new [[org.apache.spark.storage.BlockManagerId]]. */ - def apply(execId: String, host: String, port: Int): BlockManagerId = - getCachedBlockManagerId(new BlockManagerId(execId, host, port)) + def apply( + execId: String, + host: String, + port: Int, + topologyInfo: Option[String] = None): BlockManagerId = + getCachedBlockManagerId(new BlockManagerId(execId, host, port, topologyInfo)) def apply(in: ObjectInput): BlockManagerId = { val obj = new BlockManagerId() diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index 8655cf10fc28f..7a600068912b1 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -50,12 +50,20 @@ class BlockManagerMaster( logInfo("Removal of executor " + execId + " requested") } - /** Register the BlockManager's id with the driver. */ + /** + * Register the BlockManager's id with the driver. The input BlockManagerId does not contain + * topology information. This information is obtained from the master and we respond with an + * updated BlockManagerId fleshed out with this information. + */ def registerBlockManager( - blockManagerId: BlockManagerId, maxMemSize: Long, slaveEndpoint: RpcEndpointRef): Unit = { + blockManagerId: BlockManagerId, + maxMemSize: Long, + slaveEndpoint: RpcEndpointRef): BlockManagerId = { logInfo(s"Registering BlockManager $blockManagerId") - tell(RegisterBlockManager(blockManagerId, maxMemSize, slaveEndpoint)) - logInfo(s"Registered BlockManager $blockManagerId") + val updatedId = driverEndpoint.askWithRetry[BlockManagerId]( + RegisterBlockManager(blockManagerId, maxMemSize, slaveEndpoint)) + logInfo(s"Registered BlockManager $updatedId") + updatedId } def updateBlockInfo( diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index 8fa12150114db..145c434a4f0cf 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -55,10 +55,21 @@ class BlockManagerMasterEndpoint( private val askThreadPool = ThreadUtils.newDaemonCachedThreadPool("block-manager-ask-thread-pool") private implicit val askExecutionContext = ExecutionContext.fromExecutorService(askThreadPool) + private val topologyMapper = { + val topologyMapperClassName = conf.get( + "spark.storage.replication.topologyMapper", classOf[DefaultTopologyMapper].getName) + val clazz = Utils.classForName(topologyMapperClassName) + val mapper = + clazz.getConstructor(classOf[SparkConf]).newInstance(conf).asInstanceOf[TopologyMapper] + logInfo(s"Using $topologyMapperClassName for getting topology information") + mapper + } + + logInfo("BlockManagerMasterEndpoint up") + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case RegisterBlockManager(blockManagerId, maxMemSize, slaveEndpoint) => - register(blockManagerId, maxMemSize, slaveEndpoint) - context.reply(true) + context.reply(register(blockManagerId, maxMemSize, slaveEndpoint)) case _updateBlockInfo @ UpdateBlockInfo(blockManagerId, blockId, storageLevel, deserializedSize, size) => @@ -298,7 +309,21 @@ class BlockManagerMasterEndpoint( ).map(_.flatten.toSeq) } - private def register(id: BlockManagerId, maxMemSize: Long, slaveEndpoint: RpcEndpointRef) { + /** + * Returns the BlockManagerId with topology information populated, if available. + */ + private def register( + idWithoutTopologyInfo: BlockManagerId, + maxMemSize: Long, + slaveEndpoint: RpcEndpointRef): BlockManagerId = { + // the dummy id is not expected to contain the topology information. + // we get that info here and respond back with a more fleshed out block manager id + val id = BlockManagerId( + idWithoutTopologyInfo.executorId, + idWithoutTopologyInfo.host, + idWithoutTopologyInfo.port, + topologyMapper.getTopologyForHost(idWithoutTopologyInfo.host)) + val time = System.currentTimeMillis() if (!blockManagerInfo.contains(id)) { blockManagerIdByExecutor.get(id.executorId) match { @@ -318,6 +343,7 @@ class BlockManagerMasterEndpoint( id, System.currentTimeMillis(), maxMemSize, slaveEndpoint) } listenerBus.post(SparkListenerBlockManagerAdded(time, id, maxMemSize)) + id } private def updateBlockInfo( diff --git a/core/src/main/scala/org/apache/spark/storage/BlockReplicationPolicy.scala b/core/src/main/scala/org/apache/spark/storage/BlockReplicationPolicy.scala new file mode 100644 index 0000000000000..bf087af16a5b1 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/BlockReplicationPolicy.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import scala.collection.mutable +import scala.util.Random + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.internal.Logging + +/** + * ::DeveloperApi:: + * BlockReplicationPrioritization provides logic for prioritizing a sequence of peers for + * replicating blocks. BlockManager will replicate to each peer returned in order until the + * desired replication order is reached. If a replication fails, prioritize() will be called + * again to get a fresh prioritization. + */ +@DeveloperApi +trait BlockReplicationPolicy { + + /** + * Method to prioritize a bunch of candidate peers of a block + * + * @param blockManagerId Id of the current BlockManager for self identification + * @param peers A list of peers of a BlockManager + * @param peersReplicatedTo Set of peers already replicated to + * @param blockId BlockId of the block being replicated. This can be used as a source of + * randomness if needed. + * @param numReplicas Number of peers we need to replicate to + * @return A prioritized list of peers. Lower the index of a peer, higher its priority. + * This returns a list of size at most `numPeersToReplicateTo`. + */ + def prioritize( + blockManagerId: BlockManagerId, + peers: Seq[BlockManagerId], + peersReplicatedTo: mutable.HashSet[BlockManagerId], + blockId: BlockId, + numReplicas: Int): List[BlockManagerId] +} + +@DeveloperApi +class RandomBlockReplicationPolicy + extends BlockReplicationPolicy + with Logging { + + /** + * Method to prioritize a bunch of candidate peers of a block. This is a basic implementation, + * that just makes sure we put blocks on different hosts, if possible + * + * @param blockManagerId Id of the current BlockManager for self identification + * @param peers A list of peers of a BlockManager + * @param peersReplicatedTo Set of peers already replicated to + * @param blockId BlockId of the block being replicated. This can be used as a source of + * randomness if needed. + * @return A prioritized list of peers. Lower the index of a peer, higher its priority + */ + override def prioritize( + blockManagerId: BlockManagerId, + peers: Seq[BlockManagerId], + peersReplicatedTo: mutable.HashSet[BlockManagerId], + blockId: BlockId, + numReplicas: Int): List[BlockManagerId] = { + val random = new Random(blockId.hashCode) + logDebug(s"Input peers : ${peers.mkString(", ")}") + val prioritizedPeers = if (peers.size > numReplicas) { + getSampleIds(peers.size, numReplicas, random).map(peers(_)) + } else { + if (peers.size < numReplicas) { + logWarning(s"Expecting ${numReplicas} replicas with only ${peers.size} peer/s.") + } + random.shuffle(peers).toList + } + logDebug(s"Prioritized peers : ${prioritizedPeers.mkString(", ")}") + prioritizedPeers + } + + /** + * Uses sampling algorithm by Robert Floyd. Finds a random sample in O(n) while + * minimizing space usage + * [[http://math.stackexchange.com/questions/178690/ + * whats-the-proof-of-correctness-for-robert-floyds-algorithm-for-selecting-a-sin]] + * + * @param n total number of indices + * @param m number of samples needed + * @param r random number generator + * @return list of m random unique indices + */ + private def getSampleIds(n: Int, m: Int, r: Random): List[Int] = { + val indices = (n - m + 1 to n).foldLeft(Set.empty[Int]) {case (set, i) => + val t = r.nextInt(i) + 1 + if (set.contains(t)) set + i else set + t + } + // we shuffle the result to ensure a random arrangement within the sample + // to avoid any bias from set implementations + r.shuffle(indices.map(_ - 1).toList) + } +} diff --git a/core/src/main/scala/org/apache/spark/storage/TopologyMapper.scala b/core/src/main/scala/org/apache/spark/storage/TopologyMapper.scala new file mode 100644 index 0000000000000..a0f0fdef8e948 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/TopologyMapper.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import org.apache.spark.SparkConf +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.internal.Logging +import org.apache.spark.util.Utils + +/** + * ::DeveloperApi:: + * TopologyMapper provides topology information for a given host + * @param conf SparkConf to get required properties, if needed + */ +@DeveloperApi +abstract class TopologyMapper(conf: SparkConf) { + /** + * Gets the topology information given the host name + * + * @param hostname Hostname + * @return topology information for the given hostname. One can use a 'topology delimiter' + * to make this topology information nested. + * For example : ‘/myrack/myhost’, where ‘/’ is the topology delimiter, + * ‘myrack’ is the topology identifier, and ‘myhost’ is the individual host. + * This function only returns the topology information without the hostname. + * This information can be used when choosing executors for block replication + * to discern executors from a different rack than a candidate executor, for example. + * + * An implementation can choose to use empty strings or None in case topology info + * is not available. This would imply that all such executors belong to the same rack. + */ + def getTopologyForHost(hostname: String): Option[String] +} + +/** + * A TopologyMapper that assumes all nodes are in the same rack + */ +@DeveloperApi +class DefaultTopologyMapper(conf: SparkConf) extends TopologyMapper(conf) with Logging { + override def getTopologyForHost(hostname: String): Option[String] = { + logDebug(s"Got a request for $hostname") + None + } +} + +/** + * A simple file based topology mapper. This expects topology information provided as a + * [[java.util.Properties]] file. The name of the file is obtained from SparkConf property + * `spark.storage.replication.topologyFile`. To use this topology mapper, set the + * `spark.storage.replication.topologyMapper` property to + * [[org.apache.spark.storage.FileBasedTopologyMapper]] + * @param conf SparkConf object + */ +@DeveloperApi +class FileBasedTopologyMapper(conf: SparkConf) extends TopologyMapper(conf) with Logging { + val topologyFile = conf.getOption("spark.storage.replication.topologyFile") + require(topologyFile.isDefined, "Please specify topology file via " + + "spark.storage.replication.topologyFile for FileBasedTopologyMapper.") + val topologyMap = Utils.getPropertiesFromFile(topologyFile.get) + + override def getTopologyForHost(hostname: String): Option[String] = { + val topology = topologyMap.get(hostname) + if (topology.isDefined) { + logDebug(s"$hostname -> ${topology.get}") + } else { + logWarning(s"$hostname does not have any topology information") + } + topology + } +} + diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index e1c1787cbd15e..f4bfdc2fd69a9 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -346,6 +346,8 @@ class BlockManagerReplicationSuite extends SparkFunSuite } } + + /** * Test replication of blocks with different storage levels (various combinations of * memory, disk & serialization). For each storage level, this function tests every store diff --git a/core/src/test/scala/org/apache/spark/storage/BlockReplicationPolicySuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockReplicationPolicySuite.scala new file mode 100644 index 0000000000000..800c3899f1a72 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/storage/BlockReplicationPolicySuite.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import scala.collection.mutable + +import org.scalatest.{BeforeAndAfter, Matchers} + +import org.apache.spark.{LocalSparkContext, SparkFunSuite} + +class BlockReplicationPolicySuite extends SparkFunSuite + with Matchers + with BeforeAndAfter + with LocalSparkContext { + + // Implicitly convert strings to BlockIds for test clarity. + private implicit def StringToBlockId(value: String): BlockId = new TestBlockId(value) + + /** + * Test if we get the required number of peers when using random sampling from + * RandomBlockReplicationPolicy + */ + test(s"block replication - random block replication policy") { + val numBlockManagers = 10 + val storeSize = 1000 + val blockManagers = (1 to numBlockManagers).map { i => + BlockManagerId(s"store-$i", "localhost", 1000 + i, None) + } + val candidateBlockManager = BlockManagerId("test-store", "localhost", 1000, None) + val replicationPolicy = new RandomBlockReplicationPolicy + val blockId = "test-block" + + (1 to 10).foreach {numReplicas => + logDebug(s"Num replicas : $numReplicas") + val randomPeers = replicationPolicy.prioritize( + candidateBlockManager, + blockManagers, + mutable.HashSet.empty[BlockManagerId], + blockId, + numReplicas + ) + logDebug(s"Random peers : ${randomPeers.mkString(", ")}") + assert(randomPeers.toSet.size === numReplicas) + + // choosing n peers out of n + val secondPass = replicationPolicy.prioritize( + candidateBlockManager, + randomPeers, + mutable.HashSet.empty[BlockManagerId], + blockId, + numReplicas + ) + logDebug(s"Random peers : ${secondPass.mkString(", ")}") + assert(secondPass.toSet.size === numReplicas) + } + + } + +} diff --git a/core/src/test/scala/org/apache/spark/storage/TopologyMapperSuite.scala b/core/src/test/scala/org/apache/spark/storage/TopologyMapperSuite.scala new file mode 100644 index 0000000000000..bbd252d7be7ea --- /dev/null +++ b/core/src/test/scala/org/apache/spark/storage/TopologyMapperSuite.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import java.io.{File, FileOutputStream} + +import org.scalatest.{BeforeAndAfter, Matchers} + +import org.apache.spark._ +import org.apache.spark.util.Utils + +class TopologyMapperSuite extends SparkFunSuite + with Matchers + with BeforeAndAfter + with LocalSparkContext { + + test("File based Topology Mapper") { + val numHosts = 100 + val numRacks = 4 + val props = (1 to numHosts).map{i => s"host-$i" -> s"rack-${i % numRacks}"}.toMap + val propsFile = createPropertiesFile(props) + + val sparkConf = (new SparkConf(false)) + sparkConf.set("spark.storage.replication.topologyFile", propsFile.getAbsolutePath) + val topologyMapper = new FileBasedTopologyMapper(sparkConf) + + props.foreach {case (host, topology) => + val obtainedTopology = topologyMapper.getTopologyForHost(host) + assert(obtainedTopology.isDefined) + assert(obtainedTopology.get === topology) + } + + // we get None for hosts not in the file + assert(topologyMapper.getTopologyForHost("host").isEmpty) + + cleanup(propsFile) + } + + def createPropertiesFile(props: Map[String, String]): File = { + val testFile = new File(Utils.createTempDir(), "TopologyMapperSuite-test").getAbsoluteFile + val fileOS = new FileOutputStream(testFile) + props.foreach{case (k, v) => fileOS.write(s"$k=$v\n".getBytes)} + fileOS.close + testFile + } + + def cleanup(testFile: File): Unit = { + testFile.getParentFile.listFiles.filter { file => + file.getName.startsWith(testFile.getName) + }.foreach { _.delete() } + } + +} From aef506e39a41cfe7198162c324a11ef2f01136c3 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Fri, 30 Sep 2016 21:05:06 -0700 Subject: [PATCH 171/306] [SPARK-17739][SQL] Collapse adjacent similar Window operators ## What changes were proposed in this pull request? Currently, Spark does not collapse adjacent windows with the same partitioning and sorting. This PR implements `CollapseWindow` optimizer to do the followings. 1. If the partition specs and order specs are the same, collapse into the parent. 2. If the partition specs are the same and one order spec is a prefix of the other, collapse to the more specific one. For example: ```scala val df = spark.range(1000).select($"id" % 100 as "grp", $"id", rand() as "col1", rand() as "col2") // Add summary statistics for all columns import org.apache.spark.sql.expressions.Window val cols = Seq("id", "col1", "col2") val window = Window.partitionBy($"grp").orderBy($"id") val result = cols.foldLeft(df) { (base, name) => base.withColumn(s"${name}_avg", avg(col(name)).over(window)) .withColumn(s"${name}_stddev", stddev(col(name)).over(window)) .withColumn(s"${name}_min", min(col(name)).over(window)) .withColumn(s"${name}_max", max(col(name)).over(window)) } ``` **Before** ```scala scala> result.explain == Physical Plan == Window [max(col2#19) windowspecdefinition(grp#17L, id#14L ASC NULLS FIRST, RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS col2_max#234], [grp#17L], [id#14L ASC NULLS FIRST] +- Window [min(col2#19) windowspecdefinition(grp#17L, id#14L ASC NULLS FIRST, RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS col2_min#216], [grp#17L], [id#14L ASC NULLS FIRST] +- Window [stddev_samp(col2#19) windowspecdefinition(grp#17L, id#14L ASC NULLS FIRST, RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS col2_stddev#191], [grp#17L], [id#14L ASC NULLS FIRST] +- Window [avg(col2#19) windowspecdefinition(grp#17L, id#14L ASC NULLS FIRST, RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS col2_avg#167], [grp#17L], [id#14L ASC NULLS FIRST] +- Window [max(col1#18) windowspecdefinition(grp#17L, id#14L ASC NULLS FIRST, RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS col1_max#152], [grp#17L], [id#14L ASC NULLS FIRST] +- Window [min(col1#18) windowspecdefinition(grp#17L, id#14L ASC NULLS FIRST, RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS col1_min#138], [grp#17L], [id#14L ASC NULLS FIRST] +- Window [stddev_samp(col1#18) windowspecdefinition(grp#17L, id#14L ASC NULLS FIRST, RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS col1_stddev#117], [grp#17L], [id#14L ASC NULLS FIRST] +- Window [avg(col1#18) windowspecdefinition(grp#17L, id#14L ASC NULLS FIRST, RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS col1_avg#97], [grp#17L], [id#14L ASC NULLS FIRST] +- Window [max(id#14L) windowspecdefinition(grp#17L, id#14L ASC NULLS FIRST, RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS id_max#86L], [grp#17L], [id#14L ASC NULLS FIRST] +- Window [min(id#14L) windowspecdefinition(grp#17L, id#14L ASC NULLS FIRST, RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS id_min#76L], [grp#17L], [id#14L ASC NULLS FIRST] +- *Project [grp#17L, id#14L, col1#18, col2#19, id_avg#26, id_stddev#42] +- Window [stddev_samp(_w0#59) windowspecdefinition(grp#17L, id#14L ASC NULLS FIRST, RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS id_stddev#42], [grp#17L], [id#14L ASC NULLS FIRST] +- *Project [grp#17L, id#14L, col1#18, col2#19, id_avg#26, cast(id#14L as double) AS _w0#59] +- Window [avg(id#14L) windowspecdefinition(grp#17L, id#14L ASC NULLS FIRST, RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS id_avg#26], [grp#17L], [id#14L ASC NULLS FIRST] +- *Sort [grp#17L ASC NULLS FIRST, id#14L ASC NULLS FIRST], false, 0 +- Exchange hashpartitioning(grp#17L, 200) +- *Project [(id#14L % 100) AS grp#17L, id#14L, rand(-6329949029880411066) AS col1#18, rand(-7251358484380073081) AS col2#19] +- *Range (0, 1000, step=1, splits=Some(8)) ``` **After** ```scala scala> result.explain == Physical Plan == Window [max(col2#5) windowspecdefinition(grp#3L, id#0L ASC NULLS FIRST, RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS col2_max#220, min(col2#5) windowspecdefinition(grp#3L, id#0L ASC NULLS FIRST, RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS col2_min#202, stddev_samp(col2#5) windowspecdefinition(grp#3L, id#0L ASC NULLS FIRST, RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS col2_stddev#177, avg(col2#5) windowspecdefinition(grp#3L, id#0L ASC NULLS FIRST, RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS col2_avg#153, max(col1#4) windowspecdefinition(grp#3L, id#0L ASC NULLS FIRST, RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS col1_max#138, min(col1#4) windowspecdefinition(grp#3L, id#0L ASC NULLS FIRST, RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS col1_min#124, stddev_samp(col1#4) windowspecdefinition(grp#3L, id#0L ASC NULLS FIRST, RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS col1_stddev#103, avg(col1#4) windowspecdefinition(grp#3L, id#0L ASC NULLS FIRST, RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS col1_avg#83, max(id#0L) windowspecdefinition(grp#3L, id#0L ASC NULLS FIRST, RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS id_max#72L, min(id#0L) windowspecdefinition(grp#3L, id#0L ASC NULLS FIRST, RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS id_min#62L], [grp#3L], [id#0L ASC NULLS FIRST] +- *Project [grp#3L, id#0L, col1#4, col2#5, id_avg#12, id_stddev#28] +- Window [stddev_samp(_w0#45) windowspecdefinition(grp#3L, id#0L ASC NULLS FIRST, RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS id_stddev#28], [grp#3L], [id#0L ASC NULLS FIRST] +- *Project [grp#3L, id#0L, col1#4, col2#5, id_avg#12, cast(id#0L as double) AS _w0#45] +- Window [avg(id#0L) windowspecdefinition(grp#3L, id#0L ASC NULLS FIRST, RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS id_avg#12], [grp#3L], [id#0L ASC NULLS FIRST] +- *Sort [grp#3L ASC NULLS FIRST, id#0L ASC NULLS FIRST], false, 0 +- Exchange hashpartitioning(grp#3L, 200) +- *Project [(id#0L % 100) AS grp#3L, id#0L, rand(6537478539664068821) AS col1#4, rand(-8961093871295252795) AS col2#5] +- *Range (0, 1000, step=1, splits=Some(8)) ``` ## How was this patch tested? Pass the Jenkins tests with a newly added testsuite. Author: Dongjoon Hyun Closes #15317 from dongjoon-hyun/SPARK-17739. --- .../sql/catalyst/optimizer/Optimizer.scala | 12 +++ .../optimizer/CollapseWindowSuite.scala | 78 +++++++++++++++++++ 2 files changed, 90 insertions(+) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseWindowSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 9df8ce1fa3b28..e5e2cd7d27d15 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -88,6 +88,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) // Operator combine CollapseRepartition, CollapseProject, + CollapseWindow, CombineFilters, CombineLimits, CombineUnions, @@ -537,6 +538,17 @@ object CollapseRepartition extends Rule[LogicalPlan] { } } +/** + * Collapse Adjacent Window Expression. + * - If the partition specs and order specs are the same, collapse into the parent. + */ +object CollapseWindow extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + case w @ Window(we1, ps1, os1, Window(we2, ps2, os2, grandChild)) if ps1 == ps2 && os1 == os2 => + w.copy(windowExpressions = we1 ++ we2, child = grandChild) + } +} + /** * Generate a list of additional filters from an operator's existing constraint but remove those * that are either already part of the operator's condition or are part of the operator's child diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseWindowSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseWindowSuite.scala new file mode 100644 index 0000000000000..797076e55cfcc --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseWindowSuite.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor + +class CollapseWindowSuite extends PlanTest { + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("CollapseWindow", FixedPoint(10), + CollapseWindow) :: Nil + } + + val testRelation = LocalRelation('a.double, 'b.double, 'c.string) + val a = testRelation.output(0) + val b = testRelation.output(1) + val c = testRelation.output(2) + val partitionSpec1 = Seq(c) + val partitionSpec2 = Seq(c + 1) + val orderSpec1 = Seq(c.asc) + val orderSpec2 = Seq(c.desc) + + test("collapse two adjacent windows with the same partition/order") { + val query = testRelation + .window(Seq(min(a).as('min_a)), partitionSpec1, orderSpec1) + .window(Seq(max(a).as('max_a)), partitionSpec1, orderSpec1) + .window(Seq(sum(b).as('sum_b)), partitionSpec1, orderSpec1) + .window(Seq(avg(b).as('avg_b)), partitionSpec1, orderSpec1) + + val optimized = Optimize.execute(query.analyze) + val correctAnswer = testRelation.window(Seq( + avg(b).as('avg_b), + sum(b).as('sum_b), + max(a).as('max_a), + min(a).as('min_a)), partitionSpec1, orderSpec1) + + comparePlans(optimized, correctAnswer) + } + + test("Don't collapse adjacent windows with different partitions or orders") { + val query1 = testRelation + .window(Seq(min(a).as('min_a)), partitionSpec1, orderSpec1) + .window(Seq(max(a).as('max_a)), partitionSpec1, orderSpec2) + + val optimized1 = Optimize.execute(query1.analyze) + val correctAnswer1 = query1.analyze + + comparePlans(optimized1, correctAnswer1) + + val query2 = testRelation + .window(Seq(min(a).as('min_a)), partitionSpec1, orderSpec1) + .window(Seq(max(a).as('max_a)), partitionSpec2, orderSpec1) + + val optimized2 = Optimize.execute(query2.analyze) + val correctAnswer2 = query2.analyze + + comparePlans(optimized2, correctAnswer2) + } +} From 15e9bbb49e00b3982c428d39776725d0dea2cdfa Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Fri, 30 Sep 2016 22:05:59 -0700 Subject: [PATCH 172/306] [MINOR][DOC] Add an up-to-date description for default serialization during shuffling ## What changes were proposed in this pull request? This PR aims to make the doc up-to-date. The documentation is generally correct, but after https://issues.apache.org/jira/browse/SPARK-13926, Spark starts to choose Kyro as a default serialization library during shuffling of simple types, arrays of simple types, or string type. ## How was this patch tested? This is a documentation update. Author: Dongjoon Hyun Closes #15315 from dongjoon-hyun/SPARK-DOC-SERIALIZER. --- docs/tuning.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/tuning.md b/docs/tuning.md index cbf37213aa724..9c43b315bbb9e 100644 --- a/docs/tuning.md +++ b/docs/tuning.md @@ -45,6 +45,7 @@ and calling `conf.set("spark.serializer", "org.apache.spark.serializer.KryoSeria This setting configures the serializer used for not only shuffling data between worker nodes but also when serializing RDDs to disk. The only reason Kryo is not the default is because of the custom registration requirement, but we recommend trying it in any network-intensive application. +Since Spark 2.0.0, we internally use Kryo serializer when shuffling RDDs with simple types, arrays of simple types, or string type. Spark automatically includes Kryo serializers for the many commonly-used core Scala classes covered in the AllScalaRegistrar from the [Twitter chill](https://github.com/twitter/chill) library. From 4bcd9b728b8df74756d16b27725c2db7c523d4b2 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Fri, 30 Sep 2016 23:51:36 -0700 Subject: [PATCH 173/306] [SPARK-17740] Spark tests should mock / interpose HDFS to ensure that streams are closed ## What changes were proposed in this pull request? As a followup to SPARK-17666, ensure filesystem connections are not leaked at least in unit tests. This is done here by intercepting filesystem calls as suggested by JoshRosen . At the end of each test, we assert no filesystem streams are left open. This applies to all tests using SharedSQLContext or SharedSparkContext. ## How was this patch tested? I verified that tests in sql and core are indeed using the filesystem backend, and fixed the detected leaks. I also checked that reverting https://github.com/apache/spark/pull/15245 causes many actual test failures due to connection leaks. Author: Eric Liang Author: Eric Liang Closes #15306 from ericl/sc-4672. --- .../org/apache/spark/DebugFilesystem.scala | 114 ++++++++++++++++++ .../org/apache/spark/SharedSparkContext.scala | 17 ++- .../parquet/ParquetEncodingSuite.scala | 1 + .../streaming/HDFSMetadataLogSuite.scala | 3 +- .../spark/sql/test/SharedSQLContext.scala | 19 ++- 5 files changed, 147 insertions(+), 7 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/DebugFilesystem.scala diff --git a/core/src/test/scala/org/apache/spark/DebugFilesystem.scala b/core/src/test/scala/org/apache/spark/DebugFilesystem.scala new file mode 100644 index 0000000000000..fb8d701ebda8a --- /dev/null +++ b/core/src/test/scala/org/apache/spark/DebugFilesystem.scala @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import java.io.{FileDescriptor, InputStream} +import java.lang +import java.nio.ByteBuffer +import java.util.concurrent.ConcurrentHashMap + +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import org.apache.hadoop.fs._ + +import org.apache.spark.internal.Logging + +object DebugFilesystem extends Logging { + // Stores the set of active streams and their creation sites. + private val openStreams = new ConcurrentHashMap[FSDataInputStream, Throwable]() + + def clearOpenStreams(): Unit = { + openStreams.clear() + } + + def assertNoOpenStreams(): Unit = { + val numOpen = openStreams.size() + if (numOpen > 0) { + for (exc <- openStreams.values().asScala) { + logWarning("Leaked filesystem connection created at:") + exc.printStackTrace() + } + throw new RuntimeException(s"There are $numOpen possibly leaked file streams.") + } + } +} + +/** + * DebugFilesystem wraps file open calls to track all open connections. This can be used in tests + * to check that connections are not leaked. + */ +// TODO(ekl) we should consider always interposing this to expose num open conns as a metric +class DebugFilesystem extends LocalFileSystem { + import DebugFilesystem._ + + override def open(f: Path, bufferSize: Int): FSDataInputStream = { + val wrapped: FSDataInputStream = super.open(f, bufferSize) + openStreams.put(wrapped, new Throwable()) + + new FSDataInputStream(wrapped.getWrappedStream) { + override def setDropBehind(dropBehind: lang.Boolean): Unit = wrapped.setDropBehind(dropBehind) + + override def getWrappedStream: InputStream = wrapped.getWrappedStream + + override def getFileDescriptor: FileDescriptor = wrapped.getFileDescriptor + + override def getPos: Long = wrapped.getPos + + override def seekToNewSource(targetPos: Long): Boolean = wrapped.seekToNewSource(targetPos) + + override def seek(desired: Long): Unit = wrapped.seek(desired) + + override def setReadahead(readahead: lang.Long): Unit = wrapped.setReadahead(readahead) + + override def read(position: Long, buffer: Array[Byte], offset: Int, length: Int): Int = + wrapped.read(position, buffer, offset, length) + + override def read(buf: ByteBuffer): Int = wrapped.read(buf) + + override def readFully(position: Long, buffer: Array[Byte], offset: Int, length: Int): Unit = + wrapped.readFully(position, buffer, offset, length) + + override def readFully(position: Long, buffer: Array[Byte]): Unit = + wrapped.readFully(position, buffer) + + override def available(): Int = wrapped.available() + + override def mark(readlimit: Int): Unit = wrapped.mark(readlimit) + + override def skip(n: Long): Long = wrapped.skip(n) + + override def markSupported(): Boolean = wrapped.markSupported() + + override def close(): Unit = { + wrapped.close() + openStreams.remove(wrapped) + } + + override def read(): Int = wrapped.read() + + override def reset(): Unit = wrapped.reset() + + override def toString: String = wrapped.toString + + override def equals(obj: scala.Any): Boolean = wrapped.equals(obj) + + override def hashCode(): Int = wrapped.hashCode() + } + } +} diff --git a/core/src/test/scala/org/apache/spark/SharedSparkContext.scala b/core/src/test/scala/org/apache/spark/SharedSparkContext.scala index 858bc742e07cf..6aedcb1271ff6 100644 --- a/core/src/test/scala/org/apache/spark/SharedSparkContext.scala +++ b/core/src/test/scala/org/apache/spark/SharedSparkContext.scala @@ -17,11 +17,11 @@ package org.apache.spark -import org.scalatest.BeforeAndAfterAll +import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} import org.scalatest.Suite /** Shares a local `SparkContext` between all tests in a suite and closes it at the end */ -trait SharedSparkContext extends BeforeAndAfterAll { self: Suite => +trait SharedSparkContext extends BeforeAndAfterAll with BeforeAndAfterEach { self: Suite => @transient private var _sc: SparkContext = _ @@ -31,7 +31,8 @@ trait SharedSparkContext extends BeforeAndAfterAll { self: Suite => override def beforeAll() { super.beforeAll() - _sc = new SparkContext("local[4]", "test", conf) + _sc = new SparkContext( + "local[4]", "test", conf.set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName)) } override def afterAll() { @@ -42,4 +43,14 @@ trait SharedSparkContext extends BeforeAndAfterAll { self: Suite => super.afterAll() } } + + protected override def beforeEach(): Unit = { + super.beforeEach() + DebugFilesystem.clearOpenStreams() + } + + protected override def afterEach(): Unit = { + super.afterEach() + DebugFilesystem.assertNoOpenStreams() + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala index c7541889f202e..00799301ca8d9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala @@ -104,6 +104,7 @@ class ParquetEncodingSuite extends ParquetCompatibilityTest with SharedSQLContex assert(column.getUTF8String(3 * i + 1).toString == i.toString) assert(column.getUTF8String(3 * i + 2).toString == i.toString) } + reader.close() } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala index 4259384f0bc61..9c1d26dcb2241 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala @@ -203,13 +203,14 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { } // Open and delete - fm.open(path) + val f1 = fm.open(path) fm.delete(path) assert(!fm.exists(path)) intercept[IOException] { fm.open(path) } fm.delete(path) // should not throw exception + f1.close() // Rename val path1 = new Path(s"$dir/file1") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala index 79c37faa4e9ba..db24ee8b46dd5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala @@ -17,14 +17,16 @@ package org.apache.spark.sql.test -import org.apache.spark.SparkConf +import org.scalatest.BeforeAndAfterEach + +import org.apache.spark.{DebugFilesystem, SparkConf} import org.apache.spark.sql.{SparkSession, SQLContext} /** * Helper trait for SQL test suites where all tests share a single [[TestSparkSession]]. */ -trait SharedSQLContext extends SQLTestUtils { +trait SharedSQLContext extends SQLTestUtils with BeforeAndAfterEach { protected val sparkConf = new SparkConf() @@ -52,7 +54,8 @@ trait SharedSQLContext extends SQLTestUtils { protected override def beforeAll(): Unit = { SparkSession.sqlListener.set(null) if (_spark == null) { - _spark = new TestSparkSession(sparkConf) + _spark = new TestSparkSession( + sparkConf.set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName)) } // Ensure we have initialized the context before calling parent code super.beforeAll() @@ -71,4 +74,14 @@ trait SharedSQLContext extends SQLTestUtils { super.afterAll() } } + + protected override def beforeEach(): Unit = { + super.beforeEach() + DebugFilesystem.clearOpenStreams() + } + + protected override def afterEach(): Unit = { + super.afterEach() + DebugFilesystem.assertNoOpenStreams() + } } From af6ece33d39cf305bd4a211d08a2f8e910c69bc1 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Sat, 1 Oct 2016 00:50:16 -0700 Subject: [PATCH 174/306] [SPARK-17717][SQL] Add Exist/find methods to Catalog [FOLLOW-UP] ## What changes were proposed in this pull request? We added find and exists methods for Databases, Tables and Functions to the user facing Catalog in PR https://github.com/apache/spark/pull/15301. However, it was brought up that the semantics of the `find` methods are more in line a `get` method (get an object or else fail). So we rename these in this PR. ## How was this patch tested? Existing tests. Author: Herman van Hovell Closes #15308 from hvanhovell/SPARK-17717-2. --- project/MimaExcludes.scala | 10 +-- .../apache/spark/sql/catalog/Catalog.scala | 31 +++---- .../spark/sql/internal/CatalogImpl.scala | 80 ++++++++----------- .../spark/sql/internal/CatalogSuite.scala | 38 ++++----- 4 files changed, 71 insertions(+), 88 deletions(-) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 2ffe0ac9bc982..7362041428b1f 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -48,14 +48,12 @@ object MimaExcludes { // [SPARK-16240] ML persistence backward compatibility for LDA ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.clustering.LDA$"), // [SPARK-17717] Add Find and Exists method to Catalog. - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.findDatabase"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.findTable"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.findFunction"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.findColumn"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.getDatabase"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.getTable"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.getFunction"), ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.databaseExists"), ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.tableExists"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.functionExists"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.columnExists") + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.functionExists") ) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala index b439022d227cc..7f2762c7dac92 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala @@ -102,50 +102,51 @@ abstract class Catalog { def listColumns(dbName: String, tableName: String): Dataset[Column] /** - * Find the database with the specified name. This throws an AnalysisException when the database + * Get the database with the specified name. This throws an AnalysisException when the database * cannot be found. * * @since 2.1.0 */ @throws[AnalysisException]("database does not exist") - def findDatabase(dbName: String): Database + def getDatabase(dbName: String): Database /** - * Find the table with the specified name. This table can be a temporary table or a table in the - * current database. This throws an AnalysisException when the table cannot be found. + * Get the table or view with the specified name. This table can be a temporary view or a + * table/view in the current database. This throws an AnalysisException when no Table + * can be found. * * @since 2.1.0 */ @throws[AnalysisException]("table does not exist") - def findTable(tableName: String): Table + def getTable(tableName: String): Table /** - * Find the table with the specified name in the specified database. This throws an - * AnalysisException when the table cannot be found. + * Get the table or view with the specified name in the specified database. This throws an + * AnalysisException when no Table can be found. * * @since 2.1.0 */ @throws[AnalysisException]("database or table does not exist") - def findTable(dbName: String, tableName: String): Table + def getTable(dbName: String, tableName: String): Table /** - * Find the function with the specified name. This function can be a temporary function or a + * Get the function with the specified name. This function can be a temporary function or a * function in the current database. This throws an AnalysisException when the function cannot * be found. * * @since 2.1.0 */ @throws[AnalysisException]("function does not exist") - def findFunction(functionName: String): Function + def getFunction(functionName: String): Function /** - * Find the function with the specified name. This throws an AnalysisException when the function + * Get the function with the specified name. This throws an AnalysisException when the function * cannot be found. * * @since 2.1.0 */ @throws[AnalysisException]("database or function does not exist") - def findFunction(dbName: String, functionName: String): Function + def getFunction(dbName: String, functionName: String): Function /** * Check if the database with the specified name exists. @@ -155,15 +156,15 @@ abstract class Catalog { def databaseExists(dbName: String): Boolean /** - * Check if the table with the specified name exists. This can either be a temporary table or a - * table in the current database. + * Check if the table or view with the specified name exists. This can either be a temporary + * view or a table/view in the current database. * * @since 2.1.0 */ def tableExists(tableName: String): Boolean /** - * Check if the table with the specified name exists in the specified database. + * Check if the table or view with the specified name exists in the specified database. * * @since 2.1.0 */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala index a1087edd03fdf..e412e1b4b302a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala @@ -68,13 +68,12 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { * Returns a list of databases available across all sessions. */ override def listDatabases(): Dataset[Database] = { - val databases = sessionCatalog.listDatabases().map { dbName => - makeDatabase(sessionCatalog.getDatabaseMetadata(dbName)) - } + val databases = sessionCatalog.listDatabases().map(makeDatabase) CatalogImpl.makeDataset(databases, sparkSession) } - private def makeDatabase(metadata: CatalogDatabase): Database = { + private def makeDatabase(dbName: String): Database = { + val metadata = sessionCatalog.getDatabaseMetadata(dbName) new Database( name = metadata.name, description = metadata.description, @@ -96,20 +95,19 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { @throws[AnalysisException]("database does not exist") override def listTables(dbName: String): Dataset[Table] = { requireDatabaseExists(dbName) - val tables = sessionCatalog.listTables(dbName).map { tableIdent => - makeTable(tableIdent, tableIdent.database.isEmpty) - } + val tables = sessionCatalog.listTables(dbName).map(makeTable) CatalogImpl.makeDataset(tables, sparkSession) } - private def makeTable(tableIdent: TableIdentifier, isTemp: Boolean): Table = { - val metadata = if (isTemp) None else Some(sessionCatalog.getTableMetadata(tableIdent)) + private def makeTable(tableIdent: TableIdentifier): Table = { + val metadata = sessionCatalog.getTempViewOrPermanentTableMetadata(tableIdent) + val database = metadata.identifier.database new Table( - name = tableIdent.identifier, - database = metadata.flatMap(_.identifier.database).orNull, - description = metadata.flatMap(_.comment).orNull, - tableType = metadata.map(_.tableType.name).getOrElse("TEMPORARY"), - isTemporary = isTemp) + name = tableIdent.table, + database = database.orNull, + description = metadata.comment.orNull, + tableType = if (database.isEmpty) "TEMPORARY" else metadata.tableType.name, + isTemporary = database.isEmpty) } /** @@ -178,59 +176,45 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { } /** - * Find the database with the specified name. This throws an [[AnalysisException]] when no + * Get the database with the specified name. This throws an [[AnalysisException]] when no * [[Database]] can be found. */ - override def findDatabase(dbName: String): Database = { - if (sessionCatalog.databaseExists(dbName)) { - makeDatabase(sessionCatalog.getDatabaseMetadata(dbName)) - } else { - throw new AnalysisException(s"The specified database $dbName does not exist.") - } + override def getDatabase(dbName: String): Database = { + makeDatabase(dbName) } /** - * Find the table with the specified name. This table can be a temporary table or a table in the - * current database. This throws an [[AnalysisException]] when no [[Table]] can be found. + * Get the table or view with the specified name. This table can be a temporary view or a + * table/view in the current database. This throws an [[AnalysisException]] when no [[Table]] + * can be found. */ - override def findTable(tableName: String): Table = { - findTable(null, tableName) + override def getTable(tableName: String): Table = { + getTable(null, tableName) } /** - * Find the table with the specified name in the specified database. This throws an + * Get the table or view with the specified name in the specified database. This throws an * [[AnalysisException]] when no [[Table]] can be found. */ - override def findTable(dbName: String, tableName: String): Table = { - val tableIdent = TableIdentifier(tableName, Option(dbName)) - val isTemporary = sessionCatalog.isTemporaryTable(tableIdent) - if (isTemporary || sessionCatalog.tableExists(tableIdent)) { - makeTable(tableIdent, isTemporary) - } else { - throw new AnalysisException(s"The specified table $tableIdent does not exist.") - } + override def getTable(dbName: String, tableName: String): Table = { + makeTable(TableIdentifier(tableName, Option(dbName))) } /** - * Find the function with the specified name. This function can be a temporary function or a + * Get the function with the specified name. This function can be a temporary function or a * function in the current database. This throws an [[AnalysisException]] when no [[Function]] * can be found. */ - override def findFunction(functionName: String): Function = { - findFunction(null, functionName) + override def getFunction(functionName: String): Function = { + getFunction(null, functionName) } /** - * Find the function with the specified name. This returns [[None]] when no [[Function]] can be + * Get the function with the specified name. This returns [[None]] when no [[Function]] can be * found. */ - override def findFunction(dbName: String, functionName: String): Function = { - val functionIdent = FunctionIdentifier(functionName, Option(dbName)) - if (sessionCatalog.functionExists(functionIdent)) { - makeFunction(functionIdent) - } else { - throw new AnalysisException(s"The specified function $functionIdent does not exist.") - } + override def getFunction(dbName: String, functionName: String): Function = { + makeFunction(FunctionIdentifier(functionName, Option(dbName))) } /** @@ -241,15 +225,15 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { } /** - * Check if the table with the specified name exists. This can either be a temporary table or a - * table in the current database. + * Check if the table or view with the specified name exists. This can either be a temporary + * view or a table/view in the current database. */ override def tableExists(tableName: String): Boolean = { tableExists(null, tableName) } /** - * Check if the table with the specified name exists in the specified database. + * Check if the table or view with the specified name exists in the specified database. */ override def tableExists(dbName: String, tableName: String): Boolean = { val tableIdent = TableIdentifier(tableName, Option(dbName)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala index 783bf77f86b46..214bc736bd4de 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala @@ -340,61 +340,61 @@ class CatalogSuite } } - test("find database") { - intercept[AnalysisException](spark.catalog.findDatabase("db10")) + test("get database") { + intercept[AnalysisException](spark.catalog.getDatabase("db10")) withTempDatabase { db => - assert(spark.catalog.findDatabase(db).name === db) + assert(spark.catalog.getDatabase(db).name === db) } } - test("find table") { + test("get table") { withTempDatabase { db => withTable(s"tbl_x", s"$db.tbl_y") { // Try to find non existing tables. - intercept[AnalysisException](spark.catalog.findTable("tbl_x")) - intercept[AnalysisException](spark.catalog.findTable("tbl_y")) - intercept[AnalysisException](spark.catalog.findTable(db, "tbl_y")) + intercept[AnalysisException](spark.catalog.getTable("tbl_x")) + intercept[AnalysisException](spark.catalog.getTable("tbl_y")) + intercept[AnalysisException](spark.catalog.getTable(db, "tbl_y")) // Create objects. createTempTable("tbl_x") createTable("tbl_y", Some(db)) // Find a temporary table - assert(spark.catalog.findTable("tbl_x").name === "tbl_x") + assert(spark.catalog.getTable("tbl_x").name === "tbl_x") // Find a qualified table - assert(spark.catalog.findTable(db, "tbl_y").name === "tbl_y") + assert(spark.catalog.getTable(db, "tbl_y").name === "tbl_y") // Find an unqualified table using the current database - intercept[AnalysisException](spark.catalog.findTable("tbl_y")) + intercept[AnalysisException](spark.catalog.getTable("tbl_y")) spark.catalog.setCurrentDatabase(db) - assert(spark.catalog.findTable("tbl_y").name === "tbl_y") + assert(spark.catalog.getTable("tbl_y").name === "tbl_y") } } } - test("find function") { + test("get function") { withTempDatabase { db => withUserDefinedFunction("fn1" -> true, s"$db.fn2" -> false) { // Try to find non existing functions. - intercept[AnalysisException](spark.catalog.findFunction("fn1")) - intercept[AnalysisException](spark.catalog.findFunction("fn2")) - intercept[AnalysisException](spark.catalog.findFunction(db, "fn2")) + intercept[AnalysisException](spark.catalog.getFunction("fn1")) + intercept[AnalysisException](spark.catalog.getFunction("fn2")) + intercept[AnalysisException](spark.catalog.getFunction(db, "fn2")) // Create objects. createTempFunction("fn1") createFunction("fn2", Some(db)) // Find a temporary function - assert(spark.catalog.findFunction("fn1").name === "fn1") + assert(spark.catalog.getFunction("fn1").name === "fn1") // Find a qualified function - assert(spark.catalog.findFunction(db, "fn2").name === "fn2") + assert(spark.catalog.getFunction(db, "fn2").name === "fn2") // Find an unqualified function using the current database - intercept[AnalysisException](spark.catalog.findFunction("fn2")) + intercept[AnalysisException](spark.catalog.getFunction("fn2")) spark.catalog.setCurrentDatabase(db) - assert(spark.catalog.findFunction("fn2").name === "fn2") + assert(spark.catalog.getFunction("fn2").name === "fn2") } } } From b88cb63da39786c07cb4bfa70afed32ec5eb3286 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sat, 1 Oct 2016 16:10:39 -0400 Subject: [PATCH 175/306] [SPARK-17704][ML][MLLIB] ChiSqSelector performance improvement. ## What changes were proposed in this pull request? Partial revert of #15277 to instead sort and store input to model rather than require sorted input ## How was this patch tested? Existing tests. Author: Sean Owen Closes #15299 from srowen/SPARK-17704.2. --- .../spark/ml/feature/ChiSqSelector.scala | 2 +- .../spark/mllib/feature/ChiSqSelector.scala | 22 +++++++++---------- python/pyspark/ml/feature.py | 2 +- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala index 9c131a41850cc..d0385e220e1e2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala @@ -193,7 +193,7 @@ final class ChiSqSelectorModel private[ml] ( import ChiSqSelectorModel._ - /** list of indices to select (filter). Must be ordered asc */ + /** list of indices to select (filter). */ @Since("1.6.0") val selectedFeatures: Array[Int] = chiSqSelector.selectedFeatures diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala index 706ce78f260a6..c305b36278e87 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala @@ -35,14 +35,15 @@ import org.apache.spark.sql.{Row, SparkSession} /** * Chi Squared selector model. * - * @param selectedFeatures list of indices to select (filter). Must be ordered asc + * @param selectedFeatures list of indices to select (filter). */ @Since("1.3.0") class ChiSqSelectorModel @Since("1.3.0") ( @Since("1.3.0") val selectedFeatures: Array[Int]) extends VectorTransformer with Saveable { - require(isSorted(selectedFeatures), "Array has to be sorted asc") + private val filterIndices = selectedFeatures.sorted + @deprecated("not intended for subclasses to use", "2.1.0") protected def isSorted(array: Array[Int]): Boolean = { var i = 1 val len = array.length @@ -61,7 +62,7 @@ class ChiSqSelectorModel @Since("1.3.0") ( */ @Since("1.3.0") override def transform(vector: Vector): Vector = { - compress(vector, selectedFeatures) + compress(vector) } /** @@ -69,9 +70,8 @@ class ChiSqSelectorModel @Since("1.3.0") ( * Preserves the order of filtered features the same as their indices are stored. * Might be moved to Vector as .slice * @param features vector - * @param filterIndices indices of features to filter, must be ordered asc */ - private def compress(features: Vector, filterIndices: Array[Int]): Vector = { + private def compress(features: Vector): Vector = { features match { case SparseVector(size, indices, values) => val newSize = filterIndices.length @@ -230,23 +230,23 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable { */ @Since("1.3.0") def fit(data: RDD[LabeledPoint]): ChiSqSelectorModel = { - val chiSqTestResult = Statistics.chiSqTest(data) + val chiSqTestResult = Statistics.chiSqTest(data).zipWithIndex val features = selectorType match { case ChiSqSelector.KBest => - chiSqTestResult.zipWithIndex + chiSqTestResult .sortBy { case (res, _) => -res.statistic } .take(numTopFeatures) case ChiSqSelector.Percentile => - chiSqTestResult.zipWithIndex + chiSqTestResult .sortBy { case (res, _) => -res.statistic } .take((chiSqTestResult.length * percentile).toInt) case ChiSqSelector.FPR => - chiSqTestResult.zipWithIndex - .filter{ case (res, _) => res.pValue < alpha } + chiSqTestResult + .filter { case (res, _) => res.pValue < alpha } case errorType => throw new IllegalStateException(s"Unknown ChiSqSelector Type: $errorType") } - val indices = features.map { case (_, indices) => indices }.sorted + val indices = features.map { case (_, index) => index } new ChiSqSelectorModel(indices) } } diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 12a13849dc9bc..64b21caa616ec 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -2705,7 +2705,7 @@ class ChiSqSelectorModel(JavaModel, JavaMLReadable, JavaMLWritable): @since("2.0.0") def selectedFeatures(self): """ - List of indices to select (filter). Must be ordered asc. + List of indices to select (filter). """ return self._call_java("selectedFeatures") From f8d7fade4b9a78ae87b6012e3d6f71eef3032b22 Mon Sep 17 00:00:00 2001 From: Sital Kedia Date: Sun, 2 Oct 2016 15:47:36 -0700 Subject: [PATCH 176/306] =?UTF-8?q?[SPARK-17509][SQL]=20When=20wrapping=20?= =?UTF-8?q?catalyst=20datatype=20to=20Hive=20data=20type=20avoid=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? When wrapping catalyst datatypes to Hive data type, wrap function was doing an expensive pattern matching which was consuming around 11% of cpu time. Avoid the pattern matching by returning the wrapper only once and reuse it. ## How was this patch tested? Tested by running the job on cluster and saw around 8% cpu improvements. Author: Sital Kedia Closes #15064 from sitalkedia/skedia/hive_wrapper. --- .../spark/sql/hive/HiveInspectors.scala | 307 ++++++++---------- .../org/apache/spark/sql/hive/hiveUDFs.scala | 15 +- 2 files changed, 145 insertions(+), 177 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index e4b963efeaf18..c3c4351cf58a9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -238,102 +238,161 @@ private[hive] trait HiveInspectors { case c => throw new AnalysisException(s"Unsupported java type $c") } + private def withNullSafe(f: Any => Any): Any => Any = { + input => if (input == null) null else f(input) + } + /** * Wraps with Hive types based on object inspector. - * TODO: Consolidate all hive OI/data interface code. */ protected def wrapperFor(oi: ObjectInspector, dataType: DataType): Any => Any = oi match { - case _: JavaHiveVarcharObjectInspector => - (o: Any) => - if (o != null) { - val s = o.asInstanceOf[UTF8String].toString - new HiveVarchar(s, s.length) - } else { - null - } - - case _: JavaHiveCharObjectInspector => - (o: Any) => - if (o != null) { - val s = o.asInstanceOf[UTF8String].toString - new HiveChar(s, s.length) - } else { - null - } - - case _: JavaHiveDecimalObjectInspector => - (o: Any) => - if (o != null) { - HiveDecimal.create(o.asInstanceOf[Decimal].toJavaBigDecimal) - } else { - null - } - - case _: JavaDateObjectInspector => - (o: Any) => - if (o != null) { - DateTimeUtils.toJavaDate(o.asInstanceOf[Int]) - } else { - null - } - - case _: JavaTimestampObjectInspector => + case x: ConstantObjectInspector => (o: Any) => - if (o != null) { - DateTimeUtils.toJavaTimestamp(o.asInstanceOf[Long]) - } else { - null + x.getWritableConstantValue + case x: PrimitiveObjectInspector => x match { + // TODO we don't support the HiveVarcharObjectInspector yet. + case _: StringObjectInspector if x.preferWritable() => + withNullSafe(o => getStringWritable(o)) + case _: StringObjectInspector => + withNullSafe(o => o.asInstanceOf[UTF8String].toString()) + case _: IntObjectInspector if x.preferWritable() => + withNullSafe(o => getIntWritable(o)) + case _: IntObjectInspector => + withNullSafe(o => o.asInstanceOf[java.lang.Integer]) + case _: BooleanObjectInspector if x.preferWritable() => + withNullSafe(o => getBooleanWritable(o)) + case _: BooleanObjectInspector => + withNullSafe(o => o.asInstanceOf[java.lang.Boolean]) + case _: FloatObjectInspector if x.preferWritable() => + withNullSafe(o => getFloatWritable(o)) + case _: FloatObjectInspector => + withNullSafe(o => o.asInstanceOf[java.lang.Float]) + case _: DoubleObjectInspector if x.preferWritable() => + withNullSafe(o => getDoubleWritable(o)) + case _: DoubleObjectInspector => + withNullSafe(o => o.asInstanceOf[java.lang.Double]) + case _: LongObjectInspector if x.preferWritable() => + withNullSafe(o => getLongWritable(o)) + case _: LongObjectInspector => + withNullSafe(o => o.asInstanceOf[java.lang.Long]) + case _: ShortObjectInspector if x.preferWritable() => + withNullSafe(o => getShortWritable(o)) + case _: ShortObjectInspector => + withNullSafe(o => o.asInstanceOf[java.lang.Short]) + case _: ByteObjectInspector if x.preferWritable() => + withNullSafe(o => getByteWritable(o)) + case _: ByteObjectInspector => + withNullSafe(o => o.asInstanceOf[java.lang.Byte]) + case _: JavaHiveVarcharObjectInspector => + withNullSafe { o => + val s = o.asInstanceOf[UTF8String].toString + new HiveVarchar(s, s.length) } + case _: JavaHiveCharObjectInspector => + withNullSafe { o => + val s = o.asInstanceOf[UTF8String].toString + new HiveChar(s, s.length) + } + case _: JavaHiveDecimalObjectInspector => + withNullSafe(o => + HiveDecimal.create(o.asInstanceOf[Decimal].toJavaBigDecimal)) + case _: JavaDateObjectInspector => + withNullSafe(o => + DateTimeUtils.toJavaDate(o.asInstanceOf[Int])) + case _: JavaTimestampObjectInspector => + withNullSafe(o => + DateTimeUtils.toJavaTimestamp(o.asInstanceOf[Long])) + case _: HiveDecimalObjectInspector if x.preferWritable() => + withNullSafe(o => getDecimalWritable(o.asInstanceOf[Decimal])) + case _: HiveDecimalObjectInspector => + withNullSafe(o => + HiveDecimal.create(o.asInstanceOf[Decimal].toJavaBigDecimal)) + case _: BinaryObjectInspector if x.preferWritable() => + withNullSafe(o => getBinaryWritable(o)) + case _: BinaryObjectInspector => + withNullSafe(o => o.asInstanceOf[Array[Byte]]) + case _: DateObjectInspector if x.preferWritable() => + withNullSafe(o => getDateWritable(o)) + case _: DateObjectInspector => + withNullSafe(o => DateTimeUtils.toJavaDate(o.asInstanceOf[Int])) + case _: TimestampObjectInspector if x.preferWritable() => + withNullSafe(o => getTimestampWritable(o)) + case _: TimestampObjectInspector => + withNullSafe(o => DateTimeUtils.toJavaTimestamp(o.asInstanceOf[Long])) + } case soi: StandardStructObjectInspector => val schema = dataType.asInstanceOf[StructType] val wrappers = soi.getAllStructFieldRefs.asScala.zip(schema.fields).map { case (ref, field) => wrapperFor(ref.getFieldObjectInspector, field.dataType) } - (o: Any) => { - if (o != null) { - val struct = soi.create() - val row = o.asInstanceOf[InternalRow] - soi.getAllStructFieldRefs.asScala.zip(wrappers).zipWithIndex.foreach { - case ((field, wrapper), i) => - soi.setStructFieldData(struct, field, wrapper(row.get(i, schema(i).dataType))) - } - struct - } else { - null + withNullSafe { o => + val struct = soi.create() + val row = o.asInstanceOf[InternalRow] + soi.getAllStructFieldRefs.asScala.zip(wrappers).zipWithIndex.foreach { + case ((field, wrapper), i) => + soi.setStructFieldData(struct, field, wrapper(row.get(i, schema(i).dataType))) + } + struct + } + + case ssoi: SettableStructObjectInspector => + val structType = dataType.asInstanceOf[StructType] + val wrappers = ssoi.getAllStructFieldRefs.asScala.zip(structType).map { + case (ref, tpe) => wrapperFor(ref.getFieldObjectInspector, tpe.dataType) + } + withNullSafe { o => + val row = o.asInstanceOf[InternalRow] + // 1. create the pojo (most likely) object + val result = ssoi.create() + ssoi.getAllStructFieldRefs.asScala.zip(wrappers).zipWithIndex.foreach { + case ((field, wrapper), i) => + val tpe = structType(i).dataType + ssoi.setStructFieldData( + result, + field, + wrapper(row.get(i, tpe)).asInstanceOf[AnyRef]) } + result + } + + case soi: StructObjectInspector => + val structType = dataType.asInstanceOf[StructType] + val wrappers = soi.getAllStructFieldRefs.asScala.zip(structType).map { + case (ref, tpe) => wrapperFor(ref.getFieldObjectInspector, tpe.dataType) + } + withNullSafe { o => + val row = o.asInstanceOf[InternalRow] + val result = new java.util.ArrayList[AnyRef](wrappers.size) + soi.getAllStructFieldRefs.asScala.zip(wrappers).zipWithIndex.foreach { + case ((field, wrapper), i) => + val tpe = structType(i).dataType + result.add(wrapper(row.get(i, tpe)).asInstanceOf[AnyRef]) + } + result } case loi: ListObjectInspector => val elementType = dataType.asInstanceOf[ArrayType].elementType val wrapper = wrapperFor(loi.getListElementObjectInspector, elementType) - (o: Any) => { - if (o != null) { - val array = o.asInstanceOf[ArrayData] - val values = new java.util.ArrayList[Any](array.numElements()) - array.foreach(elementType, (_, e) => values.add(wrapper(e))) - values - } else { - null - } + withNullSafe { o => + val array = o.asInstanceOf[ArrayData] + val values = new java.util.ArrayList[Any](array.numElements()) + array.foreach(elementType, (_, e) => values.add(wrapper(e))) + values } case moi: MapObjectInspector => val mt = dataType.asInstanceOf[MapType] val keyWrapper = wrapperFor(moi.getMapKeyObjectInspector, mt.keyType) val valueWrapper = wrapperFor(moi.getMapValueObjectInspector, mt.valueType) - - (o: Any) => { - if (o != null) { + withNullSafe { o => val map = o.asInstanceOf[MapData] val jmap = new java.util.HashMap[Any, Any](map.numElements()) map.foreach(mt.keyType, mt.valueType, (k, v) => jmap.put(keyWrapper(k), valueWrapper(v))) jmap - } else { - null } - } case _ => identity[Any] @@ -648,119 +707,19 @@ private[hive] trait HiveInspectors { (value: Any, row: MutableRow, ordinal: Int) => row(ordinal) = unwrapper(value) } - /** - * Converts native catalyst types to the types expected by Hive - * @param a the value to be wrapped - * @param oi This ObjectInspector associated with the value returned by this function, and - * the ObjectInspector should also be consistent with those returned from - * toInspector: DataType => ObjectInspector and - * toInspector: Expression => ObjectInspector - * - * Strictly follows the following order in wrapping (constant OI has the higher priority): - * Constant object inspector => return the bundled value of Constant object inspector - * Check whether the `a` is null => return null if true - * If object inspector prefers writable object => return a Writable for the given data `a` - * Map the catalyst data to the boxed java primitive - * - * NOTICE: the complex data type requires recursive wrapping. - */ - def wrap(a: Any, oi: ObjectInspector, dataType: DataType): AnyRef = oi match { - case x: ConstantObjectInspector => x.getWritableConstantValue - case _ if a == null => null - case x: PrimitiveObjectInspector => x match { - // TODO we don't support the HiveVarcharObjectInspector yet. - case _: StringObjectInspector if x.preferWritable() => getStringWritable(a) - case _: StringObjectInspector => a.asInstanceOf[UTF8String].toString() - case _: IntObjectInspector if x.preferWritable() => getIntWritable(a) - case _: IntObjectInspector => a.asInstanceOf[java.lang.Integer] - case _: BooleanObjectInspector if x.preferWritable() => getBooleanWritable(a) - case _: BooleanObjectInspector => a.asInstanceOf[java.lang.Boolean] - case _: FloatObjectInspector if x.preferWritable() => getFloatWritable(a) - case _: FloatObjectInspector => a.asInstanceOf[java.lang.Float] - case _: DoubleObjectInspector if x.preferWritable() => getDoubleWritable(a) - case _: DoubleObjectInspector => a.asInstanceOf[java.lang.Double] - case _: LongObjectInspector if x.preferWritable() => getLongWritable(a) - case _: LongObjectInspector => a.asInstanceOf[java.lang.Long] - case _: ShortObjectInspector if x.preferWritable() => getShortWritable(a) - case _: ShortObjectInspector => a.asInstanceOf[java.lang.Short] - case _: ByteObjectInspector if x.preferWritable() => getByteWritable(a) - case _: ByteObjectInspector => a.asInstanceOf[java.lang.Byte] - case _: HiveDecimalObjectInspector if x.preferWritable() => - getDecimalWritable(a.asInstanceOf[Decimal]) - case _: HiveDecimalObjectInspector => - HiveDecimal.create(a.asInstanceOf[Decimal].toJavaBigDecimal) - case _: BinaryObjectInspector if x.preferWritable() => getBinaryWritable(a) - case _: BinaryObjectInspector => a.asInstanceOf[Array[Byte]] - case _: DateObjectInspector if x.preferWritable() => getDateWritable(a) - case _: DateObjectInspector => DateTimeUtils.toJavaDate(a.asInstanceOf[Int]) - case _: TimestampObjectInspector if x.preferWritable() => getTimestampWritable(a) - case _: TimestampObjectInspector => DateTimeUtils.toJavaTimestamp(a.asInstanceOf[Long]) - } - case x: SettableStructObjectInspector => - val fieldRefs = x.getAllStructFieldRefs - val structType = dataType.asInstanceOf[StructType] - val row = a.asInstanceOf[InternalRow] - // 1. create the pojo (most likely) object - val result = x.create() - var i = 0 - val size = fieldRefs.size - while (i < size) { - // 2. set the property for the pojo - val tpe = structType(i).dataType - x.setStructFieldData( - result, - fieldRefs.get(i), - wrap(row.get(i, tpe), fieldRefs.get(i).getFieldObjectInspector, tpe)) - i += 1 - } - - result - case x: StructObjectInspector => - val fieldRefs = x.getAllStructFieldRefs - val structType = dataType.asInstanceOf[StructType] - val row = a.asInstanceOf[InternalRow] - val result = new java.util.ArrayList[AnyRef](fieldRefs.size) - var i = 0 - val size = fieldRefs.size - while (i < size) { - val tpe = structType(i).dataType - result.add(wrap(row.get(i, tpe), fieldRefs.get(i).getFieldObjectInspector, tpe)) - i += 1 - } - - result - case x: ListObjectInspector => - val list = new java.util.ArrayList[Object] - val tpe = dataType.asInstanceOf[ArrayType].elementType - a.asInstanceOf[ArrayData].foreach(tpe, (_, e) => - list.add(wrap(e, x.getListElementObjectInspector, tpe)) - ) - list - case x: MapObjectInspector => - val keyType = dataType.asInstanceOf[MapType].keyType - val valueType = dataType.asInstanceOf[MapType].valueType - val map = a.asInstanceOf[MapData] - - // Some UDFs seem to assume we pass in a HashMap. - val hashMap = new java.util.HashMap[Any, Any](map.numElements()) - - map.foreach(keyType, valueType, (k, v) => - hashMap.put(wrap(k, x.getMapKeyObjectInspector, keyType), - wrap(v, x.getMapValueObjectInspector, valueType)) - ) - - hashMap + def wrap(a: Any, oi: ObjectInspector, dataType: DataType): AnyRef = { + wrapperFor(oi, dataType)(a).asInstanceOf[AnyRef] } def wrap( row: InternalRow, - inspectors: Seq[ObjectInspector], + wrappers: Array[(Any) => Any], cache: Array[AnyRef], dataTypes: Array[DataType]): Array[AnyRef] = { var i = 0 - val length = inspectors.length + val length = wrappers.length while (i < length) { - cache(i) = wrap(row.get(i, dataTypes(i)), inspectors(i), dataTypes(i)) + cache(i) = wrappers(i)(row.get(i, dataTypes(i))).asInstanceOf[AnyRef] i += 1 } cache @@ -768,13 +727,13 @@ private[hive] trait HiveInspectors { def wrap( row: Seq[Any], - inspectors: Seq[ObjectInspector], + wrappers: Array[(Any) => Any], cache: Array[AnyRef], dataTypes: Array[DataType]): Array[AnyRef] = { var i = 0 - val length = inspectors.length + val length = wrappers.length while (i < length) { - cache(i) = wrap(row(i), inspectors(i), dataTypes(i)) + cache(i) = wrappers(i)(row(i)).asInstanceOf[AnyRef] i += 1 } cache diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index 962dd5a52ebc0..d54913518bb33 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -70,6 +70,9 @@ private[hive] case class HiveSimpleUDF( override lazy val dataType = javaClassToDataType(method.getReturnType) + @transient + private lazy val wrappers = children.map(x => wrapperFor(toInspector(x), x.dataType)).toArray + @transient lazy val unwrapper = unwrapperFor(ObjectInspectorFactory.getReflectionObjectInspector( method.getGenericReturnType(), ObjectInspectorOptions.JAVA)) @@ -82,7 +85,7 @@ private[hive] case class HiveSimpleUDF( // TODO: Finish input output types. override def eval(input: InternalRow): Any = { - val inputs = wrap(children.map(_.eval(input)), arguments, cached, inputDataTypes) + val inputs = wrap(children.map(_.eval(input)), wrappers, cached, inputDataTypes) val ret = FunctionRegistry.invoke( method, function, @@ -214,6 +217,9 @@ private[hive] case class HiveGenericUDTF( @transient private lazy val inputDataTypes: Array[DataType] = children.map(_.dataType).toArray + @transient + private lazy val wrappers = children.map(x => wrapperFor(toInspector(x), x.dataType)).toArray + @transient private lazy val unwrapper = unwrapperFor(outputInspector) @@ -222,7 +228,7 @@ private[hive] case class HiveGenericUDTF( val inputProjection = new InterpretedProjection(children) - function.process(wrap(inputProjection(input), inputInspectors, udtInput, inputDataTypes)) + function.process(wrap(inputProjection(input), wrappers, udtInput, inputDataTypes)) collector.collectRows() } @@ -296,6 +302,9 @@ private[hive] case class HiveUDAFFunction( @transient private lazy val function = functionAndInspector._1 + @transient + private lazy val wrappers = children.map(x => wrapperFor(toInspector(x), x.dataType)).toArray + @transient private lazy val returnInspector = functionAndInspector._2 @@ -322,7 +331,7 @@ private[hive] case class HiveUDAFFunction( override def update(_buffer: MutableRow, input: InternalRow): Unit = { val inputs = inputProjection(input) - function.iterate(buffer, wrap(inputs, inspectors, cached, inputDataTypes)) + function.iterate(buffer, wrap(inputs, wrappers, cached, inputDataTypes)) } override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { From 76dc2d9073e5e5c45c8b806a474beacb8415d506 Mon Sep 17 00:00:00 2001 From: Tao LI Date: Sun, 2 Oct 2016 16:01:02 -0700 Subject: [PATCH 177/306] [SPARK-14914][CORE][SQL] Skip/fix some test cases on Windows due to limitation of Windows ## What changes were proposed in this pull request? This PR proposes to fix/skip some tests failed on Windows. This PR takes over https://github.com/apache/spark/pull/12696. **Before** - **SparkSubmitSuite** ``` [info] - launch simple application with spark-submit *** FAILED *** (202 milliseconds) [info] java.io.IOException: Cannot run program "./bin/spark-submit" (in directory "C:\projects\spark"): CreateProcess error=2, The system cannot find the file specifie [info] - includes jars passed in through --jars *** FAILED *** (1 second, 625 milliseconds) [info] java.io.IOException: Cannot run program "./bin/spark-submit" (in directory "C:\projects\spark"): CreateProcess error=2, The system cannot find the file specified ``` - **DiskStoreSuite** ``` [info] - reads of memory-mapped and non memory-mapped files are equivalent *** FAILED *** (1 second, 78 milliseconds) [info] diskStoreMapped.remove(blockId) was false (DiskStoreSuite.scala:41) ``` **After** - **SparkSubmitSuite** ``` [info] - launch simple application with spark-submit (578 milliseconds) [info] - includes jars passed in through --jars (1 second, 875 milliseconds) ``` - **DiskStoreSuite** ``` [info] DiskStoreSuite: [info] - reads of memory-mapped and non memory-mapped files are equivalent !!! CANCELED !!! (766 milliseconds ``` For `CreateTableAsSelectSuite` and `FsHistoryProviderSuite`, I could not reproduce as the Java version seems higher than the one that has the bugs about `setReadable(..)` and `setWritable(...)` but as they are bugs reported clearly, it'd be sensible to skip those. We should revert the changes for both back as soon as we drop the support of Java 7. ## How was this patch tested? Manually tested via AppVeyor. Closes #12696 Author: Tao LI Author: U-FAREAST\tl Author: hyukjinkwon Closes #15320 from HyukjinKwon/SPARK-14914. --- .../src/main/scala/org/apache/spark/util/Utils.scala | 12 ++---------- .../org/apache/spark/deploy/SparkSubmitSuite.scala | 7 ++++++- .../deploy/history/FsHistoryProviderSuite.scala | 2 ++ .../org/apache/spark/storage/DiskStoreSuite.scala | 4 ++++ .../spark/sql/sources/CreateTableAsSelectSuite.scala | 3 ++- 5 files changed, 16 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index f3493bd96b1ee..ef832756ce3b7 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -23,7 +23,7 @@ import java.net._ import java.nio.ByteBuffer import java.nio.channels.Channels import java.nio.charset.StandardCharsets -import java.nio.file.Files +import java.nio.file.{Files, Paths} import java.util.{Locale, Properties, Random, UUID} import java.util.concurrent._ import java.util.concurrent.atomic.AtomicBoolean @@ -1014,15 +1014,7 @@ private[spark] object Utils extends Logging { * Check to see if file is a symbolic link. */ def isSymlink(file: File): Boolean = { - if (file == null) throw new NullPointerException("File must not be null") - if (isWindows) return false - val fileInCanonicalDir = if (file.getParent() == null) { - file - } else { - new File(file.getParentFile().getCanonicalFile(), file.getName()) - } - - !fileInCanonicalDir.getCanonicalFile().equals(fileInCanonicalDir.getAbsoluteFile()) + return Files.isSymbolicLink(Paths.get(file.toURI)) } /** diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 31c8fb26460df..732cbfaaeea46 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -649,8 +649,13 @@ class SparkSubmitSuite // NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly. private def runSparkSubmit(args: Seq[String]): Unit = { val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) + val sparkSubmitFile = if (Utils.isWindows) { + new File("..\\bin\\spark-submit.cmd") + } else { + new File("../bin/spark-submit") + } val process = Utils.executeCommand( - Seq("./bin/spark-submit") ++ args, + Seq(sparkSubmitFile.getCanonicalPath) ++ args, new File(sparkHome), Map("SPARK_TESTING" -> "1", "SPARK_HOME" -> sparkHome)) diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index 01bef0a11c124..a5eda7b5a5a75 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -126,6 +126,8 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc } test("SPARK-3697: ignore directories that cannot be read.") { + // setReadable(...) does not work on Windows. Please refer JDK-6728842. + assume(!Utils.isWindows) val logFile1 = newLogFile("new1", None, inProgress = false) writeFile(logFile1, true, None, SparkListenerApplicationStart("app1-1", Some("app1-1"), 1L, "test", None), diff --git a/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala index 9ed5016510d56..9e6b02b9eac4d 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala @@ -22,10 +22,14 @@ import java.util.Arrays import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.util.io.ChunkedByteBuffer +import org.apache.spark.util.Utils class DiskStoreSuite extends SparkFunSuite { test("reads of memory-mapped and non memory-mapped files are equivalent") { + // It will cause error when we tried to re-open the filestore and the + // memory-mapped byte buffer tot he file has not been GC on Windows. + assume(!Utils.isWindows) val confKey = "spark.storage.memoryMapThreshold" // Create a non-trivial (not all zeros) byte array diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala index 344d4aa6cfea4..c39005f6a1063 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala @@ -26,7 +26,6 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.parser.ParseException -import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.Utils @@ -83,6 +82,8 @@ class CreateTableAsSelectSuite } test("CREATE TABLE USING AS SELECT based on the file without write permission") { + // setWritable(...) does not work on Windows. Please refer JDK-6728842. + assume(!Utils.isWindows) val childPath = new File(path.toString, "child") path.mkdir() path.setWritable(false) From de3f71ed7a301387e870a38c14dad9508efc9743 Mon Sep 17 00:00:00 2001 From: Alex Bozarth Date: Mon, 3 Oct 2016 10:24:30 +0100 Subject: [PATCH 178/306] [SPARK-17598][SQL][WEB UI] User-friendly name for Spark Thrift Server in web UI ## What changes were proposed in this pull request? The name of Spark Thrift JDBC/ODBC Server in web UI reflects the name of the class, i.e. org.apache.spark.sql.hive.thrift.HiveThriftServer2. I changed it to Thrift JDBC/ODBC Server (like Spark shell for spark-shell) as recommended by jaceklaskowski. Note the user can still change the name adding `--name "App Name"` parameter to the start script as before ## How was this patch tested? By running the script with various parameters and checking the web ui ![screen shot 2016-09-27 at 12 19 12 pm](https://cloud.githubusercontent.com/assets/13952758/18888329/aebca47c-84ac-11e6-93d0-6e98684977c5.png) Author: Alex Bozarth Closes #15268 from ajbozarth/spark17598. --- sbin/start-thriftserver.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sbin/start-thriftserver.sh b/sbin/start-thriftserver.sh index ad7e7c5277eb1..f02f31793e346 100755 --- a/sbin/start-thriftserver.sh +++ b/sbin/start-thriftserver.sh @@ -53,4 +53,4 @@ fi export SUBMIT_USAGE_FUNCTION=usage -exec "${SPARK_HOME}"/sbin/spark-daemon.sh submit $CLASS 1 "$@" +exec "${SPARK_HOME}"/sbin/spark-daemon.sh submit $CLASS 1 --name "Thrift JDBC/ODBC Server" "$@" From a27033c0bbaae8f31db9b91693947ed71738ed11 Mon Sep 17 00:00:00 2001 From: Jagadeesan Date: Mon, 3 Oct 2016 10:46:38 +0100 Subject: [PATCH 179/306] =?UTF-8?q?[SPARK-17736][DOCUMENTATION][SPARKR]=20?= =?UTF-8?q?Update=20R=20README=20for=20rmarkdown,=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? To build R docs (which are built when R tests are run), users need to install pandoc and rmarkdown. This was done for Jenkins in ~~[SPARK-17420](https://issues.apache.org/jira/browse/SPARK-17420)~~ … pandoc] Author: Jagadeesan Closes #15309 from jagadeesanas2/SPARK-17736. --- docs/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/README.md b/docs/README.md index 8b515e187379c..ffd3b5712b618 100644 --- a/docs/README.md +++ b/docs/README.md @@ -19,8 +19,8 @@ installed. Also install the following libraries: $ sudo gem install jekyll jekyll-redirect-from pygments.rb $ sudo pip install Pygments # Following is needed only for generating API docs - $ sudo pip install sphinx - $ sudo Rscript -e 'install.packages(c("knitr", "devtools", "roxygen2", "testthat"), repos="http://cran.stat.ucla.edu/")' + $ sudo pip install sphinx pypandoc + $ sudo Rscript -e 'install.packages(c("knitr", "devtools", "roxygen2", "testthat", "rmarkdown"), repos="http://cran.stat.ucla.edu/")' ``` (Note: If you are on a system with both Ruby 1.9 and Ruby 2.0 you may need to replace gem with gem2.0) From 7bf92127643570e4eb3610fa3ffd36839eba2718 Mon Sep 17 00:00:00 2001 From: Zhenhua Wang Date: Mon, 3 Oct 2016 10:12:02 -0700 Subject: [PATCH 180/306] [SPARK-17073][SQL] generate column-level statistics MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Generate basic column statistics for all the atomic types: - numeric types: max, min, num of nulls, ndv (number of distinct values) - date/timestamp types: they are also represented as numbers internally, so they have the same stats as above. - string: avg length, max length, num of nulls, ndv - binary: avg length, max length, num of nulls - boolean: num of nulls, num of trues, num of falsies Also support storing and loading these statistics. One thing to notice: We support analyzing columns independently, e.g.: sql1: `ANALYZE TABLE src COMPUTE STATISTICS FOR COLUMNS key;` sql2: `ANALYZE TABLE src COMPUTE STATISTICS FOR COLUMNS value;` when running sql2 to collect column stats for `value`, we don’t remove stats of columns `key` which are analyzed in sql1 and not in sql2. As a result, **users need to guarantee consistency** between sql1 and sql2. If the table has been changed before sql2, users should re-analyze column `key` when they want to analyze column `value`: `ANALYZE TABLE src COMPUTE STATISTICS FOR COLUMNS key, value;` ## How was this patch tested? add unit tests Author: Zhenhua Wang Closes #15090 from wzhfy/colStats. --- .../spark/sql/catalyst/parser/SqlBase.g4 | 2 +- .../catalyst/plans/logical/Statistics.scala | 69 +++- .../spark/sql/execution/SparkSqlParser.scala | 18 +- .../command/AnalyzeColumnCommand.scala | 175 +++++++++ .../command/AnalyzeTableCommand.scala | 112 +++--- .../apache/spark/sql/internal/SQLConf.scala | 9 + .../spark/sql/internal/SessionState.scala | 8 +- .../spark/sql/StatisticsColumnSuite.scala | 334 ++++++++++++++++++ .../apache/spark/sql/StatisticsSuite.scala | 16 +- .../org/apache/spark/sql/StatisticsTest.scala | 129 +++++++ .../spark/sql/hive/HiveExternalCatalog.scala | 28 +- .../spark/sql/hive/StatisticsSuite.scala | 119 +++++-- .../sql/hive/execution/SQLViewSuite.scala | 1 + 13 files changed, 906 insertions(+), 114 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index de2f9ee6bc7a2..1284681fe80b4 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -86,7 +86,7 @@ statement | CREATE TABLE (IF NOT EXISTS)? target=tableIdentifier LIKE source=tableIdentifier #createTableLike | ANALYZE TABLE tableIdentifier partitionSpec? COMPUTE STATISTICS - (identifier | FOR COLUMNS identifierSeq?)? #analyze + (identifier | FOR COLUMNS identifierSeq)? #analyze | ALTER (TABLE | VIEW) from=tableIdentifier RENAME TO to=tableIdentifier #renameTable | ALTER (TABLE | VIEW) tableIdentifier diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala index 3cf20385dd712..43455c989c0f4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala @@ -17,6 +17,12 @@ package org.apache.spark.sql.catalyst.plans.logical +import org.apache.commons.codec.binary.Base64 + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.types._ + /** * Estimates of various statistics. The default estimation logic simply lazily multiplies the * corresponding statistic produced by the children. To override this behavior, override @@ -32,12 +38,15 @@ package org.apache.spark.sql.catalyst.plans.logical * @param sizeInBytes Physical size in bytes. For leaf operators this defaults to 1, otherwise it * defaults to the product of children's `sizeInBytes`. * @param rowCount Estimated number of rows. + * @param colStats Column-level statistics. * @param isBroadcastable If true, output is small enough to be used in a broadcast join. */ case class Statistics( sizeInBytes: BigInt, rowCount: Option[BigInt] = None, + colStats: Map[String, ColumnStat] = Map.empty, isBroadcastable: Boolean = false) { + override def toString: String = "Statistics(" + simpleString + ")" /** Readable string representation for the Statistics. */ @@ -45,6 +54,64 @@ case class Statistics( Seq(s"sizeInBytes=$sizeInBytes", if (rowCount.isDefined) s"rowCount=${rowCount.get}" else "", s"isBroadcastable=$isBroadcastable" - ).filter(_.nonEmpty).mkString("", ", ", "") + ).filter(_.nonEmpty).mkString(", ") + } +} + +/** + * Statistics for a column. + */ +case class ColumnStat(statRow: InternalRow) { + + def forNumeric[T <: AtomicType](dataType: T): NumericColumnStat[T] = { + NumericColumnStat(statRow, dataType) + } + def forString: StringColumnStat = StringColumnStat(statRow) + def forBinary: BinaryColumnStat = BinaryColumnStat(statRow) + def forBoolean: BooleanColumnStat = BooleanColumnStat(statRow) + + override def toString: String = { + // use Base64 for encoding + Base64.encodeBase64String(statRow.asInstanceOf[UnsafeRow].getBytes) } } + +object ColumnStat { + def apply(numFields: Int, str: String): ColumnStat = { + // use Base64 for decoding + val bytes = Base64.decodeBase64(str) + val unsafeRow = new UnsafeRow(numFields) + unsafeRow.pointTo(bytes, bytes.length) + ColumnStat(unsafeRow) + } +} + +case class NumericColumnStat[T <: AtomicType](statRow: InternalRow, dataType: T) { + // The indices here must be consistent with `ColumnStatStruct.numericColumnStat`. + val numNulls: Long = statRow.getLong(0) + val max: T#InternalType = statRow.get(1, dataType).asInstanceOf[T#InternalType] + val min: T#InternalType = statRow.get(2, dataType).asInstanceOf[T#InternalType] + val ndv: Long = statRow.getLong(3) +} + +case class StringColumnStat(statRow: InternalRow) { + // The indices here must be consistent with `ColumnStatStruct.stringColumnStat`. + val numNulls: Long = statRow.getLong(0) + val avgColLen: Double = statRow.getDouble(1) + val maxColLen: Long = statRow.getLong(2) + val ndv: Long = statRow.getLong(3) +} + +case class BinaryColumnStat(statRow: InternalRow) { + // The indices here must be consistent with `ColumnStatStruct.binaryColumnStat`. + val numNulls: Long = statRow.getLong(0) + val avgColLen: Double = statRow.getDouble(1) + val maxColLen: Long = statRow.getLong(2) +} + +case class BooleanColumnStat(statRow: InternalRow) { + // The indices here must be consistent with `ColumnStatStruct.booleanColumnStat`. + val numNulls: Long = statRow.getLong(0) + val numTrues: Long = statRow.getLong(1) + val numFalses: Long = statRow.getLong(2) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 3f34d0f25393d..7f1e23e665eb1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -87,19 +87,27 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { } /** - * Create an [[AnalyzeTableCommand]] command. This currently only implements the NOSCAN - * option (other options are passed on to Hive) e.g.: + * Create an [[AnalyzeTableCommand]] command or an [[AnalyzeColumnCommand]] command. + * Example SQL for analyzing table : * {{{ - * ANALYZE TABLE table COMPUTE STATISTICS NOSCAN; + * ANALYZE TABLE table COMPUTE STATISTICS [NOSCAN]; + * }}} + * Example SQL for analyzing columns : + * {{{ + * ANALYZE TABLE table COMPUTE STATISTICS FOR COLUMNS column1, column2; * }}} */ override def visitAnalyze(ctx: AnalyzeContext): LogicalPlan = withOrigin(ctx) { if (ctx.partitionSpec == null && ctx.identifier != null && ctx.identifier.getText.toLowerCase == "noscan") { - AnalyzeTableCommand(visitTableIdentifier(ctx.tableIdentifier).toString) + AnalyzeTableCommand(visitTableIdentifier(ctx.tableIdentifier)) + } else if (ctx.identifierSeq() == null) { + AnalyzeTableCommand(visitTableIdentifier(ctx.tableIdentifier), noscan = false) } else { - AnalyzeTableCommand(visitTableIdentifier(ctx.tableIdentifier).toString, noscan = false) + AnalyzeColumnCommand( + visitTableIdentifier(ctx.tableIdentifier), + visitIdentifierSeq(ctx.identifierSeq())) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala new file mode 100644 index 0000000000000..7066378279971 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.command + +import scala.collection.mutable + +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases +import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogTable} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, ColumnStat, LogicalPlan, Statistics} +import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.types._ + + +/** + * Analyzes the given columns of the given table to generate statistics, which will be used in + * query optimizations. + */ +case class AnalyzeColumnCommand( + tableIdent: TableIdentifier, + columnNames: Seq[String]) extends RunnableCommand { + + override def run(sparkSession: SparkSession): Seq[Row] = { + val sessionState = sparkSession.sessionState + val db = tableIdent.database.getOrElse(sessionState.catalog.getCurrentDatabase) + val tableIdentWithDB = TableIdentifier(tableIdent.table, Some(db)) + val relation = EliminateSubqueryAliases(sessionState.catalog.lookupRelation(tableIdentWithDB)) + + relation match { + case catalogRel: CatalogRelation => + updateStats(catalogRel.catalogTable, + AnalyzeTableCommand.calculateTotalSize(sessionState, catalogRel.catalogTable)) + + case logicalRel: LogicalRelation if logicalRel.catalogTable.isDefined => + updateStats(logicalRel.catalogTable.get, logicalRel.relation.sizeInBytes) + + case otherRelation => + throw new AnalysisException("ANALYZE TABLE is not supported for " + + s"${otherRelation.nodeName}.") + } + + def updateStats(catalogTable: CatalogTable, newTotalSize: Long): Unit = { + val (rowCount, columnStats) = computeColStats(sparkSession, relation) + val statistics = Statistics( + sizeInBytes = newTotalSize, + rowCount = Some(rowCount), + colStats = columnStats ++ catalogTable.stats.map(_.colStats).getOrElse(Map())) + sessionState.catalog.alterTable(catalogTable.copy(stats = Some(statistics))) + // Refresh the cached data source table in the catalog. + sessionState.catalog.refreshTable(tableIdentWithDB) + } + + Seq.empty[Row] + } + + def computeColStats( + sparkSession: SparkSession, + relation: LogicalPlan): (Long, Map[String, ColumnStat]) = { + + // check correctness of column names + val attributesToAnalyze = mutable.MutableList[Attribute]() + val duplicatedColumns = mutable.MutableList[String]() + val resolver = sparkSession.sessionState.conf.resolver + columnNames.foreach { col => + val exprOption = relation.output.find(attr => resolver(attr.name, col)) + val expr = exprOption.getOrElse(throw new AnalysisException(s"Invalid column name: $col.")) + // do deduplication + if (!attributesToAnalyze.contains(expr)) { + attributesToAnalyze += expr + } else { + duplicatedColumns += col + } + } + if (duplicatedColumns.nonEmpty) { + logWarning(s"Duplicated columns ${duplicatedColumns.mkString("(", ", ", ")")} detected " + + s"when analyzing columns ${columnNames.mkString("(", ", ", ")")}, ignoring them.") + } + + // Collect statistics per column. + // The first element in the result will be the overall row count, the following elements + // will be structs containing all column stats. + // The layout of each struct follows the layout of the ColumnStats. + val ndvMaxErr = sparkSession.sessionState.conf.ndvMaxError + val expressions = Count(Literal(1)).toAggregateExpression() +: + attributesToAnalyze.map(ColumnStatStruct(_, ndvMaxErr)) + val namedExpressions = expressions.map(e => Alias(e, e.toString)()) + val statsRow = Dataset.ofRows(sparkSession, Aggregate(Nil, namedExpressions, relation)) + .queryExecution.toRdd.collect().head + + // unwrap the result + val rowCount = statsRow.getLong(0) + val columnStats = attributesToAnalyze.zipWithIndex.map { case (expr, i) => + val numFields = ColumnStatStruct.numStatFields(expr.dataType) + (expr.name, ColumnStat(statsRow.getStruct(i + 1, numFields))) + }.toMap + (rowCount, columnStats) + } +} + +object ColumnStatStruct { + val zero = Literal(0, LongType) + val one = Literal(1, LongType) + + def numNulls(e: Expression): Expression = if (e.nullable) Sum(If(IsNull(e), one, zero)) else zero + def max(e: Expression): Expression = Max(e) + def min(e: Expression): Expression = Min(e) + def ndv(e: Expression, relativeSD: Double): Expression = { + // the approximate ndv should never be larger than the number of rows + Least(Seq(HyperLogLogPlusPlus(e, relativeSD), Count(one))) + } + def avgLength(e: Expression): Expression = Average(Length(e)) + def maxLength(e: Expression): Expression = Max(Length(e)) + def numTrues(e: Expression): Expression = Sum(If(e, one, zero)) + def numFalses(e: Expression): Expression = Sum(If(Not(e), one, zero)) + + def getStruct(exprs: Seq[Expression]): CreateStruct = { + CreateStruct(exprs.map { expr: Expression => + expr.transformUp { + case af: AggregateFunction => af.toAggregateExpression() + } + }) + } + + def numericColumnStat(e: Expression, relativeSD: Double): Seq[Expression] = { + Seq(numNulls(e), max(e), min(e), ndv(e, relativeSD)) + } + + def stringColumnStat(e: Expression, relativeSD: Double): Seq[Expression] = { + Seq(numNulls(e), avgLength(e), maxLength(e), ndv(e, relativeSD)) + } + + def binaryColumnStat(e: Expression): Seq[Expression] = { + Seq(numNulls(e), avgLength(e), maxLength(e)) + } + + def booleanColumnStat(e: Expression): Seq[Expression] = { + Seq(numNulls(e), numTrues(e), numFalses(e)) + } + + def numStatFields(dataType: DataType): Int = { + dataType match { + case BinaryType | BooleanType => 3 + case _ => 4 + } + } + + def apply(e: Attribute, relativeSD: Double): CreateStruct = e.dataType match { + // Use aggregate functions to compute statistics we need. + case _: NumericType | TimestampType | DateType => getStruct(numericColumnStat(e, relativeSD)) + case StringType => getStruct(stringColumnStat(e, relativeSD)) + case BinaryType => getStruct(binaryColumnStat(e)) + case BooleanType => getStruct(booleanColumnStat(e)) + case otherType => + throw new AnalysisException("Analyzing columns is not supported for column " + + s"${e.name} of data type: ${e.dataType}.") + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala index 40aecafecf5bb..7b0e49b665f42 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala @@ -21,81 +21,40 @@ import scala.util.control.NonFatal import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, Dataset, Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogTable} import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.internal.SessionState /** - * Analyzes the given table in the current database to generate statistics, which will be - * used in query optimizations. + * Analyzes the given table to generate statistics, which will be used in query optimizations. */ -case class AnalyzeTableCommand(tableName: String, noscan: Boolean = true) extends RunnableCommand { +case class AnalyzeTableCommand( + tableIdent: TableIdentifier, + noscan: Boolean = true) extends RunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { val sessionState = sparkSession.sessionState - val tableIdent = sessionState.sqlParser.parseTableIdentifier(tableName) val db = tableIdent.database.getOrElse(sessionState.catalog.getCurrentDatabase) - val tableIdentwithDB = TableIdentifier(tableIdent.table, Some(db)) - val relation = EliminateSubqueryAliases(sessionState.catalog.lookupRelation(tableIdentwithDB)) + val tableIdentWithDB = TableIdentifier(tableIdent.table, Some(db)) + val relation = EliminateSubqueryAliases(sessionState.catalog.lookupRelation(tableIdentWithDB)) relation match { case relation: CatalogRelation => - val catalogTable: CatalogTable = relation.catalogTable - // This method is mainly based on - // org.apache.hadoop.hive.ql.stats.StatsUtils.getFileSizeForTable(HiveConf, Table) - // in Hive 0.13 (except that we do not use fs.getContentSummary). - // TODO: Generalize statistics collection. - // TODO: Why fs.getContentSummary returns wrong size on Jenkins? - // Can we use fs.getContentSummary in future? - // Seems fs.getContentSummary returns wrong table size on Jenkins. So we use - // countFileSize to count the table size. - val stagingDir = sessionState.conf.getConfString("hive.exec.stagingdir", ".hive-staging") - - def calculateTableSize(fs: FileSystem, path: Path): Long = { - val fileStatus = fs.getFileStatus(path) - val size = if (fileStatus.isDirectory) { - fs.listStatus(path) - .map { status => - if (!status.getPath.getName.startsWith(stagingDir)) { - calculateTableSize(fs, status.getPath) - } else { - 0L - } - }.sum - } else { - fileStatus.getLen - } - - size - } - - val newTotalSize = - catalogTable.storage.locationUri.map { p => - val path = new Path(p) - try { - val fs = path.getFileSystem(sparkSession.sessionState.newHadoopConf()) - calculateTableSize(fs, path) - } catch { - case NonFatal(e) => - logWarning( - s"Failed to get the size of table ${catalogTable.identifier.table} in the " + - s"database ${catalogTable.identifier.database} because of ${e.toString}", e) - 0L - } - }.getOrElse(0L) - - updateTableStats(catalogTable, newTotalSize) + updateTableStats(relation.catalogTable, + AnalyzeTableCommand.calculateTotalSize(sessionState, relation.catalogTable)) // data source tables have been converted into LogicalRelations case logicalRel: LogicalRelation if logicalRel.catalogTable.isDefined => updateTableStats(logicalRel.catalogTable.get, logicalRel.relation.sizeInBytes) case otherRelation => - throw new AnalysisException(s"ANALYZE TABLE is not supported for " + + throw new AnalysisException("ANALYZE TABLE is not supported for " + s"${otherRelation.nodeName}.") } @@ -125,10 +84,57 @@ case class AnalyzeTableCommand(tableName: String, noscan: Boolean = true) extend if (newStats.isDefined) { sessionState.catalog.alterTable(catalogTable.copy(stats = newStats)) // Refresh the cached data source table in the catalog. - sessionState.catalog.refreshTable(tableIdent) + sessionState.catalog.refreshTable(tableIdentWithDB) } } Seq.empty[Row] } } + +object AnalyzeTableCommand extends Logging { + + def calculateTotalSize(sessionState: SessionState, catalogTable: CatalogTable): Long = { + // This method is mainly based on + // org.apache.hadoop.hive.ql.stats.StatsUtils.getFileSizeForTable(HiveConf, Table) + // in Hive 0.13 (except that we do not use fs.getContentSummary). + // TODO: Generalize statistics collection. + // TODO: Why fs.getContentSummary returns wrong size on Jenkins? + // Can we use fs.getContentSummary in future? + // Seems fs.getContentSummary returns wrong table size on Jenkins. So we use + // countFileSize to count the table size. + val stagingDir = sessionState.conf.getConfString("hive.exec.stagingdir", ".hive-staging") + + def calculateTableSize(fs: FileSystem, path: Path): Long = { + val fileStatus = fs.getFileStatus(path) + val size = if (fileStatus.isDirectory) { + fs.listStatus(path) + .map { status => + if (!status.getPath.getName.startsWith(stagingDir)) { + calculateTableSize(fs, status.getPath) + } else { + 0L + } + }.sum + } else { + fileStatus.getLen + } + + size + } + + catalogTable.storage.locationUri.map { p => + val path = new Path(p) + try { + val fs = path.getFileSystem(sessionState.newHadoopConf()) + calculateTableSize(fs, path) + } catch { + case NonFatal(e) => + logWarning( + s"Failed to get the size of table ${catalogTable.identifier.table} in the " + + s"database ${catalogTable.identifier.database} because of ${e.toString}", e) + 0L + } + }.getOrElse(0L) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index e67140fefef9a..fecdf792fd14a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -581,6 +581,13 @@ object SQLConf { .timeConf(TimeUnit.MILLISECONDS) .createWithDefault(10L) + val NDV_MAX_ERROR = + SQLConfigBuilder("spark.sql.statistics.ndv.maxError") + .internal() + .doc("The maximum estimation error allowed in HyperLogLog++ algorithm.") + .doubleConf + .createWithDefault(0.05) + object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } @@ -757,6 +764,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { override def groupByOrdinal: Boolean = getConf(GROUP_BY_ORDINAL) override def crossJoinEnabled: Boolean = getConf(SQLConf.CROSS_JOINS_ENABLED) + + def ndvMaxError: Double = getConf(NDV_MAX_ERROR) /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index c899773b6b36f..9f7d0019c6b92 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -23,6 +23,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.optimizer.Optimizer @@ -188,11 +189,8 @@ private[sql] class SessionState(sparkSession: SparkSession) { /** * Analyzes the given table in the current database to generate statistics, which will be * used in query optimizations. - * - * Right now, it only supports catalog tables and it only updates the size of a catalog table - * in the external catalog. */ - def analyze(tableName: String, noscan: Boolean = true): Unit = { - AnalyzeTableCommand(tableName, noscan).run(sparkSession) + def analyze(tableIdent: TableIdentifier, noscan: Boolean = true): Unit = { + AnalyzeTableCommand(tableIdent, noscan).run(sparkSession) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala new file mode 100644 index 0000000000000..0ee0547c45591 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala @@ -0,0 +1,334 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.sql.{Date, Timestamp} + +import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} +import org.apache.spark.sql.catalyst.parser.ParseException +import org.apache.spark.sql.catalyst.plans.logical.ColumnStat +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.execution.command.AnalyzeColumnCommand +import org.apache.spark.sql.test.SQLTestData.ArrayData +import org.apache.spark.sql.types._ + +class StatisticsColumnSuite extends StatisticsTest { + import testImplicits._ + + test("parse analyze column commands") { + val tableName = "tbl" + + // we need to specify column names + intercept[ParseException] { + sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS") + } + + val analyzeSql = s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS key, value" + val parsed = spark.sessionState.sqlParser.parsePlan(analyzeSql) + val expected = AnalyzeColumnCommand(TableIdentifier(tableName), Seq("key", "value")) + comparePlans(parsed, expected) + } + + test("analyzing columns of non-atomic types is not supported") { + val tableName = "tbl" + withTable(tableName) { + Seq(ArrayData(Seq(1, 2, 3), Seq(Seq(1, 2, 3)))).toDF().write.saveAsTable(tableName) + val err = intercept[AnalysisException] { + sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS data") + } + assert(err.message.contains("Analyzing columns is not supported")) + } + } + + test("check correctness of columns") { + val table = "tbl" + val colName1 = "abc" + val colName2 = "x.yz" + withTable(table) { + sql(s"CREATE TABLE $table ($colName1 int, `$colName2` string) USING PARQUET") + + val invalidColError = intercept[AnalysisException] { + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS key") + } + assert(invalidColError.message == "Invalid column name: key.") + + withSQLConf("spark.sql.caseSensitive" -> "true") { + val invalidErr = intercept[AnalysisException] { + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS ${colName1.toUpperCase}") + } + assert(invalidErr.message == s"Invalid column name: ${colName1.toUpperCase}.") + } + + withSQLConf("spark.sql.caseSensitive" -> "false") { + val columnsToAnalyze = Seq(colName2.toUpperCase, colName1, colName2) + val tableIdent = TableIdentifier(table, Some("default")) + val relation = spark.sessionState.catalog.lookupRelation(tableIdent) + val (_, columnStats) = + AnalyzeColumnCommand(tableIdent, columnsToAnalyze).computeColStats(spark, relation) + assert(columnStats.contains(colName1)) + assert(columnStats.contains(colName2)) + // check deduplication + assert(columnStats.size == 2) + assert(!columnStats.contains(colName2.toUpperCase)) + } + } + } + + private def getNonNullValues[T](values: Seq[Option[T]]): Seq[T] = { + values.filter(_.isDefined).map(_.get) + } + + test("column-level statistics for integral type columns") { + val values = (0 to 5).map { i => + if (i % 2 == 0) None else Some(i) + } + val data = values.map { i => + (i.map(_.toByte), i.map(_.toShort), i.map(_.toInt), i.map(_.toLong)) + } + + val df = data.toDF("c1", "c2", "c3", "c4") + val nonNullValues = getNonNullValues[Int](values) + val expectedColStatsSeq = df.schema.map { f => + val colStat = ColumnStat(InternalRow( + values.count(_.isEmpty).toLong, + nonNullValues.max, + nonNullValues.min, + nonNullValues.distinct.length.toLong)) + (f, colStat) + } + checkColStats(df, expectedColStatsSeq) + } + + test("column-level statistics for fractional type columns") { + val values: Seq[Option[Decimal]] = (0 to 5).map { i => + if (i == 0) None else Some(Decimal(i + i * 0.01)) + } + val data = values.map { i => + (i.map(_.toFloat), i.map(_.toDouble), i) + } + + val df = data.toDF("c1", "c2", "c3") + val nonNullValues = getNonNullValues[Decimal](values) + val numNulls = values.count(_.isEmpty).toLong + val ndv = nonNullValues.distinct.length.toLong + val expectedColStatsSeq = df.schema.map { f => + val colStat = f.dataType match { + case floatType: FloatType => + ColumnStat(InternalRow(numNulls, nonNullValues.max.toFloat, nonNullValues.min.toFloat, + ndv)) + case doubleType: DoubleType => + ColumnStat(InternalRow(numNulls, nonNullValues.max.toDouble, nonNullValues.min.toDouble, + ndv)) + case decimalType: DecimalType => + ColumnStat(InternalRow(numNulls, nonNullValues.max, nonNullValues.min, ndv)) + } + (f, colStat) + } + checkColStats(df, expectedColStatsSeq) + } + + test("column-level statistics for string column") { + val values = Seq(None, Some("a"), Some("bbbb"), Some("cccc"), Some("")) + val df = values.toDF("c1") + val nonNullValues = getNonNullValues[String](values) + val expectedColStatsSeq = df.schema.map { f => + val colStat = ColumnStat(InternalRow( + values.count(_.isEmpty).toLong, + nonNullValues.map(_.length).sum / nonNullValues.length.toDouble, + nonNullValues.map(_.length).max.toLong, + nonNullValues.distinct.length.toLong)) + (f, colStat) + } + checkColStats(df, expectedColStatsSeq) + } + + test("column-level statistics for binary column") { + val values = Seq(None, Some("a"), Some("bbbb"), Some("cccc"), Some("")).map(_.map(_.getBytes)) + val df = values.toDF("c1") + val nonNullValues = getNonNullValues[Array[Byte]](values) + val expectedColStatsSeq = df.schema.map { f => + val colStat = ColumnStat(InternalRow( + values.count(_.isEmpty).toLong, + nonNullValues.map(_.length).sum / nonNullValues.length.toDouble, + nonNullValues.map(_.length).max.toLong)) + (f, colStat) + } + checkColStats(df, expectedColStatsSeq) + } + + test("column-level statistics for boolean column") { + val values = Seq(None, Some(true), Some(false), Some(true)) + val df = values.toDF("c1") + val nonNullValues = getNonNullValues[Boolean](values) + val expectedColStatsSeq = df.schema.map { f => + val colStat = ColumnStat(InternalRow( + values.count(_.isEmpty).toLong, + nonNullValues.count(_.equals(true)).toLong, + nonNullValues.count(_.equals(false)).toLong)) + (f, colStat) + } + checkColStats(df, expectedColStatsSeq) + } + + test("column-level statistics for date column") { + val values = Seq(None, Some("1970-01-01"), Some("1970-02-02")).map(_.map(Date.valueOf)) + val df = values.toDF("c1") + val nonNullValues = getNonNullValues[Date](values) + val expectedColStatsSeq = df.schema.map { f => + val colStat = ColumnStat(InternalRow( + values.count(_.isEmpty).toLong, + // Internally, DateType is represented as the number of days from 1970-01-01. + nonNullValues.map(DateTimeUtils.fromJavaDate).max, + nonNullValues.map(DateTimeUtils.fromJavaDate).min, + nonNullValues.distinct.length.toLong)) + (f, colStat) + } + checkColStats(df, expectedColStatsSeq) + } + + test("column-level statistics for timestamp column") { + val values = Seq(None, Some("1970-01-01 00:00:00"), Some("1970-01-01 00:00:05")).map { i => + i.map(Timestamp.valueOf) + } + val df = values.toDF("c1") + val nonNullValues = getNonNullValues[Timestamp](values) + val expectedColStatsSeq = df.schema.map { f => + val colStat = ColumnStat(InternalRow( + values.count(_.isEmpty).toLong, + // Internally, TimestampType is represented as the number of days from 1970-01-01 + nonNullValues.map(DateTimeUtils.fromJavaTimestamp).max, + nonNullValues.map(DateTimeUtils.fromJavaTimestamp).min, + nonNullValues.distinct.length.toLong)) + (f, colStat) + } + checkColStats(df, expectedColStatsSeq) + } + + test("column-level statistics for null columns") { + val values = Seq(None, None) + val data = values.map { i => + (i.map(_.toString), i.map(_.toString.toInt)) + } + val df = data.toDF("c1", "c2") + val expectedColStatsSeq = df.schema.map { f => + (f, ColumnStat(InternalRow(values.count(_.isEmpty).toLong, null, null, 0L))) + } + checkColStats(df, expectedColStatsSeq) + } + + test("column-level statistics for columns with different types") { + val intSeq = Seq(1, 2) + val doubleSeq = Seq(1.01d, 2.02d) + val stringSeq = Seq("a", "bb") + val binarySeq = Seq("a", "bb").map(_.getBytes) + val booleanSeq = Seq(true, false) + val dateSeq = Seq("1970-01-01", "1970-02-02").map(Date.valueOf) + val timestampSeq = Seq("1970-01-01 00:00:00", "1970-01-01 00:00:05").map(Timestamp.valueOf) + val longSeq = Seq(5L, 4L) + + val data = intSeq.indices.map { i => + (intSeq(i), doubleSeq(i), stringSeq(i), binarySeq(i), booleanSeq(i), dateSeq(i), + timestampSeq(i), longSeq(i)) + } + val df = data.toDF("c1", "c2", "c3", "c4", "c5", "c6", "c7", "c8") + val expectedColStatsSeq = df.schema.map { f => + val colStat = f.dataType match { + case IntegerType => + ColumnStat(InternalRow(0L, intSeq.max, intSeq.min, intSeq.distinct.length.toLong)) + case DoubleType => + ColumnStat(InternalRow(0L, doubleSeq.max, doubleSeq.min, + doubleSeq.distinct.length.toLong)) + case StringType => + ColumnStat(InternalRow(0L, stringSeq.map(_.length).sum / stringSeq.length.toDouble, + stringSeq.map(_.length).max.toLong, stringSeq.distinct.length.toLong)) + case BinaryType => + ColumnStat(InternalRow(0L, binarySeq.map(_.length).sum / binarySeq.length.toDouble, + binarySeq.map(_.length).max.toLong)) + case BooleanType => + ColumnStat(InternalRow(0L, booleanSeq.count(_.equals(true)).toLong, + booleanSeq.count(_.equals(false)).toLong)) + case DateType => + ColumnStat(InternalRow(0L, dateSeq.map(DateTimeUtils.fromJavaDate).max, + dateSeq.map(DateTimeUtils.fromJavaDate).min, dateSeq.distinct.length.toLong)) + case TimestampType => + ColumnStat(InternalRow(0L, timestampSeq.map(DateTimeUtils.fromJavaTimestamp).max, + timestampSeq.map(DateTimeUtils.fromJavaTimestamp).min, + timestampSeq.distinct.length.toLong)) + case LongType => + ColumnStat(InternalRow(0L, longSeq.max, longSeq.min, longSeq.distinct.length.toLong)) + } + (f, colStat) + } + checkColStats(df, expectedColStatsSeq) + } + + test("update table-level stats while collecting column-level stats") { + val table = "tbl" + withTable(table) { + sql(s"CREATE TABLE $table (c1 int) USING PARQUET") + sql(s"INSERT INTO $table SELECT 1") + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS") + checkTableStats(tableName = table, expectedRowCount = Some(1)) + + // update table-level stats between analyze table and analyze column commands + sql(s"INSERT INTO $table SELECT 1") + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS c1") + val fetchedStats = checkTableStats(tableName = table, expectedRowCount = Some(2)) + + val colStat = fetchedStats.get.colStats("c1") + StatisticsTest.checkColStat( + dataType = IntegerType, + colStat = colStat, + expectedColStat = ColumnStat(InternalRow(0L, 1, 1, 1L)), + rsd = spark.sessionState.conf.ndvMaxError) + } + } + + test("analyze column stats independently") { + val table = "tbl" + withTable(table) { + sql(s"CREATE TABLE $table (c1 int, c2 long) USING PARQUET") + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS c1") + val fetchedStats1 = checkTableStats(tableName = table, expectedRowCount = Some(0)) + assert(fetchedStats1.get.colStats.size == 1) + val expected1 = ColumnStat(InternalRow(0L, null, null, 0L)) + val rsd = spark.sessionState.conf.ndvMaxError + StatisticsTest.checkColStat( + dataType = IntegerType, + colStat = fetchedStats1.get.colStats("c1"), + expectedColStat = expected1, + rsd = rsd) + + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS c2") + val fetchedStats2 = checkTableStats(tableName = table, expectedRowCount = Some(0)) + // column c1 is kept in the stats + assert(fetchedStats2.get.colStats.size == 2) + StatisticsTest.checkColStat( + dataType = IntegerType, + colStat = fetchedStats2.get.colStats("c1"), + expectedColStat = expected1, + rsd = rsd) + val expected2 = ColumnStat(InternalRow(0L, null, null, 0L)) + StatisticsTest.checkColStat( + dataType = LongType, + colStat = fetchedStats2.get.colStats("c2"), + expectedColStat = expected2, + rsd = rsd) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsSuite.scala index 264a2ffbebebd..8cf42e9248c2a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsSuite.scala @@ -18,11 +18,9 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.plans.logical.{GlobalLimit, Join, LocalLimit} -import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -class StatisticsSuite extends QueryTest with SharedSQLContext { +class StatisticsSuite extends StatisticsTest { import testImplicits._ test("SPARK-15392: DataFrame created from RDD should not be broadcasted") { @@ -77,20 +75,10 @@ class StatisticsSuite extends QueryTest with SharedSQLContext { } test("test table-level statistics for data source table created in InMemoryCatalog") { - def checkTableStats(tableName: String, expectedRowCount: Option[BigInt]): Unit = { - val df = sql(s"SELECT * FROM $tableName") - val relations = df.queryExecution.analyzed.collect { case rel: LogicalRelation => - assert(rel.catalogTable.isDefined) - assert(rel.catalogTable.get.stats.flatMap(_.rowCount) === expectedRowCount) - rel - } - assert(relations.size === 1) - } - val tableName = "tbl" withTable(tableName) { sql(s"CREATE TABLE $tableName(i INT, j STRING) USING parquet") - Seq(1 -> "a", 2 -> "b").toDF("i", "j").write.mode("overwrite").insertInto("tbl") + Seq(1 -> "a", 2 -> "b").toDF("i", "j").write.mode("overwrite").insertInto(tableName) // noscan won't count the number of rows sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS noscan") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala new file mode 100644 index 0000000000000..5134ac0e7e5b3 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Statistics} +import org.apache.spark.sql.execution.command.{AnalyzeColumnCommand, ColumnStatStruct} +import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ + +trait StatisticsTest extends QueryTest with SharedSQLContext { + + def checkColStats( + df: DataFrame, + expectedColStatsSeq: Seq[(StructField, ColumnStat)]): Unit = { + val table = "tbl" + withTable(table) { + df.write.format("json").saveAsTable(table) + val columns = expectedColStatsSeq.map(_._1) + val tableIdent = TableIdentifier(table, Some("default")) + val relation = spark.sessionState.catalog.lookupRelation(tableIdent) + val (_, columnStats) = + AnalyzeColumnCommand(tableIdent, columns.map(_.name)).computeColStats(spark, relation) + expectedColStatsSeq.foreach { case (field, expectedColStat) => + assert(columnStats.contains(field.name)) + val colStat = columnStats(field.name) + StatisticsTest.checkColStat( + dataType = field.dataType, + colStat = colStat, + expectedColStat = expectedColStat, + rsd = spark.sessionState.conf.ndvMaxError) + + // check if we get the same colStat after encoding and decoding + val encodedCS = colStat.toString + val numFields = ColumnStatStruct.numStatFields(field.dataType) + val decodedCS = ColumnStat(numFields, encodedCS) + StatisticsTest.checkColStat( + dataType = field.dataType, + colStat = decodedCS, + expectedColStat = expectedColStat, + rsd = spark.sessionState.conf.ndvMaxError) + } + } + } + + def checkTableStats(tableName: String, expectedRowCount: Option[Int]): Option[Statistics] = { + val df = spark.table(tableName) + val stats = df.queryExecution.analyzed.collect { case rel: LogicalRelation => + assert(rel.catalogTable.get.stats.flatMap(_.rowCount) === expectedRowCount) + rel.catalogTable.get.stats + } + assert(stats.size == 1) + stats.head + } +} + +object StatisticsTest { + def checkColStat( + dataType: DataType, + colStat: ColumnStat, + expectedColStat: ColumnStat, + rsd: Double): Unit = { + dataType match { + case StringType => + val cs = colStat.forString + val expectedCS = expectedColStat.forString + assert(cs.numNulls == expectedCS.numNulls) + assert(cs.avgColLen == expectedCS.avgColLen) + assert(cs.maxColLen == expectedCS.maxColLen) + checkNdv(ndv = cs.ndv, expectedNdv = expectedCS.ndv, rsd = rsd) + case BinaryType => + val cs = colStat.forBinary + val expectedCS = expectedColStat.forBinary + assert(cs.numNulls == expectedCS.numNulls) + assert(cs.avgColLen == expectedCS.avgColLen) + assert(cs.maxColLen == expectedCS.maxColLen) + case BooleanType => + val cs = colStat.forBoolean + val expectedCS = expectedColStat.forBoolean + assert(cs.numNulls == expectedCS.numNulls) + assert(cs.numTrues == expectedCS.numTrues) + assert(cs.numFalses == expectedCS.numFalses) + case atomicType: AtomicType => + checkNumericColStats( + dataType = atomicType, colStat = colStat, expectedColStat = expectedColStat, rsd = rsd) + } + } + + private def checkNumericColStats( + dataType: AtomicType, + colStat: ColumnStat, + expectedColStat: ColumnStat, + rsd: Double): Unit = { + val cs = colStat.forNumeric(dataType) + val expectedCS = expectedColStat.forNumeric(dataType) + assert(cs.numNulls == expectedCS.numNulls) + assert(cs.max == expectedCS.max) + assert(cs.min == expectedCS.min) + checkNdv(ndv = cs.ndv, expectedNdv = expectedCS.ndv, rsd = rsd) + } + + private def checkNdv(ndv: Long, expectedNdv: Long, rsd: Double): Unit = { + // ndv is an approximate value, so we make sure we have the value, and it should be + // within 3*SD's of the given rsd. + if (expectedNdv == 0) { + assert(ndv == 0) + } else if (expectedNdv > 0) { + assert(ndv > 0) + val error = math.abs((ndv / expectedNdv.toDouble) - 1.0d) + assert(error <= rsd * 3.0d, "Error should be within 3 std. errors.") + } + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index d35a681b67e38..261cc6feff090 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -32,8 +32,8 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException import org.apache.spark.sql.catalyst.catalog._ -import org.apache.spark.sql.catalyst.plans.logical.Statistics -import org.apache.spark.sql.execution.command.DDLUtils +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Statistics} +import org.apache.spark.sql.execution.command.{ColumnStatStruct, DDLUtils} import org.apache.spark.sql.execution.datasources.CaseInsensitiveMap import org.apache.spark.sql.hive.client.HiveClient import org.apache.spark.sql.internal.HiveSerDe @@ -401,7 +401,10 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat var statsProperties: Map[String, String] = Map(STATISTICS_TOTAL_SIZE -> stats.sizeInBytes.toString()) if (stats.rowCount.isDefined) { - statsProperties += (STATISTICS_NUM_ROWS -> stats.rowCount.get.toString()) + statsProperties += STATISTICS_NUM_ROWS -> stats.rowCount.get.toString() + } + stats.colStats.foreach { case (colName, colStat) => + statsProperties += (STATISTICS_COL_STATS_PREFIX + colName) -> colStat.toString } tableDefinition.copy(properties = tableDefinition.properties ++ statsProperties) } else { @@ -473,15 +476,21 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat } } // construct Spark's statistics from information in Hive metastore - if (catalogTable.properties.contains(STATISTICS_TOTAL_SIZE)) { - val totalSize = BigInt(catalogTable.properties.get(STATISTICS_TOTAL_SIZE).get) - // TODO: we will compute "estimatedSize" when we have column stats: - // average size of row * number of rows + val statsProps = catalogTable.properties.filterKeys(_.startsWith(STATISTICS_PREFIX)) + if (statsProps.nonEmpty) { + val colStatsProps = statsProps.filterKeys(_.startsWith(STATISTICS_COL_STATS_PREFIX)) + .map { case (k, v) => (k.drop(STATISTICS_COL_STATS_PREFIX.length), v) } + val colStats: Map[String, ColumnStat] = catalogTable.schema.collect { + case f if colStatsProps.contains(f.name) => + val numFields = ColumnStatStruct.numStatFields(f.dataType) + (f.name, ColumnStat(numFields, colStatsProps(f.name))) + }.toMap catalogTable.copy( properties = removeStatsProperties(catalogTable), stats = Some(Statistics( - sizeInBytes = totalSize, - rowCount = catalogTable.properties.get(STATISTICS_NUM_ROWS).map(BigInt(_))))) + sizeInBytes = BigInt(catalogTable.properties(STATISTICS_TOTAL_SIZE)), + rowCount = catalogTable.properties.get(STATISTICS_NUM_ROWS).map(BigInt(_)), + colStats = colStats))) } else { catalogTable } @@ -693,6 +702,7 @@ object HiveExternalCatalog { val STATISTICS_PREFIX = "spark.sql.statistics." val STATISTICS_TOTAL_SIZE = STATISTICS_PREFIX + "totalSize" val STATISTICS_NUM_ROWS = STATISTICS_PREFIX + "numRows" + val STATISTICS_COL_STATS_PREFIX = STATISTICS_PREFIX + "colStats." def removeStatsProperties(metadata: CatalogTable): Map[String, String] = { metadata.properties.filterNot { case (key, _) => key.startsWith(STATISTICS_PREFIX) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 9956706929cd1..99dd080683d40 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -21,16 +21,16 @@ import java.io.{File, PrintWriter} import scala.reflect.ClassTag -import org.apache.spark.sql.{AnalysisException, QueryTest, Row} -import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.plans.logical.Statistics +import org.apache.spark.sql.{AnalysisException, QueryTest, Row, StatisticsTest} +import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Statistics} import org.apache.spark.sql.execution.command.{AnalyzeTableCommand, DDLUtils} import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types._ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { @@ -171,7 +171,27 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils TableIdentifier("tempTable"), ignoreIfNotExists = true, purge = false) } - private def checkStats( + test("analyzing views is not supported") { + def assertAnalyzeUnsupported(analyzeCommand: String): Unit = { + val err = intercept[AnalysisException] { + sql(analyzeCommand) + } + assert(err.message.contains("ANALYZE TABLE is not supported")) + } + + val tableName = "tbl" + withTable(tableName) { + spark.range(10).write.saveAsTable(tableName) + val viewName = "view" + withView(viewName) { + sql(s"CREATE VIEW $viewName AS SELECT * FROM $tableName") + assertAnalyzeUnsupported(s"ANALYZE TABLE $viewName COMPUTE STATISTICS") + assertAnalyzeUnsupported(s"ANALYZE TABLE $viewName COMPUTE STATISTICS FOR COLUMNS id") + } + } + } + + private def checkTableStats( stats: Option[Statistics], hasSizeInBytes: Boolean, expectedRowCounts: Option[Int]): Unit = { @@ -184,7 +204,7 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils } } - private def checkStats( + private def checkTableStats( tableName: String, isDataSourceTable: Boolean, hasSizeInBytes: Boolean, @@ -192,12 +212,12 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils val df = sql(s"SELECT * FROM $tableName") val stats = df.queryExecution.analyzed.collect { case rel: MetastoreRelation => - checkStats(rel.catalogTable.stats, hasSizeInBytes, expectedRowCounts) - assert(!isDataSourceTable, "Expected a data source table, but got a Hive serde table") + checkTableStats(rel.catalogTable.stats, hasSizeInBytes, expectedRowCounts) + assert(!isDataSourceTable, "Expected a Hive serde table, but got a data source table") rel.catalogTable.stats case rel: LogicalRelation => - checkStats(rel.catalogTable.get.stats, hasSizeInBytes, expectedRowCounts) - assert(isDataSourceTable, "Expected a Hive serde table, but got a data source table") + checkTableStats(rel.catalogTable.get.stats, hasSizeInBytes, expectedRowCounts) + assert(isDataSourceTable, "Expected a data source table, but got a Hive serde table") rel.catalogTable.get.stats } assert(stats.size == 1) @@ -210,13 +230,13 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils // Currently Spark's statistics are self-contained, we don't have statistics until we use // the `ANALYZE TABLE` command. sql(s"CREATE TABLE $textTable (key STRING, value STRING) STORED AS TEXTFILE") - checkStats( + checkTableStats( textTable, isDataSourceTable = false, hasSizeInBytes = false, expectedRowCounts = None) sql(s"INSERT INTO TABLE $textTable SELECT * FROM src") - checkStats( + checkTableStats( textTable, isDataSourceTable = false, hasSizeInBytes = false, @@ -224,12 +244,12 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils // noscan won't count the number of rows sql(s"ANALYZE TABLE $textTable COMPUTE STATISTICS noscan") - val fetchedStats1 = checkStats( + val fetchedStats1 = checkTableStats( textTable, isDataSourceTable = false, hasSizeInBytes = true, expectedRowCounts = None) // without noscan, we count the number of rows sql(s"ANALYZE TABLE $textTable COMPUTE STATISTICS") - val fetchedStats2 = checkStats( + val fetchedStats2 = checkTableStats( textTable, isDataSourceTable = false, hasSizeInBytes = true, expectedRowCounts = Some(500)) assert(fetchedStats1.get.sizeInBytes == fetchedStats2.get.sizeInBytes) } @@ -241,19 +261,19 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils sql(s"CREATE TABLE $textTable (key STRING, value STRING) STORED AS TEXTFILE") sql(s"INSERT INTO TABLE $textTable SELECT * FROM src") sql(s"ANALYZE TABLE $textTable COMPUTE STATISTICS") - val fetchedStats1 = checkStats( + val fetchedStats1 = checkTableStats( textTable, isDataSourceTable = false, hasSizeInBytes = true, expectedRowCounts = Some(500)) sql(s"ANALYZE TABLE $textTable COMPUTE STATISTICS noscan") // when the total size is not changed, the old row count is kept - val fetchedStats2 = checkStats( + val fetchedStats2 = checkTableStats( textTable, isDataSourceTable = false, hasSizeInBytes = true, expectedRowCounts = Some(500)) assert(fetchedStats1 == fetchedStats2) sql(s"INSERT INTO TABLE $textTable SELECT * FROM src") sql(s"ANALYZE TABLE $textTable COMPUTE STATISTICS noscan") // update total size and remove the old and invalid row count - val fetchedStats3 = checkStats( + val fetchedStats3 = checkTableStats( textTable, isDataSourceTable = false, hasSizeInBytes = true, expectedRowCounts = None) assert(fetchedStats3.get.sizeInBytes > fetchedStats2.get.sizeInBytes) } @@ -271,20 +291,20 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils // the default value for `spark.sql.hive.convertMetastoreParquet` is true, here we just set it // for robustness withSQLConf("spark.sql.hive.convertMetastoreParquet" -> "true") { - checkStats( + checkTableStats( parquetTable, isDataSourceTable = true, hasSizeInBytes = false, expectedRowCounts = None) sql(s"ANALYZE TABLE $parquetTable COMPUTE STATISTICS") - checkStats( + checkTableStats( parquetTable, isDataSourceTable = true, hasSizeInBytes = true, expectedRowCounts = Some(500)) } withSQLConf("spark.sql.hive.convertMetastoreOrc" -> "true") { - checkStats( + checkTableStats( orcTable, isDataSourceTable = true, hasSizeInBytes = false, expectedRowCounts = None) sql(s"ANALYZE TABLE $orcTable COMPUTE STATISTICS") - checkStats( + checkTableStats( orcTable, isDataSourceTable = true, hasSizeInBytes = true, expectedRowCounts = Some(500)) } } @@ -298,23 +318,23 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils assert(DDLUtils.isDatasourceTable(catalogTable)) sql(s"INSERT INTO TABLE $parquetTable SELECT * FROM src") - checkStats( + checkTableStats( parquetTable, isDataSourceTable = true, hasSizeInBytes = false, expectedRowCounts = None) // noscan won't count the number of rows sql(s"ANALYZE TABLE $parquetTable COMPUTE STATISTICS noscan") - val fetchedStats1 = checkStats( + val fetchedStats1 = checkTableStats( parquetTable, isDataSourceTable = true, hasSizeInBytes = true, expectedRowCounts = None) sql(s"INSERT INTO TABLE $parquetTable SELECT * FROM src") sql(s"ANALYZE TABLE $parquetTable COMPUTE STATISTICS noscan") - val fetchedStats2 = checkStats( + val fetchedStats2 = checkTableStats( parquetTable, isDataSourceTable = true, hasSizeInBytes = true, expectedRowCounts = None) assert(fetchedStats2.get.sizeInBytes > fetchedStats1.get.sizeInBytes) // without noscan, we count the number of rows sql(s"ANALYZE TABLE $parquetTable COMPUTE STATISTICS") - val fetchedStats3 = checkStats( + val fetchedStats3 = checkTableStats( parquetTable, isDataSourceTable = true, hasSizeInBytes = true, @@ -330,7 +350,7 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils val dfNoCols = spark.createDataFrame(rddNoCols, StructType(Seq.empty)) dfNoCols.write.format("json").saveAsTable(table_no_cols) sql(s"ANALYZE TABLE $table_no_cols COMPUTE STATISTICS") - checkStats( + checkTableStats( table_no_cols, isDataSourceTable = true, hasSizeInBytes = true, @@ -338,6 +358,53 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils } } + test("generate column-level statistics and load them from hive metastore") { + import testImplicits._ + + val intSeq = Seq(1, 2) + val stringSeq = Seq("a", "bb") + val booleanSeq = Seq(true, false) + + val data = intSeq.indices.map { i => + (intSeq(i), stringSeq(i), booleanSeq(i)) + } + val tableName = "table" + withTable(tableName) { + val df = data.toDF("c1", "c2", "c3") + df.write.format("parquet").saveAsTable(tableName) + val expectedColStatsSeq = df.schema.map { f => + val colStat = f.dataType match { + case IntegerType => + ColumnStat(InternalRow(0L, intSeq.max, intSeq.min, intSeq.distinct.length.toLong)) + case StringType => + ColumnStat(InternalRow(0L, stringSeq.map(_.length).sum / stringSeq.length.toDouble, + stringSeq.map(_.length).max.toLong, stringSeq.distinct.length.toLong)) + case BooleanType => + ColumnStat(InternalRow(0L, booleanSeq.count(_.equals(true)).toLong, + booleanSeq.count(_.equals(false)).toLong)) + } + (f, colStat) + } + + sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS c1, c2, c3") + val readback = spark.table(tableName) + val relations = readback.queryExecution.analyzed.collect { case rel: LogicalRelation => + val columnStats = rel.catalogTable.get.stats.get.colStats + expectedColStatsSeq.foreach { case (field, expectedColStat) => + assert(columnStats.contains(field.name)) + val colStat = columnStats(field.name) + StatisticsTest.checkColStat( + dataType = field.dataType, + colStat = colStat, + expectedColStat = expectedColStat, + rsd = spark.sessionState.conf.ndvMaxError) + } + rel + } + assert(relations.size == 1) + } + } + test("estimates the size of a test MetastoreRelation") { val df = sql("""SELECT * FROM src""") val sizes = df.queryExecution.analyzed.collect { case mr: MetastoreRelation => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLViewSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLViewSuite.scala index a215c70da0c52..f5c605fe5e2fa 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLViewSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLViewSuite.scala @@ -123,6 +123,7 @@ class SQLViewSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { assertNoSuchTable(s"SHOW CREATE TABLE $viewName") assertNoSuchTable(s"SHOW PARTITIONS $viewName") assertNoSuchTable(s"ANALYZE TABLE $viewName COMPUTE STATISTICS") + assertNoSuchTable(s"ANALYZE TABLE $viewName COMPUTE STATISTICS FOR COLUMNS id") } } From 1dd68d3827133d203e85294405400b04904879e0 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Mon, 3 Oct 2016 18:09:36 +0000 Subject: [PATCH 181/306] [SPARK-17718][DOCS][MLLIB] Make loss function formulation label note clearer in MLlib docs ## What changes were proposed in this pull request? Move note about labels being +1/-1 in formulation only to be just under the table of formulations. ## How was this patch tested? Doc build Author: Sean Owen Closes #15330 from srowen/SPARK-17718. --- docs/mllib-linear-methods.md | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/docs/mllib-linear-methods.md b/docs/mllib-linear-methods.md index 6fcd3ae85700c..816bdf1317000 100644 --- a/docs/mllib-linear-methods.md +++ b/docs/mllib-linear-methods.md @@ -78,6 +78,11 @@ methods `spark.mllib` supports:
      Property NameProperty groupspark-submit equivalent
      spark.masterApplication Properties--master
      spark.yarn.keytabApplication Properties--keytab
      spark.yarn.principalApplication Properties--principal
      spark.driver.memory Application Properties
      +Note that, in the mathematical formulation above, a binary label $y$ is denoted as either +$+1$ (positive) or $-1$ (negative), which is convenient for the formulation. +*However*, the negative label is represented by $0$ in `spark.mllib` instead of $-1$, to be consistent with +multiclass labeling. + ### Regularizers The purpose of the @@ -136,10 +141,6 @@ multiclass classification problems. For both methods, `spark.mllib` supports L1 and L2 regularized variants. The training data set is represented by an RDD of [LabeledPoint](mllib-data-types.html) in MLlib, where labels are class indices starting from zero: $0, 1, 2, \ldots$. -Note that, in the mathematical formulation in this guide, a binary label $y$ is denoted as either -$+1$ (positive) or $-1$ (negative), which is convenient for the formulation. -*However*, the negative label is represented by $0$ in `spark.mllib` instead of $-1$, to be consistent with -multiclass labeling. ### Linear Support Vector Machines (SVMs) From 1f31bdaef670dd43999613deae3620f4ddcd1fbf Mon Sep 17 00:00:00 2001 From: Jason White Date: Mon, 3 Oct 2016 14:12:03 -0700 Subject: [PATCH 182/306] [SPARK-17679] [PYSPARK] remove unnecessary Py4J ListConverter patch ## What changes were proposed in this pull request? This PR removes a patch on ListConverter from https://github.com/apache/spark/pull/5570, as it is no longer necessary. The underlying issue in Py4J https://github.com/bartdag/py4j/issues/160 was patched in https://github.com/bartdag/py4j/commit/224b94b6665e56a93a064073886e1d803a4969d2 and is present in 0.10.3, the version currently in use in Spark. ## How was this patch tested? The original test added in https://github.com/apache/spark/pull/5570 remains. Author: Jason White Closes #15254 from JasonMWhite/remove_listconverter_patch. --- python/pyspark/java_gateway.py | 9 --------- python/pyspark/ml/common.py | 4 ++-- python/pyspark/mllib/common.py | 4 ++-- python/pyspark/rdd.py | 13 ++----------- 4 files changed, 6 insertions(+), 24 deletions(-) diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index 527ca82d31f1b..f76cadcf62438 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -29,18 +29,9 @@ xrange = range from py4j.java_gateway import java_import, JavaGateway, GatewayClient -from py4j.java_collections import ListConverter - from pyspark.serializers import read_int -# patching ListConverter, or it will convert bytearray into Java ArrayList -def can_convert_list(self, obj): - return isinstance(obj, (list, tuple, xrange)) - -ListConverter.can_convert = can_convert_list - - def launch_gateway(): if "PYSPARK_GATEWAY_PORT" in os.environ: gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"]) diff --git a/python/pyspark/ml/common.py b/python/pyspark/ml/common.py index aec860fca7057..387c5d7309dea 100644 --- a/python/pyspark/ml/common.py +++ b/python/pyspark/ml/common.py @@ -23,7 +23,7 @@ import py4j.protocol from py4j.protocol import Py4JJavaError from py4j.java_gateway import JavaObject -from py4j.java_collections import ListConverter, JavaArray, JavaList +from py4j.java_collections import JavaArray, JavaList from pyspark import RDD, SparkContext from pyspark.serializers import PickleSerializer, AutoBatchedSerializer @@ -76,7 +76,7 @@ def _py2java(sc, obj): elif isinstance(obj, SparkContext): obj = obj._jsc elif isinstance(obj, list): - obj = ListConverter().convert([_py2java(sc, x) for x in obj], sc._gateway._gateway_client) + obj = [_py2java(sc, x) for x in obj] elif isinstance(obj, JavaObject): pass elif isinstance(obj, (int, long, float, bool, bytes, unicode)): diff --git a/python/pyspark/mllib/common.py b/python/pyspark/mllib/common.py index 21f0e09ea7742..bac8f350563ec 100644 --- a/python/pyspark/mllib/common.py +++ b/python/pyspark/mllib/common.py @@ -23,7 +23,7 @@ import py4j.protocol from py4j.protocol import Py4JJavaError from py4j.java_gateway import JavaObject -from py4j.java_collections import ListConverter, JavaArray, JavaList +from py4j.java_collections import JavaArray, JavaList from pyspark import RDD, SparkContext from pyspark.serializers import PickleSerializer, AutoBatchedSerializer @@ -78,7 +78,7 @@ def _py2java(sc, obj): elif isinstance(obj, SparkContext): obj = obj._jsc elif isinstance(obj, list): - obj = ListConverter().convert([_py2java(sc, x) for x in obj], sc._gateway._gateway_client) + obj = [_py2java(sc, x) for x in obj] elif isinstance(obj, JavaObject): pass elif isinstance(obj, (int, long, float, bool, bytes, unicode)): diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 5fb10f86f4692..ed81eb16df3cd 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -52,8 +52,6 @@ get_used_memory, ExternalSorter, ExternalGroupBy from pyspark.traceback_utils import SCCallSiteSync -from py4j.java_collections import ListConverter, MapConverter - __all__ = ["RDD"] @@ -2317,16 +2315,9 @@ def _prepare_for_python_RDD(sc, command): # The broadcast will have same life cycle as created PythonRDD broadcast = sc.broadcast(pickled_command) pickled_command = ser.dumps(broadcast) - # There is a bug in py4j.java_gateway.JavaClass with auto_convert - # https://github.com/bartdag/py4j/issues/161 - # TODO: use auto_convert once py4j fix the bug - broadcast_vars = ListConverter().convert( - [x._jbroadcast for x in sc._pickled_broadcast_vars], - sc._gateway._gateway_client) + broadcast_vars = [x._jbroadcast for x in sc._pickled_broadcast_vars] sc._pickled_broadcast_vars.clear() - env = MapConverter().convert(sc.environment, sc._gateway._gateway_client) - includes = ListConverter().convert(sc._python_includes, sc._gateway._gateway_client) - return pickled_command, broadcast_vars, env, includes + return pickled_command, broadcast_vars, sc.environment, sc._python_includes def _wrap_function(sc, func, deserializer, serializer, profiler=None): From d8399b600cef706c22d381b01fab19c610db439a Mon Sep 17 00:00:00 2001 From: zero323 Date: Mon, 3 Oct 2016 17:57:54 -0700 Subject: [PATCH 183/306] [SPARK-17587][PYTHON][MLLIB] SparseVector __getitem__ should follow __getitem__ contract ## What changes were proposed in this pull request? Replaces` ValueError` with `IndexError` when index passed to `ml` / `mllib` `SparseVector.__getitem__` is out of range. This ensures correct iteration behavior. Replaces `ValueError` with `IndexError` for `DenseMatrix` and `SparkMatrix` in `ml` / `mllib`. ## How was this patch tested? PySpark `ml` / `mllib` unit tests. Additional unit tests to prove that the problem has been resolved. Author: zero323 Closes #15144 from zero323/SPARK-17587. --- python/pyspark/ml/linalg/__init__.py | 10 +++++----- python/pyspark/ml/tests.py | 16 +++++++++++++--- python/pyspark/mllib/linalg/__init__.py | 10 +++++----- python/pyspark/mllib/tests.py | 16 +++++++++++++--- 4 files changed, 36 insertions(+), 16 deletions(-) diff --git a/python/pyspark/ml/linalg/__init__.py b/python/pyspark/ml/linalg/__init__.py index 05c0ac862fb7f..a5df727fdb418 100644 --- a/python/pyspark/ml/linalg/__init__.py +++ b/python/pyspark/ml/linalg/__init__.py @@ -713,7 +713,7 @@ def __getitem__(self, index): "Indices must be of type integer, got type %s" % type(index)) if index >= self.size or index < -self.size: - raise ValueError("Index %d out of bounds." % index) + raise IndexError("Index %d out of bounds." % index) if index < 0: index += self.size @@ -960,10 +960,10 @@ def toSparse(self): def __getitem__(self, indices): i, j = indices if i < 0 or i >= self.numRows: - raise ValueError("Row index %d is out of range [0, %d)" + raise IndexError("Row index %d is out of range [0, %d)" % (i, self.numRows)) if j >= self.numCols or j < 0: - raise ValueError("Column index %d is out of range [0, %d)" + raise IndexError("Column index %d is out of range [0, %d)" % (j, self.numCols)) if self.isTransposed: @@ -1090,10 +1090,10 @@ def __reduce__(self): def __getitem__(self, indices): i, j = indices if i < 0 or i >= self.numRows: - raise ValueError("Row index %d is out of range [0, %d)" + raise IndexError("Row index %d is out of range [0, %d)" % (i, self.numRows)) if j < 0 or j >= self.numCols: - raise ValueError("Column index %d is out of range [0, %d)" + raise IndexError("Column index %d is out of range [0, %d)" % (j, self.numCols)) # If a CSR matrix is given, then the row index should be searched diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 6886ed321ee82..e233549850888 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -1316,7 +1316,7 @@ def test_sparse_vector_indexing(self): self.assertEqual(sv[-3], 0.) self.assertEqual(sv[-5], 0.) for ind in [5, -6]: - self.assertRaises(ValueError, sv.__getitem__, ind) + self.assertRaises(IndexError, sv.__getitem__, ind) for ind in [7.8, '1']: self.assertRaises(TypeError, sv.__getitem__, ind) @@ -1324,11 +1324,15 @@ def test_sparse_vector_indexing(self): self.assertEqual(zeros[0], 0.0) self.assertEqual(zeros[3], 0.0) for ind in [4, -5]: - self.assertRaises(ValueError, zeros.__getitem__, ind) + self.assertRaises(IndexError, zeros.__getitem__, ind) empty = SparseVector(0, {}) for ind in [-1, 0, 1]: - self.assertRaises(ValueError, empty.__getitem__, ind) + self.assertRaises(IndexError, empty.__getitem__, ind) + + def test_sparse_vector_iteration(self): + self.assertListEqual(list(SparseVector(3, [], [])), [0.0, 0.0, 0.0]) + self.assertListEqual(list(SparseVector(5, [0, 3], [1.0, 2.0])), [1.0, 0.0, 0.0, 2.0, 0.0]) def test_matrix_indexing(self): mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10]) @@ -1337,6 +1341,9 @@ def test_matrix_indexing(self): for j in range(2): self.assertEqual(mat[i, j], expected[i][j]) + for i, j in [(-1, 0), (4, 1), (3, 4)]: + self.assertRaises(IndexError, mat.__getitem__, (i, j)) + def test_repr_dense_matrix(self): mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10]) self.assertTrue( @@ -1408,6 +1415,9 @@ def test_sparse_matrix(self): self.assertEqual(expected[i][j], sm1[i, j]) self.assertTrue(array_equal(sm1.toArray(), expected)) + for i, j in [(-1, 1), (4, 3), (3, 5)]: + self.assertRaises(IndexError, sm1.__getitem__, (i, j)) + # Test conversion to dense and sparse. smnew = sm1.toDense().toSparse() self.assertEqual(sm1.numRows, smnew.numRows) diff --git a/python/pyspark/mllib/linalg/__init__.py b/python/pyspark/mllib/linalg/__init__.py index 9672dbde823f2..d37e715c8d8ec 100644 --- a/python/pyspark/mllib/linalg/__init__.py +++ b/python/pyspark/mllib/linalg/__init__.py @@ -802,7 +802,7 @@ def __getitem__(self, index): "Indices must be of type integer, got type %s" % type(index)) if index >= self.size or index < -self.size: - raise ValueError("Index %d out of bounds." % index) + raise IndexError("Index %d out of bounds." % index) if index < 0: index += self.size @@ -1115,10 +1115,10 @@ def asML(self): def __getitem__(self, indices): i, j = indices if i < 0 or i >= self.numRows: - raise ValueError("Row index %d is out of range [0, %d)" + raise IndexError("Row index %d is out of range [0, %d)" % (i, self.numRows)) if j >= self.numCols or j < 0: - raise ValueError("Column index %d is out of range [0, %d)" + raise IndexError("Column index %d is out of range [0, %d)" % (j, self.numCols)) if self.isTransposed: @@ -1245,10 +1245,10 @@ def __reduce__(self): def __getitem__(self, indices): i, j = indices if i < 0 or i >= self.numRows: - raise ValueError("Row index %d is out of range [0, %d)" + raise IndexError("Row index %d is out of range [0, %d)" % (i, self.numRows)) if j < 0 or j >= self.numCols: - raise ValueError("Column index %d is out of range [0, %d)" + raise IndexError("Column index %d is out of range [0, %d)" % (j, self.numCols)) # If a CSR matrix is given, then the row index should be searched diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 3f3dfd186c10d..c519883cdd73b 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -260,7 +260,7 @@ def test_sparse_vector_indexing(self): self.assertEqual(sv[-3], 0.) self.assertEqual(sv[-5], 0.) for ind in [5, -6]: - self.assertRaises(ValueError, sv.__getitem__, ind) + self.assertRaises(IndexError, sv.__getitem__, ind) for ind in [7.8, '1']: self.assertRaises(TypeError, sv.__getitem__, ind) @@ -268,11 +268,15 @@ def test_sparse_vector_indexing(self): self.assertEqual(zeros[0], 0.0) self.assertEqual(zeros[3], 0.0) for ind in [4, -5]: - self.assertRaises(ValueError, zeros.__getitem__, ind) + self.assertRaises(IndexError, zeros.__getitem__, ind) empty = SparseVector(0, {}) for ind in [-1, 0, 1]: - self.assertRaises(ValueError, empty.__getitem__, ind) + self.assertRaises(IndexError, empty.__getitem__, ind) + + def test_sparse_vector_iteration(self): + self.assertListEqual(list(SparseVector(3, [], [])), [0.0, 0.0, 0.0]) + self.assertListEqual(list(SparseVector(5, [0, 3], [1.0, 2.0])), [1.0, 0.0, 0.0, 2.0, 0.0]) def test_matrix_indexing(self): mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10]) @@ -281,6 +285,9 @@ def test_matrix_indexing(self): for j in range(2): self.assertEqual(mat[i, j], expected[i][j]) + for i, j in [(-1, 0), (4, 1), (3, 4)]: + self.assertRaises(IndexError, mat.__getitem__, (i, j)) + def test_repr_dense_matrix(self): mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10]) self.assertTrue( @@ -352,6 +359,9 @@ def test_sparse_matrix(self): self.assertEqual(expected[i][j], sm1[i, j]) self.assertTrue(array_equal(sm1.toArray(), expected)) + for i, j in [(-1, 1), (4, 3), (3, 5)]: + self.assertRaises(IndexError, sm1.__getitem__, (i, j)) + # Test conversion to dense and sparse. smnew = sm1.toDense().toSparse() self.assertEqual(sm1.numRows, smnew.numRows) From 2bbecdec2023143fd144e4242ff70822e0823986 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Mon, 3 Oct 2016 19:32:59 -0700 Subject: [PATCH 184/306] [SPARK-17753][SQL] Allow a complex expression as the input a value based case statement ## What changes were proposed in this pull request? We currently only allow relatively simple expressions as the input for a value based case statement. Expressions like `case (a > 1) or (b = 2) when true then 1 when false then 0 end` currently fail. This PR adds support for such expressions. ## How was this patch tested? Added a test to the ExpressionParserSuite. Author: Herman van Hovell Closes #15322 from hvanhovell/SPARK-17753. --- .../org/apache/spark/sql/catalyst/parser/SqlBase.g4 | 12 ++++++------ .../spark/sql/catalyst/parser/AstBuilder.scala | 2 +- .../sql/catalyst/parser/ExpressionParserSuite.scala | 4 ++++ 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 1284681fe80b4..c336a0c8eab7a 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -527,16 +527,16 @@ valueExpression ; primaryExpression - : constant #constantDefault - | name=(CURRENT_DATE | CURRENT_TIMESTAMP) #timeFunctionCall + : name=(CURRENT_DATE | CURRENT_TIMESTAMP) #timeFunctionCall + | CASE value=expression whenClause+ (ELSE elseExpression=expression)? END #simpleCase + | CASE whenClause+ (ELSE elseExpression=expression)? END #searchedCase + | CAST '(' expression AS dataType ')' #cast + | constant #constantDefault | ASTERISK #star | qualifiedName '.' ASTERISK #star | '(' expression (',' expression)+ ')' #rowConstructor - | qualifiedName '(' (setQuantifier? expression (',' expression)*)? ')' (OVER windowSpec)? #functionCall | '(' query ')' #subqueryExpression - | CASE valueExpression whenClause+ (ELSE elseExpression=expression)? END #simpleCase - | CASE whenClause+ (ELSE elseExpression=expression)? END #searchedCase - | CAST '(' expression AS dataType ')' #cast + | qualifiedName '(' (setQuantifier? expression (',' expression)*)? ')' (OVER windowSpec)? #functionCall | value=primaryExpression '[' index=valueExpression ']' #subscript | identifier #columnReference | base=primaryExpression '.' fieldName=identifier #dereference diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 12a70b7769ef6..cd0c70a49150d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1138,7 +1138,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { * }}} */ override def visitSimpleCase(ctx: SimpleCaseContext): Expression = withOrigin(ctx) { - val e = expression(ctx.valueExpression) + val e = expression(ctx.value) val branches = ctx.whenClause.asScala.map { wCtx => (EqualTo(e, expression(wCtx.condition)), expression(wCtx.result)) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index f319215f05681..3718ac5f1e77b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -292,6 +292,10 @@ class ExpressionParserSuite extends PlanTest { test("case when") { assertEqual("case a when 1 then b when 2 then c else d end", CaseKeyWhen('a, Seq(1, 'b, 2, 'c, 'd))) + assertEqual("case (a or b) when true then c when false then d else e end", + CaseKeyWhen('a || 'b, Seq(true, 'c, false, 'd, 'e))) + assertEqual("case 'a'='a' when true then 1 end", + CaseKeyWhen("a" === "a", Seq(true, 1))) assertEqual("case when a = 1 then b when a = 2 then c else d end", CaseWhen(Seq(('a === 1, 'b.expr), ('a === 2, 'c.expr)), 'd)) } From c571cfb2d0e1e224107fc3f0c672730cae9804cb Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Mon, 3 Oct 2016 21:28:16 -0700 Subject: [PATCH 185/306] [SPARK-17112][SQL] "select null" via JDBC triggers IllegalArgumentException in Thriftserver ## What changes were proposed in this pull request? Currently, Spark Thrift Server raises `IllegalArgumentException` for queries whose column types are `NullType`, e.g., `SELECT null` or `SELECT if(true,null,null)`. This PR fixes that by returning `void` like Hive 1.2. **Before** ```sql $ bin/beeline -u jdbc:hive2://localhost:10000 -e "select null" Connecting to jdbc:hive2://localhost:10000 Connected to: Spark SQL (version 2.1.0-SNAPSHOT) Driver: Hive JDBC (version 1.2.1.spark2) Transaction isolation: TRANSACTION_REPEATABLE_READ Error: java.lang.IllegalArgumentException: Unrecognized type name: null (state=,code=0) Closing: 0: jdbc:hive2://localhost:10000 $ bin/beeline -u jdbc:hive2://localhost:10000 -e "select if(true,null,null)" Connecting to jdbc:hive2://localhost:10000 Connected to: Spark SQL (version 2.1.0-SNAPSHOT) Driver: Hive JDBC (version 1.2.1.spark2) Transaction isolation: TRANSACTION_REPEATABLE_READ Error: java.lang.IllegalArgumentException: Unrecognized type name: null (state=,code=0) Closing: 0: jdbc:hive2://localhost:10000 ``` **After** ```sql $ bin/beeline -u jdbc:hive2://localhost:10000 -e "select null" Connecting to jdbc:hive2://localhost:10000 Connected to: Spark SQL (version 2.1.0-SNAPSHOT) Driver: Hive JDBC (version 1.2.1.spark2) Transaction isolation: TRANSACTION_REPEATABLE_READ +-------+--+ | NULL | +-------+--+ | NULL | +-------+--+ 1 row selected (3.242 seconds) Beeline version 1.2.1.spark2 by Apache Hive Closing: 0: jdbc:hive2://localhost:10000 $ bin/beeline -u jdbc:hive2://localhost:10000 -e "select if(true,null,null)" Connecting to jdbc:hive2://localhost:10000 Connected to: Spark SQL (version 2.1.0-SNAPSHOT) Driver: Hive JDBC (version 1.2.1.spark2) Transaction isolation: TRANSACTION_REPEATABLE_READ +-------------------------+--+ | (IF(true, NULL, NULL)) | +-------------------------+--+ | NULL | +-------------------------+--+ 1 row selected (0.201 seconds) Beeline version 1.2.1.spark2 by Apache Hive Closing: 0: jdbc:hive2://localhost:10000 ``` ## How was this patch tested? * Pass the Jenkins test with a new testsuite. * Also, Manually, after starting Spark Thrift Server, run the following command. ```sql $ bin/beeline -u jdbc:hive2://localhost:10000 -e "select null" $ bin/beeline -u jdbc:hive2://localhost:10000 -e "select if(true,null,null)" ``` **Hive 1.2** ```sql hive> create table null_table as select null; hive> desc null_table; OK _c0 void ``` Author: Dongjoon Hyun Closes #15325 from dongjoon-hyun/SPARK-17112. --- .../SparkExecuteStatementOperation.scala | 19 +++++++---- .../SparkExecuteStatementOperationSuite.scala | 33 +++++++++++++++++++ 2 files changed, 46 insertions(+), 6 deletions(-) create mode 100644 sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperationSuite.scala diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index e555ebd623f72..aeabd6a15881d 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -56,14 +56,11 @@ private[hive] class SparkExecuteStatementOperation( private var statementId: String = _ private lazy val resultSchema: TableSchema = { - if (result == null || result.queryExecution.analyzed.output.size == 0) { + if (result == null || result.schema.isEmpty) { new TableSchema(Arrays.asList(new FieldSchema("Result", "string", ""))) } else { - logInfo(s"Result Schema: ${result.queryExecution.analyzed.output}") - val schema = result.queryExecution.analyzed.output.map { attr => - new FieldSchema(attr.name, attr.dataType.catalogString, "") - } - new TableSchema(schema.asJava) + logInfo(s"Result Schema: ${result.schema}") + SparkExecuteStatementOperation.getTableSchema(result.schema) } } @@ -282,3 +279,13 @@ private[hive] class SparkExecuteStatementOperation( } } } + +object SparkExecuteStatementOperation { + def getTableSchema(structType: StructType): TableSchema = { + val schema = structType.map { field => + val attrTypeString = if (field.dataType == NullType) "void" else field.dataType.catalogString + new FieldSchema(field.name, attrTypeString, "") + } + new TableSchema(schema.asJava) + } +} diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperationSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperationSuite.scala new file mode 100644 index 0000000000000..32ded0d254ef8 --- /dev/null +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperationSuite.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types.{NullType, StructField, StructType} + +class SparkExecuteStatementOperationSuite extends SparkFunSuite { + test("SPARK-17112 `select null` via JDBC triggers IllegalArgumentException in ThriftServer") { + val field1 = StructField("NULL", NullType) + val field2 = StructField("(IF(true, NULL, NULL))", NullType) + val tableSchema = StructType(Seq(field1, field2)) + val columns = SparkExecuteStatementOperation.getTableSchema(tableSchema).getColumnDescriptors() + assert(columns.size() == 2) + assert(columns.get(0).getType() == org.apache.hive.service.cli.Type.NULL_TYPE) + assert(columns.get(1).getType() == org.apache.hive.service.cli.Type.NULL_TYPE) + } +} From b1b47274bfeba17a9e4e9acebd7385289f31f6c8 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Mon, 3 Oct 2016 21:48:58 -0700 Subject: [PATCH 186/306] [SPARK-17702][SQL] Code generation including too many mutable states exceeds JVM size limit. ## What changes were proposed in this pull request? Code generation including too many mutable states exceeds JVM size limit to extract values from `references` into fields in the constructor. We should split the generated extractions in the constructor into smaller functions. ## How was this patch tested? I added some tests to check if the generated codes for the expressions exceed or not. Author: Takuya UESHIN Closes #15275 from ueshin/issues/SPARK-17702. --- .../expressions/codegen/CodeGenerator.scala | 18 +++++++++++----- .../codegen/GenerateMutableProjection.scala | 3 ++- .../codegen/GenerateOrdering.scala | 3 ++- .../codegen/GeneratePredicate.scala | 4 +++- .../codegen/GenerateSafeProjection.scala | 4 +++- .../codegen/GenerateUnsafeProjection.scala | 3 ++- .../expressions/CodeGenerationSuite.scala | 21 ++++++++++++++++++- .../sql/execution/WholeStageCodegenExec.scala | 4 +++- 8 files changed, 48 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index cb808e375a35f..574943d3d21f0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -178,7 +178,10 @@ class CodegenContext { def initMutableStates(): String = { // It's possible that we add same mutable state twice, e.g. the `mergeExpressions` in // `TypedAggregateExpression`, we should call `distinct` here to remove the duplicated ones. - mutableStates.distinct.map(_._3).mkString("\n") + val initCodes = mutableStates.distinct.map(_._3 + "\n") + // The generated initialization code may exceed 64kb function size limit in JVM if there are too + // many mutable states, so split it into multiple functions. + splitExpressions(initCodes, "init", Nil) } /** @@ -604,6 +607,11 @@ class CodegenContext { // Cannot split these expressions because they are not created from a row object. return expressions.mkString("\n") } + splitExpressions(expressions, "apply", ("InternalRow", row) :: Nil) + } + + private def splitExpressions( + expressions: Seq[String], funcName: String, arguments: Seq[(String, String)]): String = { val blocks = new ArrayBuffer[String]() val blockBuilder = new StringBuilder() for (code <- expressions) { @@ -623,11 +631,11 @@ class CodegenContext { // inline execution if only one block blocks.head } else { - val apply = freshName("apply") + val func = freshName(funcName) val functions = blocks.zipWithIndex.map { case (body, i) => - val name = s"${apply}_$i" + val name = s"${func}_$i" val code = s""" - |private void $name(InternalRow $row) { + |private void $name(${arguments.map { case (t, name) => s"$t $name" }.mkString(", ")}) { | $body |} """.stripMargin @@ -635,7 +643,7 @@ class CodegenContext { name } - functions.map(name => s"$name($row);").mkString("\n") + functions.map(name => s"$name(${arguments.map(_._2).mkString(", ")});").mkString("\n") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 0f82d2e613c73..13d61af1c9b40 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -104,7 +104,6 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP private Object[] references; private MutableRow mutableRow; ${ctx.declareMutableStates()} - ${ctx.declareAddedFunctions()} public SpecificMutableProjection(Object[] references) { this.references = references; @@ -112,6 +111,8 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP ${ctx.initMutableStates()} } + ${ctx.declareAddedFunctions()} + public ${classOf[BaseMutableProjection].getName} target(MutableRow row) { mutableRow = row; return this; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index f1c30ef6c7fb8..1cef95654a17b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -133,13 +133,14 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR private Object[] references; ${ctx.declareMutableStates()} - ${ctx.declareAddedFunctions()} public SpecificOrdering(Object[] references) { this.references = references; ${ctx.initMutableStates()} } + ${ctx.declareAddedFunctions()} + public int compare(InternalRow a, InternalRow b) { InternalRow ${ctx.INPUT_ROW} = null; // Holds current row being evaluated. $comparisons diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala index 106bb27964cab..39aa7b17de6c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala @@ -40,6 +40,7 @@ object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Bool protected def create(predicate: Expression): ((InternalRow) => Boolean) = { val ctx = newCodeGenContext() val eval = predicate.genCode(ctx) + val codeBody = s""" public SpecificPredicate generate(Object[] references) { return new SpecificPredicate(references); @@ -48,13 +49,14 @@ object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Bool class SpecificPredicate extends ${classOf[Predicate].getName} { private final Object[] references; ${ctx.declareMutableStates()} - ${ctx.declareAddedFunctions()} public SpecificPredicate(Object[] references) { this.references = references; ${ctx.initMutableStates()} } + ${ctx.declareAddedFunctions()} + public boolean eval(InternalRow ${ctx.INPUT_ROW}) { ${eval.code} return !${eval.isNull} && ${eval.value}; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index b891f94673752..1c98c9ed10705 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -155,6 +155,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] """ } val allExpressions = ctx.splitExpressions(ctx.INPUT_ROW, expressionCodes) + val codeBody = s""" public java.lang.Object generate(Object[] references) { return new SpecificSafeProjection(references); @@ -165,7 +166,6 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] private Object[] references; private MutableRow mutableRow; ${ctx.declareMutableStates()} - ${ctx.declareAddedFunctions()} public SpecificSafeProjection(Object[] references) { this.references = references; @@ -173,6 +173,8 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] ${ctx.initMutableStates()} } + ${ctx.declareAddedFunctions()} + public java.lang.Object apply(java.lang.Object _i) { InternalRow ${ctx.INPUT_ROW} = (InternalRow) _i; $allExpressions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 75bb6936b49e0..7cc45372daa5a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -374,13 +374,14 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro private Object[] references; ${ctx.declareMutableStates()} - ${ctx.declareAddedFunctions()} public SpecificUnsafeProjection(Object[] references) { this.references = references; ${ctx.initMutableStates()} } + ${ctx.declareAddedFunctions()} + // Scala.Function1 need this public java.lang.Object apply(java.lang.Object row) { return apply((InternalRow) row); diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 45dcfcaf23132..5588b4429164c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import java.sql.Timestamp + import org.apache.spark.SparkFunSuite import org.apache.spark.metrics.source.CodegenMetrics import org.apache.spark.sql.Row @@ -24,7 +26,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.objects.{CreateExternalRow, GetExternalRowField, ValidateExternalType} -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.ThreadUtils @@ -164,6 +166,23 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { } } + test("SPARK-17702: split wide constructor into blocks due to JVM code size limit") { + val length = 5000 + val expressions = Seq.fill(length) { + ToUTCTimestamp( + Literal.create(Timestamp.valueOf("2015-07-24 00:00:00"), TimestampType), + Literal.create("PST", StringType)) + } + val plan = GenerateMutableProjection.generate(expressions) + val actual = plan(new GenericMutableRow(length)).toSeq(expressions.map(_.dataType)) + val expected = Seq.fill(length)( + DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2015-07-24 07:00:00"))) + + if (!checkResult(actual, expected)) { + fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected") + } + } + test("test generated safe and unsafe projection") { val schema = new StructType(Array( StructField("a", StringType, true), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index fb57ed7692de4..62bf6f4a81eec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -316,14 +316,16 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { private Object[] references; + private scala.collection.Iterator[] inputs; ${ctx.declareMutableStates()} public GeneratedIterator(Object[] references) { this.references = references; } - public void init(int index, scala.collection.Iterator inputs[]) { + public void init(int index, scala.collection.Iterator[] inputs) { partitionIndex = index; + this.inputs = inputs; ${ctx.initMutableStates()} } From d2dc8c4a162834818190ffd82894522c524ca3e5 Mon Sep 17 00:00:00 2001 From: Ergin Seyfe Date: Mon, 3 Oct 2016 23:28:39 -0700 Subject: [PATCH 187/306] [SPARK-17773] Input/Output] Add VoidObjectInspector ## What changes were proposed in this pull request? Added VoidObjectInspector to the list of PrimitiveObjectInspectors ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) Executing following query was failing. select SOME_UDAF*(a.arr) from ( select Array(null) as arr from dim_one_row ) a After the fix, I am getting the correct output: res0: Array[org.apache.spark.sql.Row] = Array([null]) Author: Ergin Seyfe Closes #15337 from seyfe/add_void_object_inspector. --- .../main/scala/org/apache/spark/sql/hive/HiveInspectors.scala | 2 ++ .../scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala | 1 + 2 files changed, 3 insertions(+) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index c3c4351cf58a9..fe34caa0a3e48 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -319,6 +319,8 @@ private[hive] trait HiveInspectors { withNullSafe(o => getTimestampWritable(o)) case _: TimestampObjectInspector => withNullSafe(o => DateTimeUtils.toJavaTimestamp(o.asInstanceOf[Long])) + case _: VoidObjectInspector => + (_: Any) => null // always be null for void object inspector } case soi: StandardStructObjectInspector => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala index bc51bcb07ec2a..3de1f4aeb74dc 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala @@ -81,6 +81,7 @@ class HiveInspectorSuite extends SparkFunSuite with HiveInspectors { val data = Literal(true) :: + Literal(null) :: Literal(0.asInstanceOf[Byte]) :: Literal(0.asInstanceOf[Short]) :: Literal(0) :: From 126baa8d32bc0e7bf8b43f9efa84f2728f02347d Mon Sep 17 00:00:00 2001 From: ding Date: Tue, 4 Oct 2016 00:00:10 -0700 Subject: [PATCH 188/306] [SPARK-17559][MLLIB] persist edges if their storage level is non in PeriodicGraphCheckpointer ## What changes were proposed in this pull request? When use PeriodicGraphCheckpointer to persist graph, sometimes the edges isn't persisted. As currently only when vertices's storage level is none, graph is persisted. However there is a chance vertices's storage level is not none while edges's is none. Eg. graph created by a outerJoinVertices operation, vertices is automatically cached while edges is not. In this way, edges will not be persisted if we use PeriodicGraphCheckpointer do persist. We need separately check edges's storage level and persisted it if it's none. ## How was this patch tested? manual tests Author: ding Closes #15124 from dding3/spark-persisitEdge. --- .../apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala index 20db6084d0e0d..80074897567eb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala @@ -87,7 +87,10 @@ private[mllib] class PeriodicGraphCheckpointer[VD, ED]( override protected def persist(data: Graph[VD, ED]): Unit = { if (data.vertices.getStorageLevel == StorageLevel.NONE) { - data.persist() + data.vertices.persist() + } + if (data.edges.getStorageLevel == StorageLevel.NONE) { + data.edges.persist() } } From 8e8de0073d71bb00baeb24c612d7841b6274f652 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Tue, 4 Oct 2016 10:29:22 +0100 Subject: [PATCH 189/306] [SPARK-17671][WEBUI] Spark 2.0 history server summary page is slow even set spark.history.ui.maxApplications ## What changes were proposed in this pull request? Return Iterator of applications internally in history server, for consistency and performance. See https://github.com/apache/spark/pull/15248 for some back-story. The code called by and calling HistoryServer.getApplicationList wants an Iterator, but this method materializes an Iterable, which potentially causes a performance problem. It's simpler too to make this internal method also pass through an Iterator. ## How was this patch tested? Existing tests. Author: Sean Owen Closes #15321 from srowen/SPARK-17671. --- .../history/ApplicationHistoryProvider.scala | 2 +- .../deploy/history/FsHistoryProvider.scala | 2 +- .../spark/deploy/history/HistoryPage.scala | 5 +-- .../spark/deploy/history/HistoryServer.scala | 4 +- .../api/v1/ApplicationListResource.scala | 38 +++++++------------ .../deploy/history/HistoryServerSuite.scala | 4 +- project/MimaExcludes.scala | 2 + 7 files changed, 22 insertions(+), 35 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala index ba42b4862aa90..ad7a0972ef9d1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala @@ -77,7 +77,7 @@ private[history] abstract class ApplicationHistoryProvider { * * @return List of all know applications. */ - def getListing(): Iterable[ApplicationHistoryInfo] + def getListing(): Iterator[ApplicationHistoryInfo] /** * Returns the Spark UI for a specific application. diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index c5740e4737094..3c2d169f3270e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -222,7 +222,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } } - override def getListing(): Iterable[FsApplicationHistoryInfo] = applications.values + override def getListing(): Iterator[FsApplicationHistoryInfo] = applications.values.iterator override def getApplicationInfo(appId: String): Option[FsApplicationHistoryInfo] = { applications.get(appId) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala index b4f5a6114f3de..95b72224e0f94 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala @@ -29,10 +29,7 @@ private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") val requestedIncomplete = Option(request.getParameter("showIncomplete")).getOrElse("false").toBoolean - val allApps = parent.getApplicationList() - .filter(_.completed != requestedIncomplete) - val allAppsSize = allApps.size - + val allAppsSize = parent.getApplicationList().count(_.completed != requestedIncomplete) val providerConfig = parent.getProviderConfig() val content =
      diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index 735aa43cfc994..087c69e6489dd 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -174,12 +174,12 @@ class HistoryServer( * * @return List of all known applications. */ - def getApplicationList(): Iterable[ApplicationHistoryInfo] = { + def getApplicationList(): Iterator[ApplicationHistoryInfo] = { provider.getListing() } def getApplicationInfoList: Iterator[ApplicationInfo] = { - getApplicationList().iterator.map(ApplicationsListResource.appHistoryInfoToPublicAppInfo) + getApplicationList().map(ApplicationsListResource.appHistoryInfoToPublicAppInfo) } def getApplicationInfo(appId: String): Option[ApplicationInfo] = { diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala index 075b9ba37dc84..76779290d45e6 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala @@ -16,7 +16,7 @@ */ package org.apache.spark.status.api.v1 -import java.util.{Arrays, Date, List => JList} +import java.util.{Date, List => JList} import javax.ws.rs.{DefaultValue, GET, Produces, QueryParam} import javax.ws.rs.core.MediaType @@ -32,33 +32,21 @@ private[v1] class ApplicationListResource(uiRoot: UIRoot) { @DefaultValue("3000-01-01") @QueryParam("maxDate") maxDate: SimpleDateParam, @QueryParam("limit") limit: Integer) : Iterator[ApplicationInfo] = { - val allApps = uiRoot.getApplicationInfoList - val adjStatus = { - if (status.isEmpty) { - Arrays.asList(ApplicationStatus.values(): _*) - } else { - status - } - } - val includeCompleted = adjStatus.contains(ApplicationStatus.COMPLETED) - val includeRunning = adjStatus.contains(ApplicationStatus.RUNNING) - val appList = allApps.filter { app => + + val numApps = Option(limit).map(_.toInt).getOrElse(Integer.MAX_VALUE) + val includeCompleted = status.isEmpty || status.contains(ApplicationStatus.COMPLETED) + val includeRunning = status.isEmpty || status.contains(ApplicationStatus.RUNNING) + + uiRoot.getApplicationInfoList.filter { app => val anyRunning = app.attempts.exists(!_.completed) - // if any attempt is still running, we consider the app to also still be running - val statusOk = (!anyRunning && includeCompleted) || - (anyRunning && includeRunning) + // if any attempt is still running, we consider the app to also still be running; // keep the app if *any* attempts fall in the right time window - val dateOk = app.attempts.exists { attempt => - attempt.startTime.getTime >= minDate.timestamp && - attempt.startTime.getTime <= maxDate.timestamp + ((!anyRunning && includeCompleted) || (anyRunning && includeRunning)) && + app.attempts.exists { attempt => + val start = attempt.startTime.getTime + start >= minDate.timestamp && start <= maxDate.timestamp } - statusOk && dateOk - } - if (limit != null) { - appList.take(limit) - } else { - appList - } + }.take(numApps) } } diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index ae3f5d9c012ea..5b316b2f6b4b7 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -447,7 +447,7 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers assert(4 === getNumJobsRestful(), s"two jobs back-to-back not updated, server=$server\n") } val jobcount = getNumJobs("/jobs") - assert(!provider.getListing().head.completed) + assert(!provider.getListing().next.completed) listApplications(false) should contain(appId) @@ -455,7 +455,7 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers resetSparkContext() // check the app is now found as completed eventually(stdTimeout, stdInterval) { - assert(provider.getListing().head.completed, + assert(provider.getListing().next.completed, s"application never completed, server=$server\n") } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 7362041428b1f..163e3f2fdea40 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -37,6 +37,8 @@ object MimaExcludes { // Exclude rules for 2.1.x lazy val v21excludes = v20excludes ++ { Seq( + // [SPARK-17671] Spark 2.0 history server summary page is slow even set spark.history.ui.maxApplications + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.deploy.history.HistoryServer.getApplicationList"), // [SPARK-14743] Improve delegation token handling in secure cluster ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkHadoopUtil.getTimeFromNowToRenewal"), // [SPARK-16199][SQL] Add a method to list the referenced columns in data source Filter From 7d5160883542f3d9dcb3babda92880985398e9af Mon Sep 17 00:00:00 2001 From: sumansomasundar Date: Tue, 4 Oct 2016 10:31:56 +0100 Subject: [PATCH 190/306] [SPARK-16962][CORE][SQL] Fix misaligned record accesses for SPARC architectures ## What changes were proposed in this pull request? Made changes to record length offsets to make them uniform throughout various areas of Spark core and unsafe ## How was this patch tested? This change affects only SPARC architectures and was tested on X86 architectures as well for regression. Author: sumansomasundar Closes #14762 from sumansomasundar/master. --- .../spark/unsafe/UnsafeAlignedOffset.java | 58 +++++++++++++++++++ .../spark/unsafe/array/ByteArrayMethods.java | 31 +++++++--- .../spark/unsafe/map/BytesToBytesMap.java | 57 +++++++++--------- .../unsafe/sort/UnsafeExternalSorter.java | 19 +++--- .../unsafe/sort/UnsafeInMemorySorter.java | 14 +++-- .../CompressibleColumnBuilder.scala | 11 +++- 6 files changed, 144 insertions(+), 46 deletions(-) create mode 100644 common/unsafe/src/main/java/org/apache/spark/unsafe/UnsafeAlignedOffset.java diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/UnsafeAlignedOffset.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/UnsafeAlignedOffset.java new file mode 100644 index 0000000000000..be62e40412f83 --- /dev/null +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/UnsafeAlignedOffset.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe; + +/** + * Class to make changes to record length offsets uniform through out + * various areas of Apache Spark core and unsafe. The SPARC platform + * requires this because using a 4 byte Int for record lengths causes + * the entire record of 8 byte Items to become misaligned by 4 bytes. + * Using a 8 byte long for record length keeps things 8 byte aligned. + */ +public class UnsafeAlignedOffset { + + private static final int UAO_SIZE = Platform.unaligned() ? 4 : 8; + + public static int getUaoSize() { + return UAO_SIZE; + } + + public static int getSize(Object object, long offset) { + switch (UAO_SIZE) { + case 4: + return Platform.getInt(object, offset); + case 8: + return (int)Platform.getLong(object, offset); + default: + throw new AssertionError("Illegal UAO_SIZE"); + } + } + + public static void putSize(Object object, long offset, int value) { + switch (UAO_SIZE) { + case 4: + Platform.putInt(object, offset, value); + break; + case 8: + Platform.putLong(object, offset, value); + break; + default: + throw new AssertionError("Illegal UAO_SIZE"); + } + } +} diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java index cf42877bf9fd4..9c551ab19e9aa 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java @@ -40,6 +40,7 @@ public static int roundNumberOfBytesToNearestWord(int numBytes) { } } + private static final boolean unaligned = Platform.unaligned(); /** * Optimized byte array equality check for byte arrays. * @return true if the arrays are equal, false otherwise @@ -47,17 +48,33 @@ public static int roundNumberOfBytesToNearestWord(int numBytes) { public static boolean arrayEquals( Object leftBase, long leftOffset, Object rightBase, long rightOffset, final long length) { int i = 0; - while (i <= length - 8) { - if (Platform.getLong(leftBase, leftOffset + i) != - Platform.getLong(rightBase, rightOffset + i)) { - return false; + + // check if stars align and we can get both offsets to be aligned + if ((leftOffset % 8) == (rightOffset % 8)) { + while ((leftOffset + i) % 8 != 0 && i < length) { + if (Platform.getByte(leftBase, leftOffset + i) != + Platform.getByte(rightBase, rightOffset + i)) { + return false; + } + i += 1; + } + } + // for architectures that suport unaligned accesses, chew it up 8 bytes at a time + if (unaligned || (((leftOffset + i) % 8 == 0) && ((rightOffset + i) % 8 == 0))) { + while (i <= length - 8) { + if (Platform.getLong(leftBase, leftOffset + i) != + Platform.getLong(rightBase, rightOffset + i)) { + return false; + } + i += 8; } - i += 8; } + // this will finish off the unaligned comparisons, or do the entire aligned + // comparison whichever is needed. while (i < length) { if (Platform.getByte(leftBase, leftOffset + i) != - Platform.getByte(rightBase, rightOffset + i)) { - return false; + Platform.getByte(rightBase, rightOffset + i)) { + return false; } i += 1; } diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index e4289818f1e75..d2fcdea4f2cee 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -35,6 +35,7 @@ import org.apache.spark.serializer.SerializerManager; import org.apache.spark.storage.BlockManager; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.UnsafeAlignedOffset; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.unsafe.hash.Murmur3_x86_32; @@ -273,8 +274,8 @@ private void advanceToNextPage() { currentPage = dataPages.get(nextIdx); pageBaseObject = currentPage.getBaseObject(); offsetInPage = currentPage.getBaseOffset(); - recordsInPage = Platform.getInt(pageBaseObject, offsetInPage); - offsetInPage += 4; + recordsInPage = UnsafeAlignedOffset.getSize(pageBaseObject, offsetInPage); + offsetInPage += UnsafeAlignedOffset.getUaoSize(); } else { currentPage = null; if (reader != null) { @@ -321,10 +322,10 @@ public Location next() { } numRecords--; if (currentPage != null) { - int totalLength = Platform.getInt(pageBaseObject, offsetInPage); + int totalLength = UnsafeAlignedOffset.getSize(pageBaseObject, offsetInPage); loc.with(currentPage, offsetInPage); // [total size] [key size] [key] [value] [pointer to next] - offsetInPage += 4 + totalLength + 8; + offsetInPage += UnsafeAlignedOffset.getUaoSize() + totalLength + 8; recordsInPage --; return loc; } else { @@ -367,14 +368,15 @@ public long spill(long numBytes) throws IOException { Object base = block.getBaseObject(); long offset = block.getBaseOffset(); - int numRecords = Platform.getInt(base, offset); - offset += 4; + int numRecords = UnsafeAlignedOffset.getSize(base, offset); + int uaoSize = UnsafeAlignedOffset.getUaoSize(); + offset += uaoSize; final UnsafeSorterSpillWriter writer = new UnsafeSorterSpillWriter(blockManager, 32 * 1024, writeMetrics, numRecords); while (numRecords > 0) { - int length = Platform.getInt(base, offset); - writer.write(base, offset + 4, length, 0); - offset += 4 + length + 8; + int length = UnsafeAlignedOffset.getSize(base, offset); + writer.write(base, offset + uaoSize, length, 0); + offset += uaoSize + length + 8; numRecords--; } writer.close(); @@ -530,13 +532,14 @@ private void updateAddressesAndSizes(long fullKeyAddress) { private void updateAddressesAndSizes(final Object base, long offset) { baseObject = base; - final int totalLength = Platform.getInt(base, offset); - offset += 4; - keyLength = Platform.getInt(base, offset); - offset += 4; + final int totalLength = UnsafeAlignedOffset.getSize(base, offset); + int uaoSize = UnsafeAlignedOffset.getUaoSize(); + offset += uaoSize; + keyLength = UnsafeAlignedOffset.getSize(base, offset); + offset += uaoSize; keyOffset = offset; valueOffset = offset + keyLength; - valueLength = totalLength - keyLength - 4; + valueLength = totalLength - keyLength - uaoSize; } private Location with(int pos, int keyHashcode, boolean isDefined) { @@ -565,10 +568,11 @@ private Location with(Object base, long offset, int length) { this.isDefined = true; this.memoryPage = null; baseObject = base; - keyOffset = offset + 4; - keyLength = Platform.getInt(base, offset); - valueOffset = offset + 4 + keyLength; - valueLength = length - 4 - keyLength; + int uaoSize = UnsafeAlignedOffset.getUaoSize(); + keyOffset = offset + uaoSize; + keyLength = UnsafeAlignedOffset.getSize(base, offset); + valueOffset = offset + uaoSize + keyLength; + valueLength = length - uaoSize - keyLength; return this; } @@ -699,9 +703,10 @@ public boolean append(Object kbase, long koff, int klen, Object vbase, long voff // the key address instead of storing the absolute address of the value, the key and value // must be stored in the same memory page. // (8 byte key length) (key) (value) (8 byte pointer to next value) - final long recordLength = 8 + klen + vlen + 8; + int uaoSize = UnsafeAlignedOffset.getUaoSize(); + final long recordLength = (2 * uaoSize) + klen + vlen + 8; if (currentPage == null || currentPage.size() - pageCursor < recordLength) { - if (!acquireNewPage(recordLength + 4L)) { + if (!acquireNewPage(recordLength + uaoSize)) { return false; } } @@ -710,9 +715,9 @@ public boolean append(Object kbase, long koff, int klen, Object vbase, long voff final Object base = currentPage.getBaseObject(); long offset = currentPage.getBaseOffset() + pageCursor; final long recordOffset = offset; - Platform.putInt(base, offset, klen + vlen + 4); - Platform.putInt(base, offset + 4, klen); - offset += 8; + UnsafeAlignedOffset.putSize(base, offset, klen + vlen + uaoSize); + UnsafeAlignedOffset.putSize(base, offset + uaoSize, klen); + offset += (2 * uaoSize); Platform.copyMemory(kbase, koff, base, offset, klen); offset += klen; Platform.copyMemory(vbase, voff, base, offset, vlen); @@ -722,7 +727,7 @@ public boolean append(Object kbase, long koff, int klen, Object vbase, long voff // --- Update bookkeeping data structures ---------------------------------------------------- offset = currentPage.getBaseOffset(); - Platform.putInt(base, offset, Platform.getInt(base, offset) + 1); + UnsafeAlignedOffset.putSize(base, offset, UnsafeAlignedOffset.getSize(base, offset) + 1); pageCursor += recordLength; final long storedKeyAddress = taskMemoryManager.encodePageNumberAndOffset( currentPage, recordOffset); @@ -757,8 +762,8 @@ private boolean acquireNewPage(long required) { return false; } dataPages.add(currentPage); - Platform.putInt(currentPage.getBaseObject(), currentPage.getBaseOffset(), 0); - pageCursor = 4; + UnsafeAlignedOffset.putSize(currentPage.getBaseObject(), currentPage.getBaseOffset(), 0); + pageCursor = UnsafeAlignedOffset.getUaoSize(); return true; } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 8ca29a58f8f64..428ff72e71a43 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -34,6 +34,7 @@ import org.apache.spark.serializer.SerializerManager; import org.apache.spark.storage.BlockManager; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.UnsafeAlignedOffset; import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.util.TaskCompletionListener; @@ -392,14 +393,15 @@ public void insertRecord( } growPointerArrayIfNecessary(); + int uaoSize = UnsafeAlignedOffset.getUaoSize(); // Need 4 bytes to store the record length. - final int required = length + 4; + final int required = length + uaoSize; acquireNewPageIfNecessary(required); final Object base = currentPage.getBaseObject(); final long recordAddress = taskMemoryManager.encodePageNumberAndOffset(currentPage, pageCursor); - Platform.putInt(base, pageCursor, length); - pageCursor += 4; + UnsafeAlignedOffset.putSize(base, pageCursor, length); + pageCursor += uaoSize; Platform.copyMemory(recordBase, recordOffset, base, pageCursor, length); pageCursor += length; inMemSorter.insertRecord(recordAddress, prefix, prefixIsNull); @@ -418,15 +420,16 @@ public void insertKVRecord(Object keyBase, long keyOffset, int keyLen, throws IOException { growPointerArrayIfNecessary(); - final int required = keyLen + valueLen + 4 + 4; + int uaoSize = UnsafeAlignedOffset.getUaoSize(); + final int required = keyLen + valueLen + (2 * uaoSize); acquireNewPageIfNecessary(required); final Object base = currentPage.getBaseObject(); final long recordAddress = taskMemoryManager.encodePageNumberAndOffset(currentPage, pageCursor); - Platform.putInt(base, pageCursor, keyLen + valueLen + 4); - pageCursor += 4; - Platform.putInt(base, pageCursor, keyLen); - pageCursor += 4; + UnsafeAlignedOffset.putSize(base, pageCursor, keyLen + valueLen + uaoSize); + pageCursor += uaoSize; + UnsafeAlignedOffset.putSize(base, pageCursor, keyLen); + pageCursor += uaoSize; Platform.copyMemory(keyBase, keyOffset, base, pageCursor, keyLen); pageCursor += keyLen; Platform.copyMemory(valueBase, valueOffset, base, pageCursor, valueLen); diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index 8ecd20910ab73..2a71e68adafad 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -25,6 +25,7 @@ import org.apache.spark.memory.MemoryConsumer; import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.UnsafeAlignedOffset; import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.util.collection.Sorter; @@ -56,11 +57,14 @@ private static final class SortComparator implements Comparator Date: Tue, 4 Oct 2016 06:54:48 -0700 Subject: [PATCH 191/306] [SPARK-17744][ML] Parity check between the ml and mllib test suites for NB ## What changes were proposed in this pull request? 1,parity check and add missing test suites for ml's NB 2,remove some unused imports ## How was this patch tested? manual tests in spark-shell Author: Zheng RuiFeng Closes #15312 from zhengruifeng/nb_test_parity. --- .../spark/ml/feature/LabeledPoint.scala | 2 +- .../ml/feature/QuantileDiscretizer.scala | 2 +- .../org/apache/spark/ml/python/MLSerDe.scala | 5 -- .../spark/ml/regression/GBTRegressor.scala | 2 +- .../ml/regression/LinearRegression.scala | 1 - .../ml/classification/NaiveBayesSuite.scala | 69 ++++++++++++++++++- python/pyspark/ml/classification.py | 1 - 7 files changed, 70 insertions(+), 12 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/LabeledPoint.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/LabeledPoint.scala index 6cefa7086c881..7d8e4adcc2259 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/LabeledPoint.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/LabeledPoint.scala @@ -25,7 +25,7 @@ import org.apache.spark.ml.linalg.Vector /** * :: Experimental :: * - * Class that represents the features and labels of a data point. + * Class that represents the features and label of a data point. * * @param label Label for this data point. * @param features List of features for this data point. diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index 1e59d71a70955..05e034d90f6a3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -25,7 +25,7 @@ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util._ import org.apache.spark.sql.Dataset -import org.apache.spark.sql.types.{DoubleType, StructType} +import org.apache.spark.sql.types.StructType /** * Params for [[QuantileDiscretizer]]. diff --git a/mllib/src/main/scala/org/apache/spark/ml/python/MLSerDe.scala b/mllib/src/main/scala/org/apache/spark/ml/python/MLSerDe.scala index 4b805e145482a..da62f8518e363 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/python/MLSerDe.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/python/MLSerDe.scala @@ -19,17 +19,12 @@ package org.apache.spark.ml.python import java.io.OutputStream import java.nio.{ByteBuffer, ByteOrder} -import java.util.{ArrayList => JArrayList} - -import scala.collection.JavaConverters._ import net.razorvine.pickle._ -import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.python.SerDeUtil import org.apache.spark.ml.linalg._ import org.apache.spark.mllib.api.python.SerDeBase -import org.apache.spark.rdd.RDD /** * SerDe utility functions for pyspark.ml. diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index ce355938ec1c7..bb01f9d5a364c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -21,7 +21,7 @@ import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.json4s.{DefaultFormats, JObject} import org.json4s.JsonDSL._ -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.internal.Logging import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.feature.LabeledPoint diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 7fddfd9b10f84..536c58f998080 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -37,7 +37,6 @@ import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.mllib.evaluation.RegressionMetrics -import org.apache.spark.mllib.linalg.{Vectors => OldVectors} import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.mllib.util.MLUtils diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala index 597428d036c7a..e934e5ea42b16 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala @@ -22,10 +22,10 @@ import scala.util.Random import breeze.linalg.{DenseVector => BDV, Vector => BV} import breeze.stats.distributions.{Multinomial => BrzMultinomial} -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.classification.NaiveBayes.{Bernoulli, Multinomial} import org.apache.spark.ml.classification.NaiveBayesSuite._ -import org.apache.spark.ml.feature.{Instance, LabeledPoint} +import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg._ import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} @@ -106,6 +106,11 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa } } + test("model types") { + assert(Multinomial === "multinomial") + assert(Bernoulli === "bernoulli") + } + test("params") { ParamsSuite.checkParams(new NaiveBayes) val model = new NaiveBayesModel("nb", pi = Vectors.dense(Array(0.2, 0.8)), @@ -228,6 +233,66 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa validateProbabilities(featureAndProbabilities, model, "bernoulli") } + test("detect negative values") { + val dense = spark.createDataFrame(Seq( + LabeledPoint(1.0, Vectors.dense(1.0)), + LabeledPoint(0.0, Vectors.dense(-1.0)), + LabeledPoint(1.0, Vectors.dense(1.0)), + LabeledPoint(1.0, Vectors.dense(0.0)))) + intercept[SparkException] { + new NaiveBayes().fit(dense) + } + val sparse = spark.createDataFrame(Seq( + LabeledPoint(1.0, Vectors.sparse(1, Array(0), Array(1.0))), + LabeledPoint(0.0, Vectors.sparse(1, Array(0), Array(-1.0))), + LabeledPoint(1.0, Vectors.sparse(1, Array(0), Array(1.0))), + LabeledPoint(1.0, Vectors.sparse(1, Array.empty, Array.empty)))) + intercept[SparkException] { + new NaiveBayes().fit(sparse) + } + val nan = spark.createDataFrame(Seq( + LabeledPoint(1.0, Vectors.sparse(1, Array(0), Array(1.0))), + LabeledPoint(0.0, Vectors.sparse(1, Array(0), Array(Double.NaN))), + LabeledPoint(1.0, Vectors.sparse(1, Array(0), Array(1.0))), + LabeledPoint(1.0, Vectors.sparse(1, Array.empty, Array.empty)))) + intercept[SparkException] { + new NaiveBayes().fit(nan) + } + } + + test("detect non zero or one values in Bernoulli") { + val badTrain = spark.createDataFrame(Seq( + LabeledPoint(1.0, Vectors.dense(1.0)), + LabeledPoint(0.0, Vectors.dense(2.0)), + LabeledPoint(1.0, Vectors.dense(1.0)), + LabeledPoint(1.0, Vectors.dense(0.0)))) + + intercept[SparkException] { + new NaiveBayes().setModelType(Bernoulli).setSmoothing(1.0).fit(badTrain) + } + + val okTrain = spark.createDataFrame(Seq( + LabeledPoint(1.0, Vectors.dense(1.0)), + LabeledPoint(0.0, Vectors.dense(0.0)), + LabeledPoint(1.0, Vectors.dense(1.0)), + LabeledPoint(1.0, Vectors.dense(1.0)), + LabeledPoint(0.0, Vectors.dense(0.0)), + LabeledPoint(1.0, Vectors.dense(1.0)), + LabeledPoint(1.0, Vectors.dense(1.0)))) + + val model = new NaiveBayes().setModelType(Bernoulli).setSmoothing(1.0).fit(okTrain) + + val badPredict = spark.createDataFrame(Seq( + LabeledPoint(1.0, Vectors.dense(1.0)), + LabeledPoint(1.0, Vectors.dense(2.0)), + LabeledPoint(1.0, Vectors.dense(1.0)), + LabeledPoint(1.0, Vectors.dense(0.0)))) + + intercept[SparkException] { + model.transform(badPredict).collect() + } + } + test("read/write") { def checkModelData(model: NaiveBayesModel, model2: NaiveBayesModel): Unit = { assert(model.pi === model2.pi) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 505e7bffd1763..ea60fab029582 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -16,7 +16,6 @@ # import operator -import warnings from pyspark import since, keyword_only from pyspark.ml import Estimator, Model From 068c198e956346b90968a4d74edb7bc820c4be28 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Tue, 4 Oct 2016 09:22:26 -0700 Subject: [PATCH 192/306] [SPARKR][DOC] minor formatting and output cleanup for R vignettes ## What changes were proposed in this pull request? Clean up output, format table, truncate long example output, hide warnings (new - Left; existing - Right) ![image](https://cloud.githubusercontent.com/assets/8969467/19064018/5dcde4d0-89bc-11e6-857b-052df3f52a4e.png) ![image](https://cloud.githubusercontent.com/assets/8969467/19064034/6db09956-89bc-11e6-8e43-232d5c3fe5e6.png) ![image](https://cloud.githubusercontent.com/assets/8969467/19064058/88f09590-89bc-11e6-9993-61639e29dfdd.png) ![image](https://cloud.githubusercontent.com/assets/8969467/19064066/95ccbf64-89bc-11e6-877f-45af03ddcadc.png) ![image](https://cloud.githubusercontent.com/assets/8969467/19064082/a8445404-89bc-11e6-8532-26d8bc9b206f.png) ## How was this patch tested? Run create-doc.sh manually Author: Felix Cheung Closes #15340 from felixcheung/vignettes. --- R/pkg/vignettes/sparkr-vignettes.Rmd | 31 ++++++++++++++++++---------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/R/pkg/vignettes/sparkr-vignettes.Rmd b/R/pkg/vignettes/sparkr-vignettes.Rmd index aea52db8b8556..80e876027bddb 100644 --- a/R/pkg/vignettes/sparkr-vignettes.Rmd +++ b/R/pkg/vignettes/sparkr-vignettes.Rmd @@ -26,7 +26,7 @@ library(SparkR) We use default settings in which it runs in local mode. It auto downloads Spark package in the background if no previous installation is found. For more details about setup, see [Spark Session](#SetupSparkSession). -```{r, message=FALSE} +```{r, message=FALSE, results="hide"} sparkR.session() ``` @@ -114,10 +114,12 @@ In particular, the following Spark driver properties can be set in `sparkConfig` Property Name | Property group | spark-submit equivalent ---------------- | ------------------ | ---------------------- -spark.driver.memory | Application Properties | --driver-memory -spark.driver.extraClassPath | Runtime Environment | --driver-class-path -spark.driver.extraJavaOptions | Runtime Environment | --driver-java-options -spark.driver.extraLibraryPath | Runtime Environment | --driver-library-path +`spark.driver.memory` | Application Properties | `--driver-memory` +`spark.driver.extraClassPath` | Runtime Environment | `--driver-class-path` +`spark.driver.extraJavaOptions` | Runtime Environment | `--driver-java-options` +`spark.driver.extraLibraryPath` | Runtime Environment | `--driver-library-path` +`spark.yarn.keytab` | Application Properties | `--keytab` +`spark.yarn.principal` | Application Properties | `--principal` **For Windows users**: Due to different file prefixes across operating systems, to avoid the issue of potential wrong prefix, a current workaround is to specify `spark.sql.warehouse.dir` when starting the `SparkSession`. @@ -161,7 +163,7 @@ head(df) ### Data Sources SparkR supports operating on a variety of data sources through the `SparkDataFrame` interface. You can check the Spark SQL programming guide for more [specific options](https://spark.apache.org/docs/latest/sql-programming-guide.html#manually-specifying-options) that are available for the built-in data sources. -The general method for creating `SparkDataFrame` from data sources is `read.df`. This method takes in the path for the file to load and the type of data source, and the currently active Spark Session will be used automatically. SparkR supports reading CSV, JSON and Parquet files natively and through Spark Packages you can find data source connectors for popular file formats like Avro. These packages can be added with `sparkPackages` parameter when initializing SparkSession using `sparkR.session'.` +The general method for creating `SparkDataFrame` from data sources is `read.df`. This method takes in the path for the file to load and the type of data source, and the currently active Spark Session will be used automatically. SparkR supports reading CSV, JSON and Parquet files natively and through Spark Packages you can find data source connectors for popular file formats like Avro. These packages can be added with `sparkPackages` parameter when initializing SparkSession using `sparkR.session`. ```{r, eval=FALSE} sparkR.session(sparkPackages = "com.databricks:spark-avro_2.11:3.0.0") @@ -406,10 +408,17 @@ class(model.summaries) ``` -To avoid lengthy display, we only present the result of the second fitted model. You are free to inspect other models as well. +To avoid lengthy display, we only present the partial result of the second fitted model. You are free to inspect other models as well. +```{r, include=FALSE} +ops <- options() +options(max.print=40) +``` ```{r} print(model.summaries[[2]]) ``` +```{r, include=FALSE} +options(ops) +``` ### SQL Queries @@ -544,7 +553,7 @@ head(select(kmeansPredictions, "model", "mpg", "hp", "wt", "prediction"), n = 20 Survival analysis studies the expected duration of time until an event happens, and often the relationship with risk factors or treatment taken on the subject. In contrast to standard regression analysis, survival modeling has to deal with special characteristics in the data including non-negative survival time and censoring. Accelerated Failure Time (AFT) model is a parametric survival model for censored data that assumes the effect of a covariate is to accelerate or decelerate the life course of an event by some constant. For more information, refer to the Wikipedia page [AFT Model](https://en.wikipedia.org/wiki/Accelerated_failure_time_model) and the references there. Different from a [Proportional Hazards Model](https://en.wikipedia.org/wiki/Proportional_hazards_model) designed for the same purpose, the AFT model is easier to parallelize because each instance contributes to the objective function independently. -```{r} +```{r, warning=FALSE} library(survival) ovarianDF <- createDataFrame(ovarian) aftModel <- spark.survreg(ovarianDF, Surv(futime, fustat) ~ ecog_ps + rx) @@ -678,7 +687,7 @@ MLPC employs backpropagation for learning the model. We use the logistic loss fu * `tol`: convergence tolerance of iterations. -* `stepSize`: step size for `"gd"`. +* `stepSize`: step size for `"gd"`. * `seed`: seed parameter for weights initialization. @@ -763,8 +772,8 @@ We also expect Decision Tree, Random Forest, Kolmogorov-Smirnov Test coming in t ### Model Persistence The following example shows how to save/load an ML model by SparkR. -```{r} -irisDF <- suppressWarnings(createDataFrame(iris)) +```{r, warning=FALSE} +irisDF <- createDataFrame(iris) gaussianGLM <- spark.glm(irisDF, Sepal_Length ~ Sepal_Width + Species, family = "gaussian") # Save and then load a fitted MLlib model From 8d969a2125d915da1506c17833aa98da614a257f Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 4 Oct 2016 09:38:44 -0700 Subject: [PATCH 193/306] [SPARK-17549][SQL] Only collect table size stat in driver for cached relation. This reverts commit 9ac68dbc5720026ea92acc61d295ca64d0d3d132. Turns out the original fix was correct. Original change description: The existing code caches all stats for all columns for each partition in the driver; for a large relation, this causes extreme memory usage, which leads to gc hell and application failures. It seems that only the size in bytes of the data is actually used in the driver, so instead just colllect that. In executors, the full stats are still kept, but that's not a big problem; we expect the data to be distributed and thus not really incur in too much memory pressure in each individual executor. There are also potential improvements on the executor side, since the data being stored currently is very wasteful (e.g. storing boxed types vs. primitive types for stats). But that's a separate issue. Author: Marcelo Vanzin Closes #15304 from vanzin/SPARK-17549.2. --- .../execution/columnar/InMemoryRelation.scala | 24 +++++-------------- .../columnar/InMemoryColumnarQuerySuite.scala | 14 +++++++++++ 2 files changed, 20 insertions(+), 18 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 479934a7afc75..56bd5c1891e8d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.execution.columnar -import scala.collection.JavaConverters._ - import org.apache.commons.lang3.StringUtils import org.apache.spark.network.util.JavaUtils @@ -31,7 +29,7 @@ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.CollectionAccumulator +import org.apache.spark.util.LongAccumulator object InMemoryRelation { @@ -63,8 +61,7 @@ case class InMemoryRelation( @transient child: SparkPlan, tableName: Option[String])( @transient var _cachedColumnBuffers: RDD[CachedBatch] = null, - val batchStats: CollectionAccumulator[InternalRow] = - child.sqlContext.sparkContext.collectionAccumulator[InternalRow]) + val batchStats: LongAccumulator = child.sqlContext.sparkContext.longAccumulator) extends logical.LeafNode with MultiInstanceRelation { override protected def innerChildren: Seq[QueryPlan[_]] = Seq(child) @@ -74,21 +71,12 @@ case class InMemoryRelation( @transient val partitionStatistics = new PartitionStatistics(output) override lazy val statistics: Statistics = { - if (batchStats.value.isEmpty) { + if (batchStats.value == 0L) { // Underlying columnar RDD hasn't been materialized, no useful statistics information // available, return the default statistics. Statistics(sizeInBytes = child.sqlContext.conf.defaultSizeInBytes) } else { - // Underlying columnar RDD has been materialized, required information has also been - // collected via the `batchStats` accumulator. - val sizeOfRow: Expression = - BindReferences.bindReference( - output.map(a => partitionStatistics.forAttribute(a).sizeInBytes).reduce(Add), - partitionStatistics.schema) - - val sizeInBytes = - batchStats.value.asScala.map(row => sizeOfRow.eval(row).asInstanceOf[Long]).sum - Statistics(sizeInBytes = sizeInBytes) + Statistics(sizeInBytes = batchStats.value.longValue) } } @@ -139,10 +127,10 @@ case class InMemoryRelation( rowCount += 1 } + batchStats.add(totalSize) + val stats = InternalRow.fromSeq(columnBuilders.map(_.columnStats.collectedStatistics) .flatMap(_.values)) - - batchStats.add(stats) CachedBatch(rowCount, columnBuilders.map { builder => JavaUtils.bufferToArray(builder.build()) }, stats) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 937839644ad5f..0daa29b666f62 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -232,4 +232,18 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { val columnTypes2 = List.fill(length2)(IntegerType) val columnarIterator2 = GenerateColumnAccessor.generate(columnTypes2) } + + test("SPARK-17549: cached table size should be correctly calculated") { + val data = spark.sparkContext.parallelize(1 to 10, 5).toDF() + val plan = spark.sessionState.executePlan(data.logicalPlan).sparkPlan + val cached = InMemoryRelation(true, 5, MEMORY_ONLY, plan, None) + + // Materialize the data. + val expectedAnswer = data.collect() + checkAnswer(cached, expectedAnswer) + + // Check that the right size was calculated. + assert(cached.batchStats.value === expectedAnswer.size * INT.defaultSize) + } + } From a99743d053e84f695dc3034550939555297b0a05 Mon Sep 17 00:00:00 2001 From: Tejas Patil Date: Tue, 4 Oct 2016 18:59:31 -0700 Subject: [PATCH 194/306] [SPARK-17495][SQL] Add Hash capability semantically equivalent to Hive's ## What changes were proposed in this pull request? Jira : https://issues.apache.org/jira/browse/SPARK-17495 Spark internally uses Murmur3Hash for partitioning. This is different from the one used by Hive. For queries which use bucketing this leads to different results if one tries the same query on both engines. For us, we want users to have backward compatibility to that one can switch parts of applications across the engines without observing regressions. This PR includes `HiveHash`, `HiveHashFunction`, `HiveHasher` which mimics Hive's hashing at https://github.com/apache/hive/blob/master/serde/src/java/org/apache/hadoop/hive/serde2/objectinspector/ObjectInspectorUtils.java#L638 I am intentionally not introducing any usages of this hash function in rest of the code to keep this PR small. My eventual goal is to have Hive bucketing support in Spark. Once this PR gets in, I will make hash function pluggable in relevant areas (eg. `HashPartitioning`'s `partitionIdExpression` has Murmur3 hardcoded : https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala#L265) ## How was this patch tested? Added `HiveHashSuite` Author: Tejas Patil Closes #15047 from tejasapatil/SPARK-17495_hive_hash. --- .../sql/catalyst/expressions/HiveHasher.java | 49 +++ .../spark/sql/catalyst/expressions/misc.scala | 391 +++++++++++++++--- .../catalyst/expressions/HiveHasherSuite.java | 128 ++++++ .../org/apache/spark/sql/HashBenchmark.scala | 93 +++-- .../spark/sql/HashByteArrayBenchmark.scala | 118 +++--- .../expressions/MiscFunctionsSuite.scala | 3 +- 6 files changed, 631 insertions(+), 151 deletions(-) create mode 100644 common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java create mode 100644 sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java new file mode 100644 index 0000000000000..c7ea9085eba66 --- /dev/null +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions; + +import org.apache.spark.unsafe.Platform; + +/** + * Simulates Hive's hashing function at + * org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils#hashcode() + */ +public class HiveHasher { + + @Override + public String toString() { + return HiveHasher.class.getSimpleName(); + } + + public static int hashInt(int input) { + return input; + } + + public static int hashLong(long input) { + return (int) ((input >>> 32) ^ input); + } + + public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes) { + assert (lengthInBytes >= 0): "lengthInBytes cannot be negative"; + int result = 0; + for (int i = 0; i < lengthInBytes; i++) { + result = (result * 31) + (int) Platform.getByte(base, offset + i); + } + return result; + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index dbb52a4bb18de..138ef2a1dcc01 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -259,7 +259,7 @@ abstract class HashExpression[E] extends Expression { $childrenHash""") } - private def nullSafeElementHash( + protected def nullSafeElementHash( input: String, index: String, nullable: Boolean, @@ -276,76 +276,127 @@ abstract class HashExpression[E] extends Expression { } } - @tailrec - private def computeHash( + protected def genHashInt(i: String, result: String): String = + s"$result = $hasherClassName.hashInt($i, $result);" + + protected def genHashLong(l: String, result: String): String = + s"$result = $hasherClassName.hashLong($l, $result);" + + protected def genHashBytes(b: String, result: String): String = { + val offset = "Platform.BYTE_ARRAY_OFFSET" + s"$result = $hasherClassName.hashUnsafeBytes($b, $offset, $b.length, $result);" + } + + protected def genHashBoolean(input: String, result: String): String = + genHashInt(s"$input ? 1 : 0", result) + + protected def genHashFloat(input: String, result: String): String = + genHashInt(s"Float.floatToIntBits($input)", result) + + protected def genHashDouble(input: String, result: String): String = + genHashLong(s"Double.doubleToLongBits($input)", result) + + protected def genHashDecimal( + ctx: CodegenContext, + d: DecimalType, input: String, - dataType: DataType, - result: String, - ctx: CodegenContext): String = { - val hasher = hasherClassName - - def hashInt(i: String): String = s"$result = $hasher.hashInt($i, $result);" - def hashLong(l: String): String = s"$result = $hasher.hashLong($l, $result);" - def hashBytes(b: String): String = - s"$result = $hasher.hashUnsafeBytes($b, Platform.BYTE_ARRAY_OFFSET, $b.length, $result);" - - dataType match { - case NullType => "" - case BooleanType => hashInt(s"$input ? 1 : 0") - case ByteType | ShortType | IntegerType | DateType => hashInt(input) - case LongType | TimestampType => hashLong(input) - case FloatType => hashInt(s"Float.floatToIntBits($input)") - case DoubleType => hashLong(s"Double.doubleToLongBits($input)") - case d: DecimalType => - if (d.precision <= Decimal.MAX_LONG_DIGITS) { - hashLong(s"$input.toUnscaledLong()") - } else { - val bytes = ctx.freshName("bytes") - s""" + result: String): String = { + if (d.precision <= Decimal.MAX_LONG_DIGITS) { + genHashLong(s"$input.toUnscaledLong()", result) + } else { + val bytes = ctx.freshName("bytes") + s""" final byte[] $bytes = $input.toJavaBigDecimal().unscaledValue().toByteArray(); - ${hashBytes(bytes)} + ${genHashBytes(bytes, result)} """ + } + } + + protected def genHashCalendarInterval(input: String, result: String): String = { + val microsecondsHash = s"$hasherClassName.hashLong($input.microseconds, $result)" + s"$result = $hasherClassName.hashInt($input.months, $microsecondsHash);" + } + + protected def genHashString(input: String, result: String): String = { + val baseObject = s"$input.getBaseObject()" + val baseOffset = s"$input.getBaseOffset()" + val numBytes = s"$input.numBytes()" + s"$result = $hasherClassName.hashUnsafeBytes($baseObject, $baseOffset, $numBytes, $result);" + } + + protected def genHashForMap( + ctx: CodegenContext, + input: String, + result: String, + keyType: DataType, + valueType: DataType, + valueContainsNull: Boolean): String = { + val index = ctx.freshName("index") + val keys = ctx.freshName("keys") + val values = ctx.freshName("values") + s""" + final ArrayData $keys = $input.keyArray(); + final ArrayData $values = $input.valueArray(); + for (int $index = 0; $index < $input.numElements(); $index++) { + ${nullSafeElementHash(keys, index, false, keyType, result, ctx)} + ${nullSafeElementHash(values, index, valueContainsNull, valueType, result, ctx)} } - case CalendarIntervalType => - val microsecondsHash = s"$hasher.hashLong($input.microseconds, $result)" - s"$result = $hasher.hashInt($input.months, $microsecondsHash);" - case BinaryType => hashBytes(input) - case StringType => - val baseObject = s"$input.getBaseObject()" - val baseOffset = s"$input.getBaseOffset()" - val numBytes = s"$input.numBytes()" - s"$result = $hasher.hashUnsafeBytes($baseObject, $baseOffset, $numBytes, $result);" - - case ArrayType(et, containsNull) => - val index = ctx.freshName("index") - s""" - for (int $index = 0; $index < $input.numElements(); $index++) { - ${nullSafeElementHash(input, index, containsNull, et, result, ctx)} - } - """ - - case MapType(kt, vt, valueContainsNull) => - val index = ctx.freshName("index") - val keys = ctx.freshName("keys") - val values = ctx.freshName("values") - s""" - final ArrayData $keys = $input.keyArray(); - final ArrayData $values = $input.valueArray(); - for (int $index = 0; $index < $input.numElements(); $index++) { - ${nullSafeElementHash(keys, index, false, kt, result, ctx)} - ${nullSafeElementHash(values, index, valueContainsNull, vt, result, ctx)} - } - """ + """ + } + + protected def genHashForArray( + ctx: CodegenContext, + input: String, + result: String, + elementType: DataType, + containsNull: Boolean): String = { + val index = ctx.freshName("index") + s""" + for (int $index = 0; $index < $input.numElements(); $index++) { + ${nullSafeElementHash(input, index, containsNull, elementType, result, ctx)} + } + """ + } - case StructType(fields) => - fields.zipWithIndex.map { case (field, index) => - nullSafeElementHash(input, index.toString, field.nullable, field.dataType, result, ctx) - }.mkString("\n") + protected def genHashForStruct( + ctx: CodegenContext, + input: String, + result: String, + fields: Array[StructField]): String = { + fields.zipWithIndex.map { case (field, index) => + nullSafeElementHash(input, index.toString, field.nullable, field.dataType, result, ctx) + }.mkString("\n") + } - case udt: UserDefinedType[_] => computeHash(input, udt.sqlType, result, ctx) - } + @tailrec + private def computeHashWithTailRec( + input: String, + dataType: DataType, + result: String, + ctx: CodegenContext): String = dataType match { + case NullType => "" + case BooleanType => genHashBoolean(input, result) + case ByteType | ShortType | IntegerType | DateType => genHashInt(input, result) + case LongType | TimestampType => genHashLong(input, result) + case FloatType => genHashFloat(input, result) + case DoubleType => genHashDouble(input, result) + case d: DecimalType => genHashDecimal(ctx, d, input, result) + case CalendarIntervalType => genHashCalendarInterval(input, result) + case BinaryType => genHashBytes(input, result) + case StringType => genHashString(input, result) + case ArrayType(et, containsNull) => genHashForArray(ctx, input, result, et, containsNull) + case MapType(kt, vt, valueContainsNull) => + genHashForMap(ctx, input, result, kt, vt, valueContainsNull) + case StructType(fields) => genHashForStruct(ctx, input, result, fields) + case udt: UserDefinedType[_] => computeHashWithTailRec(input, udt.sqlType, result, ctx) } + protected def computeHash( + input: String, + dataType: DataType, + result: String, + ctx: CodegenContext): String = computeHashWithTailRec(input, dataType, result, ctx) + protected def hasherClassName: String } @@ -568,3 +619,217 @@ case class CurrentDatabase() extends LeafExpression with Unevaluable { override def foldable: Boolean = true override def nullable: Boolean = false } + +/** + * Simulates Hive's hashing function at + * org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils#hashcode() in Hive + * + * We should use this hash function for both shuffle and bucket of Hive tables, so that + * we can guarantee shuffle and bucketing have same data distribution + * + * TODO: Support Decimal and date related types + */ +@ExpressionDescription( + usage = "_FUNC_(a1, a2, ...) - Returns a hash value of the arguments.") +case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { + override val seed = 0 + + override def dataType: DataType = IntegerType + + override def prettyName: String = "hive-hash" + + override protected def hasherClassName: String = classOf[HiveHasher].getName + + override protected def computeHash(value: Any, dataType: DataType, seed: Int): Int = { + HiveHashFunction.hash(value, dataType, seed).toInt + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + ev.isNull = "false" + val childHash = ctx.freshName("childHash") + val childrenHash = children.map { child => + val childGen = child.genCode(ctx) + childGen.code + ctx.nullSafeExec(child.nullable, childGen.isNull) { + computeHash(childGen.value, child.dataType, childHash, ctx) + } + s"${ev.value} = (31 * ${ev.value}) + $childHash;" + }.mkString(s"int $childHash = 0;", s"\n$childHash = 0;\n", "") + + ev.copy(code = s""" + ${ctx.javaType(dataType)} ${ev.value} = $seed; + $childrenHash""") + } + + override def eval(input: InternalRow): Int = { + var hash = seed + var i = 0 + val len = children.length + while (i < len) { + hash = (31 * hash) + computeHash(children(i).eval(input), children(i).dataType, hash) + i += 1 + } + hash + } + + override protected def genHashInt(i: String, result: String): String = + s"$result = $hasherClassName.hashInt($i);" + + override protected def genHashLong(l: String, result: String): String = + s"$result = $hasherClassName.hashLong($l);" + + override protected def genHashBytes(b: String, result: String): String = + s"$result = $hasherClassName.hashUnsafeBytes($b, Platform.BYTE_ARRAY_OFFSET, $b.length);" + + override protected def genHashCalendarInterval(input: String, result: String): String = { + s""" + $result = (31 * $hasherClassName.hashInt($input.months)) + + $hasherClassName.hashLong($input.microseconds);" + """ + } + + override protected def genHashString(input: String, result: String): String = { + val baseObject = s"$input.getBaseObject()" + val baseOffset = s"$input.getBaseOffset()" + val numBytes = s"$input.numBytes()" + s"$result = $hasherClassName.hashUnsafeBytes($baseObject, $baseOffset, $numBytes);" + } + + override protected def genHashForArray( + ctx: CodegenContext, + input: String, + result: String, + elementType: DataType, + containsNull: Boolean): String = { + val index = ctx.freshName("index") + val childResult = ctx.freshName("childResult") + s""" + int $childResult = 0; + for (int $index = 0; $index < $input.numElements(); $index++) { + $childResult = 0; + ${nullSafeElementHash(input, index, containsNull, elementType, childResult, ctx)}; + $result = (31 * $result) + $childResult; + } + """ + } + + override protected def genHashForMap( + ctx: CodegenContext, + input: String, + result: String, + keyType: DataType, + valueType: DataType, + valueContainsNull: Boolean): String = { + val index = ctx.freshName("index") + val keys = ctx.freshName("keys") + val values = ctx.freshName("values") + val keyResult = ctx.freshName("keyResult") + val valueResult = ctx.freshName("valueResult") + s""" + final ArrayData $keys = $input.keyArray(); + final ArrayData $values = $input.valueArray(); + int $keyResult = 0; + int $valueResult = 0; + for (int $index = 0; $index < $input.numElements(); $index++) { + $keyResult = 0; + ${nullSafeElementHash(keys, index, false, keyType, keyResult, ctx)} + $valueResult = 0; + ${nullSafeElementHash(values, index, valueContainsNull, valueType, valueResult, ctx)} + $result += $keyResult ^ $valueResult; + } + """ + } + + override protected def genHashForStruct( + ctx: CodegenContext, + input: String, + result: String, + fields: Array[StructField]): String = { + val localResult = ctx.freshName("localResult") + val childResult = ctx.freshName("childResult") + fields.zipWithIndex.map { case (field, index) => + s""" + $childResult = 0; + ${nullSafeElementHash(input, index.toString, field.nullable, field.dataType, + childResult, ctx)} + $localResult = (31 * $localResult) + $childResult; + """ + }.mkString( + s""" + int $localResult = 0; + int $childResult = 0; + """, + "", + s"$result = (31 * $result) + $localResult;" + ) + } +} + +object HiveHashFunction extends InterpretedHashFunction { + override protected def hashInt(i: Int, seed: Long): Long = { + HiveHasher.hashInt(i) + } + + override protected def hashLong(l: Long, seed: Long): Long = { + HiveHasher.hashLong(l) + } + + override protected def hashUnsafeBytes(base: AnyRef, offset: Long, len: Int, seed: Long): Long = { + HiveHasher.hashUnsafeBytes(base, offset, len) + } + + override def hash(value: Any, dataType: DataType, seed: Long): Long = { + value match { + case null => 0 + case array: ArrayData => + val elementType = dataType match { + case udt: UserDefinedType[_] => udt.sqlType.asInstanceOf[ArrayType].elementType + case ArrayType(et, _) => et + } + + var result = 0 + var i = 0 + val length = array.numElements() + while (i < length) { + result = (31 * result) + hash(array.get(i, elementType), elementType, 0).toInt + i += 1 + } + result + + case map: MapData => + val (kt, vt) = dataType match { + case udt: UserDefinedType[_] => + val mapType = udt.sqlType.asInstanceOf[MapType] + mapType.keyType -> mapType.valueType + case MapType(_kt, _vt, _) => _kt -> _vt + } + val keys = map.keyArray() + val values = map.valueArray() + + var result = 0 + var i = 0 + val length = map.numElements() + while (i < length) { + result += hash(keys.get(i, kt), kt, 0).toInt ^ hash(values.get(i, vt), vt, 0).toInt + i += 1 + } + result + + case struct: InternalRow => + val types: Array[DataType] = dataType match { + case udt: UserDefinedType[_] => + udt.sqlType.asInstanceOf[StructType].map(_.dataType).toArray + case StructType(fields) => fields.map(_.dataType) + } + + var result = 0 + var i = 0 + val length = struct.numFields + while (i < length) { + result = (31 * result) + hash(struct.get(i, types(i)), types(i), seed + 1).toInt + i += 1 + } + result + + case _ => super.hash(value, dataType, seed) + } + } +} diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java new file mode 100644 index 0000000000000..67a5eb0c7fe8f --- /dev/null +++ b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions; + +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.types.UTF8String; +import org.junit.Assert; +import org.junit.Test; + +import java.nio.charset.StandardCharsets; +import java.util.HashSet; +import java.util.Random; +import java.util.Set; + +public class HiveHasherSuite { + private final static HiveHasher hasher = new HiveHasher(); + + @Test + public void testKnownIntegerInputs() { + int[] inputs = {0, Integer.MIN_VALUE, Integer.MAX_VALUE, 593689054, -189366624}; + for (int input : inputs) { + Assert.assertEquals(input, HiveHasher.hashInt(input)); + } + } + + @Test + public void testKnownLongInputs() { + Assert.assertEquals(0, HiveHasher.hashLong(0L)); + Assert.assertEquals(41, HiveHasher.hashLong(-42L)); + Assert.assertEquals(42, HiveHasher.hashLong(42L)); + Assert.assertEquals(-2147483648, HiveHasher.hashLong(Long.MIN_VALUE)); + Assert.assertEquals(-2147483648, HiveHasher.hashLong(Long.MAX_VALUE)); + } + + @Test + public void testKnownStringAndIntInputs() { + int[] inputs = {84, 19, 8}; + int[] expected = {-823832826, -823835053, 111972242}; + + for (int i = 0; i < inputs.length; i++) { + UTF8String s = UTF8String.fromString("val_" + inputs[i]); + int hash = HiveHasher.hashUnsafeBytes(s.getBaseObject(), s.getBaseOffset(), s.numBytes()); + Assert.assertEquals(expected[i], ((31 * inputs[i]) + hash)); + } + } + + @Test + public void randomizedStressTest() { + int size = 65536; + Random rand = new Random(); + + // A set used to track collision rate. + Set hashcodes = new HashSet<>(); + for (int i = 0; i < size; i++) { + int vint = rand.nextInt(); + long lint = rand.nextLong(); + Assert.assertEquals(HiveHasher.hashInt(vint), HiveHasher.hashInt(vint)); + Assert.assertEquals(HiveHasher.hashLong(lint), HiveHasher.hashLong(lint)); + + hashcodes.add(HiveHasher.hashLong(lint)); + } + + // A very loose bound. + Assert.assertTrue(hashcodes.size() > size * 0.95); + } + + @Test + public void randomizedStressTestBytes() { + int size = 65536; + Random rand = new Random(); + + // A set used to track collision rate. + Set hashcodes = new HashSet<>(); + for (int i = 0; i < size; i++) { + int byteArrSize = rand.nextInt(100) * 8; + byte[] bytes = new byte[byteArrSize]; + rand.nextBytes(bytes); + + Assert.assertEquals( + HiveHasher.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize), + HiveHasher.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); + + hashcodes.add(HiveHasher.hashUnsafeBytes( + bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); + } + + // A very loose bound. + Assert.assertTrue(hashcodes.size() > size * 0.95); + } + + @Test + public void randomizedStressTestPaddedStrings() { + int size = 64000; + // A set used to track collision rate. + Set hashcodes = new HashSet<>(); + for (int i = 0; i < size; i++) { + int byteArrSize = 8; + byte[] strBytes = String.valueOf(i).getBytes(StandardCharsets.UTF_8); + byte[] paddedBytes = new byte[byteArrSize]; + System.arraycopy(strBytes, 0, paddedBytes, 0, strBytes.length); + + Assert.assertEquals( + HiveHasher.hashUnsafeBytes(paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize), + HiveHasher.hashUnsafeBytes(paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); + + hashcodes.add(HiveHasher.hashUnsafeBytes( + paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); + } + + // A very loose bound. + Assert.assertTrue(hashcodes.size() > size * 0.95); + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala index c6a1a2be0d071..2d94b66a1e122 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala @@ -42,8 +42,8 @@ object HashBenchmark { val benchmark = new Benchmark("Hash For " + name, iters * numRows) benchmark.addCase("interpreted version") { _: Int => + var sum = 0 for (_ <- 0L until iters) { - var sum = 0 var i = 0 while (i < numRows) { sum += rows(i).hashCode() @@ -54,8 +54,8 @@ object HashBenchmark { val getHashCode = UnsafeProjection.create(new Murmur3Hash(attrs) :: Nil, attrs) benchmark.addCase("codegen version") { _: Int => + var sum = 0 for (_ <- 0L until iters) { - var sum = 0 var i = 0 while (i < numRows) { sum += getHashCode(rows(i)).getInt(0) @@ -66,8 +66,8 @@ object HashBenchmark { val getHashCode64b = UnsafeProjection.create(new XxHash64(attrs) :: Nil, attrs) benchmark.addCase("codegen version 64-bit") { _: Int => + var sum = 0 for (_ <- 0L until iters) { - var sum = 0 var i = 0 while (i < numRows) { sum += getHashCode64b(rows(i)).getInt(0) @@ -76,30 +76,44 @@ object HashBenchmark { } } + val getHiveHashCode = UnsafeProjection.create(new HiveHash(attrs) :: Nil, attrs) + benchmark.addCase("codegen HiveHash version") { _: Int => + var sum = 0 + for (_ <- 0L until iters) { + var i = 0 + while (i < numRows) { + sum += getHiveHashCode(rows(i)).getInt(0) + i += 1 + } + } + } + benchmark.run() } def main(args: Array[String]): Unit = { val singleInt = new StructType().add("i", IntegerType) /* - Intel(R) Core(TM) i7-4750HQ CPU @ 2.00GHz - Hash For single ints: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - interpreted version 1006 / 1011 133.4 7.5 1.0X - codegen version 1835 / 1839 73.1 13.7 0.5X - codegen version 64-bit 1627 / 1628 82.5 12.1 0.6X - */ + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + Hash For single ints: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + interpreted version 3262 / 3267 164.6 6.1 1.0X + codegen version 6448 / 6718 83.3 12.0 0.5X + codegen version 64-bit 6088 / 6154 88.2 11.3 0.5X + codegen HiveHash version 4732 / 4745 113.5 8.8 0.7X + */ test("single ints", singleInt, 1 << 15, 1 << 14) val singleLong = new StructType().add("i", LongType) /* - Intel(R) Core(TM) i7-4750HQ CPU @ 2.00GHz - Hash For single longs: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - interpreted version 1196 / 1209 112.2 8.9 1.0X - codegen version 2178 / 2181 61.6 16.2 0.5X - codegen version 64-bit 1752 / 1753 76.6 13.1 0.7X - */ + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + Hash For single longs: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + interpreted version 3716 / 3726 144.5 6.9 1.0X + codegen version 7706 / 7732 69.7 14.4 0.5X + codegen version 64-bit 6370 / 6399 84.3 11.9 0.6X + codegen HiveHash version 4924 / 5026 109.0 9.2 0.8X + */ test("single longs", singleLong, 1 << 15, 1 << 14) val normal = new StructType() @@ -118,13 +132,14 @@ object HashBenchmark { .add("date", DateType) .add("timestamp", TimestampType) /* - Intel(R) Core(TM) i7-4750HQ CPU @ 2.00GHz - Hash For normal: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - interpreted version 2713 / 2715 0.8 1293.5 1.0X - codegen version 2015 / 2018 1.0 960.9 1.3X - codegen version 64-bit 735 / 738 2.9 350.7 3.7X - */ + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + Hash For normal: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + interpreted version 2985 / 3013 0.7 1423.4 1.0X + codegen version 2422 / 2434 0.9 1155.1 1.2X + codegen version 64-bit 856 / 920 2.5 408.0 3.5X + codegen HiveHash version 4501 / 4979 0.5 2146.4 0.7X + */ test("normal", normal, 1 << 10, 1 << 11) val arrayOfInt = ArrayType(IntegerType) @@ -132,13 +147,14 @@ object HashBenchmark { .add("array", arrayOfInt) .add("arrayOfArray", ArrayType(arrayOfInt)) /* - Intel(R) Core(TM) i7-4750HQ CPU @ 2.00GHz - Hash For array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - interpreted version 1498 / 1499 0.1 11432.1 1.0X - codegen version 2642 / 2643 0.0 20158.4 0.6X - codegen version 64-bit 2421 / 2424 0.1 18472.5 0.6X - */ + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + Hash For array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + interpreted version 3100 / 3555 0.0 23651.8 1.0X + codegen version 5779 / 5865 0.0 44088.4 0.5X + codegen version 64-bit 4738 / 4821 0.0 36151.7 0.7X + codegen HiveHash version 2200 / 2246 0.1 16785.9 1.4X + */ test("array", array, 1 << 8, 1 << 9) val mapOfInt = MapType(IntegerType, IntegerType) @@ -146,13 +162,14 @@ object HashBenchmark { .add("map", mapOfInt) .add("mapOfMap", MapType(IntegerType, mapOfInt)) /* - Intel(R) Core(TM) i7-4750HQ CPU @ 2.00GHz - Hash For map: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - interpreted version 1612 / 1618 0.0 393553.4 1.0X - codegen version 149 / 150 0.0 36381.2 10.8X - codegen version 64-bit 144 / 145 0.0 35122.1 11.2X - */ + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + Hash For map: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + interpreted version 0 / 0 48.1 20.8 1.0X + codegen version 257 / 275 0.0 62768.7 0.0X + codegen version 64-bit 226 / 240 0.0 55224.5 0.0X + codegen HiveHash version 89 / 96 0.0 21708.8 0.0X + */ test("map", map, 1 << 6, 1 << 6) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/HashByteArrayBenchmark.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/HashByteArrayBenchmark.scala index 53f21a8442429..2a753a0c84ed5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/HashByteArrayBenchmark.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/HashByteArrayBenchmark.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import java.util.Random -import org.apache.spark.sql.catalyst.expressions.XXH64 +import org.apache.spark.sql.catalyst.expressions.{HiveHasher, XXH64} import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.hash.Murmur3_x86_32 import org.apache.spark.util.Benchmark @@ -38,8 +38,8 @@ object HashByteArrayBenchmark { val benchmark = new Benchmark("Hash byte arrays with length " + length, iters * numArrays) benchmark.addCase("Murmur3_x86_32") { _: Int => + var sum = 0L for (_ <- 0L until iters) { - var sum = 0 var i = 0 while (i < numArrays) { sum += Murmur3_x86_32.hashUnsafeBytes(arrays(i), Platform.BYTE_ARRAY_OFFSET, length, 42) @@ -49,8 +49,8 @@ object HashByteArrayBenchmark { } benchmark.addCase("xxHash 64-bit") { _: Int => + var sum = 0L for (_ <- 0L until iters) { - var sum = 0L var i = 0 while (i < numArrays) { sum += XXH64.hashUnsafeBytes(arrays(i), Platform.BYTE_ARRAY_OFFSET, length, 42) @@ -59,90 +59,110 @@ object HashByteArrayBenchmark { } } + benchmark.addCase("HiveHasher") { _: Int => + var sum = 0L + for (_ <- 0L until iters) { + var i = 0 + while (i < numArrays) { + sum += HiveHasher.hashUnsafeBytes(arrays(i), Platform.BYTE_ARRAY_OFFSET, length) + i += 1 + } + } + } + benchmark.run() } def main(args: Array[String]): Unit = { /* - Intel(R) Core(TM) i7-4750HQ CPU @ 2.00GHz - Hash byte arrays with length 8: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - Murmur3_x86_32 11 / 12 185.1 5.4 1.0X - xxHash 64-bit 17 / 18 120.0 8.3 0.6X + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + Hash byte arrays with length 8: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Murmur3_x86_32 12 / 16 174.3 5.7 1.0X + xxHash 64-bit 17 / 22 120.0 8.3 0.7X + HiveHasher 13 / 15 162.1 6.2 0.9X */ test(8, 42L, 1 << 10, 1 << 11) /* - Intel(R) Core(TM) i7-4750HQ CPU @ 2.00GHz - Hash byte arrays with length 16: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - Murmur3_x86_32 18 / 18 118.6 8.4 1.0X - xxHash 64-bit 20 / 21 102.5 9.8 0.9X + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + Hash byte arrays with length 16: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Murmur3_x86_32 19 / 22 107.6 9.3 1.0X + xxHash 64-bit 20 / 24 104.6 9.6 1.0X + HiveHasher 24 / 28 87.0 11.5 0.8X */ test(16, 42L, 1 << 10, 1 << 11) /* - Intel(R) Core(TM) i7-4750HQ CPU @ 2.00GHz - Hash byte arrays with length 24: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - Murmur3_x86_32 24 / 24 86.6 11.5 1.0X - xxHash 64-bit 23 / 23 93.2 10.7 1.1X + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + Hash byte arrays with length 24: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Murmur3_x86_32 28 / 32 74.8 13.4 1.0X + xxHash 64-bit 24 / 29 87.3 11.5 1.2X + HiveHasher 36 / 41 57.7 17.3 0.8X */ test(24, 42L, 1 << 10, 1 << 11) // Add 31 to all arrays to create worse case alignment for xxHash. /* - Intel(R) Core(TM) i7-4750HQ CPU @ 2.00GHz - Hash byte arrays with length 31: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - Murmur3_x86_32 38 / 39 54.7 18.3 1.0X - xxHash 64-bit 33 / 33 64.4 15.5 1.2X + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + Hash byte arrays with length 31: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Murmur3_x86_32 41 / 45 51.1 19.6 1.0X + xxHash 64-bit 36 / 44 58.8 17.0 1.2X + HiveHasher 49 / 54 42.6 23.5 0.8X */ test(31, 42L, 1 << 10, 1 << 11) /* - Intel(R) Core(TM) i7-4750HQ CPU @ 2.00GHz - Hash byte arrays with length 95: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - Murmur3_x86_32 91 / 94 22.9 43.6 1.0X - xxHash 64-bit 68 / 69 30.6 32.7 1.3X + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + Hash byte arrays with length 95: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Murmur3_x86_32 100 / 110 21.0 47.7 1.0X + xxHash 64-bit 74 / 78 28.2 35.5 1.3X + HiveHasher 189 / 196 11.1 90.3 0.5X */ test(64 + 31, 42L, 1 << 10, 1 << 11) /* - Intel(R) Core(TM) i7-4750HQ CPU @ 2.00GHz - Hash byte arrays with length 287: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - Murmur3_x86_32 268 / 268 7.8 127.6 1.0X - xxHash 64-bit 108 / 109 19.4 51.6 2.5X + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + Hash byte arrays with length 287: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Murmur3_x86_32 299 / 311 7.0 142.4 1.0X + xxHash 64-bit 113 / 122 18.5 54.1 2.6X + HiveHasher 620 / 624 3.4 295.5 0.5X */ test(256 + 31, 42L, 1 << 10, 1 << 11) /* - Intel(R) Core(TM) i7-4750HQ CPU @ 2.00GHz - Hash byte arrays with length 1055: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - Murmur3_x86_32 942 / 945 2.2 449.4 1.0X - xxHash 64-bit 276 / 276 7.6 131.4 3.4X + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + Hash byte arrays with length 1055: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Murmur3_x86_32 1068 / 1070 2.0 509.1 1.0X + xxHash 64-bit 306 / 315 6.9 145.9 3.5X + HiveHasher 2316 / 2369 0.9 1104.3 0.5X */ test(1024 + 31, 42L, 1 << 10, 1 << 11) /* - Intel(R) Core(TM) i7-4750HQ CPU @ 2.00GHz - Hash byte arrays with length 2079: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - Murmur3_x86_32 1839 / 1843 1.1 876.8 1.0X - xxHash 64-bit 445 / 448 4.7 212.1 4.1X + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + Hash byte arrays with length 2079: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Murmur3_x86_32 2252 / 2274 0.9 1074.1 1.0X + xxHash 64-bit 534 / 580 3.9 254.6 4.2X + HiveHasher 4739 / 4786 0.4 2259.8 0.5X */ test(2048 + 31, 42L, 1 << 10, 1 << 11) /* - Intel(R) Core(TM) i7-4750HQ CPU @ 2.00GHz - Hash byte arrays with length 8223: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - Murmur3_x86_32 7307 / 7310 0.3 3484.4 1.0X - xxHash 64-bit 1487 / 1488 1.4 709.1 4.9X - */ + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + Hash byte arrays with length 8223: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Murmur3_x86_32 9249 / 9586 0.2 4410.5 1.0X + xxHash 64-bit 2897 / 3241 0.7 1381.6 3.2X + HiveHasher 19392 / 20211 0.1 9246.6 0.5X + */ test(8192 + 31, 42L, 1 << 10, 1 << 11) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala index 33916c0891866..13ce588462028 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala @@ -145,7 +145,7 @@ class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { val inputGenerator = RandomDataGenerator.forType(inputSchema, nullable = false).get val encoder = RowEncoder(inputSchema) val seed = scala.util.Random.nextInt() - test(s"murmur3/xxHash64 hash: ${inputSchema.simpleString}") { + test(s"murmur3/xxHash64/hive hash: ${inputSchema.simpleString}") { for (_ <- 1 to 10) { val input = encoder.toRow(inputGenerator.apply().asInstanceOf[Row]).asInstanceOf[UnsafeRow] val literals = input.toSeq(inputSchema).zip(inputSchema.map(_.dataType)).map { @@ -154,6 +154,7 @@ class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { // Only test the interpreted version has same result with codegen version. checkEvaluation(Murmur3Hash(literals, seed), Murmur3Hash(literals, seed).eval()) checkEvaluation(XxHash64(literals, seed), XxHash64(literals, seed).eval()) + checkEvaluation(HiveHash(literals), HiveHash(literals).eval()) } } } From c9fe10d4ed8df5ac4bd0f1eb8c9cd19244e27736 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 4 Oct 2016 22:58:43 -0700 Subject: [PATCH 195/306] [SPARK-17658][SPARKR] read.df/write.df API taking path optionally in SparkR MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? `write.df`/`read.df` API require path which is not actually always necessary in Spark. Currently, it only affects the datasources implementing `CreatableRelationProvider`. Currently, Spark currently does not have internal data sources implementing this but it'd affect other external datasources. In addition we'd be able to use this way in Spark's JDBC datasource after https://github.com/apache/spark/pull/12601 is merged. **Before** - `read.df` ```r > read.df(source = "json") Error in dispatchFunc("read.df(path = NULL, source = NULL, schema = NULL, ...)", : argument "x" is missing with no default ``` ```r > read.df(path = c(1, 2)) Error in dispatchFunc("read.df(path = NULL, source = NULL, schema = NULL, ...)", : argument "x" is missing with no default ``` ```r > read.df(c(1, 2)) Error in invokeJava(isStatic = TRUE, className, methodName, ...) : java.lang.ClassCastException: java.lang.Double cannot be cast to java.lang.String at org.apache.spark.sql.execution.datasources.DataSource.hasMetadata(DataSource.scala:300) at ... In if (is.na(object)) { : ... ``` - `write.df` ```r > write.df(df, source = "json") Error in (function (classes, fdef, mtable) : unable to find an inherited method for function ‘write.df’ for signature ‘"function", "missing"’ ``` ```r > write.df(df, source = c(1, 2)) Error in (function (classes, fdef, mtable) : unable to find an inherited method for function ‘write.df’ for signature ‘"SparkDataFrame", "missing"’ ``` ```r > write.df(df, mode = TRUE) Error in (function (classes, fdef, mtable) : unable to find an inherited method for function ‘write.df’ for signature ‘"SparkDataFrame", "missing"’ ``` **After** - `read.df` ```r > read.df(source = "json") Error in loadDF : analysis error - Unable to infer schema for JSON at . It must be specified manually; ``` ```r > read.df(path = c(1, 2)) Error in f(x, ...) : path should be charactor, null or omitted. ``` ```r > read.df(c(1, 2)) Error in f(x, ...) : path should be charactor, null or omitted. ``` - `write.df` ```r > write.df(df, source = "json") Error in save : illegal argument - 'path' is not specified ``` ```r > write.df(df, source = c(1, 2)) Error in .local(df, path, ...) : source should be charactor, null or omitted. It is 'parquet' by default. ``` ```r > write.df(df, mode = TRUE) Error in .local(df, path, ...) : mode should be charactor or omitted. It is 'error' by default. ``` ## How was this patch tested? Unit tests in `test_sparkSQL.R` Author: hyukjinkwon Closes #15231 from HyukjinKwon/write-default-r. --- R/pkg/R/DataFrame.R | 20 ++++++--- R/pkg/R/SQLContext.R | 19 ++++++--- R/pkg/R/generics.R | 4 +- R/pkg/R/utils.R | 52 +++++++++++++++++++++++ R/pkg/inst/tests/testthat/test_sparkSQL.R | 35 +++++++++++++++ R/pkg/inst/tests/testthat/test_utils.R | 10 +++++ 6 files changed, 127 insertions(+), 13 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 40f1f0f4429e0..75861d5de7092 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -2608,7 +2608,7 @@ setMethod("except", #' @param ... additional argument(s) passed to the method. #' #' @family SparkDataFrame functions -#' @aliases write.df,SparkDataFrame,character-method +#' @aliases write.df,SparkDataFrame-method #' @rdname write.df #' @name write.df #' @export @@ -2622,21 +2622,31 @@ setMethod("except", #' } #' @note write.df since 1.4.0 setMethod("write.df", - signature(df = "SparkDataFrame", path = "character"), - function(df, path, source = NULL, mode = "error", ...) { + signature(df = "SparkDataFrame"), + function(df, path = NULL, source = NULL, mode = "error", ...) { + if (!is.null(path) && !is.character(path)) { + stop("path should be charactor, NULL or omitted.") + } + if (!is.null(source) && !is.character(source)) { + stop("source should be character, NULL or omitted. It is the datasource specified ", + "in 'spark.sql.sources.default' configuration by default.") + } + if (!is.character(mode)) { + stop("mode should be charactor or omitted. It is 'error' by default.") + } if (is.null(source)) { source <- getDefaultSqlSource() } jmode <- convertToJSaveMode(mode) options <- varargsToEnv(...) if (!is.null(path)) { - options[["path"]] <- path + options[["path"]] <- path } write <- callJMethod(df@sdf, "write") write <- callJMethod(write, "format", source) write <- callJMethod(write, "mode", jmode) write <- callJMethod(write, "options", options) - write <- callJMethod(write, "save", path) + write <- handledCallJMethod(write, "save") }) #' @rdname write.df diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index ce531c3f88863..baa87824beb91 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -771,6 +771,13 @@ dropTempView <- function(viewName) { #' @method read.df default #' @note read.df since 1.4.0 read.df.default <- function(path = NULL, source = NULL, schema = NULL, na.strings = "NA", ...) { + if (!is.null(path) && !is.character(path)) { + stop("path should be charactor, NULL or omitted.") + } + if (!is.null(source) && !is.character(source)) { + stop("source should be character, NULL or omitted. It is the datasource specified ", + "in 'spark.sql.sources.default' configuration by default.") + } sparkSession <- getSparkSession() options <- varargsToEnv(...) if (!is.null(path)) { @@ -784,16 +791,16 @@ read.df.default <- function(path = NULL, source = NULL, schema = NULL, na.string } if (!is.null(schema)) { stopifnot(class(schema) == "structType") - sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "loadDF", sparkSession, source, - schema$jobj, options) + sdf <- handledCallJStatic("org.apache.spark.sql.api.r.SQLUtils", "loadDF", sparkSession, + source, schema$jobj, options) } else { - sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", - "loadDF", sparkSession, source, options) + sdf <- handledCallJStatic("org.apache.spark.sql.api.r.SQLUtils", "loadDF", sparkSession, + source, options) } dataFrame(sdf) } -read.df <- function(x, ...) { +read.df <- function(x = NULL, ...) { dispatchFunc("read.df(path = NULL, source = NULL, schema = NULL, ...)", x, ...) } @@ -805,7 +812,7 @@ loadDF.default <- function(path = NULL, source = NULL, schema = NULL, ...) { read.df(path, source, schema, ...) } -loadDF <- function(x, ...) { +loadDF <- function(x = NULL, ...) { dispatchFunc("loadDF(path = NULL, source = NULL, schema = NULL, ...)", x, ...) } diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 67a999da9bc26..90a02e2778310 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -633,7 +633,7 @@ setGeneric("transform", function(`_data`, ...) {standardGeneric("transform") }) #' @rdname write.df #' @export -setGeneric("write.df", function(df, path, source = NULL, mode = "error", ...) { +setGeneric("write.df", function(df, path = NULL, source = NULL, mode = "error", ...) { standardGeneric("write.df") }) @@ -732,7 +732,7 @@ setGeneric("withColumnRenamed", #' @rdname write.df #' @export -setGeneric("write.df", function(df, path, ...) { standardGeneric("write.df") }) +setGeneric("write.df", function(df, path = NULL, ...) { standardGeneric("write.df") }) #' @rdname randomSplit #' @export diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index 248c57532b6cf..e69666453480c 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -698,6 +698,58 @@ isSparkRShell <- function() { grepl(".*shell\\.R$", Sys.getenv("R_PROFILE_USER"), perl = TRUE) } +# Works identically with `callJStatic(...)` but throws a pretty formatted exception. +handledCallJStatic <- function(cls, method, ...) { + result <- tryCatch(callJStatic(cls, method, ...), + error = function(e) { + captureJVMException(e, method) + }) + result +} + +# Works identically with `callJMethod(...)` but throws a pretty formatted exception. +handledCallJMethod <- function(obj, method, ...) { + result <- tryCatch(callJMethod(obj, method, ...), + error = function(e) { + captureJVMException(e, method) + }) + result +} + +captureJVMException <- function(e, method) { + rawmsg <- as.character(e) + if (any(grep("^Error in .*?: ", rawmsg))) { + # If the exception message starts with "Error in ...", this is possibly + # "Error in invokeJava(...)". Here, it replaces the characters to + # `paste("Error in", method, ":")` in order to identify which function + # was called in JVM side. + stacktrace <- strsplit(rawmsg, "Error in .*?: ")[[1]] + rmsg <- paste("Error in", method, ":") + stacktrace <- paste(rmsg[1], stacktrace[2]) + } else { + # Otherwise, do not convert the error message just in case. + stacktrace <- rawmsg + } + + if (any(grep("java.lang.IllegalArgumentException: ", stacktrace))) { + msg <- strsplit(stacktrace, "java.lang.IllegalArgumentException: ", fixed = TRUE)[[1]] + # Extract "Error in ..." message. + rmsg <- msg[1] + # Extract the first message of JVM exception. + first <- strsplit(msg[2], "\r?\n\tat")[[1]][1] + stop(paste0(rmsg, "illegal argument - ", first), call. = FALSE) + } else if (any(grep("org.apache.spark.sql.AnalysisException: ", stacktrace))) { + msg <- strsplit(stacktrace, "org.apache.spark.sql.AnalysisException: ", fixed = TRUE)[[1]] + # Extract "Error in ..." message. + rmsg <- msg[1] + # Extract the first message of JVM exception. + first <- strsplit(msg[2], "\r?\n\tat")[[1]][1] + stop(paste0(rmsg, "analysis error - ", first), call. = FALSE) + } else { + stop(stacktrace, call. = FALSE) + } +} + # rbind a list of rows with raw (binary) columns # # @param inputData a list of rows, with each row a list diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 9d874a0988716..f5ab601f274fe 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -2544,6 +2544,41 @@ test_that("Spark version from SparkSession", { expect_equal(ver, version) }) +test_that("Call DataFrameWriter.save() API in Java without path and check argument types", { + df <- read.df(jsonPath, "json") + # This tests if the exception is thrown from JVM not from SparkR side. + # It makes sure that we can omit path argument in write.df API and then it calls + # DataFrameWriter.save() without path. + expect_error(write.df(df, source = "csv"), + "Error in save : illegal argument - 'path' is not specified") + + # Arguments checking in R side. + expect_error(write.df(df, "data.tmp", source = c(1, 2)), + paste("source should be character, NULL or omitted. It is the datasource specified", + "in 'spark.sql.sources.default' configuration by default.")) + expect_error(write.df(df, path = c(3)), + "path should be charactor, NULL or omitted.") + expect_error(write.df(df, mode = TRUE), + "mode should be charactor or omitted. It is 'error' by default.") +}) + +test_that("Call DataFrameWriter.load() API in Java without path and check argument types", { + # This tests if the exception is thrown from JVM not from SparkR side. + # It makes sure that we can omit path argument in read.df API and then it calls + # DataFrameWriter.load() without path. + expect_error(read.df(source = "json"), + paste("Error in loadDF : analysis error - Unable to infer schema for JSON at .", + "It must be specified manually")) + expect_error(read.df("arbitrary_path"), "Error in loadDF : analysis error - Path does not exist") + + # Arguments checking in R side. + expect_error(read.df(path = c(3)), + "path should be charactor, NULL or omitted.") + expect_error(read.df(jsonPath, source = c(1, 2)), + paste("source should be character, NULL or omitted. It is the datasource specified", + "in 'spark.sql.sources.default' configuration by default.")) +}) + unlink(parquetPath) unlink(orcPath) unlink(jsonPath) diff --git a/R/pkg/inst/tests/testthat/test_utils.R b/R/pkg/inst/tests/testthat/test_utils.R index 77f25292f3f29..69ed5549168b1 100644 --- a/R/pkg/inst/tests/testthat/test_utils.R +++ b/R/pkg/inst/tests/testthat/test_utils.R @@ -166,6 +166,16 @@ test_that("convertToJSaveMode", { 'mode should be one of "append", "overwrite", "error", "ignore"') #nolint }) +test_that("captureJVMException", { + method <- "getSQLDataType" + expect_error(tryCatch(callJStatic("org.apache.spark.sql.api.r.SQLUtils", method, + "unknown"), + error = function(e) { + captureJVMException(e, method) + }), + "Error in getSQLDataType : illegal argument - Invalid type unknown") +}) + test_that("hashCode", { expect_error(hashCode("bc53d3605e8a5b7de1e8e271c2317645"), NA) }) From 89516c1c4a167249b0c82f60a62edb45ede3bd2c Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Tue, 4 Oct 2016 23:48:26 -0700 Subject: [PATCH 196/306] [SPARK-17258][SQL] Parse scientific decimal literals as decimals ## What changes were proposed in this pull request? Currently Spark SQL parses regular decimal literals (e.g. `10.00`) as decimals and scientific decimal literals (e.g. `10.0e10`) as doubles. The difference between the two confuses most users. This PR unifies the parsing behavior and also parses scientific decimal literals as decimals. This implications in tests are limited to a single Hive compatibility test. ## How was this patch tested? Updated tests in `ExpressionParserSuite` and `SQLQueryTestSuite`. Author: Herman van Hovell Closes #14828 from hvanhovell/SPARK-17258. --- .../spark/sql/catalyst/parser/SqlBase.g4 | 7 +----- .../sql/catalyst/parser/AstBuilder.scala | 8 ------- .../parser/ExpressionParserSuite.scala | 24 +++++++++---------- .../resources/sql-tests/inputs/literals.sql | 8 ++++--- .../sql-tests/results/arithmetic.sql.out | 2 +- .../sql-tests/results/literals.sql.out | 24 ++++++++++++------- .../execution/HiveCompatibilitySuite.scala | 4 +++- 7 files changed, 38 insertions(+), 39 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index c336a0c8eab7a..87719d9ee2bc4 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -653,7 +653,6 @@ quotedIdentifier number : MINUS? DECIMAL_VALUE #decimalLiteral - | MINUS? SCIENTIFIC_DECIMAL_VALUE #scientificDecimalLiteral | MINUS? INTEGER_VALUE #integerLiteral | MINUS? BIGINT_LITERAL #bigIntLiteral | MINUS? SMALLINT_LITERAL #smallIntLiteral @@ -944,12 +943,8 @@ INTEGER_VALUE ; DECIMAL_VALUE - : DECIMAL_DIGITS {isValidDecimal()}? - ; - -SCIENTIFIC_DECIMAL_VALUE : DIGIT+ EXPONENT - | DECIMAL_DIGITS EXPONENT {isValidDecimal()}? + | DECIMAL_DIGITS EXPONENT? {isValidDecimal()}? ; DOUBLE_LITERAL diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index cd0c70a49150d..bf3f30279a6fe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1282,14 +1282,6 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { } } - /** - * Create a double literal for a number denoted in scientific notation. - */ - override def visitScientificDecimalLiteral( - ctx: ScientificDecimalLiteralContext): Literal = withOrigin(ctx) { - Literal(ctx.getText.toDouble) - } - /** * Create a decimal literal for a regular decimal number. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index 3718ac5f1e77b..0fb1138478a9b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -352,6 +352,10 @@ class ExpressionParserSuite extends PlanTest { } test("literals") { + def testDecimal(value: String): Unit = { + assertEqual(value, Literal(BigDecimal(value).underlying)) + } + // NULL assertEqual("null", Literal(null)) @@ -362,20 +366,18 @@ class ExpressionParserSuite extends PlanTest { // Integral should have the narrowest possible type assertEqual("787324", Literal(787324)) assertEqual("7873247234798249234", Literal(7873247234798249234L)) - assertEqual("78732472347982492793712334", - Literal(BigDecimal("78732472347982492793712334").underlying())) + testDecimal("78732472347982492793712334") // Decimal - assertEqual("7873247234798249279371.2334", - Literal(BigDecimal("7873247234798249279371.2334").underlying())) + testDecimal("7873247234798249279371.2334") // Scientific Decimal - assertEqual("9.0e1", 90d) - assertEqual(".9e+2", 90d) - assertEqual("0.9e+2", 90d) - assertEqual("900e-1", 90d) - assertEqual("900.0E-1", 90d) - assertEqual("9.e+1", 90d) + testDecimal("9.0e1") + testDecimal(".9e+2") + testDecimal("0.9e+2") + testDecimal("900e-1") + testDecimal("900.0E-1") + testDecimal("9.e+1") intercept(".e3") // Tiny Int Literal @@ -395,8 +397,6 @@ class ExpressionParserSuite extends PlanTest { assertEqual("10.0D", Literal(10.0D)) intercept("-1.8E308D", s"does not fit in range") intercept("1.8E308D", s"does not fit in range") - // TODO we need to figure out if we should throw an exception here! - assertEqual("1E309", Literal(Double.PositiveInfinity)) // BigDecimal Literal assertEqual("90912830918230182310293801923652346786BD", diff --git a/sql/core/src/test/resources/sql-tests/inputs/literals.sql b/sql/core/src/test/resources/sql-tests/inputs/literals.sql index 40dceb19cfc5b..37b4b7606d12b 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/literals.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/literals.sql @@ -50,14 +50,14 @@ select 1D, 1.2D, 1e10, 1.5e5, .10D, 0.10D, .1e5, .9e+2, 0.9e+2, 900e-1, 9.e+1; select -1D, -1.2D, -1e10, -1.5e5, -.10D, -0.10D, -.1e5; -- negative double select .e3; --- inf and -inf +-- very large decimals (overflowing double). select 1E309, -1E309; -- decimal parsing select 0.3, -0.8, .5, -.18, 0.1111, .1111; --- super large scientific notation numbers should still be valid doubles -select 123456789012345678901234567890123456789e10, 123456789012345678901234567890123456789.1e10; +-- super large scientific notation double literals should still be valid doubles +select 123456789012345678901234567890123456789e10d, 123456789012345678901234567890123456789.1e10d; -- string select "Hello Peter!", 'hello lee!'; @@ -103,3 +103,5 @@ select x'2379ACFe'; -- invalid hexadecimal binary literal select X'XuZ'; +-- Hive literal_double test. +SELECT 3.14, -3.14, 3.14e8, 3.14e-8, -3.14e8, -3.14e-8, 3.14e+8, 3.14E8, 3.14E-8; diff --git a/sql/core/src/test/resources/sql-tests/results/arithmetic.sql.out b/sql/core/src/test/resources/sql-tests/results/arithmetic.sql.out index 6abe048af477d..ce42c016a7100 100644 --- a/sql/core/src/test/resources/sql-tests/results/arithmetic.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/arithmetic.sql.out @@ -29,7 +29,7 @@ struct<-5.2:decimal(2,1)> -- !query 3 select +6.8e0 -- !query 3 schema -struct<6.8:double> +struct<6.8:decimal(2,1)> -- !query 3 output 6.8 diff --git a/sql/core/src/test/resources/sql-tests/results/literals.sql.out b/sql/core/src/test/resources/sql-tests/results/literals.sql.out index e2d8daef9868f..95d4413148f64 100644 --- a/sql/core/src/test/resources/sql-tests/results/literals.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/literals.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 42 +-- Number of queries: 43 -- !query 0 @@ -167,17 +167,17 @@ select 1234567890123456789012345678901234567890.0 -- !query 17 select 1D, 1.2D, 1e10, 1.5e5, .10D, 0.10D, .1e5, .9e+2, 0.9e+2, 900e-1, 9.e+1 -- !query 17 schema -struct<1.0:double,1.2:double,1.0E10:double,150000.0:double,0.1:double,0.1:double,10000.0:double,90.0:double,90.0:double,90.0:double,90.0:double> +struct<1.0:double,1.2:double,1E+10:decimal(1,-10),1.5E+5:decimal(2,-4),0.1:double,0.1:double,1E+4:decimal(1,-4),9E+1:decimal(1,-1),9E+1:decimal(1,-1),90.0:decimal(3,1),9E+1:decimal(1,-1)> -- !query 17 output -1.0 1.2 1.0E10 150000.0 0.1 0.1 10000.0 90.0 90.0 90.0 90.0 +1.0 1.2 10000000000 150000 0.1 0.1 10000 90 90 90 90 -- !query 18 select -1D, -1.2D, -1e10, -1.5e5, -.10D, -0.10D, -.1e5 -- !query 18 schema -struct<-1.0:double,-1.2:double,-1.0E10:double,-150000.0:double,-0.1:double,-0.1:double,-10000.0:double> +struct<-1.0:double,-1.2:double,-1E+10:decimal(1,-10),-1.5E+5:decimal(2,-4),-0.1:double,-0.1:double,-1E+4:decimal(1,-4)> -- !query 18 output --1.0 -1.2 -1.0E10 -150000.0 -0.1 -0.1 -10000.0 +-1.0 -1.2 -10000000000 -150000 -0.1 -0.1 -10000 -- !query 19 @@ -197,9 +197,9 @@ select .e3 -- !query 20 select 1E309, -1E309 -- !query 20 schema -struct +struct<1E+309:decimal(1,-309),-1E+309:decimal(1,-309)> -- !query 20 output -Infinity -Infinity +1000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000 -1000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000 -- !query 21 @@ -211,7 +211,7 @@ struct<0.3:decimal(1,1),-0.8:decimal(1,1),0.5:decimal(1,1),-0.18:decimal(2,2),0. -- !query 22 -select 123456789012345678901234567890123456789e10, 123456789012345678901234567890123456789.1e10 +select 123456789012345678901234567890123456789e10d, 123456789012345678901234567890123456789.1e10d -- !query 22 schema struct<1.2345678901234568E48:double,1.2345678901234568E48:double> -- !query 22 output @@ -408,3 +408,11 @@ contains illegal character for hexBinary: 0XuZ(line 1, pos 7) == SQL == select X'XuZ' -------^^^ + + +-- !query 42 +SELECT 3.14, -3.14, 3.14e8, 3.14e-8, -3.14e8, -3.14e-8, 3.14e+8, 3.14E8, 3.14E-8 +-- !query 42 schema +struct<3.14:decimal(3,2),-3.14:decimal(3,2),3.14E+8:decimal(3,-6),3.14E-8:decimal(10,10),-3.14E+8:decimal(3,-6),-3.14E-8:decimal(10,10),3.14E+8:decimal(3,-6),3.14E+8:decimal(3,-6),3.14E-8:decimal(10,10)> +-- !query 42 output +3.14 -3.14 314000000 0.0000000314 -314000000 -0.0000000314 314000000 314000000 0.0000000314 diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index bebcb8f8016b1..f5d10de8cd2bf 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -555,6 +555,9 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "varchar_2", "varchar_join1", + // This test assumes we parse scientific decimals as doubles (we parse them as decimals) + "literal_double", + // These tests are duplicates of joinXYZ "auto_join0", "auto_join1", @@ -832,7 +835,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "leftsemijoin_mr", "limit_pushdown_negative", "lineage1", - "literal_double", "literal_ints", "literal_string", "load_dyn_part1", From 6a05eb24d043aa93390f353850d56efa6124e063 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 5 Oct 2016 10:52:43 -0700 Subject: [PATCH 197/306] [SPARK-17328][SQL] Fix NPE with EXPLAIN DESCRIBE TABLE ## What changes were proposed in this pull request? This PR fixes the following NPE scenario in two ways. **Reported Error Scenario** ```scala scala> sql("EXPLAIN DESCRIBE TABLE x").show(truncate = false) INFO SparkSqlParser: Parsing command: EXPLAIN DESCRIBE TABLE x java.lang.NullPointerException ``` - **DESCRIBE**: Extend `DESCRIBE` syntax to accept `TABLE`. - **EXPLAIN**: Prevent NPE in case of the parsing failure of target statement, e.g., `EXPLAIN DESCRIBE TABLES x`. ## How was this patch tested? Pass the Jenkins test with a new test case. Author: Dongjoon Hyun Closes #15357 from dongjoon-hyun/SPARK-17328. --- .../spark/sql/catalyst/parser/SqlBase.g4 | 2 +- .../spark/sql/execution/SparkSqlParser.scala | 4 +- .../resources/sql-tests/inputs/describe.sql | 4 ++ .../sql-tests/results/describe.sql.out | 58 ++++++++++++++----- .../sql/execution/SparkSqlParserSuite.scala | 18 +++++- 5 files changed, 68 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 87719d9ee2bc4..6a94def65f360 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -136,7 +136,7 @@ statement | SHOW CREATE TABLE tableIdentifier #showCreateTable | (DESC | DESCRIBE) FUNCTION EXTENDED? describeFuncName #describeFunction | (DESC | DESCRIBE) DATABASE EXTENDED? identifier #describeDatabase - | (DESC | DESCRIBE) option=(EXTENDED | FORMATTED)? + | (DESC | DESCRIBE) TABLE? option=(EXTENDED | FORMATTED)? tableIdentifier partitionSpec? describeColName? #describeTable | REFRESH TABLE tableIdentifier #refreshTable | REFRESH .*? #refreshResource diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 7f1e23e665eb1..085bb9fc3c6cc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -265,7 +265,9 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { } val statement = plan(ctx.statement) - if (isExplainableStatement(statement)) { + if (statement == null) { + null // This is enough since ParseException will raise later. + } else if (isExplainableStatement(statement)) { ExplainCommand(statement, extended = ctx.EXTENDED != null, codegen = ctx.CODEGEN != null) } else { ExplainCommand(OneRowRelation) diff --git a/sql/core/src/test/resources/sql-tests/inputs/describe.sql b/sql/core/src/test/resources/sql-tests/inputs/describe.sql index 3f0ae902e0529..84503d0b12a8e 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/describe.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/describe.sql @@ -2,8 +2,12 @@ CREATE TABLE t (a STRING, b INT) PARTITIONED BY (c STRING, d STRING); ALTER TABLE t ADD PARTITION (c='Us', d=1); +DESCRIBE t; + DESC t; +DESC TABLE t; + -- Ignore these because there exist timestamp results, e.g., `Create Table`. -- DESC EXTENDED t; -- DESC FORMATTED t; diff --git a/sql/core/src/test/resources/sql-tests/results/describe.sql.out b/sql/core/src/test/resources/sql-tests/results/describe.sql.out index 37bf303f1bfe4..b448d60c7685d 100644 --- a/sql/core/src/test/resources/sql-tests/results/describe.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/describe.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 8 +-- Number of queries: 10 -- !query 0 @@ -19,7 +19,7 @@ struct<> -- !query 2 -DESC t +DESCRIBE t -- !query 2 schema struct -- !query 2 output @@ -34,7 +34,7 @@ d string -- !query 3 -DESC t PARTITION (c='Us', d=1) +DESC t -- !query 3 schema struct -- !query 3 output @@ -49,30 +49,60 @@ d string -- !query 4 -DESC t PARTITION (c='Us', d=2) +DESC TABLE t -- !query 4 schema -struct<> +struct -- !query 4 output +# Partition Information +# col_name data_type comment +a string +b int +c string +c string +d string +d string + + +-- !query 5 +DESC t PARTITION (c='Us', d=1) +-- !query 5 schema +struct +-- !query 5 output +# Partition Information +# col_name data_type comment +a string +b int +c string +c string +d string +d string + + +-- !query 6 +DESC t PARTITION (c='Us', d=2) +-- !query 6 schema +struct<> +-- !query 6 output org.apache.spark.sql.catalyst.analysis.NoSuchPartitionException Partition not found in table 't' database 'default': c -> Us d -> 2; --- !query 5 +-- !query 7 DESC t PARTITION (c='Us') --- !query 5 schema +-- !query 7 schema struct<> --- !query 5 output +-- !query 7 output org.apache.spark.sql.AnalysisException Partition spec is invalid. The spec (c) must match the partition spec (c, d) defined in table '`default`.`t`'; --- !query 6 +-- !query 8 DESC t PARTITION (c='Us', d) --- !query 6 schema +-- !query 8 schema struct<> --- !query 6 output +-- !query 8 output org.apache.spark.sql.catalyst.parser.ParseException PARTITION specification is incomplete: `d`(line 1, pos 0) @@ -82,9 +112,9 @@ DESC t PARTITION (c='Us', d) ^^^ --- !query 7 +-- !query 9 DROP TABLE t --- !query 7 schema +-- !query 9 schema struct<> --- !query 7 output +-- !query 9 output diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala index 8161c08b2cb48..6712d32924890 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala @@ -17,11 +17,12 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.catalyst.FunctionIdentifier +import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.command.{DescribeFunctionCommand, ShowFunctionsCommand} +import org.apache.spark.sql.execution.command.{DescribeFunctionCommand, DescribeTableCommand, + ShowFunctionsCommand} import org.apache.spark.sql.internal.SQLConf /** @@ -72,4 +73,17 @@ class SparkSqlParserSuite extends PlanTest { DescribeFunctionCommand(FunctionIdentifier("bar", database = Option("f")), isExtended = true)) } + test("SPARK-17328 Fix NPE with EXPLAIN DESCRIBE TABLE") { + assertEqual("describe table t", + DescribeTableCommand( + TableIdentifier("t"), Map.empty, isExtended = false, isFormatted = false)) + assertEqual("describe table extended t", + DescribeTableCommand( + TableIdentifier("t"), Map.empty, isExtended = true, isFormatted = false)) + assertEqual("describe table formatted t", + DescribeTableCommand( + TableIdentifier("t"), Map.empty, isExtended = false, isFormatted = true)) + + intercept("explain describe tables x", "Unsupported SQL statement") + } } From 9df54f5325c2942bb77008ff1810e2fb5f6d848b Mon Sep 17 00:00:00 2001 From: sethah Date: Wed, 5 Oct 2016 18:28:21 +0000 Subject: [PATCH 198/306] [SPARK-17239][ML][DOC] Update user guide for multiclass logistic regression ## What changes were proposed in this pull request? Updates user guide to reflect that LogisticRegression now supports multiclass. Also adds new examples to show multiclass training. ## How was this patch tested? Ran locally using spark-submit, run-example, and copy/paste from user guide into shells. Generated docs and verified correct output. Author: sethah Closes #15349 from sethah/SPARK-17239. --- docs/ml-classification-regression.md | 65 +++++++++++++++++-- ...gisticRegressionWithElasticNetExample.java | 14 ++++ ...gisticRegressionWithElasticNetExample.java | 55 ++++++++++++++++ .../logistic_regression_with_elastic_net.py | 10 +++ ...ss_logistic_regression_with_elastic_net.py | 48 ++++++++++++++ ...isticRegressionWithElasticNetExample.scala | 13 ++++ ...isticRegressionWithElasticNetExample.scala | 57 ++++++++++++++++ 7 files changed, 255 insertions(+), 7 deletions(-) create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaMulticlassLogisticRegressionWithElasticNetExample.java create mode 100644 examples/src/main/python/ml/multiclass_logistic_regression_with_elastic_net.py create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/MulticlassLogisticRegressionWithElasticNetExample.scala diff --git a/docs/ml-classification-regression.md b/docs/ml-classification-regression.md index 7c2437eacde3f..bb2e404330cc0 100644 --- a/docs/ml-classification-regression.md +++ b/docs/ml-classification-regression.md @@ -34,17 +34,22 @@ discussing specific classes of algorithms, such as linear methods, trees, and en ## Logistic regression -Logistic regression is a popular method to predict a binary response. It is a special case of [Generalized Linear models](https://en.wikipedia.org/wiki/Generalized_linear_model) that predicts the probability of the outcome. -For more background and more details about the implementation, refer to the documentation of the [logistic regression in `spark.mllib`](mllib-linear-methods.html#logistic-regression). +Logistic regression is a popular method to predict a categorical response. It is a special case of [Generalized Linear models](https://en.wikipedia.org/wiki/Generalized_linear_model) that predicts the probability of the outcomes. +In `spark.ml` logistic regression can be used to predict a binary outcome by using binomial logistic regression, or it can be used to predict a multiclass outcome by using multinomial logistic regression. Use the `family` +parameter to select between these two algorithms, or leave it unset and Spark will infer the correct variant. - > The current implementation of logistic regression in `spark.ml` only supports binary classes. Support for multiclass regression will be added in the future. + > Multinomial logistic regression can be used for binary classification by setting the `family` param to "multinomial". It will produce two sets of coefficients and two intercepts. > When fitting LogisticRegressionModel without intercept on dataset with constant nonzero column, Spark MLlib outputs zero coefficients for constant nonzero columns. This behavior is the same as R glmnet but different from LIBSVM. +### Binomial logistic regression + +For more background and more details about the implementation of binomial logistic regression, refer to the documentation of [logistic regression in `spark.mllib`](mllib-linear-methods.html#logistic-regression). + **Example** -The following example shows how to train a logistic regression model -with elastic net regularization. `elasticNetParam` corresponds to +The following example shows how to train binomial and multinomial logistic regression +models for binary classification with elastic net regularization. `elasticNetParam` corresponds to $\alpha$ and `regParam` corresponds to $\lambda$.
      @@ -92,8 +97,8 @@ provides a summary for a [`LogisticRegressionModel`](api/java/org/apache/spark/ml/classification/LogisticRegressionModel.html). Currently, only binary classification is supported and the summary must be explicitly cast to -[`BinaryLogisticRegressionTrainingSummary`](api/java/org/apache/spark/ml/classification/BinaryLogisticRegressionTrainingSummary.html). -This will likely change when multiclass classification is supported. +[`BinaryLogisticRegressionTrainingSummary`](api/java/org/apache/spark/ml/classification/BinaryLogisticRegressionTrainingSummary.html). +Support for multiclass model summaries will be added in the future. Continuing the earlier example: @@ -107,6 +112,52 @@ Logistic regression model summary is not yet supported in Python.
      +### Multinomial logistic regression + +Multiclass classification is supported via multinomial logistic (softmax) regression. In multinomial logistic regression, +the algorithm produces $K$ sets of coefficients, or a matrix of dimension $K \times J$ where $K$ is the number of outcome +classes and $J$ is the number of features. If the algorithm is fit with an intercept term then a length $K$ vector of +intercepts is available. + + > Multinomial coefficients are available as `coefficientMatrix` and intercepts are available as `interceptVector`. + + > `coefficients` and `intercept` methods on a logistic regression model trained with multinomial family are not supported. Use `coefficientMatrix` and `interceptVector` instead. + +The conditional probabilities of the outcome classes $k \in \{1, 2, ..., K\}$ are modeled using the softmax function. + +`\[ + P(Y=k|\mathbf{X}, \boldsymbol{\beta}_k, \beta_{0k}) = \frac{e^{\boldsymbol{\beta}_k \cdot \mathbf{X} + \beta_{0k}}}{\sum_{k'=0}^{K-1} e^{\boldsymbol{\beta}_{k'} \cdot \mathbf{X} + \beta_{0k'}}} +\]` + +We minimize the weighted negative log-likelihood, using a multinomial response model, with elastic-net penalty to control for overfitting. + +`\[ +\min_{\beta, \beta_0} -\left[\sum_{i=1}^L w_i \cdot \log P(Y = y_i|\mathbf{x}_i)\right] + \lambda \left[\frac{1}{2}\left(1 - \alpha\right)||\boldsymbol{\beta}||_2^2 + \alpha ||\boldsymbol{\beta}||_1\right] +\]` + +For a detailed derivation please see [here](https://en.wikipedia.org/wiki/Multinomial_logistic_regression#As_a_log-linear_model). + +**Example** + +The following example shows how to train a multiclass logistic regression +model with elastic net regularization. + +
      + +
      +{% include_example scala/org/apache/spark/examples/ml/MulticlassLogisticRegressionWithElasticNetExample.scala %} +
      + +
      +{% include_example java/org/apache/spark/examples/ml/JavaMulticlassLogisticRegressionWithElasticNetExample.java %} +
      + +
      +{% include_example python/ml/multiclass_logistic_regression_with_elastic_net.py %} +
      + +
      + ## Decision tree classifier diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionWithElasticNetExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionWithElasticNetExample.java index 6101c79fb0c98..b8fb5972ea418 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionWithElasticNetExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionWithElasticNetExample.java @@ -48,6 +48,20 @@ public static void main(String[] args) { // Print the coefficients and intercept for logistic regression System.out.println("Coefficients: " + lrModel.coefficients() + " Intercept: " + lrModel.intercept()); + + // We can also use the multinomial family for binary classification + LogisticRegression mlr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(0.3) + .setElasticNetParam(0.8) + .setFamily("multinomial"); + + // Fit the model + LogisticRegressionModel mlrModel = mlr.fit(training); + + // Print the coefficients and intercepts for logistic regression with multinomial family + System.out.println("Multinomial coefficients: " + + lrModel.coefficientMatrix() + "\nMultinomial intercepts: " + mlrModel.interceptVector()); // $example off$ spark.stop(); diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaMulticlassLogisticRegressionWithElasticNetExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaMulticlassLogisticRegressionWithElasticNetExample.java new file mode 100644 index 0000000000000..da410cba2b3f1 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaMulticlassLogisticRegressionWithElasticNetExample.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +// $example on$ +import org.apache.spark.ml.classification.LogisticRegression; +import org.apache.spark.ml.classification.LogisticRegressionModel; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +// $example off$ + +public class JavaMulticlassLogisticRegressionWithElasticNetExample { + public static void main(String[] args) { + SparkSession spark = SparkSession + .builder() + .appName("JavaMulticlassLogisticRegressionWithElasticNetExample") + .getOrCreate(); + + // $example on$ + // Load training data + Dataset training = spark.read().format("libsvm") + .load("data/mllib/sample_multiclass_classification_data.txt"); + + LogisticRegression lr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(0.3) + .setElasticNetParam(0.8); + + // Fit the model + LogisticRegressionModel lrModel = lr.fit(training); + + // Print the coefficients and intercept for multinomial logistic regression + System.out.println("Coefficients: \n" + + lrModel.coefficientMatrix() + " \nIntercept: " + lrModel.interceptVector()); + // $example off$ + + spark.stop(); + } +} diff --git a/examples/src/main/python/ml/logistic_regression_with_elastic_net.py b/examples/src/main/python/ml/logistic_regression_with_elastic_net.py index 33d0689f75cd5..d095fbd373408 100644 --- a/examples/src/main/python/ml/logistic_regression_with_elastic_net.py +++ b/examples/src/main/python/ml/logistic_regression_with_elastic_net.py @@ -40,6 +40,16 @@ # Print the coefficients and intercept for logistic regression print("Coefficients: " + str(lrModel.coefficients)) print("Intercept: " + str(lrModel.intercept)) + + # We can also use the multinomial family for binary classification + mlr = LogisticRegression(maxIter=10, regParam=0.3, elasticNetParam=0.8, family="multinomial") + + # Fit the model + mlrModel = mlr.fit(training) + + # Print the coefficients and intercepts for logistic regression with multinomial family + print("Multinomial coefficients: " + str(mlrModel.coefficientMatrix)) + print("Multinomial intercepts: " + str(mlrModel.interceptVector)) # $example off$ spark.stop() diff --git a/examples/src/main/python/ml/multiclass_logistic_regression_with_elastic_net.py b/examples/src/main/python/ml/multiclass_logistic_regression_with_elastic_net.py new file mode 100644 index 0000000000000..bb9cd82d6ba27 --- /dev/null +++ b/examples/src/main/python/ml/multiclass_logistic_regression_with_elastic_net.py @@ -0,0 +1,48 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +# $example on$ +from pyspark.ml.classification import LogisticRegression +# $example off$ +from pyspark.sql import SparkSession + +if __name__ == "__main__": + spark = SparkSession \ + .builder \ + .appName("MulticlassLogisticRegressionWithElasticNet") \ + .getOrCreate() + + # $example on$ + # Load training data + training = spark \ + .read \ + .format("libsvm") \ + .load("data/mllib/sample_multiclass_classification_data.txt") + + lr = LogisticRegression(maxIter=10, regParam=0.3, elasticNetParam=0.8) + + # Fit the model + lrModel = lr.fit(training) + + # Print the coefficients and intercept for multinomial logistic regression + print("Coefficients: \n" + str(lrModel.coefficientMatrix)) + print("Intercept: " + str(lrModel.interceptVector)) + # $example off$ + + spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionWithElasticNetExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionWithElasticNetExample.scala index 616263b8e9f48..18471049087d9 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionWithElasticNetExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionWithElasticNetExample.scala @@ -45,6 +45,19 @@ object LogisticRegressionWithElasticNetExample { // Print the coefficients and intercept for logistic regression println(s"Coefficients: ${lrModel.coefficients} Intercept: ${lrModel.intercept}") + + // We can also use the multinomial family for binary classification + val mlr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(0.3) + .setElasticNetParam(0.8) + .setFamily("multinomial") + + val mlrModel = mlr.fit(training) + + // Print the coefficients and intercepts for logistic regression with multinomial family + println(s"Multinomial coefficients: ${mlrModel.coefficientMatrix}") + println(s"Multinomial intercepts: ${mlrModel.interceptVector}") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MulticlassLogisticRegressionWithElasticNetExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MulticlassLogisticRegressionWithElasticNetExample.scala new file mode 100644 index 0000000000000..42f0ace7a353d --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/MulticlassLogisticRegressionWithElasticNetExample.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.classification.LogisticRegression +// $example off$ +import org.apache.spark.sql.SparkSession + +object MulticlassLogisticRegressionWithElasticNetExample { + + def main(args: Array[String]): Unit = { + val spark = SparkSession + .builder + .appName("MulticlassLogisticRegressionWithElasticNetExample") + .getOrCreate() + + // $example on$ + // Load training data + val training = spark + .read + .format("libsvm") + .load("data/mllib/sample_multiclass_classification_data.txt") + + val lr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(0.3) + .setElasticNetParam(0.8) + + // Fit the model + val lrModel = lr.fit(training) + + // Print the coefficients and intercept for multinomial logistic regression + println(s"Coefficients: \n${lrModel.coefficientMatrix}") + println(s"Intercepts: ${lrModel.interceptVector}") + // $example off$ + + spark.stop() + } +} +// scalastyle:on println From 221b418b1c9db7b04c600b6300d18b034a4f444e Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Wed, 5 Oct 2016 14:54:55 -0700 Subject: [PATCH 199/306] [SPARK-17778][TESTS] Mock SparkContext to reduce memory usage of BlockManagerSuite ## What changes were proposed in this pull request? Mock SparkContext to reduce memory usage of BlockManagerSuite ## How was this patch tested? Jenkins Author: Shixiong Zhu Closes #15350 from zsxwing/SPARK-17778. --- .../scala/org/apache/spark/storage/BlockManagerSuite.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 1652fcdb964da..705c355234425 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -107,7 +107,10 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE rpcEnv = RpcEnv.create("test", "localhost", 0, conf, securityMgr) conf.set("spark.driver.port", rpcEnv.address.port.toString) - sc = new SparkContext("local", "test", conf) + // Mock SparkContext to reduce the memory usage of tests. It's fine since the only reason we + // need to create a SparkContext is to initialize LiveListenerBus. + sc = mock(classOf[SparkContext]) + when(sc.conf).thenReturn(conf) master = new BlockManagerMaster(rpcEnv.setupEndpoint("blockmanager", new BlockManagerMasterEndpoint(rpcEnv, true, conf, new LiveListenerBus(sc))), conf, true) From 5fd54b994e2078dbf0794932b4e0ffa9a9eda0c3 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Wed, 5 Oct 2016 16:05:30 -0700 Subject: [PATCH 200/306] [SPARK-17758][SQL] Last returns wrong result in case of empty partition ## What changes were proposed in this pull request? The result of the `Last` function can be wrong when the last partition processed is empty. It can return `null` instead of the expected value. For example, this can happen when we process partitions in the following order: ``` - Partition 1 [Row1, Row2] - Partition 2 [Row3] - Partition 3 [] ``` In this case the `Last` function will currently return a null, instead of the value of `Row3`. This PR fixes this by adding a `valueSet` flag to the `Last` function. ## How was this patch tested? We only used end to end tests for `DeclarativeAggregateFunction`s. I have added an evaluator for these functions so we can tests them in catalyst. I have added a `LastTestSuite` to test the `Last` aggregate function. Author: Herman van Hovell Closes #15348 from hvanhovell/SPARK-17758. --- .../catalyst/expressions/aggregate/Last.scala | 27 ++--- .../DeclarativeAggregateEvaluator.scala | 61 ++++++++++ .../expressions/aggregate/LastTestSuite.scala | 109 ++++++++++++++++++ 3 files changed, 184 insertions(+), 13 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/DeclarativeAggregateEvaluator.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/LastTestSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala index af8840305805f..8579f7292d3ab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala @@ -55,34 +55,35 @@ case class Last(child: Expression, ignoreNullsExpr: Expression) extends Declarat private lazy val last = AttributeReference("last", child.dataType)() - override lazy val aggBufferAttributes: Seq[AttributeReference] = last :: Nil + private lazy val valueSet = AttributeReference("valueSet", BooleanType)() + + override lazy val aggBufferAttributes: Seq[AttributeReference] = last :: valueSet :: Nil override lazy val initialValues: Seq[Literal] = Seq( - /* last = */ Literal.create(null, child.dataType) + /* last = */ Literal.create(null, child.dataType), + /* valueSet = */ Literal.create(false, BooleanType) ) override lazy val updateExpressions: Seq[Expression] = { if (ignoreNulls) { Seq( - /* last = */ If(IsNull(child), last, child) + /* last = */ If(IsNull(child), last, child), + /* valueSet = */ Or(valueSet, IsNotNull(child)) ) } else { Seq( - /* last = */ child + /* last = */ child, + /* valueSet = */ Literal.create(true, BooleanType) ) } } override lazy val mergeExpressions: Seq[Expression] = { - if (ignoreNulls) { - Seq( - /* last = */ If(IsNull(last.right), last.left, last.right) - ) - } else { - Seq( - /* last = */ last.right - ) - } + // Prefer the right hand expression if it has been set. + Seq( + /* last = */ If(valueSet.right, last.right, last.left), + /* valueSet = */ Or(valueSet.right, valueSet.left) + ) } override lazy val evaluateExpression: AttributeReference = last diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/DeclarativeAggregateEvaluator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/DeclarativeAggregateEvaluator.scala new file mode 100644 index 0000000000000..614f24db0aafb --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/DeclarativeAggregateEvaluator.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, JoinedRow} +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateSafeProjection + +/** + * Evaluator for a [[DeclarativeAggregate]]. + */ +case class DeclarativeAggregateEvaluator(function: DeclarativeAggregate, input: Seq[Attribute]) { + + lazy val initializer = GenerateSafeProjection.generate(function.initialValues) + + lazy val updater = GenerateSafeProjection.generate( + function.updateExpressions, + function.aggBufferAttributes ++ input) + + lazy val merger = GenerateSafeProjection.generate( + function.mergeExpressions, + function.aggBufferAttributes ++ function.inputAggBufferAttributes) + + lazy val evaluator = GenerateSafeProjection.generate( + function.evaluateExpression :: Nil, + function.aggBufferAttributes) + + def initialize(): InternalRow = initializer.apply(InternalRow.empty).copy() + + def update(values: InternalRow*): InternalRow = { + val joiner = new JoinedRow + val buffer = values.foldLeft(initialize()) { (buffer, input) => + updater(joiner(buffer, input)) + } + buffer.copy() + } + + def merge(buffers: InternalRow*): InternalRow = { + val joiner = new JoinedRow + val buffer = buffers.foldLeft(initialize()) { (left, right) => + merger(joiner(left, right)) + } + buffer.copy() + } + + def eval(buffer: InternalRow): InternalRow = evaluator(buffer).copy() +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/LastTestSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/LastTestSuite.scala new file mode 100644 index 0000000000000..ba36bc074e154 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/LastTestSuite.scala @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Literal} +import org.apache.spark.sql.types.IntegerType + +class LastTestSuite extends SparkFunSuite { + val input = AttributeReference("input", IntegerType, nullable = true)() + val evaluator = DeclarativeAggregateEvaluator(Last(input, Literal(false)), Seq(input)) + val evaluatorIgnoreNulls = DeclarativeAggregateEvaluator(Last(input, Literal(true)), Seq(input)) + + test("empty buffer") { + assert(evaluator.initialize() === InternalRow(null, false)) + } + + test("update") { + val result = evaluator.update( + InternalRow(1), + InternalRow(9), + InternalRow(-1)) + assert(result === InternalRow(-1, true)) + } + + test("update - ignore nulls") { + val result1 = evaluatorIgnoreNulls.update( + InternalRow(null), + InternalRow(9), + InternalRow(null)) + assert(result1 === InternalRow(9, true)) + + val result2 = evaluatorIgnoreNulls.update( + InternalRow(null), + InternalRow(null)) + assert(result2 === InternalRow(null, false)) + } + + test("merge") { + // Empty merge + val p0 = evaluator.initialize() + assert(evaluator.merge(p0) === InternalRow(null, false)) + + // Single merge + val p1 = evaluator.update(InternalRow(1), InternalRow(-99)) + assert(evaluator.merge(p1) === p1) + + // Multiple merges. + val p2 = evaluator.update(InternalRow(2), InternalRow(10)) + assert(evaluator.merge(p1, p2) === p2) + + // Empty partitions (p0 is empty) + assert(evaluator.merge(p1, p0, p2) === p2) + assert(evaluator.merge(p2, p1, p0) === p1) + } + + test("merge - ignore nulls") { + // Multi merges + val p1 = evaluatorIgnoreNulls.update(InternalRow(1), InternalRow(null)) + val p2 = evaluatorIgnoreNulls.update(InternalRow(null), InternalRow(null)) + assert(evaluatorIgnoreNulls.merge(p1, p2) === p1) + } + + test("eval") { + // Null Eval + assert(evaluator.eval(InternalRow(null, true)) === InternalRow(null)) + assert(evaluator.eval(InternalRow(null, false)) === InternalRow(null)) + + // Empty Eval + val p0 = evaluator.initialize() + assert(evaluator.eval(p0) === InternalRow(null)) + + // Update - Eval + val p1 = evaluator.update(InternalRow(1), InternalRow(-99)) + assert(evaluator.eval(p1) === InternalRow(-99)) + + // Update - Merge - Eval + val p2 = evaluator.update(InternalRow(2), InternalRow(10)) + val m1 = evaluator.merge(p1, p0, p2) + assert(evaluator.eval(m1) === InternalRow(10)) + + // Update - Merge - Eval (empty partition at the end) + val m2 = evaluator.merge(p2, p1, p0) + assert(evaluator.eval(m2) === InternalRow(-99)) + } + + test("eval - ignore nulls") { + // Update - Merge - Eval + val p1 = evaluatorIgnoreNulls.update(InternalRow(1), InternalRow(null)) + val p2 = evaluatorIgnoreNulls.update(InternalRow(null), InternalRow(null)) + val m1 = evaluatorIgnoreNulls.merge(p1, p2) + assert(evaluatorIgnoreNulls.eval(m1) === InternalRow(1)) + } +} From 9293734d35eb3d6e4fd4ebb86f54dd5d3a35e6db Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Wed, 5 Oct 2016 16:45:45 -0700 Subject: [PATCH 201/306] [SPARK-17346][SQL] Add Kafka source for Structured Streaming ## What changes were proposed in this pull request? This PR adds a new project ` external/kafka-0-10-sql` for Structured Streaming Kafka source. It's based on the design doc: https://docs.google.com/document/d/19t2rWe51x7tq2e5AOfrsM9qb8_m7BRuv9fel9i0PqR8/edit?usp=sharing tdas did most of work and part of them was inspired by koeninger's work. ### Introduction The Kafka source is a structured streaming data source to poll data from Kafka. The schema of reading data is as follows: Column | Type ---- | ---- key | binary value | binary topic | string partition | int offset | long timestamp | long timestampType | int The source can deal with deleting topics. However, the user should make sure there is no Spark job processing the data when deleting a topic. ### Configuration The user can use `DataStreamReader.option` to set the following configurations. Kafka Source's options | value | default | meaning ------ | ------- | ------ | ----- startingOffset | ["earliest", "latest"] | "latest" | The start point when a query is started, either "earliest" which is from the earliest offset, or "latest" which is just from the latest offset. Note: This only applies when a new Streaming query is started, and that resuming will always pick up from where the query left off. failOnDataLost | [true, false] | true | Whether to fail the query when it's possible that data is lost (e.g., topics are deleted, or offsets are out of range). This may be a false alarm. You can disable it when it doesn't work as you expected. subscribe | A comma-separated list of topics | (none) | The topic list to subscribe. Only one of "subscribe" and "subscribeParttern" options can be specified for Kafka source. subscribePattern | Java regex string | (none) | The pattern used to subscribe the topic. Only one of "subscribe" and "subscribeParttern" options can be specified for Kafka source. kafka.consumer.poll.timeoutMs | long | 512 | The timeout in milliseconds to poll data from Kafka in executors fetchOffset.numRetries | int | 3 | Number of times to retry before giving up fatch Kafka latest offsets. fetchOffset.retryIntervalMs | long | 10 | milliseconds to wait before retrying to fetch Kafka offsets Kafka's own configurations can be set via `DataStreamReader.option` with `kafka.` prefix, e.g, `stream.option("kafka.bootstrap.servers", "host:port")` ### Usage * Subscribe to 1 topic ```Scala spark .readStream .format("kafka") .option("kafka.bootstrap.servers", "host:port") .option("subscribe", "topic1") .load() ``` * Subscribe to multiple topics ```Scala spark .readStream .format("kafka") .option("kafka.bootstrap.servers", "host:port") .option("subscribe", "topic1,topic2") .load() ``` * Subscribe to a pattern ```Scala spark .readStream .format("kafka") .option("kafka.bootstrap.servers", "host:port") .option("subscribePattern", "topic.*") .load() ``` ## How was this patch tested? The new unit tests. Author: Shixiong Zhu Author: Tathagata Das Author: Shixiong Zhu Author: cody koeninger Closes #15102 from zsxwing/kafka-source. --- .../spark/util/UninterruptibleThread.scala | 7 - dev/run-tests.py | 2 +- dev/sparktestsupport/modules.py | 12 + .../structured-streaming-kafka-integration.md | 239 ++++++++++ .../structured-streaming-programming-guide.md | 7 +- external/kafka-0-10-sql/pom.xml | 82 ++++ ...pache.spark.sql.sources.DataSourceRegister | 1 + .../sql/kafka010/CachedKafkaConsumer.scala | 152 +++++++ .../spark/sql/kafka010/KafkaSource.scala | 399 ++++++++++++++++ .../sql/kafka010/KafkaSourceOffset.scala | 54 +++ .../sql/kafka010/KafkaSourceProvider.scala | 282 ++++++++++++ .../spark/sql/kafka010/KafkaSourceRDD.scala | 148 ++++++ .../spark/sql/kafka010/package-info.java | 21 + .../src/test/resources/log4j.properties | 28 ++ .../sql/kafka010/KafkaSourceOffsetSuite.scala | 39 ++ .../spark/sql/kafka010/KafkaSourceSuite.scala | 424 ++++++++++++++++++ .../spark/sql/kafka010/KafkaTestUtils.scala | 339 ++++++++++++++ pom.xml | 1 + project/SparkBuild.scala | 6 +- .../execution/streaming/StreamExecution.scala | 8 +- .../spark/sql/streaming/StreamTest.scala | 40 +- 21 files changed, 2268 insertions(+), 23 deletions(-) create mode 100644 docs/structured-streaming-kafka-integration.md create mode 100644 external/kafka-0-10-sql/pom.xml create mode 100644 external/kafka-0-10-sql/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister create mode 100644 external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala create mode 100644 external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala create mode 100644 external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala create mode 100644 external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala create mode 100644 external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala create mode 100644 external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/package-info.java create mode 100644 external/kafka-0-10-sql/src/test/resources/log4j.properties create mode 100644 external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceOffsetSuite.scala create mode 100644 external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala create mode 100644 external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala diff --git a/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala b/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala index 4dcf95177aa78..f0b68f0cb7e29 100644 --- a/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala +++ b/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala @@ -89,13 +89,6 @@ private[spark] class UninterruptibleThread(name: String) extends Thread(name) { } } - /** - * Tests whether `interrupt()` has been called. - */ - override def isInterrupted: Boolean = { - super.isInterrupted || uninterruptibleLock.synchronized { shouldInterruptThread } - } - /** * Interrupt `this` thread if possible. If `this` is in the uninterruptible status, it won't be * interrupted until it enters into the interruptible status. diff --git a/dev/run-tests.py b/dev/run-tests.py index ae4b5306fc5cf..5d661f5f1a1c5 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -110,7 +110,7 @@ def determine_modules_to_test(changed_modules): ['graphx', 'examples'] >>> x = [x.name for x in determine_modules_to_test([modules.sql])] >>> x # doctest: +NORMALIZE_WHITESPACE - ['sql', 'hive', 'mllib', 'examples', 'hive-thriftserver', + ['sql', 'hive', 'mllib', 'sql-kafka-0-10', 'examples', 'hive-thriftserver', 'pyspark-sql', 'sparkr', 'pyspark-mllib', 'pyspark-ml'] """ modules_to_test = set() diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 050cdf043757f..5f14683d9a52f 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -158,6 +158,18 @@ def __hash__(self): ) +sql_kafka = Module( + name="sql-kafka-0-10", + dependencies=[sql], + source_file_regexes=[ + "external/kafka-0-10-sql", + ], + sbt_test_goals=[ + "sql-kafka-0-10/test", + ] +) + + sketch = Module( name="sketch", dependencies=[tags], diff --git a/docs/structured-streaming-kafka-integration.md b/docs/structured-streaming-kafka-integration.md new file mode 100644 index 0000000000000..668489addf82c --- /dev/null +++ b/docs/structured-streaming-kafka-integration.md @@ -0,0 +1,239 @@ +--- +layout: global +title: Structured Streaming + Kafka Integration Guide (Kafka broker version 0.10.0 or higher) +--- + +Structured Streaming integration for Kafka 0.10 to poll data from Kafka. + +### Linking +For Scala/Java applications using SBT/Maven project definitions, link your application with the following artifact: + + groupId = org.apache.spark + artifactId = spark-sql-kafka-0-10_{{site.SCALA_BINARY_VERSION}} + version = {{site.SPARK_VERSION_SHORT}} + +For Python applications, you need to add this above library and its dependencies when deploying your +application. See the [Deploying](#deploying) subsection below. + +### Creating a Kafka Source Stream + +
      +
      + + // Subscribe to 1 topic + val ds1 = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribe", "topic1") + .load() + ds1.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + + // Subscribe to multiple topics + val ds2 = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribe", "topic1,topic2") + .load() + ds2.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + + // Subscribe to a pattern + val ds3 = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribePattern", "topic.*") + .load() + ds3.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + +
      +
      + + // Subscribe to 1 topic + Dataset ds1 = spark + .readStream() + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribe", "topic1") + .load() + ds1.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + + // Subscribe to multiple topics + Dataset ds2 = spark + .readStream() + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribe", "topic1,topic2") + .load() + ds2.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + + // Subscribe to a pattern + Dataset ds3 = spark + .readStream() + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribePattern", "topic.*") + .load() + ds3.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + +
      +
      + + # Subscribe to 1 topic + ds1 = spark + .readStream() + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribe", "topic1") + .load() + ds1.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + + # Subscribe to multiple topics + ds2 = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribe", "topic1,topic2") + .load() + ds2.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + + # Subscribe to a pattern + ds3 = spark + .readStream() + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribePattern", "topic.*") + .load() + ds3.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + +
      +
      + +Each row in the source has the following schema: + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
      ColumnType
      keybinary
      valuebinary
      topicstring
      partitionint
      offsetlong
      timestamplong
      timestampTypeint
      + +The following options must be set for the Kafka source. + + + + + + + + + + + + + + + + + + +
      Optionvaluemeaning
      subscribeA comma-separated list of topicsThe topic list to subscribe. Only one of "subscribe" and "subscribePattern" options can be + specified for Kafka source.
      subscribePatternJava regex stringThe pattern used to subscribe the topic. Only one of "subscribe" and "subscribePattern" + options can be specified for Kafka source.
      kafka.bootstrap.serversA comma-separated list of host:portThe Kafka "bootstrap.servers" configuration.
      + +The following configurations are optional: + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
      Optionvaluedefaultmeaning
      startingOffset["earliest", "latest"]"latest"The start point when a query is started, either "earliest" which is from the earliest offset, + or "latest" which is just from the latest offset. Note: This only applies when a new Streaming q + uery is started, and that resuming will always pick up from where the query left off.
      failOnDataLoss[true, false]trueWhether to fail the query when it's possible that data is lost (e.g., topics are deleted, or + offsets are out of range). This may be a false alarm. You can disable it when it doesn't work + as you expected.
      kafkaConsumer.pollTimeoutMslong512The timeout in milliseconds to poll data from Kafka in executors.
      fetchOffset.numRetriesint3Number of times to retry before giving up fatch Kafka latest offsets.
      fetchOffset.retryIntervalMslong10milliseconds to wait before retrying to fetch Kafka offsets
      + +Kafka's own configurations can be set via `DataStreamReader.option` with `kafka.` prefix, e.g, +`stream.option("kafka.bootstrap.servers", "host:port")`. For possible kafkaParams, see +[Kafka consumer config docs](http://kafka.apache.org/documentation.html#newconsumerconfigs). + +Note that the following Kafka params cannot be set and the Kafka source will throw an exception: +- **group.id**: Kafka source will create a unique group id for each query automatically. +- **auto.offset.reset**: Set the source option `startingOffset` to `earliest` or `latest` to specify + where to start instead. Structured Streaming manages which offsets are consumed internally, rather + than rely on the kafka Consumer to do it. This will ensure that no data is missed when when new + topics/partitions are dynamically subscribed. Note that `startingOffset` only applies when a new + Streaming query is started, and that resuming will always pick up from where the query left off. +- **key.deserializer**: Keys are always deserialized as byte arrays with ByteArrayDeserializer. Use + DataFrame operations to explicitly deserialize the keys. +- **value.deserializer**: Values are always deserialized as byte arrays with ByteArrayDeserializer. + Use DataFrame operations to explicitly deserialize the values. +- **enable.auto.commit**: Kafka source doesn't commit any offset. +- **interceptor.classes**: Kafka source always read keys and values as byte arrays. It's not safe to + use ConsumerInterceptor as it may break the query. + +### Deploying + +As with any Spark applications, `spark-submit` is used to launch your application. `spark-sql-kafka-0-10_{{site.SCALA_BINARY_VERSION}}` +and its dependencies can be directly added to `spark-submit` using `--packages`, such as, + + ./bin/spark-submit --packages org.apache.spark:spark-sql-kafka-0-10_{{site.SCALA_BINARY_VERSION}}:{{site.SPARK_VERSION_SHORT}} ... + +See [Application Submission Guide](submitting-applications.html) for more details about submitting +applications with external dependencies. diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index 2e6df94823d38..173fd6e8c73b9 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -418,10 +418,15 @@ Since Spark 2.0, DataFrames and Datasets can represent static, bounded data, as Streaming DataFrames can be created through the `DataStreamReader` interface ([Scala](api/scala/index.html#org.apache.spark.sql.streaming.DataStreamReader)/ [Java](api/java/org/apache/spark/sql/streaming/DataStreamReader.html)/ -[Python](api/python/pyspark.sql.html#pyspark.sql.streaming.DataStreamReader) docs) returned by `SparkSession.readStream()`. Similar to the read interface for creating static DataFrame, you can specify the details of the source – data format, schema, options, etc. In Spark 2.0, there are a few built-in sources. +[Python](api/python/pyspark.sql.html#pyspark.sql.streaming.DataStreamReader) docs) returned by `SparkSession.readStream()`. Similar to the read interface for creating static DataFrame, you can specify the details of the source – data format, schema, options, etc. + +#### Data Sources +In Spark 2.0, there are a few built-in sources. - **File source** - Reads files written in a directory as a stream of data. Supported file formats are text, csv, json, parquet. See the docs of the DataStreamReader interface for a more up-to-date list, and supported options for each file format. Note that the files must be atomically placed in the given directory, which in most file systems, can be achieved by file move operations. + - **Kafka source** - Poll data from Kafka. It's compatible with Kafka broker versions 0.10.0 or higher. See the [Kafka Integration Guide](structured-streaming-kafka-integration.html) for more details. + - **Socket source (for testing)** - Reads UTF8 text data from a socket connection. The listening server socket is at the driver. Note that this should be used only for testing as this does not provide end-to-end fault-tolerance guarantees. Here are some examples. diff --git a/external/kafka-0-10-sql/pom.xml b/external/kafka-0-10-sql/pom.xml new file mode 100644 index 0000000000000..b96445a11f858 --- /dev/null +++ b/external/kafka-0-10-sql/pom.xml @@ -0,0 +1,82 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent_2.11 + 2.1.0-SNAPSHOT + ../../pom.xml + + + org.apache.spark + spark-sql-kafka-0-10_2.11 + + sql-kafka-0-10 + + jar + Kafka 0.10 Source for Structured Streaming + http://spark.apache.org/ + + + + org.apache.spark + spark-sql_${scala.binary.version} + ${project.version} + provided + + + org.apache.spark + spark-sql_${scala.binary.version} + ${project.version} + test-jar + test + + + org.apache.kafka + kafka-clients + 0.10.0.1 + + + org.apache.kafka + kafka_${scala.binary.version} + 0.10.0.1 + test + + + net.sf.jopt-simple + jopt-simple + 3.2 + test + + + org.scalacheck + scalacheck_${scala.binary.version} + test + + + org.apache.spark + spark-tags_${scala.binary.version} + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + diff --git a/external/kafka-0-10-sql/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/external/kafka-0-10-sql/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister new file mode 100644 index 0000000000000..2f9e9fc0396d5 --- /dev/null +++ b/external/kafka-0-10-sql/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -0,0 +1 @@ +org.apache.spark.sql.kafka010.KafkaSourceProvider diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala new file mode 100644 index 0000000000000..3b5a96534f9b6 --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.{util => ju} + +import org.apache.kafka.clients.consumer.{ConsumerConfig, ConsumerRecord, KafkaConsumer} +import org.apache.kafka.common.TopicPartition + +import org.apache.spark.{SparkEnv, SparkException, TaskContext} +import org.apache.spark.internal.Logging + + +/** + * Consumer of single topicpartition, intended for cached reuse. + * Underlying consumer is not threadsafe, so neither is this, + * but processing the same topicpartition and group id in multiple threads is usually bad anyway. + */ +private[kafka010] case class CachedKafkaConsumer private( + topicPartition: TopicPartition, + kafkaParams: ju.Map[String, Object]) extends Logging { + + private val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String] + + private val consumer = { + val c = new KafkaConsumer[Array[Byte], Array[Byte]](kafkaParams) + val tps = new ju.ArrayList[TopicPartition]() + tps.add(topicPartition) + c.assign(tps) + c + } + + /** Iterator to the already fetch data */ + private var fetchedData = ju.Collections.emptyIterator[ConsumerRecord[Array[Byte], Array[Byte]]] + private var nextOffsetInFetchedData = -2L + + /** + * Get the record for the given offset, waiting up to timeout ms if IO is necessary. + * Sequential forward access will use buffers, but random access will be horribly inefficient. + */ + def get(offset: Long, pollTimeoutMs: Long): ConsumerRecord[Array[Byte], Array[Byte]] = { + logDebug(s"Get $groupId $topicPartition nextOffset $nextOffsetInFetchedData requested $offset") + if (offset != nextOffsetInFetchedData) { + logInfo(s"Initial fetch for $topicPartition $offset") + seek(offset) + poll(pollTimeoutMs) + } + + if (!fetchedData.hasNext()) { poll(pollTimeoutMs) } + assert(fetchedData.hasNext(), + s"Failed to get records for $groupId $topicPartition $offset " + + s"after polling for $pollTimeoutMs") + var record = fetchedData.next() + + if (record.offset != offset) { + logInfo(s"Buffer miss for $groupId $topicPartition $offset") + seek(offset) + poll(pollTimeoutMs) + assert(fetchedData.hasNext(), + s"Failed to get records for $groupId $topicPartition $offset " + + s"after polling for $pollTimeoutMs") + record = fetchedData.next() + assert(record.offset == offset, + s"Got wrong record for $groupId $topicPartition even after seeking to offset $offset") + } + + nextOffsetInFetchedData = offset + 1 + record + } + + private def close(): Unit = consumer.close() + + private def seek(offset: Long): Unit = { + logDebug(s"Seeking to $groupId $topicPartition $offset") + consumer.seek(topicPartition, offset) + } + + private def poll(pollTimeoutMs: Long): Unit = { + val p = consumer.poll(pollTimeoutMs) + val r = p.records(topicPartition) + logDebug(s"Polled $groupId ${p.partitions()} ${r.size}") + fetchedData = r.iterator + } +} + +private[kafka010] object CachedKafkaConsumer extends Logging { + + private case class CacheKey(groupId: String, topicPartition: TopicPartition) + + private lazy val cache = { + val conf = SparkEnv.get.conf + val capacity = conf.getInt("spark.sql.kafkaConsumerCache.capacity", 64) + new ju.LinkedHashMap[CacheKey, CachedKafkaConsumer](capacity, 0.75f, true) { + override def removeEldestEntry( + entry: ju.Map.Entry[CacheKey, CachedKafkaConsumer]): Boolean = { + if (this.size > capacity) { + logWarning(s"KafkaConsumer cache hitting max capacity of $capacity, " + + s"removing consumer for ${entry.getKey}") + try { + entry.getValue.close() + } catch { + case e: SparkException => + logError(s"Error closing earliest Kafka consumer for ${entry.getKey}", e) + } + true + } else { + false + } + } + } + } + + /** + * Get a cached consumer for groupId, assigned to topic and partition. + * If matching consumer doesn't already exist, will be created using kafkaParams. + */ + def getOrCreate( + topic: String, + partition: Int, + kafkaParams: ju.Map[String, Object]): CachedKafkaConsumer = synchronized { + val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String] + val topicPartition = new TopicPartition(topic, partition) + val key = CacheKey(groupId, topicPartition) + + // If this is reattempt at running the task, then invalidate cache and start with + // a new consumer + if (TaskContext.get != null && TaskContext.get.attemptNumber > 1) { + cache.remove(key) + new CachedKafkaConsumer(topicPartition, kafkaParams) + } else { + if (!cache.containsKey(key)) { + cache.put(key, new CachedKafkaConsumer(topicPartition, kafkaParams)) + } + cache.get(key) + } + } +} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala new file mode 100644 index 0000000000000..1be70db87497e --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala @@ -0,0 +1,399 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.{util => ju} + +import scala.collection.JavaConverters._ +import scala.util.control.NonFatal + +import org.apache.kafka.clients.consumer.{Consumer, KafkaConsumer} +import org.apache.kafka.clients.consumer.internals.NoOpConsumerRebalanceListener +import org.apache.kafka.common.TopicPartition + +import org.apache.spark.SparkContext +import org.apache.spark.internal.Logging +import org.apache.spark.scheduler.ExecutorCacheTaskLocation +import org.apache.spark.sql._ +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.kafka010.KafkaSource._ +import org.apache.spark.sql.types._ +import org.apache.spark.util.UninterruptibleThread + +/** + * A [[Source]] that uses Kafka's own [[KafkaConsumer]] API to reads data from Kafka. The design + * for this source is as follows. + * + * - The [[KafkaSourceOffset]] is the custom [[Offset]] defined for this source that contains + * a map of TopicPartition -> offset. Note that this offset is 1 + (available offset). For + * example if the last record in a Kafka topic "t", partition 2 is offset 5, then + * KafkaSourceOffset will contain TopicPartition("t", 2) -> 6. This is done keep it consistent + * with the semantics of `KafkaConsumer.position()`. + * + * - The [[ConsumerStrategy]] class defines which Kafka topics and partitions should be read + * by this source. These strategies directly correspond to the different consumption options + * in . This class is designed to return a configured [[KafkaConsumer]] that is used by the + * [[KafkaSource]] to query for the offsets. See the docs on + * [[org.apache.spark.sql.kafka010.KafkaSource.ConsumerStrategy]] for more details. + * + * - The [[KafkaSource]] written to do the following. + * + * - As soon as the source is created, the pre-configured KafkaConsumer returned by the + * [[ConsumerStrategy]] is used to query the initial offsets that this source should + * start reading from. This used to create the first batch. + * + * - `getOffset()` uses the KafkaConsumer to query the latest available offsets, which are + * returned as a [[KafkaSourceOffset]]. + * + * - `getBatch()` returns a DF that reads from the 'start offset' until the 'end offset' in + * for each partition. The end offset is excluded to be consistent with the semantics of + * [[KafkaSourceOffset]] and `KafkaConsumer.position()`. + * + * - The DF returned is based on [[KafkaSourceRDD]] which is constructed such that the + * data from Kafka topic + partition is consistently read by the same executors across + * batches, and cached KafkaConsumers in the executors can be reused efficiently. See the + * docs on [[KafkaSourceRDD]] for more details. + * + * Zero data lost is not guaranteed when topics are deleted. If zero data lost is critical, the user + * must make sure all messages in a topic have been processed when deleting a topic. + * + * There is a known issue caused by KAFKA-1894: the query using KafkaSource maybe cannot be stopped. + * To avoid this issue, you should make sure stopping the query before stopping the Kafka brokers + * and not use wrong broker addresses. + */ +private[kafka010] case class KafkaSource( + sqlContext: SQLContext, + consumerStrategy: ConsumerStrategy, + executorKafkaParams: ju.Map[String, Object], + sourceOptions: Map[String, String], + metadataPath: String, + failOnDataLoss: Boolean) + extends Source with Logging { + + private val sc = sqlContext.sparkContext + + private val pollTimeoutMs = sourceOptions.getOrElse("kafkaConsumer.pollTimeoutMs", "512").toLong + + private val maxOffsetFetchAttempts = + sourceOptions.getOrElse("fetchOffset.numRetries", "3").toInt + + private val offsetFetchAttemptIntervalMs = + sourceOptions.getOrElse("fetchOffset.retryIntervalMs", "10").toLong + + /** + * A KafkaConsumer used in the driver to query the latest Kafka offsets. This only queries the + * offsets and never commits them. + */ + private val consumer = consumerStrategy.createConsumer() + + /** + * Lazily initialize `initialPartitionOffsets` to make sure that `KafkaConsumer.poll` is only + * called in StreamExecutionThread. Otherwise, interrupting a thread while running + * `KafkaConsumer.poll` may hang forever (KAFKA-1894). + */ + private lazy val initialPartitionOffsets = { + val metadataLog = new HDFSMetadataLog[KafkaSourceOffset](sqlContext.sparkSession, metadataPath) + metadataLog.get(0).getOrElse { + val offsets = KafkaSourceOffset(fetchPartitionOffsets(seekToEnd = false)) + metadataLog.add(0, offsets) + logInfo(s"Initial offsets: $offsets") + offsets + }.partitionToOffsets + } + + override def schema: StructType = KafkaSource.kafkaSchema + + /** Returns the maximum available offset for this source. */ + override def getOffset: Option[Offset] = { + // Make sure initialPartitionOffsets is initialized + initialPartitionOffsets + + val offset = KafkaSourceOffset(fetchPartitionOffsets(seekToEnd = true)) + logDebug(s"GetOffset: ${offset.partitionToOffsets.toSeq.map(_.toString).sorted}") + Some(offset) + } + + /** + * Returns the data that is between the offsets + * [`start.get.partitionToOffsets`, `end.partitionToOffsets`), i.e. end.partitionToOffsets is + * exclusive. + */ + override def getBatch(start: Option[Offset], end: Offset): DataFrame = { + // Make sure initialPartitionOffsets is initialized + initialPartitionOffsets + + logInfo(s"GetBatch called with start = $start, end = $end") + val untilPartitionOffsets = KafkaSourceOffset.getPartitionOffsets(end) + val fromPartitionOffsets = start match { + case Some(prevBatchEndOffset) => + KafkaSourceOffset.getPartitionOffsets(prevBatchEndOffset) + case None => + initialPartitionOffsets + } + + // Find the new partitions, and get their earliest offsets + val newPartitions = untilPartitionOffsets.keySet.diff(fromPartitionOffsets.keySet) + val newPartitionOffsets = if (newPartitions.nonEmpty) { + fetchNewPartitionEarliestOffsets(newPartitions.toSeq) + } else { + Map.empty[TopicPartition, Long] + } + if (newPartitionOffsets.keySet != newPartitions) { + // We cannot get from offsets for some partitions. It means they got deleted. + val deletedPartitions = newPartitions.diff(newPartitionOffsets.keySet) + reportDataLoss( + s"Cannot find earliest offsets of ${deletedPartitions}. Some data may have been missed") + } + logInfo(s"Partitions added: $newPartitionOffsets") + newPartitionOffsets.filter(_._2 != 0).foreach { case (p, o) => + reportDataLoss( + s"Added partition $p starts from $o instead of 0. Some data may have been missed") + } + + val deletedPartitions = fromPartitionOffsets.keySet.diff(untilPartitionOffsets.keySet) + if (deletedPartitions.nonEmpty) { + reportDataLoss(s"$deletedPartitions are gone. Some data may have been missed") + } + + // Use the until partitions to calculate offset ranges to ignore partitions that have + // been deleted + val topicPartitions = untilPartitionOffsets.keySet.filter { tp => + // Ignore partitions that we don't know the from offsets. + newPartitionOffsets.contains(tp) || fromPartitionOffsets.contains(tp) + }.toSeq + logDebug("TopicPartitions: " + topicPartitions.mkString(", ")) + + val sortedExecutors = getSortedExecutorList(sc) + val numExecutors = sortedExecutors.length + logDebug("Sorted executors: " + sortedExecutors.mkString(", ")) + + // Calculate offset ranges + val offsetRanges = topicPartitions.map { tp => + val fromOffset = fromPartitionOffsets.get(tp).getOrElse { + newPartitionOffsets.getOrElse(tp, { + // This should not happen since newPartitionOffsets contains all partitions not in + // fromPartitionOffsets + throw new IllegalStateException(s"$tp doesn't have a from offset") + }) + } + val untilOffset = untilPartitionOffsets(tp) + val preferredLoc = if (numExecutors > 0) { + // This allows cached KafkaConsumers in the executors to be re-used to read the same + // partition in every batch. + Some(sortedExecutors(floorMod(tp.hashCode, numExecutors))) + } else None + KafkaSourceRDDOffsetRange(tp, fromOffset, untilOffset, preferredLoc) + }.filter { range => + if (range.untilOffset < range.fromOffset) { + reportDataLoss(s"Partition ${range.topicPartition}'s offset was changed from " + + s"${range.fromOffset} to ${range.untilOffset}, some data may have been missed") + false + } else { + true + } + }.toArray + + // Create a RDD that reads from Kafka and get the (key, value) pair as byte arrays. + val rdd = new KafkaSourceRDD( + sc, executorKafkaParams, offsetRanges, pollTimeoutMs).map { cr => + Row(cr.key, cr.value, cr.topic, cr.partition, cr.offset, cr.timestamp, cr.timestampType.id) + } + + logInfo("GetBatch generating RDD of offset range: " + + offsetRanges.sortBy(_.topicPartition.toString).mkString(", ")) + sqlContext.createDataFrame(rdd, schema) + } + + /** Stop this source and free any resources it has allocated. */ + override def stop(): Unit = synchronized { + consumer.close() + } + + override def toString(): String = s"KafkaSource[$consumerStrategy]" + + /** + * Fetch the offset of a partition, either seek to the latest offsets or use the current offsets + * in the consumer. + */ + private def fetchPartitionOffsets( + seekToEnd: Boolean): Map[TopicPartition, Long] = withRetriesWithoutInterrupt { + // Make sure `KafkaConsumer.poll` won't be interrupted (KAFKA-1894) + assert(Thread.currentThread().isInstanceOf[StreamExecutionThread]) + // Poll to get the latest assigned partitions + consumer.poll(0) + val partitions = consumer.assignment() + consumer.pause(partitions) + logDebug(s"Partitioned assigned to consumer: $partitions") + + // Get the current or latest offset of each partition + if (seekToEnd) { + consumer.seekToEnd(partitions) + logDebug("Seeked to the end") + } + val partitionOffsets = partitions.asScala.map(p => p -> consumer.position(p)).toMap + logDebug(s"Got offsets for partition : $partitionOffsets") + partitionOffsets + } + + /** + * Fetch the earliest offsets for newly discovered partitions. The return result may not contain + * some partitions if they are deleted. + */ + private def fetchNewPartitionEarliestOffsets( + newPartitions: Seq[TopicPartition]): Map[TopicPartition, Long] = withRetriesWithoutInterrupt { + // Make sure `KafkaConsumer.poll` won't be interrupted (KAFKA-1894) + assert(Thread.currentThread().isInstanceOf[StreamExecutionThread]) + // Poll to get the latest assigned partitions + consumer.poll(0) + val partitions = consumer.assignment() + logDebug(s"\tPartitioned assigned to consumer: $partitions") + + // Get the earliest offset of each partition + consumer.seekToBeginning(partitions) + val partitionToOffsets = newPartitions.filter { p => + // When deleting topics happen at the same time, some partitions may not be in `partitions`. + // So we need to ignore them + partitions.contains(p) + }.map(p => p -> consumer.position(p)).toMap + logDebug(s"Got offsets for new partitions: $partitionToOffsets") + partitionToOffsets + } + + /** + * Helper function that does multiple retries on the a body of code that returns offsets. + * Retries are needed to handle transient failures. For e.g. race conditions between getting + * assignment and getting position while topics/partitions are deleted can cause NPEs. + * + * This method also makes sure `body` won't be interrupted to workaround a potential issue in + * `KafkaConsumer.poll`. (KAFKA-1894) + */ + private def withRetriesWithoutInterrupt( + body: => Map[TopicPartition, Long]): Map[TopicPartition, Long] = { + synchronized { + var result: Option[Map[TopicPartition, Long]] = None + var attempt = 1 + var lastException: Throwable = null + while (result.isEmpty && attempt <= maxOffsetFetchAttempts + && !Thread.currentThread().isInterrupted) { + Thread.currentThread match { + case ut: UninterruptibleThread => + // "KafkaConsumer.poll" may hang forever if the thread is interrupted (E.g., the query + // is stopped)(KAFKA-1894). Hence, we just make sure we don't interrupt it. + // + // If the broker addresses are wrong, or Kafka cluster is down, "KafkaConsumer.poll" may + // hang forever as well. This cannot be resolved in KafkaSource until Kafka fixes the + // issue. + ut.runUninterruptibly { + try { + result = Some(body) + } catch { + case NonFatal(e) => + lastException = e + logWarning(s"Error in attempt $attempt getting Kafka offsets: ", e) + attempt += 1 + Thread.sleep(offsetFetchAttemptIntervalMs) + } + } + case _ => + throw new IllegalStateException( + "Kafka APIs must be executed on a o.a.spark.util.UninterruptibleThread") + } + } + if (Thread.interrupted()) { + throw new InterruptedException() + } + if (result.isEmpty) { + assert(attempt > maxOffsetFetchAttempts) + assert(lastException != null) + throw lastException + } + result.get + } + } + + /** + * If `failOnDataLoss` is true, this method will throw an `IllegalStateException`. + * Otherwise, just log a warning. + */ + private def reportDataLoss(message: String): Unit = { + if (failOnDataLoss) { + throw new IllegalStateException(message + + ". Set the source option 'failOnDataLoss' to 'false' if you want to ignore these checks.") + } else { + logWarning(message) + } + } +} + + +/** Companion object for the [[KafkaSource]]. */ +private[kafka010] object KafkaSource { + + def kafkaSchema: StructType = StructType(Seq( + StructField("key", BinaryType), + StructField("value", BinaryType), + StructField("topic", StringType), + StructField("partition", IntegerType), + StructField("offset", LongType), + StructField("timestamp", LongType), + StructField("timestampType", IntegerType) + )) + + sealed trait ConsumerStrategy { + def createConsumer(): Consumer[Array[Byte], Array[Byte]] + } + + case class SubscribeStrategy(topics: Seq[String], kafkaParams: ju.Map[String, Object]) + extends ConsumerStrategy { + override def createConsumer(): Consumer[Array[Byte], Array[Byte]] = { + val consumer = new KafkaConsumer[Array[Byte], Array[Byte]](kafkaParams) + consumer.subscribe(topics.asJava) + consumer + } + + override def toString: String = s"Subscribe[${topics.mkString(", ")}]" + } + + case class SubscribePatternStrategy( + topicPattern: String, kafkaParams: ju.Map[String, Object]) + extends ConsumerStrategy { + override def createConsumer(): Consumer[Array[Byte], Array[Byte]] = { + val consumer = new KafkaConsumer[Array[Byte], Array[Byte]](kafkaParams) + consumer.subscribe( + ju.regex.Pattern.compile(topicPattern), + new NoOpConsumerRebalanceListener()) + consumer + } + + override def toString: String = s"SubscribePattern[$topicPattern]" + } + + private def getSortedExecutorList(sc: SparkContext): Array[String] = { + val bm = sc.env.blockManager + bm.master.getPeers(bm.blockManagerId).toArray + .map(x => ExecutorCacheTaskLocation(x.host, x.executorId)) + .sortWith(compare) + .map(_.toString) + } + + private def compare(a: ExecutorCacheTaskLocation, b: ExecutorCacheTaskLocation): Boolean = { + if (a.host == b.host) { a.executorId > b.executorId } else { a.host > b.host } + } + + private def floorMod(a: Long, b: Int): Int = ((a % b).toInt + b) % b +} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala new file mode 100644 index 0000000000000..b5ade982515f0 --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import org.apache.kafka.common.TopicPartition + +import org.apache.spark.sql.execution.streaming.Offset + +/** + * An [[Offset]] for the [[KafkaSource]]. This one tracks all partitions of subscribed topics and + * their offsets. + */ +private[kafka010] +case class KafkaSourceOffset(partitionToOffsets: Map[TopicPartition, Long]) extends Offset { + override def toString(): String = { + partitionToOffsets.toSeq.sortBy(_._1.toString).mkString("[", ", ", "]") + } +} + +/** Companion object of the [[KafkaSourceOffset]] */ +private[kafka010] object KafkaSourceOffset { + + def getPartitionOffsets(offset: Offset): Map[TopicPartition, Long] = { + offset match { + case o: KafkaSourceOffset => o.partitionToOffsets + case _ => + throw new IllegalArgumentException( + s"Invalid conversion from offset of ${offset.getClass} to KafkaSourceOffset") + } + } + + /** + * Returns [[KafkaSourceOffset]] from a variable sequence of (topic, partitionId, offset) + * tuples. + */ + def apply(offsetTuples: (String, Int, Long)*): KafkaSourceOffset = { + KafkaSourceOffset(offsetTuples.map { case(t, p, o) => (new TopicPartition(t, p), o) }.toMap) + } +} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala new file mode 100644 index 0000000000000..1b0a2fe955d03 --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -0,0 +1,282 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.{util => ju} +import java.util.UUID + +import scala.collection.JavaConverters._ + +import org.apache.kafka.clients.consumer.ConsumerConfig +import org.apache.kafka.common.serialization.ByteArrayDeserializer + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.execution.streaming.Source +import org.apache.spark.sql.kafka010.KafkaSource._ +import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider} +import org.apache.spark.sql.types.StructType + +/** + * The provider class for the [[KafkaSource]]. This provider is designed such that it throws + * IllegalArgumentException when the Kafka Dataset is created, so that it can catch + * missing options even before the query is started. + */ +private[kafka010] class KafkaSourceProvider extends StreamSourceProvider + with DataSourceRegister with Logging { + + import KafkaSourceProvider._ + + /** + * Returns the name and schema of the source. In addition, it also verifies whether the options + * are correct and sufficient to create the [[KafkaSource]] when the query is started. + */ + override def sourceSchema( + sqlContext: SQLContext, + schema: Option[StructType], + providerName: String, + parameters: Map[String, String]): (String, StructType) = { + require(schema.isEmpty, "Kafka source has a fixed schema and cannot be set with a custom one") + validateOptions(parameters) + ("kafka", KafkaSource.kafkaSchema) + } + + override def createSource( + sqlContext: SQLContext, + metadataPath: String, + schema: Option[StructType], + providerName: String, + parameters: Map[String, String]): Source = { + validateOptions(parameters) + val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase, v) } + val specifiedKafkaParams = + parameters + .keySet + .filter(_.toLowerCase.startsWith("kafka.")) + .map { k => k.drop(6).toString -> parameters(k) } + .toMap + + val deserClassName = classOf[ByteArrayDeserializer].getName + // Each running query should use its own group id. Otherwise, the query may be only assigned + // partial data since Kafka will assign partitions to multiple consumers having the same group + // id. Hence, we should generate a unique id for each query. + val uniqueGroupId = s"spark-kafka-source-${UUID.randomUUID}-${metadataPath.hashCode}" + + val autoOffsetResetValue = caseInsensitiveParams.get(STARTING_OFFSET_OPTION_KEY) match { + case Some(value) => value.trim() // same values as those supported by auto.offset.reset + case None => "latest" + } + + val kafkaParamsForStrategy = + ConfigUpdater("source", specifiedKafkaParams) + .set(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, deserClassName) + .set(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, deserClassName) + + // So that consumers in Kafka source do not mess with any existing group id + .set(ConsumerConfig.GROUP_ID_CONFIG, s"$uniqueGroupId-driver") + + // So that consumers can start from earliest or latest + .set(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, autoOffsetResetValue) + + // So that consumers in the driver does not commit offsets unnecessarily + .set(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "false") + + // So that the driver does not pull too much data + .set(ConsumerConfig.MAX_POLL_RECORDS_CONFIG, new java.lang.Integer(1)) + + // If buffer config is not set, set it to reasonable value to work around + // buffer issues (see KAFKA-3135) + .setIfUnset(ConsumerConfig.RECEIVE_BUFFER_CONFIG, 65536: java.lang.Integer) + .build() + + val kafkaParamsForExecutors = + ConfigUpdater("executor", specifiedKafkaParams) + .set(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, deserClassName) + .set(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, deserClassName) + + // Make sure executors do only what the driver tells them. + .set(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "none") + + // So that consumers in executors do not mess with any existing group id + .set(ConsumerConfig.GROUP_ID_CONFIG, s"$uniqueGroupId-executor") + + // So that consumers in executors does not commit offsets unnecessarily + .set(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "false") + + // If buffer config is not set, set it to reasonable value to work around + // buffer issues (see KAFKA-3135) + .setIfUnset(ConsumerConfig.RECEIVE_BUFFER_CONFIG, 65536: java.lang.Integer) + .build() + + val strategy = caseInsensitiveParams.find(x => STRATEGY_OPTION_KEYS.contains(x._1)).get match { + case ("subscribe", value) => + SubscribeStrategy( + value.split(",").map(_.trim()).filter(_.nonEmpty), + kafkaParamsForStrategy) + case ("subscribepattern", value) => + SubscribePatternStrategy( + value.trim(), + kafkaParamsForStrategy) + case _ => + // Should never reach here as we are already matching on + // matched strategy names + throw new IllegalArgumentException("Unknown option") + } + + val failOnDataLoss = + caseInsensitiveParams.getOrElse(FAIL_ON_DATA_LOSS_OPTION_KEY, "true").toBoolean + + new KafkaSource( + sqlContext, + strategy, + kafkaParamsForExecutors, + parameters, + metadataPath, + failOnDataLoss) + } + + private def validateOptions(parameters: Map[String, String]): Unit = { + + // Validate source options + + val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase, v) } + val specifiedStrategies = + caseInsensitiveParams.filter { case (k, _) => STRATEGY_OPTION_KEYS.contains(k) }.toSeq + if (specifiedStrategies.isEmpty) { + throw new IllegalArgumentException( + "One of the following options must be specified for Kafka source: " + + STRATEGY_OPTION_KEYS.mkString(", ") + ". See the docs for more details.") + } else if (specifiedStrategies.size > 1) { + throw new IllegalArgumentException( + "Only one of the following options can be specified for Kafka source: " + + STRATEGY_OPTION_KEYS.mkString(", ") + ". See the docs for more details.") + } + + val strategy = caseInsensitiveParams.find(x => STRATEGY_OPTION_KEYS.contains(x._1)).get match { + case ("subscribe", value) => + val topics = value.split(",").map(_.trim).filter(_.nonEmpty) + if (topics.isEmpty) { + throw new IllegalArgumentException( + "No topics to subscribe to as specified value for option " + + s"'subscribe' is '$value'") + } + case ("subscribepattern", value) => + val pattern = caseInsensitiveParams("subscribepattern").trim() + if (pattern.isEmpty) { + throw new IllegalArgumentException( + "Pattern to subscribe is empty as specified value for option " + + s"'subscribePattern' is '$value'") + } + case _ => + // Should never reach here as we are already matching on + // matched strategy names + throw new IllegalArgumentException("Unknown option") + } + + caseInsensitiveParams.get(STARTING_OFFSET_OPTION_KEY) match { + case Some(pos) if !STARTING_OFFSET_OPTION_VALUES.contains(pos.trim.toLowerCase) => + throw new IllegalArgumentException( + s"Illegal value '$pos' for option '$STARTING_OFFSET_OPTION_KEY', " + + s"acceptable values are: ${STARTING_OFFSET_OPTION_VALUES.mkString(", ")}") + case _ => + } + + // Validate user-specified Kafka options + + if (caseInsensitiveParams.contains(s"kafka.${ConsumerConfig.GROUP_ID_CONFIG}")) { + throw new IllegalArgumentException( + s"Kafka option '${ConsumerConfig.GROUP_ID_CONFIG}' is not supported as " + + s"user-specified consumer groups is not used to track offsets.") + } + + if (caseInsensitiveParams.contains(s"kafka.${ConsumerConfig.AUTO_OFFSET_RESET_CONFIG}")) { + throw new IllegalArgumentException( + s""" + |Kafka option '${ConsumerConfig.AUTO_OFFSET_RESET_CONFIG}' is not supported. + |Instead set the source option '$STARTING_OFFSET_OPTION_KEY' to 'earliest' or 'latest' to + |specify where to start. Structured Streaming manages which offsets are consumed + |internally, rather than relying on the kafkaConsumer to do it. This will ensure that no + |data is missed when when new topics/partitions are dynamically subscribed. Note that + |'$STARTING_OFFSET_OPTION_KEY' only applies when a new Streaming query is started, and + |that resuming will always pick up from where the query left off. See the docs for more + |details. + """.stripMargin) + } + + if (caseInsensitiveParams.contains(s"kafka.${ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG}")) { + throw new IllegalArgumentException( + s"Kafka option '${ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG}' is not supported as keys " + + "are deserialized as byte arrays with ByteArrayDeserializer. Use DataFrame operations " + + "to explicitly deserialize the keys.") + } + + if (caseInsensitiveParams.contains(s"kafka.${ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG}")) + { + throw new IllegalArgumentException( + s"Kafka option '${ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG}' is not supported as " + + "value are deserialized as byte arrays with ByteArrayDeserializer. Use DataFrame " + + "operations to explicitly deserialize the values.") + } + + val otherUnsupportedConfigs = Seq( + ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, // committing correctly requires new APIs in Source + ConsumerConfig.INTERCEPTOR_CLASSES_CONFIG) // interceptors can modify payload, so not safe + + otherUnsupportedConfigs.foreach { c => + if (caseInsensitiveParams.contains(s"kafka.$c")) { + throw new IllegalArgumentException(s"Kafka option '$c' is not supported") + } + } + + if (!caseInsensitiveParams.contains(s"kafka.${ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG}")) { + throw new IllegalArgumentException( + s"Option 'kafka.${ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG}' must be specified for " + + s"configuring Kafka consumer") + } + } + + override def shortName(): String = "kafka" + + /** Class to conveniently update Kafka config params, while logging the changes */ + private case class ConfigUpdater(module: String, kafkaParams: Map[String, String]) { + private val map = new ju.HashMap[String, Object](kafkaParams.asJava) + + def set(key: String, value: Object): this.type = { + map.put(key, value) + logInfo(s"$module: Set $key to $value, earlier value: ${kafkaParams.get(key).getOrElse("")}") + this + } + + def setIfUnset(key: String, value: Object): ConfigUpdater = { + if (!map.containsKey(key)) { + map.put(key, value) + logInfo(s"$module: Set $key to $value") + } + this + } + + def build(): ju.Map[String, Object] = map + } +} + +private[kafka010] object KafkaSourceProvider { + private val STRATEGY_OPTION_KEYS = Set("subscribe", "subscribepattern") + private val STARTING_OFFSET_OPTION_KEY = "startingoffset" + private val STARTING_OFFSET_OPTION_VALUES = Set("earliest", "latest") + private val FAIL_ON_DATA_LOSS_OPTION_KEY = "failondataloss" +} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala new file mode 100644 index 0000000000000..496af7e39abab --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.{util => ju} + +import scala.collection.mutable.ArrayBuffer + +import org.apache.kafka.clients.consumer.ConsumerRecord +import org.apache.kafka.common.TopicPartition + +import org.apache.spark.{Partition, SparkContext, TaskContext} +import org.apache.spark.partial.{BoundedDouble, PartialResult} +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel + + +/** Offset range that one partition of the KafkaSourceRDD has to read */ +private[kafka010] case class KafkaSourceRDDOffsetRange( + topicPartition: TopicPartition, + fromOffset: Long, + untilOffset: Long, + preferredLoc: Option[String]) { + def topic: String = topicPartition.topic + def partition: Int = topicPartition.partition + def size: Long = untilOffset - fromOffset +} + + +/** Partition of the KafkaSourceRDD */ +private[kafka010] case class KafkaSourceRDDPartition( + index: Int, offsetRange: KafkaSourceRDDOffsetRange) extends Partition + + +/** + * An RDD that reads data from Kafka based on offset ranges across multiple partitions. + * Additionally, it allows preferred locations to be set for each topic + partition, so that + * the [[KafkaSource]] can ensure the same executor always reads the same topic + partition + * and cached KafkaConsuemrs (see [[CachedKafkaConsumer]] can be used read data efficiently. + * + * @param sc the [[SparkContext]] + * @param executorKafkaParams Kafka configuration for creating KafkaConsumer on the executors + * @param offsetRanges Offset ranges that define the Kafka data belonging to this RDD + */ +private[kafka010] class KafkaSourceRDD( + sc: SparkContext, + executorKafkaParams: ju.Map[String, Object], + offsetRanges: Seq[KafkaSourceRDDOffsetRange], + pollTimeoutMs: Long) + extends RDD[ConsumerRecord[Array[Byte], Array[Byte]]](sc, Nil) { + + override def persist(newLevel: StorageLevel): this.type = { + logError("Kafka ConsumerRecord is not serializable. " + + "Use .map to extract fields before calling .persist or .window") + super.persist(newLevel) + } + + override def getPartitions: Array[Partition] = { + offsetRanges.zipWithIndex.map { case (o, i) => new KafkaSourceRDDPartition(i, o) }.toArray + } + + override def count(): Long = offsetRanges.map(_.size).sum + + override def countApprox(timeout: Long, confidence: Double): PartialResult[BoundedDouble] = { + val c = count + new PartialResult(new BoundedDouble(c, 1.0, c, c), true) + } + + override def isEmpty(): Boolean = count == 0L + + override def take(num: Int): Array[ConsumerRecord[Array[Byte], Array[Byte]]] = { + val nonEmptyPartitions = + this.partitions.map(_.asInstanceOf[KafkaSourceRDDPartition]).filter(_.offsetRange.size > 0) + + if (num < 1 || nonEmptyPartitions.isEmpty) { + return new Array[ConsumerRecord[Array[Byte], Array[Byte]]](0) + } + + // Determine in advance how many messages need to be taken from each partition + val parts = nonEmptyPartitions.foldLeft(Map[Int, Int]()) { (result, part) => + val remain = num - result.values.sum + if (remain > 0) { + val taken = Math.min(remain, part.offsetRange.size) + result + (part.index -> taken.toInt) + } else { + result + } + } + + val buf = new ArrayBuffer[ConsumerRecord[Array[Byte], Array[Byte]]] + val res = context.runJob( + this, + (tc: TaskContext, it: Iterator[ConsumerRecord[Array[Byte], Array[Byte]]]) => + it.take(parts(tc.partitionId)).toArray, parts.keys.toArray + ) + res.foreach(buf ++= _) + buf.toArray + } + + override def compute( + thePart: Partition, + context: TaskContext): Iterator[ConsumerRecord[Array[Byte], Array[Byte]]] = { + val range = thePart.asInstanceOf[KafkaSourceRDDPartition].offsetRange + assert( + range.fromOffset <= range.untilOffset, + s"Beginning offset ${range.fromOffset} is after the ending offset ${range.untilOffset} " + + s"for topic ${range.topic} partition ${range.partition}. " + + "You either provided an invalid fromOffset, or the Kafka topic has been damaged") + if (range.fromOffset == range.untilOffset) { + logInfo(s"Beginning offset ${range.fromOffset} is the same as ending offset " + + s"skipping ${range.topic} ${range.partition}") + Iterator.empty + + } else { + + val consumer = CachedKafkaConsumer.getOrCreate( + range.topic, range.partition, executorKafkaParams) + var requestOffset = range.fromOffset + + logDebug(s"Creating iterator for $range") + + new Iterator[ConsumerRecord[Array[Byte], Array[Byte]]]() { + override def hasNext(): Boolean = requestOffset < range.untilOffset + override def next(): ConsumerRecord[Array[Byte], Array[Byte]] = { + assert(hasNext(), "Can't call next() once untilOffset has been reached") + val r = consumer.get(requestOffset, pollTimeoutMs) + requestOffset += 1 + r + } + } + } + } +} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/package-info.java b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/package-info.java new file mode 100644 index 0000000000000..596f775c56dbc --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/package-info.java @@ -0,0 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Structured Streaming Data Source for Kafka 0.10 + */ +package org.apache.spark.sql.kafka010; diff --git a/external/kafka-0-10-sql/src/test/resources/log4j.properties b/external/kafka-0-10-sql/src/test/resources/log4j.properties new file mode 100644 index 0000000000000..75e3b53a093f6 --- /dev/null +++ b/external/kafka-0-10-sql/src/test/resources/log4j.properties @@ -0,0 +1,28 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Set everything to be logged to the file target/unit-tests.log +log4j.rootCategory=INFO, file +log4j.appender.file=org.apache.log4j.FileAppender +log4j.appender.file.append=true +log4j.appender.file.file=target/unit-tests.log +log4j.appender.file.layout=org.apache.log4j.PatternLayout +log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n + +# Ignore messages below warning level from Jetty, because it's a bit verbose +log4j.logger.org.spark-project.jetty=WARN + diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceOffsetSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceOffsetSuite.scala new file mode 100644 index 0000000000000..7056a41b1751e --- /dev/null +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceOffsetSuite.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import org.apache.spark.sql.streaming.OffsetSuite + +class KafkaSourceOffsetSuite extends OffsetSuite { + + compare( + one = KafkaSourceOffset(("t", 0, 1L)), + two = KafkaSourceOffset(("t", 0, 2L))) + + compare( + one = KafkaSourceOffset(("t", 0, 1L), ("t", 1, 0L)), + two = KafkaSourceOffset(("t", 0, 2L), ("t", 1, 1L))) + + compare( + one = KafkaSourceOffset(("t", 0, 1L), ("T", 0, 0L)), + two = KafkaSourceOffset(("t", 0, 2L), ("T", 0, 1L))) + + compare( + one = KafkaSourceOffset(("t", 0, 1L)), + two = KafkaSourceOffset(("t", 0, 2L), ("t", 1, 1L))) +} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala new file mode 100644 index 0000000000000..64bf503058027 --- /dev/null +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala @@ -0,0 +1,424 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.util.concurrent.atomic.AtomicInteger + +import scala.util.Random + +import org.apache.kafka.clients.producer.RecordMetadata +import org.scalatest.BeforeAndAfter +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.sql.test.SharedSQLContext + + +abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { + + protected var testUtils: KafkaTestUtils = _ + + override val streamingTimeout = 30.seconds + + override def beforeAll(): Unit = { + super.beforeAll() + testUtils = new KafkaTestUtils + testUtils.setup() + } + + override def afterAll(): Unit = { + if (testUtils != null) { + testUtils.teardown() + testUtils = null + super.afterAll() + } + } + + protected def makeSureGetOffsetCalled = AssertOnQuery { q => + // Because KafkaSource's initialPartitionOffsets is set lazily, we need to make sure + // its "getOffset" is called before pushing any data. Otherwise, because of the race contion, + // we don't know which data should be fetched when `startingOffset` is latest. + q.processAllAvailable() + true + } + + /** + * Add data to Kafka. + * + * `topicAction` can be used to run actions for each topic before inserting data. + */ + case class AddKafkaData(topics: Set[String], data: Int*) + (implicit ensureDataInMultiplePartition: Boolean = false, + concurrent: Boolean = false, + message: String = "", + topicAction: (String, Option[Int]) => Unit = (_, _) => {}) extends AddData { + + override def addData(query: Option[StreamExecution]): (Source, Offset) = { + if (query.get.isActive) { + // Make sure no Spark job is running when deleting a topic + query.get.processAllAvailable() + } + + val existingTopics = testUtils.getAllTopicsAndPartitionSize().toMap + val newTopics = topics.diff(existingTopics.keySet) + for (newTopic <- newTopics) { + topicAction(newTopic, None) + } + for (existingTopicPartitions <- existingTopics) { + topicAction(existingTopicPartitions._1, Some(existingTopicPartitions._2)) + } + + // Read all topics again in case some topics are delete. + val allTopics = testUtils.getAllTopicsAndPartitionSize().toMap.keys + require( + query.nonEmpty, + "Cannot add data when there is no query for finding the active kafka source") + + val sources = query.get.logicalPlan.collect { + case StreamingExecutionRelation(source, _) if source.isInstanceOf[KafkaSource] => + source.asInstanceOf[KafkaSource] + } + if (sources.isEmpty) { + throw new Exception( + "Could not find Kafka source in the StreamExecution logical plan to add data to") + } else if (sources.size > 1) { + throw new Exception( + "Could not select the Kafka source in the StreamExecution logical plan as there" + + "are multiple Kafka sources:\n\t" + sources.mkString("\n\t")) + } + val kafkaSource = sources.head + val topic = topics.toSeq(Random.nextInt(topics.size)) + val sentMetadata = testUtils.sendMessages(topic, data.map { _.toString }.toArray) + + def metadataToStr(m: (String, RecordMetadata)): String = { + s"Sent ${m._1} to partition ${m._2.partition()}, offset ${m._2.offset()}" + } + // Verify that the test data gets inserted into multiple partitions + if (ensureDataInMultiplePartition) { + require( + sentMetadata.groupBy(_._2.partition).size > 1, + s"Added data does not test multiple partitions: ${sentMetadata.map(metadataToStr)}") + } + + val offset = KafkaSourceOffset(testUtils.getLatestOffsets(topics)) + logInfo(s"Added data, expected offset $offset") + (kafkaSource, offset) + } + + override def toString: String = + s"AddKafkaData(topics = $topics, data = $data, message = $message)" + } +} + + +class KafkaSourceSuite extends KafkaSourceTest { + + import testImplicits._ + + private val topicId = new AtomicInteger(0) + + test("cannot stop Kafka stream") { + val topic = newTopic() + testUtils.createTopic(newTopic(), partitions = 5) + testUtils.sendMessages(topic, (101 to 105).map { _.toString }.toArray) + + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("subscribePattern", s"topic-.*") + + val kafka = reader.load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + val mapped = kafka.map(kv => kv._2.toInt + 1) + + testStream(mapped)( + StopStream + ) + } + + test("subscribing topic by name from latest offsets") { + val topic = newTopic() + testFromLatestOffsets(topic, "subscribe" -> topic) + } + + test("subscribing topic by name from earliest offsets") { + val topic = newTopic() + testFromEarliestOffsets(topic, "subscribe" -> topic) + } + + test("subscribing topic by pattern from latest offsets") { + val topicPrefix = newTopic() + val topic = topicPrefix + "-suffix" + testFromLatestOffsets(topic, "subscribePattern" -> s"$topicPrefix-.*") + } + + test("subscribing topic by pattern from earliest offsets") { + val topicPrefix = newTopic() + val topic = topicPrefix + "-suffix" + testFromEarliestOffsets(topic, "subscribePattern" -> s"$topicPrefix-.*") + } + + test("subscribing topic by pattern with topic deletions") { + val topicPrefix = newTopic() + val topic = topicPrefix + "-seems" + val topic2 = topicPrefix + "-bad" + testUtils.createTopic(topic, partitions = 5) + testUtils.sendMessages(topic, Array("-1")) + require(testUtils.getLatestOffsets(Set(topic)).size === 5) + + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("subscribePattern", s"$topicPrefix-.*") + .option("failOnDataLoss", "false") + + val kafka = reader.load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + val mapped = kafka.map(kv => kv._2.toInt + 1) + + testStream(mapped)( + makeSureGetOffsetCalled, + AddKafkaData(Set(topic), 1, 2, 3), + CheckAnswer(2, 3, 4), + Assert { + testUtils.deleteTopic(topic) + testUtils.createTopic(topic2, partitions = 5) + true + }, + AddKafkaData(Set(topic2), 4, 5, 6), + CheckAnswer(2, 3, 4, 5, 6, 7) + ) + } + + test("bad source options") { + def testBadOptions(options: (String, String)*)(expectedMsgs: String*): Unit = { + val ex = intercept[IllegalArgumentException] { + val reader = spark + .readStream + .format("kafka") + options.foreach { case (k, v) => reader.option(k, v) } + reader.load() + } + expectedMsgs.foreach { m => + assert(ex.getMessage.toLowerCase.contains(m.toLowerCase)) + } + } + + // No strategy specified + testBadOptions()("options must be specified", "subscribe", "subscribePattern") + + // Multiple strategies specified + testBadOptions("subscribe" -> "t", "subscribePattern" -> "t.*")( + "only one", "options can be specified") + + testBadOptions("subscribe" -> "")("no topics to subscribe") + testBadOptions("subscribePattern" -> "")("pattern to subscribe is empty") + } + + test("unsupported kafka configs") { + def testUnsupportedConfig(key: String, value: String = "someValue"): Unit = { + val ex = intercept[IllegalArgumentException] { + val reader = spark + .readStream + .format("kafka") + .option("subscribe", "topic") + .option("kafka.bootstrap.servers", "somehost") + .option(s"$key", value) + reader.load() + } + assert(ex.getMessage.toLowerCase.contains("not supported")) + } + + testUnsupportedConfig("kafka.group.id") + testUnsupportedConfig("kafka.auto.offset.reset") + testUnsupportedConfig("kafka.enable.auto.commit") + testUnsupportedConfig("kafka.interceptor.classes") + testUnsupportedConfig("kafka.key.deserializer") + testUnsupportedConfig("kafka.value.deserializer") + + testUnsupportedConfig("kafka.auto.offset.reset", "none") + testUnsupportedConfig("kafka.auto.offset.reset", "someValue") + testUnsupportedConfig("kafka.auto.offset.reset", "earliest") + testUnsupportedConfig("kafka.auto.offset.reset", "latest") + } + + private def newTopic(): String = s"topic-${topicId.getAndIncrement()}" + + private def testFromLatestOffsets(topic: String, options: (String, String)*): Unit = { + testUtils.createTopic(topic, partitions = 5) + testUtils.sendMessages(topic, Array("-1")) + require(testUtils.getLatestOffsets(Set(topic)).size === 5) + + val reader = spark + .readStream + .format("kafka") + .option("startingOffset", s"latest") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + options.foreach { case (k, v) => reader.option(k, v) } + val kafka = reader.load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + val mapped = kafka.map(kv => kv._2.toInt + 1) + + testStream(mapped)( + makeSureGetOffsetCalled, + AddKafkaData(Set(topic), 1, 2, 3), + CheckAnswer(2, 3, 4), + StopStream, + StartStream(), + CheckAnswer(2, 3, 4), // Should get the data back on recovery + StopStream, + AddKafkaData(Set(topic), 4, 5, 6), // Add data when stream is stopped + StartStream(), + CheckAnswer(2, 3, 4, 5, 6, 7), // Should get the added data + AddKafkaData(Set(topic), 7, 8), + CheckAnswer(2, 3, 4, 5, 6, 7, 8, 9), + AssertOnQuery("Add partitions") { query: StreamExecution => + testUtils.addPartitions(topic, 10) + true + }, + AddKafkaData(Set(topic), 9, 10, 11, 12, 13, 14, 15, 16), + CheckAnswer(2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17) + ) + } + + private def testFromEarliestOffsets(topic: String, options: (String, String)*): Unit = { + testUtils.createTopic(topic, partitions = 5) + testUtils.sendMessages(topic, (1 to 3).map { _.toString }.toArray) + require(testUtils.getLatestOffsets(Set(topic)).size === 5) + + val reader = spark.readStream + reader + .format(classOf[KafkaSourceProvider].getCanonicalName.stripSuffix("$")) + .option("startingOffset", s"earliest") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + options.foreach { case (k, v) => reader.option(k, v) } + val kafka = reader.load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + val mapped = kafka.map(kv => kv._2.toInt + 1) + + testStream(mapped)( + AddKafkaData(Set(topic), 4, 5, 6), // Add data when stream is stopped + CheckAnswer(2, 3, 4, 5, 6, 7), + StopStream, + StartStream(), + CheckAnswer(2, 3, 4, 5, 6, 7), + StopStream, + AddKafkaData(Set(topic), 7, 8), + StartStream(), + CheckAnswer(2, 3, 4, 5, 6, 7, 8, 9), + AssertOnQuery("Add partitions") { query: StreamExecution => + testUtils.addPartitions(topic, 10) + true + }, + AddKafkaData(Set(topic), 9, 10, 11, 12, 13, 14, 15, 16), + CheckAnswer(2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17) + ) + } +} + + +class KafkaSourceStressSuite extends KafkaSourceTest with BeforeAndAfter { + + import testImplicits._ + + val topicId = new AtomicInteger(1) + + @volatile var topics: Seq[String] = (1 to 5).map(_ => newStressTopic) + + def newStressTopic: String = s"stress${topicId.getAndIncrement()}" + + private def nextInt(start: Int, end: Int): Int = { + start + Random.nextInt(start + end - 1) + } + + after { + for (topic <- testUtils.getAllTopicsAndPartitionSize().toMap.keys) { + testUtils.deleteTopic(topic) + } + } + + test("stress test with multiple topics and partitions") { + topics.foreach { topic => + testUtils.createTopic(topic, partitions = nextInt(1, 6)) + testUtils.sendMessages(topic, (101 to 105).map { _.toString }.toArray) + } + + // Create Kafka source that reads from latest offset + val kafka = + spark.readStream + .format(classOf[KafkaSourceProvider].getCanonicalName.stripSuffix("$")) + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("subscribePattern", "stress.*") + .option("failOnDataLoss", "false") + .load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + + val mapped = kafka.map(kv => kv._2.toInt + 1) + + runStressTest( + mapped, + Seq(makeSureGetOffsetCalled), + (d, running) => { + Random.nextInt(5) match { + case 0 => // Add a new topic + topics = topics ++ Seq(newStressTopic) + AddKafkaData(topics.toSet, d: _*)(message = s"Add topic $newStressTopic", + topicAction = (topic, partition) => { + if (partition.isEmpty) { + testUtils.createTopic(topic, partitions = nextInt(1, 6)) + } + }) + case 1 if running => + // Only delete a topic when the query is running. Otherwise, we may lost data and + // cannot check the correctness. + val deletedTopic = topics(Random.nextInt(topics.size)) + if (deletedTopic != topics.head) { + topics = topics.filterNot(_ == deletedTopic) + } + AddKafkaData(topics.toSet, d: _*)(message = s"Delete topic $deletedTopic", + topicAction = (topic, partition) => { + // Never remove the first topic to make sure we have at least one topic + if (topic == deletedTopic && deletedTopic != topics.head) { + testUtils.deleteTopic(deletedTopic) + } + }) + case 2 => // Add new partitions + AddKafkaData(topics.toSet, d: _*)(message = "Add partitiosn", + topicAction = (topic, partition) => { + testUtils.addPartitions(topic, partition.get + nextInt(1, 6)) + }) + case _ => // Just add new data + AddKafkaData(topics.toSet, d: _*) + } + }, + iterations = 50) + } +} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala new file mode 100644 index 0000000000000..3eb8a737ba4c8 --- /dev/null +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala @@ -0,0 +1,339 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.io.File +import java.lang.{Integer => JInt} +import java.net.InetSocketAddress +import java.util.{Map => JMap, Properties} +import java.util.concurrent.TimeUnit + +import scala.collection.JavaConverters._ +import scala.language.postfixOps +import scala.util.Random + +import kafka.admin.AdminUtils +import kafka.api.Request +import kafka.common.TopicAndPartition +import kafka.server.{KafkaConfig, KafkaServer, OffsetCheckpoint} +import kafka.utils.ZkUtils +import org.apache.kafka.clients.consumer.KafkaConsumer +import org.apache.kafka.clients.producer._ +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.serialization.{StringDeserializer, StringSerializer} +import org.apache.zookeeper.server.{NIOServerCnxnFactory, ZooKeeperServer} +import org.scalatest.concurrent.Eventually._ +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.internal.Logging +import org.apache.spark.util.Utils +import org.apache.spark.SparkConf + +/** + * This is a helper class for Kafka test suites. This has the functionality to set up + * and tear down local Kafka servers, and to push data using Kafka producers. + * + * The reason to put Kafka test utility class in src is to test Python related Kafka APIs. + */ +class KafkaTestUtils extends Logging { + + // Zookeeper related configurations + private val zkHost = "localhost" + private var zkPort: Int = 0 + private val zkConnectionTimeout = 60000 + private val zkSessionTimeout = 6000 + + private var zookeeper: EmbeddedZookeeper = _ + + private var zkUtils: ZkUtils = _ + + // Kafka broker related configurations + private val brokerHost = "localhost" + private var brokerPort = 0 + private var brokerConf: KafkaConfig = _ + + // Kafka broker server + private var server: KafkaServer = _ + + // Kafka producer + private var producer: Producer[String, String] = _ + + // Flag to test whether the system is correctly started + private var zkReady = false + private var brokerReady = false + + def zkAddress: String = { + assert(zkReady, "Zookeeper not setup yet or already torn down, cannot get zookeeper address") + s"$zkHost:$zkPort" + } + + def brokerAddress: String = { + assert(brokerReady, "Kafka not setup yet or already torn down, cannot get broker address") + s"$brokerHost:$brokerPort" + } + + def zookeeperClient: ZkUtils = { + assert(zkReady, "Zookeeper not setup yet or already torn down, cannot get zookeeper client") + Option(zkUtils).getOrElse( + throw new IllegalStateException("Zookeeper client is not yet initialized")) + } + + // Set up the Embedded Zookeeper server and get the proper Zookeeper port + private def setupEmbeddedZookeeper(): Unit = { + // Zookeeper server startup + zookeeper = new EmbeddedZookeeper(s"$zkHost:$zkPort") + // Get the actual zookeeper binding port + zkPort = zookeeper.actualPort + zkUtils = ZkUtils(s"$zkHost:$zkPort", zkSessionTimeout, zkConnectionTimeout, false) + zkReady = true + } + + // Set up the Embedded Kafka server + private def setupEmbeddedKafkaServer(): Unit = { + assert(zkReady, "Zookeeper should be set up beforehand") + + // Kafka broker startup + Utils.startServiceOnPort(brokerPort, port => { + brokerPort = port + brokerConf = new KafkaConfig(brokerConfiguration, doLog = false) + server = new KafkaServer(brokerConf) + server.startup() + brokerPort = server.boundPort() + (server, brokerPort) + }, new SparkConf(), "KafkaBroker") + + brokerReady = true + } + + /** setup the whole embedded servers, including Zookeeper and Kafka brokers */ + def setup(): Unit = { + setupEmbeddedZookeeper() + setupEmbeddedKafkaServer() + } + + /** Teardown the whole servers, including Kafka broker and Zookeeper */ + def teardown(): Unit = { + brokerReady = false + zkReady = false + + if (producer != null) { + producer.close() + producer = null + } + + if (server != null) { + server.shutdown() + server = null + } + + brokerConf.logDirs.foreach { f => Utils.deleteRecursively(new File(f)) } + + if (zkUtils != null) { + zkUtils.close() + zkUtils = null + } + + if (zookeeper != null) { + zookeeper.shutdown() + zookeeper = null + } + } + + /** Create a Kafka topic and wait until it is propagated to the whole cluster */ + def createTopic(topic: String, partitions: Int): Unit = { + AdminUtils.createTopic(zkUtils, topic, partitions, 1) + // wait until metadata is propagated + (0 until partitions).foreach { p => + waitUntilMetadataIsPropagated(topic, p) + } + } + + def getAllTopicsAndPartitionSize(): Seq[(String, Int)] = { + zkUtils.getPartitionsForTopics(zkUtils.getAllTopics()).mapValues(_.size).toSeq + } + + /** Create a Kafka topic and wait until it is propagated to the whole cluster */ + def createTopic(topic: String): Unit = { + createTopic(topic, 1) + } + + /** Delete a Kafka topic and wait until it is propagated to the whole cluster */ + def deleteTopic(topic: String): Unit = { + val partitions = zkUtils.getPartitionsForTopics(Seq(topic))(topic).size + AdminUtils.deleteTopic(zkUtils, topic) + verifyTopicDeletion(zkUtils, topic, partitions, List(this.server)) + } + + /** Add new paritions to a Kafka topic */ + def addPartitions(topic: String, partitions: Int): Unit = { + AdminUtils.addPartitions(zkUtils, topic, partitions) + // wait until metadata is propagated + (0 until partitions).foreach { p => + waitUntilMetadataIsPropagated(topic, p) + } + } + + /** Java-friendly function for sending messages to the Kafka broker */ + def sendMessages(topic: String, messageToFreq: JMap[String, JInt]): Unit = { + sendMessages(topic, Map(messageToFreq.asScala.mapValues(_.intValue()).toSeq: _*)) + } + + /** Send the messages to the Kafka broker */ + def sendMessages(topic: String, messageToFreq: Map[String, Int]): Unit = { + val messages = messageToFreq.flatMap { case (s, freq) => Seq.fill(freq)(s) }.toArray + sendMessages(topic, messages) + } + + /** Send the array of messages to the Kafka broker */ + def sendMessages(topic: String, messages: Array[String]): Seq[(String, RecordMetadata)] = { + producer = new KafkaProducer[String, String](producerConfiguration) + val offsets = try { + messages.map { m => + val metadata = + producer.send(new ProducerRecord[String, String](topic, m)).get(10, TimeUnit.SECONDS) + logInfo(s"\tSent $m to partition ${metadata.partition}, offset ${metadata.offset}") + (m, metadata) + } + } finally { + if (producer != null) { + producer.close() + producer = null + } + } + offsets + } + + def getLatestOffsets(topics: Set[String]): Map[TopicPartition, Long] = { + val kc = new KafkaConsumer[String, String](consumerConfiguration) + logInfo("Created consumer to get latest offsets") + kc.subscribe(topics.asJavaCollection) + kc.poll(0) + val partitions = kc.assignment() + kc.pause(partitions) + kc.seekToEnd(partitions) + val offsets = partitions.asScala.map(p => p -> kc.position(p)).toMap + kc.close() + logInfo("Closed consumer to get latest offsets") + offsets + } + + private def brokerConfiguration: Properties = { + val props = new Properties() + props.put("broker.id", "0") + props.put("host.name", "localhost") + props.put("advertised.host.name", "localhost") + props.put("port", brokerPort.toString) + props.put("log.dir", Utils.createTempDir().getAbsolutePath) + props.put("zookeeper.connect", zkAddress) + props.put("log.flush.interval.messages", "1") + props.put("replica.socket.timeout.ms", "1500") + props.put("delete.topic.enable", "true") + props + } + + private def producerConfiguration: Properties = { + val props = new Properties() + props.put("bootstrap.servers", brokerAddress) + props.put("value.serializer", classOf[StringSerializer].getName) + props.put("key.serializer", classOf[StringSerializer].getName) + // wait for all in-sync replicas to ack sends + props.put("acks", "all") + props + } + + private def consumerConfiguration: Properties = { + val props = new Properties() + props.put("bootstrap.servers", brokerAddress) + props.put("group.id", "group-KafkaTestUtils-" + Random.nextInt) + props.put("value.deserializer", classOf[StringDeserializer].getName) + props.put("key.deserializer", classOf[StringDeserializer].getName) + props.put("enable.auto.commit", "false") + props + } + + private def verifyTopicDeletion( + zkUtils: ZkUtils, + topic: String, + numPartitions: Int, + servers: Seq[KafkaServer]) { + import ZkUtils._ + val topicAndPartitions = (0 until numPartitions).map(TopicAndPartition(topic, _)) + def isDeleted(): Boolean = { + // wait until admin path for delete topic is deleted, signaling completion of topic deletion + val deletePath = !zkUtils.pathExists(getDeleteTopicPath(topic)) + val topicPath = !zkUtils.pathExists(getTopicPath(topic)) + // ensure that the topic-partition has been deleted from all brokers' replica managers + val replicaManager = servers.forall(server => topicAndPartitions.forall(tp => + server.replicaManager.getPartition(tp.topic, tp.partition) == None)) + // ensure that logs from all replicas are deleted if delete topic is marked successful + val logManager = servers.forall(server => topicAndPartitions.forall(tp => + server.getLogManager().getLog(tp).isEmpty)) + // ensure that topic is removed from all cleaner offsets + val cleaner = servers.forall(server => topicAndPartitions.forall { tp => + val checkpoints = server.getLogManager().logDirs.map { logDir => + new OffsetCheckpoint(new File(logDir, "cleaner-offset-checkpoint")).read() + } + checkpoints.forall(checkpointsPerLogDir => !checkpointsPerLogDir.contains(tp)) + }) + deletePath && topicPath && replicaManager && logManager && cleaner + } + eventually(timeout(10.seconds)) { + assert(isDeleted, s"$topic not deleted after timeout") + } + } + + private def waitUntilMetadataIsPropagated(topic: String, partition: Int): Unit = { + def isPropagated = server.apis.metadataCache.getPartitionInfo(topic, partition) match { + case Some(partitionState) => + val leaderAndInSyncReplicas = partitionState.leaderIsrAndControllerEpoch.leaderAndIsr + + zkUtils.getLeaderForPartition(topic, partition).isDefined && + Request.isValidBrokerId(leaderAndInSyncReplicas.leader) && + leaderAndInSyncReplicas.isr.size >= 1 + + case _ => + false + } + eventually(timeout(10.seconds)) { + assert(isPropagated, s"Partition [$topic, $partition] metadata not propagated after timeout") + } + } + + private class EmbeddedZookeeper(val zkConnect: String) { + val snapshotDir = Utils.createTempDir() + val logDir = Utils.createTempDir() + + val zookeeper = new ZooKeeperServer(snapshotDir, logDir, 500) + val (ip, port) = { + val splits = zkConnect.split(":") + (splits(0), splits(1).toInt) + } + val factory = new NIOServerCnxnFactory() + factory.configure(new InetSocketAddress(ip, port), 16) + factory.startup(zookeeper) + + val actualPort = factory.getLocalPort + + def shutdown() { + factory.shutdown() + Utils.deleteRecursively(snapshotDir) + Utils.deleteRecursively(logDir) + } + } +} + diff --git a/pom.xml b/pom.xml index 8408f4b1fa5ed..37976b0359ad4 100644 --- a/pom.xml +++ b/pom.xml @@ -111,6 +111,7 @@ external/kafka-0-8-assembly external/kafka-0-10 external/kafka-0-10-assembly + external/kafka-0-10-sql diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 8e47e7f13d367..88d5dc9b02dd9 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -39,8 +39,8 @@ object BuildCommons { private val buildLocation = file(".").getAbsoluteFile.getParentFile - val sqlProjects@Seq(catalyst, sql, hive, hiveThriftServer) = Seq( - "catalyst", "sql", "hive", "hive-thriftserver" + val sqlProjects@Seq(catalyst, sql, hive, hiveThriftServer, sqlKafka010) = Seq( + "catalyst", "sql", "hive", "hive-thriftserver", "sql-kafka-0-10" ).map(ProjectRef(buildLocation, _)) val streamingProjects@Seq( @@ -353,7 +353,7 @@ object SparkBuild extends PomBuild { val mimaProjects = allProjects.filterNot { x => Seq( spark, hive, hiveThriftServer, catalyst, repl, networkCommon, networkShuffle, networkYarn, - unsafe, tags + unsafe, tags, sqlKafka010 ).contains(x) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 9825f19b86a55..b3a0d6ad0bd4c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -116,7 +116,7 @@ class StreamExecution( * [[HDFSMetadataLog]]. See SPARK-14131 for more details. */ val microBatchThread = - new UninterruptibleThread(s"stream execution thread for $name") { + new StreamExecutionThread(s"stream execution thread for $name") { override def run(): Unit = { // To fix call site like "run at :0", we bridge the call site from the caller // thread to this micro batch thread @@ -530,3 +530,9 @@ object StreamExecution { def nextId: Long = _nextId.getAndIncrement() } + +/** + * A special thread to run the stream query. Some codes require to run in the StreamExecutionThread + * and will use `classOf[StreamExecutionThread]` to check. + */ +abstract class StreamExecutionThread(name: String) extends UninterruptibleThread(name) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index aa6515bc7a909..09140a1d6e76b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -50,11 +50,11 @@ import org.apache.spark.util.{Clock, ManualClock, SystemClock, Utils} * * {{{ * val inputData = MemoryStream[Int] - val mapped = inputData.toDS().map(_ + 1) - - testStream(mapped)( - AddData(inputData, 1, 2, 3), - CheckAnswer(2, 3, 4)) + * val mapped = inputData.toDS().map(_ + 1) + * + * testStream(mapped)( + * AddData(inputData, 1, 2, 3), + * CheckAnswer(2, 3, 4)) * }}} * * Note that while we do sleep to allow the other thread to progress without spinning, @@ -477,21 +477,41 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { } } + + /** + * Creates a stress test that randomly starts/stops/adds data/checks the result. + * + * @param ds a dataframe that executes + 1 on a stream of integers, returning the result + * @param addData an add data action that adds the given numbers to the stream, encoding them + * as needed + * @param iterations the iteration number + */ + def runStressTest( + ds: Dataset[Int], + addData: Seq[Int] => StreamAction, + iterations: Int = 100): Unit = { + runStressTest(ds, Seq.empty, (data, running) => addData(data), iterations) + } + /** * Creates a stress test that randomly starts/stops/adds data/checks the result. * - * @param ds a dataframe that executes + 1 on a stream of integers, returning the result. - * @param addData and add data action that adds the given numbers to the stream, encoding them + * @param ds a dataframe that executes + 1 on a stream of integers, returning the result + * @param prepareActions actions need to run before starting the stress test. + * @param addData an add data action that adds the given numbers to the stream, encoding them * as needed + * @param iterations the iteration number */ def runStressTest( ds: Dataset[Int], - addData: Seq[Int] => StreamAction, - iterations: Int = 100): Unit = { + prepareActions: Seq[StreamAction], + addData: (Seq[Int], Boolean) => StreamAction, + iterations: Int): Unit = { implicit val intEncoder = ExpressionEncoder[Int]() var dataPos = 0 var running = true val actions = new ArrayBuffer[StreamAction]() + actions ++= prepareActions def addCheck() = { actions += CheckAnswer(1 to dataPos: _*) } @@ -499,7 +519,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { val numItems = Random.nextInt(10) val data = dataPos until (dataPos + numItems) dataPos += numItems - actions += addData(data) + actions += addData(data, running) } (1 to iterations).foreach { i => From b678e465afa417780b54db0fbbaa311621311f15 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Wed, 5 Oct 2016 18:11:31 -0700 Subject: [PATCH 202/306] [SPARK-17346][SQL][TEST-MAVEN] Generate the sql test jar to fix the maven build ## What changes were proposed in this pull request? Generate the sql test jar to fix the maven build ## How was this patch tested? Jenkins Author: Shixiong Zhu Closes #15368 from zsxwing/sql-test-jar. --- external/kafka-0-10-sql/pom.xml | 14 ++++++++++ .../spark/sql/kafka010/KafkaSourceSuite.scala | 1 + sql/core/pom.xml | 27 +++++++++++++++++++ 3 files changed, 42 insertions(+) diff --git a/external/kafka-0-10-sql/pom.xml b/external/kafka-0-10-sql/pom.xml index b96445a11f858..ebff5fd07a9b9 100644 --- a/external/kafka-0-10-sql/pom.xml +++ b/external/kafka-0-10-sql/pom.xml @@ -41,6 +41,20 @@ ${project.version} provided + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + + + org.apache.spark + spark-catalyst_${scala.binary.version} + ${project.version} + test-jar + test + org.apache.spark spark-sql_${scala.binary.version} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala index 64bf503058027..6c03070398fca 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala @@ -151,6 +151,7 @@ class KafkaSourceSuite extends KafkaSourceTest { val mapped = kafka.map(kv => kv._2.toInt + 1) testStream(mapped)( + makeSureGetOffsetCalled, StopStream ) } diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 84de1d4a6e2d1..7da77158ff07e 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -132,6 +132,33 @@ target/scala-${scala.binary.version}/classes target/scala-${scala.binary.version}/test-classes + + + org.apache.maven.plugins + maven-jar-plugin + + + + test-jar + + + + test-jar-on-test-compile + test-compile + + test-jar + + + + org.codehaus.mojo build-helper-maven-plugin From 7aeb20be7e999523784aca7be1a7c9c99dec125e Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 5 Oct 2016 23:03:09 -0700 Subject: [PATCH 203/306] [MINOR][ML] Avoid 2D array flatten in NB training. ## What changes were proposed in this pull request? Avoid 2D array flatten in ```NaiveBayes``` training, since flatten method might be expensive (It will create another array and copy data there). ## How was this patch tested? Existing tests. Author: Yanbo Liang Closes #15359 from yanboliang/nb-theta. --- .../org/apache/spark/ml/classification/NaiveBayes.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index 6775745167b08..e565a6fd3ece2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -176,8 +176,8 @@ class NaiveBayes @Since("1.5.0") ( val numLabels = aggregated.length val numDocuments = aggregated.map(_._2._1).sum - val piArray = Array.fill[Double](numLabels)(0.0) - val thetaArrays = Array.fill[Double](numLabels, numFeatures)(0.0) + val piArray = new Array[Double](numLabels) + val thetaArray = new Array[Double](numLabels * numFeatures) val lambda = $(smoothing) val piLogDenom = math.log(numDocuments + numLabels * lambda) @@ -193,14 +193,14 @@ class NaiveBayes @Since("1.5.0") ( } var j = 0 while (j < numFeatures) { - thetaArrays(i)(j) = math.log(sumTermFreqs(j) + lambda) - thetaLogDenom + thetaArray(i * numFeatures + j) = math.log(sumTermFreqs(j) + lambda) - thetaLogDenom j += 1 } i += 1 } val pi = Vectors.dense(piArray) - val theta = new DenseMatrix(numLabels, thetaArrays(0).length, thetaArrays.flatten, true) + val theta = new DenseMatrix(numLabels, numFeatures, thetaArray, true) new NaiveBayesModel(uid, pi, theta) } From 5e9f32dd87e58e909a579eaa310e67d31c3b6573 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 6 Oct 2016 09:58:58 +0100 Subject: [PATCH 204/306] [BUILD] Closing some stale PRs ## What changes were proposed in this pull request? This PR proposes to close some stale PRs and ones suggested to be closed by committer(s) or obviously inappropriate PRs (e.g. branch to branch). Closes #13458 Closes #15278 Closes #15294 Closes #15339 Closes #15283 ## How was this patch tested? N/A Author: hyukjinkwon Closes #15356 from HyukjinKwon/closing-prs. From 92b7e5728025b1bb6ed3aab5f1753c946a73568c Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 6 Oct 2016 09:42:30 -0700 Subject: [PATCH 205/306] [SPARK-17750][SQL] Fix CREATE VIEW with INTERVAL arithmetic. ## What changes were proposed in this pull request? Currently, Spark raises `RuntimeException` when creating a view with timestamp with INTERVAL arithmetic like the following. The root cause is the arithmetic expression, `TimeAdd`, was transformed into `timeadd` function as a VIEW definition. This PR fixes the SQL definition of `TimeAdd` and `TimeSub` expressions. ```scala scala> sql("CREATE TABLE dates (ts TIMESTAMP)") scala> sql("CREATE VIEW view1 AS SELECT ts + INTERVAL 1 DAY FROM dates") java.lang.RuntimeException: Failed to analyze the canonicalized SQL: ... ``` ## How was this patch tested? Pass Jenkins with a new testcase. Author: Dongjoon Hyun Closes #15318 from dongjoon-hyun/SPARK-17750. --- .../expressions/datetimeExpressions.scala | 2 ++ .../resources/sqlgen/interval_arithmetic.sql | 8 ++++++++ .../catalyst/ExpressionSQLBuilderSuite.scala | 18 +++++++++++++++++- .../sql/catalyst/LogicalPlanToSQLSuite.scala | 16 ++++++++++++++++ 4 files changed, 43 insertions(+), 1 deletion(-) create mode 100644 sql/hive/src/test/resources/sqlgen/interval_arithmetic.sql diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 04c17bdaf2989..7ab68a13e09cf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -682,6 +682,7 @@ case class TimeAdd(start: Expression, interval: Expression) override def right: Expression = interval override def toString: String = s"$left + $right" + override def sql: String = s"${left.sql} + ${right.sql}" override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, CalendarIntervalType) override def dataType: DataType = TimestampType @@ -762,6 +763,7 @@ case class TimeSub(start: Expression, interval: Expression) override def right: Expression = interval override def toString: String = s"$left - $right" + override def sql: String = s"${left.sql} - ${right.sql}" override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, CalendarIntervalType) override def dataType: DataType = TimestampType diff --git a/sql/hive/src/test/resources/sqlgen/interval_arithmetic.sql b/sql/hive/src/test/resources/sqlgen/interval_arithmetic.sql new file mode 100644 index 0000000000000..31d00348769f5 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/interval_arithmetic.sql @@ -0,0 +1,8 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +select ts + interval 1 day, ts + interval 2 days, + ts - interval 1 day, ts - interval 2 days, + ts + interval '1' day, ts + interval '2' days, + ts - interval '1' day, ts - interval '2' days +from dates +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `CAST(ts + interval 1 days AS TIMESTAMP)`, `gen_attr_2` AS `CAST(ts + interval 2 days AS TIMESTAMP)`, `gen_attr_3` AS `CAST(ts - interval 1 days AS TIMESTAMP)`, `gen_attr_4` AS `CAST(ts - interval 2 days AS TIMESTAMP)`, `gen_attr_5` AS `CAST(ts + interval 1 days AS TIMESTAMP)`, `gen_attr_6` AS `CAST(ts + interval 2 days AS TIMESTAMP)`, `gen_attr_7` AS `CAST(ts - interval 1 days AS TIMESTAMP)`, `gen_attr_8` AS `CAST(ts - interval 2 days AS TIMESTAMP)` FROM (SELECT CAST(`gen_attr_1` + interval 1 days AS TIMESTAMP) AS `gen_attr_0`, CAST(`gen_attr_1` + interval 2 days AS TIMESTAMP) AS `gen_attr_2`, CAST(`gen_attr_1` - interval 1 days AS TIMESTAMP) AS `gen_attr_3`, CAST(`gen_attr_1` - interval 2 days AS TIMESTAMP) AS `gen_attr_4`, CAST(`gen_attr_1` + interval 1 days AS TIMESTAMP) AS `gen_attr_5`, CAST(`gen_attr_1` + interval 2 days AS TIMESTAMP) AS `gen_attr_6`, CAST(`gen_attr_1` - interval 1 days AS TIMESTAMP) AS `gen_attr_7`, CAST(`gen_attr_1` - interval 2 days AS TIMESTAMP) AS `gen_attr_8` FROM (SELECT `ts` AS `gen_attr_1` FROM `default`.`dates`) AS gen_subquery_0) AS gen_subquery_1 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionSQLBuilderSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionSQLBuilderSuite.scala index ce5efe853ca4f..149ce1e195111 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionSQLBuilderSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionSQLBuilderSuite.scala @@ -20,7 +20,9 @@ package org.apache.spark.sql.catalyst import java.sql.Timestamp import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.{If, Literal, SpecifiedWindowFrame, WindowSpecDefinition} +import org.apache.spark.sql.catalyst.expressions.{If, Literal, SpecifiedWindowFrame, TimeAdd, + TimeSub, WindowSpecDefinition} +import org.apache.spark.unsafe.types.CalendarInterval class ExpressionSQLBuilderSuite extends SQLBuilderTest { test("literal") { @@ -119,4 +121,18 @@ class ExpressionSQLBuilderSuite extends SQLBuilderTest { s"(PARTITION BY `a`, `b` ORDER BY `c` ASC NULLS FIRST, `d` DESC NULLS LAST $frame)" ) } + + test("interval arithmetic") { + val interval = Literal(new CalendarInterval(0, CalendarInterval.MICROS_PER_DAY)) + + checkSQL( + TimeAdd('a, interval), + "`a` + interval 1 days" + ) + + checkSQL( + TimeSub('a, interval), + "`a` - interval 1 days" + ) + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala index 7fa5c29dc5b8f..9ac1e86fc82cb 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala @@ -1145,4 +1145,20 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { """.stripMargin, "inline_tables") } + + test("SPARK-17750 - interval arithmetic") { + withTable("dates") { + sql("create table dates (ts timestamp)") + checkSQL( + """ + |select ts + interval 1 day, ts + interval 2 days, + | ts - interval 1 day, ts - interval 2 days, + | ts + interval '1' day, ts + interval '2' days, + | ts - interval '1' day, ts - interval '2' days + |from dates + """.stripMargin, + "interval_arithmetic" + ) + } + } } From 79accf45ace5549caa0cbab02f94fc87bedb5587 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 6 Oct 2016 10:33:45 -0700 Subject: [PATCH 206/306] [SPARK-17798][SQL] Remove redundant Experimental annotations in sql.streaming ## What changes were proposed in this pull request? I was looking through API annotations to catch mislabeled APIs, and realized DataStreamReader and DataStreamWriter classes are already annotated as Experimental, and as a result there is no need to annotate each method within them. ## How was this patch tested? N/A Author: Reynold Xin Closes #15373 from rxin/SPARK-17798. --- .../sql/streaming/DataStreamReader.scala | 28 ------------------ .../sql/streaming/DataStreamWriter.scala | 29 ------------------- .../streaming/StreamingQueryListener.scala | 4 +-- 3 files changed, 1 insertion(+), 60 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index d437c16a25b01..864a9cd3eb89d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -35,89 +35,73 @@ import org.apache.spark.sql.types.StructType @Experimental final class DataStreamReader private[sql](sparkSession: SparkSession) extends Logging { /** - * :: Experimental :: * Specifies the input data source format. * * @since 2.0.0 */ - @Experimental def format(source: String): DataStreamReader = { this.source = source this } /** - * :: Experimental :: * Specifies the input schema. Some data sources (e.g. JSON) can infer the input schema * automatically from data. By specifying the schema here, the underlying data source can * skip the schema inference step, and thus speed up data loading. * * @since 2.0.0 */ - @Experimental def schema(schema: StructType): DataStreamReader = { this.userSpecifiedSchema = Option(schema) this } /** - * :: Experimental :: * Adds an input option for the underlying data source. * * @since 2.0.0 */ - @Experimental def option(key: String, value: String): DataStreamReader = { this.extraOptions += (key -> value) this } /** - * :: Experimental :: * Adds an input option for the underlying data source. * * @since 2.0.0 */ - @Experimental def option(key: String, value: Boolean): DataStreamReader = option(key, value.toString) /** - * :: Experimental :: * Adds an input option for the underlying data source. * * @since 2.0.0 */ - @Experimental def option(key: String, value: Long): DataStreamReader = option(key, value.toString) /** - * :: Experimental :: * Adds an input option for the underlying data source. * * @since 2.0.0 */ - @Experimental def option(key: String, value: Double): DataStreamReader = option(key, value.toString) /** - * :: Experimental :: * (Scala-specific) Adds input options for the underlying data source. * * @since 2.0.0 */ - @Experimental def options(options: scala.collection.Map[String, String]): DataStreamReader = { this.extraOptions ++= options this } /** - * :: Experimental :: * Adds input options for the underlying data source. * * @since 2.0.0 */ - @Experimental def options(options: java.util.Map[String, String]): DataStreamReader = { this.options(options.asScala) this @@ -125,13 +109,11 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo /** - * :: Experimental :: * Loads input data stream in as a [[DataFrame]], for data streams that don't require a path * (e.g. external key-value stores). * * @since 2.0.0 */ - @Experimental def load(): DataFrame = { val dataSource = DataSource( @@ -143,18 +125,15 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo } /** - * :: Experimental :: * Loads input in as a [[DataFrame]], for data streams that read from some path. * * @since 2.0.0 */ - @Experimental def load(path: String): DataFrame = { option("path", path).load() } /** - * :: Experimental :: * Loads a JSON file stream (one object per line) and returns the result as a [[DataFrame]]. * * This function goes through the input once to determine the input schema. If you know the @@ -198,11 +177,9 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * * @since 2.0.0 */ - @Experimental def json(path: String): DataFrame = format("json").load(path) /** - * :: Experimental :: * Loads a CSV file stream and returns the result as a [[DataFrame]]. * * This function will go through the input once to determine the input schema if `inferSchema` @@ -262,11 +239,9 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * * @since 2.0.0 */ - @Experimental def csv(path: String): DataFrame = format("csv").load(path) /** - * :: Experimental :: * Loads a Parquet file stream, returning the result as a [[DataFrame]]. * * You can set the following Parquet-specific option(s) for reading Parquet files: @@ -281,13 +256,11 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * * @since 2.0.0 */ - @Experimental def parquet(path: String): DataFrame = { format("parquet").load(path) } /** - * :: Experimental :: * Loads text files and returns a [[DataFrame]] whose schema starts with a string column named * "value", and followed by partitioned columns if there are any. * @@ -308,7 +281,6 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * * @since 2.0.0 */ - @Experimental def text(path: String): DataFrame = format("text").load(path) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index f70c7d08a691c..b959444b49298 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -37,7 +37,6 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { private val df = ds.toDF() /** - * :: Experimental :: * Specifies how data of a streaming DataFrame/Dataset is written to a streaming sink. * - `OutputMode.Append()`: only the new rows in the streaming DataFrame/Dataset will be * written to the sink @@ -46,15 +45,12 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { * * @since 2.0.0 */ - @Experimental def outputMode(outputMode: OutputMode): DataStreamWriter[T] = { this.outputMode = outputMode this } - /** - * :: Experimental :: * Specifies how data of a streaming DataFrame/Dataset is written to a streaming sink. * - `append`: only the new rows in the streaming DataFrame/Dataset will be written to * the sink @@ -63,7 +59,6 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { * * @since 2.0.0 */ - @Experimental def outputMode(outputMode: String): DataStreamWriter[T] = { this.outputMode = outputMode.toLowerCase match { case "append" => @@ -78,7 +73,6 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { } /** - * :: Experimental :: * Set the trigger for the stream query. The default value is `ProcessingTime(0)` and it will run * the query as fast as possible. * @@ -100,7 +94,6 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { * * @since 2.0.0 */ - @Experimental def trigger(trigger: Trigger): DataStreamWriter[T] = { this.trigger = trigger this @@ -108,25 +101,21 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { /** - * :: Experimental :: * Specifies the name of the [[StreamingQuery]] that can be started with `start()`. * This name must be unique among all the currently active queries in the associated SQLContext. * * @since 2.0.0 */ - @Experimental def queryName(queryName: String): DataStreamWriter[T] = { this.extraOptions += ("queryName" -> queryName) this } /** - * :: Experimental :: * Specifies the underlying output data source. Built-in options include "parquet" for now. * * @since 2.0.0 */ - @Experimental def format(source: String): DataStreamWriter[T] = { this.source = source this @@ -156,90 +145,74 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { } /** - * :: Experimental :: * Adds an output option for the underlying data source. * * @since 2.0.0 */ - @Experimental def option(key: String, value: String): DataStreamWriter[T] = { this.extraOptions += (key -> value) this } /** - * :: Experimental :: * Adds an output option for the underlying data source. * * @since 2.0.0 */ - @Experimental def option(key: String, value: Boolean): DataStreamWriter[T] = option(key, value.toString) /** - * :: Experimental :: * Adds an output option for the underlying data source. * * @since 2.0.0 */ - @Experimental def option(key: String, value: Long): DataStreamWriter[T] = option(key, value.toString) /** - * :: Experimental :: * Adds an output option for the underlying data source. * * @since 2.0.0 */ - @Experimental def option(key: String, value: Double): DataStreamWriter[T] = option(key, value.toString) /** - * :: Experimental :: * (Scala-specific) Adds output options for the underlying data source. * * @since 2.0.0 */ - @Experimental def options(options: scala.collection.Map[String, String]): DataStreamWriter[T] = { this.extraOptions ++= options this } /** - * :: Experimental :: * Adds output options for the underlying data source. * * @since 2.0.0 */ - @Experimental def options(options: java.util.Map[String, String]): DataStreamWriter[T] = { this.options(options.asScala) this } /** - * :: Experimental :: * Starts the execution of the streaming query, which will continually output results to the given * path as new data arrives. The returned [[StreamingQuery]] object can be used to interact with * the stream. * * @since 2.0.0 */ - @Experimental def start(path: String): StreamingQuery = { option("path", path).start() } /** - * :: Experimental :: * Starts the execution of the streaming query, which will continually output results to the given * path as new data arrives. The returned [[StreamingQuery]] object can be used to interact with * the stream. * * @since 2.0.0 */ - @Experimental def start(): StreamingQuery = { if (source == "memory") { assertNotPartitioned("memory") @@ -297,7 +270,6 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { } /** - * :: Experimental :: * Starts the execution of the streaming query, which will continually send results to the given * [[ForeachWriter]] as as new data arrives. The [[ForeachWriter]] can be used to send the data * generated by the [[DataFrame]]/[[Dataset]] to an external system. @@ -343,7 +315,6 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { * * @since 2.0.0 */ - @Experimental def foreach(writer: ForeachWriter[T]): DataStreamWriter[T] = { this.source = "foreach" this.foreachWriter = if (writer != null) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala index db606abb8ce43..8a8855d85a4c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala @@ -35,7 +35,7 @@ abstract class StreamingQueryListener { /** * Called when a query is started. * @note This is called synchronously with - * [[org.apache.spark.sql.DataStreamWriter `DataStreamWriter.start()`]], + * [[org.apache.spark.sql.streaming.DataStreamWriter `DataStreamWriter.start()`]], * that is, `onQueryStart` will be called on all listeners before * `DataStreamWriter.start()` returns the corresponding [[StreamingQuery]]. Please * don't block this method as it will block your query. @@ -101,8 +101,6 @@ object StreamingQueryListener { * @param queryInfo Information about the status of the query. * @param exception The exception message of the [[StreamingQuery]] if the query was terminated * with an exception. Otherwise, it will be `None`. - * @param stackTrace The stack trace of the exception if the query was terminated with an - * exception. It will be empty if there was no error. * @since 2.0.0 */ @Experimental From 9a48e60e6319d85f2c3be3a3c608dab135e18a73 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Thu, 6 Oct 2016 12:51:12 -0700 Subject: [PATCH 207/306] [SPARK-17780][SQL] Report Throwable to user in StreamExecution ## What changes were proposed in this pull request? When using an incompatible source for structured streaming, it may throw NoClassDefFoundError. It's better to just catch Throwable and report it to the user since the streaming thread is dying. ## How was this patch tested? `test("NoClassDefFoundError from an incompatible source")` Author: Shixiong Zhu Closes #15352 from zsxwing/SPARK-17780. --- .../execution/streaming/StreamExecution.scala | 7 ++++- .../spark/sql/streaming/StreamSuite.scala | 31 ++++++++++++++++++- .../spark/sql/streaming/StreamTest.scala | 3 +- 3 files changed, 37 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index b3a0d6ad0bd4c..333239f875bd3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -207,13 +207,18 @@ class StreamExecution( }) } catch { case _: InterruptedException if state == TERMINATED => // interrupted by stop() - case NonFatal(e) => + case e: Throwable => streamDeathCause = new StreamingQueryException( this, s"Query $name terminated with exception: ${e.getMessage}", e, Some(committedOffsets.toCompositeOffset(sources))) logError(s"Query $name terminated with error", e) + // Rethrow the fatal errors to allow the user using `Thread.UncaughtExceptionHandler` to + // handle them + if (!NonFatal(e)) { + throw e + } } finally { state = TERMINATED sparkSession.streams.notifyQueryTermination(StreamExecution.this) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index 1caafb9d74440..cdbad901dba8e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -17,10 +17,12 @@ package org.apache.spark.sql.streaming +import scala.reflect.ClassTag +import scala.util.control.ControlThrowable + import org.apache.spark.sql._ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.sources.StreamSourceProvider -import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.apache.spark.util.ManualClock @@ -236,6 +238,33 @@ class StreamSuite extends StreamTest { } } + testQuietly("fatal errors from a source should be sent to the user") { + for (e <- Seq( + new VirtualMachineError {}, + new ThreadDeath, + new LinkageError, + new ControlThrowable {} + )) { + val source = new Source { + override def getOffset: Option[Offset] = { + throw e + } + + override def getBatch(start: Option[Offset], end: Offset): DataFrame = { + throw e + } + + override def schema: StructType = StructType(Array(StructField("value", IntegerType))) + + override def stop(): Unit = {} + } + val df = Dataset[Int](sqlContext.sparkSession, StreamingExecutionRelation(source)) + testStream(df)( + ExpectFailure()(ClassTag(e.getClass)) + ) + } + } + test("output mode API in Scala") { val o1 = OutputMode.Append assert(o1 === InternalOutputModes.Append) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 09140a1d6e76b..fa13d385cce75 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -167,7 +167,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { /** Signals that a failure is expected and should not kill the test. */ case class ExpectFailure[T <: Throwable : ClassTag]() extends StreamAction { val causeClass: Class[T] = implicitly[ClassTag[T]].runtimeClass.asInstanceOf[Class[T]] - override def toString(): String = s"ExpectFailure[${causeClass.getCanonicalName}]" + override def toString(): String = s"ExpectFailure[${causeClass.getName}]" } /** Assert that a body is true */ @@ -322,7 +322,6 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { new UncaughtExceptionHandler { override def uncaughtException(t: Thread, e: Throwable): Unit = { streamDeathCause = e - testThread.interrupt() } }) From 49d11d49983fbe270f4df4fb1e34b5fbe854c5ec Mon Sep 17 00:00:00 2001 From: Christian Kadner Date: Thu, 6 Oct 2016 14:28:49 -0700 Subject: [PATCH 208/306] [SPARK-17803][TESTS] Upgrade docker-client dependency [SPARK-17803: Docker integration tests don't run with "Docker for Mac"](https://issues.apache.org/jira/browse/SPARK-17803) ## What changes were proposed in this pull request? This PR upgrades the [docker-client](https://mvnrepository.com/artifact/com.spotify/docker-client) dependency from [3.6.6](https://mvnrepository.com/artifact/com.spotify/docker-client/3.6.6) to [5.0.2](https://mvnrepository.com/artifact/com.spotify/docker-client/5.0.2) to enable _Docker for Mac_ users to run the `docker-integration-tests` out of the box. The very latest docker-client version is [6.0.0](https://mvnrepository.com/artifact/com.spotify/docker-client/6.0.0) but that has one additional dependency and no usage yet. ## How was this patch tested? The code change was tested on Mac OS X Yosemite with both _Docker Toolbox_ as well as _Docker for Mac_ and on Linux Ubuntu 14.04. ``` $ build/mvn -Pyarn -Phadoop-2.6 -Dhadoop.version=2.6.0 -Phive -Phive-thriftserver -DskipTests clean package $ build/mvn -Pdocker-integration-tests -Pscala-2.11 -pl :spark-docker-integration-tests_2.11 clean compile test ``` Author: Christian Kadner Closes #15378 from ckadner/SPARK-17803_Docker_for_Mac. --- .../org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala | 1 + pom.xml | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala index c36f4d5f95482..609696bc8a2c7 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala @@ -24,6 +24,7 @@ import scala.collection.JavaConverters._ import scala.util.control.NonFatal import com.spotify.docker.client._ +import com.spotify.docker.client.exceptions.ImageNotFoundException import com.spotify.docker.client.messages.{ContainerConfig, HostConfig, PortBinding} import org.scalatest.BeforeAndAfterAll import org.scalatest.concurrent.Eventually diff --git a/pom.xml b/pom.xml index 37976b0359ad4..7d13c51b2a596 100644 --- a/pom.xml +++ b/pom.xml @@ -744,7 +744,7 @@ com.spotify docker-client - 3.6.6 + 5.0.2 test From 3713bb199142c5e06e2e527c99650f02f41f47b1 Mon Sep 17 00:00:00 2001 From: sethah Date: Thu, 6 Oct 2016 21:10:17 -0700 Subject: [PATCH 209/306] [SPARK-17792][ML] L-BFGS solver for linear regression does not accept general numeric label column types ## What changes were proposed in this pull request? Before, we computed `instances` in LinearRegression in two spots, even though they did the same thing. One of them did not cast the label column to `DoubleType`. This patch consolidates the computation and always casts the label column to `DoubleType`. ## How was this patch tested? Added a unit test to check all solvers. This test failed before this patch. Author: sethah Closes #15364 from sethah/linreg_numeric_type. --- .../spark/ml/regression/LinearRegression.scala | 17 ++++++----------- .../ml/regression/LinearRegressionSuite.scala | 8 +++++--- 2 files changed, 11 insertions(+), 14 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 536c58f998080..025ed20c75a04 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -188,17 +188,18 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String val numFeatures = dataset.select(col($(featuresCol))).first().getAs[Vector](0).size val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) + val instances: RDD[Instance] = dataset.select( + col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map { + case Row(label: Double, weight: Double, features: Vector) => + Instance(label, weight, features) + } + if (($(solver) == "auto" && $(elasticNetParam) == 0.0 && numFeatures <= WeightedLeastSquares.MAX_NUM_FEATURES) || $(solver) == "normal") { require($(elasticNetParam) == 0.0, "Only L2 regularization can be used when normal " + "solver is used.'") // For low dimensional data, WeightedLeastSquares is more efficiently since the // training algorithm only requires one pass through the data. (SPARK-10668) - val instances: RDD[Instance] = dataset.select( - col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map { - case Row(label: Double, weight: Double, features: Vector) => - Instance(label, weight, features) - } val optimizer = new WeightedLeastSquares($(fitIntercept), $(regParam), $(standardization), true) @@ -221,12 +222,6 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String return lrModel.setSummary(trainingSummary) } - val instances: RDD[Instance] = - dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map { - case Row(label: Double, weight: Double, features: Vector) => - Instance(label, weight, features) - } - val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 5ae371b489aa5..1c94ec67d79d1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -1015,12 +1015,14 @@ class LinearRegressionSuite } test("should support all NumericType labels and not support other types") { - val lr = new LinearRegression().setMaxIter(1) - MLTestingUtils.checkNumericTypes[LinearRegressionModel, LinearRegression]( - lr, spark, isClassification = false) { (expected, actual) => + for (solver <- Seq("auto", "l-bfgs", "normal")) { + val lr = new LinearRegression().setMaxIter(1).setSolver(solver) + MLTestingUtils.checkNumericTypes[LinearRegressionModel, LinearRegression]( + lr, spark, isClassification = false) { (expected, actual) => assert(expected.intercept === actual.intercept) assert(expected.coefficients === actual.coefficients) } + } } } From bcaa799cb01289f73e9f48526e94653a07628983 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Fri, 7 Oct 2016 00:27:55 -0700 Subject: [PATCH 210/306] [SPARK-17805][PYSPARK] Fix in sqlContext.read.text when pass in list of paths ## What changes were proposed in this pull request? If given a list of paths, `pyspark.sql.readwriter.text` will attempt to use an undefined variable `paths`. This change checks if the param `paths` is a basestring and then converts it to a list, so that the same variable `paths` can be used for both cases ## How was this patch tested? Added unit test for reading list of files Author: Bryan Cutler Closes #15379 from BryanCutler/sql-readtext-paths-SPARK-17805. --- python/pyspark/sql/readwriter.py | 4 ++-- python/pyspark/sql/tests.py | 6 ++++++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 3ad6f80de9fdf..91c2b17049fa1 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -289,8 +289,8 @@ def text(self, paths): [Row(value=u'hello'), Row(value=u'this')] """ if isinstance(paths, basestring): - path = [paths] - return self._df(self._jreader.text(self._spark._sc._jvm.PythonUtils.toSeq(path))) + paths = [paths] + return self._df(self._jreader.text(self._spark._sc._jvm.PythonUtils.toSeq(paths))) @since(2.0) def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=None, diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index c2171c277cac3..a9e455565a6cd 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1702,6 +1702,12 @@ def test_cache(self): "does_not_exist", lambda: spark.catalog.uncacheTable("does_not_exist")) + def test_read_text_file_list(self): + df = self.spark.read.text(['python/test_support/sql/text-test.txt', + 'python/test_support/sql/text-test.txt']) + count = df.count() + self.assertEquals(count, 4) + class HiveSparkSubmitTests(SparkSubmitTests): From 18bf9d2b2d7eae0574102d4f15ac27dc71dea18a Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Fri, 7 Oct 2016 11:46:39 +0100 Subject: [PATCH 211/306] [SPARK-17782][STREAMING][BUILD] Add Kafka 0.10 project to build modules ## What changes were proposed in this pull request? This PR adds the Kafka 0.10 subproject to the build infrastructure. This makes sure Kafka 0.10 tests are only triggers when it or of its dependencies change. Author: Herman van Hovell Closes #15355 from hvanhovell/SPARK-17782. --- dev/sparktestsupport/modules.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 5f14683d9a52f..b34ab51f3b996 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -241,6 +241,17 @@ def __hash__(self): ] ) +streaming_kafka_0_10 = Module( + name="streaming-kafka-0-10", + dependencies=[streaming], + source_file_regexes=[ + "external/kafka-0-10", + "external/kafka-0-10-assembly", + ], + sbt_test_goals=[ + "streaming-kafka-0-10/test", + ] +) streaming_flume_sink = Module( name="streaming-flume-sink", From 24097d84743d3e792e395410139e8d486b75a3ef Mon Sep 17 00:00:00 2001 From: Alex Bozarth Date: Fri, 7 Oct 2016 11:47:37 +0100 Subject: [PATCH 212/306] =?UTF-8?q?[SPARK-17795][WEB=20UI]=20Sorting=20on?= =?UTF-8?q?=20stage=20or=20job=20tables=20doesn=E2=80=99t=20reload=20page?= =?UTF-8?q?=20on=20that=20table?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Added anchor on table header id to sorting links on job and stage tables. This make the page reload after a sort load the page at the sorted table. This only changes page load behavior so no UI changes ## How was this patch tested? manually tested and dev/run-tests Author: Alex Bozarth Closes #15369 from ajbozarth/spark17795. --- .../apache/spark/ui/jobs/AllJobsPage.scala | 20 ++++++++++++------- .../apache/spark/ui/jobs/AllStagesPage.scala | 12 +++++------ .../org/apache/spark/ui/jobs/JobPage.scala | 17 +++++++++++----- .../org/apache/spark/ui/jobs/PoolPage.scala | 2 +- .../org/apache/spark/ui/jobs/StageTable.scala | 14 +++++++++---- 5 files changed, 42 insertions(+), 23 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala index c04964ec66479..19bb41a1417c7 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala @@ -216,6 +216,7 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { private def jobsTable( request: HttpServletRequest, + tableHeaderId: String, jobTag: String, jobs: Seq[JobUIData]): Seq[Node] = { val allParameters = request.getParameterMap.asScala.toMap @@ -256,6 +257,7 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { try { new JobPagedTable( jobs, + tableHeaderId, jobTag, UIUtils.prependBaseUri(parent.basePath), "jobs", // subPath @@ -288,9 +290,9 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { val completedJobs = listener.completedJobs.reverse.toSeq val failedJobs = listener.failedJobs.reverse.toSeq - val activeJobsTable = jobsTable(request, "activeJob", activeJobs) - val completedJobsTable = jobsTable(request, "completedJob", completedJobs) - val failedJobsTable = jobsTable(request, "failedJob", failedJobs) + val activeJobsTable = jobsTable(request, "active", "activeJob", activeJobs) + val completedJobsTable = jobsTable(request, "completed", "completedJob", completedJobs) + val failedJobsTable = jobsTable(request, "failed", "failedJob", failedJobs) val shouldShowActiveJobs = activeJobs.nonEmpty val shouldShowCompletedJobs = completedJobs.nonEmpty @@ -486,6 +488,7 @@ private[ui] class JobDataSource( } private[ui] class JobPagedTable( data: Seq[JobUIData], + tableHeaderId: String, jobTag: String, basePath: String, subPath: String, @@ -528,12 +531,13 @@ private[ui] class JobPagedTable( s"&$pageNumberFormField=$page" + s"&$jobTag.sort=$encodedSortColumn" + s"&$jobTag.desc=$desc" + - s"&$pageSizeFormField=$pageSize" + s"&$pageSizeFormField=$pageSize" + + s"#$tableHeaderId" } override def goButtonFormPath: String = { val encodedSortColumn = URLEncoder.encode(sortColumn, "UTF-8") - s"$parameterPath&$jobTag.sort=$encodedSortColumn&$jobTag.desc=$desc" + s"$parameterPath&$jobTag.sort=$encodedSortColumn&$jobTag.desc=$desc#$tableHeaderId" } override def headers: Seq[Node] = { @@ -557,7 +561,8 @@ private[ui] class JobPagedTable( parameterPath + s"&$jobTag.sort=${URLEncoder.encode(header, "UTF-8")}" + s"&$jobTag.desc=${!desc}" + - s"&$jobTag.pageSize=$pageSize") + s"&$jobTag.pageSize=$pageSize" + + s"#$tableHeaderId") val arrow = if (desc) "▾" else "▴" // UP or DOWN @@ -572,7 +577,8 @@ private[ui] class JobPagedTable( val headerLink = Unparsed( parameterPath + s"&$jobTag.sort=${URLEncoder.encode(header, "UTF-8")}" + - s"&$jobTag.pageSize=$pageSize") + s"&$jobTag.pageSize=$pageSize" + + s"#$tableHeaderId") diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala index cba8f82dd77a6..fe6ca1099e6b0 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala @@ -41,19 +41,19 @@ private[ui] class AllStagesPage(parent: StagesTab) extends WebUIPage("") { val subPath = "stages" val activeStagesTable = - new StageTableBase(request, activeStages, "activeStage", parent.basePath, subPath, + new StageTableBase(request, activeStages, "active", "activeStage", parent.basePath, subPath, parent.progressListener, parent.isFairScheduler, killEnabled = parent.killEnabled, isFailedStage = false) val pendingStagesTable = - new StageTableBase(request, pendingStages, "pendingStage", parent.basePath, subPath, - parent.progressListener, parent.isFairScheduler, + new StageTableBase(request, pendingStages, "pending", "pendingStage", parent.basePath, + subPath, parent.progressListener, parent.isFairScheduler, killEnabled = false, isFailedStage = false) val completedStagesTable = - new StageTableBase(request, completedStages, "completedStage", parent.basePath, subPath, - parent.progressListener, parent.isFairScheduler, + new StageTableBase(request, completedStages, "completed", "completedStage", parent.basePath, + subPath, parent.progressListener, parent.isFairScheduler, killEnabled = false, isFailedStage = false) val failedStagesTable = - new StageTableBase(request, failedStages, "failedStage", parent.basePath, subPath, + new StageTableBase(request, failedStages, "failed", "failedStage", parent.basePath, subPath, parent.progressListener, parent.isFairScheduler, killEnabled = false, isFailedStage = true) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala index 2f7f8976a8899..0ff9e5e9411ca 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala @@ -230,20 +230,27 @@ private[ui] class JobPage(parent: JobsTab) extends WebUIPage("job") { val basePath = "jobs/job" + val pendingOrSkippedTableId = + if (isComplete) { + "pending" + } else { + "skipped" + } + val activeStagesTable = - new StageTableBase(request, activeStages, "activeStage", parent.basePath, + new StageTableBase(request, activeStages, "active", "activeStage", parent.basePath, basePath, parent.jobProgresslistener, parent.isFairScheduler, killEnabled = parent.killEnabled, isFailedStage = false) val pendingOrSkippedStagesTable = - new StageTableBase(request, pendingOrSkippedStages, "pendingStage", parent.basePath, - basePath, parent.jobProgresslistener, parent.isFairScheduler, + new StageTableBase(request, pendingOrSkippedStages, pendingOrSkippedTableId, "pendingStage", + parent.basePath, basePath, parent.jobProgresslistener, parent.isFairScheduler, killEnabled = false, isFailedStage = false) val completedStagesTable = - new StageTableBase(request, completedStages, "completedStage", parent.basePath, + new StageTableBase(request, completedStages, "completed", "completedStage", parent.basePath, basePath, parent.jobProgresslistener, parent.isFairScheduler, killEnabled = false, isFailedStage = false) val failedStagesTable = - new StageTableBase(request, failedStages, "failedStage", parent.basePath, + new StageTableBase(request, failedStages, "failed", "failedStage", parent.basePath, basePath, parent.jobProgresslistener, parent.isFairScheduler, killEnabled = false, isFailedStage = true) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala index f9cb717918592..8ee70d27cc09f 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala @@ -44,7 +44,7 @@ private[ui] class PoolPage(parent: StagesTab) extends WebUIPage("pool") { } val shouldShowActiveStages = activeStages.nonEmpty val activeStagesTable = - new StageTableBase(request, activeStages, "activeStage", parent.basePath, "stages/pool", + new StageTableBase(request, activeStages, "", "activeStage", parent.basePath, "stages/pool", parent.progressListener, parent.isFairScheduler, parent.killEnabled, isFailedStage = false) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index 2a04e8fc7d007..40a6762c281ce 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -34,6 +34,7 @@ import org.apache.spark.util.Utils private[ui] class StageTableBase( request: HttpServletRequest, stages: Seq[StageInfo], + tableHeaderID: String, stageTag: String, basePath: String, subPath: String, @@ -77,6 +78,7 @@ private[ui] class StageTableBase( val toNodeSeq = try { new StagePagedTable( stages, + tableHeaderID, stageTag, basePath, subPath, @@ -131,6 +133,7 @@ private[ui] class MissingStageTableRowData( /** Page showing list of all ongoing and recently finished stages */ private[ui] class StagePagedTable( stages: Seq[StageInfo], + tableHeaderId: String, stageTag: String, basePath: String, subPath: String, @@ -173,12 +176,13 @@ private[ui] class StagePagedTable( s"&$pageNumberFormField=$page" + s"&$stageTag.sort=$encodedSortColumn" + s"&$stageTag.desc=$desc" + - s"&$pageSizeFormField=$pageSize" + s"&$pageSizeFormField=$pageSize" + + s"#$tableHeaderId" } override def goButtonFormPath: String = { val encodedSortColumn = URLEncoder.encode(sortColumn, "UTF-8") - s"$parameterPath&$stageTag.sort=$encodedSortColumn&$stageTag.desc=$desc" + s"$parameterPath&$stageTag.sort=$encodedSortColumn&$stageTag.desc=$desc#$tableHeaderId" } override def headers: Seq[Node] = { @@ -226,7 +230,8 @@ private[ui] class StagePagedTable( parameterPath + s"&$stageTag.sort=${URLEncoder.encode(header, "UTF-8")}" + s"&$stageTag.desc=${!desc}" + - s"&$stageTag.pageSize=$pageSize") + s"&$stageTag.pageSize=$pageSize") + + s"#$tableHeaderId" val arrow = if (desc) "▾" else "▴" // UP or DOWN @@ -241,7 +246,8 @@ private[ui] class StagePagedTable( val headerLink = Unparsed( parameterPath + s"&$stageTag.sort=${URLEncoder.encode(header, "UTF-8")}" + - s"&$stageTag.pageSize=$pageSize") + s"&$stageTag.pageSize=$pageSize") + + s"#$tableHeaderId" From 2b01d3c701c58f07fa42afd570523dd161384882 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 7 Oct 2016 11:49:34 +0100 Subject: [PATCH 213/306] [SPARK-16960][SQL] Deprecate approxCountDistinct, toDegrees and toRadians according to FunctionRegistry ## What changes were proposed in this pull request? It seems `approxCountDistinct`, `toDegrees` and `toRadians` are also missed while matching the names to the ones in `FunctionRegistry`. (please see [approx_count_distinct](https://github.com/apache/spark/blob/5c2ae79bfcf448d8dc9217efafa1409997c739de/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala#L244), [degrees](https://github.com/apache/spark/blob/5c2ae79bfcf448d8dc9217efafa1409997c739de/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala#L203) and [radians](https://github.com/apache/spark/blob/5c2ae79bfcf448d8dc9217efafa1409997c739de/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala#L222) in `FunctionRegistry`). I took a scan between `functions.scala` and `FunctionRegistry` and it seems these are all left. For `countDistinct` and `sumDistinct`, they are not registered in `FunctionRegistry`. This PR deprecates `approxCountDistinct`, `toDegrees` and `toRadians` and introduces `approx_count_distinct`, `degrees` and `radians`. ## How was this patch tested? Existing tests should cover this. Author: hyukjinkwon Author: Hyukjin Kwon Closes #14538 from HyukjinKwon/SPARK-16588-followup. --- python/pyspark/sql/functions.py | 33 +++++-- .../org/apache/spark/sql/functions.scala | 91 +++++++++++++++---- .../spark/sql/DataFrameWindowSuite.scala | 2 +- .../spark/sql/MathExpressionsSuite.scala | 12 +-- 4 files changed, 105 insertions(+), 33 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 45d6bf944b702..7fa3fd2de7ddf 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -112,11 +112,8 @@ def _(): 'sinh': 'Computes the hyperbolic sine of the given value.', 'tan': 'Computes the tangent of the given value.', 'tanh': 'Computes the hyperbolic tangent of the given value.', - 'toDegrees': 'Converts an angle measured in radians to an approximately equivalent angle ' + - 'measured in degrees.', - 'toRadians': 'Converts an angle measured in degrees to an approximately equivalent angle ' + - 'measured in radians.', - + 'toDegrees': '.. note:: Deprecated in 2.1, use degrees instead.', + 'toRadians': '.. note:: Deprecated in 2.1, use radians instead.', 'bitwiseNOT': 'Computes bitwise not.', } @@ -135,7 +132,15 @@ def _(): 'kurtosis': 'Aggregate function: returns the kurtosis of the values in a group.', 'collect_list': 'Aggregate function: returns a list of objects with duplicates.', 'collect_set': 'Aggregate function: returns a set of objects with duplicate elements' + - ' eliminated.' + ' eliminated.', +} + +_functions_2_1 = { + # unary math functions + 'degrees': 'Converts an angle measured in radians to an approximately equivalent angle ' + + 'measured in degrees.', + 'radians': 'Converts an angle measured in degrees to an approximately equivalent angle ' + + 'measured in radians.', } # math functions that take two arguments as input @@ -182,21 +187,31 @@ def _(): globals()[_name] = since(1.6)(_create_window_function(_name, _doc)) for _name, _doc in _functions_1_6.items(): globals()[_name] = since(1.6)(_create_function(_name, _doc)) +for _name, _doc in _functions_2_1.items(): + globals()[_name] = since(2.1)(_create_function(_name, _doc)) del _name, _doc @since(1.3) def approxCountDistinct(col, rsd=None): + """ + .. note:: Deprecated in 2.1, use approx_count_distinct instead. + """ + return approx_count_distinct(col, rsd) + + +@since(2.1) +def approx_count_distinct(col, rsd=None): """Returns a new :class:`Column` for approximate distinct count of ``col``. - >>> df.agg(approxCountDistinct(df.age).alias('c')).collect() + >>> df.agg(approx_count_distinct(df.age).alias('c')).collect() [Row(c=2)] """ sc = SparkContext._active_spark_context if rsd is None: - jc = sc._jvm.functions.approxCountDistinct(_to_java_column(col)) + jc = sc._jvm.functions.approx_count_distinct(_to_java_column(col)) else: - jc = sc._jvm.functions.approxCountDistinct(_to_java_column(col), rsd) + jc = sc._jvm.functions.approx_count_distinct(_to_java_column(col), rsd) return Column(jc) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 3bc1c5b90031d..40f82d895d43b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -182,13 +182,43 @@ object functions { // Aggregate functions ////////////////////////////////////////////////////////////////////////////////////////////// + /** + * @group agg_funcs + * @since 1.3.0 + */ + @deprecated("Use approx_count_distinct", "2.1.0") + def approxCountDistinct(e: Column): Column = approx_count_distinct(e) + + /** + * @group agg_funcs + * @since 1.3.0 + */ + @deprecated("Use approx_count_distinct", "2.1.0") + def approxCountDistinct(columnName: String): Column = approx_count_distinct(columnName) + + /** + * @group agg_funcs + * @since 1.3.0 + */ + @deprecated("Use approx_count_distinct", "2.1.0") + def approxCountDistinct(e: Column, rsd: Double): Column = approx_count_distinct(e, rsd) + + /** + * @group agg_funcs + * @since 1.3.0 + */ + @deprecated("Use approx_count_distinct", "2.1.0") + def approxCountDistinct(columnName: String, rsd: Double): Column = { + approx_count_distinct(Column(columnName), rsd) + } + /** * Aggregate function: returns the approximate number of distinct items in a group. * * @group agg_funcs - * @since 1.3.0 + * @since 2.1.0 */ - def approxCountDistinct(e: Column): Column = withAggregateFunction { + def approx_count_distinct(e: Column): Column = withAggregateFunction { HyperLogLogPlusPlus(e.expr) } @@ -196,9 +226,9 @@ object functions { * Aggregate function: returns the approximate number of distinct items in a group. * * @group agg_funcs - * @since 1.3.0 + * @since 2.1.0 */ - def approxCountDistinct(columnName: String): Column = approxCountDistinct(column(columnName)) + def approx_count_distinct(columnName: String): Column = approx_count_distinct(column(columnName)) /** * Aggregate function: returns the approximate number of distinct items in a group. @@ -206,9 +236,9 @@ object functions { * @param rsd maximum estimation error allowed (default = 0.05) * * @group agg_funcs - * @since 1.3.0 + * @since 2.1.0 */ - def approxCountDistinct(e: Column, rsd: Double): Column = withAggregateFunction { + def approx_count_distinct(e: Column, rsd: Double): Column = withAggregateFunction { HyperLogLogPlusPlus(e.expr, rsd, 0, 0) } @@ -218,10 +248,10 @@ object functions { * @param rsd maximum estimation error allowed (default = 0.05) * * @group agg_funcs - * @since 1.3.0 + * @since 2.1.0 */ - def approxCountDistinct(columnName: String, rsd: Double): Column = { - approxCountDistinct(Column(columnName), rsd) + def approx_count_distinct(columnName: String, rsd: Double): Column = { + approx_count_distinct(Column(columnName), rsd) } /** @@ -1949,37 +1979,65 @@ object functions { */ def tanh(columnName: String): Column = tanh(Column(columnName)) + /** + * @group math_funcs + * @since 1.4.0 + */ + @deprecated("Use degrees", "2.1.0") + def toDegrees(e: Column): Column = degrees(e) + + /** + * @group math_funcs + * @since 1.4.0 + */ + @deprecated("Use degrees", "2.1.0") + def toDegrees(columnName: String): Column = degrees(Column(columnName)) + /** * Converts an angle measured in radians to an approximately equivalent angle measured in degrees. * * @group math_funcs - * @since 1.4.0 + * @since 2.1.0 */ - def toDegrees(e: Column): Column = withExpr { ToDegrees(e.expr) } + def degrees(e: Column): Column = withExpr { ToDegrees(e.expr) } /** * Converts an angle measured in radians to an approximately equivalent angle measured in degrees. * + * @group math_funcs + * @since 2.1.0 + */ + def degrees(columnName: String): Column = degrees(Column(columnName)) + + /** + * @group math_funcs + * @since 1.4.0 + */ + @deprecated("Use radians", "2.1.0") + def toRadians(e: Column): Column = radians(e) + + /** * @group math_funcs * @since 1.4.0 */ - def toDegrees(columnName: String): Column = toDegrees(Column(columnName)) + @deprecated("Use radians", "2.1.0") + def toRadians(columnName: String): Column = radians(Column(columnName)) /** * Converts an angle measured in degrees to an approximately equivalent angle measured in radians. * * @group math_funcs - * @since 1.4.0 + * @since 2.1.0 */ - def toRadians(e: Column): Column = withExpr { ToRadians(e.expr) } + def radians(e: Column): Column = withExpr { ToRadians(e.expr) } /** * Converts an angle measured in degrees to an approximately equivalent angle measured in radians. * * @group math_funcs - * @since 1.4.0 + * @since 2.1.0 */ - def toRadians(columnName: String): Column = toRadians(Column(columnName)) + def radians(columnName: String): Column = radians(Column(columnName)) ////////////////////////////////////////////////////////////////////////////////////////////// // Misc functions @@ -3096,5 +3154,4 @@ object functions { def callUDF(udfName: String, cols: Column*): Column = withExpr { UnresolvedFunction(udfName, cols.map(_.expr), isDistinct = false) } - } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala index c6f8c3ad3fc93..c2b47cae8f4c4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala @@ -228,7 +228,7 @@ class DataFrameWindowSuite extends QueryTest with SharedSQLContext { $"key", var_pop($"value").over(window), var_samp($"value").over(window), - approxCountDistinct($"value").over(window)), + approx_count_distinct($"value").over(window)), Seq.fill(4)(Row("a", 1.0d / 4.0d, 1.0d / 3.0d, 2)) ++ Seq.fill(3)(Row("b", 2.0d / 3.0d, 1.0d, 3))) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index 0de7f2321f398..6944c6f848179 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -148,19 +148,19 @@ class MathExpressionsSuite extends QueryTest with SharedSQLContext { testOneToOneMathFunction(tanh, math.tanh) } - test("toDegrees") { - testOneToOneMathFunction(toDegrees, math.toDegrees) + test("degrees") { + testOneToOneMathFunction(degrees, math.toDegrees) checkAnswer( sql("SELECT degrees(0), degrees(1), degrees(1.5)"), - Seq((1, 2)).toDF().select(toDegrees(lit(0)), toDegrees(lit(1)), toDegrees(lit(1.5))) + Seq((1, 2)).toDF().select(degrees(lit(0)), degrees(lit(1)), degrees(lit(1.5))) ) } - test("toRadians") { - testOneToOneMathFunction(toRadians, math.toRadians) + test("radians") { + testOneToOneMathFunction(radians, math.toRadians) checkAnswer( sql("SELECT radians(0), radians(1), radians(1.5)"), - Seq((1, 2)).toDF().select(toRadians(lit(0)), toRadians(lit(1)), toRadians(lit(1.5))) + Seq((1, 2)).toDF().select(radians(lit(0)), radians(lit(1)), radians(lit(1.5))) ) } From e56614cba99bfdf5fa8a6c617fdd56eca2b34694 Mon Sep 17 00:00:00 2001 From: Brian Cho Date: Fri, 7 Oct 2016 11:37:18 -0400 Subject: [PATCH 214/306] [SPARK-16827] Stop reporting spill metrics as shuffle metrics ## What changes were proposed in this pull request? Fix a bug where spill metrics were being reported as shuffle metrics. Eventually these spill metrics should be reported (SPARK-3577), but separate from shuffle metrics. The fix itself basically reverts the line to what it was in 1.6. ## How was this patch tested? Tested on a job that was reporting shuffle writes even for the final stage, when no shuffle writes should take place. After the change the job no longer shows these writes. Before: ![screen shot 2016-10-03 at 6 39 59 pm](https://cloud.githubusercontent.com/assets/1514239/19085897/dbf59a92-8a20-11e6-9f68-a978860c0d74.png) After: screen shot 2016-10-03 at 11 44 44 pm Author: Brian Cho Closes #15347 from dafrista/shuffle-metrics. --- .../util/collection/unsafe/sort/UnsafeExternalSorter.java | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 428ff72e71a43..7835017910232 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -145,7 +145,9 @@ private UnsafeExternalSorter( // Use getSizeAsKb (not bytes) to maintain backwards compatibility for units // this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; this.fileBufferSizeBytes = 32 * 1024; - this.writeMetrics = taskContext.taskMetrics().shuffleWriteMetrics(); + // The spill metrics are stored in a new ShuffleWriteMetrics, and then discarded (this fixes SPARK-16827). + // TODO: Instead, separate spill metrics should be stored and reported (tracked in SPARK-3577). + this.writeMetrics = new ShuffleWriteMetrics(); if (existingInMemorySorter == null) { this.inMemSorter = new UnsafeInMemorySorter( From dd16b52cf785ae06026bd00e8e6bedfffa791f5d Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 7 Oct 2016 10:24:42 -0700 Subject: [PATCH 215/306] [SPARK-17800] Introduce InterfaceStability annotation ## What changes were proposed in this pull request? This patch introduces three new annotations under InterfaceStability: - Stable - Evolving - Unstable This is inspired by Hadoop's InterfaceStability, and the first step towards switching over to a new API stability annotation framework. ## How was this patch tested? N/A Author: Reynold Xin Closes #15374 from rxin/SPARK-17800. --- .../spark/annotation/InterfaceStability.java | 49 +++++++++++++++++++ 1 file changed, 49 insertions(+) create mode 100644 common/tags/src/main/java/org/apache/spark/annotation/InterfaceStability.java diff --git a/common/tags/src/main/java/org/apache/spark/annotation/InterfaceStability.java b/common/tags/src/main/java/org/apache/spark/annotation/InterfaceStability.java new file mode 100644 index 0000000000000..323098f69c6e1 --- /dev/null +++ b/common/tags/src/main/java/org/apache/spark/annotation/InterfaceStability.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.annotation; + +import java.lang.annotation.Documented; + +/** + * Annotation to inform users of how much to rely on a particular package, + * class or method not changing over time. + */ +public class InterfaceStability { + + /** + * Stable APIs that retain source and binary compatibility within a major release. + * These interfaces can change from one major release to another major release + * (e.g. from 1.0 to 2.0). + */ + @Documented + public @interface Stable {}; + + /** + * APIs that are meant to evolve towards becoming stable APIs, but are not stable APIs yet. + * Evolving interfaces can change from one feature release to another release (i.e. 2.1 to 2.2). + */ + @Documented + public @interface Evolving {}; + + /** + * Unstable APIs, with no guarantee on stability. + * Classes that are unannotated are considered Unstable. + */ + @Documented + public @interface Unstable {}; +} From cff560755244dd4ccb998e0c56e81d2620cd4cff Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Fri, 7 Oct 2016 10:31:41 -0700 Subject: [PATCH 216/306] [SPARK-17707][WEBUI] Web UI prevents spark-submit application to be finished ## What changes were proposed in this pull request? This expands calls to Jetty's simple `ServerConnector` constructor to explicitly specify a `ScheduledExecutorScheduler` that makes daemon threads. It should otherwise result in exactly the same configuration, because the other args are copied from the constructor that is currently called. (I'm not sure we should change the Hive Thriftserver impl, but I did anyway.) This also adds `sc.stop()` to the quick start guide example. ## How was this patch tested? Existing tests; _pending_ at least manual verification of the fix. Author: Sean Owen Closes #15381 from srowen/SPARK-17707. --- .../deploy/rest/RestSubmissionServer.scala | 14 +++++++++--- .../org/apache/spark/ui/JettyUtils.scala | 14 +++++++++--- docs/quick-start.md | 7 +++++- .../cli/thrift/ThriftHttpCLIService.java | 22 +++++++++++++++++-- 4 files changed, 48 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala index fa55d470842b3..b30c980e95a9a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala @@ -22,9 +22,9 @@ import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse} import scala.io.Source import com.fasterxml.jackson.core.JsonProcessingException -import org.eclipse.jetty.server.{Server, ServerConnector} +import org.eclipse.jetty.server.{HttpConnectionFactory, Server, ServerConnector} import org.eclipse.jetty.servlet.{ServletContextHandler, ServletHolder} -import org.eclipse.jetty.util.thread.QueuedThreadPool +import org.eclipse.jetty.util.thread.{QueuedThreadPool, ScheduledExecutorScheduler} import org.json4s._ import org.json4s.jackson.JsonMethods._ @@ -83,7 +83,15 @@ private[spark] abstract class RestSubmissionServer( threadPool.setDaemon(true) val server = new Server(threadPool) - val connector = new ServerConnector(server) + val connector = new ServerConnector( + server, + null, + // Call this full constructor to set this, which forces daemon threads: + new ScheduledExecutorScheduler("RestSubmissionServer-JettyScheduler", true), + null, + -1, + -1, + new HttpConnectionFactory()) connector.setHost(host) connector.setPort(startPort) server.addConnector(connector) diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index 24f3f757157f3..35c3c8d00f99b 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -27,12 +27,12 @@ import scala.xml.Node import org.eclipse.jetty.client.api.Response import org.eclipse.jetty.proxy.ProxyServlet -import org.eclipse.jetty.server.{Request, Server, ServerConnector} +import org.eclipse.jetty.server.{HttpConnectionFactory, Request, Server, ServerConnector} import org.eclipse.jetty.server.handler._ import org.eclipse.jetty.servlet._ import org.eclipse.jetty.servlets.gzip.GzipHandler import org.eclipse.jetty.util.component.LifeCycle -import org.eclipse.jetty.util.thread.QueuedThreadPool +import org.eclipse.jetty.util.thread.{QueuedThreadPool, ScheduledExecutorScheduler} import org.json4s.JValue import org.json4s.jackson.JsonMethods.{pretty, render} @@ -294,7 +294,15 @@ private[spark] object JettyUtils extends Logging { val server = new Server(pool) val connectors = new ArrayBuffer[ServerConnector] // Create a connector on port currentPort to listen for HTTP requests - val httpConnector = new ServerConnector(server) + val httpConnector = new ServerConnector( + server, + null, + // Call this full constructor to set this, which forces daemon threads: + new ScheduledExecutorScheduler(s"$serverName-JettyScheduler", true), + null, + -1, + -1, + new HttpConnectionFactory()) httpConnector.setPort(currentPort) connectors += httpConnector diff --git a/docs/quick-start.md b/docs/quick-start.md index 2eab8d19aa4c6..cb9a378199562 100644 --- a/docs/quick-start.md +++ b/docs/quick-start.md @@ -240,7 +240,8 @@ object SimpleApp { val logData = sc.textFile(logFile, 2).cache() val numAs = logData.filter(line => line.contains("a")).count() val numBs = logData.filter(line => line.contains("b")).count() - println("Lines with a: %s, Lines with b: %s".format(numAs, numBs)) + println(s"Lines with a: $numAs, Lines with b: $numBs") + sc.stop() } } {% endhighlight %} @@ -328,6 +329,8 @@ public class SimpleApp { }).count(); System.out.println("Lines with a: " + numAs + ", lines with b: " + numBs); + + sc.stop() } } {% endhighlight %} @@ -407,6 +410,8 @@ numAs = logData.filter(lambda s: 'a' in s).count() numBs = logData.filter(lambda s: 'b' in s).count() print("Lines with a: %i, lines with b: %i" % (numAs, numBs)) + +sc.stop() {% endhighlight %} diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/thrift/ThriftHttpCLIService.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/thrift/ThriftHttpCLIService.java index 37e4845cceb9e..341a7fdbb59b8 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/thrift/ThriftHttpCLIService.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/thrift/ThriftHttpCLIService.java @@ -37,11 +37,15 @@ import org.apache.thrift.protocol.TBinaryProtocol; import org.apache.thrift.protocol.TProtocolFactory; import org.apache.thrift.server.TServlet; +import org.eclipse.jetty.server.AbstractConnectionFactory; +import org.eclipse.jetty.server.ConnectionFactory; +import org.eclipse.jetty.server.HttpConnectionFactory; import org.eclipse.jetty.server.ServerConnector; import org.eclipse.jetty.servlet.ServletContextHandler; import org.eclipse.jetty.servlet.ServletHolder; import org.eclipse.jetty.util.ssl.SslContextFactory; import org.eclipse.jetty.util.thread.ExecutorThreadPool; +import org.eclipse.jetty.util.thread.ScheduledExecutorScheduler; public class ThriftHttpCLIService extends ThriftCLIService { @@ -70,7 +74,8 @@ public void run() { httpServer = new org.eclipse.jetty.server.Server(threadPool); // Connector configs - ServerConnector connector = new ServerConnector(httpServer); + + ConnectionFactory[] connectionFactories; boolean useSsl = hiveConf.getBoolVar(ConfVars.HIVE_SERVER2_USE_SSL); String schemeName = useSsl ? "https" : "http"; // Change connector if SSL is used @@ -90,8 +95,21 @@ public void run() { Arrays.toString(sslContextFactory.getExcludeProtocols())); sslContextFactory.setKeyStorePath(keyStorePath); sslContextFactory.setKeyStorePassword(keyStorePassword); - connector = new ServerConnector(httpServer, sslContextFactory); + connectionFactories = AbstractConnectionFactory.getFactories( + sslContextFactory, new HttpConnectionFactory()); + } else { + connectionFactories = new ConnectionFactory[] { new HttpConnectionFactory() }; } + ServerConnector connector = new ServerConnector( + httpServer, + null, + // Call this full constructor to set this, which forces daemon threads: + new ScheduledExecutorScheduler("HiveServer2-HttpHandler-JettyScheduler", true), + null, + -1, + -1, + connectionFactories); + connector.setPort(portNum); // Linux:yes, Windows:no connector.setReuseAddress(!Shell.WINDOWS); From aa3a6841ebaf45efb5d3930a93869948bdd0d2b6 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 7 Oct 2016 10:52:32 -0700 Subject: [PATCH 217/306] [SPARK-14525][SQL][FOLLOWUP] Clean up JdbcRelationProvider ## What changes were proposed in this pull request? This PR proposes cleaning up the confusing part in `createRelation` as discussed in https://github.com/apache/spark/pull/12601/files#r80627940 Also, this PR proposes the changes below: - Add documentation for `batchsize` and `isolationLevel`. - Move property names into `JDBCOptions` so that they can be managed in a single place. which were, `fetchsize`, `batchsize`, `isolationLevel` and `driver`. ## How was this patch tested? Existing tests should cover this. Author: hyukjinkwon Closes #15263 from HyukjinKwon/SPARK-14525. --- .../jdbc/JdbcRelationProvider.scala | 82 ++++++++----------- .../datasources/jdbc/JdbcUtils.scala | 29 ++++++- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 2 +- .../spark/sql/jdbc/JDBCWriteSuite.scala | 13 +++ 4 files changed, 74 insertions(+), 52 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala index ae04af2479c8d..3a8a197ef5241 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala @@ -22,6 +22,7 @@ import java.util.Properties import scala.collection.JavaConverters.mapAsJavaMapConverter import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SQLContext} +import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils._ import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister, RelationProvider} class JdbcRelationProvider extends CreatableRelationProvider @@ -50,67 +51,52 @@ class JdbcRelationProvider extends CreatableRelationProvider JDBCRelation(jdbcOptions.url, jdbcOptions.table, parts, properties)(sqlContext.sparkSession) } - /* - * The following structure applies to this code: - * | tableExists | !tableExists - *------------------------------------------------------------------------------------ - * Ignore | BaseRelation | CreateTable, saveTable, BaseRelation - * ErrorIfExists | ERROR | CreateTable, saveTable, BaseRelation - * Overwrite* | (DropTable, CreateTable,) | CreateTable, saveTable, BaseRelation - * | saveTable, BaseRelation | - * Append | saveTable, BaseRelation | CreateTable, saveTable, BaseRelation - * - * *Overwrite & tableExists with truncate, will not drop & create, but instead truncate - */ override def createRelation( sqlContext: SQLContext, mode: SaveMode, parameters: Map[String, String], - data: DataFrame): BaseRelation = { - val jdbcOptions = new JDBCOptions(parameters) - val url = jdbcOptions.url - val table = jdbcOptions.table - + df: DataFrame): BaseRelation = { + val options = new JDBCOptions(parameters) + val url = options.url + val table = options.table + val createTableOptions = options.createTableOptions + val isTruncate = options.isTruncate val props = new Properties() props.putAll(parameters.asJava) - val conn = JdbcUtils.createConnectionFactory(url, props)() + val conn = JdbcUtils.createConnectionFactory(url, props)() try { val tableExists = JdbcUtils.tableExists(conn, url, table) + if (tableExists) { + mode match { + case SaveMode.Overwrite => + if (isTruncate && isCascadingTruncateTable(url).contains(false)) { + // In this case, we should truncate table and then load. + truncateTable(conn, table) + saveTable(df, url, table, props) + } else { + // Otherwise, do not truncate the table, instead drop and recreate it + dropTable(conn, table) + createTable(df.schema, url, table, createTableOptions, conn) + saveTable(df, url, table, props) + } - val (doCreate, doSave) = (mode, tableExists) match { - case (SaveMode.Ignore, true) => (false, false) - case (SaveMode.ErrorIfExists, true) => throw new AnalysisException( - s"Table or view '$table' already exists, and SaveMode is set to ErrorIfExists.") - case (SaveMode.Overwrite, true) => - if (jdbcOptions.isTruncate && JdbcUtils.isCascadingTruncateTable(url) == Some(false)) { - JdbcUtils.truncateTable(conn, table) - (false, true) - } else { - JdbcUtils.dropTable(conn, table) - (true, true) - } - case (SaveMode.Append, true) => (false, true) - case (_, true) => throw new IllegalArgumentException(s"Unexpected SaveMode, '$mode'," + - " for handling existing tables.") - case (_, false) => (true, true) - } + case SaveMode.Append => + saveTable(df, url, table, props) + + case SaveMode.ErrorIfExists => + throw new AnalysisException( + s"Table or view '$table' already exists. SaveMode: ErrorIfExists.") - if (doCreate) { - val schema = JdbcUtils.schemaString(data, url) - // To allow certain options to append when create a new table, which can be - // table_options or partition_options. - // E.g., "CREATE TABLE t (name string) ENGINE=InnoDB DEFAULT CHARSET=utf8" - val createtblOptions = jdbcOptions.createTableOptions - val sql = s"CREATE TABLE $table ($schema) $createtblOptions" - val statement = conn.createStatement - try { - statement.executeUpdate(sql) - } finally { - statement.close() + case SaveMode.Ignore => + // With `SaveMode.Ignore` mode, if table already exists, the save operation is expected + // to not save the contents of the DataFrame and to not change the existing data. + // Therefore, it is okay to do nothing here and then just return the relation below. } + } else { + createTable(df.schema, url, table, createTableOptions, conn) + saveTable(df, url, table, props) } - if (doSave) JdbcUtils.saveTable(data, url, table, props) } finally { conn.close() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 3db1d1f109fb7..66f2bada2e3d8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -552,7 +552,7 @@ object JdbcUtils extends Logging { isolationLevel: Int): Iterator[Byte] = { require(batchSize >= 1, s"Invalid value `${batchSize.toString}` for parameter " + - s"`${JdbcUtils.JDBC_BATCH_INSERT_SIZE}`. The minimum value is 1.") + s"`$JDBC_BATCH_INSERT_SIZE`. The minimum value is 1.") val conn = getConnection() var committed = false @@ -657,10 +657,10 @@ object JdbcUtils extends Logging { /** * Compute the schema string for this RDD. */ - def schemaString(df: DataFrame, url: String): String = { + def schemaString(schema: StructType, url: String): String = { val sb = new StringBuilder() val dialect = JdbcDialects.get(url) - df.schema.fields foreach { field => + schema.fields foreach { field => val name = dialect.quoteIdentifier(field.name) val typ: String = getJdbcType(field.dataType, dialect).databaseTypeDefinition val nullable = if (field.nullable) "" else "NOT NULL" @@ -697,4 +697,27 @@ object JdbcUtils extends Logging { getConnection, table, iterator, rddSchema, nullTypes, batchSize, dialect, isolationLevel) ) } + + /** + * Creates a table with a given schema. + */ + def createTable( + schema: StructType, + url: String, + table: String, + createTableOptions: String, + conn: Connection): Unit = { + val strSchema = schemaString(schema, url) + // Create the table if the table does not exist. + // To allow certain options to append when create a new table, which can be + // table_options or partition_options. + // E.g., "CREATE TABLE t (name string) ENGINE=InnoDB DEFAULT CHARSET=utf8" + val sql = s"CREATE TABLE $table ($strSchema) $createTableOptions" + val statement = conn.createStatement + try { + statement.executeUpdate(sql) + } finally { + statement.close() + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 10f15ca280689..7cc3989b791ad 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -788,7 +788,7 @@ class JDBCSuite extends SparkFunSuite test("SPARK-16387: Reserved SQL words are not escaped by JDBC writer") { val df = spark.createDataset(Seq("a", "b", "c")).toDF("order") - val schema = JdbcUtils.schemaString(df, "jdbc:mysql://localhost:3306/temp") + val schema = JdbcUtils.schemaString(df.schema, "jdbc:mysql://localhost:3306/temp") assert(schema.contains("`order` TEXT")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index 506971362f867..62b29db4d5524 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -132,6 +132,19 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { } } + test("CREATE with ignore") { + val df = spark.createDataFrame(sparkContext.parallelize(arr2x3), schema3) + val df2 = spark.createDataFrame(sparkContext.parallelize(arr1x2), schema2) + + df.write.mode(SaveMode.Ignore).jdbc(url1, "TEST.DROPTEST", properties) + assert(2 === spark.read.jdbc(url1, "TEST.DROPTEST", properties).count()) + assert(3 === spark.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length) + + df2.write.mode(SaveMode.Ignore).jdbc(url1, "TEST.DROPTEST", properties) + assert(2 === spark.read.jdbc(url1, "TEST.DROPTEST", properties).count()) + assert(3 === spark.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length) + } + test("CREATE with overwrite") { val df = spark.createDataFrame(sparkContext.parallelize(arr2x3), schema3) val df2 = spark.createDataFrame(sparkContext.parallelize(arr1x2), schema2) From bb1aaf28eca6d9ae9af664ac3ad35cafdfc01a3b Mon Sep 17 00:00:00 2001 From: Prashant Sharma Date: Fri, 7 Oct 2016 11:16:24 -0700 Subject: [PATCH 218/306] [SPARK-16411][SQL][STREAMING] Add textFile to Structured Streaming. ## What changes were proposed in this pull request? Adds the textFile API which exists in DataFrameReader and serves same purpose. ## How was this patch tested? Added corresponding testcase. Author: Prashant Sharma Closes #14087 from ScrapCodes/textFile. --- .../sql/streaming/DataStreamReader.scala | 33 ++++++++++++++++++- .../sql/streaming/FileStreamSourceSuite.scala | 18 ++++++++++ 2 files changed, 50 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 864a9cd3eb89d..87b73062180e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -21,7 +21,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental import org.apache.spark.internal.Logging -import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} +import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, SparkSession} import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming.StreamingRelation import org.apache.spark.sql.types.StructType @@ -283,6 +283,37 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo */ def text(path: String): DataFrame = format("text").load(path) + /** + * Loads text file(s) and returns a [[Dataset]] of String. The underlying schema of the Dataset + * contains a single string column named "value". + * + * If the directory structure of the text files contains partitioning information, those are + * ignored in the resulting Dataset. To include partitioning information as columns, use `text`. + * + * Each line in the text file is a new element in the resulting Dataset. For example: + * {{{ + * // Scala: + * spark.readStream.textFile("/path/to/spark/README.md") + * + * // Java: + * spark.readStream().textFile("/path/to/spark/README.md") + * }}} + * + * You can set the following text-specific options to deal with text files: + *
        + *
      • `maxFilesPerTrigger` (default: no max limit): sets the maximum number of new files to be + * considered in every trigger.
      • + *
      + * + * @param path input path + * @since 2.1.0 + */ + def textFile(path: String): Dataset[String] = { + if (userSpecifiedSchema.nonEmpty) { + throw new AnalysisException("User specified schema not supported with `textFile`") + } + text(path).select("value").as[String](sparkSession.implicits.newStringEncoder) + } /////////////////////////////////////////////////////////////////////////////////////// // Builder pattern config options diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index 3157afe5a56c0..7f9c981a4e9c9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -342,6 +342,24 @@ class FileStreamSourceSuite extends FileStreamSourceTest { } } + test("read from textfile") { + withTempDirs { case (src, tmp) => + val textStream = spark.readStream.textFile(src.getCanonicalPath) + val filtered = textStream.filter(_.contains("keep")) + + testStream(filtered)( + AddTextFileData("drop1\nkeep2\nkeep3", src, tmp), + CheckAnswer("keep2", "keep3"), + StopStream, + AddTextFileData("drop4\nkeep5\nkeep6", src, tmp), + StartStream(), + CheckAnswer("keep2", "keep3", "keep5", "keep6"), + AddTextFileData("drop7\nkeep8\nkeep9", src, tmp), + CheckAnswer("keep2", "keep3", "keep5", "keep6", "keep8", "keep9") + ) + } + } + test("SPARK-17165 should not track the list of seen files indefinitely") { // This test works by: // 1. Create a file From 9d8ae853ecc5600f5c2f69565b96d5c46a8c0048 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 7 Oct 2016 11:34:49 -0700 Subject: [PATCH 219/306] [SPARK-17665][SPARKR] Support options/mode all for read/write APIs and options in other types ## What changes were proposed in this pull request? This PR includes the changes below: - Support `mode`/`options` in `read.parquet`, `write.parquet`, `read.orc`, `write.orc`, `read.text`, `write.text`, `read.json` and `write.json` APIs - Support other types (logical, numeric and string) as options for `write.df`, `read.df`, `read.parquet`, `write.parquet`, `read.orc`, `write.orc`, `read.text`, `write.text`, `read.json` and `write.json` ## How was this patch tested? Unit tests in `test_sparkSQL.R`/ `utils.R`. Author: hyukjinkwon Closes #15239 from HyukjinKwon/SPARK-17665. --- R/pkg/R/DataFrame.R | 43 +++++++++---- R/pkg/R/SQLContext.R | 23 +++++-- R/pkg/R/generics.R | 10 +-- R/pkg/R/utils.R | 22 +++++++ R/pkg/inst/tests/testthat/test_sparkSQL.R | 75 +++++++++++++++++++++++ R/pkg/inst/tests/testthat/test_utils.R | 9 +++ 6 files changed, 160 insertions(+), 22 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 75861d5de7092..801d2ed4e7500 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -55,6 +55,19 @@ setMethod("initialize", "SparkDataFrame", function(.Object, sdf, isCached) { .Object }) +#' Set options/mode and then return the write object +#' @noRd +setWriteOptions <- function(write, path = NULL, mode = "error", ...) { + options <- varargsToStrEnv(...) + if (!is.null(path)) { + options[["path"]] <- path + } + jmode <- convertToJSaveMode(mode) + write <- callJMethod(write, "mode", jmode) + write <- callJMethod(write, "options", options) + write +} + #' @export #' @param sdf A Java object reference to the backing Scala DataFrame #' @param isCached TRUE if the SparkDataFrame is cached @@ -727,6 +740,8 @@ setMethod("toJSON", #' #' @param x A SparkDataFrame #' @param path The directory where the file is saved +#' @param mode one of 'append', 'overwrite', 'error', 'ignore' save mode (it is 'error' by default) +#' @param ... additional argument(s) passed to the method. #' #' @family SparkDataFrame functions #' @rdname write.json @@ -743,8 +758,9 @@ setMethod("toJSON", #' @note write.json since 1.6.0 setMethod("write.json", signature(x = "SparkDataFrame", path = "character"), - function(x, path) { + function(x, path, mode = "error", ...) { write <- callJMethod(x@sdf, "write") + write <- setWriteOptions(write, mode = mode, ...) invisible(callJMethod(write, "json", path)) }) @@ -755,6 +771,8 @@ setMethod("write.json", #' #' @param x A SparkDataFrame #' @param path The directory where the file is saved +#' @param mode one of 'append', 'overwrite', 'error', 'ignore' save mode (it is 'error' by default) +#' @param ... additional argument(s) passed to the method. #' #' @family SparkDataFrame functions #' @aliases write.orc,SparkDataFrame,character-method @@ -771,8 +789,9 @@ setMethod("write.json", #' @note write.orc since 2.0.0 setMethod("write.orc", signature(x = "SparkDataFrame", path = "character"), - function(x, path) { + function(x, path, mode = "error", ...) { write <- callJMethod(x@sdf, "write") + write <- setWriteOptions(write, mode = mode, ...) invisible(callJMethod(write, "orc", path)) }) @@ -783,6 +802,8 @@ setMethod("write.orc", #' #' @param x A SparkDataFrame #' @param path The directory where the file is saved +#' @param mode one of 'append', 'overwrite', 'error', 'ignore' save mode (it is 'error' by default) +#' @param ... additional argument(s) passed to the method. #' #' @family SparkDataFrame functions #' @rdname write.parquet @@ -800,8 +821,9 @@ setMethod("write.orc", #' @note write.parquet since 1.6.0 setMethod("write.parquet", signature(x = "SparkDataFrame", path = "character"), - function(x, path) { + function(x, path, mode = "error", ...) { write <- callJMethod(x@sdf, "write") + write <- setWriteOptions(write, mode = mode, ...) invisible(callJMethod(write, "parquet", path)) }) @@ -825,6 +847,8 @@ setMethod("saveAsParquetFile", #' #' @param x A SparkDataFrame #' @param path The directory where the file is saved +#' @param mode one of 'append', 'overwrite', 'error', 'ignore' save mode (it is 'error' by default) +#' @param ... additional argument(s) passed to the method. #' #' @family SparkDataFrame functions #' @aliases write.text,SparkDataFrame,character-method @@ -841,8 +865,9 @@ setMethod("saveAsParquetFile", #' @note write.text since 2.0.0 setMethod("write.text", signature(x = "SparkDataFrame", path = "character"), - function(x, path) { + function(x, path, mode = "error", ...) { write <- callJMethod(x@sdf, "write") + write <- setWriteOptions(write, mode = mode, ...) invisible(callJMethod(write, "text", path)) }) @@ -2637,15 +2662,9 @@ setMethod("write.df", if (is.null(source)) { source <- getDefaultSqlSource() } - jmode <- convertToJSaveMode(mode) - options <- varargsToEnv(...) - if (!is.null(path)) { - options[["path"]] <- path - } write <- callJMethod(df@sdf, "write") write <- callJMethod(write, "format", source) - write <- callJMethod(write, "mode", jmode) - write <- callJMethod(write, "options", options) + write <- setWriteOptions(write, path = path, mode = mode, ...) write <- handledCallJMethod(write, "save") }) @@ -2701,7 +2720,7 @@ setMethod("saveAsTable", source <- getDefaultSqlSource() } jmode <- convertToJSaveMode(mode) - options <- varargsToEnv(...) + options <- varargsToStrEnv(...) write <- callJMethod(df@sdf, "write") write <- callJMethod(write, "format", source) diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index baa87824beb91..0d6a229e63455 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -328,6 +328,7 @@ setMethod("toDF", signature(x = "RDD"), #' It goes through the entire dataset once to determine the schema. #' #' @param path Path of file to read. A vector of multiple paths is allowed. +#' @param ... additional external data source specific named properties. #' @return SparkDataFrame #' @rdname read.json #' @export @@ -341,11 +342,13 @@ setMethod("toDF", signature(x = "RDD"), #' @name read.json #' @method read.json default #' @note read.json since 1.6.0 -read.json.default <- function(path) { +read.json.default <- function(path, ...) { sparkSession <- getSparkSession() + options <- varargsToStrEnv(...) # Allow the user to have a more flexible definiton of the text file path paths <- as.list(suppressWarnings(normalizePath(path))) read <- callJMethod(sparkSession, "read") + read <- callJMethod(read, "options", options) sdf <- callJMethod(read, "json", paths) dataFrame(sdf) } @@ -405,16 +408,19 @@ jsonRDD <- function(sqlContext, rdd, schema = NULL, samplingRatio = 1.0) { #' Loads an ORC file, returning the result as a SparkDataFrame. #' #' @param path Path of file to read. +#' @param ... additional external data source specific named properties. #' @return SparkDataFrame #' @rdname read.orc #' @export #' @name read.orc #' @note read.orc since 2.0.0 -read.orc <- function(path) { +read.orc <- function(path, ...) { sparkSession <- getSparkSession() + options <- varargsToStrEnv(...) # Allow the user to have a more flexible definiton of the ORC file path path <- suppressWarnings(normalizePath(path)) read <- callJMethod(sparkSession, "read") + read <- callJMethod(read, "options", options) sdf <- callJMethod(read, "orc", path) dataFrame(sdf) } @@ -430,11 +436,13 @@ read.orc <- function(path) { #' @name read.parquet #' @method read.parquet default #' @note read.parquet since 1.6.0 -read.parquet.default <- function(path) { +read.parquet.default <- function(path, ...) { sparkSession <- getSparkSession() + options <- varargsToStrEnv(...) # Allow the user to have a more flexible definiton of the Parquet file path paths <- as.list(suppressWarnings(normalizePath(path))) read <- callJMethod(sparkSession, "read") + read <- callJMethod(read, "options", options) sdf <- callJMethod(read, "parquet", paths) dataFrame(sdf) } @@ -467,6 +475,7 @@ parquetFile <- function(x, ...) { #' Each line in the text file is a new row in the resulting SparkDataFrame. #' #' @param path Path of file to read. A vector of multiple paths is allowed. +#' @param ... additional external data source specific named properties. #' @return SparkDataFrame #' @rdname read.text #' @export @@ -479,11 +488,13 @@ parquetFile <- function(x, ...) { #' @name read.text #' @method read.text default #' @note read.text since 1.6.1 -read.text.default <- function(path) { +read.text.default <- function(path, ...) { sparkSession <- getSparkSession() + options <- varargsToStrEnv(...) # Allow the user to have a more flexible definiton of the text file path paths <- as.list(suppressWarnings(normalizePath(path))) read <- callJMethod(sparkSession, "read") + read <- callJMethod(read, "options", options) sdf <- callJMethod(read, "text", paths) dataFrame(sdf) } @@ -779,7 +790,7 @@ read.df.default <- function(path = NULL, source = NULL, schema = NULL, na.string "in 'spark.sql.sources.default' configuration by default.") } sparkSession <- getSparkSession() - options <- varargsToEnv(...) + options <- varargsToStrEnv(...) if (!is.null(path)) { options[["path"]] <- path } @@ -842,7 +853,7 @@ loadDF <- function(x = NULL, ...) { #' @note createExternalTable since 1.4.0 createExternalTable.default <- function(tableName, path = NULL, source = NULL, ...) { sparkSession <- getSparkSession() - options <- varargsToEnv(...) + options <- varargsToStrEnv(...) if (!is.null(path)) { options[["path"]] <- path } diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 90a02e2778310..810aea9017743 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -651,15 +651,17 @@ setGeneric("write.jdbc", function(x, url, tableName, mode = "error", ...) { #' @rdname write.json #' @export -setGeneric("write.json", function(x, path) { standardGeneric("write.json") }) +setGeneric("write.json", function(x, path, ...) { standardGeneric("write.json") }) #' @rdname write.orc #' @export -setGeneric("write.orc", function(x, path) { standardGeneric("write.orc") }) +setGeneric("write.orc", function(x, path, ...) { standardGeneric("write.orc") }) #' @rdname write.parquet #' @export -setGeneric("write.parquet", function(x, path) { standardGeneric("write.parquet") }) +setGeneric("write.parquet", function(x, path, ...) { + standardGeneric("write.parquet") +}) #' @rdname write.parquet #' @export @@ -667,7 +669,7 @@ setGeneric("saveAsParquetFile", function(x, path) { standardGeneric("saveAsParqu #' @rdname write.text #' @export -setGeneric("write.text", function(x, path) { standardGeneric("write.text") }) +setGeneric("write.text", function(x, path, ...) { standardGeneric("write.text") }) #' @rdname schema #' @export diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index e69666453480c..fa8bb0f79ce80 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -334,6 +334,28 @@ varargsToEnv <- function(...) { env } +# Utility function to capture the varargs into environment object but all values are converted +# into string. +varargsToStrEnv <- function(...) { + pairs <- list(...) + env <- new.env() + for (name in names(pairs)) { + value <- pairs[[name]] + if (!(is.logical(value) || is.numeric(value) || is.character(value) || is.null(value))) { + stop(paste0("Unsupported type for ", name, " : ", class(value), + ". Supported types are logical, numeric, character and NULL.")) + } + if (is.logical(value)) { + env[[name]] <- tolower(as.character(value)) + } else if (is.null(value)) { + env[[name]] <- value + } else { + env[[name]] <- as.character(value) + } + } + env +} + getStorageLevel <- function(newLevel = c("DISK_ONLY", "DISK_ONLY_2", "MEMORY_AND_DISK", diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index f5ab601f274fe..6d8cfad5c1f93 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -256,6 +256,23 @@ test_that("read/write csv as DataFrame", { unlink(csvPath2) }) +test_that("Support other types for options", { + csvPath <- tempfile(pattern = "sparkr-test", fileext = ".csv") + mockLinesCsv <- c("year,make,model,comment,blank", + "\"2012\",\"Tesla\",\"S\",\"No comment\",", + "1997,Ford,E350,\"Go get one now they are going fast\",", + "2015,Chevy,Volt", + "NA,Dummy,Placeholder") + writeLines(mockLinesCsv, csvPath) + + csvDf <- read.df(csvPath, "csv", header = "true", inferSchema = "true") + expected <- read.df(csvPath, "csv", header = TRUE, inferSchema = TRUE) + expect_equal(collect(csvDf), collect(expected)) + + expect_error(read.df(csvPath, "csv", header = TRUE, maxColumns = 3)) + unlink(csvPath) +}) + test_that("convert NAs to null type in DataFrames", { rdd <- parallelize(sc, list(list(1L, 2L), list(NA, 4L))) df <- createDataFrame(rdd, list("a", "b")) @@ -497,6 +514,19 @@ test_that("read/write json files", { unlink(jsonPath3) }) +test_that("read/write json files - compression option", { + df <- read.df(jsonPath, "json") + + jsonPath <- tempfile(pattern = "jsonPath", fileext = ".json") + write.json(df, jsonPath, compression = "gzip") + jsonDF <- read.json(jsonPath) + expect_is(jsonDF, "SparkDataFrame") + expect_equal(count(jsonDF), count(df)) + expect_true(length(list.files(jsonPath, pattern = ".gz")) > 0) + + unlink(jsonPath) +}) + test_that("jsonRDD() on a RDD with json string", { sqlContext <- suppressWarnings(sparkRSQL.init(sc)) rdd <- parallelize(sc, mockLines) @@ -1786,6 +1816,21 @@ test_that("read/write ORC files", { unsetHiveContext() }) +test_that("read/write ORC files - compression option", { + setHiveContext(sc) + df <- read.df(jsonPath, "json") + + orcPath2 <- tempfile(pattern = "orcPath2", fileext = ".orc") + write.orc(df, orcPath2, compression = "ZLIB") + orcDF <- read.orc(orcPath2) + expect_is(orcDF, "SparkDataFrame") + expect_equal(count(orcDF), count(df)) + expect_true(length(list.files(orcPath2, pattern = ".zlib.orc")) > 0) + + unlink(orcPath2) + unsetHiveContext() +}) + test_that("read/write Parquet files", { df <- read.df(jsonPath, "json") # Test write.df and read.df @@ -1817,6 +1862,23 @@ test_that("read/write Parquet files", { unlink(parquetPath4) }) +test_that("read/write Parquet files - compression option/mode", { + df <- read.df(jsonPath, "json") + tempPath <- tempfile(pattern = "tempPath", fileext = ".parquet") + + # Test write.df and read.df + write.parquet(df, tempPath, compression = "GZIP") + df2 <- read.parquet(tempPath) + expect_is(df2, "SparkDataFrame") + expect_equal(count(df2), 3) + expect_true(length(list.files(tempPath, pattern = ".gz.parquet")) > 0) + + write.parquet(df, tempPath, mode = "overwrite") + df3 <- read.parquet(tempPath) + expect_is(df3, "SparkDataFrame") + expect_equal(count(df3), 3) +}) + test_that("read/write text files", { # Test write.df and read.df df <- read.df(jsonPath, "text") @@ -1838,6 +1900,19 @@ test_that("read/write text files", { unlink(textPath2) }) +test_that("read/write text files - compression option", { + df <- read.df(jsonPath, "text") + + textPath <- tempfile(pattern = "textPath", fileext = ".txt") + write.text(df, textPath, compression = "GZIP") + textDF <- read.text(textPath) + expect_is(textDF, "SparkDataFrame") + expect_equal(count(textDF), count(df)) + expect_true(length(list.files(textPath, pattern = ".gz")) > 0) + + unlink(textPath) +}) + test_that("describe() and summarize() on a DataFrame", { df <- read.json(jsonPath) stats <- describe(df, "age") diff --git a/R/pkg/inst/tests/testthat/test_utils.R b/R/pkg/inst/tests/testthat/test_utils.R index 69ed5549168b1..a20254e9b3fa9 100644 --- a/R/pkg/inst/tests/testthat/test_utils.R +++ b/R/pkg/inst/tests/testthat/test_utils.R @@ -217,4 +217,13 @@ test_that("rbindRaws", { }) +test_that("varargsToStrEnv", { + strenv <- varargsToStrEnv(a = 1, b = 1.1, c = TRUE, d = "abcd") + env <- varargsToEnv(a = "1", b = "1.1", c = "true", d = "abcd") + expect_equal(strenv, env) + expect_error(varargsToStrEnv(a = list(1, "a")), + paste0("Unsupported type for a : list. Supported types are logical, ", + "numeric, character and NULL.")) +}) + sparkR.session.stop() From 2badb58cdd7833465202197c4c52db5aa3d4c6e7 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 7 Oct 2016 13:45:00 -0700 Subject: [PATCH 220/306] [SPARK-15621][SQL] Support spilling for Python UDF ## What changes were proposed in this pull request? When execute a Python UDF, we buffer the input row into as queue, then pull them out to join with the result from Python UDF. In the case that Python UDF is slow or the input row is too wide, we could ran out of memory because of the queue. Since we can't flush all the buffers (sockets) between JVM and Python process from JVM side, we can't limit the rows in the queue, otherwise it could deadlock. This PR will manage the memory used by the queue, spill that into disk when there is no enough memory (also release the memory and disk space as soon as possible). ## How was this patch tested? Added unit tests. Also manually ran a workload with large input row and slow python UDF (with large broadcast) like this: ``` b = range(1<<24) add = udf(lambda x: x + len(b), IntegerType()) df = sqlContext.range(1, 1<<26, 1, 4) print df.select(df.id, lit("adf"*10000).alias("s"), add(df.id).alias("add")).groupBy(length("s")).sum().collect() ``` It ran out of memory (hang because of full GC) before the patch, ran smoothly after the patch. Author: Davies Liu Closes #15089 from davies/spill_udf. --- .../python/BatchEvalPythonExec.scala | 36 ++- .../spark/sql/execution/python/RowQueue.scala | 280 ++++++++++++++++++ .../sql/execution/python/RowQueueSuite.scala | 127 ++++++++ 3 files changed, 436 insertions(+), 7 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala index d9bf4d3ccf698..f9d20ad090056 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala @@ -17,18 +17,21 @@ package org.apache.spark.sql.execution.python +import java.io.File + import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import net.razorvine.pickle.{Pickler, Unpickler} -import org.apache.spark.TaskContext +import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.api.python.{ChainedPythonFunctions, PythonRunner} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.types.{DataType, StructField, StructType} +import org.apache.spark.util.Utils /** @@ -37,9 +40,25 @@ import org.apache.spark.sql.types.{DataType, StructField, StructType} * Python evaluation works by sending the necessary (projected) input data via a socket to an * external Python process, and combine the result from the Python process with the original row. * - * For each row we send to Python, we also put it in a queue. For each output row from Python, + * For each row we send to Python, we also put it in a queue first. For each output row from Python, * we drain the queue to find the original input row. Note that if the Python process is way too - * slow, this could lead to the queue growing unbounded and eventually run out of memory. + * slow, this could lead to the queue growing unbounded and spill into disk when run out of memory. + * + * Here is a diagram to show how this works: + * + * Downstream (for parent) + * / \ + * / socket (output of UDF) + * / \ + * RowQueue Python + * \ / + * \ socket (input of UDF) + * \ / + * upstream (from child) + * + * The rows sent to and received from Python are packed into batches (100 rows) and serialized, + * there should be always some rows buffered in the socket or Python process, so the pulling from + * RowQueue ALWAYS happened after pushing into it. */ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan) extends SparkPlan { @@ -70,7 +89,11 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi // The queue used to buffer input rows so we can drain it to // combine input with output from Python. - val queue = new java.util.concurrent.ConcurrentLinkedQueue[InternalRow]() + val queue = HybridRowQueue(TaskContext.get().taskMemoryManager(), + new File(Utils.getLocalDir(SparkEnv.get.conf)), child.output.length) + TaskContext.get().addTaskCompletionListener({ ctx => + queue.close() + }) val (pyFuncs, inputs) = udfs.map(collectFunctions).unzip @@ -98,7 +121,7 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi // For each row, add it to the queue. val inputIterator = iter.grouped(100).map { inputRows => val toBePickled = inputRows.map { inputRow => - queue.add(inputRow) + queue.add(inputRow.asInstanceOf[UnsafeRow]) val row = projection(inputRow) if (needConversion) { EvaluatePython.toJava(row, schema) @@ -132,7 +155,6 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi StructType(udfs.map(u => StructField("", u.dataType, u.nullable))) } val resultProj = UnsafeProjection.create(output, output) - outputIterator.flatMap { pickedResult => val unpickledBatch = unpickle.loads(pickedResult) unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala @@ -144,7 +166,7 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi } else { EvaluatePython.fromJava(result, resultType).asInstanceOf[InternalRow] } - resultProj(joined(queue.poll(), row)) + resultProj(joined(queue.remove(), row)) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala new file mode 100644 index 0000000000000..422a3f862d96f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala @@ -0,0 +1,280 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.python + +import java.io._ + +import com.google.common.io.Closeables + +import org.apache.spark.SparkException +import org.apache.spark.memory.{MemoryConsumer, TaskMemoryManager} +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.unsafe.Platform +import org.apache.spark.unsafe.memory.MemoryBlock + +/** + * A RowQueue is an FIFO queue for UnsafeRow. + * + * This RowQueue is ONLY designed and used for Python UDF, which has only one writer and only one + * reader, the reader ALWAYS ran behind the writer. See the doc of class [[BatchEvalPythonExec]] + * on how it works. + */ +private[python] trait RowQueue { + + /** + * Add a row to the end of it, returns true iff the row has been added to the queue. + */ + def add(row: UnsafeRow): Boolean + + /** + * Retrieve and remove the first row, returns null if it's empty. + * + * It can only be called after add is called, otherwise it will fail (NPE). + */ + def remove(): UnsafeRow + + /** + * Cleanup all the resources. + */ + def close(): Unit +} + +/** + * A RowQueue that is based on in-memory page. UnsafeRows are appended into it until it's full. + * Another thread could read from it at the same time (behind the writer). + * + * The format of UnsafeRow in page: + * [4 bytes to hold length of record (N)] [N bytes to hold record] [...] + * + * -1 length means end of page. + */ +private[python] abstract class InMemoryRowQueue(val page: MemoryBlock, numFields: Int) + extends RowQueue { + private val base: AnyRef = page.getBaseObject + private val endOfPage: Long = page.getBaseOffset + page.size + // the first location where a new row would be written + private var writeOffset = page.getBaseOffset + // points to the start of the next row to read + private var readOffset = page.getBaseOffset + private val resultRow = new UnsafeRow(numFields) + + def add(row: UnsafeRow): Boolean = synchronized { + val size = row.getSizeInBytes + if (writeOffset + 4 + size > endOfPage) { + // if there is not enough space in this page to hold the new record + if (writeOffset + 4 <= endOfPage) { + // if there's extra space at the end of the page, store a special "end-of-page" length (-1) + Platform.putInt(base, writeOffset, -1) + } + false + } else { + Platform.putInt(base, writeOffset, size) + Platform.copyMemory(row.getBaseObject, row.getBaseOffset, base, writeOffset + 4, size) + writeOffset += 4 + size + true + } + } + + def remove(): UnsafeRow = synchronized { + assert(readOffset <= writeOffset, "reader should not go beyond writer") + if (readOffset + 4 > endOfPage || Platform.getInt(base, readOffset) < 0) { + null + } else { + val size = Platform.getInt(base, readOffset) + resultRow.pointTo(base, readOffset + 4, size) + readOffset += 4 + size + resultRow + } + } +} + +/** + * A RowQueue that is backed by a file on disk. This queue will stop accepting new rows once any + * reader has begun reading from the queue. + */ +private[python] case class DiskRowQueue(file: File, fields: Int) extends RowQueue { + private var out = new DataOutputStream( + new BufferedOutputStream(new FileOutputStream(file.toString))) + private var unreadBytes = 0L + + private var in: DataInputStream = _ + private val resultRow = new UnsafeRow(fields) + + def add(row: UnsafeRow): Boolean = synchronized { + if (out == null) { + // Another thread is reading, stop writing this one + return false + } + out.writeInt(row.getSizeInBytes) + out.write(row.getBytes) + unreadBytes += 4 + row.getSizeInBytes + true + } + + def remove(): UnsafeRow = synchronized { + if (out != null) { + out.close() + out = null + in = new DataInputStream(new BufferedInputStream(new FileInputStream(file.toString))) + } + + if (unreadBytes > 0) { + val size = in.readInt() + val bytes = new Array[Byte](size) + in.readFully(bytes) + unreadBytes -= 4 + size + resultRow.pointTo(bytes, size) + resultRow + } else { + null + } + } + + def close(): Unit = synchronized { + Closeables.close(out, true) + out = null + Closeables.close(in, true) + in = null + if (file.exists()) { + file.delete() + } + } +} + +/** + * A RowQueue that has a list of RowQueues, which could be in memory or disk. + * + * HybridRowQueue could be safely appended in one thread, and pulled in another thread in the same + * time. + */ +private[python] case class HybridRowQueue( + memManager: TaskMemoryManager, + tempDir: File, + numFields: Int) + extends MemoryConsumer(memManager) with RowQueue { + + // Each buffer should have at least one row + private var queues = new java.util.LinkedList[RowQueue]() + + private var writing: RowQueue = _ + private var reading: RowQueue = _ + + // exposed for testing + private[python] def numQueues(): Int = queues.size() + + def spill(size: Long, trigger: MemoryConsumer): Long = { + if (trigger == this) { + // When it's triggered by itself, it should write upcoming rows into disk instead of copying + // the rows already in the queue. + return 0L + } + var released = 0L + synchronized { + // poll out all the buffers and add them back in the same order to make sure that the rows + // are in correct order. + val newQueues = new java.util.LinkedList[RowQueue]() + while (!queues.isEmpty) { + val queue = queues.remove() + val newQueue = if (!queues.isEmpty && queue.isInstanceOf[InMemoryRowQueue]) { + val diskQueue = createDiskQueue() + var row = queue.remove() + while (row != null) { + diskQueue.add(row) + row = queue.remove() + } + released += queue.asInstanceOf[InMemoryRowQueue].page.size() + queue.close() + diskQueue + } else { + queue + } + newQueues.add(newQueue) + } + queues = newQueues + } + released + } + + private def createDiskQueue(): RowQueue = { + DiskRowQueue(File.createTempFile("buffer", "", tempDir), numFields) + } + + private def createNewQueue(required: Long): RowQueue = { + val page = try { + allocatePage(required) + } catch { + case _: OutOfMemoryError => + null + } + val buffer = if (page != null) { + new InMemoryRowQueue(page, numFields) { + override def close(): Unit = { + freePage(page) + } + } + } else { + createDiskQueue() + } + + synchronized { + queues.add(buffer) + } + buffer + } + + def add(row: UnsafeRow): Boolean = { + if (writing == null || !writing.add(row)) { + writing = createNewQueue(4 + row.getSizeInBytes) + if (!writing.add(row)) { + throw new SparkException(s"failed to push a row into $writing") + } + } + true + } + + def remove(): UnsafeRow = { + var row: UnsafeRow = null + if (reading != null) { + row = reading.remove() + } + if (row == null) { + if (reading != null) { + reading.close() + } + synchronized { + reading = queues.remove() + } + assert(reading != null, s"queue should not be empty") + row = reading.remove() + assert(row != null, s"$reading should have at least one row") + } + row + } + + def close(): Unit = { + if (reading != null) { + reading.close() + reading = null + } + synchronized { + while (!queues.isEmpty) { + queues.remove().close() + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala new file mode 100644 index 0000000000000..ffda33cf906c5 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import java.io.File + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.memory.{MemoryManager, TaskMemoryManager, TestMemoryManager} +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.unsafe.memory.MemoryBlock +import org.apache.spark.util.Utils + +class RowQueueSuite extends SparkFunSuite { + + test("in-memory queue") { + val page = MemoryBlock.fromLongArray(new Array[Long](1<<10)) + val queue = new InMemoryRowQueue(page, 1) { + override def close() {} + } + val row = new UnsafeRow(1) + row.pointTo(new Array[Byte](16), 16) + val n = page.size() / (4 + row.getSizeInBytes) + var i = 0 + while (i < n) { + row.setLong(0, i) + assert(queue.add(row), "fail to add") + i += 1 + } + assert(!queue.add(row), "should not add more") + i = 0 + while (i < n) { + val row = queue.remove() + assert(row != null, "fail to poll") + assert(row.getLong(0) == i, "does not match") + i += 1 + } + assert(queue.remove() == null, "should be empty") + queue.close() + } + + test("disk queue") { + val dir = Utils.createTempDir().getCanonicalFile + dir.mkdirs() + val queue = DiskRowQueue(new File(dir, "buffer"), 1) + val row = new UnsafeRow(1) + row.pointTo(new Array[Byte](16), 16) + val n = 1000 + var i = 0 + while (i < n) { + row.setLong(0, i) + assert(queue.add(row), "fail to add") + i += 1 + } + val first = queue.remove() + assert(first != null, "first should not be null") + assert(first.getLong(0) == 0, "first should be 0") + assert(!queue.add(row), "should not add more") + i = 1 + while (i < n) { + val row = queue.remove() + assert(row != null, "fail to poll") + assert(row.getLong(0) == i, "does not match") + i += 1 + } + assert(queue.remove() == null, "should be empty") + queue.close() + } + + test("hybrid queue") { + val mem = new TestMemoryManager(new SparkConf()) + mem.limit(4<<10) + val taskM = new TaskMemoryManager(mem, 0) + val queue = HybridRowQueue(taskM, Utils.createTempDir().getCanonicalFile, 1) + val row = new UnsafeRow(1) + row.pointTo(new Array[Byte](16), 16) + val n = (4<<10) / 16 * 3 + var i = 0 + while (i < n) { + row.setLong(0, i) + assert(queue.add(row), "fail to add") + i += 1 + } + assert(queue.numQueues() > 1, "should have more than one queue") + queue.spill(1<<20, null) + i = 0 + while (i < n) { + val row = queue.remove() + assert(row != null, "fail to poll") + assert(row.getLong(0) == i, "does not match") + i += 1 + } + + // fill again and spill + i = 0 + while (i < n) { + row.setLong(0, i) + assert(queue.add(row), "fail to add") + i += 1 + } + assert(queue.numQueues() > 1, "should have more than one queue") + queue.spill(1<<20, null) + assert(queue.numQueues() > 1, "should have more than one queue") + i = 0 + while (i < n) { + val row = queue.remove() + assert(row != null, "fail to poll") + assert(row.getLong(0) == i, "does not match") + i += 1 + } + queue.close() + } +} From 97594c29b723f372a5c4c061760015bd78d01f50 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Fri, 7 Oct 2016 14:03:45 -0700 Subject: [PATCH 221/306] [SPARK-17761][SQL] Remove MutableRow ## What changes were proposed in this pull request? In practice we cannot guarantee that an `InternalRow` is immutable. This makes the `MutableRow` almost redundant. This PR folds `MutableRow` into `InternalRow`. The code below illustrates the immutability issue with InternalRow: ```scala import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericMutableRow val struct = new GenericMutableRow(1) val row = InternalRow(struct, 1) println(row) scala> [[null], 1] struct.setInt(0, 42) println(row) scala> [[42], 1] ``` This might be somewhat controversial, so feedback is appreciated. ## How was this patch tested? Existing tests. Author: Herman van Hovell Closes #15333 from hvanhovell/SPARK-17761. --- .../apache/spark/ml/linalg/MatrixUDT.scala | 4 +- .../apache/spark/ml/linalg/VectorUDT.scala | 6 +- .../apache/spark/mllib/linalg/Matrices.scala | 4 +- .../apache/spark/mllib/linalg/Vectors.scala | 6 +- .../sql/catalyst/expressions/UnsafeRow.java | 2 +- .../spark/sql/catalyst/InternalRow.scala | 23 +++++- .../catalyst/encoders/ExpressionEncoder.scala | 2 +- .../spark/sql/catalyst/expressions/Cast.scala | 4 +- .../sql/catalyst/expressions/JoinedRow.scala | 16 +++++ .../sql/catalyst/expressions/Projection.scala | 4 +- ...bleRow.scala => SpecificInternalRow.scala} | 5 +- .../aggregate/HyperLogLogPlusPlus.scala | 6 +- .../expressions/aggregate/PivotFirst.scala | 10 +-- .../expressions/aggregate/collect.scala | 6 +- .../expressions/aggregate/interfaces.scala | 14 ++-- .../expressions/codegen/CodeGenerator.scala | 3 +- .../codegen/GenerateMutableProjection.scala | 8 +-- .../codegen/GenerateSafeProjection.scala | 8 +-- .../sql/catalyst/expressions/package.scala | 2 +- .../spark/sql/catalyst/expressions/rows.scala | 44 +----------- .../sql/catalyst/json/JacksonParser.scala | 4 +- .../sql/catalyst/ScalaReflectionSuite.scala | 4 +- .../expressions/CodeGenerationSuite.scala | 16 ++--- .../catalyst/expressions/MapDataSuite.scala | 2 +- .../expressions/UnsafeRowConverterSuite.scala | 26 +++---- .../ApproximatePercentileSuite.scala | 9 +-- .../aggregate/HyperLogLogPlusPlusSuite.scala | 13 ++-- .../execution/vectorized/ColumnarBatch.java | 7 +- .../spark/sql/execution/ExistingRDD.scala | 4 +- .../aggregate/AggregationIterator.scala | 26 +++---- .../SortBasedAggregationIterator.scala | 6 +- .../TungstenAggregationIterator.scala | 8 +-- .../spark/sql/execution/aggregate/udaf.scala | 38 +++++----- .../execution/columnar/ColumnAccessor.scala | 13 ++-- .../sql/execution/columnar/ColumnType.scala | 72 +++++++++---------- .../columnar/GenerateColumnAccessor.scala | 6 +- .../columnar/NullableColumnAccessor.scala | 4 +- .../CompressibleColumnAccessor.scala | 4 +- .../compression/CompressionScheme.scala | 3 +- .../compression/compressionSchemes.scala | 20 +++--- .../datasources/DataSourceStrategy.scala | 2 +- .../datasources/csv/CSVRelation.scala | 4 +- .../datasources/jdbc/JdbcUtils.scala | 34 ++++----- .../parquet/ParquetRowConverter.scala | 6 +- .../joins/BroadcastNestedLoopJoinExec.scala | 10 +-- .../spark/sql/execution/joins/HashJoin.scala | 2 +- .../execution/joins/SortMergeJoinExec.scala | 2 +- .../apache/spark/sql/execution/objects.scala | 4 +- .../python/BatchEvalPythonExec.scala | 2 +- .../sql/execution/stat/StatFunctions.scala | 4 +- .../execution/window/AggregateProcessor.scala | 4 +- .../sql/execution/window/WindowExec.scala | 12 ++-- .../window/WindowFunctionFrame.scala | 10 +-- .../scala/org/apache/spark/sql/RowSuite.scala | 6 +- .../sql/TypedImperativeAggregateSuite.scala | 6 +- .../execution/columnar/ColumnTypeSuite.scala | 4 +- .../columnar/ColumnarTestUtils.scala | 12 ++-- .../NullableColumnAccessorSuite.scala | 4 +- .../columnar/NullableColumnBuilderSuite.scala | 4 +- .../compression/BooleanBitSetSuite.scala | 4 +- .../CompressionSchemeBenchmark.scala | 4 +- .../compression/DictionaryEncodingSuite.scala | 4 +- .../compression/IntegralDeltaSuite.scala | 6 +- .../compression/RunLengthEncodingSuite.scala | 4 +- .../datasources/parquet/ParquetIOSuite.scala | 4 +- .../parquet/ParquetQuerySuite.scala | 4 +- .../spark/sql/hive/HiveInspectors.scala | 18 ++--- .../apache/spark/sql/hive/TableReader.scala | 38 +++++----- .../hive/execution/ScriptTransformation.scala | 2 +- .../org/apache/spark/sql/hive/hiveUDFs.scala | 6 +- .../spark/sql/hive/orc/OrcFileFormat.scala | 2 +- 71 files changed, 343 insertions(+), 347 deletions(-) rename sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/{SpecificMutableRow.scala => SpecificInternalRow.scala} (98%) diff --git a/mllib/src/main/scala/org/apache/spark/ml/linalg/MatrixUDT.scala b/mllib/src/main/scala/org/apache/spark/ml/linalg/MatrixUDT.scala index a1e53662f02a8..f4a8556c71f6e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/linalg/MatrixUDT.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/linalg/MatrixUDT.scala @@ -18,7 +18,7 @@ package org.apache.spark.ml.linalg import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, UnsafeArrayData} +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeArrayData} import org.apache.spark.sql.types._ /** @@ -46,7 +46,7 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] { } override def serialize(obj: Matrix): InternalRow = { - val row = new GenericMutableRow(7) + val row = new GenericInternalRow(7) obj match { case sm: SparseMatrix => row.setByte(0, 0) diff --git a/mllib/src/main/scala/org/apache/spark/ml/linalg/VectorUDT.scala b/mllib/src/main/scala/org/apache/spark/ml/linalg/VectorUDT.scala index 0b9b2ff5c5e26..917861309c573 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/linalg/VectorUDT.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/linalg/VectorUDT.scala @@ -18,7 +18,7 @@ package org.apache.spark.ml.linalg import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, UnsafeArrayData} +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeArrayData} import org.apache.spark.sql.types._ /** @@ -42,14 +42,14 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] { override def serialize(obj: Vector): InternalRow = { obj match { case SparseVector(size, indices, values) => - val row = new GenericMutableRow(4) + val row = new GenericInternalRow(4) row.setByte(0, 0) row.setInt(1, size) row.update(2, UnsafeArrayData.fromPrimitiveArray(indices)) row.update(3, UnsafeArrayData.fromPrimitiveArray(values)) row case DenseVector(values) => - val row = new GenericMutableRow(4) + val row = new GenericInternalRow(4) row.setByte(0, 1) row.setNullAt(1) row.setNullAt(2) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index 6642999a2121f..542a69b3ef8cf 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -28,7 +28,7 @@ import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.apache.spark.annotation.Since import org.apache.spark.ml.{linalg => newlinalg} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, UnsafeArrayData} +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeArrayData} import org.apache.spark.sql.types._ /** @@ -189,7 +189,7 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] { } override def serialize(obj: Matrix): InternalRow = { - val row = new GenericMutableRow(7) + val row = new GenericInternalRow(7) obj match { case sm: SparseMatrix => row.setByte(0, 0) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 91f065831c804..fbd217af74ecb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -34,7 +34,7 @@ import org.apache.spark.annotation.{AlphaComponent, Since} import org.apache.spark.ml.{linalg => newlinalg} import org.apache.spark.mllib.util.NumericParser import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, UnsafeArrayData} +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeArrayData} import org.apache.spark.sql.types._ /** @@ -214,14 +214,14 @@ class VectorUDT extends UserDefinedType[Vector] { override def serialize(obj: Vector): InternalRow = { obj match { case SparseVector(size, indices, values) => - val row = new GenericMutableRow(4) + val row = new GenericInternalRow(4) row.setByte(0, 0) row.setInt(1, size) row.update(2, UnsafeArrayData.fromPrimitiveArray(indices)) row.update(3, UnsafeArrayData.fromPrimitiveArray(values)) row case DenseVector(values) => - val row = new GenericMutableRow(4) + val row = new GenericInternalRow(4) row.setByte(0, 1) row.setNullAt(1) row.setNullAt(2) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 9027652d57f14..c3f0abac244cf 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -59,7 +59,7 @@ * * Instances of `UnsafeRow` act as pointers to row data stored in this format. */ -public final class UnsafeRow extends MutableRow implements Externalizable, KryoSerializable { +public final class UnsafeRow extends InternalRow implements Externalizable, KryoSerializable { ////////////////////////////////////////////////////////////////////////////// // Static methods diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index eba95c5c8b908..f498e071b50a3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.sql.types.{DataType, Decimal, StructType} /** * An abstract class for row used internal in Spark SQL, which only contain the columns as @@ -31,6 +31,27 @@ abstract class InternalRow extends SpecializedGetters with Serializable { // This is only use for test and will throw a null pointer exception if the position is null. def getString(ordinal: Int): String = getUTF8String(ordinal).toString + def setNullAt(i: Int): Unit + + def update(i: Int, value: Any): Unit + + // default implementation (slow) + def setBoolean(i: Int, value: Boolean): Unit = update(i, value) + def setByte(i: Int, value: Byte): Unit = update(i, value) + def setShort(i: Int, value: Short): Unit = update(i, value) + def setInt(i: Int, value: Int): Unit = update(i, value) + def setLong(i: Int, value: Long): Unit = update(i, value) + def setFloat(i: Int, value: Float): Unit = update(i, value) + def setDouble(i: Int, value: Double): Unit = update(i, value) + + /** + * Update the decimal column at `i`. + * + * Note: In order to support update decimal with precision > 18 in UnsafeRow, + * CAN NOT call setNullAt() for decimal column on UnsafeRow, call setDecimal(i, null, precision). + */ + def setDecimal(i: Int, value: Decimal, precision: Int) { update(i, value) } + /** * Make a copy of the current [[InternalRow]] object. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index b96b744b4fa98..82e1a8a7cad96 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -256,7 +256,7 @@ case class ExpressionEncoder[T]( private lazy val extractProjection = GenerateUnsafeProjection.generate(serializer) @transient - private lazy val inputRow = new GenericMutableRow(1) + private lazy val inputRow = new GenericInternalRow(1) @transient private lazy val constructProjection = GenerateSafeProjection.generate(deserializer :: Nil) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 70fff51956255..1314c416510dc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -403,7 +403,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case (fromField, toField) => cast(fromField.dataType, toField.dataType) } // TODO: Could be faster? - val newRow = new GenericMutableRow(from.fields.length) + val newRow = new GenericInternalRow(from.fields.length) buildCast[InternalRow](_, row => { var i = 0 while (i < row.numFields) { @@ -892,7 +892,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w val fieldsCasts = from.fields.zip(to.fields).map { case (fromField, toField) => nullSafeCastFunction(fromField.dataType, toField.dataType, ctx) } - val rowClass = classOf[GenericMutableRow].getName + val rowClass = classOf[GenericInternalRow].getName val result = ctx.freshName("result") val tmpRow = ctx.freshName("tmpRow") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala index ed894f6d6e10e..7770684a5b399 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala @@ -123,6 +123,22 @@ class JoinedRow extends InternalRow { override def anyNull: Boolean = row1.anyNull || row2.anyNull + override def setNullAt(i: Int): Unit = { + if (i < row1.numFields) { + row1.setNullAt(i) + } else { + row2.setNullAt(i - row1.numFields) + } + } + + override def update(i: Int, value: Any): Unit = { + if (i < row1.numFields) { + row1.update(i, value) + } else { + row2.update(i - row1.numFields, value) + } + } + override def copy(): InternalRow = { val copy1 = row1.copy() val copy2 = row2.copy() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index c8d18667f7c4a..a81fa1ce3adcc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -69,10 +69,10 @@ case class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mu }) private[this] val exprArray = expressions.toArray - private[this] var mutableRow: MutableRow = new GenericMutableRow(exprArray.length) + private[this] var mutableRow: InternalRow = new GenericInternalRow(exprArray.length) def currentValue: InternalRow = mutableRow - override def target(row: MutableRow): MutableProjection = { + override def target(row: InternalRow): MutableProjection = { mutableRow = row this } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificInternalRow.scala similarity index 98% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificInternalRow.scala index 61ca7272dfa61..74e0b4691d4cc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificInternalRow.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.types._ /** * A parent class for mutable container objects that are reused when the values are changed, - * resulting in less garbage. These values are held by a [[SpecificMutableRow]]. + * resulting in less garbage. These values are held by a [[SpecificInternalRow]]. * * The following code was roughly used to generate these objects: * {{{ @@ -191,8 +191,7 @@ final class MutableAny extends MutableValue { * based on the dataTypes of each column. The intent is to decrease garbage when modifying the * values of primitive columns. */ -final class SpecificMutableRow(val values: Array[MutableValue]) - extends MutableRow with BaseGenericInternalRow { +final class SpecificInternalRow(val values: Array[MutableValue]) extends BaseGenericInternalRow { def this(dataTypes: Seq[DataType]) = this( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala index 1d218da6db806..83c8d400c5d6a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala @@ -155,7 +155,7 @@ case class HyperLogLogPlusPlus( aggBufferAttributes.map(_.newInstance()) /** Fill all words with zeros. */ - override def initialize(buffer: MutableRow): Unit = { + override def initialize(buffer: InternalRow): Unit = { var word = 0 while (word < numWords) { buffer.setLong(mutableAggBufferOffset + word, 0) @@ -168,7 +168,7 @@ case class HyperLogLogPlusPlus( * * Variable names in the HLL++ paper match variable names in the code. */ - override def update(buffer: MutableRow, input: InternalRow): Unit = { + override def update(buffer: InternalRow, input: InternalRow): Unit = { val v = child.eval(input) if (v != null) { // Create the hashed value 'x'. @@ -200,7 +200,7 @@ case class HyperLogLogPlusPlus( * Merge the HLL buffers by iterating through the registers in both buffers and select the * maximum number of leading zeros for each register. */ - override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { + override def merge(buffer1: InternalRow, buffer2: InternalRow): Unit = { var idx = 0 var wordOffset = 0 while (wordOffset < numWords) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala index 16c03c500ad08..087606077295f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala @@ -30,7 +30,7 @@ object PivotFirst { // Currently UnsafeRow does not support the generic update method (throws // UnsupportedOperationException), so we need to explicitly support each DataType. - private val updateFunction: PartialFunction[DataType, (MutableRow, Int, Any) => Unit] = { + private val updateFunction: PartialFunction[DataType, (InternalRow, Int, Any) => Unit] = { case DoubleType => (row, offset, value) => row.setDouble(offset, value.asInstanceOf[Double]) case IntegerType => @@ -89,9 +89,9 @@ case class PivotFirst( val indexSize = pivotIndex.size - private val updateRow: (MutableRow, Int, Any) => Unit = PivotFirst.updateFunction(valueDataType) + private val updateRow: (InternalRow, Int, Any) => Unit = PivotFirst.updateFunction(valueDataType) - override def update(mutableAggBuffer: MutableRow, inputRow: InternalRow): Unit = { + override def update(mutableAggBuffer: InternalRow, inputRow: InternalRow): Unit = { val pivotColValue = pivotColumn.eval(inputRow) if (pivotColValue != null) { // We ignore rows whose pivot column value is not in the list of pivot column values. @@ -105,7 +105,7 @@ case class PivotFirst( } } - override def merge(mutableAggBuffer: MutableRow, inputAggBuffer: InternalRow): Unit = { + override def merge(mutableAggBuffer: InternalRow, inputAggBuffer: InternalRow): Unit = { for (i <- 0 until indexSize) { if (!inputAggBuffer.isNullAt(inputAggBufferOffset + i)) { val value = inputAggBuffer.get(inputAggBufferOffset + i, valueDataType) @@ -114,7 +114,7 @@ case class PivotFirst( } } - override def initialize(mutableAggBuffer: MutableRow): Unit = valueDataType match { + override def initialize(mutableAggBuffer: InternalRow): Unit = valueDataType match { case d: DecimalType => // Per doc of setDecimal we need to do this instead of setNullAt for DecimalType. for (i <- 0 until indexSize) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala index 78a388d20630b..89eb864e94702 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala @@ -60,11 +60,11 @@ abstract class Collect extends ImperativeAggregate { protected[this] val buffer: Growable[Any] with Iterable[Any] - override def initialize(b: MutableRow): Unit = { + override def initialize(b: InternalRow): Unit = { buffer.clear() } - override def update(b: MutableRow, input: InternalRow): Unit = { + override def update(b: InternalRow, input: InternalRow): Unit = { // Do not allow null values. We follow the semantics of Hive's collect_list/collect_set here. // See: org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMkCollectionEvaluator val value = child.eval(input) @@ -73,7 +73,7 @@ abstract class Collect extends ImperativeAggregate { } } - override def merge(buffer: MutableRow, input: InternalRow): Unit = { + override def merge(buffer: InternalRow, input: InternalRow): Unit = { sys.error("Collect cannot be used in partial aggregations.") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index b5c0844fbf310..f3fd58bc98ef6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -307,14 +307,14 @@ abstract class ImperativeAggregate extends AggregateFunction with CodegenFallbac * * Use `fieldNumber + mutableAggBufferOffset` to access fields of `mutableAggBuffer`. */ - def initialize(mutableAggBuffer: MutableRow): Unit + def initialize(mutableAggBuffer: InternalRow): Unit /** * Updates its aggregation buffer, located in `mutableAggBuffer`, based on the given `inputRow`. * * Use `fieldNumber + mutableAggBufferOffset` to access fields of `mutableAggBuffer`. */ - def update(mutableAggBuffer: MutableRow, inputRow: InternalRow): Unit + def update(mutableAggBuffer: InternalRow, inputRow: InternalRow): Unit /** * Combines new intermediate results from the `inputAggBuffer` with the existing intermediate @@ -323,7 +323,7 @@ abstract class ImperativeAggregate extends AggregateFunction with CodegenFallbac * Use `fieldNumber + mutableAggBufferOffset` to access fields of `mutableAggBuffer`. * Use `fieldNumber + inputAggBufferOffset` to access fields of `inputAggBuffer`. */ - def merge(mutableAggBuffer: MutableRow, inputAggBuffer: InternalRow): Unit + def merge(mutableAggBuffer: InternalRow, inputAggBuffer: InternalRow): Unit } /** @@ -504,16 +504,16 @@ abstract class TypedImperativeAggregate[T] extends ImperativeAggregate { /** De-serializes the serialized format Array[Byte], and produces aggregation buffer object T */ def deserialize(storageFormat: Array[Byte]): T - final override def initialize(buffer: MutableRow): Unit = { + final override def initialize(buffer: InternalRow): Unit = { val bufferObject = createAggregationBuffer() buffer.update(mutableAggBufferOffset, bufferObject) } - final override def update(buffer: MutableRow, input: InternalRow): Unit = { + final override def update(buffer: InternalRow, input: InternalRow): Unit = { update(getBufferObject(buffer), input) } - final override def merge(buffer: MutableRow, inputBuffer: InternalRow): Unit = { + final override def merge(buffer: InternalRow, inputBuffer: InternalRow): Unit = { val bufferObject = getBufferObject(buffer) // The inputBuffer stores serialized aggregation buffer object produced by partial aggregate val inputObject = deserialize(inputBuffer.getBinary(inputAggBufferOffset)) @@ -547,7 +547,7 @@ abstract class TypedImperativeAggregate[T] extends ImperativeAggregate { * This is only called when doing Partial or PartialMerge mode aggregation, before the framework * shuffle out aggregate buffers. */ - final def serializeAggregateBufferInPlace(buffer: MutableRow): Unit = { + final def serializeAggregateBufferInPlace(buffer: InternalRow): Unit = { buffer(mutableAggBufferOffset) = serialize(getBufferObject(buffer)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 574943d3d21f0..6cab50ae1bf8d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -819,7 +819,7 @@ class CodeAndComment(val body: String, val comment: collection.Map[String, Strin */ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Logging { - protected val genericMutableRowType: String = classOf[GenericMutableRow].getName + protected val genericMutableRowType: String = classOf[GenericInternalRow].getName /** * Generates a class for a given input expression. Called when there is not cached code @@ -889,7 +889,6 @@ object CodeGenerator extends Logging { classOf[UnsafeArrayData].getName, classOf[MapData].getName, classOf[UnsafeMapData].getName, - classOf[MutableRow].getName, classOf[Expression].getName )) evaluator.setExtendedClass(classOf[GeneratedClass]) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 13d61af1c9b40..5c4b56b0b224c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -24,10 +24,10 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp abstract class BaseMutableProjection extends MutableProjection /** - * Generates byte code that produces a [[MutableRow]] object that can update itself based on a new + * Generates byte code that produces a [[InternalRow]] object that can update itself based on a new * input [[InternalRow]] for a fixed set of [[Expression Expressions]]. * It exposes a `target` method, which is used to set the row that will be updated. - * The internal [[MutableRow]] object created internally is used only when `target` is not used. + * The internal [[InternalRow]] object created internally is used only when `target` is not used. */ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableProjection] { @@ -102,7 +102,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP class SpecificMutableProjection extends ${classOf[BaseMutableProjection].getName} { private Object[] references; - private MutableRow mutableRow; + private InternalRow mutableRow; ${ctx.declareMutableStates()} public SpecificMutableProjection(Object[] references) { @@ -113,7 +113,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP ${ctx.declareAddedFunctions()} - public ${classOf[BaseMutableProjection].getName} target(MutableRow row) { + public ${classOf[BaseMutableProjection].getName} target(InternalRow row) { mutableRow = row; return this; } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index 1c98c9ed10705..2773e1a666212 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.types._ abstract class BaseProjection extends Projection {} /** - * Generates byte code that produces a [[MutableRow]] object (not an [[UnsafeRow]]) that can update + * Generates byte code that produces a [[InternalRow]] object (not an [[UnsafeRow]]) that can update * itself based on a new input [[InternalRow]] for a fixed set of [[Expression Expressions]]. */ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] { @@ -164,12 +164,12 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] class SpecificSafeProjection extends ${classOf[BaseProjection].getName} { private Object[] references; - private MutableRow mutableRow; + private InternalRow mutableRow; ${ctx.declareMutableStates()} public SpecificSafeProjection(Object[] references) { this.references = references; - mutableRow = (MutableRow) references[references.length - 1]; + mutableRow = (InternalRow) references[references.length - 1]; ${ctx.initMutableStates()} } @@ -188,7 +188,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] logDebug(s"code for ${expressions.mkString(",")}:\n${CodeFormatter.format(code)}") val c = CodeGenerator.compile(code) - val resultRow = new SpecificMutableRow(expressions.map(_.dataType)) + val resultRow = new SpecificInternalRow(expressions.map(_.dataType)) c.generate(ctx.references.toArray :+ resultRow).asInstanceOf[Projection] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index a6125c61e508a..1510a4796683c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -81,7 +81,7 @@ package object expressions { def currentValue: InternalRow /** Uses the given row to store the output of the projection. */ - def target(row: MutableRow): MutableProjection + def target(row: InternalRow): MutableProjection } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index 73dceb35ac50e..751b821e1b009 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -157,33 +157,6 @@ trait BaseGenericInternalRow extends InternalRow { } } -/** - * An extended interface to [[InternalRow]] that allows the values for each column to be updated. - * Setting a value through a primitive function implicitly marks that column as not null. - */ -abstract class MutableRow extends InternalRow { - def setNullAt(i: Int): Unit - - def update(i: Int, value: Any): Unit - - // default implementation (slow) - def setBoolean(i: Int, value: Boolean): Unit = { update(i, value) } - def setByte(i: Int, value: Byte): Unit = { update(i, value) } - def setShort(i: Int, value: Short): Unit = { update(i, value) } - def setInt(i: Int, value: Int): Unit = { update(i, value) } - def setLong(i: Int, value: Long): Unit = { update(i, value) } - def setFloat(i: Int, value: Float): Unit = { update(i, value) } - def setDouble(i: Int, value: Double): Unit = { update(i, value) } - - /** - * Update the decimal column at `i`. - * - * Note: In order to support update decimal with precision > 18 in UnsafeRow, - * CAN NOT call setNullAt() for decimal column on UnsafeRow, call setDecimal(i, null, precision). - */ - def setDecimal(i: Int, value: Decimal, precision: Int) { update(i, value) } -} - /** * A row implementation that uses an array of objects as the underlying storage. Note that, while * the array is not copied, and thus could technically be mutated after creation, this is not @@ -230,24 +203,9 @@ class GenericInternalRow(val values: Array[Any]) extends BaseGenericInternalRow override def numFields: Int = values.length - override def copy(): GenericInternalRow = this -} - -class GenericMutableRow(values: Array[Any]) extends MutableRow with BaseGenericInternalRow { - /** No-arg constructor for serialization. */ - protected def this() = this(null) - - def this(size: Int) = this(new Array[Any](size)) - - override protected def genericGet(ordinal: Int) = values(ordinal) - - override def toSeq(fieldTypes: Seq[DataType]): Seq[Any] = values - - override def numFields: Int = values.length - override def setNullAt(i: Int): Unit = { values(i) = null} override def update(i: Int, value: Any): Unit = { values(i) = value } - override def copy(): InternalRow = new GenericInternalRow(values.clone()) + override def copy(): GenericInternalRow = this } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index f80e6373d2f89..e476cb11a3517 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -105,7 +105,7 @@ class JacksonParser( } emptyRow } else { - val row = new GenericMutableRow(schema.length) + val row = new GenericInternalRow(schema.length) for (corruptIndex <- schema.getFieldIndex(columnNameOfCorruptRecord)) { require(schema(corruptIndex).dataType == StringType) row.update(corruptIndex, UTF8String.fromString(record)) @@ -363,7 +363,7 @@ class JacksonParser( parser: JsonParser, schema: StructType, fieldConverters: Seq[ValueConverter]): InternalRow = { - val row = new GenericMutableRow(schema.length) + val row = new GenericInternalRow(schema.length) while (nextUntil(parser, JsonToken.END_OBJECT)) { schema.getFieldIndex(parser.getCurrentName) match { case Some(index) => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 85563ddedc165..43b6afd9ad896 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -23,7 +23,7 @@ import java.sql.{Date, Timestamp} import scala.reflect.runtime.universe.typeOf import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions.{BoundReference, Literal, SpecificMutableRow} +import org.apache.spark.sql.catalyst.expressions.{BoundReference, Literal, SpecificInternalRow} import org.apache.spark.sql.catalyst.expressions.objects.NewInstance import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -94,7 +94,7 @@ object TestingUDT { .add("c", DoubleType, nullable = false) override def serialize(n: NestedStruct): Any = { - val row = new SpecificMutableRow(sqlType.asInstanceOf[StructType].map(_.dataType)) + val row = new SpecificInternalRow(sqlType.asInstanceOf[StructType].map(_.dataType)) row.setInt(0, n.a) row.setLong(1, n.b) row.setDouble(2, n.c) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 5588b4429164c..0cb201e4dae3e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -68,7 +68,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { val length = 5000 val expressions = List.fill(length)(EqualTo(Literal(1), Literal(1))) val plan = GenerateMutableProjection.generate(expressions) - val actual = plan(new GenericMutableRow(length)).toSeq(expressions.map(_.dataType)) + val actual = plan(new GenericInternalRow(length)).toSeq(expressions.map(_.dataType)) val expected = Seq.fill(length)(true) if (!checkResult(actual, expected)) { @@ -91,7 +91,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { val expression = CaseWhen((1 to cases).map(generateCase(_))) val plan = GenerateMutableProjection.generate(Seq(expression)) - val input = new GenericMutableRow(Array[Any](UTF8String.fromString(s"${clauses}:${cases}"))) + val input = new GenericInternalRow(Array[Any](UTF8String.fromString(s"${clauses}:${cases}"))) val actual = plan(input).toSeq(Seq(expression.dataType)) assert(actual(0) == cases) @@ -101,7 +101,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { val length = 5000 val expressions = Seq(CreateArray(List.fill(length)(EqualTo(Literal(1), Literal(1))))) val plan = GenerateMutableProjection.generate(expressions) - val actual = plan(new GenericMutableRow(length)).toSeq(expressions.map(_.dataType)) + val actual = plan(new GenericInternalRow(length)).toSeq(expressions.map(_.dataType)) val expected = Seq(new GenericArrayData(Seq.fill(length)(true))) if (!checkResult(actual, expected)) { @@ -116,7 +116,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { case (expr, i) => Seq(Literal(i), expr) })) val plan = GenerateMutableProjection.generate(expressions) - val actual = plan(new GenericMutableRow(length)).toSeq(expressions.map(_.dataType)).map { + val actual = plan(new GenericInternalRow(length)).toSeq(expressions.map(_.dataType)).map { case m: ArrayBasedMapData => ArrayBasedMapData.toScalaMap(m) } val expected = (0 until length).map((_, true)).toMap :: Nil @@ -130,7 +130,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { val length = 5000 val expressions = Seq(CreateStruct(List.fill(length)(EqualTo(Literal(1), Literal(1))))) val plan = GenerateMutableProjection.generate(expressions) - val actual = plan(new GenericMutableRow(length)).toSeq(expressions.map(_.dataType)) + val actual = plan(new GenericInternalRow(length)).toSeq(expressions.map(_.dataType)) val expected = Seq(InternalRow(Seq.fill(length)(true): _*)) if (!checkResult(actual, expected)) { @@ -145,7 +145,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { expr => Seq(Literal(expr.toString), expr) })) val plan = GenerateMutableProjection.generate(expressions) - val actual = plan(new GenericMutableRow(length)).toSeq(expressions.map(_.dataType)) + val actual = plan(new GenericInternalRow(length)).toSeq(expressions.map(_.dataType)) val expected = Seq(InternalRow(Seq.fill(length)(true): _*)) if (!checkResult(actual, expected)) { @@ -158,7 +158,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { val schema = StructType(Seq.fill(length)(StructField("int", IntegerType))) val expressions = Seq(CreateExternalRow(Seq.fill(length)(Literal(1)), schema)) val plan = GenerateMutableProjection.generate(expressions) - val actual = plan(new GenericMutableRow(length)).toSeq(expressions.map(_.dataType)) + val actual = plan(new GenericInternalRow(length)).toSeq(expressions.map(_.dataType)) val expected = Seq(Row.fromSeq(Seq.fill(length)(1))) if (!checkResult(actual, expected)) { @@ -174,7 +174,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { Literal.create("PST", StringType)) } val plan = GenerateMutableProjection.generate(expressions) - val actual = plan(new GenericMutableRow(length)).toSeq(expressions.map(_.dataType)) + val actual = plan(new GenericInternalRow(length)).toSeq(expressions.map(_.dataType)) val expected = Seq.fill(length)( DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2015-07-24 07:00:00"))) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MapDataSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MapDataSuite.scala index 0f1264c7c3269..25a675a90276d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MapDataSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MapDataSuite.scala @@ -45,7 +45,7 @@ class MapDataSuite extends SparkFunSuite { // UnsafeMapData val unsafeConverter = UnsafeProjection.create(Array[DataType](MapType(StringType, IntegerType))) - val row = new GenericMutableRow(1) + val row = new GenericInternalRow(1) def toUnsafeMap(map: ArrayBasedMapData): UnsafeMapData = { row.update(0, map) val unsafeRow = unsafeConverter.apply(row) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index 90790dda753f8..cf3cbe270753e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -37,7 +37,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { val fieldTypes: Array[DataType] = Array(LongType, LongType, IntegerType) val converter = UnsafeProjection.create(fieldTypes) - val row = new SpecificMutableRow(fieldTypes) + val row = new SpecificInternalRow(fieldTypes) row.setLong(0, 0) row.setLong(1, 1) row.setInt(2, 2) @@ -75,7 +75,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { val fieldTypes: Array[DataType] = Array(LongType, StringType, BinaryType) val converter = UnsafeProjection.create(fieldTypes) - val row = new SpecificMutableRow(fieldTypes) + val row = new SpecificInternalRow(fieldTypes) row.setLong(0, 0) row.update(1, UTF8String.fromString("Hello")) row.update(2, "World".getBytes(StandardCharsets.UTF_8)) @@ -94,7 +94,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { val fieldTypes: Array[DataType] = Array(LongType, StringType, DateType, TimestampType) val converter = UnsafeProjection.create(fieldTypes) - val row = new SpecificMutableRow(fieldTypes) + val row = new SpecificInternalRow(fieldTypes) row.setLong(0, 0) row.update(1, UTF8String.fromString("Hello")) row.update(2, DateTimeUtils.fromJavaDate(Date.valueOf("1970-01-01"))) @@ -138,7 +138,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { val converter = UnsafeProjection.create(fieldTypes) val rowWithAllNullColumns: InternalRow = { - val r = new SpecificMutableRow(fieldTypes) + val r = new SpecificInternalRow(fieldTypes) for (i <- fieldTypes.indices) { r.setNullAt(i) } @@ -167,7 +167,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { // columns, then the serialized row representation should be identical to what we would get by // creating an entirely null row via the converter val rowWithNoNullColumns: InternalRow = { - val r = new SpecificMutableRow(fieldTypes) + val r = new SpecificInternalRow(fieldTypes) r.setNullAt(0) r.setBoolean(1, false) r.setByte(2, 20) @@ -243,11 +243,11 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { test("NaN canonicalization") { val fieldTypes: Array[DataType] = Array(FloatType, DoubleType) - val row1 = new SpecificMutableRow(fieldTypes) + val row1 = new SpecificInternalRow(fieldTypes) row1.setFloat(0, java.lang.Float.intBitsToFloat(0x7f800001)) row1.setDouble(1, java.lang.Double.longBitsToDouble(0x7ff0000000000001L)) - val row2 = new SpecificMutableRow(fieldTypes) + val row2 = new SpecificInternalRow(fieldTypes) row2.setFloat(0, java.lang.Float.intBitsToFloat(0x7fffffff)) row2.setDouble(1, java.lang.Double.longBitsToDouble(0x7fffffffffffffffL)) @@ -263,7 +263,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { val converter = UnsafeProjection.create(fieldTypes) - val row = new GenericMutableRow(fieldTypes.length) + val row = new GenericInternalRow(fieldTypes.length) row.update(0, InternalRow(1)) row.update(1, InternalRow(InternalRow(2L))) @@ -324,7 +324,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { ) val converter = UnsafeProjection.create(fieldTypes) - val row = new GenericMutableRow(fieldTypes.length) + val row = new GenericInternalRow(fieldTypes.length) row.update(0, createArray(1, 2)) row.update(1, createArray(createArray(3, 4))) @@ -359,7 +359,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { val innerMap = createMap(5, 6)(7, 8) val map2 = createMap(9)(innerMap) - val row = new GenericMutableRow(fieldTypes.length) + val row = new GenericInternalRow(fieldTypes.length) row.update(0, map1) row.update(1, map2) @@ -400,7 +400,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { ) val converter = UnsafeProjection.create(fieldTypes) - val row = new GenericMutableRow(fieldTypes.length) + val row = new GenericInternalRow(fieldTypes.length) row.update(0, InternalRow(createArray(1))) row.update(1, createArray(InternalRow(2L))) @@ -439,7 +439,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { ) val converter = UnsafeProjection.create(fieldTypes) - val row = new GenericMutableRow(fieldTypes.length) + val row = new GenericInternalRow(fieldTypes.length) row.update(0, InternalRow(createMap(1)(2))) row.update(1, createMap(3)(InternalRow(4L))) @@ -485,7 +485,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { ) val converter = UnsafeProjection.create(fieldTypes) - val row = new GenericMutableRow(fieldTypes.length) + val row = new GenericInternalRow(fieldTypes.length) row.update(0, createArray(createMap(1)(2))) row.update(1, createMap(3)(createArray(4))) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala index 61298a1b72d77..8456e244609bc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedAttribu import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, BoundReference, Cast, CreateArray, DecimalLiteral, GenericMutableRow, Literal} +import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, BoundReference, Cast, CreateArray, DecimalLiteral, GenericInternalRow, Literal} import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.{PercentileDigest, PercentileDigestSerializer} import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.catalyst.util.ArrayData @@ -144,7 +144,8 @@ class ApproximatePercentileSuite extends SparkFunSuite { .withNewInputAggBufferOffset(inputAggregationBufferOffset) .withNewMutableAggBufferOffset(mutableAggregationBufferOffset) - val mutableAggBuffer = new GenericMutableRow(new Array[Any](mutableAggregationBufferOffset + 1)) + val mutableAggBuffer = new GenericInternalRow( + new Array[Any](mutableAggregationBufferOffset + 1)) agg.initialize(mutableAggBuffer) val dataCount = 10 (1 to dataCount).foreach { data => @@ -154,7 +155,7 @@ class ApproximatePercentileSuite extends SparkFunSuite { // Serialize the aggregation buffer val serialized = mutableAggBuffer.getBinary(mutableAggregationBufferOffset) - val inputAggBuffer = new GenericMutableRow(Array[Any](null, serialized)) + val inputAggBuffer = new GenericInternalRow(Array[Any](null, serialized)) // Phase 2: final mode aggregation // Re-initialize the aggregation buffer @@ -311,7 +312,7 @@ class ApproximatePercentileSuite extends SparkFunSuite { test("class ApproximatePercentile, null handling") { val childExpression = Cast(BoundReference(0, IntegerType, nullable = true), DoubleType) val agg = new ApproximatePercentile(childExpression, Literal(0.5D)) - val buffer = new GenericMutableRow(new Array[Any](1)) + val buffer = new GenericInternalRow(new Array[Any](1)) agg.initialize(buffer) // Empty aggregation buffer assert(agg.eval(buffer) == null) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlusSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlusSuite.scala index f5374229ca5cd..17f6b71bb270b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlusSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlusSuite.scala @@ -22,28 +22,29 @@ import java.util.Random import scala.collection.mutable import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions.{BoundReference, MutableRow, SpecificMutableRow} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{BoundReference, SpecificInternalRow} import org.apache.spark.sql.types.{DataType, IntegerType} class HyperLogLogPlusPlusSuite extends SparkFunSuite { /** Create a HLL++ instance and an input and output buffer. */ def createEstimator(rsd: Double, dt: DataType = IntegerType): - (HyperLogLogPlusPlus, MutableRow, MutableRow) = { - val input = new SpecificMutableRow(Seq(dt)) + (HyperLogLogPlusPlus, InternalRow, InternalRow) = { + val input = new SpecificInternalRow(Seq(dt)) val hll = new HyperLogLogPlusPlus(new BoundReference(0, dt, true), rsd) val buffer = createBuffer(hll) (hll, input, buffer) } - def createBuffer(hll: HyperLogLogPlusPlus): MutableRow = { - val buffer = new SpecificMutableRow(hll.aggBufferAttributes.map(_.dataType)) + def createBuffer(hll: HyperLogLogPlusPlus): InternalRow = { + val buffer = new SpecificInternalRow(hll.aggBufferAttributes.map(_.dataType)) hll.initialize(buffer) buffer } /** Evaluate the estimate. It should be within 3*SD's of the given true rsd. */ - def evaluateEstimate(hll: HyperLogLogPlusPlus, buffer: MutableRow, cardinality: Int): Unit = { + def evaluateEstimate(hll: HyperLogLogPlusPlus, buffer: InternalRow, cardinality: Int): Unit = { val estimate = hll.eval(buffer).asInstanceOf[Long].toDouble val error = math.abs((estimate / cardinality.toDouble) - 1.0d) assert(error < hll.trueRsd * 3.0d, "Error should be within 3 std. errors.") diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java index 62abc2a821a3a..a6ce4c2edc232 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java @@ -21,8 +21,7 @@ import org.apache.spark.memory.MemoryMode; import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.catalyst.expressions.GenericMutableRow; -import org.apache.spark.sql.catalyst.expressions.MutableRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; import org.apache.spark.sql.catalyst.util.ArrayData; import org.apache.spark.sql.catalyst.util.MapData; @@ -91,7 +90,7 @@ public void close() { * Adapter class to interop with existing components that expect internal row. A lot of * performance is lost with this translation. */ - public static final class Row extends MutableRow { + public static final class Row extends InternalRow { protected int rowId; private final ColumnarBatch parent; private final int fixedLenRowSize; @@ -129,7 +128,7 @@ public void markFiltered() { * Revisit this. This is expensive. This is currently only used in test paths. */ public InternalRow copy() { - GenericMutableRow row = new GenericMutableRow(columns.length); + GenericInternalRow row = new GenericInternalRow(columns.length); for (int i = 0; i < numFields(); i++) { if (isNullAt(i)) { row.setNullAt(i); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 6c4248c60e893..d3a22228623e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -32,7 +32,7 @@ object RDDConversions { def productToRowRdd[A <: Product](data: RDD[A], outputTypes: Seq[DataType]): RDD[InternalRow] = { data.mapPartitions { iterator => val numColumns = outputTypes.length - val mutableRow = new GenericMutableRow(numColumns) + val mutableRow = new GenericInternalRow(numColumns) val converters = outputTypes.map(CatalystTypeConverters.createToCatalystConverter) iterator.map { r => var i = 0 @@ -52,7 +52,7 @@ object RDDConversions { def rowToRowRdd(data: RDD[Row], outputTypes: Seq[DataType]): RDD[InternalRow] = { data.mapPartitions { iterator => val numColumns = outputTypes.length - val mutableRow = new GenericMutableRow(numColumns) + val mutableRow = new GenericInternalRow(numColumns) val converters = outputTypes.map(CatalystTypeConverters.createToCatalystConverter) iterator.map { r => var i = 0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala index f335912ba2c32..7c11fdb9792e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala @@ -153,7 +153,7 @@ abstract class AggregationIterator( protected def generateProcessRow( expressions: Seq[AggregateExpression], functions: Seq[AggregateFunction], - inputAttributes: Seq[Attribute]): (MutableRow, InternalRow) => Unit = { + inputAttributes: Seq[Attribute]): (InternalRow, InternalRow) => Unit = { val joinedRow = new JoinedRow if (expressions.nonEmpty) { val mergeExpressions = functions.zipWithIndex.flatMap { @@ -168,9 +168,9 @@ abstract class AggregationIterator( case (ae: ImperativeAggregate, i) => expressions(i).mode match { case Partial | Complete => - (buffer: MutableRow, row: InternalRow) => ae.update(buffer, row) + (buffer: InternalRow, row: InternalRow) => ae.update(buffer, row) case PartialMerge | Final => - (buffer: MutableRow, row: InternalRow) => ae.merge(buffer, row) + (buffer: InternalRow, row: InternalRow) => ae.merge(buffer, row) } }.toArray // This projection is used to merge buffer values for all expression-based aggregates. @@ -178,7 +178,7 @@ abstract class AggregationIterator( val updateProjection = newMutableProjection(mergeExpressions, aggregationBufferSchema ++ inputAttributes) - (currentBuffer: MutableRow, row: InternalRow) => { + (currentBuffer: InternalRow, row: InternalRow) => { // Process all expression-based aggregate functions. updateProjection.target(currentBuffer)(joinedRow(currentBuffer, row)) // Process all imperative aggregate functions. @@ -190,11 +190,11 @@ abstract class AggregationIterator( } } else { // Grouping only. - (currentBuffer: MutableRow, row: InternalRow) => {} + (currentBuffer: InternalRow, row: InternalRow) => {} } } - protected val processRow: (MutableRow, InternalRow) => Unit = + protected val processRow: (InternalRow, InternalRow) => Unit = generateProcessRow(aggregateExpressions, aggregateFunctions, inputAttributes) protected val groupingProjection: UnsafeProjection = @@ -202,7 +202,7 @@ abstract class AggregationIterator( protected val groupingAttributes = groupingExpressions.map(_.toAttribute) // Initializing the function used to generate the output row. - protected def generateResultProjection(): (UnsafeRow, MutableRow) => UnsafeRow = { + protected def generateResultProjection(): (UnsafeRow, InternalRow) => UnsafeRow = { val joinedRow = new JoinedRow val modes = aggregateExpressions.map(_.mode).distinct val bufferAttributes = aggregateFunctions.flatMap(_.aggBufferAttributes) @@ -211,14 +211,14 @@ abstract class AggregationIterator( case ae: DeclarativeAggregate => ae.evaluateExpression case agg: AggregateFunction => NoOp } - val aggregateResult = new SpecificMutableRow(aggregateAttributes.map(_.dataType)) + val aggregateResult = new SpecificInternalRow(aggregateAttributes.map(_.dataType)) val expressionAggEvalProjection = newMutableProjection(evalExpressions, bufferAttributes) expressionAggEvalProjection.target(aggregateResult) val resultProjection = UnsafeProjection.create(resultExpressions, groupingAttributes ++ aggregateAttributes) - (currentGroupingKey: UnsafeRow, currentBuffer: MutableRow) => { + (currentGroupingKey: UnsafeRow, currentBuffer: InternalRow) => { // Generate results for all expression-based aggregate functions. expressionAggEvalProjection(currentBuffer) // Generate results for all imperative aggregate functions. @@ -244,7 +244,7 @@ abstract class AggregationIterator( } } - (currentGroupingKey: UnsafeRow, currentBuffer: MutableRow) => { + (currentGroupingKey: UnsafeRow, currentBuffer: InternalRow) => { // Serializes the generic object stored in aggregation buffer var i = 0 while (i < typedImperativeAggregates.length) { @@ -256,17 +256,17 @@ abstract class AggregationIterator( } else { // Grouping-only: we only output values based on grouping expressions. val resultProjection = UnsafeProjection.create(resultExpressions, groupingAttributes) - (currentGroupingKey: UnsafeRow, currentBuffer: MutableRow) => { + (currentGroupingKey: UnsafeRow, currentBuffer: InternalRow) => { resultProjection(currentGroupingKey) } } } - protected val generateOutput: (UnsafeRow, MutableRow) => UnsafeRow = + protected val generateOutput: (UnsafeRow, InternalRow) => UnsafeRow = generateResultProjection() /** Initializes buffer values for all aggregate functions. */ - protected def initializeBuffer(buffer: MutableRow): Unit = { + protected def initializeBuffer(buffer: InternalRow): Unit = { expressionAggInitialProjection.target(buffer)(EmptyRow) var i = 0 while (i < allImperativeAggregateFunctions.length) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala index c2b1ef0fe3c2c..bea2dce1a7657 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala @@ -49,11 +49,11 @@ class SortBasedAggregationIterator( * Creates a new aggregation buffer and initializes buffer values * for all aggregate functions. */ - private def newBuffer: MutableRow = { + private def newBuffer: InternalRow = { val bufferSchema = aggregateFunctions.flatMap(_.aggBufferAttributes) val bufferRowSize: Int = bufferSchema.length - val genericMutableBuffer = new GenericMutableRow(bufferRowSize) + val genericMutableBuffer = new GenericInternalRow(bufferRowSize) val useUnsafeBuffer = bufferSchema.map(_.dataType).forall(UnsafeRow.isMutable) val buffer = if (useUnsafeBuffer) { @@ -84,7 +84,7 @@ class SortBasedAggregationIterator( private[this] var sortedInputHasNewGroup: Boolean = false // The aggregation buffer used by the sort-based aggregation. - private[this] val sortBasedAggregationBuffer: MutableRow = newBuffer + private[this] val sortBasedAggregationBuffer: InternalRow = newBuffer // This safe projection is used to turn the input row into safe row. This is necessary // because the input row may be produced by unsafe projection in child operator and all the diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index 4e072a92cc772..2988161ee5e7b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -118,7 +118,7 @@ class TungstenAggregationIterator( private def createNewAggregationBuffer(): UnsafeRow = { val bufferSchema = aggregateFunctions.flatMap(_.aggBufferAttributes) val buffer: UnsafeRow = UnsafeProjection.create(bufferSchema.map(_.dataType)) - .apply(new GenericMutableRow(bufferSchema.length)) + .apply(new GenericInternalRow(bufferSchema.length)) // Initialize declarative aggregates' buffer values expressionAggInitialProjection.target(buffer)(EmptyRow) // Initialize imperative aggregates' buffer values @@ -127,7 +127,7 @@ class TungstenAggregationIterator( } // Creates a function used to generate output rows. - override protected def generateResultProjection(): (UnsafeRow, MutableRow) => UnsafeRow = { + override protected def generateResultProjection(): (UnsafeRow, InternalRow) => UnsafeRow = { val modes = aggregateExpressions.map(_.mode).distinct if (modes.nonEmpty && !modes.contains(Final) && !modes.contains(Complete)) { // Fast path for partial aggregation, UnsafeRowJoiner is usually faster than projection @@ -137,7 +137,7 @@ class TungstenAggregationIterator( val bufferSchema = StructType.fromAttributes(bufferAttributes) val unsafeRowJoiner = GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema) - (currentGroupingKey: UnsafeRow, currentBuffer: MutableRow) => { + (currentGroupingKey: UnsafeRow, currentBuffer: InternalRow) => { unsafeRowJoiner.join(currentGroupingKey, currentBuffer.asInstanceOf[UnsafeRow]) } } else { @@ -300,7 +300,7 @@ class TungstenAggregationIterator( private[this] val sortBasedAggregationBuffer: UnsafeRow = createNewAggregationBuffer() // The function used to process rows in a group - private[this] var sortBasedProcessRow: (MutableRow, InternalRow) => Unit = null + private[this] var sortBasedProcessRow: (InternalRow, InternalRow) => Unit = null // Processes rows in the current group. It will stop when it find a new group. private def processCurrentSortedGroup(): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala index 586e1456ac69e..67760f334e406 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.internal.Logging import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, MutableRow, _} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, _} import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} @@ -96,18 +96,18 @@ sealed trait BufferSetterGetterUtils { getters } - def createSetters(schema: StructType): Array[((MutableRow, Int, Any) => Unit)] = { + def createSetters(schema: StructType): Array[((InternalRow, Int, Any) => Unit)] = { val dataTypes = schema.fields.map(_.dataType) - val setters = new Array[(MutableRow, Int, Any) => Unit](dataTypes.length) + val setters = new Array[(InternalRow, Int, Any) => Unit](dataTypes.length) var i = 0 while (i < setters.length) { setters(i) = dataTypes(i) match { case NullType => - (row: MutableRow, ordinal: Int, value: Any) => row.setNullAt(ordinal) + (row: InternalRow, ordinal: Int, value: Any) => row.setNullAt(ordinal) case b: BooleanType => - (row: MutableRow, ordinal: Int, value: Any) => + (row: InternalRow, ordinal: Int, value: Any) => if (value != null) { row.setBoolean(ordinal, value.asInstanceOf[Boolean]) } else { @@ -115,7 +115,7 @@ sealed trait BufferSetterGetterUtils { } case ByteType => - (row: MutableRow, ordinal: Int, value: Any) => + (row: InternalRow, ordinal: Int, value: Any) => if (value != null) { row.setByte(ordinal, value.asInstanceOf[Byte]) } else { @@ -123,7 +123,7 @@ sealed trait BufferSetterGetterUtils { } case ShortType => - (row: MutableRow, ordinal: Int, value: Any) => + (row: InternalRow, ordinal: Int, value: Any) => if (value != null) { row.setShort(ordinal, value.asInstanceOf[Short]) } else { @@ -131,7 +131,7 @@ sealed trait BufferSetterGetterUtils { } case IntegerType => - (row: MutableRow, ordinal: Int, value: Any) => + (row: InternalRow, ordinal: Int, value: Any) => if (value != null) { row.setInt(ordinal, value.asInstanceOf[Int]) } else { @@ -139,7 +139,7 @@ sealed trait BufferSetterGetterUtils { } case LongType => - (row: MutableRow, ordinal: Int, value: Any) => + (row: InternalRow, ordinal: Int, value: Any) => if (value != null) { row.setLong(ordinal, value.asInstanceOf[Long]) } else { @@ -147,7 +147,7 @@ sealed trait BufferSetterGetterUtils { } case FloatType => - (row: MutableRow, ordinal: Int, value: Any) => + (row: InternalRow, ordinal: Int, value: Any) => if (value != null) { row.setFloat(ordinal, value.asInstanceOf[Float]) } else { @@ -155,7 +155,7 @@ sealed trait BufferSetterGetterUtils { } case DoubleType => - (row: MutableRow, ordinal: Int, value: Any) => + (row: InternalRow, ordinal: Int, value: Any) => if (value != null) { row.setDouble(ordinal, value.asInstanceOf[Double]) } else { @@ -164,13 +164,13 @@ sealed trait BufferSetterGetterUtils { case dt: DecimalType => val precision = dt.precision - (row: MutableRow, ordinal: Int, value: Any) => + (row: InternalRow, ordinal: Int, value: Any) => // To make it work with UnsafeRow, we cannot use setNullAt. // Please see the comment of UnsafeRow's setDecimal. row.setDecimal(ordinal, value.asInstanceOf[Decimal], precision) case DateType => - (row: MutableRow, ordinal: Int, value: Any) => + (row: InternalRow, ordinal: Int, value: Any) => if (value != null) { row.setInt(ordinal, value.asInstanceOf[Int]) } else { @@ -178,7 +178,7 @@ sealed trait BufferSetterGetterUtils { } case TimestampType => - (row: MutableRow, ordinal: Int, value: Any) => + (row: InternalRow, ordinal: Int, value: Any) => if (value != null) { row.setLong(ordinal, value.asInstanceOf[Long]) } else { @@ -186,7 +186,7 @@ sealed trait BufferSetterGetterUtils { } case other => - (row: MutableRow, ordinal: Int, value: Any) => + (row: InternalRow, ordinal: Int, value: Any) => if (value != null) { row.update(ordinal, value) } else { @@ -209,7 +209,7 @@ private[aggregate] class MutableAggregationBufferImpl( toCatalystConverters: Array[Any => Any], toScalaConverters: Array[Any => Any], bufferOffset: Int, - var underlyingBuffer: MutableRow) + var underlyingBuffer: InternalRow) extends MutableAggregationBuffer with BufferSetterGetterUtils { private[this] val offsets: Array[Int] = { @@ -413,13 +413,13 @@ case class ScalaUDAF( null) } - override def initialize(buffer: MutableRow): Unit = { + override def initialize(buffer: InternalRow): Unit = { mutableAggregateBuffer.underlyingBuffer = buffer udaf.initialize(mutableAggregateBuffer) } - override def update(buffer: MutableRow, input: InternalRow): Unit = { + override def update(buffer: InternalRow, input: InternalRow): Unit = { mutableAggregateBuffer.underlyingBuffer = buffer udaf.update( @@ -427,7 +427,7 @@ case class ScalaUDAF( inputToScalaConverters(inputProjection(input)).asInstanceOf[Row]) } - override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { + override def merge(buffer1: InternalRow, buffer2: InternalRow): Unit = { mutableAggregateBuffer.underlyingBuffer = buffer1 inputAggregateBuffer.underlyingInputBuffer = buffer2 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala index 7cde04b62619e..6241b79d9affc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala @@ -21,15 +21,16 @@ import java.nio.{ByteBuffer, ByteOrder} import scala.annotation.tailrec -import org.apache.spark.sql.catalyst.expressions.{MutableRow, UnsafeArrayData, UnsafeMapData, UnsafeRow} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{UnsafeArrayData, UnsafeMapData, UnsafeRow} import org.apache.spark.sql.execution.columnar.compression.CompressibleColumnAccessor import org.apache.spark.sql.types._ /** * An `Iterator` like trait used to extract values from columnar byte buffer. When a value is * extracted from the buffer, instead of directly returning it, the value is set into some field of - * a [[MutableRow]]. In this way, boxing cost can be avoided by leveraging the setter methods - * for primitive values provided by [[MutableRow]]. + * a [[InternalRow]]. In this way, boxing cost can be avoided by leveraging the setter methods + * for primitive values provided by [[InternalRow]]. */ private[columnar] trait ColumnAccessor { initialize() @@ -38,7 +39,7 @@ private[columnar] trait ColumnAccessor { def hasNext: Boolean - def extractTo(row: MutableRow, ordinal: Int): Unit + def extractTo(row: InternalRow, ordinal: Int): Unit protected def underlyingBuffer: ByteBuffer } @@ -52,11 +53,11 @@ private[columnar] abstract class BasicColumnAccessor[JvmType]( override def hasNext: Boolean = buffer.hasRemaining - override def extractTo(row: MutableRow, ordinal: Int): Unit = { + override def extractTo(row: InternalRow, ordinal: Int): Unit = { extractSingle(row, ordinal) } - def extractSingle(row: MutableRow, ordinal: Int): Unit = { + def extractSingle(row: InternalRow, ordinal: Int): Unit = { columnType.extract(buffer, row, ordinal) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala index d27d8c362dd9a..703bde25316df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala @@ -92,7 +92,7 @@ private[columnar] sealed abstract class ColumnType[JvmType] { * `row(ordinal)`. Subclasses should override this method to avoid boxing/unboxing costs whenever * possible. */ - def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { + def extract(buffer: ByteBuffer, row: InternalRow, ordinal: Int): Unit = { setField(row, ordinal, extract(buffer)) } @@ -125,13 +125,13 @@ private[columnar] sealed abstract class ColumnType[JvmType] { * Sets `row(ordinal)` to `field`. Subclasses should override this method to avoid boxing/unboxing * costs whenever possible. */ - def setField(row: MutableRow, ordinal: Int, value: JvmType): Unit + def setField(row: InternalRow, ordinal: Int, value: JvmType): Unit /** * Copies `from(fromOrdinal)` to `to(toOrdinal)`. Subclasses should override this method to avoid * boxing/unboxing costs whenever possible. */ - def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { + def copyField(from: InternalRow, fromOrdinal: Int, to: InternalRow, toOrdinal: Int): Unit = { setField(to, toOrdinal, getField(from, fromOrdinal)) } @@ -149,7 +149,7 @@ private[columnar] object NULL extends ColumnType[Any] { override def defaultSize: Int = 0 override def append(v: Any, buffer: ByteBuffer): Unit = {} override def extract(buffer: ByteBuffer): Any = null - override def setField(row: MutableRow, ordinal: Int, value: Any): Unit = row.setNullAt(ordinal) + override def setField(row: InternalRow, ordinal: Int, value: Any): Unit = row.setNullAt(ordinal) override def getField(row: InternalRow, ordinal: Int): Any = null } @@ -177,18 +177,18 @@ private[columnar] object INT extends NativeColumnType(IntegerType, 4) { ByteBufferHelper.getInt(buffer) } - override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { + override def extract(buffer: ByteBuffer, row: InternalRow, ordinal: Int): Unit = { row.setInt(ordinal, ByteBufferHelper.getInt(buffer)) } - override def setField(row: MutableRow, ordinal: Int, value: Int): Unit = { + override def setField(row: InternalRow, ordinal: Int, value: Int): Unit = { row.setInt(ordinal, value) } override def getField(row: InternalRow, ordinal: Int): Int = row.getInt(ordinal) - override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { + override def copyField(from: InternalRow, fromOrdinal: Int, to: InternalRow, toOrdinal: Int) { to.setInt(toOrdinal, from.getInt(fromOrdinal)) } } @@ -206,17 +206,17 @@ private[columnar] object LONG extends NativeColumnType(LongType, 8) { ByteBufferHelper.getLong(buffer) } - override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { + override def extract(buffer: ByteBuffer, row: InternalRow, ordinal: Int): Unit = { row.setLong(ordinal, ByteBufferHelper.getLong(buffer)) } - override def setField(row: MutableRow, ordinal: Int, value: Long): Unit = { + override def setField(row: InternalRow, ordinal: Int, value: Long): Unit = { row.setLong(ordinal, value) } override def getField(row: InternalRow, ordinal: Int): Long = row.getLong(ordinal) - override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { + override def copyField(from: InternalRow, fromOrdinal: Int, to: InternalRow, toOrdinal: Int) { to.setLong(toOrdinal, from.getLong(fromOrdinal)) } } @@ -234,17 +234,17 @@ private[columnar] object FLOAT extends NativeColumnType(FloatType, 4) { ByteBufferHelper.getFloat(buffer) } - override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { + override def extract(buffer: ByteBuffer, row: InternalRow, ordinal: Int): Unit = { row.setFloat(ordinal, ByteBufferHelper.getFloat(buffer)) } - override def setField(row: MutableRow, ordinal: Int, value: Float): Unit = { + override def setField(row: InternalRow, ordinal: Int, value: Float): Unit = { row.setFloat(ordinal, value) } override def getField(row: InternalRow, ordinal: Int): Float = row.getFloat(ordinal) - override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { + override def copyField(from: InternalRow, fromOrdinal: Int, to: InternalRow, toOrdinal: Int) { to.setFloat(toOrdinal, from.getFloat(fromOrdinal)) } } @@ -262,17 +262,17 @@ private[columnar] object DOUBLE extends NativeColumnType(DoubleType, 8) { ByteBufferHelper.getDouble(buffer) } - override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { + override def extract(buffer: ByteBuffer, row: InternalRow, ordinal: Int): Unit = { row.setDouble(ordinal, ByteBufferHelper.getDouble(buffer)) } - override def setField(row: MutableRow, ordinal: Int, value: Double): Unit = { + override def setField(row: InternalRow, ordinal: Int, value: Double): Unit = { row.setDouble(ordinal, value) } override def getField(row: InternalRow, ordinal: Int): Double = row.getDouble(ordinal) - override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { + override def copyField(from: InternalRow, fromOrdinal: Int, to: InternalRow, toOrdinal: Int) { to.setDouble(toOrdinal, from.getDouble(fromOrdinal)) } } @@ -288,17 +288,17 @@ private[columnar] object BOOLEAN extends NativeColumnType(BooleanType, 1) { override def extract(buffer: ByteBuffer): Boolean = buffer.get() == 1 - override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { + override def extract(buffer: ByteBuffer, row: InternalRow, ordinal: Int): Unit = { row.setBoolean(ordinal, buffer.get() == 1) } - override def setField(row: MutableRow, ordinal: Int, value: Boolean): Unit = { + override def setField(row: InternalRow, ordinal: Int, value: Boolean): Unit = { row.setBoolean(ordinal, value) } override def getField(row: InternalRow, ordinal: Int): Boolean = row.getBoolean(ordinal) - override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { + override def copyField(from: InternalRow, fromOrdinal: Int, to: InternalRow, toOrdinal: Int) { to.setBoolean(toOrdinal, from.getBoolean(fromOrdinal)) } } @@ -316,17 +316,17 @@ private[columnar] object BYTE extends NativeColumnType(ByteType, 1) { buffer.get() } - override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { + override def extract(buffer: ByteBuffer, row: InternalRow, ordinal: Int): Unit = { row.setByte(ordinal, buffer.get()) } - override def setField(row: MutableRow, ordinal: Int, value: Byte): Unit = { + override def setField(row: InternalRow, ordinal: Int, value: Byte): Unit = { row.setByte(ordinal, value) } override def getField(row: InternalRow, ordinal: Int): Byte = row.getByte(ordinal) - override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { + override def copyField(from: InternalRow, fromOrdinal: Int, to: InternalRow, toOrdinal: Int) { to.setByte(toOrdinal, from.getByte(fromOrdinal)) } } @@ -344,17 +344,17 @@ private[columnar] object SHORT extends NativeColumnType(ShortType, 2) { buffer.getShort() } - override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { + override def extract(buffer: ByteBuffer, row: InternalRow, ordinal: Int): Unit = { row.setShort(ordinal, buffer.getShort()) } - override def setField(row: MutableRow, ordinal: Int, value: Short): Unit = { + override def setField(row: InternalRow, ordinal: Int, value: Short): Unit = { row.setShort(ordinal, value) } override def getField(row: InternalRow, ordinal: Int): Short = row.getShort(ordinal) - override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { + override def copyField(from: InternalRow, fromOrdinal: Int, to: InternalRow, toOrdinal: Int) { to.setShort(toOrdinal, from.getShort(fromOrdinal)) } } @@ -366,7 +366,7 @@ private[columnar] object SHORT extends NativeColumnType(ShortType, 2) { private[columnar] trait DirectCopyColumnType[JvmType] extends ColumnType[JvmType] { // copy the bytes from ByteBuffer to UnsafeRow - override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { + override def extract(buffer: ByteBuffer, row: InternalRow, ordinal: Int): Unit = { if (row.isInstanceOf[MutableUnsafeRow]) { val numBytes = buffer.getInt val cursor = buffer.position() @@ -407,7 +407,7 @@ private[columnar] object STRING UTF8String.fromBytes(buffer.array(), buffer.arrayOffset() + cursor, length) } - override def setField(row: MutableRow, ordinal: Int, value: UTF8String): Unit = { + override def setField(row: InternalRow, ordinal: Int, value: UTF8String): Unit = { if (row.isInstanceOf[MutableUnsafeRow]) { row.asInstanceOf[MutableUnsafeRow].writer.write(ordinal, value) } else { @@ -419,7 +419,7 @@ private[columnar] object STRING row.getUTF8String(ordinal) } - override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { + override def copyField(from: InternalRow, fromOrdinal: Int, to: InternalRow, toOrdinal: Int) { setField(to, toOrdinal, getField(from, fromOrdinal)) } @@ -433,7 +433,7 @@ private[columnar] case class COMPACT_DECIMAL(precision: Int, scale: Int) Decimal(ByteBufferHelper.getLong(buffer), precision, scale) } - override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { + override def extract(buffer: ByteBuffer, row: InternalRow, ordinal: Int): Unit = { if (row.isInstanceOf[MutableUnsafeRow]) { // copy it as Long row.setLong(ordinal, ByteBufferHelper.getLong(buffer)) @@ -459,11 +459,11 @@ private[columnar] case class COMPACT_DECIMAL(precision: Int, scale: Int) row.getDecimal(ordinal, precision, scale) } - override def setField(row: MutableRow, ordinal: Int, value: Decimal): Unit = { + override def setField(row: InternalRow, ordinal: Int, value: Decimal): Unit = { row.setDecimal(ordinal, value, precision) } - override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { + override def copyField(from: InternalRow, fromOrdinal: Int, to: InternalRow, toOrdinal: Int) { setField(to, toOrdinal, getField(from, fromOrdinal)) } } @@ -497,7 +497,7 @@ private[columnar] object BINARY extends ByteArrayColumnType[Array[Byte]](16) { def dataType: DataType = BinaryType - override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]): Unit = { + override def setField(row: InternalRow, ordinal: Int, value: Array[Byte]): Unit = { row.update(ordinal, value) } @@ -522,7 +522,7 @@ private[columnar] case class LARGE_DECIMAL(precision: Int, scale: Int) row.getDecimal(ordinal, precision, scale) } - override def setField(row: MutableRow, ordinal: Int, value: Decimal): Unit = { + override def setField(row: InternalRow, ordinal: Int, value: Decimal): Unit = { row.setDecimal(ordinal, value, precision) } @@ -553,7 +553,7 @@ private[columnar] case class STRUCT(dataType: StructType) override def defaultSize: Int = 20 - override def setField(row: MutableRow, ordinal: Int, value: UnsafeRow): Unit = { + override def setField(row: InternalRow, ordinal: Int, value: UnsafeRow): Unit = { row.update(ordinal, value) } @@ -591,7 +591,7 @@ private[columnar] case class ARRAY(dataType: ArrayType) override def defaultSize: Int = 28 - override def setField(row: MutableRow, ordinal: Int, value: UnsafeArrayData): Unit = { + override def setField(row: InternalRow, ordinal: Int, value: UnsafeArrayData): Unit = { row.update(ordinal, value) } @@ -630,7 +630,7 @@ private[columnar] case class MAP(dataType: MapType) override def defaultSize: Int = 68 - override def setField(row: MutableRow, ordinal: Int, value: UnsafeMapData): Unit = { + override def setField(row: InternalRow, ordinal: Int, value: UnsafeMapData): Unit = { row.update(ordinal, value) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala index 96bd338f092e5..14024d6c10558 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala @@ -36,8 +36,7 @@ abstract class ColumnarIterator extends Iterator[InternalRow] { * * WARNING: These setter MUST be called in increasing order of ordinals. */ -class MutableUnsafeRow(val writer: UnsafeRowWriter) extends GenericMutableRow(null) { - +class MutableUnsafeRow(val writer: UnsafeRowWriter) extends BaseGenericInternalRow { override def isNullAt(i: Int): Boolean = writer.isNullAt(i) override def setNullAt(i: Int): Unit = writer.setNullAt(i) @@ -55,6 +54,9 @@ class MutableUnsafeRow(val writer: UnsafeRowWriter) extends GenericMutableRow(nu override def update(i: Int, v: Any): Unit = throw new UnsupportedOperationException // all other methods inherited from GenericMutableRow are not need + override protected def genericGet(ordinal: Int): Any = throw new UnsupportedOperationException + override def numFields: Int = throw new UnsupportedOperationException + override def copy(): InternalRow = throw new UnsupportedOperationException } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessor.scala index 2465633162c4e..2f09757aa341c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessor.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.columnar import java.nio.{ByteBuffer, ByteOrder} -import org.apache.spark.sql.catalyst.expressions.MutableRow +import org.apache.spark.sql.catalyst.InternalRow private[columnar] trait NullableColumnAccessor extends ColumnAccessor { private var nullsBuffer: ByteBuffer = _ @@ -39,7 +39,7 @@ private[columnar] trait NullableColumnAccessor extends ColumnAccessor { super.initialize() } - abstract override def extractTo(row: MutableRow, ordinal: Int): Unit = { + abstract override def extractTo(row: InternalRow, ordinal: Int): Unit = { if (pos == nextNullIndex) { seenNulls += 1 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressibleColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressibleColumnAccessor.scala index 6579b5068e65a..e1d13ad0e94e5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressibleColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressibleColumnAccessor.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.columnar.compression -import org.apache.spark.sql.catalyst.expressions.MutableRow +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.columnar.{ColumnAccessor, NativeColumnAccessor} import org.apache.spark.sql.types.AtomicType @@ -33,7 +33,7 @@ private[columnar] trait CompressibleColumnAccessor[T <: AtomicType] extends Colu abstract override def hasNext: Boolean = super.hasNext || decoder.hasNext - override def extractSingle(row: MutableRow, ordinal: Int): Unit = { + override def extractSingle(row: InternalRow, ordinal: Int): Unit = { decoder.next(row, ordinal) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressionScheme.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressionScheme.scala index b90d00b15b180..6e4f1c5b80684 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressionScheme.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressionScheme.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.execution.columnar.compression import java.nio.{ByteBuffer, ByteOrder} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.MutableRow import org.apache.spark.sql.execution.columnar.{ColumnType, NativeColumnType} import org.apache.spark.sql.types.AtomicType @@ -39,7 +38,7 @@ private[columnar] trait Encoder[T <: AtomicType] { } private[columnar] trait Decoder[T <: AtomicType] { - def next(row: MutableRow, ordinal: Int): Unit + def next(row: InternalRow, ordinal: Int): Unit def hasNext: Boolean } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala index 941f03b745a07..ee99c90a751d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala @@ -22,7 +22,7 @@ import java.nio.ByteBuffer import scala.collection.mutable import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{MutableRow, SpecificMutableRow} +import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow import org.apache.spark.sql.execution.columnar._ import org.apache.spark.sql.types._ @@ -56,7 +56,7 @@ private[columnar] case object PassThrough extends CompressionScheme { class Decoder[T <: AtomicType](buffer: ByteBuffer, columnType: NativeColumnType[T]) extends compression.Decoder[T] { - override def next(row: MutableRow, ordinal: Int): Unit = { + override def next(row: InternalRow, ordinal: Int): Unit = { columnType.extract(buffer, row, ordinal) } @@ -86,7 +86,7 @@ private[columnar] case object RunLengthEncoding extends CompressionScheme { private var _compressedSize = 0 // Using `MutableRow` to store the last value to avoid boxing/unboxing cost. - private val lastValue = new SpecificMutableRow(Seq(columnType.dataType)) + private val lastValue = new SpecificInternalRow(Seq(columnType.dataType)) private var lastRun = 0 override def uncompressedSize: Int = _uncompressedSize @@ -117,9 +117,9 @@ private[columnar] case object RunLengthEncoding extends CompressionScheme { to.putInt(RunLengthEncoding.typeId) if (from.hasRemaining) { - val currentValue = new SpecificMutableRow(Seq(columnType.dataType)) + val currentValue = new SpecificInternalRow(Seq(columnType.dataType)) var currentRun = 1 - val value = new SpecificMutableRow(Seq(columnType.dataType)) + val value = new SpecificInternalRow(Seq(columnType.dataType)) columnType.extract(from, currentValue, 0) @@ -156,7 +156,7 @@ private[columnar] case object RunLengthEncoding extends CompressionScheme { private var valueCount = 0 private var currentValue: T#InternalType = _ - override def next(row: MutableRow, ordinal: Int): Unit = { + override def next(row: InternalRow, ordinal: Int): Unit = { if (valueCount == run) { currentValue = columnType.extract(buffer) run = ByteBufferHelper.getInt(buffer) @@ -273,7 +273,7 @@ private[columnar] case object DictionaryEncoding extends CompressionScheme { Array.fill[Any](elementNum)(columnType.extract(buffer).asInstanceOf[Any]) } - override def next(row: MutableRow, ordinal: Int): Unit = { + override def next(row: InternalRow, ordinal: Int): Unit = { columnType.setField(row, ordinal, dictionary(buffer.getShort()).asInstanceOf[T#InternalType]) } @@ -356,7 +356,7 @@ private[columnar] case object BooleanBitSet extends CompressionScheme { private var visited: Int = 0 - override def next(row: MutableRow, ordinal: Int): Unit = { + override def next(row: InternalRow, ordinal: Int): Unit = { val bit = visited % BITS_PER_LONG visited += 1 @@ -443,7 +443,7 @@ private[columnar] case object IntDelta extends CompressionScheme { override def hasNext: Boolean = buffer.hasRemaining - override def next(row: MutableRow, ordinal: Int): Unit = { + override def next(row: InternalRow, ordinal: Int): Unit = { val delta = buffer.get() prev = if (delta > Byte.MinValue) prev + delta else ByteBufferHelper.getInt(buffer) row.setInt(ordinal, prev) @@ -523,7 +523,7 @@ private[columnar] case object LongDelta extends CompressionScheme { override def hasNext: Boolean = buffer.hasRemaining - override def next(row: MutableRow, ordinal: Int): Unit = { + override def next(row: InternalRow, ordinal: Int): Unit = { val delta = buffer.get() prev = if (delta > Byte.MinValue) prev + delta else ByteBufferHelper.getLong(buffer) row.setLong(ordinal, prev) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 693b4c4d0e5e9..6f9ed50a02b09 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -273,7 +273,7 @@ object DataSourceStrategy extends Strategy with Logging { // Get the bucket ID based on the bucketing values. // Restriction: Bucket pruning works iff the bucketing column has one and only one column. def getBucketId(bucketColumn: Attribute, numBuckets: Int, value: Any): Int = { - val mutableRow = new SpecificMutableRow(Seq(bucketColumn.dataType)) + val mutableRow = new SpecificInternalRow(Seq(bucketColumn.dataType)) mutableRow(0) = Cast(Literal(value), bucketColumn.dataType).eval(null) val bucketIdGeneration = UnsafeProjection.create( HashPartitioning(bucketColumn :: Nil, numBuckets).partitionIdExpression :: Nil, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala index 33b170bc31f62..55cb26d6513af 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala @@ -29,7 +29,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.GenericMutableRow +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources.{OutputWriter, OutputWriterFactory, PartitionedFile, WriterContainer} import org.apache.spark.sql.types._ @@ -88,7 +88,7 @@ object CSVRelation extends Logging { case (field, index) => safeRequiredIndices(safeRequiredFields.indexOf(field)) = index } val requiredSize = requiredFields.length - val row = new GenericMutableRow(requiredSize) + val row = new GenericInternalRow(requiredSize) (tokens: Array[String], numMalformedRows) => { if (params.dropMalformed && schemaFields.length != tokens.length) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 66f2bada2e3d8..47549637b5813 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -30,7 +30,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.RowEncoder -import org.apache.spark.sql.catalyst.expressions.{MutableRow, SpecificMutableRow} +import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData} import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects, JdbcType} import org.apache.spark.sql.types._ @@ -283,7 +283,7 @@ object JdbcUtils extends Logging { new NextIterator[InternalRow] { private[this] val rs = resultSet private[this] val getters: Array[JDBCValueGetter] = makeGetters(schema) - private[this] val mutableRow = new SpecificMutableRow(schema.fields.map(x => x.dataType)) + private[this] val mutableRow = new SpecificInternalRow(schema.fields.map(x => x.dataType)) override protected def close(): Unit = { try { @@ -314,22 +314,22 @@ object JdbcUtils extends Logging { // A `JDBCValueGetter` is responsible for getting a value from `ResultSet` into a field // for `MutableRow`. The last argument `Int` means the index for the value to be set in // the row and also used for the value in `ResultSet`. - private type JDBCValueGetter = (ResultSet, MutableRow, Int) => Unit + private type JDBCValueGetter = (ResultSet, InternalRow, Int) => Unit /** * Creates `JDBCValueGetter`s according to [[StructType]], which can set - * each value from `ResultSet` to each field of [[MutableRow]] correctly. + * each value from `ResultSet` to each field of [[InternalRow]] correctly. */ private def makeGetters(schema: StructType): Array[JDBCValueGetter] = schema.fields.map(sf => makeGetter(sf.dataType, sf.metadata)) private def makeGetter(dt: DataType, metadata: Metadata): JDBCValueGetter = dt match { case BooleanType => - (rs: ResultSet, row: MutableRow, pos: Int) => + (rs: ResultSet, row: InternalRow, pos: Int) => row.setBoolean(pos, rs.getBoolean(pos + 1)) case DateType => - (rs: ResultSet, row: MutableRow, pos: Int) => + (rs: ResultSet, row: InternalRow, pos: Int) => // DateTimeUtils.fromJavaDate does not handle null value, so we need to check it. val dateVal = rs.getDate(pos + 1) if (dateVal != null) { @@ -347,25 +347,25 @@ object JdbcUtils extends Logging { // retrieve it, you will get wrong result 199.99. // So it is needed to set precision and scale for Decimal based on JDBC metadata. case DecimalType.Fixed(p, s) => - (rs: ResultSet, row: MutableRow, pos: Int) => + (rs: ResultSet, row: InternalRow, pos: Int) => val decimal = nullSafeConvert[java.math.BigDecimal](rs.getBigDecimal(pos + 1), d => Decimal(d, p, s)) row.update(pos, decimal) case DoubleType => - (rs: ResultSet, row: MutableRow, pos: Int) => + (rs: ResultSet, row: InternalRow, pos: Int) => row.setDouble(pos, rs.getDouble(pos + 1)) case FloatType => - (rs: ResultSet, row: MutableRow, pos: Int) => + (rs: ResultSet, row: InternalRow, pos: Int) => row.setFloat(pos, rs.getFloat(pos + 1)) case IntegerType => - (rs: ResultSet, row: MutableRow, pos: Int) => + (rs: ResultSet, row: InternalRow, pos: Int) => row.setInt(pos, rs.getInt(pos + 1)) case LongType if metadata.contains("binarylong") => - (rs: ResultSet, row: MutableRow, pos: Int) => + (rs: ResultSet, row: InternalRow, pos: Int) => val bytes = rs.getBytes(pos + 1) var ans = 0L var j = 0 @@ -376,20 +376,20 @@ object JdbcUtils extends Logging { row.setLong(pos, ans) case LongType => - (rs: ResultSet, row: MutableRow, pos: Int) => + (rs: ResultSet, row: InternalRow, pos: Int) => row.setLong(pos, rs.getLong(pos + 1)) case ShortType => - (rs: ResultSet, row: MutableRow, pos: Int) => + (rs: ResultSet, row: InternalRow, pos: Int) => row.setShort(pos, rs.getShort(pos + 1)) case StringType => - (rs: ResultSet, row: MutableRow, pos: Int) => + (rs: ResultSet, row: InternalRow, pos: Int) => // TODO(davies): use getBytes for better performance, if the encoding is UTF-8 row.update(pos, UTF8String.fromString(rs.getString(pos + 1))) case TimestampType => - (rs: ResultSet, row: MutableRow, pos: Int) => + (rs: ResultSet, row: InternalRow, pos: Int) => val t = rs.getTimestamp(pos + 1) if (t != null) { row.setLong(pos, DateTimeUtils.fromJavaTimestamp(t)) @@ -398,7 +398,7 @@ object JdbcUtils extends Logging { } case BinaryType => - (rs: ResultSet, row: MutableRow, pos: Int) => + (rs: ResultSet, row: InternalRow, pos: Int) => row.update(pos, rs.getBytes(pos + 1)) case ArrayType(et, _) => @@ -437,7 +437,7 @@ object JdbcUtils extends Logging { case _ => (array: Object) => array.asInstanceOf[Array[Any]] } - (rs: ResultSet, row: MutableRow, pos: Int) => + (rs: ResultSet, row: InternalRow, pos: Int) => val array = nullSafeConvert[Object]( rs.getArray(pos + 1).getArray, array => new GenericArrayData(elementConversion.apply(array))) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala index 9ffc2b5dd8a56..33dcf2f3fd167 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala @@ -40,7 +40,7 @@ import org.apache.spark.unsafe.types.UTF8String /** * A [[ParentContainerUpdater]] is used by a Parquet converter to set converted values to some * corresponding parent container. For example, a converter for a `StructType` field may set - * converted values to a [[MutableRow]]; or a converter for array elements may append converted + * converted values to a [[InternalRow]]; or a converter for array elements may append converted * values to an [[ArrayBuffer]]. */ private[parquet] trait ParentContainerUpdater { @@ -155,7 +155,7 @@ private[parquet] class ParquetRowConverter( * Updater used together with field converters within a [[ParquetRowConverter]]. It propagates * converted filed values to the `ordinal`-th cell in `currentRow`. */ - private final class RowUpdater(row: MutableRow, ordinal: Int) extends ParentContainerUpdater { + private final class RowUpdater(row: InternalRow, ordinal: Int) extends ParentContainerUpdater { override def set(value: Any): Unit = row(ordinal) = value override def setBoolean(value: Boolean): Unit = row.setBoolean(ordinal, value) override def setByte(value: Byte): Unit = row.setByte(ordinal, value) @@ -166,7 +166,7 @@ private[parquet] class ParquetRowConverter( override def setFloat(value: Float): Unit = row.setFloat(ordinal, value) } - private val currentRow = new SpecificMutableRow(catalystType.map(_.dataType)) + private val currentRow = new SpecificInternalRow(catalystType.map(_.dataType)) private val unsafeProjection = UnsafeProjection.create(catalystType) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala index 43cdce7de8c7f..bfe7e3dea45df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala @@ -119,7 +119,7 @@ case class BroadcastNestedLoopJoinExec( streamed.execute().mapPartitionsInternal { streamedIter => val buildRows = relation.value val joinedRow = new JoinedRow - val nulls = new GenericMutableRow(broadcast.output.size) + val nulls = new GenericInternalRow(broadcast.output.size) // Returns an iterator to avoid copy the rows. new Iterator[InternalRow] { @@ -205,14 +205,14 @@ case class BroadcastNestedLoopJoinExec( val joinedRow = new JoinedRow if (condition.isDefined) { - val resultRow = new GenericMutableRow(Array[Any](null)) + val resultRow = new GenericInternalRow(Array[Any](null)) streamedIter.map { row => val result = buildRows.exists(r => boundCondition(joinedRow(row, r))) resultRow.setBoolean(0, result) joinedRow(row, resultRow) } } else { - val resultRow = new GenericMutableRow(Array[Any](buildRows.nonEmpty)) + val resultRow = new GenericInternalRow(Array[Any](buildRows.nonEmpty)) streamedIter.map { row => joinedRow(row, resultRow) } @@ -293,7 +293,7 @@ case class BroadcastNestedLoopJoinExec( } val notMatchedBroadcastRows: Seq[InternalRow] = { - val nulls = new GenericMutableRow(streamed.output.size) + val nulls = new GenericInternalRow(streamed.output.size) val buf: CompactBuffer[InternalRow] = new CompactBuffer() val joinedRow = new JoinedRow joinedRow.withLeft(nulls) @@ -311,7 +311,7 @@ case class BroadcastNestedLoopJoinExec( val matchedStreamRows = streamRdd.mapPartitionsInternal { streamedIter => val buildRows = relation.value val joinedRow = new JoinedRow - val nulls = new GenericMutableRow(broadcast.output.size) + val nulls = new GenericInternalRow(broadcast.output.size) streamedIter.flatMap { streamedRow => var i = 0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index fb6bfa7b2735c..8ddac19bf1b57 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -192,7 +192,7 @@ trait HashJoin { streamIter: Iterator[InternalRow], hashedRelation: HashedRelation): Iterator[InternalRow] = { val joinKeys = streamSideKeyGenerator() - val result = new GenericMutableRow(Array[Any](null)) + val result = new GenericInternalRow(Array[Any](null)) val joinedRow = new JoinedRow streamIter.map { current => val key = joinKeys(current) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 81b3e1d224ab6..ecf7cf289f034 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -275,7 +275,7 @@ case class SortMergeJoinExec( case j: ExistenceJoin => new RowIterator { private[this] var currentLeftRow: InternalRow = _ - private[this] val result: MutableRow = new GenericMutableRow(Array[Any](null)) + private[this] val result: InternalRow = new GenericInternalRow(Array[Any](null)) private[this] val smjScanner = new SortMergeJoinScanner( createLeftKeyGenerator(), createRightKeyGenerator(), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index c7e267152b5cd..2acc5110e8950 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -141,7 +141,7 @@ object ObjectOperator { def serializeObjectToRow(serializer: Seq[Expression]): Any => UnsafeRow = { val proj = GenerateUnsafeProjection.generate(serializer) val objType = serializer.head.collect { case b: BoundReference => b.dataType }.head - val objRow = new SpecificMutableRow(objType :: Nil) + val objRow = new SpecificInternalRow(objType :: Nil) (o: Any) => { objRow(0) = o proj(objRow) @@ -149,7 +149,7 @@ object ObjectOperator { } def wrapObjectToRow(objType: DataType): Any => InternalRow = { - val outputRow = new SpecificMutableRow(objType :: Nil) + val outputRow = new SpecificInternalRow(objType :: Nil) (o: Any) => { outputRow(0) = o outputRow diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala index f9d20ad090056..dcaf2c76d479d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala @@ -147,7 +147,7 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi .compute(inputIterator, context.partitionId(), context) val unpickle = new Unpickler - val mutableRow = new GenericMutableRow(1) + val mutableRow = new GenericInternalRow(1) val joined = new JoinedRow val resultType = if (udfs.length == 1) { udfs.head.dataType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index 822f49ecab47b..c02b15498748f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.stat import org.apache.spark.internal.Logging import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} -import org.apache.spark.sql.catalyst.expressions.{Cast, GenericMutableRow} +import org.apache.spark.sql.catalyst.expressions.{Cast, GenericInternalRow} import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.catalyst.util.QuantileSummaries import org.apache.spark.sql.functions._ @@ -186,7 +186,7 @@ object StatFunctions extends Logging { require(columnSize < 1e4, s"The number of distinct values for $col2, can't " + s"exceed 1e4. Currently $columnSize") val table = counts.groupBy(_.get(0)).map { case (col1Item, rows) => - val countsRow = new GenericMutableRow(columnSize + 1) + val countsRow = new GenericInternalRow(columnSize + 1) rows.foreach { (row: Row) => // row.get(0) is column 1 // row.get(1) is column 2 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/AggregateProcessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/AggregateProcessor.scala index d3a46d020dbbf..c9f5d3b3d92d7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/AggregateProcessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/AggregateProcessor.scala @@ -123,7 +123,7 @@ private[window] final class AggregateProcessor( private[this] val join = new JoinedRow private[this] val numImperatives = imperatives.length - private[this] val buffer = new SpecificMutableRow(bufferSchema.toSeq.map(_.dataType)) + private[this] val buffer = new SpecificInternalRow(bufferSchema.toSeq.map(_.dataType)) initialProjection.target(buffer) updateProjection.target(buffer) @@ -154,6 +154,6 @@ private[window] final class AggregateProcessor( } /** Evaluate buffer. */ - def evaluate(target: MutableRow): Unit = + def evaluate(target: InternalRow): Unit = evaluateProjection.target(target)(buffer) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala index 7a6a30f120386..1dd281ebf1034 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala @@ -204,7 +204,7 @@ case class WindowExec( val factory = key match { // Offset Frame case ("OFFSET", RowFrame, Some(offset), Some(h)) if offset == h => - target: MutableRow => + target: InternalRow => new OffsetWindowFunctionFrame( target, ordinal, @@ -217,7 +217,7 @@ case class WindowExec( // Growing Frame. case ("AGGREGATE", frameType, None, Some(high)) => - target: MutableRow => { + target: InternalRow => { new UnboundedPrecedingWindowFunctionFrame( target, processor, @@ -226,7 +226,7 @@ case class WindowExec( // Shrinking Frame. case ("AGGREGATE", frameType, Some(low), None) => - target: MutableRow => { + target: InternalRow => { new UnboundedFollowingWindowFunctionFrame( target, processor, @@ -235,7 +235,7 @@ case class WindowExec( // Moving Frame. case ("AGGREGATE", frameType, Some(low), Some(high)) => - target: MutableRow => { + target: InternalRow => { new SlidingWindowFunctionFrame( target, processor, @@ -245,7 +245,7 @@ case class WindowExec( // Entire Partition Frame. case ("AGGREGATE", frameType, None, None) => - target: MutableRow => { + target: InternalRow => { new UnboundedWindowFunctionFrame(target, processor) } } @@ -312,7 +312,7 @@ case class WindowExec( val inputFields = child.output.length var sorter: UnsafeExternalSorter = null var rowBuffer: RowBuffer = null - val windowFunctionResult = new SpecificMutableRow(expressions.map(_.dataType)) + val windowFunctionResult = new SpecificInternalRow(expressions.map(_.dataType)) val frames = factories.map(_(windowFunctionResult)) val numFrames = frames.length private[this] def fetchNextPartition() { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala index 2ab9faab7a59b..70efc0f78ddb0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala @@ -56,7 +56,7 @@ private[window] abstract class WindowFunctionFrame { * @param offset by which rows get moved within a partition. */ private[window] final class OffsetWindowFunctionFrame( - target: MutableRow, + target: InternalRow, ordinal: Int, expressions: Array[OffsetWindowFunction], inputSchema: Seq[Attribute], @@ -136,7 +136,7 @@ private[window] final class OffsetWindowFunctionFrame( * @param ubound comparator used to identify the upper bound of an output row. */ private[window] final class SlidingWindowFunctionFrame( - target: MutableRow, + target: InternalRow, processor: AggregateProcessor, lbound: BoundOrdering, ubound: BoundOrdering) @@ -217,7 +217,7 @@ private[window] final class SlidingWindowFunctionFrame( * @param processor to calculate the row values with. */ private[window] final class UnboundedWindowFunctionFrame( - target: MutableRow, + target: InternalRow, processor: AggregateProcessor) extends WindowFunctionFrame { @@ -255,7 +255,7 @@ private[window] final class UnboundedWindowFunctionFrame( * @param ubound comparator used to identify the upper bound of an output row. */ private[window] final class UnboundedPrecedingWindowFunctionFrame( - target: MutableRow, + target: InternalRow, processor: AggregateProcessor, ubound: BoundOrdering) extends WindowFunctionFrame { @@ -317,7 +317,7 @@ private[window] final class UnboundedPrecedingWindowFunctionFrame( * @param lbound comparator used to identify the lower bound of an output row. */ private[window] final class UnboundedFollowingWindowFunctionFrame( - target: MutableRow, + target: InternalRow, processor: AggregateProcessor, lbound: BoundOrdering) extends WindowFunctionFrame { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala index 34936b38fb5d4..7516be315dd2d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, SpecificMutableRow} +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, SpecificInternalRow} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -27,7 +27,7 @@ class RowSuite extends SparkFunSuite with SharedSQLContext { import testImplicits._ test("create row") { - val expected = new GenericMutableRow(4) + val expected = new GenericInternalRow(4) expected.setInt(0, 2147483647) expected.update(1, UTF8String.fromString("this is a string")) expected.setBoolean(2, false) @@ -49,7 +49,7 @@ class RowSuite extends SparkFunSuite with SharedSQLContext { } test("SpecificMutableRow.update with null") { - val row = new SpecificMutableRow(Seq(IntegerType)) + val row = new SpecificInternalRow(Seq(IntegerType)) row(0) = null assert(row.isNullAt(0)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala index b5eb16b6f650b..ffa26f1f8250f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala @@ -21,7 +21,7 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, Da import org.apache.spark.sql.TypedImperativeAggregateSuite.TypedMax import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, GenericMutableRow, SpecificMutableRow} +import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, GenericInternalRow, SpecificInternalRow} import org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate import org.apache.spark.sql.execution.aggregate.SortAggregateExec import org.apache.spark.sql.expressions.Window @@ -64,7 +64,7 @@ class TypedImperativeAggregateSuite extends QueryTest with SharedSQLContext { assert(agg.eval(mergeBuffer) == data.map(_._1).max) // Tests low level eval(row: InternalRow) API. - val row = new GenericMutableRow(Array(mergeBuffer): Array[Any]) + val row = new GenericInternalRow(Array(mergeBuffer): Array[Any]) // Evaluates directly on row consist of aggregation buffer object. assert(agg.eval(row) == data.map(_._1).max) @@ -73,7 +73,7 @@ class TypedImperativeAggregateSuite extends QueryTest with SharedSQLContext { test("supports SpecificMutableRow as mutable row") { val aggregationBufferSchema = Seq(IntegerType, LongType, BinaryType, IntegerType) val aggBufferOffset = 2 - val buffer = new SpecificMutableRow(aggregationBufferSchema) + val buffer = new SpecificInternalRow(aggregationBufferSchema) val agg = new TypedMax(BoundReference(ordinal = 1, dataType = IntegerType, nullable = false)) .withNewMutableAggBufferOffset(aggBufferOffset) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala index 805b5667287ea..8bf9f521e2f06 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.internal.Logging import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.CatalystTypeConverters -import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, UnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection} import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ import org.apache.spark.sql.types._ @@ -54,7 +54,7 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { expected: Int): Unit = { assertResult(expected, s"Wrong actualSize for $columnType") { - val row = new GenericMutableRow(1) + val row = new GenericInternalRow(1) row.update(0, CatalystTypeConverters.convertToCatalyst(value)) val proj = UnsafeProjection.create(Array[DataType](columnType.dataType)) columnType.actualSize(proj(row), 0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnarTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnarTestUtils.scala index 1529313dfbd51..686c8fa6f5fa9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnarTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnarTestUtils.scala @@ -21,14 +21,14 @@ import scala.collection.immutable.HashSet import scala.util.Random import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, GenericMutableRow} +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} import org.apache.spark.sql.types.{AtomicType, Decimal} import org.apache.spark.unsafe.types.UTF8String object ColumnarTestUtils { - def makeNullRow(length: Int): GenericMutableRow = { - val row = new GenericMutableRow(length) + def makeNullRow(length: Int): GenericInternalRow = { + val row = new GenericInternalRow(length) (0 until length).foreach(row.setNullAt) row } @@ -86,7 +86,7 @@ object ColumnarTestUtils { tail: ColumnType[_]*): InternalRow = makeRandomRow(Seq(head) ++ tail) def makeRandomRow(columnTypes: Seq[ColumnType[_]]): InternalRow = { - val row = new GenericMutableRow(columnTypes.length) + val row = new GenericInternalRow(columnTypes.length) makeRandomValues(columnTypes).zipWithIndex.foreach { case (value, index) => row(index) = value } @@ -95,11 +95,11 @@ object ColumnarTestUtils { def makeUniqueValuesAndSingleValueRows[T <: AtomicType]( columnType: NativeColumnType[T], - count: Int): (Seq[T#InternalType], Seq[GenericMutableRow]) = { + count: Int): (Seq[T#InternalType], Seq[GenericInternalRow]) = { val values = makeUniqueRandomValues(columnType, count) val rows = values.map { value => - val row = new GenericMutableRow(1) + val row = new GenericInternalRow(1) row(0) = value row } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessorSuite.scala index dc22d3e8e4d3a..8f4ca3cea77a5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessorSuite.scala @@ -21,7 +21,7 @@ import java.nio.ByteBuffer import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.CatalystTypeConverters -import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, UnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection} import org.apache.spark.sql.types._ class TestNullableColumnAccessor[JvmType]( @@ -72,7 +72,7 @@ class NullableColumnAccessorSuite extends SparkFunSuite { } val accessor = TestNullableColumnAccessor(builder.build(), columnType) - val row = new GenericMutableRow(1) + val row = new GenericInternalRow(1) val converter = CatalystTypeConverters.createToScalaConverter(columnType.dataType) (0 until 4).foreach { _ => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnBuilderSuite.scala index cdd4551d64b50..b2b6e92e9a056 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnBuilderSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.columnar import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.CatalystTypeConverters -import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, UnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection} import org.apache.spark.sql.types._ class TestNullableColumnBuilder[JvmType](columnType: ColumnType[JvmType]) @@ -94,7 +94,7 @@ class NullableColumnBuilderSuite extends SparkFunSuite { (1 to 7 by 2).foreach(assertResult(_, "Wrong null position")(buffer.getInt())) // For non-null values - val actual = new GenericMutableRow(new Array[Any](1)) + val actual = new GenericInternalRow(new Array[Any](1)) (0 until 4).foreach { _ => columnType.extract(buffer, actual, 0) assert(converter(actual.get(0, dataType)) === converter(randomRow.get(0, dataType)), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/BooleanBitSetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/BooleanBitSetSuite.scala index f67e9c7dae278..d01bf911e3a77 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/BooleanBitSetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/BooleanBitSetSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.columnar.compression import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.GenericMutableRow +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.execution.columnar.{BOOLEAN, NoopColumnStats} import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ @@ -72,7 +72,7 @@ class BooleanBitSetSuite extends SparkFunSuite { buffer.rewind().position(headerSize + 4) val decoder = BooleanBitSet.decoder(buffer, BOOLEAN) - val mutableRow = new GenericMutableRow(1) + val mutableRow = new GenericInternalRow(1) if (values.nonEmpty) { values.foreach { assert(decoder.hasNext) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala index babf944e6aa8e..9005ec93e786e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala @@ -23,7 +23,7 @@ import java.nio.charset.StandardCharsets import org.apache.commons.lang3.RandomStringUtils import org.apache.commons.math3.distribution.LogNormalDistribution -import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, GenericMutableRow} +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.execution.columnar.{BOOLEAN, INT, LONG, NativeColumnType, SHORT, STRING} import org.apache.spark.sql.types.AtomicType import org.apache.spark.util.Benchmark @@ -111,7 +111,7 @@ object CompressionSchemeBenchmark extends AllCompressionSchemes { input.rewind() benchmark.addCase(label)({ i: Int => - val rowBuf = new GenericMutableRow(1) + val rowBuf = new GenericInternalRow(1) for (n <- 0L until iters) { compressedBuf.rewind.position(4) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/DictionaryEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/DictionaryEncodingSuite.scala index 830ca0294e1b8..67139b13d7882 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/DictionaryEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/DictionaryEncodingSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.columnar.compression import java.nio.ByteBuffer import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions.GenericMutableRow +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.execution.columnar._ import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ import org.apache.spark.sql.types.AtomicType @@ -97,7 +97,7 @@ class DictionaryEncodingSuite extends SparkFunSuite { buffer.rewind().position(headerSize + 4) val decoder = DictionaryEncoding.decoder(buffer, columnType) - val mutableRow = new GenericMutableRow(1) + val mutableRow = new GenericInternalRow(1) if (inputSeq.nonEmpty) { inputSeq.foreach { i => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/IntegralDeltaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/IntegralDeltaSuite.scala index a530e270746c5..411d31fa0e29b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/IntegralDeltaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/IntegralDeltaSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.columnar.compression import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions.GenericMutableRow +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.execution.columnar._ import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ import org.apache.spark.sql.types.IntegralType @@ -48,7 +48,7 @@ class IntegralDeltaSuite extends SparkFunSuite { } input.foreach { value => - val row = new GenericMutableRow(1) + val row = new GenericInternalRow(1) columnType.setField(row, 0, value) builder.appendFrom(row, 0) } @@ -95,7 +95,7 @@ class IntegralDeltaSuite extends SparkFunSuite { buffer.rewind().position(headerSize + 4) val decoder = scheme.decoder(buffer, columnType) - val mutableRow = new GenericMutableRow(1) + val mutableRow = new GenericInternalRow(1) if (input.nonEmpty) { input.foreach{ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/RunLengthEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/RunLengthEncodingSuite.scala index 95642e93ae9f0..dffa9b364ebfe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/RunLengthEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/RunLengthEncodingSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.columnar.compression import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions.GenericMutableRow +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.execution.columnar._ import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ import org.apache.spark.sql.types.AtomicType @@ -80,7 +80,7 @@ class RunLengthEncodingSuite extends SparkFunSuite { buffer.rewind().position(headerSize + 4) val decoder = RunLengthEncoding.decoder(buffer, columnType) - val mutableRow = new GenericMutableRow(1) + val mutableRow = new GenericInternalRow(1) if (inputSeq.nonEmpty) { inputSeq.foreach { i => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index 3161a630af0f1..580eade4b1412 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -38,7 +38,7 @@ import org.apache.parquet.schema.{MessageType, MessageTypeParser} import org.apache.spark.SparkException import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflection} -import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -716,7 +716,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { dataTypes.zip(constantValues).foreach { case (dt, v) => val schema = StructType(StructField("pcol", dt) :: Nil) val vectorizedReader = new VectorizedParquetRecordReader - val partitionValues = new GenericMutableRow(Array(v)) + val partitionValues = new GenericInternalRow(Array(v)) val file = SpecificParquetRecordReaderBase.listDirectory(dir).get(0) try { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index 9dd8d9f80496c..4c4a7d86f2bd3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -24,7 +24,7 @@ import org.apache.parquet.hadoop.ParquetOutputFormat import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} -import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow +import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow import org.apache.spark.sql.execution.FileSourceScanExec import org.apache.spark.sql.execution.datasources.parquet.TestingUDT.{NestedStruct, NestedStructUDT, SingleElement} import org.apache.spark.sql.internal.SQLConf @@ -719,7 +719,7 @@ object TestingUDT { .add("c", DoubleType, nullable = false) override def serialize(n: NestedStruct): Any = { - val row = new SpecificMutableRow(sqlType.asInstanceOf[StructType].map(_.dataType)) + val row = new SpecificInternalRow(sqlType.asInstanceOf[StructType].map(_.dataType)) row.setInt(0, n.a) row.setLong(1, n.b) row.setDouble(2, n.c) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index fe34caa0a3e48..1625116803505 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -688,25 +688,25 @@ private[hive] trait HiveInspectors { * @return A function that performs in-place updating of a MutableRow. * Use the overloaded ObjectInspector version for assignments. */ - def unwrapperFor(field: HiveStructField): (Any, MutableRow, Int) => Unit = + def unwrapperFor(field: HiveStructField): (Any, InternalRow, Int) => Unit = field.getFieldObjectInspector match { case oi: BooleanObjectInspector => - (value: Any, row: MutableRow, ordinal: Int) => row.setBoolean(ordinal, oi.get(value)) + (value: Any, row: InternalRow, ordinal: Int) => row.setBoolean(ordinal, oi.get(value)) case oi: ByteObjectInspector => - (value: Any, row: MutableRow, ordinal: Int) => row.setByte(ordinal, oi.get(value)) + (value: Any, row: InternalRow, ordinal: Int) => row.setByte(ordinal, oi.get(value)) case oi: ShortObjectInspector => - (value: Any, row: MutableRow, ordinal: Int) => row.setShort(ordinal, oi.get(value)) + (value: Any, row: InternalRow, ordinal: Int) => row.setShort(ordinal, oi.get(value)) case oi: IntObjectInspector => - (value: Any, row: MutableRow, ordinal: Int) => row.setInt(ordinal, oi.get(value)) + (value: Any, row: InternalRow, ordinal: Int) => row.setInt(ordinal, oi.get(value)) case oi: LongObjectInspector => - (value: Any, row: MutableRow, ordinal: Int) => row.setLong(ordinal, oi.get(value)) + (value: Any, row: InternalRow, ordinal: Int) => row.setLong(ordinal, oi.get(value)) case oi: FloatObjectInspector => - (value: Any, row: MutableRow, ordinal: Int) => row.setFloat(ordinal, oi.get(value)) + (value: Any, row: InternalRow, ordinal: Int) => row.setFloat(ordinal, oi.get(value)) case oi: DoubleObjectInspector => - (value: Any, row: MutableRow, ordinal: Int) => row.setDouble(ordinal, oi.get(value)) + (value: Any, row: InternalRow, ordinal: Int) => row.setDouble(ordinal, oi.get(value)) case oi => val unwrapper = unwrapperFor(oi) - (value: Any, row: MutableRow, ordinal: Int) => row(ordinal) = unwrapper(value) + (value: Any, row: InternalRow, ordinal: Int) => row(ordinal) = unwrapper(value) } def wrap(a: Any, oi: ObjectInspector, dataType: DataType): AnyRef = { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index ec7e53efc87f9..2a54163a04e9b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -120,7 +120,7 @@ class HadoopTableReader( val hadoopRDD = createHadoopRdd(tableDesc, inputPathStr, ifc) val attrsWithIndex = attributes.zipWithIndex - val mutableRow = new SpecificMutableRow(attributes.map(_.dataType)) + val mutableRow = new SpecificInternalRow(attributes.map(_.dataType)) val deserializedHadoopRDD = hadoopRDD.mapPartitions { iter => val hconf = broadcastedHadoopConf.value.value @@ -215,7 +215,7 @@ class HadoopTableReader( val tableDesc = relation.tableDesc val broadcastedHiveConf = _broadcastedHadoopConf val localDeserializer = partDeserializer - val mutableRow = new SpecificMutableRow(attributes.map(_.dataType)) + val mutableRow = new SpecificInternalRow(attributes.map(_.dataType)) // Splits all attributes into two groups, partition key attributes and those that are not. // Attached indices indicate the position of each attribute in the output schema. @@ -224,7 +224,7 @@ class HadoopTableReader( relation.partitionKeys.contains(attr) } - def fillPartitionKeys(rawPartValues: Array[String], row: MutableRow): Unit = { + def fillPartitionKeys(rawPartValues: Array[String], row: InternalRow): Unit = { partitionKeyAttrs.foreach { case (attr, ordinal) => val partOrdinal = relation.partitionKeys.indexOf(attr) row(ordinal) = Cast(Literal(rawPartValues(partOrdinal)), attr.dataType).eval(null) @@ -360,7 +360,7 @@ private[hive] object HadoopTableReader extends HiveInspectors with Logging { iterator: Iterator[Writable], rawDeser: Deserializer, nonPartitionKeyAttrs: Seq[(Attribute, Int)], - mutableRow: MutableRow, + mutableRow: InternalRow, tableDeser: Deserializer): Iterator[InternalRow] = { val soi = if (rawDeser.getObjectInspector.equals(tableDeser.getObjectInspector)) { @@ -381,43 +381,43 @@ private[hive] object HadoopTableReader extends HiveInspectors with Logging { * Builds specific unwrappers ahead of time according to object inspector * types to avoid pattern matching and branching costs per row. */ - val unwrappers: Seq[(Any, MutableRow, Int) => Unit] = fieldRefs.map { + val unwrappers: Seq[(Any, InternalRow, Int) => Unit] = fieldRefs.map { _.getFieldObjectInspector match { case oi: BooleanObjectInspector => - (value: Any, row: MutableRow, ordinal: Int) => row.setBoolean(ordinal, oi.get(value)) + (value: Any, row: InternalRow, ordinal: Int) => row.setBoolean(ordinal, oi.get(value)) case oi: ByteObjectInspector => - (value: Any, row: MutableRow, ordinal: Int) => row.setByte(ordinal, oi.get(value)) + (value: Any, row: InternalRow, ordinal: Int) => row.setByte(ordinal, oi.get(value)) case oi: ShortObjectInspector => - (value: Any, row: MutableRow, ordinal: Int) => row.setShort(ordinal, oi.get(value)) + (value: Any, row: InternalRow, ordinal: Int) => row.setShort(ordinal, oi.get(value)) case oi: IntObjectInspector => - (value: Any, row: MutableRow, ordinal: Int) => row.setInt(ordinal, oi.get(value)) + (value: Any, row: InternalRow, ordinal: Int) => row.setInt(ordinal, oi.get(value)) case oi: LongObjectInspector => - (value: Any, row: MutableRow, ordinal: Int) => row.setLong(ordinal, oi.get(value)) + (value: Any, row: InternalRow, ordinal: Int) => row.setLong(ordinal, oi.get(value)) case oi: FloatObjectInspector => - (value: Any, row: MutableRow, ordinal: Int) => row.setFloat(ordinal, oi.get(value)) + (value: Any, row: InternalRow, ordinal: Int) => row.setFloat(ordinal, oi.get(value)) case oi: DoubleObjectInspector => - (value: Any, row: MutableRow, ordinal: Int) => row.setDouble(ordinal, oi.get(value)) + (value: Any, row: InternalRow, ordinal: Int) => row.setDouble(ordinal, oi.get(value)) case oi: HiveVarcharObjectInspector => - (value: Any, row: MutableRow, ordinal: Int) => + (value: Any, row: InternalRow, ordinal: Int) => row.update(ordinal, UTF8String.fromString(oi.getPrimitiveJavaObject(value).getValue)) case oi: HiveCharObjectInspector => - (value: Any, row: MutableRow, ordinal: Int) => + (value: Any, row: InternalRow, ordinal: Int) => row.update(ordinal, UTF8String.fromString(oi.getPrimitiveJavaObject(value).getValue)) case oi: HiveDecimalObjectInspector => - (value: Any, row: MutableRow, ordinal: Int) => + (value: Any, row: InternalRow, ordinal: Int) => row.update(ordinal, HiveShim.toCatalystDecimal(oi, value)) case oi: TimestampObjectInspector => - (value: Any, row: MutableRow, ordinal: Int) => + (value: Any, row: InternalRow, ordinal: Int) => row.setLong(ordinal, DateTimeUtils.fromJavaTimestamp(oi.getPrimitiveJavaObject(value))) case oi: DateObjectInspector => - (value: Any, row: MutableRow, ordinal: Int) => + (value: Any, row: InternalRow, ordinal: Int) => row.setInt(ordinal, DateTimeUtils.fromJavaDate(oi.getPrimitiveJavaObject(value))) case oi: BinaryObjectInspector => - (value: Any, row: MutableRow, ordinal: Int) => + (value: Any, row: InternalRow, ordinal: Int) => row.update(ordinal, oi.getPrimitiveJavaObject(value)) case oi => val unwrapper = unwrapperFor(oi) - (value: Any, row: MutableRow, ordinal: Int) => row(ordinal) = unwrapper(value) + (value: Any, row: InternalRow, ordinal: Int) => row(ordinal) = unwrapper(value) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index c553c03a9b708..1025b8f70d9ff 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -124,7 +124,7 @@ case class ScriptTransformation( } else { null } - val mutableRow = new SpecificMutableRow(output.map(_.dataType)) + val mutableRow = new SpecificInternalRow(output.map(_.dataType)) @transient lazy val unwrappers = outputSoi.getAllStructFieldRefs.asScala.map(unwrapperFor) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index d54913518bb33..42033080dc34b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -329,17 +329,17 @@ private[hive] case class HiveUDAFFunction( // buffer for it. override def aggBufferSchema: StructType = StructType(Nil) - override def update(_buffer: MutableRow, input: InternalRow): Unit = { + override def update(_buffer: InternalRow, input: InternalRow): Unit = { val inputs = inputProjection(input) function.iterate(buffer, wrap(inputs, wrappers, cached, inputDataTypes)) } - override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { + override def merge(buffer1: InternalRow, buffer2: InternalRow): Unit = { throw new UnsupportedOperationException( "Hive UDAF doesn't support partial aggregate") } - override def initialize(_buffer: MutableRow): Unit = { + override def initialize(_buffer: InternalRow): Unit = { buffer = function.getNewAggregationBuffer } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala index 15b72d8d2179f..e94f49ea81177 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala @@ -281,7 +281,7 @@ private[orc] object OrcRelation extends HiveInspectors { maybeStructOI: Option[StructObjectInspector], iterator: Iterator[Writable]): Iterator[InternalRow] = { val deserializer = new OrcSerde - val mutableRow = new SpecificMutableRow(dataSchema.map(_.dataType)) + val mutableRow = new SpecificInternalRow(dataSchema.map(_.dataType)) val unsafeProjection = UnsafeProjection.create(dataSchema) def unwrap(oi: StructObjectInspector): Iterator[InternalRow] = { From 94b24b84a666517e31e9c9d693f92d9bbfd7f9ad Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 7 Oct 2016 15:03:47 -0700 Subject: [PATCH 222/306] [SPARK-17806] [SQL] fix bug in join key rewritten in HashJoin ## What changes were proposed in this pull request? In HashJoin, we try to rewrite the join key as Long to improve the performance of finding a match. The rewriting part is not well tested, has a bug that could cause wrong result when there are at least three integral columns in the joining key also the total length of the key exceed 8 bytes. ## How was this patch tested? Added unit test to covering the rewriting with different number of columns and different data types. Manually test the reported case and confirmed that this PR fix the bug. Author: Davies Liu Closes #15390 from davies/rewrite_key. --- .../spark/sql/execution/joins/HashJoin.scala | 65 +++++++++---------- .../execution/joins/BroadcastJoinSuite.scala | 47 ++++++++++++++ 2 files changed, 79 insertions(+), 33 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 8ddac19bf1b57..05c5e2f4cd77b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -63,45 +63,16 @@ trait HashJoin { protected lazy val (buildKeys, streamedKeys) = { require(leftKeys.map(_.dataType) == rightKeys.map(_.dataType), "Join keys from two sides should have same types") - val lkeys = rewriteKeyExpr(leftKeys).map(BindReferences.bindReference(_, left.output)) - val rkeys = rewriteKeyExpr(rightKeys).map(BindReferences.bindReference(_, right.output)) + val lkeys = HashJoin.rewriteKeyExpr(leftKeys).map(BindReferences.bindReference(_, left.output)) + val rkeys = HashJoin.rewriteKeyExpr(rightKeys) + .map(BindReferences.bindReference(_, right.output)) buildSide match { case BuildLeft => (lkeys, rkeys) case BuildRight => (rkeys, lkeys) } } - /** - * Try to rewrite the key as LongType so we can use getLong(), if they key can fit with a long. - * - * If not, returns the original expressions. - */ - private def rewriteKeyExpr(keys: Seq[Expression]): Seq[Expression] = { - var keyExpr: Expression = null - var width = 0 - keys.foreach { e => - e.dataType match { - case dt: IntegralType if dt.defaultSize <= 8 - width => - if (width == 0) { - if (e.dataType != LongType) { - keyExpr = Cast(e, LongType) - } else { - keyExpr = e - } - width = dt.defaultSize - } else { - val bits = dt.defaultSize * 8 - keyExpr = BitwiseOr(ShiftLeft(keyExpr, Literal(bits)), - BitwiseAnd(Cast(e, LongType), Literal((1L << bits) - 1))) - width -= bits - } - // TODO: support BooleanType, DateType and TimestampType - case other => - return keys - } - } - keyExpr :: Nil - } + protected def buildSideKeyGenerator(): Projection = UnsafeProjection.create(buildKeys) @@ -247,3 +218,31 @@ trait HashJoin { } } } + +object HashJoin { + /** + * Try to rewrite the key as LongType so we can use getLong(), if they key can fit with a long. + * + * If not, returns the original expressions. + */ + private[joins] def rewriteKeyExpr(keys: Seq[Expression]): Seq[Expression] = { + assert(keys.nonEmpty) + // TODO: support BooleanType, DateType and TimestampType + if (keys.exists(!_.dataType.isInstanceOf[IntegralType]) + || keys.map(_.dataType.defaultSize).sum > 8) { + return keys + } + + var keyExpr: Expression = if (keys.head.dataType != LongType) { + Cast(keys.head, LongType) + } else { + keys.head + } + keys.tail.foreach { e => + val bits = e.dataType.defaultSize * 8 + keyExpr = BitwiseOr(ShiftLeft(keyExpr, Literal(bits)), + BitwiseAnd(Cast(e, LongType), Literal((1L << bits) - 1))) + } + keyExpr :: Nil + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 97adffa8ce101..83db81ea3f1c2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -21,11 +21,13 @@ import scala.reflect.ClassTag import org.apache.spark.AccumulatorSuite import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession} +import org.apache.spark.sql.catalyst.expressions.{BitwiseAnd, BitwiseOr, Cast, Literal, ShiftLeft} import org.apache.spark.sql.execution.exchange.EnsureRequirements import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.types.{LongType, ShortType} /** * Test various broadcast join operators. @@ -153,4 +155,49 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { cases.foreach(assertBroadcastJoin) } } + + test("join key rewritten") { + val l = Literal(1L) + val i = Literal(2) + val s = Literal.create(3, ShortType) + val ss = Literal("hello") + + assert(HashJoin.rewriteKeyExpr(l :: Nil) === l :: Nil) + assert(HashJoin.rewriteKeyExpr(l :: l :: Nil) === l :: l :: Nil) + assert(HashJoin.rewriteKeyExpr(l :: i :: Nil) === l :: i :: Nil) + + assert(HashJoin.rewriteKeyExpr(i :: Nil) === Cast(i, LongType) :: Nil) + assert(HashJoin.rewriteKeyExpr(i :: l :: Nil) === i :: l :: Nil) + assert(HashJoin.rewriteKeyExpr(i :: i :: Nil) === + BitwiseOr(ShiftLeft(Cast(i, LongType), Literal(32)), + BitwiseAnd(Cast(i, LongType), Literal((1L << 32) - 1))) :: Nil) + assert(HashJoin.rewriteKeyExpr(i :: i :: i :: Nil) === i :: i :: i :: Nil) + + assert(HashJoin.rewriteKeyExpr(s :: Nil) === Cast(s, LongType) :: Nil) + assert(HashJoin.rewriteKeyExpr(s :: l :: Nil) === s :: l :: Nil) + assert(HashJoin.rewriteKeyExpr(s :: s :: Nil) === + BitwiseOr(ShiftLeft(Cast(s, LongType), Literal(16)), + BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))) :: Nil) + assert(HashJoin.rewriteKeyExpr(s :: s :: s :: Nil) === + BitwiseOr(ShiftLeft( + BitwiseOr(ShiftLeft(Cast(s, LongType), Literal(16)), + BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))), + Literal(16)), + BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))) :: Nil) + assert(HashJoin.rewriteKeyExpr(s :: s :: s :: s :: Nil) === + BitwiseOr(ShiftLeft( + BitwiseOr(ShiftLeft( + BitwiseOr(ShiftLeft(Cast(s, LongType), Literal(16)), + BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))), + Literal(16)), + BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))), + Literal(16)), + BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))) :: Nil) + assert(HashJoin.rewriteKeyExpr(s :: s :: s :: s :: s :: Nil) === + s :: s :: s :: s :: s :: Nil) + + assert(HashJoin.rewriteKeyExpr(ss :: Nil) === ss :: Nil) + assert(HashJoin.rewriteKeyExpr(l :: ss :: Nil) === l :: ss :: Nil) + assert(HashJoin.rewriteKeyExpr(i :: ss :: Nil) === i :: ss :: Nil) + } } From 24850c9415bfe18dc1edf66e5a7b4c554fff4f23 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 7 Oct 2016 17:59:24 -0700 Subject: [PATCH 223/306] [HOTFIX][BUILD] Do not use contains in Option in JdbcRelationProvider ## What changes were proposed in this pull request? This PR proposes the fix the use of `contains` API which only exists from Scala 2.11. ## How was this patch tested? Manually checked: ```scala scala> val o: Option[Boolean] = None o: Option[Boolean] = None scala> o == Some(false) res17: Boolean = false scala> val o: Option[Boolean] = Some(true) o: Option[Boolean] = Some(true) scala> o == Some(false) res18: Boolean = false scala> val o: Option[Boolean] = Some(false) o: Option[Boolean] = Some(false) scala> o == Some(false) res19: Boolean = true ``` Author: hyukjinkwon Closes #15393 from HyukjinKwon/hotfix. --- .../sql/execution/datasources/jdbc/JdbcRelationProvider.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala index 3a8a197ef5241..b1a061b6f7422 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala @@ -70,7 +70,7 @@ class JdbcRelationProvider extends CreatableRelationProvider if (tableExists) { mode match { case SaveMode.Overwrite => - if (isTruncate && isCascadingTruncateTable(url).contains(false)) { + if (isTruncate && isCascadingTruncateTable(url) == Some(false)) { // In this case, we should truncate table and then load. truncateTable(conn, table) saveTable(df, url, table, props) From 471690f90f3bf29735faecd83d4671842c57b164 Mon Sep 17 00:00:00 2001 From: "wm624@hotmail.com" Date: Fri, 7 Oct 2016 18:00:26 -0700 Subject: [PATCH 224/306] [MINOR][ML] remove redundant comment in LogisticRegression ## What changes were proposed in this pull request? While adding R wrapper for LogisticRegression, I found one extra comment. It is minor and I just remove it. ## How was this patch tested? Unit tests Author: wm624@hotmail.com Closes #15391 from wangmiao1981/mlordoc. --- .../org/apache/spark/ml/classification/LogisticRegression.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 329961a25d984..862a468745fbd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -78,7 +78,6 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas /** * Param for the name of family which is a description of the label distribution * to be used in the model. - * Supported options: "auto", "multinomial", "binomial". * Supported options: * - "auto": Automatically select the family based on the number of classes: * If numClasses == 1 || numClasses == 2, set to "binomial". From 362ba4b6f8e8fc2355368742c5adced7573fec00 Mon Sep 17 00:00:00 2001 From: Alex Bozarth Date: Sat, 8 Oct 2016 11:24:00 +0100 Subject: [PATCH 225/306] =?UTF-8?q?[SPARK-17793][WEB=20UI]=20Sorting=20on?= =?UTF-8?q?=20the=20description=20on=20the=20Job=20or=20Stage=20page=20doe?= =?UTF-8?q?sn=E2=80=99t=20always=20work?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Added secondary sorting on stage name for the description column. This provide a clearer behavior in the common case where the Description column only comprises of Stage names instead of the option description value. ## How was this patch tested? manual testing and dev/run-tests Screenshots of sorting on both description and stage name as well as an example of both: ![screen shot 2016-10-04 at 1 09 39 pm](https://cloud.githubusercontent.com/assets/13952758/19135523/067b042e-8b1a-11e6-912e-e6371d006d21.png) ![screen shot 2016-10-04 at 1 09 51 pm](https://cloud.githubusercontent.com/assets/13952758/19135526/06960936-8b1a-11e6-85e9-8aaf694c5f7b.png) ![screen shot 2016-10-05 at 1 14 45 pm](https://cloud.githubusercontent.com/assets/13952758/19135525/069547da-8b1a-11e6-8692-6524c75c4c07.png) ![screen shot 2016-10-05 at 1 14 51 pm](https://cloud.githubusercontent.com/assets/13952758/19135524/0694b4d2-8b1a-11e6-92dc-c8aa514e4f62.png) ![screen shot 2016-10-05 at 4 42 52 pm](https://cloud.githubusercontent.com/assets/13952758/19135618/e232eafe-8b1a-11e6-88b3-ff0bbb26b7f8.png) Author: Alex Bozarth Closes #15366 from ajbozarth/spark17793. --- .../apache/spark/ui/jobs/AllJobsPage.scala | 25 +--- .../org/apache/spark/ui/jobs/StagePage.scala | 134 ++++-------------- .../org/apache/spark/ui/jobs/StageTable.scala | 51 ++----- .../org/apache/spark/ui/storage/RDDPage.scala | 27 +--- 4 files changed, 49 insertions(+), 188 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala index 19bb41a1417c7..f6713097b9349 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala @@ -457,23 +457,11 @@ private[ui] class JobDataSource( * Return Ordering according to sortColumn and desc */ private def ordering(sortColumn: String, desc: Boolean): Ordering[JobTableRowData] = { - val ordering = sortColumn match { - case "Job Id" | "Job Id (Job Group)" => new Ordering[JobTableRowData] { - override def compare(x: JobTableRowData, y: JobTableRowData): Int = - Ordering.Int.compare(x.jobData.jobId, y.jobData.jobId) - } - case "Description" => new Ordering[JobTableRowData] { - override def compare(x: JobTableRowData, y: JobTableRowData): Int = - Ordering.String.compare(x.lastStageDescription, y.lastStageDescription) - } - case "Submitted" => new Ordering[JobTableRowData] { - override def compare(x: JobTableRowData, y: JobTableRowData): Int = - Ordering.Long.compare(x.submissionTime, y.submissionTime) - } - case "Duration" => new Ordering[JobTableRowData] { - override def compare(x: JobTableRowData, y: JobTableRowData): Int = - Ordering.Long.compare(x.duration, y.duration) - } + val ordering: Ordering[JobTableRowData] = sortColumn match { + case "Job Id" | "Job Id (Job Group)" => Ordering.by(_.jobData.jobId) + case "Description" => Ordering.by(x => (x.lastStageDescription, x.lastStageName)) + case "Submitted" => Ordering.by(_.submissionTime) + case "Duration" => Ordering.by(_.duration) case "Stages: Succeeded/Total" | "Tasks (for all stages): Succeeded/Total" => throw new IllegalArgumentException(s"Unsortable column: $sortColumn") case unknownColumn => throw new IllegalArgumentException(s"Unknown column: $unknownColumn") @@ -501,8 +489,7 @@ private[ui] class JobPagedTable( sortColumn: String, desc: Boolean ) extends PagedTable[JobTableRowData] { - val parameterPath = UIUtils.prependBaseUri(basePath) + s"/$subPath/?" + - parameterOtherTable.mkString("&") + val parameterPath = basePath + s"/$subPath/?" + parameterOtherTable.mkString("&") override def tableId: String = jobTag + "-table" diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index c322ae0972ad7..8c7cefe200739 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -1050,89 +1050,38 @@ private[ui] class TaskDataSource( * Return Ordering according to sortColumn and desc */ private def ordering(sortColumn: String, desc: Boolean): Ordering[TaskTableRowData] = { - val ordering = sortColumn match { - case "Index" => new Ordering[TaskTableRowData] { - override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = - Ordering.Int.compare(x.index, y.index) - } - case "ID" => new Ordering[TaskTableRowData] { - override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = - Ordering.Long.compare(x.taskId, y.taskId) - } - case "Attempt" => new Ordering[TaskTableRowData] { - override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = - Ordering.Int.compare(x.attempt, y.attempt) - } - case "Status" => new Ordering[TaskTableRowData] { - override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = - Ordering.String.compare(x.status, y.status) - } - case "Locality Level" => new Ordering[TaskTableRowData] { - override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = - Ordering.String.compare(x.taskLocality, y.taskLocality) - } - case "Executor ID / Host" => new Ordering[TaskTableRowData] { - override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = - Ordering.String.compare(x.executorIdAndHost, y.executorIdAndHost) - } - case "Launch Time" => new Ordering[TaskTableRowData] { - override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = - Ordering.Long.compare(x.launchTime, y.launchTime) - } - case "Duration" => new Ordering[TaskTableRowData] { - override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = - Ordering.Long.compare(x.duration, y.duration) - } - case "Scheduler Delay" => new Ordering[TaskTableRowData] { - override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = - Ordering.Long.compare(x.schedulerDelay, y.schedulerDelay) - } - case "Task Deserialization Time" => new Ordering[TaskTableRowData] { - override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = - Ordering.Long.compare(x.taskDeserializationTime, y.taskDeserializationTime) - } - case "GC Time" => new Ordering[TaskTableRowData] { - override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = - Ordering.Long.compare(x.gcTime, y.gcTime) - } - case "Result Serialization Time" => new Ordering[TaskTableRowData] { - override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = - Ordering.Long.compare(x.serializationTime, y.serializationTime) - } - case "Getting Result Time" => new Ordering[TaskTableRowData] { - override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = - Ordering.Long.compare(x.gettingResultTime, y.gettingResultTime) - } - case "Peak Execution Memory" => new Ordering[TaskTableRowData] { - override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = - Ordering.Long.compare(x.peakExecutionMemoryUsed, y.peakExecutionMemoryUsed) - } + val ordering: Ordering[TaskTableRowData] = sortColumn match { + case "Index" => Ordering.by(_.index) + case "ID" => Ordering.by(_.taskId) + case "Attempt" => Ordering.by(_.attempt) + case "Status" => Ordering.by(_.status) + case "Locality Level" => Ordering.by(_.taskLocality) + case "Executor ID / Host" => Ordering.by(_.executorIdAndHost) + case "Launch Time" => Ordering.by(_.launchTime) + case "Duration" => Ordering.by(_.duration) + case "Scheduler Delay" => Ordering.by(_.schedulerDelay) + case "Task Deserialization Time" => Ordering.by(_.taskDeserializationTime) + case "GC Time" => Ordering.by(_.gcTime) + case "Result Serialization Time" => Ordering.by(_.serializationTime) + case "Getting Result Time" => Ordering.by(_.gettingResultTime) + case "Peak Execution Memory" => Ordering.by(_.peakExecutionMemoryUsed) case "Accumulators" => if (hasAccumulators) { - new Ordering[TaskTableRowData] { - override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = - Ordering.String.compare(x.accumulators.get, y.accumulators.get) - } + Ordering.by(_.accumulators.get) } else { throw new IllegalArgumentException( "Cannot sort by Accumulators because of no accumulators") } case "Input Size / Records" => if (hasInput) { - new Ordering[TaskTableRowData] { - override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = - Ordering.Long.compare(x.input.get.inputSortable, y.input.get.inputSortable) - } + Ordering.by(_.input.get.inputSortable) } else { throw new IllegalArgumentException( "Cannot sort by Input Size / Records because of no inputs") } case "Output Size / Records" => if (hasOutput) { - new Ordering[TaskTableRowData] { - override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = - Ordering.Long.compare(x.output.get.outputSortable, y.output.get.outputSortable) - } + Ordering.by(_.output.get.outputSortable) } else { throw new IllegalArgumentException( "Cannot sort by Output Size / Records because of no outputs") @@ -1140,33 +1089,21 @@ private[ui] class TaskDataSource( // ShuffleRead case "Shuffle Read Blocked Time" => if (hasShuffleRead) { - new Ordering[TaskTableRowData] { - override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = - Ordering.Long.compare(x.shuffleRead.get.shuffleReadBlockedTimeSortable, - y.shuffleRead.get.shuffleReadBlockedTimeSortable) - } + Ordering.by(_.shuffleRead.get.shuffleReadBlockedTimeSortable) } else { throw new IllegalArgumentException( "Cannot sort by Shuffle Read Blocked Time because of no shuffle reads") } case "Shuffle Read Size / Records" => if (hasShuffleRead) { - new Ordering[TaskTableRowData] { - override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = - Ordering.Long.compare(x.shuffleRead.get.shuffleReadSortable, - y.shuffleRead.get.shuffleReadSortable) - } + Ordering.by(_.shuffleRead.get.shuffleReadSortable) } else { throw new IllegalArgumentException( "Cannot sort by Shuffle Read Size / Records because of no shuffle reads") } case "Shuffle Remote Reads" => if (hasShuffleRead) { - new Ordering[TaskTableRowData] { - override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = - Ordering.Long.compare(x.shuffleRead.get.shuffleReadRemoteSortable, - y.shuffleRead.get.shuffleReadRemoteSortable) - } + Ordering.by(_.shuffleRead.get.shuffleReadRemoteSortable) } else { throw new IllegalArgumentException( "Cannot sort by Shuffle Remote Reads because of no shuffle reads") @@ -1174,22 +1111,14 @@ private[ui] class TaskDataSource( // ShuffleWrite case "Write Time" => if (hasShuffleWrite) { - new Ordering[TaskTableRowData] { - override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = - Ordering.Long.compare(x.shuffleWrite.get.writeTimeSortable, - y.shuffleWrite.get.writeTimeSortable) - } + Ordering.by(_.shuffleWrite.get.writeTimeSortable) } else { throw new IllegalArgumentException( "Cannot sort by Write Time because of no shuffle writes") } case "Shuffle Write Size / Records" => if (hasShuffleWrite) { - new Ordering[TaskTableRowData] { - override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = - Ordering.Long.compare(x.shuffleWrite.get.shuffleWriteSortable, - y.shuffleWrite.get.shuffleWriteSortable) - } + Ordering.by(_.shuffleWrite.get.shuffleWriteSortable) } else { throw new IllegalArgumentException( "Cannot sort by Shuffle Write Size / Records because of no shuffle writes") @@ -1197,30 +1126,19 @@ private[ui] class TaskDataSource( // BytesSpilled case "Shuffle Spill (Memory)" => if (hasBytesSpilled) { - new Ordering[TaskTableRowData] { - override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = - Ordering.Long.compare(x.bytesSpilled.get.memoryBytesSpilledSortable, - y.bytesSpilled.get.memoryBytesSpilledSortable) - } + Ordering.by(_.bytesSpilled.get.memoryBytesSpilledSortable) } else { throw new IllegalArgumentException( "Cannot sort by Shuffle Spill (Memory) because of no spills") } case "Shuffle Spill (Disk)" => if (hasBytesSpilled) { - new Ordering[TaskTableRowData] { - override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = - Ordering.Long.compare(x.bytesSpilled.get.diskBytesSpilledSortable, - y.bytesSpilled.get.diskBytesSpilledSortable) - } + Ordering.by(_.bytesSpilled.get.diskBytesSpilledSortable) } else { throw new IllegalArgumentException( "Cannot sort by Shuffle Spill (Disk) because of no spills") } - case "Errors" => new Ordering[TaskTableRowData] { - override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = - Ordering.String.compare(x.error, y.error) - } + case "Errors" => Ordering.by(_.error) case unknownColumn => throw new IllegalArgumentException(s"Unknown column: $unknownColumn") } if (desc) { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index 40a6762c281ce..9b9b4681ba5db 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -109,7 +109,6 @@ private[ui] class StageTableRowData( val stageId: Int, val attemptId: Int, val schedulingPool: String, - val description: String, val descriptionOption: Option[String], val submissionTime: Long, val formattedSubmissionTime: String, @@ -128,7 +127,7 @@ private[ui] class MissingStageTableRowData( stageInfo: StageInfo, stageId: Int, attemptId: Int) extends StageTableRowData( - stageInfo, None, stageId, attemptId, "", "", None, 0, "", -1, "", 0, "", 0, "", 0, "", 0, "") + stageInfo, None, stageId, attemptId, "", None, 0, "", -1, "", 0, "", 0, "", 0, "", 0, "") /** Page showing list of all ongoing and recently finished stages */ private[ui] class StagePagedTable( @@ -470,7 +469,6 @@ private[ui] class StageDataSource( s.stageId, s.attemptId, stageData.schedulingPool, - description.getOrElse(""), description, s.submissionTime.getOrElse(0), formattedSubmissionTime, @@ -491,43 +489,16 @@ private[ui] class StageDataSource( * Return Ordering according to sortColumn and desc */ private def ordering(sortColumn: String, desc: Boolean): Ordering[StageTableRowData] = { - val ordering = sortColumn match { - case "Stage Id" => new Ordering[StageTableRowData] { - override def compare(x: StageTableRowData, y: StageTableRowData): Int = - Ordering.Int.compare(x.stageId, y.stageId) - } - case "Pool Name" => new Ordering[StageTableRowData] { - override def compare(x: StageTableRowData, y: StageTableRowData): Int = - Ordering.String.compare(x.schedulingPool, y.schedulingPool) - } - case "Description" => new Ordering[StageTableRowData] { - override def compare(x: StageTableRowData, y: StageTableRowData): Int = - Ordering.String.compare(x.description, y.description) - } - case "Submitted" => new Ordering[StageTableRowData] { - override def compare(x: StageTableRowData, y: StageTableRowData): Int = - Ordering.Long.compare(x.submissionTime, y.submissionTime) - } - case "Duration" => new Ordering[StageTableRowData] { - override def compare(x: StageTableRowData, y: StageTableRowData): Int = - Ordering.Long.compare(x.duration, y.duration) - } - case "Input" => new Ordering[StageTableRowData] { - override def compare(x: StageTableRowData, y: StageTableRowData): Int = - Ordering.Long.compare(x.inputRead, y.inputRead) - } - case "Output" => new Ordering[StageTableRowData] { - override def compare(x: StageTableRowData, y: StageTableRowData): Int = - Ordering.Long.compare(x.outputWrite, y.outputWrite) - } - case "Shuffle Read" => new Ordering[StageTableRowData] { - override def compare(x: StageTableRowData, y: StageTableRowData): Int = - Ordering.Long.compare(x.shuffleRead, y.shuffleRead) - } - case "Shuffle Write" => new Ordering[StageTableRowData] { - override def compare(x: StageTableRowData, y: StageTableRowData): Int = - Ordering.Long.compare(x.shuffleWrite, y.shuffleWrite) - } + val ordering: Ordering[StageTableRowData] = sortColumn match { + case "Stage Id" => Ordering.by(_.stageId) + case "Pool Name" => Ordering.by(_.schedulingPool) + case "Description" => Ordering.by(x => (x.descriptionOption, x.stageInfo.name)) + case "Submitted" => Ordering.by(_.submissionTime) + case "Duration" => Ordering.by(_.duration) + case "Input" => Ordering.by(_.inputRead) + case "Output" => Ordering.by(_.outputWrite) + case "Shuffle Read" => Ordering.by(_.shuffleRead) + case "Shuffle Write" => Ordering.by(_.shuffleWrite) case "Tasks: Succeeded/Total" => throw new IllegalArgumentException(s"Unsortable column: $sortColumn") case unknownColumn => throw new IllegalArgumentException(s"Unknown column: $unknownColumn") diff --git a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala index 606d15d599e81..227e940c9c50c 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala @@ -197,27 +197,12 @@ private[ui] class BlockDataSource( * Return Ordering according to sortColumn and desc */ private def ordering(sortColumn: String, desc: Boolean): Ordering[BlockTableRowData] = { - val ordering = sortColumn match { - case "Block Name" => new Ordering[BlockTableRowData] { - override def compare(x: BlockTableRowData, y: BlockTableRowData): Int = - Ordering.String.compare(x.blockName, y.blockName) - } - case "Storage Level" => new Ordering[BlockTableRowData] { - override def compare(x: BlockTableRowData, y: BlockTableRowData): Int = - Ordering.String.compare(x.storageLevel, y.storageLevel) - } - case "Size in Memory" => new Ordering[BlockTableRowData] { - override def compare(x: BlockTableRowData, y: BlockTableRowData): Int = - Ordering.Long.compare(x.memoryUsed, y.memoryUsed) - } - case "Size on Disk" => new Ordering[BlockTableRowData] { - override def compare(x: BlockTableRowData, y: BlockTableRowData): Int = - Ordering.Long.compare(x.diskUsed, y.diskUsed) - } - case "Executors" => new Ordering[BlockTableRowData] { - override def compare(x: BlockTableRowData, y: BlockTableRowData): Int = - Ordering.String.compare(x.executors, y.executors) - } + val ordering: Ordering[BlockTableRowData] = sortColumn match { + case "Block Name" => Ordering.by(_.blockName) + case "Storage Level" => Ordering.by(_.storageLevel) + case "Size in Memory" => Ordering.by(_.memoryUsed) + case "Size on Disk" => Ordering.by(_.diskUsed) + case "Executors" => Ordering.by(_.executors) case unknownColumn => throw new IllegalArgumentException(s"Unknown column: $unknownColumn") } if (desc) { From 4201ddcc07ca2e9af78bf4a74fdb3900c1783347 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sat, 8 Oct 2016 11:31:12 +0100 Subject: [PATCH 226/306] [SPARK-17768][CORE] Small (Sum,Count,Mean)Evaluator problems and suboptimalities MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Fix: - GroupedMeanEvaluator and GroupedSumEvaluator are unused, as is the StudentTCacher support class - CountEvaluator can return a lower bound < 0, when counts can't be negative - MeanEvaluator will actually fail on exactly 1 datum (yields t-test with 0 DOF) - CountEvaluator uses a normal distribution, which may be an inappropriate approximation (leading to above) - Test for SumEvaluator asserts incorrect expected sums – e.g. after observing 10% of data has sum of 2, expectation should be 20, not 38 - CountEvaluator, MeanEvaluator have no unit tests to catch these - Duplication of distribution code across CountEvaluator, GroupedCountEvaluator - The stats in each could use a bit of documentation as I had to guess at them - (Code could use a few cleanups and optimizations too) ## How was this patch tested? Existing and new tests Author: Sean Owen Closes #15341 from srowen/SPARK-17768. --- .../apache/spark/partial/CountEvaluator.scala | 53 +++++++---- .../spark/partial/GroupedCountEvaluator.scala | 30 ++----- .../spark/partial/GroupedMeanEvaluator.scala | 80 ----------------- .../spark/partial/GroupedSumEvaluator.scala | 88 ------------------- .../apache/spark/partial/MeanEvaluator.scala | 23 +++-- .../apache/spark/partial/StudentTCacher.scala | 46 ---------- .../apache/spark/partial/SumEvaluator.scala | 33 ++++--- .../spark/partial/CountEvaluatorSuite.scala | 43 +++++++++ .../spark/partial/MeanEvaluatorSuite.scala | 57 ++++++++++++ .../spark/partial/SumEvaluatorSuite.scala | 82 ++++++----------- 10 files changed, 203 insertions(+), 332 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/partial/GroupedMeanEvaluator.scala delete mode 100644 core/src/main/scala/org/apache/spark/partial/GroupedSumEvaluator.scala delete mode 100644 core/src/main/scala/org/apache/spark/partial/StudentTCacher.scala create mode 100644 core/src/test/scala/org/apache/spark/partial/CountEvaluatorSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/partial/MeanEvaluatorSuite.scala diff --git a/core/src/main/scala/org/apache/spark/partial/CountEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/CountEvaluator.scala index 637492a97551b..5a5bd7fbbe2f8 100644 --- a/core/src/main/scala/org/apache/spark/partial/CountEvaluator.scala +++ b/core/src/main/scala/org/apache/spark/partial/CountEvaluator.scala @@ -17,21 +17,18 @@ package org.apache.spark.partial -import org.apache.commons.math3.distribution.NormalDistribution +import org.apache.commons.math3.distribution.{PascalDistribution, PoissonDistribution} /** * An ApproximateEvaluator for counts. - * - * TODO: There's currently a lot of shared code between this and GroupedCountEvaluator. It might - * be best to make this a special case of GroupedCountEvaluator with one group. */ private[spark] class CountEvaluator(totalOutputs: Int, confidence: Double) extends ApproximateEvaluator[Long, BoundedDouble] { - var outputsMerged = 0 - var sum: Long = 0 + private var outputsMerged = 0 + private var sum: Long = 0 - override def merge(outputId: Int, taskResult: Long) { + override def merge(outputId: Int, taskResult: Long): Unit = { outputsMerged += 1 sum += taskResult } @@ -39,18 +36,40 @@ private[spark] class CountEvaluator(totalOutputs: Int, confidence: Double) override def currentResult(): BoundedDouble = { if (outputsMerged == totalOutputs) { new BoundedDouble(sum, 1.0, sum, sum) - } else if (outputsMerged == 0) { - new BoundedDouble(0, 0.0, Double.NegativeInfinity, Double.PositiveInfinity) + } else if (outputsMerged == 0 || sum == 0) { + new BoundedDouble(0, 0.0, 0.0, Double.PositiveInfinity) } else { val p = outputsMerged.toDouble / totalOutputs - val mean = (sum + 1 - p) / p - val variance = (sum + 1) * (1 - p) / (p * p) - val stdev = math.sqrt(variance) - val confFactor = new NormalDistribution(). - inverseCumulativeProbability(1 - (1 - confidence) / 2) - val low = mean - confFactor * stdev - val high = mean + confFactor * stdev - new BoundedDouble(mean, confidence, low, high) + CountEvaluator.bound(confidence, sum, p) } } } + +private[partial] object CountEvaluator { + + def bound(confidence: Double, sum: Long, p: Double): BoundedDouble = { + // Let the total count be N. A fraction p has been counted already, with sum 'sum', + // as if each element from the total data set had been seen with probability p. + val dist = + if (sum <= 10000) { + // The remaining count, k=N-sum, may be modeled as negative binomial (aka Pascal), + // where there have been 'sum' successes of probability p already. (There are several + // conventions, but this is the one followed by Commons Math3.) + new PascalDistribution(sum.toInt, p) + } else { + // For large 'sum' (certainly, > Int.MaxValue!), use a Poisson approximation, which has + // a different interpretation. "sum" elements have been observed having scanned a fraction + // p of the data. This suggests data is counted at a rate of sum / p across the whole data + // set. The total expected count from the rest is distributed as + // (1-p) Poisson(sum / p) = Poisson(sum*(1-p)/p) + new PoissonDistribution(sum * (1 - p) / p) + } + // Not quite symmetric; calculate interval straight from discrete distribution + val low = dist.inverseCumulativeProbability((1 - confidence) / 2) + val high = dist.inverseCumulativeProbability((1 + confidence) / 2) + // Add 'sum' to each because distribution is just of remaining count, not observed + new BoundedDouble(sum + dist.getNumericalMean, confidence, sum + low, sum + high) + } + + +} diff --git a/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala index 5afce75680f94..d2b4187df5d50 100644 --- a/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala +++ b/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala @@ -17,15 +17,10 @@ package org.apache.spark.partial -import java.util.{HashMap => JHashMap} - -import scala.collection.JavaConverters._ import scala.collection.Map import scala.collection.mutable.HashMap import scala.reflect.ClassTag -import org.apache.commons.math3.distribution.NormalDistribution - import org.apache.spark.util.collection.OpenHashMap /** @@ -34,10 +29,10 @@ import org.apache.spark.util.collection.OpenHashMap private[spark] class GroupedCountEvaluator[T : ClassTag](totalOutputs: Int, confidence: Double) extends ApproximateEvaluator[OpenHashMap[T, Long], Map[T, BoundedDouble]] { - var outputsMerged = 0 - var sums = new OpenHashMap[T, Long]() // Sum of counts for each key + private var outputsMerged = 0 + private val sums = new OpenHashMap[T, Long]() // Sum of counts for each key - override def merge(outputId: Int, taskResult: OpenHashMap[T, Long]) { + override def merge(outputId: Int, taskResult: OpenHashMap[T, Long]): Unit = { outputsMerged += 1 taskResult.foreach { case (key, value) => sums.changeValue(key, value, _ + value) @@ -46,27 +41,12 @@ private[spark] class GroupedCountEvaluator[T : ClassTag](totalOutputs: Int, conf override def currentResult(): Map[T, BoundedDouble] = { if (outputsMerged == totalOutputs) { - val result = new JHashMap[T, BoundedDouble](sums.size) - sums.foreach { case (key, sum) => - result.put(key, new BoundedDouble(sum, 1.0, sum, sum)) - } - result.asScala + sums.map { case (key, sum) => (key, new BoundedDouble(sum, 1.0, sum, sum)) }.toMap } else if (outputsMerged == 0) { new HashMap[T, BoundedDouble] } else { val p = outputsMerged.toDouble / totalOutputs - val confFactor = new NormalDistribution(). - inverseCumulativeProbability(1 - (1 - confidence) / 2) - val result = new JHashMap[T, BoundedDouble](sums.size) - sums.foreach { case (key, sum) => - val mean = (sum + 1 - p) / p - val variance = (sum + 1) * (1 - p) / (p * p) - val stdev = math.sqrt(variance) - val low = mean - confFactor * stdev - val high = mean + confFactor * stdev - result.put(key, new BoundedDouble(mean, confidence, low, high)) - } - result.asScala + sums.map { case (key, sum) => (key, CountEvaluator.bound(confidence, sum, p)) }.toMap } } } diff --git a/core/src/main/scala/org/apache/spark/partial/GroupedMeanEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/GroupedMeanEvaluator.scala deleted file mode 100644 index a164040684803..0000000000000 --- a/core/src/main/scala/org/apache/spark/partial/GroupedMeanEvaluator.scala +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.partial - -import java.util.{HashMap => JHashMap} - -import scala.collection.JavaConverters._ -import scala.collection.Map -import scala.collection.mutable.HashMap - -import org.apache.spark.util.StatCounter - -/** - * An ApproximateEvaluator for means by key. Returns a map of key to confidence interval. - */ -private[spark] class GroupedMeanEvaluator[T](totalOutputs: Int, confidence: Double) - extends ApproximateEvaluator[JHashMap[T, StatCounter], Map[T, BoundedDouble]] { - - var outputsMerged = 0 - var sums = new JHashMap[T, StatCounter] // Sum of counts for each key - - override def merge(outputId: Int, taskResult: JHashMap[T, StatCounter]) { - outputsMerged += 1 - val iter = taskResult.entrySet.iterator() - while (iter.hasNext) { - val entry = iter.next() - val old = sums.get(entry.getKey) - if (old != null) { - old.merge(entry.getValue) - } else { - sums.put(entry.getKey, entry.getValue) - } - } - } - - override def currentResult(): Map[T, BoundedDouble] = { - if (outputsMerged == totalOutputs) { - val result = new JHashMap[T, BoundedDouble](sums.size) - val iter = sums.entrySet.iterator() - while (iter.hasNext) { - val entry = iter.next() - val mean = entry.getValue.mean - result.put(entry.getKey, new BoundedDouble(mean, 1.0, mean, mean)) - } - result.asScala - } else if (outputsMerged == 0) { - new HashMap[T, BoundedDouble] - } else { - val studentTCacher = new StudentTCacher(confidence) - val result = new JHashMap[T, BoundedDouble](sums.size) - val iter = sums.entrySet.iterator() - while (iter.hasNext) { - val entry = iter.next() - val counter = entry.getValue - val mean = counter.mean - val stdev = math.sqrt(counter.sampleVariance / counter.count) - val confFactor = studentTCacher.get(counter.count) - val low = mean - confFactor * stdev - val high = mean + confFactor * stdev - result.put(entry.getKey, new BoundedDouble(mean, confidence, low, high)) - } - result.asScala - } - } -} diff --git a/core/src/main/scala/org/apache/spark/partial/GroupedSumEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/GroupedSumEvaluator.scala deleted file mode 100644 index 54a1beab3514b..0000000000000 --- a/core/src/main/scala/org/apache/spark/partial/GroupedSumEvaluator.scala +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.partial - -import java.util.{HashMap => JHashMap} - -import scala.collection.JavaConverters._ -import scala.collection.Map -import scala.collection.mutable.HashMap - -import org.apache.spark.util.StatCounter - -/** - * An ApproximateEvaluator for sums by key. Returns a map of key to confidence interval. - */ -private[spark] class GroupedSumEvaluator[T](totalOutputs: Int, confidence: Double) - extends ApproximateEvaluator[JHashMap[T, StatCounter], Map[T, BoundedDouble]] { - - var outputsMerged = 0 - var sums = new JHashMap[T, StatCounter] // Sum of counts for each key - - override def merge(outputId: Int, taskResult: JHashMap[T, StatCounter]) { - outputsMerged += 1 - val iter = taskResult.entrySet.iterator() - while (iter.hasNext) { - val entry = iter.next() - val old = sums.get(entry.getKey) - if (old != null) { - old.merge(entry.getValue) - } else { - sums.put(entry.getKey, entry.getValue) - } - } - } - - override def currentResult(): Map[T, BoundedDouble] = { - if (outputsMerged == totalOutputs) { - val result = new JHashMap[T, BoundedDouble](sums.size) - val iter = sums.entrySet.iterator() - while (iter.hasNext) { - val entry = iter.next() - val sum = entry.getValue.sum - result.put(entry.getKey, new BoundedDouble(sum, 1.0, sum, sum)) - } - result.asScala - } else if (outputsMerged == 0) { - new HashMap[T, BoundedDouble] - } else { - val p = outputsMerged.toDouble / totalOutputs - val studentTCacher = new StudentTCacher(confidence) - val result = new JHashMap[T, BoundedDouble](sums.size) - val iter = sums.entrySet.iterator() - while (iter.hasNext) { - val entry = iter.next() - val counter = entry.getValue - val meanEstimate = counter.mean - val meanVar = counter.sampleVariance / counter.count - val countEstimate = (counter.count + 1 - p) / p - val countVar = (counter.count + 1) * (1 - p) / (p * p) - val sumEstimate = meanEstimate * countEstimate - val sumVar = (meanEstimate * meanEstimate * countVar) + - (countEstimate * countEstimate * meanVar) + - (meanVar * countVar) - val sumStdev = math.sqrt(sumVar) - val confFactor = studentTCacher.get(counter.count) - val low = sumEstimate - confFactor * sumStdev - val high = sumEstimate + confFactor * sumStdev - result.put(entry.getKey, new BoundedDouble(sumEstimate, confidence, low, high)) - } - result.asScala - } - } -} diff --git a/core/src/main/scala/org/apache/spark/partial/MeanEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/MeanEvaluator.scala index 787a21a61fdcf..3fb2d30a800b6 100644 --- a/core/src/main/scala/org/apache/spark/partial/MeanEvaluator.scala +++ b/core/src/main/scala/org/apache/spark/partial/MeanEvaluator.scala @@ -27,10 +27,10 @@ import org.apache.spark.util.StatCounter private[spark] class MeanEvaluator(totalOutputs: Int, confidence: Double) extends ApproximateEvaluator[StatCounter, BoundedDouble] { - var outputsMerged = 0 - var counter = new StatCounter + private var outputsMerged = 0 + private val counter = new StatCounter() - override def merge(outputId: Int, taskResult: StatCounter) { + override def merge(outputId: Int, taskResult: StatCounter): Unit = { outputsMerged += 1 counter.merge(taskResult) } @@ -38,19 +38,24 @@ private[spark] class MeanEvaluator(totalOutputs: Int, confidence: Double) override def currentResult(): BoundedDouble = { if (outputsMerged == totalOutputs) { new BoundedDouble(counter.mean, 1.0, counter.mean, counter.mean) - } else if (outputsMerged == 0) { + } else if (outputsMerged == 0 || counter.count == 0) { new BoundedDouble(0, 0.0, Double.NegativeInfinity, Double.PositiveInfinity) + } else if (counter.count == 1) { + new BoundedDouble(counter.mean, confidence, Double.NegativeInfinity, Double.PositiveInfinity) } else { val mean = counter.mean val stdev = math.sqrt(counter.sampleVariance / counter.count) - val confFactor = { - if (counter.count > 100) { - new NormalDistribution().inverseCumulativeProbability(1 - (1 - confidence) / 2) + val confFactor = if (counter.count > 100) { + // For large n, the normal distribution is a good approximation to t-distribution + new NormalDistribution().inverseCumulativeProbability((1 + confidence) / 2) } else { + // t-distribution describes distribution of actual population mean + // note that if this goes to 0, TDistribution will throw an exception. + // Hence special casing 1 above. val degreesOfFreedom = (counter.count - 1).toInt - new TDistribution(degreesOfFreedom).inverseCumulativeProbability(1 - (1 - confidence) / 2) + new TDistribution(degreesOfFreedom).inverseCumulativeProbability((1 + confidence) / 2) } - } + // Symmetric, so confidence interval is symmetric about mean of distribution val low = mean - confFactor * stdev val high = mean + confFactor * stdev new BoundedDouble(mean, confidence, low, high) diff --git a/core/src/main/scala/org/apache/spark/partial/StudentTCacher.scala b/core/src/main/scala/org/apache/spark/partial/StudentTCacher.scala deleted file mode 100644 index 55acb9ca64d3f..0000000000000 --- a/core/src/main/scala/org/apache/spark/partial/StudentTCacher.scala +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.partial - -import org.apache.commons.math3.distribution.{NormalDistribution, TDistribution} - -/** - * A utility class for caching Student's T distribution values for a given confidence level - * and various sample sizes. This is used by the MeanEvaluator to efficiently calculate - * confidence intervals for many keys. - */ -private[spark] class StudentTCacher(confidence: Double) { - - val NORMAL_APPROX_SAMPLE_SIZE = 100 // For samples bigger than this, use Gaussian approximation - - val normalApprox = new NormalDistribution().inverseCumulativeProbability(1 - (1 - confidence) / 2) - val cache = Array.fill[Double](NORMAL_APPROX_SAMPLE_SIZE)(-1.0) - - def get(sampleSize: Long): Double = { - if (sampleSize >= NORMAL_APPROX_SAMPLE_SIZE) { - normalApprox - } else { - val size = sampleSize.toInt - if (cache(size) < 0) { - val tDist = new TDistribution(size - 1) - cache(size) = tDist.inverseCumulativeProbability(1 - (1 - confidence) / 2) - } - cache(size) - } - } -} diff --git a/core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala index 5fe33583166c3..1988052b733e6 100644 --- a/core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala +++ b/core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala @@ -30,10 +30,10 @@ private[spark] class SumEvaluator(totalOutputs: Int, confidence: Double) extends ApproximateEvaluator[StatCounter, BoundedDouble] { // modified in merge - var outputsMerged = 0 - val counter = new StatCounter + private var outputsMerged = 0 + private val counter = new StatCounter() - override def merge(outputId: Int, taskResult: StatCounter) { + override def merge(outputId: Int, taskResult: StatCounter): Unit = { outputsMerged += 1 counter.merge(taskResult) } @@ -45,34 +45,45 @@ private[spark] class SumEvaluator(totalOutputs: Int, confidence: Double) new BoundedDouble(0, 0.0, Double.NegativeInfinity, Double.PositiveInfinity) } else { val p = outputsMerged.toDouble / totalOutputs + // Expected value of unobserved is presumed equal to that of the observed data val meanEstimate = counter.mean - val countEstimate = (counter.count + 1 - p) / p + // Expected size of rest of the data is proportional + val countEstimate = counter.count * (1 - p) / p + // Expected sum is simply their product val sumEstimate = meanEstimate * countEstimate + // Variance of unobserved data is presumed equal to that of the observed data val meanVar = counter.sampleVariance / counter.count - // branch at this point because counter.count == 1 implies counter.sampleVariance == Nan + // branch at this point because count == 1 implies counter.sampleVariance == Nan // and we don't want to ever return a bound of NaN if (meanVar.isNaN || counter.count == 1) { - new BoundedDouble(sumEstimate, confidence, Double.NegativeInfinity, Double.PositiveInfinity) + // add sum because estimate is of unobserved data sum + new BoundedDouble( + counter.sum + sumEstimate, confidence, Double.NegativeInfinity, Double.PositiveInfinity) } else { - val countVar = (counter.count + 1) * (1 - p) / (p * p) + // See CountEvaluator. Variance of population count here follows from negative binomial + val countVar = counter.count * (1 - p) / (p * p) + // Var(Sum) = Var(Mean*Count) = + // [E(Mean)]^2 * Var(Count) + [E(Count)]^2 * Var(Mean) + Var(Mean) * Var(Count) val sumVar = (meanEstimate * meanEstimate * countVar) + (countEstimate * countEstimate * meanVar) + (meanVar * countVar) val sumStdev = math.sqrt(sumVar) val confFactor = if (counter.count > 100) { - new NormalDistribution().inverseCumulativeProbability(1 - (1 - confidence) / 2) + new NormalDistribution().inverseCumulativeProbability((1 + confidence) / 2) } else { // note that if this goes to 0, TDistribution will throw an exception. // Hence special casing 1 above. val degreesOfFreedom = (counter.count - 1).toInt - new TDistribution(degreesOfFreedom).inverseCumulativeProbability(1 - (1 - confidence) / 2) + new TDistribution(degreesOfFreedom).inverseCumulativeProbability((1 + confidence) / 2) } - + // Symmetric, so confidence interval is symmetric about mean of distribution val low = sumEstimate - confFactor * sumStdev val high = sumEstimate + confFactor * sumStdev - new BoundedDouble(sumEstimate, confidence, low, high) + // add sum because estimate is of unobserved data sum + new BoundedDouble( + counter.sum + sumEstimate, confidence, counter.sum + low, counter.sum + high) } } } diff --git a/core/src/test/scala/org/apache/spark/partial/CountEvaluatorSuite.scala b/core/src/test/scala/org/apache/spark/partial/CountEvaluatorSuite.scala new file mode 100644 index 0000000000000..da3256bd882e8 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/partial/CountEvaluatorSuite.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.partial + +import org.apache.spark.SparkFunSuite + +class CountEvaluatorSuite extends SparkFunSuite { + + test("test count 0") { + val evaluator = new CountEvaluator(10, 0.95) + assert(new BoundedDouble(0.0, 0.0, 0.0, Double.PositiveInfinity) == evaluator.currentResult()) + evaluator.merge(1, 0) + assert(new BoundedDouble(0.0, 0.0, 0.0, Double.PositiveInfinity) == evaluator.currentResult()) + } + + test("test count >= 1") { + val evaluator = new CountEvaluator(10, 0.95) + evaluator.merge(1, 1) + assert(new BoundedDouble(10.0, 0.95, 1.0, 36.0) == evaluator.currentResult()) + evaluator.merge(1, 3) + assert(new BoundedDouble(20.0, 0.95, 7.0, 41.0) == evaluator.currentResult()) + evaluator.merge(1, 8) + assert(new BoundedDouble(40.0, 0.95, 24.0, 61.0) == evaluator.currentResult()) + (4 to 10).foreach(_ => evaluator.merge(1, 10)) + assert(new BoundedDouble(82.0, 1.0, 82.0, 82.0) == evaluator.currentResult()) + } + +} diff --git a/core/src/test/scala/org/apache/spark/partial/MeanEvaluatorSuite.scala b/core/src/test/scala/org/apache/spark/partial/MeanEvaluatorSuite.scala new file mode 100644 index 0000000000000..eaa1262b4199f --- /dev/null +++ b/core/src/test/scala/org/apache/spark/partial/MeanEvaluatorSuite.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.partial + +import org.apache.spark.SparkFunSuite +import org.apache.spark.util.StatCounter + +class MeanEvaluatorSuite extends SparkFunSuite { + + test("test count 0") { + val evaluator = new MeanEvaluator(10, 0.95) + assert(new BoundedDouble(0.0, 0.0, Double.NegativeInfinity, Double.PositiveInfinity) == + evaluator.currentResult()) + evaluator.merge(1, new StatCounter()) + assert(new BoundedDouble(0.0, 0.0, Double.NegativeInfinity, Double.PositiveInfinity) == + evaluator.currentResult()) + evaluator.merge(1, new StatCounter(Seq(0.0))) + assert(new BoundedDouble(0.0, 0.95, Double.NegativeInfinity, Double.PositiveInfinity) == + evaluator.currentResult()) + } + + test("test count 1") { + val evaluator = new MeanEvaluator(10, 0.95) + evaluator.merge(1, new StatCounter(Seq(1.0))) + assert(new BoundedDouble(1.0, 0.95, Double.NegativeInfinity, Double.PositiveInfinity) == + evaluator.currentResult()) + } + + test("test count > 1") { + val evaluator = new MeanEvaluator(10, 0.95) + evaluator.merge(1, new StatCounter(Seq(1.0))) + evaluator.merge(1, new StatCounter(Seq(3.0))) + assert(new BoundedDouble(2.0, 0.95, -10.706204736174746, 14.706204736174746) == + evaluator.currentResult()) + evaluator.merge(1, new StatCounter(Seq(8.0))) + assert(new BoundedDouble(4.0, 0.95, -4.9566858949231225, 12.956685894923123) == + evaluator.currentResult()) + (4 to 10).foreach(_ => evaluator.merge(1, new StatCounter(Seq(9.0)))) + assert(new BoundedDouble(7.5, 1.0, 7.5, 7.5) == evaluator.currentResult()) + } + +} diff --git a/core/src/test/scala/org/apache/spark/partial/SumEvaluatorSuite.scala b/core/src/test/scala/org/apache/spark/partial/SumEvaluatorSuite.scala index a79f5b4d74467..e212db73627e7 100644 --- a/core/src/test/scala/org/apache/spark/partial/SumEvaluatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/partial/SumEvaluatorSuite.scala @@ -17,61 +17,34 @@ package org.apache.spark.partial -import org.apache.spark._ +import org.apache.spark.SparkFunSuite import org.apache.spark.util.StatCounter -class SumEvaluatorSuite extends SparkFunSuite with SharedSparkContext { +class SumEvaluatorSuite extends SparkFunSuite { test("correct handling of count 1") { + // sanity check: + assert(new BoundedDouble(2.0, 0.95, 1.1, 1.2) == new BoundedDouble(2.0, 0.95, 1.1, 1.2)) - // setup - val counter = new StatCounter(List(2.0)) // count of 10 because it's larger than 1, // and 0.95 because that's the default val evaluator = new SumEvaluator(10, 0.95) // arbitrarily assign id 1 - evaluator.merge(1, counter) - - // execute - val res = evaluator.currentResult() - // 38.0 - 7.1E-15 because that's how the maths shakes out - val targetMean = 38.0 - 7.1E-15 - - // Sanity check that equality works on BoundedDouble - assert(new BoundedDouble(2.0, 0.95, 1.1, 1.2) == new BoundedDouble(2.0, 0.95, 1.1, 1.2)) - - // actual test - assert(res == - new BoundedDouble(targetMean, 0.950, Double.NegativeInfinity, Double.PositiveInfinity)) + evaluator.merge(1, new StatCounter(Seq(2.0))) + assert(new BoundedDouble(20.0, 0.95, Double.NegativeInfinity, Double.PositiveInfinity) == + evaluator.currentResult()) } test("correct handling of count 0") { - - // setup - val counter = new StatCounter(List()) - // count of 10 because it's larger than 0, - // and 0.95 because that's the default val evaluator = new SumEvaluator(10, 0.95) - // arbitrarily assign id 1 - evaluator.merge(1, counter) - - // execute - val res = evaluator.currentResult() - // assert - assert(res == new BoundedDouble(0, 0.0, Double.NegativeInfinity, Double.PositiveInfinity)) + evaluator.merge(1, new StatCounter()) + assert(new BoundedDouble(0, 0.0, Double.NegativeInfinity, Double.PositiveInfinity) == + evaluator.currentResult()) } test("correct handling of NaN") { - - // setup - val counter = new StatCounter(List(1, Double.NaN, 2)) - // count of 10 because it's larger than 0, - // and 0.95 because that's the default val evaluator = new SumEvaluator(10, 0.95) - // arbitrarily assign id 1 - evaluator.merge(1, counter) - - // execute + evaluator.merge(1, new StatCounter(Seq(1, Double.NaN, 2))) val res = evaluator.currentResult() // assert - note semantics of == in face of NaN assert(res.mean.isNaN) @@ -81,27 +54,24 @@ class SumEvaluatorSuite extends SparkFunSuite with SharedSparkContext { } test("correct handling of > 1 values") { - - // setup - val counter = new StatCounter(List(1, 3, 2)) - // count of 10 because it's larger than 0, - // and 0.95 because that's the default val evaluator = new SumEvaluator(10, 0.95) - // arbitrarily assign id 1 - evaluator.merge(1, counter) - - // execute + evaluator.merge(1, new StatCounter(Seq(1.0, 3.0, 2.0))) val res = evaluator.currentResult() + assert(new BoundedDouble(60.0, 0.95, -101.7362525347778, 221.7362525347778) == + evaluator.currentResult()) + } - // These vals because that's how the maths shakes out - val targetMean = 78.0 - val targetLow = -117.617 + 2.732357258139473E-5 - val targetHigh = 273.617 - 2.7323572624027292E-5 - val target = new BoundedDouble(targetMean, 0.95, targetLow, targetHigh) - - - // check that values are within expected tolerance of expectation - assert(res == target) + test("test count > 1") { + val evaluator = new SumEvaluator(10, 0.95) + evaluator.merge(1, new StatCounter().merge(1.0)) + evaluator.merge(1, new StatCounter().merge(3.0)) + assert(new BoundedDouble(20.0, 0.95, -186.4513905077019, 226.4513905077019) == + evaluator.currentResult()) + evaluator.merge(1, new StatCounter().merge(8.0)) + assert(new BoundedDouble(40.0, 0.95, -72.75723361226733, 152.75723361226733) == + evaluator.currentResult()) + (4 to 10).foreach(_ => evaluator.merge(1, new StatCounter().merge(9.0))) + assert(new BoundedDouble(75.0, 1.0, 75.0, 75.0) == evaluator.currentResult()) } } From 8a6bbe095b6a9aa33989c0deaa5ed0128d70320f Mon Sep 17 00:00:00 2001 From: Weiqing Yang Date: Sat, 8 Oct 2016 12:12:35 +0100 Subject: [PATCH 227/306] [MINOR][SQL] Use resource path for test_script.sh ## What changes were proposed in this pull request? This PR modified the test case `test("script")` to use resource path for `test_script.sh`. Make the test case portable (even in IntelliJ). ## How was this patch tested? Passed the test case. Before: Run `test("script")` in IntelliJ: ``` Caused by: org.apache.spark.SparkException: Subprocess exited with status 127. Error: bash: src/test/resources/test_script.sh: No such file or directory ``` After: Test passed. Author: Weiqing Yang Closes #15246 from weiqingy/hivetest. --- .../scala/org/apache/spark/SparkFunSuite.scala | 11 +++++++++++ .../spark/deploy/history/HistoryServerSuite.scala | 6 +++--- .../test/scala/org/apache/spark/ui/UISuite.scala | 3 ++- .../sql/catalyst/LogicalPlanToSQLSuite.scala | 2 +- .../spark/sql/hive/HiveSparkSubmitSuite.scala | 3 ++- .../spark/sql/hive/execution/SQLQuerySuite.scala | 15 +++++++++------ 6 files changed, 28 insertions(+), 12 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala index cd876807f890e..18077c08c9dcc 100644 --- a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala @@ -18,6 +18,8 @@ package org.apache.spark // scalastyle:off +import java.io.File + import org.scalatest.{BeforeAndAfterAll, FunSuite, Outcome} import org.apache.spark.internal.Logging @@ -41,6 +43,15 @@ abstract class SparkFunSuite } } + // helper function + protected final def getTestResourceFile(file: String): File = { + new File(getClass.getClassLoader.getResource(file).getFile) + } + + protected final def getTestResourcePath(file: String): String = { + getTestResourceFile(file).getCanonicalPath + } + /** * Log the suite name and the test name before and after each test. * diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index 5b316b2f6b4b7..a595bc174a310 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -59,8 +59,8 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers with JsonTestUtils with Eventually with WebBrowser with LocalSparkContext with ResetSystemProperties { - private val logDir = new File("src/test/resources/spark-events") - private val expRoot = new File("src/test/resources/HistoryServerExpectations/") + private val logDir = getTestResourcePath("spark-events") + private val expRoot = getTestResourceFile("HistoryServerExpectations") private var provider: FsHistoryProvider = null private var server: HistoryServer = null @@ -68,7 +68,7 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers def init(): Unit = { val conf = new SparkConf() - .set("spark.history.fs.logDirectory", logDir.getAbsolutePath) + .set("spark.history.fs.logDirectory", logDir) .set("spark.history.fs.update.interval", "0") .set("spark.testing", "true") provider = new FsHistoryProvider(conf) diff --git a/core/src/test/scala/org/apache/spark/ui/UISuite.scala b/core/src/test/scala/org/apache/spark/ui/UISuite.scala index dbb8dca4c8dab..4abcfb7e51914 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISuite.scala @@ -53,9 +53,10 @@ class UISuite extends SparkFunSuite { } private def sslEnabledConf(): (SparkConf, SSLOptions) = { + val keyStoreFilePath = getTestResourcePath("spark.keystore") val conf = new SparkConf() .set("spark.ssl.ui.enabled", "true") - .set("spark.ssl.ui.keyStore", "./src/test/resources/spark.keystore") + .set("spark.ssl.ui.keyStore", keyStoreFilePath) .set("spark.ssl.ui.keyStorePassword", "123456") .set("spark.ssl.ui.keyPassword", "123456") (conf, new SecurityManager(conf).getSSLOptions("ui")) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala index 9ac1e86fc82cb..c7f10e569fa4d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala @@ -45,7 +45,7 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { // Used for generating new query answer files by saving private val regenerateGoldenFiles: Boolean = System.getenv("SPARK_GENERATE_GOLDEN_FILES") == "1" - private val goldenSQLPath = "src/test/resources/sqlgen/" + private val goldenSQLPath = getTestResourcePath("sqlgen") protected override def beforeAll(): Unit = { super.beforeAll() diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index 29317e2887861..d3873cf6c8231 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -152,7 +152,8 @@ class HiveSparkSubmitSuite case v if v.startsWith("2.10") || v.startsWith("2.11") => v.substring(0, 4) case x => throw new Exception(s"Unsupported Scala Version: $x") } - val testJar = s"sql/hive/src/test/resources/regression-test-SPARK-8489/test-$version.jar" + val jarDir = getTestResourcePath("regression-test-SPARK-8489") + val testJar = s"$jarDir/test-$version.jar" val args = Seq( "--conf", "spark.ui.enabled=false", "--conf", "spark.master.rest.enabled=false", diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 6c77a0deb52a4..6f2a16662bf10 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -66,13 +66,14 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { import spark.implicits._ test("script") { + val scriptFilePath = getTestResourcePath("test_script.sh") if (testCommandAvailable("bash") && testCommandAvailable("echo | sed")) { val df = Seq(("x1", "y1", "z1"), ("x2", "y2", "z2")).toDF("c1", "c2", "c3") df.createOrReplaceTempView("script_table") val query1 = sql( - """ + s""" |SELECT col1 FROM (from(SELECT c1, c2, c3 FROM script_table) tempt_table - |REDUCE c1, c2, c3 USING 'bash src/test/resources/test_script.sh' AS + |REDUCE c1, c2, c3 USING 'bash $scriptFilePath' AS |(col1 STRING, col2 STRING)) script_test_table""".stripMargin) checkAnswer(query1, Row("x1_y1") :: Row("x2_y2") :: Nil) } @@ -1290,11 +1291,12 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { .selectExpr("id AS a", "id AS b") .createOrReplaceTempView("test") + val scriptFilePath = getTestResourcePath("data") checkAnswer( sql( - """FROM( + s"""FROM( | FROM test SELECT TRANSFORM(a, b) - | USING 'python src/test/resources/data/scripts/test_transform.py "\t"' + | USING 'python $scriptFilePath/scripts/test_transform.py "\t"' | AS (c STRING, d STRING) |) t |SELECT c @@ -1308,12 +1310,13 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { .selectExpr("id AS a", "id AS b") .createOrReplaceTempView("test") + val scriptFilePath = getTestResourcePath("data") val df = sql( - """FROM test + s"""FROM test |SELECT TRANSFORM(a, b) |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' |WITH SERDEPROPERTIES('field.delim' = '|') - |USING 'python src/test/resources/data/scripts/test_transform.py "|"' + |USING 'python $scriptFilePath/scripts/test_transform.py "|"' |AS (c STRING, d STRING) |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' |WITH SERDEPROPERTIES('field.delim' = '|') From 26fbca480604ba258f97b9590cfd6dda1ecd31db Mon Sep 17 00:00:00 2001 From: jiangxingbo Date: Sun, 9 Oct 2016 21:52:46 -0700 Subject: [PATCH 228/306] [SPARK-17832][SQL] TableIdentifier.quotedString creates un-parseable names when name contains a backtick ## What changes were proposed in this pull request? The `quotedString` method in `TableIdentifier` and `FunctionIdentifier` produce an illegal (un-parseable) name when the name contains a backtick. For example: ``` import org.apache.spark.sql.catalyst.parser.CatalystSqlParser._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute val complexName = TableIdentifier("`weird`table`name", Some("`d`b`1")) parseTableIdentifier(complexName.unquotedString) // Does not work parseTableIdentifier(complexName.quotedString) // Does not work parseExpression(complexName.unquotedString) // Does not work parseExpression(complexName.quotedString) // Does not work ``` We should handle the backtick properly to make `quotedString` parseable. ## How was this patch tested? Add new testcases in `TableIdentifierParserSuite` and `ExpressionParserSuite`. Author: jiangxingbo Closes #15403 from jiangxb1987/backtick. --- .../org/apache/spark/sql/catalyst/identifiers.scala | 11 +++++++++-- .../sql/catalyst/parser/ExpressionParserSuite.scala | 11 ++++++++++- .../catalyst/parser/TableIdentifierParserSuite.scala | 10 ++++++++++ 3 files changed, 29 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala index d7b48ceca591a..834897b85023d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst - /** * An identifier that optionally specifies a database. * @@ -29,8 +28,16 @@ sealed trait IdentifierWithDatabase { def database: Option[String] + /* + * Escapes back-ticks within the identifier name with double-back-ticks. + */ + private def quoteIdentifier(name: String): String = name.replace("`", "``") + def quotedString: String = { - if (database.isDefined) s"`${database.get}`.`$identifier`" else s"`$identifier`" + val replacedId = quoteIdentifier(identifier) + val replacedDb = database.map(quoteIdentifier(_)) + + if (replacedDb.isDefined) s"`${replacedDb.get}`.`$replacedId`" else s"`$replacedId`" } def unquotedString: String = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index 0fb1138478a9b..17cfc8158803b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.parser import java.sql.{Date, Timestamp} -import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, _} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.PlanTest @@ -535,4 +535,13 @@ class ExpressionParserSuite extends PlanTest { // ".123BD" should not be treated as token of type BIGDECIMAL_LITERAL assertEqual("a.123BD_column", UnresolvedAttribute("a.123BD_column")) } + + test("SPARK-17832 function identifier contains backtick") { + val complexName = FunctionIdentifier("`ba`r", Some("`fo`o")) + assertEqual(complexName.quotedString, UnresolvedAttribute("`fo`o.`ba`r")) + intercept(complexName.unquotedString, "mismatched input") + // Function identifier contains countious backticks should be treated correctly. + val complexName2 = FunctionIdentifier("ba``r", Some("fo``o")) + assertEqual(complexName2.quotedString, UnresolvedAttribute("fo``o.ba``r")) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala index 793be8953d07a..7d46011b410e2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala @@ -104,4 +104,14 @@ class TableIdentifierParserSuite extends SparkFunSuite { // ".123BD" should not be treated as token of type BIGDECIMAL_LITERAL assert(parseTableIdentifier("a.123BD_LIST") == TableIdentifier("123BD_LIST", Some("a"))) } + + test("SPARK-17832 table identifier - contains backtick") { + val complexName = TableIdentifier("`weird`table`name", Some("`d`b`1")) + assert(complexName === parseTableIdentifier("```d``b``1`.```weird``table``name`")) + assert(complexName === parseTableIdentifier(complexName.quotedString)) + intercept[ParseException](parseTableIdentifier(complexName.unquotedString)) + // Table identifier contains countious backticks should be treated correctly. + val complexName2 = TableIdentifier("x``y", Some("d``b")) + assert(complexName2 === parseTableIdentifier(complexName2.quotedString)) + } } From 16590030c15b32e83b584283697b6f783cffe043 Mon Sep 17 00:00:00 2001 From: jiangxingbo Date: Sun, 9 Oct 2016 22:00:54 -0700 Subject: [PATCH 229/306] [SPARK-17741][SQL] Grammar to parse top level and nested data fields separately ## What changes were proposed in this pull request? Currently we use the same rule to parse top level and nested data fields. For example: ``` create table tbl_x( id bigint, nested struct ) ``` Shows both syntaxes. In this PR we split this rule in a top-level and nested rule. Before this PR, ``` sql("CREATE TABLE my_tab(column1: INT)") ``` works fine. After this PR, it will throw a `ParseException`: ``` scala> sql("CREATE TABLE my_tab(column1: INT)") org.apache.spark.sql.catalyst.parser.ParseException: no viable alternative at input 'CREATE TABLE my_tab(column1:'(line 1, pos 27) ``` ## How was this patch tested? Add new testcases in `SparkSqlParserSuite`. Author: jiangxingbo Closes #15346 from jiangxb1987/cdt. --- .../spark/sql/catalyst/parser/SqlBase.g4 | 12 +- .../sql/catalyst/parser/AstBuilder.scala | 32 +++- .../catalyst/parser/DataTypeParserSuite.scala | 14 +- .../spark/sql/execution/SparkSqlParser.scala | 4 +- .../sql/execution/SparkSqlParserSuite.scala | 152 +++++++++++++++++- .../sql/execution/command/DDLSuite.scala | 2 +- .../spark/sql/hive/HiveDDLCommandSuite.scala | 2 +- 7 files changed, 195 insertions(+), 23 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 6a94def65f360..a3bbaceca371b 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -584,7 +584,7 @@ intervalValue dataType : complex=ARRAY '<' dataType '>' #complexDataType | complex=MAP '<' dataType ',' dataType '>' #complexDataType - | complex=STRUCT ('<' colTypeList? '>' | NEQ) #complexDataType + | complex=STRUCT ('<' complexColTypeList? '>' | NEQ) #complexDataType | identifier ('(' INTEGER_VALUE (',' INTEGER_VALUE)* ')')? #primitiveDataType ; @@ -593,7 +593,15 @@ colTypeList ; colType - : identifier ':'? dataType (COMMENT STRING)? + : identifier dataType (COMMENT STRING)? + ; + +complexColTypeList + : complexColType (',' complexColType)* + ; + +complexColType + : identifier ':' dataType (COMMENT STRING)? ; whenClause diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index bf3f30279a6fe..929c1c4f2d9e4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -316,7 +316,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { // Create the attributes. val (attributes, schemaLess) = if (colTypeList != null) { // Typed return columns. - (createStructType(colTypeList).toAttributes, false) + (createSchema(colTypeList).toAttributes, false) } else if (identifierSeq != null) { // Untyped return columns. val attrs = visitIdentifierSeq(identifierSeq).map { name => @@ -1450,14 +1450,14 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { case SqlBaseParser.MAP => MapType(typedVisit(ctx.dataType(0)), typedVisit(ctx.dataType(1))) case SqlBaseParser.STRUCT => - createStructType(ctx.colTypeList()) + createStructType(ctx.complexColTypeList()) } } /** - * Create a [[StructType]] from a sequence of [[StructField]]s. + * Create top level table schema. */ - protected def createStructType(ctx: ColTypeListContext): StructType = { + protected def createSchema(ctx: ColTypeListContext): StructType = { StructType(Option(ctx).toSeq.flatMap(visitColTypeList)) } @@ -1476,4 +1476,28 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { val structField = StructField(identifier.getText, typedVisit(dataType), nullable = true) if (STRING == null) structField else structField.withComment(string(STRING)) } + + /** + * Create a [[StructType]] from a sequence of [[StructField]]s. + */ + protected def createStructType(ctx: ComplexColTypeListContext): StructType = { + StructType(Option(ctx).toSeq.flatMap(visitComplexColTypeList)) + } + + /** + * Create a [[StructType]] from a number of column definitions. + */ + override def visitComplexColTypeList( + ctx: ComplexColTypeListContext): Seq[StructField] = withOrigin(ctx) { + ctx.complexColType().asScala.map(visitComplexColType) + } + + /** + * Create a [[StructField]] from a column definition. + */ + override def visitComplexColType(ctx: ComplexColTypeContext): StructField = withOrigin(ctx) { + import ctx._ + val structField = StructField(identifier.getText, typedVisit(dataType), nullable = true) + if (STRING == null) structField else structField.withComment(string(STRING)) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala index 020fb16f6f3d5..3964fa3924b24 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala @@ -116,6 +116,7 @@ class DataTypeParserSuite extends SparkFunSuite { unsupported("it is not a data type") unsupported("struct") unsupported("struct") // DataType parser accepts certain reserved keywords. checkDataType( @@ -125,16 +126,11 @@ class DataTypeParserSuite extends SparkFunSuite { StructField("DATE", BooleanType, true) :: Nil) ) - // Define struct columns without ':' - checkDataType( - "struct", - (new StructType).add("x", IntegerType).add("y", StringType)) - - checkDataType( - "struct<`x``y` int>", - (new StructType).add("x`y", IntegerType)) - // Use SQL keywords. checkDataType("struct", (new StructType).add("end", LongType).add("select", IntegerType).add("from", StringType)) + + // DataType parser accepts comments. + checkDataType("Struct", + (new StructType).add("x", IntegerType).add("y", StringType, true, "test")) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 085bb9fc3c6cc..5f87b71210d31 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -340,7 +340,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { if (provider.toLowerCase == "hive") { throw new AnalysisException("Cannot create hive serde table with CREATE TABLE USING") } - val schema = Option(ctx.colTypeList()).map(createStructType) + val schema = Option(ctx.colTypeList()).map(createSchema) val partitionColumnNames = Option(ctx.partitionColumnNames) .map(visitIdentifierList(_).toArray) @@ -399,7 +399,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { ctx: CreateTempViewUsingContext): LogicalPlan = withOrigin(ctx) { CreateTempViewUsing( tableIdent = visitTableIdentifier(ctx.tableIdentifier()), - userSpecifiedSchema = Option(ctx.colTypeList()).map(createStructType), + userSpecifiedSchema = Option(ctx.colTypeList()).map(createSchema), replace = ctx.REPLACE != null, provider = ctx.tableProvider.qualifiedName.getText, options = Option(ctx.tablePropertyList).map(visitPropertyKeyValues).getOrElse(Map.empty)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala index 6712d32924890..e0976ae95001e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala @@ -17,13 +17,17 @@ package org.apache.spark.sql.execution +import org.apache.spark.sql.SaveMode import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType} import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.command.{DescribeFunctionCommand, DescribeTableCommand, ShowFunctionsCommand} -import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.execution.datasources.{CreateTable, CreateTempViewUsing} +import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} +import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType} /** * Parser test cases for rules defined in [[SparkSqlParser]]. @@ -35,8 +39,23 @@ class SparkSqlParserSuite extends PlanTest { private lazy val parser = new SparkSqlParser(new SQLConf) + /** + * Normalizes plans: + * - CreateTable the createTime in tableDesc will replaced by -1L. + */ + private def normalizePlan(plan: LogicalPlan): LogicalPlan = { + plan match { + case CreateTable(tableDesc, mode, query) => + val newTableDesc = tableDesc.copy(createTime = -1L) + CreateTable(newTableDesc, mode, query) + case _ => plan // Don't transform + } + } + private def assertEqual(sqlCommand: String, plan: LogicalPlan): Unit = { - comparePlans(parser.parsePlan(sqlCommand), plan) + val normalized1 = normalizePlan(parser.parsePlan(sqlCommand)) + val normalized2 = normalizePlan(plan) + comparePlans(normalized1, normalized2) } private def intercept(sqlCommand: String, messages: String*): Unit = { @@ -68,9 +87,134 @@ class SparkSqlParserSuite extends PlanTest { DescribeFunctionCommand(FunctionIdentifier("bar", database = None), isExtended = true)) assertEqual("describe function foo.bar", DescribeFunctionCommand( - FunctionIdentifier("bar", database = Option("foo")), isExtended = false)) + FunctionIdentifier("bar", database = Some("foo")), isExtended = false)) assertEqual("describe function extended f.bar", - DescribeFunctionCommand(FunctionIdentifier("bar", database = Option("f")), isExtended = true)) + DescribeFunctionCommand(FunctionIdentifier("bar", database = Some("f")), isExtended = true)) + } + + private def createTableUsing( + table: String, + database: Option[String] = None, + tableType: CatalogTableType = CatalogTableType.MANAGED, + storage: CatalogStorageFormat = CatalogStorageFormat.empty, + schema: StructType = new StructType, + provider: Option[String] = Some("parquet"), + partitionColumnNames: Seq[String] = Seq.empty, + bucketSpec: Option[BucketSpec] = None, + mode: SaveMode = SaveMode.ErrorIfExists, + query: Option[LogicalPlan] = None): CreateTable = { + CreateTable( + CatalogTable( + identifier = TableIdentifier(table, database), + tableType = tableType, + storage = storage, + schema = schema, + provider = provider, + partitionColumnNames = partitionColumnNames, + bucketSpec = bucketSpec + ), mode, query + ) + } + + private def createTempViewUsing( + table: String, + database: Option[String] = None, + schema: Option[StructType] = None, + replace: Boolean = true, + provider: String = "parquet", + options: Map[String, String] = Map.empty): LogicalPlan = { + CreateTempViewUsing(TableIdentifier(table, database), schema, replace, provider, options) + } + + private def createTable( + table: String, + database: Option[String] = None, + tableType: CatalogTableType = CatalogTableType.MANAGED, + storage: CatalogStorageFormat = CatalogStorageFormat.empty.copy( + inputFormat = HiveSerDe.sourceToSerDe("textfile").get.inputFormat, + outputFormat = HiveSerDe.sourceToSerDe("textfile").get.outputFormat), + schema: StructType = new StructType, + provider: Option[String] = Some("hive"), + partitionColumnNames: Seq[String] = Seq.empty, + comment: Option[String] = None, + mode: SaveMode = SaveMode.ErrorIfExists, + query: Option[LogicalPlan] = None): CreateTable = { + CreateTable( + CatalogTable( + identifier = TableIdentifier(table, database), + tableType = tableType, + storage = storage, + schema = schema, + provider = provider, + partitionColumnNames = partitionColumnNames, + comment = comment + ), mode, query + ) + } + + test("create table - schema") { + assertEqual("CREATE TABLE my_tab(a INT COMMENT 'test', b STRING)", + createTable( + table = "my_tab", + schema = (new StructType) + .add("a", IntegerType, nullable = true, "test") + .add("b", StringType) + ) + ) + assertEqual("CREATE TABLE my_tab(a INT COMMENT 'test', b STRING) " + + "PARTITIONED BY (c INT, d STRING COMMENT 'test2')", + createTable( + table = "my_tab", + schema = (new StructType) + .add("a", IntegerType, nullable = true, "test") + .add("b", StringType) + .add("c", IntegerType) + .add("d", StringType, nullable = true, "test2"), + partitionColumnNames = Seq("c", "d") + ) + ) + assertEqual("CREATE TABLE my_tab(id BIGINT, nested STRUCT)", + createTable( + table = "my_tab", + schema = (new StructType) + .add("id", LongType) + .add("nested", (new StructType) + .add("col1", StringType) + .add("col2", IntegerType) + ) + ) + ) + // Partitioned by a StructType should be accepted by `SparkSqlParser` but will fail an analyze + // rule in `AnalyzeCreateTable`. + assertEqual("CREATE TABLE my_tab(a INT COMMENT 'test', b STRING) " + + "PARTITIONED BY (nested STRUCT)", + createTable( + table = "my_tab", + schema = (new StructType) + .add("a", IntegerType, nullable = true, "test") + .add("b", StringType) + .add("nested", (new StructType) + .add("col1", StringType) + .add("col2", IntegerType) + ), + partitionColumnNames = Seq("nested") + ) + ) + intercept("CREATE TABLE my_tab(a: INT COMMENT 'test', b: STRING)", + "no viable alternative at input") + } + + test("create table using - schema") { + assertEqual("CREATE TABLE my_tab(a INT COMMENT 'test', b STRING) USING parquet", + createTableUsing( + table = "my_tab", + schema = (new StructType) + .add("a", IntegerType, nullable = true, "test") + .add("b", StringType) + ) + ) + intercept("CREATE TABLE my_tab(a: INT COMMENT 'test', b: STRING) USING parquet", + "no viable alternative at input") } test("SPARK-17328 Fix NPE with EXPLAIN DESCRIBE TABLE") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index b5499f2884c61..1bcb810a1564f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -642,7 +642,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { val csvFile = Thread.currentThread().getContextClassLoader.getResource("test-data/cars.csv").toString withView("testview") { - sql(s"CREATE OR REPLACE TEMPORARY VIEW testview (c1: String, c2: String) USING " + + sql(s"CREATE OR REPLACE TEMPORARY VIEW testview (c1 String, c2 String) USING " + "org.apache.spark.sql.execution.datasources.csv.CSVFileFormat " + s"OPTIONS (PATH '$csvFile')") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala index 54e27b6f73502..9ce3338647398 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala @@ -243,7 +243,7 @@ class HiveDDLCommandSuite extends PlanTest { .asInstanceOf[ScriptTransformation].copy(ioschema = null) val plan2 = parser.parsePlan("map a, b using 'func' as c, d from e") .asInstanceOf[ScriptTransformation].copy(ioschema = null) - val plan3 = parser.parsePlan("reduce a, b using 'func' as (c: int, d decimal(10, 0)) from e") + val plan3 = parser.parsePlan("reduce a, b using 'func' as (c int, d decimal(10, 0)) from e") .asInstanceOf[ScriptTransformation].copy(ioschema = null) val p = ScriptTransformation( From 23ddff4b2b2744c3dc84d928e144c541ad5df376 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 10 Oct 2016 15:48:57 +0800 Subject: [PATCH 230/306] [SPARK-17338][SQL] add global temp view ## What changes were proposed in this pull request? Global temporary view is a cross-session temporary view, which means it's shared among all sessions. Its lifetime is the lifetime of the Spark application, i.e. it will be automatically dropped when the application terminates. It's tied to a system preserved database `global_temp`(configurable via SparkConf), and we must use the qualified name to refer a global temp view, e.g. SELECT * FROM global_temp.view1. changes for `SessionCatalog`: 1. add a new field `gloabalTempViews: GlobalTempViewManager`, to access the shared global temp views, and the global temp db name. 2. `createDatabase` will fail if users wanna create `global_temp`, which is system preserved. 3. `setCurrentDatabase` will fail if users wanna set `global_temp`, which is system preserved. 4. add `createGlobalTempView`, which is used in `CreateViewCommand` to create global temp views. 5. add `dropGlobalTempView`, which is used in `CatalogImpl` to drop global temp view. 6. add `alterTempViewDefinition`, which is used in `AlterViewAsCommand` to update the view definition for local/global temp views. 7. `renameTable`/`dropTable`/`isTemporaryTable`/`lookupRelation`/`getTempViewOrPermanentTableMetadata`/`refreshTable` will handle global temp views. changes for SQL commands: 1. `CreateViewCommand`/`AlterViewAsCommand` is updated to support global temp views 2. `ShowTablesCommand` outputs a new column `database`, which is used to distinguish global and local temp views. 3. other commands can also handle global temp views if they call `SessionCatalog` APIs which accepts global temp views, e.g. `DropTableCommand`, `AlterTableRenameCommand`, `ShowColumnsCommand`, etc. changes for other public API 1. add a new method `dropGlobalTempView` in `Catalog` 2. `Catalog.findTable` can find global temp view 3. add a new method `createGlobalTempView` in `Dataset` ## How was this patch tested? new tests in `SQLViewSuite` Author: Wenchen Fan Closes #14897 from cloud-fan/global-temp-view. --- .../spark/internal/config/package.scala | 7 + docs/sql-programming-guide.md | 45 ++++- .../examples/sql/JavaSparkSQLExample.java | 30 ++- examples/src/main/python/sql/basic.py | 25 +++ .../spark/examples/sql/SparkSQLExample.scala | 25 +++ project/MimaExcludes.scala | 4 +- python/pyspark/sql/catalog.py | 18 +- python/pyspark/sql/context.py | 2 +- python/pyspark/sql/dataframe.py | 25 ++- .../spark/sql/catalyst/parser/SqlBase.g4 | 8 +- .../sql/catalyst/analysis/Analyzer.scala | 10 +- .../catalog/GlobalTempViewManager.scala | 121 +++++++++++ .../sql/catalyst/catalog/SessionCatalog.scala | 189 ++++++++++++++---- .../scala/org/apache/spark/sql/Dataset.scala | 48 ++++- .../apache/spark/sql/catalog/Catalog.scala | 20 +- .../spark/sql/execution/QueryExecution.scala | 8 +- .../spark/sql/execution/SparkSqlParser.scala | 19 +- .../spark/sql/execution/command/ddl.scala | 25 ++- .../spark/sql/execution/command/tables.scala | 11 +- .../spark/sql/execution/command/views.scala | 150 +++++++------- .../spark/sql/execution/datasources/ddl.scala | 20 +- .../spark/sql/internal/CatalogImpl.scala | 26 ++- .../spark/sql/internal/SessionState.scala | 1 + .../spark/sql/internal/SharedState.scala | 75 ++++--- .../apache/spark/sql/SQLContextSuite.scala | 11 +- .../sql/execution/GlobalTempViewSuite.scala | 168 ++++++++++++++++ .../sql/execution/command/DDLSuite.scala | 10 +- .../spark/sql/hive/HiveSessionCatalog.scala | 4 +- .../spark/sql/hive/HiveSessionState.scala | 1 + .../hive/HiveContextCompatibilitySuite.scala | 4 +- .../spark/sql/hive/ListTablesSuite.scala | 8 +- .../sql/hive/MetastoreDataSourcesSuite.scala | 2 +- .../sql/hive/execution/HiveCommandSuite.scala | 10 +- .../sql/hive/execution/SQLViewSuite.scala | 6 +- 34 files changed, 906 insertions(+), 230 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/GlobalTempViewManager.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index d536cc5097b2d..0896e68eca7dc 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -98,6 +98,13 @@ package object config { .checkValues(Set("hive", "in-memory")) .createWithDefault("in-memory") + // Note: This is a SQL config but needs to be in core because it's cross-session and can not put + // in SQLConf. + private[spark] val GLOBAL_TEMP_DATABASE = ConfigBuilder("spark.sql.globalTempDatabase") + .internal() + .stringConf + .createWithDefault("global_temp") + private[spark] val LISTENER_BUS_EVENT_QUEUE_SIZE = ConfigBuilder("spark.scheduler.listenerbus.eventqueue.size") .intConf diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 71bdd19c16dbb..835cb6981f5bd 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -220,6 +220,41 @@ The `sql` function enables applications to run SQL queries programmatically and
      +## Global Temporary View + +Temporay views in Spark SQL are session-scoped and will disappear if the session that creates it +terminates. If you want to have a temporary view that is shared among all sessions and keep alive +until the Spark application terminiates, you can create a global temporary view. Global temporary +view is tied to a system preserved database `global_temp`, and we must use the qualified name to +refer it, e.g. `SELECT * FROM global_temp.view1`. + +
      +
      +{% include_example global_temp_view scala/org/apache/spark/examples/sql/SparkSQLExample.scala %} +
      + +
      +{% include_example global_temp_view java/org/apache/spark/examples/sql/JavaSparkSQLExample.java %} +
      + +
      +{% include_example global_temp_view python/sql/basic.py %} +
      + +
      + +{% highlight sql %} + +CREATE GLOBAL TEMPORARY VIEW temp_view AS SELECT a + 1, b * 2 FROM tbl + +SELECT * FROM global_temp.temp_view + +{% endhighlight %} + +
      +
      + + ## Creating Datasets Datasets are similar to RDDs, however, instead of using Java serialization or Kryo they use @@ -1058,14 +1093,14 @@ the Data Sources API. The following options are supported: The JDBC fetch size, which determines how many rows to fetch per round trip. This can help performance on JDBC drivers which default to low fetch size (eg. Oracle with 10 rows). - + truncate - This is a JDBC writer related option. When SaveMode.Overwrite is enabled, this option causes Spark to truncate an existing table instead of dropping and recreating it. This can be more efficient, and prevents the table metadata (e.g. indices) from being removed. However, it will not work in some cases, such as when the new data has a different schema. It defaults to false. + This is a JDBC writer related option. When SaveMode.Overwrite is enabled, this option causes Spark to truncate an existing table instead of dropping and recreating it. This can be more efficient, and prevents the table metadata (e.g. indices) from being removed. However, it will not work in some cases, such as when the new data has a different schema. It defaults to false. - + createTableOptions @@ -1101,11 +1136,11 @@ USING org.apache.spark.sql.jdbc OPTIONS ( url "jdbc:postgresql:dbserver", dbtable "schema.tablename", - user 'username', + user 'username', password 'password' ) -INSERT INTO TABLE jdbcTable +INSERT INTO TABLE jdbcTable SELECT * FROM resultTable {% endhighlight %} diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQLExample.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQLExample.java index cff9032f52b5a..c5770d147a6b5 100644 --- a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQLExample.java +++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQLExample.java @@ -54,6 +54,7 @@ import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; // $example off:programmatic_schema$ +import org.apache.spark.sql.AnalysisException; // $example on:untyped_ops$ // col("...") is preferable to df.col("...") @@ -84,7 +85,7 @@ public void setAge(int age) { } // $example off:create_ds$ - public static void main(String[] args) { + public static void main(String[] args) throws AnalysisException { // $example on:init_session$ SparkSession spark = SparkSession .builder() @@ -101,7 +102,7 @@ public static void main(String[] args) { spark.stop(); } - private static void runBasicDataFrameExample(SparkSession spark) { + private static void runBasicDataFrameExample(SparkSession spark) throws AnalysisException { // $example on:create_df$ Dataset df = spark.read().json("examples/src/main/resources/people.json"); @@ -176,6 +177,31 @@ private static void runBasicDataFrameExample(SparkSession spark) { // | 19| Justin| // +----+-------+ // $example off:run_sql$ + + // $example on:global_temp_view$ + // Register the DataFrame as a global temporary view + df.createGlobalTempView("people"); + + // Global temporary view is tied to a system preserved database `global_temp` + spark.sql("SELECT * FROM global_temp.people").show(); + // +----+-------+ + // | age| name| + // +----+-------+ + // |null|Michael| + // | 30| Andy| + // | 19| Justin| + // +----+-------+ + + // Global temporary view is cross-session + spark.newSession().sql("SELECT * FROM global_temp.people").show(); + // +----+-------+ + // | age| name| + // +----+-------+ + // |null|Michael| + // | 30| Andy| + // | 19| Justin| + // +----+-------+ + // $example off:global_temp_view$ } private static void runDatasetCreationExample(SparkSession spark) { diff --git a/examples/src/main/python/sql/basic.py b/examples/src/main/python/sql/basic.py index fdc017aed97c1..ebcf66995b477 100644 --- a/examples/src/main/python/sql/basic.py +++ b/examples/src/main/python/sql/basic.py @@ -114,6 +114,31 @@ def basic_df_example(spark): # +----+-------+ # $example off:run_sql$ + # $example on:global_temp_view$ + # Register the DataFrame as a global temporary view + df.createGlobalTempView("people") + + # Global temporary view is tied to a system preserved database `global_temp` + spark.sql("SELECT * FROM global_temp.people").show() + # +----+-------+ + # | age| name| + # +----+-------+ + # |null|Michael| + # | 30| Andy| + # | 19| Justin| + # +----+-------+ + + # Global temporary view is cross-session + spark.newSession().sql("SELECT * FROM global_temp.people").show() + # +----+-------+ + # | age| name| + # +----+-------+ + # |null|Michael| + # | 30| Andy| + # | 19| Justin| + # +----+-------+ + # $example off:global_temp_view$ + def schema_inference_example(spark): # $example on:schema_inferring$ diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/SparkSQLExample.scala b/examples/src/main/scala/org/apache/spark/examples/sql/SparkSQLExample.scala index 129b81d5fbbf3..f27c403c5b388 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/SparkSQLExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/SparkSQLExample.scala @@ -135,6 +135,31 @@ object SparkSQLExample { // | 19| Justin| // +----+-------+ // $example off:run_sql$ + + // $example on:global_temp_view$ + // Register the DataFrame as a global temporary view + df.createGlobalTempView("people") + + // Global temporary view is tied to a system preserved database `global_temp` + spark.sql("SELECT * FROM global_temp.people").show() + // +----+-------+ + // | age| name| + // +----+-------+ + // |null|Michael| + // | 30| Andy| + // | 19| Justin| + // +----+-------+ + + // Global temporary view is cross-session + spark.newSession().sql("SELECT * FROM global_temp.people").show() + // +----+-------+ + // | age| name| + // +----+-------+ + // |null|Michael| + // | 30| Andy| + // | 19| Justin| + // +----+-------+ + // $example off:global_temp_view$ } private def runDatasetCreationExample(spark: SparkSession): Unit = { diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 163e3f2fdea40..e3d9a17469a35 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -55,7 +55,9 @@ object MimaExcludes { ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.getFunction"), ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.databaseExists"), ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.tableExists"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.functionExists") + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.functionExists"), + // [SPARK-17338][SQL] add global temp view + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.dropGlobalTempView") ) } diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py index 3c5030722f307..df3bf4254d4d3 100644 --- a/python/pyspark/sql/catalog.py +++ b/python/pyspark/sql/catalog.py @@ -167,7 +167,7 @@ def createExternalTable(self, tableName, path=None, source=None, schema=None, ** @since(2.0) def dropTempView(self, viewName): - """Drops the temporary view with the given view name in the catalog. + """Drops the local temporary view with the given view name in the catalog. If the view has been cached before, then it will also be uncached. >>> spark.createDataFrame([(1, 1)]).createTempView("my_table") @@ -181,6 +181,22 @@ def dropTempView(self, viewName): """ self._jcatalog.dropTempView(viewName) + @since(2.1) + def dropGlobalTempView(self, viewName): + """Drops the global temporary view with the given view name in the catalog. + If the view has been cached before, then it will also be uncached. + + >>> spark.createDataFrame([(1, 1)]).createGlobalTempView("my_table") + >>> spark.table("global_temp.my_table").collect() + [Row(_1=1, _2=1)] + >>> spark.catalog.dropGlobalTempView("my_table") + >>> spark.table("global_temp.my_table") # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + AnalysisException: ... + """ + self._jcatalog.dropGlobalTempView(viewName) + @ignore_unicode_prefix @since(2.0) def registerFunction(self, name, f, returnType=StringType()): diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 7482be8bda5c4..8264dcf8a97d2 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -386,7 +386,7 @@ def tables(self, dbName=None): >>> sqlContext.registerDataFrameAsTable(df, "table1") >>> df2 = sqlContext.tables() >>> df2.filter("tableName = 'table1'").first() - Row(tableName=u'table1', isTemporary=True) + Row(database=u'', tableName=u'table1', isTemporary=True) """ if dbName is None: return DataFrame(self._ssql_ctx.tables(), self) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 0ac481a8a8b56..14e80ea4615ef 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -131,7 +131,7 @@ def registerTempTable(self, name): @since(2.0) def createTempView(self, name): - """Creates a temporary view with this DataFrame. + """Creates a local temporary view with this DataFrame. The lifetime of this temporary table is tied to the :class:`SparkSession` that was used to create this :class:`DataFrame`. @@ -153,7 +153,7 @@ def createTempView(self, name): @since(2.0) def createOrReplaceTempView(self, name): - """Creates or replaces a temporary view with this DataFrame. + """Creates or replaces a local temporary view with this DataFrame. The lifetime of this temporary table is tied to the :class:`SparkSession` that was used to create this :class:`DataFrame`. @@ -169,6 +169,27 @@ def createOrReplaceTempView(self, name): """ self._jdf.createOrReplaceTempView(name) + @since(2.1) + def createGlobalTempView(self, name): + """Creates a global temporary view with this DataFrame. + + The lifetime of this temporary view is tied to this Spark application. + throws :class:`TempTableAlreadyExistsException`, if the view name already exists in the + catalog. + + >>> df.createGlobalTempView("people") + >>> df2 = spark.sql("select * from global_temp.people") + >>> sorted(df.collect()) == sorted(df2.collect()) + True + >>> df.createGlobalTempView("people") # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + AnalysisException: u"Temporary table 'people' already exists;" + >>> spark.catalog.dropGlobalTempView("people") + + """ + self._jdf.createGlobalTempView(name) + @property @since(1.4) def write(self): diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index a3bbaceca371b..b599a884957a8 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -111,11 +111,12 @@ statement | ALTER TABLE tableIdentifier RECOVER PARTITIONS #recoverPartitions | DROP TABLE (IF EXISTS)? tableIdentifier PURGE? #dropTable | DROP VIEW (IF EXISTS)? tableIdentifier #dropTable - | CREATE (OR REPLACE)? TEMPORARY? VIEW (IF NOT EXISTS)? tableIdentifier + | CREATE (OR REPLACE)? (GLOBAL? TEMPORARY)? + VIEW (IF NOT EXISTS)? tableIdentifier identifierCommentList? (COMMENT STRING)? (PARTITIONED ON identifierList)? (TBLPROPERTIES tablePropertyList)? AS query #createView - | CREATE (OR REPLACE)? TEMPORARY VIEW + | CREATE (OR REPLACE)? GLOBAL? TEMPORARY VIEW tableIdentifier ('(' colTypeList ')')? tableProvider (OPTIONS tablePropertyList)? #createTempViewUsing | ALTER VIEW tableIdentifier AS? query #alterViewQuery @@ -676,7 +677,7 @@ nonReserved | MAP | ARRAY | STRUCT | LATERAL | WINDOW | REDUCE | TRANSFORM | USING | SERDE | SERDEPROPERTIES | RECORDREADER | DELIMITED | FIELDS | TERMINATED | COLLECTION | ITEMS | KEYS | ESCAPED | LINES | SEPARATED - | EXTENDED | REFRESH | CLEAR | CACHE | UNCACHE | LAZY | TEMPORARY | OPTIONS + | EXTENDED | REFRESH | CLEAR | CACHE | UNCACHE | LAZY | GLOBAL | TEMPORARY | OPTIONS | GROUPING | CUBE | ROLLUP | EXPLAIN | FORMAT | LOGICAL | FORMATTED | CODEGEN | TABLESAMPLE | USE | TO | BUCKET | PERCENTLIT | OUT | OF @@ -864,6 +865,7 @@ CACHE: 'CACHE'; UNCACHE: 'UNCACHE'; LAZY: 'LAZY'; FORMATTED: 'FORMATTED'; +GLOBAL: 'GLOBAL'; TEMPORARY: 'TEMPORARY' | 'TEMP'; OPTIONS: 'OPTIONS'; UNSET: 'UNSET'; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index ae8869ff25f2d..536d38777f89d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -458,12 +458,12 @@ class Analyzer( i.copy(table = EliminateSubqueryAliases(lookupTableFromCatalog(u))) case u: UnresolvedRelation => val table = u.tableIdentifier - if (table.database.isDefined && conf.runSQLonFile && + if (table.database.isDefined && conf.runSQLonFile && !catalog.isTemporaryTable(table) && (!catalog.databaseExists(table.database.get) || !catalog.tableExists(table))) { - // If the table does not exist, and the database part is specified, and we support - // running SQL directly on files, then let's just return the original UnresolvedRelation. - // It is possible we are matching a query like "select * from parquet.`/path/to/query`". - // The plan will get resolved later. + // If the database part is specified, and we support running SQL directly on files, and + // it's not a temporary view, and the table does not exist, then let's just return the + // original UnresolvedRelation. It is possible we are matching a query like "select * + // from parquet.`/path/to/query`". The plan will get resolved later. // Note that we are testing (!db_exists || !table_exists) because the catalog throws // an exception from tableExists if the database does not exist. u diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/GlobalTempViewManager.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/GlobalTempViewManager.scala new file mode 100644 index 0000000000000..6095ac0bc9c50 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/GlobalTempViewManager.scala @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.catalog + +import javax.annotation.concurrent.GuardedBy + +import scala.collection.mutable + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.TempTableAlreadyExistsException +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.util.StringUtils + + +/** + * A thread-safe manager for global temporary views, providing atomic operations to manage them, + * e.g. create, update, remove, etc. + * + * Note that, the view name is always case-sensitive here, callers are responsible to format the + * view name w.r.t. case-sensitive config. + * + * @param database The system preserved virtual database that keeps all the global temporary views. + */ +class GlobalTempViewManager(val database: String) { + + /** List of view definitions, mapping from view name to logical plan. */ + @GuardedBy("this") + private val viewDefinitions = new mutable.HashMap[String, LogicalPlan] + + /** + * Returns the global view definition which matches the given name, or None if not found. + */ + def get(name: String): Option[LogicalPlan] = synchronized { + viewDefinitions.get(name) + } + + /** + * Creates a global temp view, or issue an exception if the view already exists and + * `overrideIfExists` is false. + */ + def create( + name: String, + viewDefinition: LogicalPlan, + overrideIfExists: Boolean): Unit = synchronized { + if (!overrideIfExists && viewDefinitions.contains(name)) { + throw new TempTableAlreadyExistsException(name) + } + viewDefinitions.put(name, viewDefinition) + } + + /** + * Updates the global temp view if it exists, returns true if updated, false otherwise. + */ + def update( + name: String, + viewDefinition: LogicalPlan): Boolean = synchronized { + if (viewDefinitions.contains(name)) { + viewDefinitions.put(name, viewDefinition) + true + } else { + false + } + } + + /** + * Removes the global temp view if it exists, returns true if removed, false otherwise. + */ + def remove(name: String): Boolean = synchronized { + viewDefinitions.remove(name).isDefined + } + + /** + * Renames the global temp view if the source view exists and the destination view not exists, or + * issue an exception if the source view exists but the destination view already exists. Returns + * true if renamed, false otherwise. + */ + def rename(oldName: String, newName: String): Boolean = synchronized { + if (viewDefinitions.contains(oldName)) { + if (viewDefinitions.contains(newName)) { + throw new AnalysisException( + s"rename temporary view from '$oldName' to '$newName': destination view already exists") + } + + val viewDefinition = viewDefinitions(oldName) + viewDefinitions.remove(oldName) + viewDefinitions.put(newName, viewDefinition) + true + } else { + false + } + } + + /** + * Lists the names of all global temporary views. + */ + def listViewNames(pattern: String): Seq[String] = synchronized { + StringUtils.filterPattern(viewDefinitions.keys.toSeq, pattern) + } + + /** + * Clears all the global temporary views. + */ + def clear(): Unit = synchronized { + viewDefinitions.clear() + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 8c01c7a3f2bd5..e44e30ec648f6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -25,6 +25,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.GLOBAL_TEMP_DATABASE import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{CatalystConf, SimpleCatalystConf} import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} @@ -47,6 +48,7 @@ object SessionCatalog { */ class SessionCatalog( externalCatalog: ExternalCatalog, + globalTempViewManager: GlobalTempViewManager, functionResourceLoader: FunctionResourceLoader, functionRegistry: FunctionRegistry, conf: CatalystConf, @@ -61,6 +63,7 @@ class SessionCatalog( conf: CatalystConf) { this( externalCatalog, + new GlobalTempViewManager(GLOBAL_TEMP_DATABASE.defaultValueString), DummyFunctionResourceLoader, functionRegistry, conf, @@ -142,8 +145,13 @@ class SessionCatalog( // ---------------------------------------------------------------------------- def createDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit = { - val qualifiedPath = makeQualifiedPath(dbDefinition.locationUri).toString val dbName = formatDatabaseName(dbDefinition.name) + if (dbName == globalTempViewManager.database) { + throw new AnalysisException( + s"${globalTempViewManager.database} is a system preserved database, " + + "you cannot create a database with this name.") + } + val qualifiedPath = makeQualifiedPath(dbDefinition.locationUri).toString externalCatalog.createDatabase( dbDefinition.copy(name = dbName, locationUri = qualifiedPath), ignoreIfExists) @@ -154,7 +162,7 @@ class SessionCatalog( if (dbName == DEFAULT_DATABASE) { throw new AnalysisException(s"Can not drop default database") } else if (dbName == getCurrentDatabase) { - throw new AnalysisException(s"Can not drop current database `${dbName}`") + throw new AnalysisException(s"Can not drop current database `$dbName`") } externalCatalog.dropDatabase(dbName, ignoreIfNotExists, cascade) } @@ -188,6 +196,13 @@ class SessionCatalog( def setCurrentDatabase(db: String): Unit = { val dbName = formatDatabaseName(db) + if (dbName == globalTempViewManager.database) { + throw new AnalysisException( + s"${globalTempViewManager.database} is a system preserved database, " + + "you cannot use it as current database. To access global temporary views, you should " + + "use qualified name with the GLOBAL_TEMP_DATABASE, e.g. SELECT * FROM " + + s"${globalTempViewManager.database}.viewName.") + } requireDbExists(dbName) synchronized { currentDb = dbName } } @@ -329,7 +344,7 @@ class SessionCatalog( // ---------------------------------------------- /** - * Create a temporary table. + * Create a local temporary view. */ def createTempView( name: String, @@ -343,19 +358,65 @@ class SessionCatalog( } /** - * Return a temporary view exactly as it was stored. + * Create a global temporary view. + */ + def createGlobalTempView( + name: String, + viewDefinition: LogicalPlan, + overrideIfExists: Boolean): Unit = { + globalTempViewManager.create(formatTableName(name), viewDefinition, overrideIfExists) + } + + /** + * Alter the definition of a local/global temp view matching the given name, returns true if a + * temp view is matched and altered, false otherwise. + */ + def alterTempViewDefinition( + name: TableIdentifier, + viewDefinition: LogicalPlan): Boolean = synchronized { + val viewName = formatTableName(name.table) + if (name.database.isEmpty) { + if (tempTables.contains(viewName)) { + createTempView(viewName, viewDefinition, overrideIfExists = true) + true + } else { + false + } + } else if (formatDatabaseName(name.database.get) == globalTempViewManager.database) { + globalTempViewManager.update(viewName, viewDefinition) + } else { + false + } + } + + /** + * Return a local temporary view exactly as it was stored. */ def getTempView(name: String): Option[LogicalPlan] = synchronized { tempTables.get(formatTableName(name)) } /** - * Drop a temporary view. + * Return a global temporary view exactly as it was stored. + */ + def getGlobalTempView(name: String): Option[LogicalPlan] = { + globalTempViewManager.get(formatTableName(name)) + } + + /** + * Drop a local temporary view. */ def dropTempView(name: String): Unit = synchronized { tempTables.remove(formatTableName(name)) } + /** + * Drop a global temporary view. + */ + def dropGlobalTempView(name: String): Boolean = { + globalTempViewManager.remove(formatTableName(name)) + } + // ------------------------------------------------------------- // | Methods that interact with temporary and metastore tables | // ------------------------------------------------------------- @@ -371,9 +432,7 @@ class SessionCatalog( */ def getTempViewOrPermanentTableMetadata(name: TableIdentifier): CatalogTable = synchronized { val table = formatTableName(name.table) - if (name.database.isDefined) { - getTableMetadata(name) - } else { + if (name.database.isEmpty) { getTempView(table).map { plan => CatalogTable( identifier = TableIdentifier(table), @@ -381,6 +440,16 @@ class SessionCatalog( storage = CatalogStorageFormat.empty, schema = plan.output.toStructType) }.getOrElse(getTableMetadata(name)) + } else if (formatDatabaseName(name.database.get) == globalTempViewManager.database) { + globalTempViewManager.get(table).map { plan => + CatalogTable( + identifier = TableIdentifier(table, Some(globalTempViewManager.database)), + tableType = CatalogTableType.VIEW, + storage = CatalogStorageFormat.empty, + schema = plan.output.toStructType) + }.getOrElse(throw new NoSuchTableException(globalTempViewManager.database, table)) + } else { + getTableMetadata(name) } } @@ -393,21 +462,25 @@ class SessionCatalog( */ def renameTable(oldName: TableIdentifier, newName: String): Unit = synchronized { val db = formatDatabaseName(oldName.database.getOrElse(currentDb)) - requireDbExists(db) val oldTableName = formatTableName(oldName.table) val newTableName = formatTableName(newName) - if (oldName.database.isDefined || !tempTables.contains(oldTableName)) { - requireTableExists(TableIdentifier(oldTableName, Some(db))) - requireTableNotExists(TableIdentifier(newTableName, Some(db))) - externalCatalog.renameTable(db, oldTableName, newTableName) + if (db == globalTempViewManager.database) { + globalTempViewManager.rename(oldTableName, newTableName) } else { - if (tempTables.contains(newTableName)) { - throw new AnalysisException( - s"RENAME TEMPORARY TABLE from '$oldName' to '$newName': destination table already exists") + requireDbExists(db) + if (oldName.database.isDefined || !tempTables.contains(oldTableName)) { + requireTableExists(TableIdentifier(oldTableName, Some(db))) + requireTableNotExists(TableIdentifier(newTableName, Some(db))) + externalCatalog.renameTable(db, oldTableName, newTableName) + } else { + if (tempTables.contains(newTableName)) { + throw new AnalysisException(s"RENAME TEMPORARY TABLE from '$oldName' to '$newName': " + + "destination table already exists") + } + val table = tempTables(oldTableName) + tempTables.remove(oldTableName) + tempTables.put(newTableName, table) } - val table = tempTables(oldTableName) - tempTables.remove(oldTableName) - tempTables.put(newTableName, table) } } @@ -424,17 +497,24 @@ class SessionCatalog( purge: Boolean): Unit = synchronized { val db = formatDatabaseName(name.database.getOrElse(currentDb)) val table = formatTableName(name.table) - if (name.database.isDefined || !tempTables.contains(table)) { - requireDbExists(db) - // When ignoreIfNotExists is false, no exception is issued when the table does not exist. - // Instead, log it as an error message. - if (tableExists(TableIdentifier(table, Option(db)))) { - externalCatalog.dropTable(db, table, ignoreIfNotExists = true, purge = purge) - } else if (!ignoreIfNotExists) { - throw new NoSuchTableException(db = db, table = table) + if (db == globalTempViewManager.database) { + val viewExists = globalTempViewManager.remove(table) + if (!viewExists && !ignoreIfNotExists) { + throw new NoSuchTableException(globalTempViewManager.database, table) } } else { - tempTables.remove(table) + if (name.database.isDefined || !tempTables.contains(table)) { + requireDbExists(db) + // When ignoreIfNotExists is false, no exception is issued when the table does not exist. + // Instead, log it as an error message. + if (tableExists(TableIdentifier(table, Option(db)))) { + externalCatalog.dropTable(db, table, ignoreIfNotExists = true, purge = purge) + } else if (!ignoreIfNotExists) { + throw new NoSuchTableException(db = db, table = table) + } + } else { + tempTables.remove(table) + } } } @@ -445,6 +525,9 @@ class SessionCatalog( * If no database is specified, this will first attempt to return a temporary table/view with * the same name, then, if that does not exist, return the table/view from the current database. * + * Note that, the global temp view database is also valid here, this will return the global temp + * view matching the given name. + * * If the relation is a view, the relation will be wrapped in a [[SubqueryAlias]] which will * track the name of the view. */ @@ -453,7 +536,11 @@ class SessionCatalog( val db = formatDatabaseName(name.database.getOrElse(currentDb)) val table = formatTableName(name.table) val relationAlias = alias.getOrElse(table) - if (name.database.isDefined || !tempTables.contains(table)) { + if (db == globalTempViewManager.database) { + globalTempViewManager.get(table).map { viewDef => + SubqueryAlias(relationAlias, viewDef, Some(name)) + }.getOrElse(throw new NoSuchTableException(db, table)) + } else if (name.database.isDefined || !tempTables.contains(table)) { val metadata = externalCatalog.getTable(db, table) val view = Option(metadata.tableType).collect { case CatalogTableType.VIEW => name @@ -472,27 +559,48 @@ class SessionCatalog( * explicitly specified. */ def isTemporaryTable(name: TableIdentifier): Boolean = synchronized { - name.database.isEmpty && tempTables.contains(formatTableName(name.table)) + val table = formatTableName(name.table) + if (name.database.isEmpty) { + tempTables.contains(table) + } else if (formatDatabaseName(name.database.get) == globalTempViewManager.database) { + globalTempViewManager.get(table).isDefined + } else { + false + } } /** - * List all tables in the specified database, including temporary tables. + * List all tables in the specified database, including local temporary tables. + * + * Note that, if the specified database is global temporary view database, we will list global + * temporary views. */ def listTables(db: String): Seq[TableIdentifier] = listTables(db, "*") /** - * List all matching tables in the specified database, including temporary tables. + * List all matching tables in the specified database, including local temporary tables. + * + * Note that, if the specified database is global temporary view database, we will list global + * temporary views. */ def listTables(db: String, pattern: String): Seq[TableIdentifier] = { val dbName = formatDatabaseName(db) - requireDbExists(dbName) - val dbTables = - externalCatalog.listTables(dbName, pattern).map { t => TableIdentifier(t, Some(dbName)) } - synchronized { - val _tempTables = StringUtils.filterPattern(tempTables.keys.toSeq, pattern) - .map { t => TableIdentifier(t) } - dbTables ++ _tempTables + val dbTables = if (dbName == globalTempViewManager.database) { + globalTempViewManager.listViewNames(pattern).map { name => + TableIdentifier(name, Some(globalTempViewManager.database)) + } + } else { + requireDbExists(dbName) + externalCatalog.listTables(dbName, pattern).map { name => + TableIdentifier(name, Some(dbName)) + } + } + val localTempViews = synchronized { + StringUtils.filterPattern(tempTables.keys.toSeq, pattern).map { name => + TableIdentifier(name) + } } + dbTables ++ localTempViews } /** @@ -504,6 +612,8 @@ class SessionCatalog( // If the database is not defined, there is a good chance this is a temp table. if (name.database.isEmpty) { tempTables.get(formatTableName(name.table)).foreach(_.refresh()) + } else if (formatDatabaseName(name.database.get) == globalTempViewManager.database) { + globalTempViewManager.get(formatTableName(name.table)).foreach(_.refresh()) } } @@ -919,6 +1029,7 @@ class SessionCatalog( } } tempTables.clear() + globalTempViewManager.clear() functionRegistry.clear() // restore built-in functions FunctionRegistry.builtin.listFunction().foreach { f => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 9cfbdffd02582..4b52508740bfb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -42,7 +42,7 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.usePrettyExpression import org.apache.spark.sql.execution.{FileRelation, LogicalRDD, QueryExecution, SQLExecution} -import org.apache.spark.sql.execution.command.{CreateViewCommand, ExplainCommand} +import org.apache.spark.sql.execution.command.{CreateViewCommand, ExplainCommand, GlobalTempView, LocalTempView} import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.datasources.json.JacksonGenerator import org.apache.spark.sql.execution.python.EvaluatePython @@ -2433,9 +2433,13 @@ class Dataset[T] private[sql]( } /** - * Creates a temporary view using the given name. The lifetime of this + * Creates a local temporary view using the given name. The lifetime of this * temporary view is tied to the [[SparkSession]] that was used to create this Dataset. * + * Local temporary view is session-scoped. Its lifetime is the lifetime of the session that + * created it, i.e. it will be automatically dropped when the session terminates. It's not + * tied to any databases, i.e. we can't use `db1.view1` to reference a local temporary view. + * * @throws AnalysisException if the view name already exists * * @group basic @@ -2443,21 +2447,51 @@ class Dataset[T] private[sql]( */ @throws[AnalysisException] def createTempView(viewName: String): Unit = withPlan { - createViewCommand(viewName, replace = false) + createTempViewCommand(viewName, replace = false, global = false) } + + /** - * Creates a temporary view using the given name. The lifetime of this + * Creates a local temporary view using the given name. The lifetime of this * temporary view is tied to the [[SparkSession]] that was used to create this Dataset. * * @group basic * @since 2.0.0 */ def createOrReplaceTempView(viewName: String): Unit = withPlan { - createViewCommand(viewName, replace = true) + createTempViewCommand(viewName, replace = true, global = false) } - private def createViewCommand(viewName: String, replace: Boolean): CreateViewCommand = { + /** + * Creates a global temporary view using the given name. The lifetime of this + * temporary view is tied to this Spark application. + * + * Global temporary view is cross-session. Its lifetime is the lifetime of the Spark application, + * i.e. it will be automatically dropped when the application terminates. It's tied to a system + * preserved database `_global_temp`, and we must use the qualified name to refer a global temp + * view, e.g. `SELECT * FROM _global_temp.view1`. + * + * @throws TempTableAlreadyExistsException if the view name already exists + * + * @group basic + * @since 2.1.0 + */ + @throws[AnalysisException] + def createGlobalTempView(viewName: String): Unit = withPlan { + createTempViewCommand(viewName, replace = false, global = true) + } + + private def createTempViewCommand( + viewName: String, + replace: Boolean, + global: Boolean): CreateViewCommand = { + val viewType = if (global) { + GlobalTempView + } else { + LocalTempView + } + CreateViewCommand( name = sparkSession.sessionState.sqlParser.parseTableIdentifier(viewName), userSpecifiedColumns = Nil, @@ -2467,7 +2501,7 @@ class Dataset[T] private[sql]( child = logicalPlan, allowExisting = false, replace = replace, - isTemporary = true) + viewType = viewType) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala index 7f2762c7dac92..717fb291901bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala @@ -262,14 +262,32 @@ abstract class Catalog { options: Map[String, String]): DataFrame /** - * Drops the temporary view with the given view name in the catalog. + * Drops the local temporary view with the given view name in the catalog. * If the view has been cached before, then it will also be uncached. * + * Local temporary view is session-scoped. Its lifetime is the lifetime of the session that + * created it, i.e. it will be automatically dropped when the session terminates. It's not + * tied to any databases, i.e. we can't use `db1.view1` to reference a local temporary view. + * * @param viewName the name of the view to be dropped. * @since 2.0.0 */ def dropTempView(viewName: String): Unit + /** + * Drops the global temporary view with the given view name in the catalog. + * If the view has been cached before, then it will also be uncached. + * + * Global temporary view is cross-session. Its lifetime is the lifetime of the Spark application, + * i.e. it will be automatically dropped when the application terminates. It's tied to a system + * preserved database `_global_temp`, and we must use the qualified name to refer a global temp + * view, e.g. `SELECT * FROM _global_temp.view1`. + * + * @param viewName the name of the view to be dropped. + * @since 2.1.0 + */ + def dropGlobalTempView(viewName: String): Boolean + /** * Returns true if the table is currently cached in-memory. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 383b3a233fc27..cb45a6d78b9b6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -21,15 +21,14 @@ import java.nio.charset.StandardCharsets import java.sql.Timestamp import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{AnalysisException, Row, SparkSession, SQLContext} +import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.UnsupportedOperationChecker import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.command.{DescribeTableCommand, ExecutedCommandExec} +import org.apache.spark.sql.execution.command.{DescribeTableCommand, ExecutedCommandExec, ShowTablesCommand} import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReuseExchange} -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{BinaryType, DateType, DecimalType, TimestampType, _} import org.apache.spark.util.Utils @@ -125,6 +124,9 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { .mkString("\t") } } + // SHOW TABLES in Hive only output table names, while ours outputs database, table name, isTemp. + case command: ExecutedCommandExec if command.cmd.isInstanceOf[ShowTablesCommand] => + command.executeCollect().map(_.getString(1)) case command: ExecutedCommandExec => command.executeCollect().map(_.getString(0)) case other => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 5f87b71210d31..be2eddbb0e423 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -29,9 +29,9 @@ import org.apache.spark.sql.catalyst.parser._ import org.apache.spark.sql.catalyst.parser.SqlBaseParser._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation, ScriptInputOutputSchema} import org.apache.spark.sql.execution.command._ -import org.apache.spark.sql.execution.datasources.{CreateTable, CreateTempViewUsing, _} +import org.apache.spark.sql.execution.datasources.{CreateTable, _} import org.apache.spark.sql.internal.{HiveSerDe, SQLConf, VariableSubstitution} -import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.sql.types.StructType /** * Concrete parser for Spark SQL statements. @@ -385,7 +385,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { logWarning(s"CREATE TEMPORARY TABLE ... USING ... is deprecated, please use " + "CREATE TEMPORARY VIEW ... USING ... instead") - CreateTempViewUsing(table, schema, replace = true, provider, options) + CreateTempViewUsing(table, schema, replace = true, global = false, provider, options) } else { CreateTable(tableDesc, mode, None) } @@ -401,6 +401,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { tableIdent = visitTableIdentifier(ctx.tableIdentifier()), userSpecifiedSchema = Option(ctx.colTypeList()).map(createSchema), replace = ctx.REPLACE != null, + global = ctx.GLOBAL != null, provider = ctx.tableProvider.qualifiedName.getText, options = Option(ctx.tablePropertyList).map(visitPropertyKeyValues).getOrElse(Map.empty)) } @@ -1269,7 +1270,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { * * For example: * {{{ - * CREATE [OR REPLACE] [TEMPORARY] VIEW [IF NOT EXISTS] [db_name.]view_name + * CREATE [OR REPLACE] [[GLOBAL] TEMPORARY] VIEW [IF NOT EXISTS] [db_name.]view_name * [(column_name [COMMENT column_comment], ...) ] * [COMMENT view_comment] * [TBLPROPERTIES (property_name = property_value, ...)] @@ -1286,6 +1287,14 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { } } + val viewType = if (ctx.TEMPORARY == null) { + PersistedView + } else if (ctx.GLOBAL != null) { + GlobalTempView + } else { + LocalTempView + } + CreateViewCommand( name = visitTableIdentifier(ctx.tableIdentifier), userSpecifiedColumns = userSpecifiedColumns, @@ -1295,7 +1304,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { child = plan(ctx.query), allowExisting = ctx.EXISTS != null, replace = ctx.REPLACE != null, - isTemporary = ctx.TEMPORARY != null) + viewType = viewType) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 01ac89868d100..45fa293e58951 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -183,17 +183,20 @@ case class DropTableCommand( override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog - // If the command DROP VIEW is to drop a table or DROP TABLE is to drop a view - // issue an exception. - catalog.getTableMetadataOption(tableName).map(_.tableType match { - case CatalogTableType.VIEW if !isView => - throw new AnalysisException( - "Cannot drop a view with DROP TABLE. Please use DROP VIEW instead") - case o if o != CatalogTableType.VIEW && isView => - throw new AnalysisException( - s"Cannot drop a table with DROP VIEW. Please use DROP TABLE instead") - case _ => - }) + + if (!catalog.isTemporaryTable(tableName) && catalog.tableExists(tableName)) { + // If the command DROP VIEW is to drop a table or DROP TABLE is to drop a view + // issue an exception. + catalog.getTableMetadata(tableName).tableType match { + case CatalogTableType.VIEW if !isView => + throw new AnalysisException( + "Cannot drop a view with DROP TABLE. Please use DROP VIEW instead") + case o if o != CatalogTableType.VIEW && isView => + throw new AnalysisException( + s"Cannot drop a table with DROP VIEW. Please use DROP TABLE instead") + case _ => + } + } try { sparkSession.sharedState.cacheManager.uncacheQuery( sparkSession.table(tableName.quotedString)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index 08de6cd4242c5..424ef58d76c5e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -579,9 +579,10 @@ case class ShowTablesCommand( databaseName: Option[String], tableIdentifierPattern: Option[String]) extends RunnableCommand { - // The result of SHOW TABLES has two columns, tableName and isTemporary. + // The result of SHOW TABLES has three columns: database, tableName and isTemporary. override val output: Seq[Attribute] = { - AttributeReference("tableName", StringType, nullable = false)() :: + AttributeReference("database", StringType, nullable = false)() :: + AttributeReference("tableName", StringType, nullable = false)() :: AttributeReference("isTemporary", BooleanType, nullable = false)() :: Nil } @@ -592,9 +593,9 @@ case class ShowTablesCommand( val db = databaseName.getOrElse(catalog.getCurrentDatabase) val tables = tableIdentifierPattern.map(catalog.listTables(db, _)).getOrElse(catalog.listTables(db)) - tables.map { t => - val isTemp = t.database.isEmpty - Row(t.table, isTemp) + tables.map { tableIdent => + val isTemp = catalog.isTemporaryTable(tableIdent) + Row(tableIdent.database.getOrElse(""), tableIdent.table, isTemp) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala index 15340ee921f68..bbcd9c4ef564c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala @@ -19,13 +19,46 @@ package org.apache.spark.sql.execution.command import scala.util.control.NonFatal -import org.apache.spark.sql.{AnalysisException, Row, SparkSession} +import org.apache.spark.sql.{AnalysisException, Dataset, Row, SparkSession} import org.apache.spark.sql.catalyst.{SQLBuilder, TableIdentifier} import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType} import org.apache.spark.sql.catalyst.expressions.Alias import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.execution.datasources.{DataSource, LogicalRelation} +import org.apache.spark.sql.types.{MetadataBuilder, StructType} + + +/** + * ViewType is used to specify the expected view type when we want to create or replace a view in + * [[CreateViewCommand]]. + */ +sealed trait ViewType + +/** + * LocalTempView means session-scoped local temporary views. Its lifetime is the lifetime of the + * session that created it, i.e. it will be automatically dropped when the session terminates. It's + * not tied to any databases, i.e. we can't use `db1.view1` to reference a local temporary view. + */ +object LocalTempView extends ViewType + +/** + * GlobalTempView means cross-session global temporary views. Its lifetime is the lifetime of the + * Spark application, i.e. it will be automatically dropped when the application terminates. It's + * tied to a system preserved database `_global_temp`, and we must use the qualified name to refer a + * global temp view, e.g. SELECT * FROM _global_temp.view1. + */ +object GlobalTempView extends ViewType + +/** + * PersistedView means cross-session persisted views. Persisted views stay until they are + * explicitly dropped by user command. It's always tied to a database, default to the current + * database if not specified. + * + * Note that, Existing persisted view with the same name are not visible to the current session + * while the local temporary view exists, unless the view name is qualified by database. + */ +object PersistedView extends ViewType /** @@ -46,10 +79,7 @@ import org.apache.spark.sql.types.StructType * already exists, throws analysis exception. * @param replace if true, and if the view already exists, updates it; if false, and if the view * already exists, throws analysis exception. - * @param isTemporary if true, the view is created as a temporary view. Temporary views are dropped - * at the end of current Spark session. Existing permanent relations with the same - * name are not visible to the current session while the temporary view exists, - * unless they are specified with full qualified table name with database prefix. + * @param viewType the expected view type to be created with this command. */ case class CreateViewCommand( name: TableIdentifier, @@ -60,20 +90,21 @@ case class CreateViewCommand( child: LogicalPlan, allowExisting: Boolean, replace: Boolean, - isTemporary: Boolean) + viewType: ViewType) extends RunnableCommand { override protected def innerChildren: Seq[QueryPlan[_]] = Seq(child) - if (!isTemporary) { - require(originalText.isDefined, - "The table to created with CREATE VIEW must have 'originalText'.") + if (viewType == PersistedView) { + require(originalText.isDefined, "'originalText' must be provided to create permanent view") } if (allowExisting && replace) { throw new AnalysisException("CREATE VIEW with both IF NOT EXISTS and REPLACE is not allowed.") } + private def isTemporary = viewType == LocalTempView || viewType == GlobalTempView + // Disallows 'CREATE TEMPORARY VIEW IF NOT EXISTS' to be consistent with 'CREATE TEMPORARY TABLE' if (allowExisting && isTemporary) { throw new AnalysisException( @@ -99,72 +130,53 @@ case class CreateViewCommand( s"(num: `${analyzedPlan.output.length}`) does not match the number of column names " + s"specified by CREATE VIEW (num: `${userSpecifiedColumns.length}`).") } - val sessionState = sparkSession.sessionState - - if (isTemporary) { - createTemporaryView(sparkSession, analyzedPlan) - } else { - // Adds default database for permanent table if it doesn't exist, so that tableExists() - // only check permanent tables. - val database = name.database.getOrElse(sessionState.catalog.getCurrentDatabase) - val qualifiedName = name.copy(database = Option(database)) - - if (sessionState.catalog.tableExists(qualifiedName)) { - val tableMetadata = sessionState.catalog.getTableMetadata(qualifiedName) - if (allowExisting) { - // Handles `CREATE VIEW IF NOT EXISTS v0 AS SELECT ...`. Does nothing when the target view - // already exists. - } else if (tableMetadata.tableType != CatalogTableType.VIEW) { - throw new AnalysisException(s"$qualifiedName is not a view") - } else if (replace) { - // Handles `CREATE OR REPLACE VIEW v0 AS SELECT ...` - sessionState.catalog.alterTable(prepareTable(sparkSession, analyzedPlan)) - } else { - // Handles `CREATE VIEW v0 AS SELECT ...`. Throws exception when the target view already - // exists. - throw new AnalysisException( - s"View $qualifiedName already exists. If you want to update the view definition, " + - "please use ALTER VIEW AS or CREATE OR REPLACE VIEW AS") - } - } else { - // Create the view if it doesn't exist. - sessionState.catalog.createTable( - prepareTable(sparkSession, analyzedPlan), ignoreIfExists = false) - } - } - Seq.empty[Row] - } - - private def createTemporaryView(sparkSession: SparkSession, analyzedPlan: LogicalPlan): Unit = { - val catalog = sparkSession.sessionState.catalog - // Projects column names to alias names - val logicalPlan = if (userSpecifiedColumns.isEmpty) { + val aliasedPlan = if (userSpecifiedColumns.isEmpty) { analyzedPlan } else { val projectList = analyzedPlan.output.zip(userSpecifiedColumns).map { - case (attr, (colName, _)) => Alias(attr, colName)() + case (attr, (colName, None)) => Alias(attr, colName)() + case (attr, (colName, Some(colComment))) => + val meta = new MetadataBuilder().putString("comment", colComment).build() + Alias(attr, colName)(explicitMetadata = Some(meta)) } sparkSession.sessionState.executePlan(Project(projectList, analyzedPlan)).analyzed } - catalog.createTempView(name.table, logicalPlan, replace) + val catalog = sparkSession.sessionState.catalog + if (viewType == LocalTempView) { + catalog.createTempView(name.table, aliasedPlan, overrideIfExists = replace) + } else if (viewType == GlobalTempView) { + catalog.createGlobalTempView(name.table, aliasedPlan, overrideIfExists = replace) + } else if (catalog.tableExists(name)) { + val tableMetadata = catalog.getTableMetadata(name) + if (allowExisting) { + // Handles `CREATE VIEW IF NOT EXISTS v0 AS SELECT ...`. Does nothing when the target view + // already exists. + } else if (tableMetadata.tableType != CatalogTableType.VIEW) { + throw new AnalysisException(s"$name is not a view") + } else if (replace) { + // Handles `CREATE OR REPLACE VIEW v0 AS SELECT ...` + catalog.alterTable(prepareTable(sparkSession, aliasedPlan)) + } else { + // Handles `CREATE VIEW v0 AS SELECT ...`. Throws exception when the target view already + // exists. + throw new AnalysisException( + s"View $name already exists. If you want to update the view definition, " + + "please use ALTER VIEW AS or CREATE OR REPLACE VIEW AS") + } + } else { + // Create the view if it doesn't exist. + catalog.createTable(prepareTable(sparkSession, aliasedPlan), ignoreIfExists = false) + } + Seq.empty[Row] } /** * Returns a [[CatalogTable]] that can be used to save in the catalog. This comment canonicalize * SQL based on the analyzed plan, and also creates the proper schema for the view. */ - private def prepareTable(sparkSession: SparkSession, analyzedPlan: LogicalPlan): CatalogTable = { - val aliasedPlan = if (userSpecifiedColumns.isEmpty) { - analyzedPlan - } else { - val projectList = analyzedPlan.output.zip(userSpecifiedColumns).map { - case (attr, (colName, _)) => Alias(attr, colName)() - } - sparkSession.sessionState.executePlan(Project(projectList, analyzedPlan)).analyzed - } - + private def prepareTable(sparkSession: SparkSession, aliasedPlan: LogicalPlan): CatalogTable = { val viewSQL: String = new SQLBuilder(aliasedPlan).toSQL // Validate the view SQL - make sure we can parse it and analyze it. @@ -176,19 +188,11 @@ case class CreateViewCommand( throw new RuntimeException(s"Failed to analyze the canonicalized SQL: $viewSQL", e) } - val viewSchema = if (userSpecifiedColumns.isEmpty) { - aliasedPlan.schema - } else { - StructType(aliasedPlan.schema.zip(userSpecifiedColumns).map { - case (field, (_, comment)) => comment.map(field.withComment).getOrElse(field) - }) - } - CatalogTable( identifier = name, tableType = CatalogTableType.VIEW, storage = CatalogStorageFormat.empty, - schema = viewSchema, + schema = aliasedPlan.schema, properties = properties, viewOriginalText = originalText, viewText = Some(viewSQL), @@ -222,8 +226,8 @@ case class AlterViewAsCommand( qe.assertAnalyzed() val analyzedPlan = qe.analyzed - if (session.sessionState.catalog.isTemporaryTable(name)) { - session.sessionState.catalog.createTempView(name.table, analyzedPlan, overrideIfExists = true) + if (session.sessionState.catalog.alterTempViewDefinition(name, analyzedPlan)) { + // a local/global temp view has been altered, we are done. } else { alterPermanentView(session, analyzedPlan) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index fa95af2648cf9..59fb48ffea598 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -40,16 +40,20 @@ case class CreateTable( override def innerChildren: Seq[QueryPlan[_]] = query.toSeq } +/** + * Create or replace a local/global temporary view with given data source. + */ case class CreateTempViewUsing( tableIdent: TableIdentifier, userSpecifiedSchema: Option[StructType], replace: Boolean, + global: Boolean, provider: String, options: Map[String, String]) extends RunnableCommand { if (tableIdent.database.isDefined) { throw new AnalysisException( - s"Temporary table '$tableIdent' should not have specified a database") + s"Temporary view '$tableIdent' should not have specified a database") } def run(sparkSession: SparkSession): Seq[Row] = { @@ -58,10 +62,16 @@ case class CreateTempViewUsing( userSpecifiedSchema = userSpecifiedSchema, className = provider, options = options) - sparkSession.sessionState.catalog.createTempView( - tableIdent.table, - Dataset.ofRows(sparkSession, LogicalRelation(dataSource.resolveRelation())).logicalPlan, - replace) + + val catalog = sparkSession.sessionState.catalog + val viewDefinition = Dataset.ofRows( + sparkSession, LogicalRelation(dataSource.resolveRelation())).logicalPlan + + if (global) { + catalog.createGlobalTempView(tableIdent.table, viewDefinition, replace) + } else { + catalog.createTempView(tableIdent.table, viewDefinition, replace) + } Seq.empty[Row] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala index e412e1b4b302a..c05bda3f1b526 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala @@ -94,20 +94,19 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { */ @throws[AnalysisException]("database does not exist") override def listTables(dbName: String): Dataset[Table] = { - requireDatabaseExists(dbName) val tables = sessionCatalog.listTables(dbName).map(makeTable) CatalogImpl.makeDataset(tables, sparkSession) } private def makeTable(tableIdent: TableIdentifier): Table = { val metadata = sessionCatalog.getTempViewOrPermanentTableMetadata(tableIdent) - val database = metadata.identifier.database + val isTemp = sessionCatalog.isTemporaryTable(tableIdent) new Table( name = tableIdent.table, - database = database.orNull, + database = metadata.identifier.database.orNull, description = metadata.comment.orNull, - tableType = if (database.isEmpty) "TEMPORARY" else metadata.tableType.name, - isTemporary = database.isEmpty) + tableType = if (isTemp) "TEMPORARY" else metadata.tableType.name, + isTemporary = isTemp) } /** @@ -365,7 +364,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { } /** - * Drops the temporary view with the given view name in the catalog. + * Drops the local temporary view with the given view name in the catalog. * If the view has been cached/persisted before, it's also unpersisted. * * @param viewName the name of the view to be dropped. @@ -379,6 +378,21 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { } } + /** + * Drops the global temporary view with the given view name in the catalog. + * If the view has been cached/persisted before, it's also unpersisted. + * + * @param viewName the name of the view to be dropped. + * @group ddl_ops + * @since 2.1.0 + */ + override def dropGlobalTempView(viewName: String): Boolean = { + sparkSession.sessionState.catalog.getGlobalTempView(viewName).exists { viewDef => + sparkSession.sharedState.cacheManager.uncacheQuery(Dataset.ofRows(sparkSession, viewDef)) + sessionCatalog.dropGlobalTempView(viewName) + } + } + /** * Returns true if the table is currently cached in-memory. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index 9f7d0019c6b92..8759dfe39ce1d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -95,6 +95,7 @@ private[sql] class SessionState(sparkSession: SparkSession) { */ lazy val catalog = new SessionCatalog( sparkSession.sharedState.externalCatalog, + sparkSession.sharedState.globalTempViewManager, functionResourceLoader, functionRegistry, conf, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala index 6387f0150631c..c555a43cd2581 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala @@ -22,11 +22,11 @@ import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.{SparkConf, SparkContext, SparkException} import org.apache.spark.internal.config._ import org.apache.spark.internal.Logging import org.apache.spark.sql.{SparkSession, SQLContext} -import org.apache.spark.sql.catalyst.catalog.{ExternalCatalog, InMemoryCatalog} +import org.apache.spark.sql.catalyst.catalog.{ExternalCatalog, GlobalTempViewManager, InMemoryCatalog} import org.apache.spark.sql.execution.CacheManager import org.apache.spark.sql.execution.ui.{SQLListener, SQLTab} import org.apache.spark.util.{MutableURLClassLoader, Utils} @@ -37,39 +37,14 @@ import org.apache.spark.util.{MutableURLClassLoader, Utils} */ private[sql] class SharedState(val sparkContext: SparkContext) extends Logging { - /** - * Class for caching query results reused in future executions. - */ - val cacheManager: CacheManager = new CacheManager - - /** - * A listener for SQL-specific [[org.apache.spark.scheduler.SparkListenerEvent]]s. - */ - val listener: SQLListener = createListenerAndUI(sparkContext) - + // Load hive-site.xml into hadoopConf and determine the warehouse path we want to use, based on + // the config from both hive and Spark SQL. Finally set the warehouse config value to sparkConf. { val configFile = Utils.getContextOrSparkClassLoader.getResource("hive-site.xml") if (configFile != null) { sparkContext.hadoopConfiguration.addResource(configFile) } - } - - /** - * A catalog that interacts with external systems. - */ - lazy val externalCatalog: ExternalCatalog = - SharedState.reflect[ExternalCatalog, SparkConf, Configuration]( - SharedState.externalCatalogClassName(sparkContext.conf), - sparkContext.conf, - sparkContext.hadoopConfiguration) - - /** - * A classloader used to load all user-added jar. - */ - val jarClassLoader = new NonClosableMutableURLClassLoader( - org.apache.spark.util.Utils.getContextOrSparkClassLoader) - { // Set the Hive metastore warehouse path to the one we use val tempConf = new SQLConf sparkContext.conf.getAll.foreach { case (k, v) => tempConf.setConfString(k, v) } @@ -93,6 +68,48 @@ private[sql] class SharedState(val sparkContext: SparkContext) extends Logging { logInfo(s"Warehouse path is '${tempConf.warehousePath}'.") } + /** + * Class for caching query results reused in future executions. + */ + val cacheManager: CacheManager = new CacheManager + + /** + * A listener for SQL-specific [[org.apache.spark.scheduler.SparkListenerEvent]]s. + */ + val listener: SQLListener = createListenerAndUI(sparkContext) + + /** + * A catalog that interacts with external systems. + */ + val externalCatalog: ExternalCatalog = + SharedState.reflect[ExternalCatalog, SparkConf, Configuration]( + SharedState.externalCatalogClassName(sparkContext.conf), + sparkContext.conf, + sparkContext.hadoopConfiguration) + + /** + * A manager for global temporary views. + */ + val globalTempViewManager = { + // System preserved database should not exists in metastore. However it's hard to guarantee it + // for every session, because case-sensitivity differs. Here we always lowercase it to make our + // life easier. + val globalTempDB = sparkContext.conf.get(GLOBAL_TEMP_DATABASE).toLowerCase + if (externalCatalog.databaseExists(globalTempDB)) { + throw new SparkException( + s"$globalTempDB is a system preserved database, please rename your existing database " + + "to resolve the name conflict, or set a different value for " + + s"${GLOBAL_TEMP_DATABASE.key}, and launch your Spark application again.") + } + new GlobalTempViewManager(globalTempDB) + } + + /** + * A classloader used to load all user-added jar. + */ + val jarClassLoader = new NonClosableMutableURLClassLoader( + org.apache.spark.util.Utils.getContextOrSparkClassLoader) + /** * Create a SQLListener then add it into SparkContext, and create a SQLTab if there is SparkUI. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala index 001c1a1d85313..2b35db411e2ab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala @@ -88,11 +88,11 @@ class SQLContextSuite extends SparkFunSuite with SharedSparkContext { df.createOrReplaceTempView("listtablessuitetable") assert( sqlContext.tables().filter("tableName = 'listtablessuitetable'").collect().toSeq == - Row("listtablessuitetable", true) :: Nil) + Row("", "listtablessuitetable", true) :: Nil) assert( sqlContext.sql("SHOW tables").filter("tableName = 'listtablessuitetable'").collect().toSeq == - Row("listtablessuitetable", true) :: Nil) + Row("", "listtablessuitetable", true) :: Nil) sqlContext.sessionState.catalog.dropTable( TableIdentifier("listtablessuitetable"), ignoreIfNotExists = true, purge = false) @@ -105,11 +105,11 @@ class SQLContextSuite extends SparkFunSuite with SharedSparkContext { df.createOrReplaceTempView("listtablessuitetable") assert( sqlContext.tables("default").filter("tableName = 'listtablessuitetable'").collect().toSeq == - Row("listtablessuitetable", true) :: Nil) + Row("", "listtablessuitetable", true) :: Nil) assert( sqlContext.sql("show TABLES in default").filter("tableName = 'listtablessuitetable'") - .collect().toSeq == Row("listtablessuitetable", true) :: Nil) + .collect().toSeq == Row("", "listtablessuitetable", true) :: Nil) sqlContext.sessionState.catalog.dropTable( TableIdentifier("listtablessuitetable"), ignoreIfNotExists = true, purge = false) @@ -122,7 +122,8 @@ class SQLContextSuite extends SparkFunSuite with SharedSparkContext { df.createOrReplaceTempView("listtablessuitetable") val expectedSchema = StructType( - StructField("tableName", StringType, false) :: + StructField("database", StringType, false) :: + StructField("tableName", StringType, false) :: StructField("isTemporary", BooleanType, false) :: Nil) Seq(sqlContext.tables(), sqlContext.sql("SHOW TABLes")).foreach { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala new file mode 100644 index 0000000000000..391bcb8b35d02 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala @@ -0,0 +1,168 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.sql.catalog.Table +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.StructType + +class GlobalTempViewSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + override protected def beforeAll(): Unit = { + super.beforeAll() + globalTempDB = spark.sharedState.globalTempViewManager.database + } + + private var globalTempDB: String = _ + + test("basic semantic") { + sql("CREATE GLOBAL TEMP VIEW src AS SELECT 1, 'a'") + + // If there is no database in table name, we should try local temp view first, if not found, + // try table/view in current database, which is "default" in this case. So we expect + // NoSuchTableException here. + intercept[NoSuchTableException](spark.table("src")) + + // Use qualified name to refer to the global temp view explicitly. + checkAnswer(spark.table(s"$globalTempDB.src"), Row(1, "a")) + + // Table name without database will never refer to a global temp view. + intercept[NoSuchTableException](sql("DROP VIEW src")) + + sql(s"DROP VIEW $globalTempDB.src") + // The global temp view should be dropped successfully. + intercept[NoSuchTableException](spark.table(s"$globalTempDB.src")) + + // We can also use Dataset API to create global temp view + Seq(1 -> "a").toDF("i", "j").createGlobalTempView("src") + checkAnswer(spark.table(s"$globalTempDB.src"), Row(1, "a")) + + // Use qualified name to rename a global temp view. + sql(s"ALTER VIEW $globalTempDB.src RENAME TO src2") + intercept[NoSuchTableException](spark.table(s"$globalTempDB.src")) + checkAnswer(spark.table(s"$globalTempDB.src2"), Row(1, "a")) + + // Use qualified name to alter a global temp view. + sql(s"ALTER VIEW $globalTempDB.src2 AS SELECT 2, 'b'") + checkAnswer(spark.table(s"$globalTempDB.src2"), Row(2, "b")) + + // We can also use Catalog API to drop global temp view + spark.catalog.dropGlobalTempView("src2") + intercept[NoSuchTableException](spark.table(s"$globalTempDB.src2")) + } + + test("global temp view is shared among all sessions") { + try { + sql("CREATE GLOBAL TEMP VIEW src AS SELECT 1, 2") + checkAnswer(spark.table(s"$globalTempDB.src"), Row(1, 2)) + val newSession = spark.newSession() + checkAnswer(newSession.table(s"$globalTempDB.src"), Row(1, 2)) + } finally { + spark.catalog.dropGlobalTempView("src") + } + } + + test("global temp view database should be preserved") { + val e = intercept[AnalysisException](sql(s"CREATE DATABASE $globalTempDB")) + assert(e.message.contains("system preserved database")) + + val e2 = intercept[AnalysisException](sql(s"USE $globalTempDB")) + assert(e2.message.contains("system preserved database")) + } + + test("CREATE GLOBAL TEMP VIEW USING") { + withTempPath { path => + try { + Seq(1 -> "a").toDF("i", "j").write.parquet(path.getAbsolutePath) + sql(s"CREATE GLOBAL TEMP VIEW src USING parquet OPTIONS (PATH '${path.getAbsolutePath}')") + checkAnswer(spark.table(s"$globalTempDB.src"), Row(1, "a")) + sql(s"INSERT INTO $globalTempDB.src SELECT 2, 'b'") + checkAnswer(spark.table(s"$globalTempDB.src"), Row(1, "a") :: Row(2, "b") :: Nil) + } finally { + spark.catalog.dropGlobalTempView("src") + } + } + } + + test("CREATE TABLE LIKE should work for global temp view") { + try { + sql("CREATE GLOBAL TEMP VIEW src AS SELECT 1 AS a, '2' AS b") + sql(s"CREATE TABLE cloned LIKE ${globalTempDB}.src") + val tableMeta = spark.sessionState.catalog.getTableMetadata(TableIdentifier("cloned")) + assert(tableMeta.schema == new StructType().add("a", "int", false).add("b", "string", false)) + } finally { + spark.catalog.dropGlobalTempView("src") + sql("DROP TABLE default.cloned") + } + } + + test("list global temp views") { + try { + sql("CREATE GLOBAL TEMP VIEW v1 AS SELECT 3, 4") + sql("CREATE TEMP VIEW v2 AS SELECT 1, 2") + + checkAnswer(sql(s"SHOW TABLES IN $globalTempDB"), + Row(globalTempDB, "v1", true) :: + Row("", "v2", true) :: Nil) + + assert(spark.catalog.listTables(globalTempDB).collect().toSeq.map(_.name) == Seq("v1", "v2")) + } finally { + spark.catalog.dropTempView("v1") + spark.catalog.dropGlobalTempView("v2") + } + } + + test("should lookup global temp view if and only if global temp db is specified") { + try { + sql("CREATE GLOBAL TEMP VIEW same_name AS SELECT 3, 4") + sql("CREATE TEMP VIEW same_name AS SELECT 1, 2") + + checkAnswer(sql("SELECT * FROM same_name"), Row(1, 2)) + + // we never lookup global temp views if database is not specified in table name + spark.catalog.dropTempView("same_name") + intercept[AnalysisException](sql("SELECT * FROM same_name")) + + // Use qualified name to lookup a global temp view. + checkAnswer(sql(s"SELECT * FROM $globalTempDB.same_name"), Row(3, 4)) + } finally { + spark.catalog.dropTempView("same_name") + spark.catalog.dropGlobalTempView("same_name") + } + } + + test("public Catalog should recognize global temp view") { + try { + sql("CREATE GLOBAL TEMP VIEW src AS SELECT 1, 2") + + assert(spark.catalog.tableExists(globalTempDB, "src")) + assert(spark.catalog.getTable(globalTempDB, "src").toString == new Table( + name = "src", + database = globalTempDB, + description = null, + tableType = "TEMPORARY", + isTemporary = true).toString) + } finally { + spark.catalog.dropGlobalTempView("src") + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 1bcb810a1564f..19885156cc722 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -969,17 +969,17 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { """.stripMargin) checkAnswer( sql("SHOW TABLES IN default 'show1*'"), - Row("show1a", true) :: Nil) + Row("", "show1a", true) :: Nil) checkAnswer( sql("SHOW TABLES IN default 'show1*|show2*'"), - Row("show1a", true) :: - Row("show2b", true) :: Nil) + Row("", "show1a", true) :: + Row("", "show2b", true) :: Nil) checkAnswer( sql("SHOW TABLES 'show1*|show2*'"), - Row("show1a", true) :: - Row("show2b", true) :: Nil) + Row("", "show1a", true) :: + Row("", "show2b", true) :: Nil) assert( sql("SHOW TABLES").count() >= 2) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index 85c509847d8ef..85ecf0ce70756 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder -import org.apache.spark.sql.catalyst.catalog.{FunctionResourceLoader, SessionCatalog} +import org.apache.spark.sql.catalyst.catalog.{FunctionResourceLoader, GlobalTempViewManager, SessionCatalog} import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, ExpressionInfo} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} import org.apache.spark.sql.catalyst.rules.Rule @@ -41,6 +41,7 @@ import org.apache.spark.util.Utils private[sql] class HiveSessionCatalog( externalCatalog: HiveExternalCatalog, + globalTempViewManager: GlobalTempViewManager, sparkSession: SparkSession, functionResourceLoader: FunctionResourceLoader, functionRegistry: FunctionRegistry, @@ -48,6 +49,7 @@ private[sql] class HiveSessionCatalog( hadoopConf: Configuration) extends SessionCatalog( externalCatalog, + globalTempViewManager, functionResourceLoader, functionRegistry, conf, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala index eb10c11382e83..6d4fe1a941a98 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala @@ -45,6 +45,7 @@ private[hive] class HiveSessionState(sparkSession: SparkSession) override lazy val catalog = { new HiveSessionCatalog( sparkSession.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog], + sparkSession.sharedState.globalTempViewManager, sparkSession, functionResourceLoader, functionRegistry, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveContextCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveContextCompatibilitySuite.scala index 57363b7259c61..939fd71b4f1ed 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveContextCompatibilitySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveContextCompatibilitySuite.scala @@ -87,11 +87,11 @@ class HiveContextCompatibilitySuite extends SparkFunSuite with BeforeAndAfterEac assert( hc.sql("SELECT * FROM moo_table order by name").collect().toSeq == df.collect().toSeq.sortBy(_.getString(0))) - val tables = hc.sql("SHOW TABLES IN mee_db").collect().map(_.getString(0)) + val tables = hc.sql("SHOW TABLES IN mee_db").select("tableName").collect().map(_.getString(0)) assert(tables.toSet == Set("moo_table", "mee_table")) hc.sql("DROP TABLE moo_table") hc.sql("DROP TABLE mee_table") - val tables2 = hc.sql("SHOW TABLES IN mee_db").collect().map(_.getString(0)) + val tables2 = hc.sql("SHOW TABLES IN mee_db").select("tableName").collect().map(_.getString(0)) assert(tables2.isEmpty) hc.sql("USE default") hc.sql("DROP DATABASE mee_db CASCADE") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala index 6eeb67510c735..15ba61646d03f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala @@ -58,10 +58,10 @@ class ListTablesSuite extends QueryTest with TestHiveSingleton with BeforeAndAft // We are using default DB. checkAnswer( allTables.filter("tableName = 'listtablessuitetable'"), - Row("listtablessuitetable", true)) + Row("", "listtablessuitetable", true)) checkAnswer( allTables.filter("tableName = 'hivelisttablessuitetable'"), - Row("hivelisttablessuitetable", false)) + Row("default", "hivelisttablessuitetable", false)) assert(allTables.filter("tableName = 'hiveindblisttablessuitetable'").count() === 0) } } @@ -71,11 +71,11 @@ class ListTablesSuite extends QueryTest with TestHiveSingleton with BeforeAndAft case allTables => checkAnswer( allTables.filter("tableName = 'listtablessuitetable'"), - Row("listtablessuitetable", true)) + Row("", "listtablessuitetable", true)) assert(allTables.filter("tableName = 'hivelisttablessuitetable'").count() === 0) checkAnswer( allTables.filter("tableName = 'hiveindblisttablessuitetable'"), - Row("hiveindblisttablessuitetable", false)) + Row("listtablessuitedb", "hiveindblisttablessuitetable", false)) } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index 8ae6868c9848a..51670649ad1d4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -984,7 +984,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv checkAnswer( spark.sql("show TABLES in testdb8156").filter("tableName = 'ttt3'"), - Row("ttt3", false)) + Row("testdb8156", "ttt3", false)) spark.sql("""use default""") spark.sql("""drop database if exists testdb8156 CASCADE""") } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCommandSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCommandSuite.scala index b2103b3bfc36c..2c772ce2155ef 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCommandSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCommandSuite.scala @@ -94,15 +94,15 @@ class HiveCommandSuite extends QueryTest with SQLTestUtils with TestHiveSingleto sql("CREATE TABLE show2b(c2 int)") checkAnswer( sql("SHOW TABLES IN default 'show1*'"), - Row("show1a", false) :: Nil) + Row("default", "show1a", false) :: Nil) checkAnswer( sql("SHOW TABLES IN default 'show1*|show2*'"), - Row("show1a", false) :: - Row("show2b", false) :: Nil) + Row("default", "show1a", false) :: + Row("default", "show2b", false) :: Nil) checkAnswer( sql("SHOW TABLES 'show1*|show2*'"), - Row("show1a", false) :: - Row("show2b", false) :: Nil) + Row("default", "show1a", false) :: + Row("default", "show2b", false) :: Nil) assert( sql("SHOW TABLES").count() >= 2) assert( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLViewSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLViewSuite.scala index f5c605fe5e2fa..2af935da689c9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLViewSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLViewSuite.scala @@ -62,15 +62,15 @@ class SQLViewSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { var e = intercept[AnalysisException] { sql("CREATE OR REPLACE VIEW tab1 AS SELECT * FROM jt") }.getMessage - assert(e.contains("`default`.`tab1` is not a view")) + assert(e.contains("`tab1` is not a view")) e = intercept[AnalysisException] { sql("CREATE VIEW tab1 AS SELECT * FROM jt") }.getMessage - assert(e.contains("`default`.`tab1` is not a view")) + assert(e.contains("`tab1` is not a view")) e = intercept[AnalysisException] { sql("ALTER VIEW tab1 AS SELECT * FROM jt") }.getMessage - assert(e.contains("`default`.`tab1` is not a view")) + assert(e.contains("`tab1` is not a view")) } } From 7e16c94f18ec07e4de63e66e06ad757b9e2550b9 Mon Sep 17 00:00:00 2001 From: jiangxingbo Date: Mon, 10 Oct 2016 13:49:25 +0100 Subject: [PATCH 231/306] [HOT-FIX][SQL][TESTS] Remove unused function in `SparkSqlParserSuite` ## What changes were proposed in this pull request? The function `SparkSqlParserSuite.createTempViewUsing` is not used for now and causes build failure, this PR simply removes it. ## How was this patch tested? N/A Author: jiangxingbo Closes #15418 from jiangxb1987/parserSuite. --- .../spark/sql/execution/SparkSqlParserSuite.scala | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala index e0976ae95001e..679150e9ae4c0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala @@ -116,16 +116,6 @@ class SparkSqlParserSuite extends PlanTest { ) } - private def createTempViewUsing( - table: String, - database: Option[String] = None, - schema: Option[StructType] = None, - replace: Boolean = true, - provider: String = "parquet", - options: Map[String, String] = Map.empty): LogicalPlan = { - CreateTempViewUsing(TableIdentifier(table, database), schema, replace, provider, options) - } - private def createTable( table: String, database: Option[String] = None, From 4bafacaa5f50a3e986c14a38bc8df9bae303f3a0 Mon Sep 17 00:00:00 2001 From: Dhruve Ashar Date: Mon, 10 Oct 2016 10:55:57 -0500 Subject: [PATCH 232/306] [SPARK-17417][CORE] Fix # of partitions for Reliable RDD checkpointing ## What changes were proposed in this pull request? Currently the no. of partition files are limited to 10000 files (%05d format). If there are more than 10000 part files, the logic goes for a toss while recreating the RDD as it sorts them by string. More details can be found in the JIRA desc [here](https://issues.apache.org/jira/browse/SPARK-17417). ## How was this patch tested? I tested this patch by checkpointing a RDD and then manually renaming part files to the old format and tried to access the RDD. It was successfully created from the old format. Also verified loading a sample parquet file and saving it as multiple formats - CSV, JSON, Text, Parquet, ORC and read them successfully back from the saved files. I couldn't launch the unit test from my local box, so will wait for the Jenkins output. Author: Dhruve Ashar Closes #15370 from dhruve/bug/SPARK-17417. --- .../scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala index ab6554fd8a7e7..eac901d10067c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala @@ -69,10 +69,10 @@ private[spark] class ReliableCheckpointRDD[T: ClassTag]( val inputFiles = fs.listStatus(cpath) .map(_.getPath) .filter(_.getName.startsWith("part-")) - .sortBy(_.toString) + .sortBy(_.getName.stripPrefix("part-").toInt) // Fail fast if input files are invalid inputFiles.zipWithIndex.foreach { case (path, i) => - if (!path.toString.endsWith(ReliableCheckpointRDD.checkpointFileName(i))) { + if (path.getName != ReliableCheckpointRDD.checkpointFileName(i)) { throw new SparkException(s"Invalid checkpoint file: $path") } } From 689de920056ae20fe203c2b6faf5b1462e8ea73c Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 10 Oct 2016 11:29:09 -0700 Subject: [PATCH 233/306] [SPARK-17830] Annotate spark.sql package with InterfaceStability ## What changes were proposed in this pull request? This patch annotates the InterfaceStability level for top level classes in o.a.spark.sql and o.a.spark.sql.util packages, to experiment with this new annotation. ## How was this patch tested? N/A Author: Reynold Xin Closes #15392 from rxin/SPARK-17830. --- .../scala/org/apache/spark/sql/Column.scala | 5 +++- .../spark/sql/DataFrameNaFunctions.scala | 3 +- .../apache/spark/sql/DataFrameReader.scala | 2 ++ .../spark/sql/DataFrameStatFunctions.scala | 3 +- .../apache/spark/sql/DataFrameWriter.scala | 3 +- .../scala/org/apache/spark/sql/Dataset.scala | 29 ++++++++++++++++--- .../org/apache/spark/sql/DatasetHolder.scala | 3 ++ .../spark/sql/ExperimentalMethods.scala | 5 ++-- .../org/apache/spark/sql/ForeachWriter.scala | 5 +++- .../spark/sql/KeyValueGroupedDataset.scala | 3 +- .../spark/sql/RelationalGroupedDataset.scala | 4 +-- .../org/apache/spark/sql/RuntimeConfig.scala | 2 ++ .../org/apache/spark/sql/SQLContext.scala | 18 +++++++++++- .../org/apache/spark/sql/SQLImplicits.scala | 2 ++ .../org/apache/spark/sql/SparkSession.scala | 23 ++++++++++++++- .../apache/spark/sql/UDFRegistration.scala | 2 ++ .../org/apache/spark/sql/functions.scala | 8 +++-- .../scala/org/apache/spark/sql/package.scala | 5 ++-- .../sql/util/QueryExecutionListener.scala | 4 ++- 19 files changed, 107 insertions(+), 22 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 63da501f18cca..d22bb17934ce7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import scala.language.implicitConversions -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, InterfaceStability} import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} @@ -59,6 +59,7 @@ private[sql] object Column { * * @since 1.6.0 */ +@InterfaceStability.Stable class TypedColumn[-T, U]( expr: Expression, private[sql] val encoder: ExpressionEncoder[U]) @@ -124,6 +125,7 @@ class TypedColumn[-T, U]( * * @since 1.3.0 */ +@InterfaceStability.Stable class Column(protected[sql] val expr: Expression) extends Logging { def this(name: String) = this(name match { @@ -1185,6 +1187,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @since 1.3.0 */ @Experimental +@InterfaceStability.Evolving class ColumnName(name: String) extends Column(name) { /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index ad00966a917ad..65a9c008f9650 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -21,7 +21,7 @@ import java.{lang => jl} import scala.collection.JavaConverters._ -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, InterfaceStability} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -34,6 +34,7 @@ import org.apache.spark.sql.types._ * @since 1.3.1 */ @Experimental +@InterfaceStability.Evolving final class DataFrameNaFunctions private[sql](df: DataFrame) { /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index b84fb2fb95914..b54e695db3b5e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -24,6 +24,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging import org.apache.spark.Partition +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.json.{JacksonParser, JSONOptions} import org.apache.spark.sql.execution.LogicalRDD @@ -38,6 +39,7 @@ import org.apache.spark.sql.types.StructType * * @since 1.4.0 */ +@InterfaceStability.Stable class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index d69be36917360..a212bb6205328 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -21,7 +21,7 @@ import java.{lang => jl, util => ju} import scala.collection.JavaConverters._ -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, InterfaceStability} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.stat._ import org.apache.spark.sql.types._ @@ -34,6 +34,7 @@ import org.apache.spark.util.sketch.{BloomFilter, CountMinSketch} * @since 1.4.0 */ @Experimental +@InterfaceStability.Evolving final class DataFrameStatFunctions private[sql](df: DataFrame) { /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 7374a8e045035..35ef050dcb169 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -21,12 +21,12 @@ import java.util.Properties import scala.collection.JavaConverters._ +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType} import org.apache.spark.sql.catalyst.plans.logical.InsertIntoTable import org.apache.spark.sql.execution.datasources.{CaseInsensitiveMap, CreateTable, DataSource, HadoopFsRelation} -import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils} import org.apache.spark.sql.types.StructType /** @@ -35,6 +35,7 @@ import org.apache.spark.sql.types.StructType * * @since 1.4.0 */ +@InterfaceStability.Stable final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { private val df = ds.toDF() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 4b52508740bfb..30349ba3cb452 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -26,7 +26,7 @@ import scala.util.control.NonFatal import org.apache.commons.lang3.StringUtils -import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.java.function._ import org.apache.spark.api.python.{PythonRDD, SerDeUtil} @@ -149,9 +149,10 @@ private[sql] object Dataset { * * @since 1.6.0 */ +@InterfaceStability.Stable class Dataset[T] private[sql]( @transient val sparkSession: SparkSession, - @DeveloperApi @transient val queryExecution: QueryExecution, + @DeveloperApi @InterfaceStability.Unstable @transient val queryExecution: QueryExecution, encoder: Encoder[T]) extends Serializable { @@ -369,6 +370,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental + @InterfaceStability.Evolving def as[U : Encoder]: Dataset[U] = Dataset[U](sparkSession, logicalPlan) /** @@ -477,6 +479,7 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ @Experimental + @InterfaceStability.Evolving def isStreaming: Boolean = logicalPlan.isStreaming /** @@ -798,6 +801,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental + @InterfaceStability.Evolving def joinWith[U](other: Dataset[U], condition: Column, joinType: String): Dataset[(T, U)] = { // Creates a Join node and resolve it first, to get join condition resolved, self-join resolved, // etc. @@ -869,6 +873,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental + @InterfaceStability.Evolving def joinWith[U](other: Dataset[U], condition: Column): Dataset[(T, U)] = { joinWith(other, condition, "inner") } @@ -1071,6 +1076,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental + @InterfaceStability.Evolving def select[U1](c1: TypedColumn[T, U1]): Dataset[U1] = { implicit val encoder = c1.encoder val project = Project(c1.withInputType(exprEnc, logicalPlan.output).named :: Nil, @@ -1105,6 +1111,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental + @InterfaceStability.Evolving def select[U1, U2](c1: TypedColumn[T, U1], c2: TypedColumn[T, U2]): Dataset[(U1, U2)] = selectUntyped(c1, c2).asInstanceOf[Dataset[(U1, U2)]] @@ -1116,6 +1123,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental + @InterfaceStability.Evolving def select[U1, U2, U3]( c1: TypedColumn[T, U1], c2: TypedColumn[T, U2], @@ -1130,6 +1138,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental + @InterfaceStability.Evolving def select[U1, U2, U3, U4]( c1: TypedColumn[T, U1], c2: TypedColumn[T, U2], @@ -1145,6 +1154,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental + @InterfaceStability.Evolving def select[U1, U2, U3, U4, U5]( c1: TypedColumn[T, U1], c2: TypedColumn[T, U2], @@ -1315,6 +1325,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental + @InterfaceStability.Evolving def reduce(func: (T, T) => T): T = rdd.reduce(func) /** @@ -1327,6 +1338,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental + @InterfaceStability.Evolving def reduce(func: ReduceFunction[T]): T = reduce(func.call(_, _)) /** @@ -1338,6 +1350,7 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ @Experimental + @InterfaceStability.Evolving def groupByKey[K: Encoder](func: T => K): KeyValueGroupedDataset[K, T] = { val inputPlan = logicalPlan val withGroupingKey = AppendColumns(func, inputPlan) @@ -1360,6 +1373,7 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ @Experimental + @InterfaceStability.Evolving def groupByKey[K](func: MapFunction[T, K], encoder: Encoder[K]): KeyValueGroupedDataset[K, T] = groupByKey(func.call(_))(encoder) @@ -2028,6 +2042,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental + @InterfaceStability.Evolving def filter(func: T => Boolean): Dataset[T] = { withTypedPlan(TypedFilter(func, logicalPlan)) } @@ -2041,6 +2056,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental + @InterfaceStability.Evolving def filter(func: FilterFunction[T]): Dataset[T] = { withTypedPlan(TypedFilter(func, logicalPlan)) } @@ -2054,6 +2070,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental + @InterfaceStability.Evolving def map[U : Encoder](func: T => U): Dataset[U] = withTypedPlan { MapElements[T, U](func, logicalPlan) } @@ -2067,6 +2084,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental + @InterfaceStability.Evolving def map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] = { implicit val uEnc = encoder withTypedPlan(MapElements[T, U](func, logicalPlan)) @@ -2081,6 +2099,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental + @InterfaceStability.Evolving def mapPartitions[U : Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] = { new Dataset[U]( sparkSession, @@ -2097,6 +2116,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental + @InterfaceStability.Evolving def mapPartitions[U](f: MapPartitionsFunction[T, U], encoder: Encoder[U]): Dataset[U] = { val func: (Iterator[T]) => Iterator[U] = x => f.call(x.asJava).asScala mapPartitions(func)(encoder) @@ -2127,6 +2147,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental + @InterfaceStability.Evolving def flatMap[U : Encoder](func: T => TraversableOnce[U]): Dataset[U] = mapPartitions(_.flatMap(func)) @@ -2140,6 +2161,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental + @InterfaceStability.Evolving def flatMap[U](f: FlatMapFunction[T, U], encoder: Encoder[U]): Dataset[U] = { val func: (T) => Iterator[U] = x => f.call(x).asScala flatMap(func)(encoder) @@ -2505,13 +2527,11 @@ class Dataset[T] private[sql]( } /** - * :: Experimental :: * Interface for saving the content of the non-streaming Dataset out into external storage. * * @group basic * @since 1.6.0 */ - @Experimental def write: DataFrameWriter[T] = { if (isStreaming) { logicalPlan.failAnalysis( @@ -2528,6 +2548,7 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ @Experimental + @InterfaceStability.Evolving def writeStream: DataStreamWriter[T] = { if (!isStreaming) { logicalPlan.failAnalysis( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala b/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala index 47b81c17a31dc..18bccee98f610 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import org.apache.spark.annotation.InterfaceStability + /** * A container for a [[Dataset]], used for implicit conversions in Scala. * @@ -27,6 +29,7 @@ package org.apache.spark.sql * * @since 1.6.0 */ +@InterfaceStability.Stable case class DatasetHolder[T] private[sql](private val ds: Dataset[T]) { // This is declared with parentheses to prevent the Scala compiler from treating diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala b/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala index a435734b0caef..1e8ba51e59e33 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, InterfaceStability} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.rules.Rule * @since 1.3.0 */ @Experimental +@InterfaceStability.Unstable class ExperimentalMethods private[sql]() { /** @@ -41,10 +42,8 @@ class ExperimentalMethods private[sql]() { * * @since 1.3.0 */ - @Experimental @volatile var extraStrategies: Seq[Strategy] = Nil - @Experimental @volatile var extraOptimizations: Seq[Rule[LogicalPlan]] = Nil } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala index f56b25b5576f1..1163035e315fc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, InterfaceStability} import org.apache.spark.sql.streaming.StreamingQuery /** @@ -68,8 +68,11 @@ import org.apache.spark.sql.streaming.StreamingQuery * @since 2.0.0 */ @Experimental +@InterfaceStability.Evolving abstract class ForeachWriter[T] extends Serializable { + // TODO: Move this to org.apache.spark.sql.util or consolidate this with batch API. + /** * Called when starting to process one partition of new data in the executor. The `version` is * for data deduplication when there are failures. When recovering from a failure, some data may diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index cea16fba76e47..828eb94efe598 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import scala.collection.JavaConverters._ -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, InterfaceStability} import org.apache.spark.api.java.function._ import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, CreateStruct} @@ -36,6 +36,7 @@ import org.apache.spark.sql.expressions.ReduceAggregator * @since 2.0.0 */ @Experimental +@InterfaceStability.Evolving class KeyValueGroupedDataset[K, V] private[sql]( kEncoder: Encoder[K], vEncoder: Encoder[V], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 6c3fe07709fa3..f019d1e9daceb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -20,10 +20,9 @@ package org.apache.spark.sql import scala.collection.JavaConverters._ import scala.language.implicitConversions +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.broadcast.Broadcast -import org.apache.spark.sql.api.r.SQLUtils._ import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction} -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, FlatMapGroupsInR, Pivot} @@ -43,6 +42,7 @@ import org.apache.spark.sql.types.StructType * * @since 2.0.0 */ +@InterfaceStability.Stable class RelationalGroupedDataset protected[sql]( df: DataFrame, groupingExprs: Seq[Expression], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala b/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala index 7e07e0cb84a87..c2baa74ed7d2e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.internal.config.{ConfigEntry, OptionalConfigEntry} import org.apache.spark.sql.internal.SQLConf @@ -28,6 +29,7 @@ import org.apache.spark.sql.internal.SQLConf * * @since 2.0.0 */ +@InterfaceStability.Stable class RuntimeConfig private[sql](sqlConf: SQLConf = new SQLConf) { /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 2edf2e1972053..3c5cf037c578d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -24,7 +24,7 @@ import scala.collection.immutable import scala.reflect.runtime.universe.TypeTag import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.internal.Logging import org.apache.spark.internal.config.ConfigEntry @@ -55,6 +55,7 @@ import org.apache.spark.sql.util.ExecutionListenerManager * @groupname Ungrouped Support functions for language integrated queries * @since 1.0.0 */ +@InterfaceStability.Stable class SQLContext private[sql](val sparkSession: SparkSession) extends Logging with Serializable { @@ -95,6 +96,7 @@ class SQLContext private[sql](val sparkSession: SparkSession) * that listen for execution metrics. */ @Experimental + @InterfaceStability.Evolving def listenerManager: ExecutionListenerManager = sparkSession.listenerManager /** @@ -166,6 +168,7 @@ class SQLContext private[sql](val sparkSession: SparkSession) */ @Experimental @transient + @InterfaceStability.Unstable def experimental: ExperimentalMethods = sparkSession.experimental /** @@ -261,6 +264,7 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @since 1.3.0 */ @Experimental + @InterfaceStability.Evolving object implicits extends SQLImplicits with Serializable { protected override def _sqlContext: SQLContext = self } @@ -274,6 +278,7 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @since 1.3.0 */ @Experimental + @InterfaceStability.Evolving def createDataFrame[A <: Product : TypeTag](rdd: RDD[A]): DataFrame = { sparkSession.createDataFrame(rdd) } @@ -286,6 +291,7 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @since 1.3.0 */ @Experimental + @InterfaceStability.Evolving def createDataFrame[A <: Product : TypeTag](data: Seq[A]): DataFrame = { sparkSession.createDataFrame(data) } @@ -333,6 +339,7 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @since 1.3.0 */ @DeveloperApi + @InterfaceStability.Evolving def createDataFrame(rowRDD: RDD[Row], schema: StructType): DataFrame = { sparkSession.createDataFrame(rowRDD, schema) } @@ -376,6 +383,7 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @group dataset */ @Experimental + @InterfaceStability.Evolving def createDataset[T : Encoder](data: Seq[T]): Dataset[T] = { sparkSession.createDataset(data) } @@ -413,6 +421,7 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @group dataset */ @Experimental + @InterfaceStability.Evolving def createDataset[T : Encoder](data: java.util.List[T]): Dataset[T] = { sparkSession.createDataset(data) } @@ -436,6 +445,7 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @since 1.3.0 */ @DeveloperApi + @InterfaceStability.Evolving def createDataFrame(rowRDD: JavaRDD[Row], schema: StructType): DataFrame = { sparkSession.createDataFrame(rowRDD, schema) } @@ -450,6 +460,7 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @since 1.6.0 */ @DeveloperApi + @InterfaceStability.Evolving def createDataFrame(rows: java.util.List[Row], schema: StructType): DataFrame = { sparkSession.createDataFrame(rows, schema) } @@ -515,6 +526,7 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @since 2.0.0 */ @Experimental + @InterfaceStability.Evolving def readStream: DataStreamReader = sparkSession.readStream @@ -632,6 +644,7 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @group dataframe */ @Experimental + @InterfaceStability.Evolving def range(end: Long): DataFrame = sparkSession.range(end).toDF() /** @@ -643,6 +656,7 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @group dataframe */ @Experimental + @InterfaceStability.Evolving def range(start: Long, end: Long): DataFrame = sparkSession.range(start, end).toDF() /** @@ -654,6 +668,7 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @group dataframe */ @Experimental + @InterfaceStability.Evolving def range(start: Long, end: Long, step: Long): DataFrame = { sparkSession.range(start, end, step).toDF() } @@ -668,6 +683,7 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @group dataframe */ @Experimental + @InterfaceStability.Evolving def range(start: Long, end: Long, step: Long, numPartitions: Int): DataFrame = { sparkSession.range(start, end, step, numPartitions).toDF() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala index 440952572d8c4..73d16d8a10fd6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder @@ -28,6 +29,7 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder * * @since 1.6.0 */ +@InterfaceStability.Evolving abstract class SQLImplicits { protected def _sqlContext: SQLContext diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 6d7ac0f6c1bb2..d26eea507284c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -26,7 +26,7 @@ import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext} -import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging import org.apache.spark.internal.config.CATALOG_IMPLEMENTATION @@ -68,6 +68,7 @@ import org.apache.spark.util.Utils * .getOrCreate() * }}} */ +@InterfaceStability.Stable class SparkSession private( @transient val sparkContext: SparkContext, @transient private val existingSharedState: Option[SharedState]) @@ -137,6 +138,7 @@ class SparkSession private( * @since 2.0.0 */ @Experimental + @InterfaceStability.Evolving def listenerManager: ExecutionListenerManager = sessionState.listenerManager /** @@ -147,6 +149,7 @@ class SparkSession private( * @since 2.0.0 */ @Experimental + @InterfaceStability.Unstable def experimental: ExperimentalMethods = sessionState.experimentalMethods /** @@ -190,6 +193,7 @@ class SparkSession private( * @since 2.0.0 */ @Experimental + @InterfaceStability.Unstable def streams: StreamingQueryManager = sessionState.streamingQueryManager /** @@ -229,6 +233,7 @@ class SparkSession private( * @return 2.0.0 */ @Experimental + @InterfaceStability.Evolving def emptyDataset[T: Encoder]: Dataset[T] = { val encoder = implicitly[Encoder[T]] new Dataset(self, LocalRelation(encoder.schema.toAttributes), encoder) @@ -241,6 +246,7 @@ class SparkSession private( * @since 2.0.0 */ @Experimental + @InterfaceStability.Evolving def createDataFrame[A <: Product : TypeTag](rdd: RDD[A]): DataFrame = { SparkSession.setActiveSession(this) val encoder = Encoders.product[A] @@ -254,6 +260,7 @@ class SparkSession private( * @since 2.0.0 */ @Experimental + @InterfaceStability.Evolving def createDataFrame[A <: Product : TypeTag](data: Seq[A]): DataFrame = { SparkSession.setActiveSession(this) val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] @@ -293,6 +300,7 @@ class SparkSession private( * @since 2.0.0 */ @DeveloperApi + @InterfaceStability.Evolving def createDataFrame(rowRDD: RDD[Row], schema: StructType): DataFrame = { createDataFrame(rowRDD, schema, needsConversion = true) } @@ -306,6 +314,7 @@ class SparkSession private( * @since 2.0.0 */ @DeveloperApi + @InterfaceStability.Evolving def createDataFrame(rowRDD: JavaRDD[Row], schema: StructType): DataFrame = { createDataFrame(rowRDD.rdd, schema) } @@ -319,6 +328,7 @@ class SparkSession private( * @since 2.0.0 */ @DeveloperApi + @InterfaceStability.Evolving def createDataFrame(rows: java.util.List[Row], schema: StructType): DataFrame = { Dataset.ofRows(self, LocalRelation.fromExternalRows(schema.toAttributes, rows.asScala)) } @@ -410,6 +420,7 @@ class SparkSession private( * @since 2.0.0 */ @Experimental + @InterfaceStability.Evolving def createDataset[T : Encoder](data: Seq[T]): Dataset[T] = { val enc = encoderFor[T] val attributes = enc.schema.toAttributes @@ -428,6 +439,7 @@ class SparkSession private( * @since 2.0.0 */ @Experimental + @InterfaceStability.Evolving def createDataset[T : Encoder](data: RDD[T]): Dataset[T] = { Dataset[T](self, ExternalRDD(data, self)) } @@ -449,6 +461,7 @@ class SparkSession private( * @since 2.0.0 */ @Experimental + @InterfaceStability.Evolving def createDataset[T : Encoder](data: java.util.List[T]): Dataset[T] = { createDataset(data.asScala) } @@ -461,6 +474,7 @@ class SparkSession private( * @since 2.0.0 */ @Experimental + @InterfaceStability.Evolving def range(end: Long): Dataset[java.lang.Long] = range(0, end) /** @@ -471,6 +485,7 @@ class SparkSession private( * @since 2.0.0 */ @Experimental + @InterfaceStability.Evolving def range(start: Long, end: Long): Dataset[java.lang.Long] = { range(start, end, step = 1, numPartitions = sparkContext.defaultParallelism) } @@ -483,6 +498,7 @@ class SparkSession private( * @since 2.0.0 */ @Experimental + @InterfaceStability.Evolving def range(start: Long, end: Long, step: Long): Dataset[java.lang.Long] = { range(start, end, step, numPartitions = sparkContext.defaultParallelism) } @@ -496,6 +512,7 @@ class SparkSession private( * @since 2.0.0 */ @Experimental + @InterfaceStability.Evolving def range(start: Long, end: Long, step: Long, numPartitions: Int): Dataset[java.lang.Long] = { new Dataset(self, Range(start, end, step, numPartitions), Encoders.LONG) } @@ -596,6 +613,7 @@ class SparkSession private( * @since 2.0.0 */ @Experimental + @InterfaceStability.Evolving def readStream: DataStreamReader = new DataStreamReader(self) @@ -614,6 +632,7 @@ class SparkSession private( * @since 2.0.0 */ @Experimental + @InterfaceStability.Evolving object implicits extends SQLImplicits with Serializable { protected override def _sqlContext: SQLContext = SparkSession.this.sqlContext } @@ -670,11 +689,13 @@ class SparkSession private( } +@InterfaceStability.Stable object SparkSession { /** * Builder for [[SparkSession]]. */ + @InterfaceStability.Stable class Builder extends Logging { private[this] val options = new scala.collection.mutable.HashMap[String, String] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index b006236481a29..617a14793697b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import scala.reflect.runtime.universe.TypeTag import scala.util.Try +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.internal.Logging import org.apache.spark.sql.api.java._ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry @@ -36,6 +37,7 @@ import org.apache.spark.sql.types.DataType * * @since 1.3.0 */ +@InterfaceStability.Stable class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends Logging { protected[sql] def registerPython(name: String, udf: UserDefinedPythonFunction): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 40f82d895d43b..de4943152720c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -22,7 +22,7 @@ import scala.language.implicitConversions import scala.reflect.runtime.universe.{typeTag, TypeTag} import scala.util.Try -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, InterfaceStability} import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedFunction} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder @@ -38,7 +38,7 @@ import org.apache.spark.util.Utils /** * :: Experimental :: - * Functions available for [[DataFrame]]. + * Functions available for DataFrame operations. * * @groupname udf_funcs UDF functions * @groupname agg_funcs Aggregate functions @@ -54,6 +54,7 @@ import org.apache.spark.util.Utils * @since 1.3.0 */ @Experimental +@InterfaceStability.Evolving // scalastyle:off object functions { // scalastyle:on @@ -2730,6 +2731,7 @@ object functions { * @since 2.0.0 */ @Experimental + @InterfaceStability.Evolving def window( timeColumn: Column, windowDuration: String, @@ -2783,6 +2785,7 @@ object functions { * @since 2.0.0 */ @Experimental + @InterfaceStability.Evolving def window(timeColumn: Column, windowDuration: String, slideDuration: String): Column = { window(timeColumn, windowDuration, slideDuration, "0 second") } @@ -2821,6 +2824,7 @@ object functions { * @since 2.0.0 */ @Experimental + @InterfaceStability.Evolving def window(timeColumn: Column, windowDuration: String): Column = { window(timeColumn, windowDuration, windowDuration, "0 second") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala index 28d8bc3de68b8..161e0102f0b43 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala @@ -17,8 +17,8 @@ package org.apache.spark -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.execution.{SparkPlan, SparkStrategy} +import org.apache.spark.annotation.{DeveloperApi, InterfaceStability} +import org.apache.spark.sql.execution.SparkStrategy /** * Allows the execution of relational queries, including those expressed in SQL using Spark. @@ -40,6 +40,7 @@ package object sql { * [[org.apache.spark.sql.sources]] */ @DeveloperApi + @InterfaceStability.Unstable type Strategy = SparkStrategy type DataFrame = Dataset[Row] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala index 3cae5355eecc6..5e93fc469a41f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala @@ -22,7 +22,7 @@ import java.util.concurrent.locks.ReentrantReadWriteLock import scala.collection.mutable.ListBuffer import scala.util.control.NonFatal -import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} import org.apache.spark.internal.Logging import org.apache.spark.sql.execution.QueryExecution @@ -34,6 +34,7 @@ import org.apache.spark.sql.execution.QueryExecution * multiple different threads. */ @Experimental +@InterfaceStability.Evolving trait QueryExecutionListener { /** @@ -68,6 +69,7 @@ trait QueryExecutionListener { * Manager for [[QueryExecutionListener]]. See [[org.apache.spark.sql.SQLContext.listenerManager]]. */ @Experimental +@InterfaceStability.Evolving class ExecutionListenerManager private[sql] () extends Logging { /** From 3f8a0222e2fa9351a3de09bd2636b000a88da67a Mon Sep 17 00:00:00 2001 From: Adam Roberts Date: Mon, 10 Oct 2016 23:16:40 +0200 Subject: [PATCH 234/306] [SPARK-17828][DOCS] Remove unused generate-changelist.py ## What changes were proposed in this pull request? We can remove this file based on discussion at https://issues.apache.org/jira/browse/SPARK-17828 it's evident this file has been redundant for a while, JIRA release notes serves this purpose for us already. For ease of future reference you can find detailed release notes at, for example: http://spark.apache.org/downloads.html -> http://spark.apache.org/releases/spark-release-2-0-1.html -> "Detailed changes" which links to https://issues.apache.org/jira/secure/ReleaseNote.jspa?projectId=12315420&version=12336857 ## How was this patch tested? Searched the codebase and saw nothing referencing this, hasn't been used in a while (probably manually invoked a long time ago) Author: Adam Roberts Closes #15419 from a-roberts/patch-7. --- dev/create-release/generate-changelist.py | 148 ---------------------- 1 file changed, 148 deletions(-) delete mode 100755 dev/create-release/generate-changelist.py diff --git a/dev/create-release/generate-changelist.py b/dev/create-release/generate-changelist.py deleted file mode 100755 index 2e1a35a629342..0000000000000 --- a/dev/create-release/generate-changelist.py +++ /dev/null @@ -1,148 +0,0 @@ -#!/usr/bin/python - -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# Creates CHANGES.txt from git history. -# -# Usage: -# First set the new release version and old CHANGES.txt version in this file. -# Make sure you have SPARK_HOME set. -# $ python generate-changelist.py - - -import os -import sys -import subprocess -import time -import traceback - -SPARK_HOME = os.environ["SPARK_HOME"] -NEW_RELEASE_VERSION = "1.0.0" -PREV_RELEASE_GIT_TAG = "v0.9.1" - -CHANGELIST = "CHANGES.txt" -OLD_CHANGELIST = "%s.old" % (CHANGELIST) -NEW_CHANGELIST = "%s.new" % (CHANGELIST) -TMP_CHANGELIST = "%s.tmp" % (CHANGELIST) - -# date before first PR in TLP Spark repo -SPARK_REPO_CHANGE_DATE1 = time.strptime("2014-02-26", "%Y-%m-%d") -# date after last PR in incubator Spark repo -SPARK_REPO_CHANGE_DATE2 = time.strptime("2014-03-01", "%Y-%m-%d") -# Threshold PR number that differentiates PRs to TLP -# and incubator repos -SPARK_REPO_PR_NUM_THRESH = 200 - -LOG_FILE_NAME = "changes_%s" % time.strftime("%h_%m_%Y_%I_%M_%S") -LOG_FILE = open(LOG_FILE_NAME, 'w') - - -def run_cmd(cmd): - try: - print >> LOG_FILE, "Running command: %s" % cmd - output = subprocess.check_output(cmd, shell=True, stderr=LOG_FILE) - print >> LOG_FILE, "Output: %s" % output - return output - except: - traceback.print_exc() - cleanup() - sys.exit(1) - - -def append_to_changelist(string): - with open(TMP_CHANGELIST, "a") as f: - print >> f, string - - -def cleanup(ask=True): - if ask is True: - print "OK to delete temporary and log files? (y/N): " - response = raw_input() - if ask is False or (ask is True and response == "y"): - if os.path.isfile(TMP_CHANGELIST): - os.remove(TMP_CHANGELIST) - if os.path.isfile(OLD_CHANGELIST): - os.remove(OLD_CHANGELIST) - LOG_FILE.close() - os.remove(LOG_FILE_NAME) - - -print "Generating new %s for Spark release %s" % (CHANGELIST, NEW_RELEASE_VERSION) -os.chdir(SPARK_HOME) -if os.path.isfile(TMP_CHANGELIST): - os.remove(TMP_CHANGELIST) -if os.path.isfile(OLD_CHANGELIST): - os.remove(OLD_CHANGELIST) - -append_to_changelist("Spark Change Log") -append_to_changelist("----------------") -append_to_changelist("") -append_to_changelist("Release %s" % NEW_RELEASE_VERSION) -append_to_changelist("") - -print "Getting commits between tag %s and HEAD" % PREV_RELEASE_GIT_TAG -hashes = run_cmd("git log %s..HEAD --pretty='%%h'" % PREV_RELEASE_GIT_TAG).split() - -print "Getting details of %s commits" % len(hashes) -for h in hashes: - date = run_cmd("git log %s -1 --pretty='%%ad' --date=iso | head -1" % h).strip() - subject = run_cmd("git log %s -1 --pretty='%%s' | head -1" % h).strip() - body = run_cmd("git log %s -1 --pretty='%%b'" % h) - committer = run_cmd("git log %s -1 --pretty='%%cn <%%ce>' | head -1" % h).strip() - body_lines = body.split("\n") - - if "Merge pull" in subject: - # Parse old format commit message - append_to_changelist(" %s %s" % (h, date)) - append_to_changelist(" %s" % subject) - append_to_changelist(" [%s]" % body_lines[0]) - append_to_changelist("") - - elif "maven-release" not in subject: - # Parse new format commit message - # Get authors from commit message, committer otherwise - authors = [committer] - if "Author:" in body: - authors = [line.split(":")[1].strip() for line in body_lines if "Author:" in line] - - # Generate GitHub PR URL for easy access if possible - github_url = "" - if "Closes #" in body: - pr_num = [line.split()[1].lstrip("#") for line in body_lines if "Closes #" in line][0] - github_url = "github.com/apache/spark/pull/%s" % pr_num - day = time.strptime(date.split()[0], "%Y-%m-%d") - if (day < SPARK_REPO_CHANGE_DATE1 or - (day < SPARK_REPO_CHANGE_DATE2 and pr_num < SPARK_REPO_PR_NUM_THRESH)): - github_url = "github.com/apache/incubator-spark/pull/%s" % pr_num - - append_to_changelist(" %s" % subject) - append_to_changelist(" %s" % ', '.join(authors)) - # for author in authors: - # append_to_changelist(" %s" % author) - append_to_changelist(" %s" % date) - if len(github_url) > 0: - append_to_changelist(" Commit: %s, %s" % (h, github_url)) - else: - append_to_changelist(" Commit: %s" % h) - append_to_changelist("") - -# Append old change list -print "Appending changelist from tag %s" % PREV_RELEASE_GIT_TAG -run_cmd("git show %s:%s | tail -n +3 >> %s" % (PREV_RELEASE_GIT_TAG, CHANGELIST, TMP_CHANGELIST)) -run_cmd("cp %s %s" % (TMP_CHANGELIST, NEW_CHANGELIST)) -print "New change list generated as %s" % NEW_CHANGELIST -cleanup(False) From 29f186bfdf929b1e8ffd8e33ee37b76d5dc5af53 Mon Sep 17 00:00:00 2001 From: Timothy Chen Date: Mon, 10 Oct 2016 23:20:15 +0200 Subject: [PATCH 235/306] [SPARK-14082][MESOS] Enable GPU support with Mesos ## What changes were proposed in this pull request? Enable GPU resources to be used when running coarse grain mode with Mesos. ## How was this patch tested? Manual test with GPU. Author: Timothy Chen Closes #14644 from tnachen/gpu_mesos. --- docs/running-on-mesos.md | 9 +++ .../MesosCoarseGrainedSchedulerBackend.scala | 30 +++++++-- .../cluster/mesos/MesosSchedulerUtils.scala | 5 ++ ...osCoarseGrainedSchedulerBackendSuite.scala | 61 ++++++++++++++----- .../spark/scheduler/cluster/mesos/Utils.scala | 14 +++-- 5 files changed, 96 insertions(+), 23 deletions(-) diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 173961deaadcb..77b06fcf33740 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -498,6 +498,15 @@ See the [configuration page](configuration.html) for information on Spark config in the history server. + + spark.mesos.gpus.max + 0 + + Set the maximum number GPU resources to acquire for this job. Note that executors will still launch when no GPU resources are found + since this configuration is just a upper limit and not a guaranteed amount. + + + diff --git a/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala index a64b5768c57b2..e67bf3e328f94 100644 --- a/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala +++ b/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala @@ -59,6 +59,8 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( // Maximum number of cores to acquire (TODO: we'll need more flexible controls here) val maxCores = conf.get("spark.cores.max", Int.MaxValue.toString).toInt + val maxGpus = conf.getInt("spark.mesos.gpus.max", 0) + private[this] val shutdownTimeoutMS = conf.getTimeAsMs("spark.mesos.coarse.shutdownTimeout", "10s") .ensuring(_ >= 0, "spark.mesos.coarse.shutdownTimeout must be >= 0") @@ -72,7 +74,9 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( // Cores we have acquired with each Mesos task ID val coresByTaskId = new mutable.HashMap[String, Int] + val gpusByTaskId = new mutable.HashMap[String, Int] var totalCoresAcquired = 0 + var totalGpusAcquired = 0 // SlaveID -> Slave // This map accumulates entries for the duration of the job. Slaves are never deleted, because @@ -396,6 +400,8 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( launchTasks = true val taskId = newMesosTaskId() val offerCPUs = getResource(resources, "cpus").toInt + val taskGPUs = Math.min( + Math.max(0, maxGpus - totalGpusAcquired), getResource(resources, "gpus").toInt) val taskCPUs = executorCores(offerCPUs) val taskMemory = executorMemory(sc) @@ -403,7 +409,7 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( slaves.getOrElseUpdate(slaveId, new Slave(offer.getHostname)).taskIDs.add(taskId) val (resourcesLeft, resourcesToUse) = - partitionTaskResources(resources, taskCPUs, taskMemory) + partitionTaskResources(resources, taskCPUs, taskMemory, taskGPUs) val taskBuilder = MesosTaskInfo.newBuilder() .setTaskId(TaskID.newBuilder().setValue(taskId.toString).build()) @@ -425,6 +431,10 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( remainingResources(offerId) = resourcesLeft.asJava totalCoresAcquired += taskCPUs coresByTaskId(taskId) = taskCPUs + if (taskGPUs > 0) { + totalGpusAcquired += taskGPUs + gpusByTaskId(taskId) = taskGPUs + } } } } @@ -432,21 +442,28 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( } /** Extracts task needed resources from a list of available resources. */ - private def partitionTaskResources(resources: JList[Resource], taskCPUs: Int, taskMemory: Int) + private def partitionTaskResources( + resources: JList[Resource], + taskCPUs: Int, + taskMemory: Int, + taskGPUs: Int) : (List[Resource], List[Resource]) = { // partition cpus & mem val (afterCPUResources, cpuResourcesToUse) = partitionResources(resources, "cpus", taskCPUs) val (afterMemResources, memResourcesToUse) = partitionResources(afterCPUResources.asJava, "mem", taskMemory) + val (afterGPUResources, gpuResourcesToUse) = + partitionResources(afterMemResources.asJava, "gpus", taskGPUs) // If user specifies port numbers in SparkConfig then consecutive tasks will not be launched // on the same host. This essentially means one executor per host. // TODO: handle network isolator case val (nonPortResources, portResourcesToUse) = - partitionPortResources(nonZeroPortValuesFromConfig(sc.conf), afterMemResources) + partitionPortResources(nonZeroPortValuesFromConfig(sc.conf), afterGPUResources) - (nonPortResources, cpuResourcesToUse ++ memResourcesToUse ++ portResourcesToUse) + (nonPortResources, + cpuResourcesToUse ++ memResourcesToUse ++ portResourcesToUse ++ gpuResourcesToUse) } private def canLaunchTask(slaveId: String, resources: JList[Resource]): Boolean = { @@ -513,6 +530,11 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( totalCoresAcquired -= cores coresByTaskId -= taskId } + // Also remove the gpus we have remembered for this task, if it's in the hashmap + for (gpus <- gpusByTaskId.get(taskId)) { + totalGpusAcquired -= gpus + gpusByTaskId -= taskId + } // If it was a failure, mark the slave as failed for blacklisting purposes if (TaskState.isFailed(state)) { slave.taskFailures += 1 diff --git a/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala index 2963d161d6700..73cc241239c4c 100644 --- a/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala +++ b/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala @@ -27,6 +27,7 @@ import scala.util.control.NonFatal import com.google.common.base.Splitter import org.apache.mesos.{MesosSchedulerDriver, Protos, Scheduler, SchedulerDriver} import org.apache.mesos.Protos.{TaskState => MesosTaskState, _} +import org.apache.mesos.Protos.FrameworkInfo.Capability import org.apache.mesos.protobuf.{ByteString, GeneratedMessage} import org.apache.spark.{SparkConf, SparkContext, SparkException} @@ -93,6 +94,10 @@ trait MesosSchedulerUtils extends Logging { conf.getOption("spark.mesos.role").foreach { role => fwInfoBuilder.setRole(role) } + val maxGpus = conf.getInt("spark.mesos.gpus.max", 0) + if (maxGpus > 0) { + fwInfoBuilder.addCapabilities(Capability.newBuilder().setType(Capability.Type.GPU_RESOURCES)) + } if (credBuilder.hasPrincipal) { new MesosSchedulerDriver( scheduler, fwInfoBuilder.build(), masterUrl, credBuilder.build()) diff --git a/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala b/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala index c3ab488e2aa69..75ba02e470e27 100644 --- a/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala +++ b/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala @@ -67,7 +67,7 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite val minMem = backend.executorMemory(sc) val minCpu = 4 - val offers = List((minMem, minCpu)) + val offers = List(Resources(minMem, minCpu)) // launches a task on a valid offer offerResources(offers) @@ -95,8 +95,8 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite // launches a task on a valid offer val minMem = backend.executorMemory(sc) + 1024 val minCpu = 4 - val offer1 = (minMem, minCpu) - val offer2 = (minMem, 1) + val offer1 = Resources(minMem, minCpu) + val offer2 = Resources(minMem, 1) offerResources(List(offer1, offer2)) verifyTaskLaunched(driver, "o1") @@ -115,7 +115,7 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite setBackend(Map("spark.executor.cores" -> executorCores.toString)) val executorMemory = backend.executorMemory(sc) - val offers = List((executorMemory * 2, executorCores + 1)) + val offers = List(Resources(executorMemory * 2, executorCores + 1)) offerResources(offers) val taskInfos = verifyTaskLaunched(driver, "o1") @@ -130,7 +130,7 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite val executorMemory = backend.executorMemory(sc) val offerCores = 10 - offerResources(List((executorMemory * 2, offerCores))) + offerResources(List(Resources(executorMemory * 2, offerCores))) val taskInfos = verifyTaskLaunched(driver, "o1") assert(taskInfos.length == 1) @@ -144,7 +144,7 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite setBackend(Map("spark.cores.max" -> maxCores.toString)) val executorMemory = backend.executorMemory(sc) - offerResources(List((executorMemory, maxCores + 1))) + offerResources(List(Resources(executorMemory, maxCores + 1))) val taskInfos = verifyTaskLaunched(driver, "o1") assert(taskInfos.length == 1) @@ -153,9 +153,38 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite assert(cpus == maxCores) } + test("mesos does not acquire gpus if not specified") { + setBackend() + + val executorMemory = backend.executorMemory(sc) + offerResources(List(Resources(executorMemory, 1, 1))) + + val taskInfos = verifyTaskLaunched(driver, "o1") + assert(taskInfos.length == 1) + + val gpus = backend.getResource(taskInfos.head.getResourcesList, "gpus") + assert(gpus == 0.0) + } + + + test("mesos does not acquire more than spark.mesos.gpus.max") { + val maxGpus = 5 + setBackend(Map("spark.mesos.gpus.max" -> maxGpus.toString)) + + val executorMemory = backend.executorMemory(sc) + offerResources(List(Resources(executorMemory, 1, maxGpus + 1))) + + val taskInfos = verifyTaskLaunched(driver, "o1") + assert(taskInfos.length == 1) + + val gpus = backend.getResource(taskInfos.head.getResourcesList, "gpus") + assert(gpus == maxGpus) + } + + test("mesos declines offers that violate attribute constraints") { setBackend(Map("spark.mesos.constraints" -> "x:true")) - offerResources(List((backend.executorMemory(sc), 4))) + offerResources(List(Resources(backend.executorMemory(sc), 4))) verifyDeclinedOffer(driver, createOfferId("o1"), true) } @@ -165,8 +194,8 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite val executorMemory = backend.executorMemory(sc) offerResources(List( - (executorMemory, maxCores + 1), - (executorMemory, maxCores + 1))) + Resources(executorMemory, maxCores + 1), + Resources(executorMemory, maxCores + 1))) verifyTaskLaunched(driver, "o1") verifyDeclinedOffer(driver, createOfferId("o2"), true) @@ -180,8 +209,8 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite val executorMemory = backend.executorMemory(sc) offerResources(List( - (executorMemory * 2, executorCores * 2), - (executorMemory * 2, executorCores * 2))) + Resources(executorMemory * 2, executorCores * 2), + Resources(executorMemory * 2, executorCores * 2))) verifyTaskLaunched(driver, "o1") verifyTaskLaunched(driver, "o2") @@ -193,7 +222,7 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite // offer with room for two executors val executorMemory = backend.executorMemory(sc) - offerResources(List((executorMemory * 2, executorCores * 2))) + offerResources(List(Resources(executorMemory * 2, executorCores * 2))) // verify two executors were started on a single offer val taskInfos = verifyTaskLaunched(driver, "o1") @@ -397,7 +426,7 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite setBackend() // launches a task on a valid offer - val offers = List((backend.executorMemory(sc), 1)) + val offers = List(Resources(backend.executorMemory(sc), 1)) offerResources(offers) verifyTaskLaunched(driver, "o1") @@ -434,6 +463,8 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite assert(launchedTasks.head.getCommand.getUrisList.asScala(0).getValue == url) } + private case class Resources(mem: Int, cpus: Int, gpus: Int = 0) + private def verifyDeclinedOffer(driver: SchedulerDriver, offerId: OfferID, filter: Boolean = false): Unit = { @@ -444,9 +475,9 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite } } - private def offerResources(offers: List[(Int, Int)], startId: Int = 1): Unit = { + private def offerResources(offers: List[Resources], startId: Int = 1): Unit = { val mesosOffers = offers.zipWithIndex.map {case (offer, i) => - createOffer(s"o${i + startId}", s"s${i + startId}", offer._1, offer._2)} + createOffer(s"o${i + startId}", s"s${i + startId}", offer.mem, offer.cpus, None, offer.gpus)} backend.resourceOffers(driver, mesosOffers.asJava) } diff --git a/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala b/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala index fa9406f5f0553..7ebb294aa9080 100644 --- a/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala +++ b/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala @@ -32,8 +32,9 @@ object Utils { offerId: String, slaveId: String, mem: Int, - cpu: Int, - ports: Option[(Long, Long)] = None): Offer = { + cpus: Int, + ports: Option[(Long, Long)] = None, + gpus: Int = 0): Offer = { val builder = Offer.newBuilder() builder.addResourcesBuilder() .setName("mem") @@ -42,7 +43,7 @@ object Utils { builder.addResourcesBuilder() .setName("cpus") .setType(Value.Type.SCALAR) - .setScalar(Scalar.newBuilder().setValue(cpu)) + .setScalar(Scalar.newBuilder().setValue(cpus)) ports.foreach { resourcePorts => builder.addResourcesBuilder() .setName("ports") @@ -50,6 +51,12 @@ object Utils { .setRanges(Ranges.newBuilder().addRange(MesosRange.newBuilder() .setBegin(resourcePorts._1).setEnd(resourcePorts._2).build())) } + if (gpus > 0) { + builder.addResourcesBuilder() + .setName("gpus") + .setType(Value.Type.SCALAR) + .setScalar(Scalar.newBuilder().setValue(gpus)) + } builder.setId(createOfferId(offerId)) .setFrameworkId(FrameworkID.newBuilder() .setValue("f1")) @@ -82,4 +89,3 @@ object Utils { TaskID.newBuilder().setValue(taskId).build() } } - From 03c40202f36ea9fc93071b79fed21ed3f2190ba1 Mon Sep 17 00:00:00 2001 From: sethah Date: Mon, 10 Oct 2016 17:04:11 -0700 Subject: [PATCH 236/306] [SPARK-14610][ML] Remove superfluous split for continuous features in decision tree training ## What changes were proposed in this pull request? A nonsensical split is produced from method `findSplitsForContinuousFeature` for decision trees. This PR removes the superfluous split and updates unit tests accordingly. Additionally, an assertion to check that the number of found splits is `> 0` is removed, and instead features with zero possible splits are ignored. ## How was this patch tested? A unit test was added to check that finding splits for a constant feature produces an empty array. Author: sethah Closes #12374 from sethah/SPARK-14610. --- .../spark/ml/tree/impl/RandomForest.scala | 31 +++++++------ .../ml/tree/impl/RandomForestSuite.scala | 44 ++++++++++++++++--- 2 files changed, 52 insertions(+), 23 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 0b7ad92b3cf30..b504f411d256d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -705,14 +705,17 @@ private[spark] object RandomForest extends Logging { node.stats } + val validFeatureSplits = + Range(0, binAggregates.metadata.numFeaturesPerNode).view.map { featureIndexIdx => + featuresForNode.map(features => (featureIndexIdx, features(featureIndexIdx))) + .getOrElse((featureIndexIdx, featureIndexIdx)) + }.withFilter { case (_, featureIndex) => + binAggregates.metadata.numSplits(featureIndex) != 0 + } + // For each (feature, split), calculate the gain, and select the best (feature, split). val (bestSplit, bestSplitStats) = - Range(0, binAggregates.metadata.numFeaturesPerNode).map { featureIndexIdx => - val featureIndex = if (featuresForNode.nonEmpty) { - featuresForNode.get.apply(featureIndexIdx) - } else { - featureIndexIdx - } + validFeatureSplits.map { case (featureIndexIdx, featureIndex) => val numSplits = binAggregates.metadata.numSplits(featureIndex) if (binAggregates.metadata.isContinuous(featureIndex)) { // Cumulative sum (scanLeft) of bin statistics. @@ -966,7 +969,7 @@ private[spark] object RandomForest extends Logging { * NOTE: `metadata.numbins` will be changed accordingly * if there are not enough splits to be found * @param featureIndex feature index to find splits - * @return array of splits + * @return array of split thresholds */ private[tree] def findSplitsForContinuousFeature( featureSamples: Iterable[Double], @@ -975,7 +978,9 @@ private[spark] object RandomForest extends Logging { require(metadata.isContinuous(featureIndex), "findSplitsForContinuousFeature can only be used to find splits for a continuous feature.") - val splits = { + val splits = if (featureSamples.isEmpty) { + Array.empty[Double] + } else { val numSplits = metadata.numSplits(featureIndex) // get count for each distinct value @@ -987,9 +992,9 @@ private[spark] object RandomForest extends Logging { val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray // if possible splits is not enough or just enough, just return all possible splits - val possibleSplits = valueCounts.length + val possibleSplits = valueCounts.length - 1 if (possibleSplits <= numSplits) { - valueCounts.map(_._1) + valueCounts.map(_._1).init } else { // stride between splits val stride: Double = numSamples.toDouble / (numSplits + 1) @@ -1023,12 +1028,6 @@ private[spark] object RandomForest extends Logging { splitsBuilder.result() } } - - // TODO: Do not fail; just ignore the useless feature. - assert(splits.length > 0, - s"DecisionTree could not handle feature $featureIndex since it had only 1 unique value." + - " Please remove this feature and then try again.") - splits } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala index 79b19ea5ad206..499d386e66413 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala @@ -115,7 +115,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { ) val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3).map(_.toDouble) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) - assert(splits.length === 3) + assert(splits === Array(1.0, 2.0)) // check returned splits are distinct assert(splits.distinct.length === splits.length) } @@ -129,23 +129,53 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { ) val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5).map(_.toDouble) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) - assert(splits.length === 2) - assert(splits(0) === 2.0) - assert(splits(1) === 3.0) + assert(splits === Array(2.0, 3.0)) } // find splits when most samples close to the maximum { val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, Map(), Set(), - Array(3), Gini, QuantileStrategy.Sort, + Array(2), Gini, QuantileStrategy.Sort, 0, 0, 0.0, 0, 0 ) val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(_.toDouble) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) - assert(splits.length === 1) - assert(splits(0) === 1.0) + assert(splits === Array(1.0)) } + + // find splits for constant feature + { + val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, + Map(), Set(), + Array(3), Gini, QuantileStrategy.Sort, + 0, 0, 0.0, 0, 0 + ) + val featureSamples = Array(0, 0, 0).map(_.toDouble) + val featureSamplesEmpty = Array.empty[Double] + val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) + assert(splits === Array[Double]()) + val splitsEmpty = + RandomForest.findSplitsForContinuousFeature(featureSamplesEmpty, fakeMetadata, 0) + assert(splitsEmpty === Array[Double]()) + } + } + + test("train with constant features") { + val lp = LabeledPoint(1.0, Vectors.dense(0.0, 0.0, 0.0)) + val data = Array.fill(5)(lp) + val rdd = sc.parallelize(data) + val strategy = new OldStrategy( + OldAlgo.Classification, + Gini, + maxDepth = 2, + numClasses = 2, + maxBins = 100, + categoricalFeaturesInfo = Map(0 -> 1, 1 -> 5)) + val Array(tree) = RandomForest.run(rdd, strategy, 1, "all", 42L, instr = None) + assert(tree.rootNode.impurity === -1.0) + assert(tree.depth === 0) + assert(tree.rootNode.prediction === lp.label) } test("Multiclass classification with unordered categorical features: split calculations") { From d5ec4a3e014494a3d991a6350caffbc3b17be0fd Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 10 Oct 2016 19:14:01 -0700 Subject: [PATCH 237/306] [SPARK-17738][TEST] Fix flaky test in ColumnTypeSuite ## What changes were proposed in this pull request? The default buffer size is not big enough for randomly generated MapType. ## How was this patch tested? Ran the tests in 100 times, it never fail (it fail 8 times before the patch). Author: Davies Liu Closes #15395 from davies/flaky_map. --- .../spark/sql/execution/columnar/ColumnTypeSuite.scala | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala index 8bf9f521e2f06..5f2a3aaff634c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala @@ -101,14 +101,15 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { def testColumnType[JvmType](columnType: ColumnType[JvmType]): Unit = { - val buffer = ByteBuffer.allocate(DEFAULT_BUFFER_SIZE).order(ByteOrder.nativeOrder()) val proj = UnsafeProjection.create(Array[DataType](columnType.dataType)) val converter = CatalystTypeConverters.createToScalaConverter(columnType.dataType) val seq = (0 until 4).map(_ => proj(makeRandomRow(columnType)).copy()) + val totalSize = seq.map(_.getSizeInBytes).sum + val bufferSize = Math.max(DEFAULT_BUFFER_SIZE, totalSize) test(s"$columnType append/extract") { - buffer.rewind() - seq.foreach(columnType.append(_, 0, buffer)) + val buffer = ByteBuffer.allocate(bufferSize).order(ByteOrder.nativeOrder()) + seq.foreach(r => columnType.append(columnType.getField(r, 0), buffer)) buffer.rewind() seq.foreach { row => From 90217f9deed01ae187e28ef1531491aac8ee50c9 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 11 Oct 2016 10:21:22 +0800 Subject: [PATCH 238/306] [SPARK-16896][SQL] Handle duplicated field names in header consistently with null or empty strings in CSV ## What changes were proposed in this pull request? Currently, CSV datasource allows to load duplicated empty string fields or fields having `nullValue` in the header. It'd be great if this can deal with normal fields as well. This PR proposes handling the duplicates consistently with the existing behaviour with considering case-sensitivity (`spark.sql.caseSensitive`) as below: data below: ``` fieldA,fieldB,,FIELDA,fielda,, 1,2,3,4,5,6,7 ``` is parsed as below: ```scala spark.read.format("csv").option("header", "true").load("test.csv").show() ``` - when `spark.sql.caseSensitive` is `false` (by default). ``` +-------+------+---+-------+-------+---+---+ |fieldA0|fieldB|_c2|FIELDA3|fieldA4|_c5|_c6| +-------+------+---+-------+-------+---+---+ | 1| 2| 3| 4| 5| 6| 7| +-------+------+---+-------+-------+---+---+ ``` - when `spark.sql.caseSensitive` is `true`. ``` +-------+------+---+-------+-------+---+---+ |fieldA0|fieldB|_c2| FIELDA|fieldA4|_c5|_c6| +-------+------+---+-------+-------+---+---+ | 1| 2| 3| 4| 5| 6| 7| +-------+------+---+-------+-------+---+---+ ``` **In more details**, There is a good reference about this problem, `read.csv()` in R. So, I initially wanted to propose the similar behaviour. In case of R, the CSV data below: ``` fieldA,fieldB,,fieldA,fieldA,, 1,2,3,4,5,6,7 ``` is parsed as below: ```r test <- read.csv(file="test.csv",header=TRUE,sep=",") > test fieldA fieldB X fieldA.1 fieldA.2 X.1 X.2 1 1 2 3 4 5 6 7 ``` However, Spark CSV datasource already is handling duplicated empty strings and `nullValue` as field names. So the data below: ``` ,,,fieldA,,fieldB, 1,2,3,4,5,6,7 ``` is parsed as below: ```scala spark.read.format("csv").option("header", "true").load("test.csv").show() ``` ``` +---+---+---+------+---+------+---+ |_c0|_c1|_c2|fieldA|_c4|fieldB|_c6| +---+---+---+------+---+------+---+ | 1| 2| 3| 4| 5| 6| 7| +---+---+---+------+---+------+---+ ``` R starts the number for each duplicate but Spark adds the number for its position for all fields for `nullValue` and empty strings. In terms of case-sensitivity, it seems R is case-sensitive as below: (it seems it is not configurable). ``` a,a,a,A,A 1,2,3,4,5 ``` is parsed as below: ```r test <- read.csv(file="test.csv",header=TRUE,sep=",") > test a a.1 a.2 A A.1 1 1 2 3 4 5 ``` ## How was this patch tested? Unit test in `CSVSuite`. Author: hyukjinkwon Closes #14745 from HyukjinKwon/SPARK-16896. --- .../datasources/csv/CSVFileFormat.scala | 50 +++++++++++++++---- .../execution/datasources/csv/CSVSuite.scala | 33 ++++++++++++ .../datasources/csv/CSVTypeCastSuite.scala | 2 - 3 files changed, 74 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index 4e662a52a7bb7..a3691158ee758 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -59,14 +59,8 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { val rdd = baseRdd(sparkSession, csvOptions, paths) val firstLine = findFirstLine(csvOptions, rdd) val firstRow = new CsvReader(csvOptions).parseLine(firstLine) - - val header = if (csvOptions.headerFlag) { - firstRow.zipWithIndex.map { case (value, index) => - if (value == null || value.isEmpty || value == csvOptions.nullValue) s"_c$index" else value - } - } else { - firstRow.zipWithIndex.map { case (value, index) => s"_c$index" } - } + val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis + val header = makeSafeHeader(firstRow, csvOptions, caseSensitive) val parsedRdd = tokenRdd(sparkSession, csvOptions, header, paths) val schema = if (csvOptions.inferSchemaFlag) { @@ -74,13 +68,51 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { } else { // By default fields are assumed to be StringType val schemaFields = header.map { fieldName => - StructField(fieldName.toString, StringType, nullable = true) + StructField(fieldName, StringType, nullable = true) } StructType(schemaFields) } Some(schema) } + /** + * Generates a header from the given row which is null-safe and duplicate-safe. + */ + private def makeSafeHeader( + row: Array[String], + options: CSVOptions, + caseSensitive: Boolean): Array[String] = { + if (options.headerFlag) { + val duplicates = { + val headerNames = row.filter(_ != null) + .map(name => if (caseSensitive) name else name.toLowerCase) + headerNames.diff(headerNames.distinct).distinct + } + + row.zipWithIndex.map { case (value, index) => + if (value == null || value.isEmpty || value == options.nullValue) { + // When there are empty strings or the values set in `nullValue`, put the + // index as the suffix. + s"_c$index" + } else if (!caseSensitive && duplicates.contains(value.toLowerCase)) { + // When there are case-insensitive duplicates, put the index as the suffix. + s"$value$index" + } else if (duplicates.contains(value)) { + // When there are duplicates, put the index as the suffix. + s"$value$index" + } else { + value + } + } + } else { + row.zipWithIndex.map { case (_, index) => + // Uses default column names, "_c#" where # is its position of fields + // when header option is disabled. + s"_c$index" + } + } + } + override def prepareWrite( sparkSession: SparkSession, job: Job, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 29aac9def6924..f7c22c6c93f7a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -28,6 +28,7 @@ import org.apache.hadoop.io.compress.GzipCodec import org.apache.spark.SparkException import org.apache.spark.sql.{DataFrame, QueryTest, Row, UDT} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.sql.types._ @@ -856,4 +857,36 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { checkAnswer(stringTimestampsWithFormat, expectedStringTimestampsWithFormat) } } + + test("load duplicated field names consistently with null or empty strings - case sensitive") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + withTempPath { path => + Seq("a,a,c,A,b,B").toDF().write.text(path.getAbsolutePath) + val actualSchema = spark.read + .format("csv") + .option("header", true) + .load(path.getAbsolutePath) + .schema + val fields = Seq("a0", "a1", "c", "A", "b", "B").map(StructField(_, StringType, true)) + val expectedSchema = StructType(fields) + assert(actualSchema == expectedSchema) + } + } + } + + test("load duplicated field names consistently with null or empty strings - case insensitive") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + withTempPath { path => + Seq("a,A,c,A,b,B").toDF().write.text(path.getAbsolutePath) + val actualSchema = spark.read + .format("csv") + .option("header", true) + .load(path.getAbsolutePath) + .schema + val fields = Seq("a0", "A1", "c", "A3", "b4", "B5").map(StructField(_, StringType, true)) + val expectedSchema = StructType(fields) + assert(actualSchema == expectedSchema) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala index dae92f626c225..51832a13cfe0b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala @@ -18,8 +18,6 @@ package org.apache.spark.sql.execution.datasources.csv import java.math.BigDecimal -import java.sql.{Date, Timestamp} -import java.text.SimpleDateFormat import java.util.Locale import org.apache.spark.SparkFunSuite From 19a5bae47f69929d00d9de43387c7df37a05ee25 Mon Sep 17 00:00:00 2001 From: Ergin Seyfe Date: Mon, 10 Oct 2016 20:41:31 -0700 Subject: [PATCH 239/306] [SPARK-17816][CORE] Fix ConcurrentModificationException issue in BlockStatusesAccumulator ## What changes were proposed in this pull request? Change the BlockStatusesAccumulator to return immutable object when value method is called. ## How was this patch tested? Existing tests plus I verified this change by running a pipeline which consistently repro this issue. This is the stack trace for this exception: ` java.util.ConcurrentModificationException at java.util.ArrayList$Itr.checkForComodification(ArrayList.java:901) at java.util.ArrayList$Itr.next(ArrayList.java:851) at scala.collection.convert.Wrappers$JIteratorWrapper.next(Wrappers.scala:43) at scala.collection.Iterator$class.foreach(Iterator.scala:893) at scala.collection.AbstractIterator.foreach(Iterator.scala:1336) at scala.collection.IterableLike$class.foreach(IterableLike.scala:72) at scala.collection.AbstractIterable.foreach(Iterable.scala:54) at scala.collection.generic.Growable$class.$plus$plus$eq(Growable.scala:59) at scala.collection.mutable.ListBuffer.$plus$plus$eq(ListBuffer.scala:183) at scala.collection.mutable.ListBuffer.$plus$plus$eq(ListBuffer.scala:45) at scala.collection.TraversableLike$class.to(TraversableLike.scala:590) at scala.collection.AbstractTraversable.to(Traversable.scala:104) at scala.collection.TraversableOnce$class.toList(TraversableOnce.scala:294) at scala.collection.AbstractTraversable.toList(Traversable.scala:104) at org.apache.spark.util.JsonProtocol$.accumValueToJson(JsonProtocol.scala:314) at org.apache.spark.util.JsonProtocol$$anonfun$accumulableInfoToJson$5.apply(JsonProtocol.scala:291) at org.apache.spark.util.JsonProtocol$$anonfun$accumulableInfoToJson$5.apply(JsonProtocol.scala:291) at scala.Option.map(Option.scala:146) at org.apache.spark.util.JsonProtocol$.accumulableInfoToJson(JsonProtocol.scala:291) at org.apache.spark.util.JsonProtocol$$anonfun$taskInfoToJson$12.apply(JsonProtocol.scala:283) at org.apache.spark.util.JsonProtocol$$anonfun$taskInfoToJson$12.apply(JsonProtocol.scala:283) at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234) at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234) at scala.collection.immutable.List.foreach(List.scala:381) at scala.collection.generic.TraversableForwarder$class.foreach(TraversableForwarder.scala:35) at scala.collection.mutable.ListBuffer.foreach(ListBuffer.scala:45) at scala.collection.TraversableLike$class.map(TraversableLike.scala:234) at scala.collection.AbstractTraversable.map(Traversable.scala:104) at org.apache.spark.util.JsonProtocol$.taskInfoToJson(JsonProtocol.scala:283) at org.apache.spark.util.JsonProtocol$.taskEndToJson(JsonProtocol.scala:145) at org.apache.spark.util.JsonProtocol$.sparkEventToJson(JsonProtocol.scala:76) ` Author: Ergin Seyfe Closes #15371 from seyfe/race_cond_jsonprotocal. --- .../apache/spark/executor/TaskMetrics.scala | 42 +------------------ .../org/apache/spark/util/AccumulatorV2.scala | 4 +- .../org/apache/spark/util/JsonProtocol.scala | 2 +- 3 files changed, 6 insertions(+), 42 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index 2956768c16417..dfd2f818acdac 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -17,8 +17,6 @@ package org.apache.spark.executor -import java.util.{ArrayList, Collections} - import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, LinkedHashMap} @@ -27,7 +25,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.internal.Logging import org.apache.spark.scheduler.AccumulableInfo import org.apache.spark.storage.{BlockId, BlockStatus} -import org.apache.spark.util.{AccumulatorContext, AccumulatorMetadata, AccumulatorV2, LongAccumulator} +import org.apache.spark.util._ /** @@ -56,7 +54,7 @@ class TaskMetrics private[spark] () extends Serializable { private val _memoryBytesSpilled = new LongAccumulator private val _diskBytesSpilled = new LongAccumulator private val _peakExecutionMemory = new LongAccumulator - private val _updatedBlockStatuses = new BlockStatusesAccumulator + private val _updatedBlockStatuses = new CollectionAccumulator[(BlockId, BlockStatus)] /** * Time taken on the executor to deserialize this task. @@ -323,39 +321,3 @@ private[spark] object TaskMetrics extends Logging { tm } } - - -private[spark] class BlockStatusesAccumulator - extends AccumulatorV2[(BlockId, BlockStatus), java.util.List[(BlockId, BlockStatus)]] { - private val _seq = Collections.synchronizedList(new ArrayList[(BlockId, BlockStatus)]()) - - override def isZero(): Boolean = _seq.isEmpty - - override def copyAndReset(): BlockStatusesAccumulator = new BlockStatusesAccumulator - - override def copy(): BlockStatusesAccumulator = { - val newAcc = new BlockStatusesAccumulator - newAcc._seq.addAll(_seq) - newAcc - } - - override def reset(): Unit = _seq.clear() - - override def add(v: (BlockId, BlockStatus)): Unit = _seq.add(v) - - override def merge( - other: AccumulatorV2[(BlockId, BlockStatus), java.util.List[(BlockId, BlockStatus)]]): Unit = { - other match { - case o: BlockStatusesAccumulator => _seq.addAll(o.value) - case _ => throw new UnsupportedOperationException( - s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}") - } - } - - override def value: java.util.List[(BlockId, BlockStatus)] = _seq - - def setValue(newValue: java.util.List[(BlockId, BlockStatus)]): Unit = { - _seq.clear() - _seq.addAll(newValue) - } -} diff --git a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala index 470d912ecff13..d3ddd39131326 100644 --- a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala +++ b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala @@ -444,7 +444,9 @@ class CollectionAccumulator[T] extends AccumulatorV2[T, java.util.List[T]] { override def copy(): CollectionAccumulator[T] = { val newAcc = new CollectionAccumulator[T] - newAcc._list.addAll(_list) + _list.synchronized { + newAcc._list.addAll(_list) + } newAcc } diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index f4fa7b4061640..c11eb3ffa4601 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -281,7 +281,7 @@ private[spark] object JsonProtocol { ("Finish Time" -> taskInfo.finishTime) ~ ("Failed" -> taskInfo.failed) ~ ("Killed" -> taskInfo.killed) ~ - ("Accumulables" -> JArray(taskInfo.accumulables.map(accumulableInfoToJson).toList)) + ("Accumulables" -> JArray(taskInfo.accumulables.toList.map(accumulableInfoToJson))) } def accumulableInfoToJson(accumulableInfo: AccumulableInfo): JValue = { From 0c0ad436ad909364915b910867d08262c62bc95d Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 10 Oct 2016 22:22:41 -0700 Subject: [PATCH 240/306] [SPARK-17719][SPARK-17776][SQL] Unify and tie up options in a single place in JDBC datasource package ## What changes were proposed in this pull request? This PR proposes to fix arbitrary usages among `Map[String, String]`, `Properties` and `JDBCOptions` instances for options in `execution/jdbc` package and make the connection properties exclude Spark-only options. This PR includes some changes as below: - Unify `Map[String, String]`, `Properties` and `JDBCOptions` in `execution/jdbc` package to `JDBCOptions`. - Move `batchsize`, `fetchszie`, `driver` and `isolationlevel` options into `JDBCOptions` instance. - Document `batchSize` and `isolationlevel` with marking both read-only options and write-only options. Also, this includes minor types and detailed explanation for some statements such as url. - Throw exceptions fast by checking arguments first rather than in execution time (e.g. for `fetchsize`). - Exclude Spark-only options in connection properties. ## How was this patch tested? Existing tests should cover this. Author: hyukjinkwon Closes #15292 from HyukjinKwon/SPARK-17719. --- docs/sql-programming-guide.md | 36 ++++-- .../apache/spark/sql/DataFrameReader.scala | 13 +-- .../datasources/jdbc/JDBCOptions.scala | 110 +++++++++++++++--- .../execution/datasources/jdbc/JDBCRDD.scala | 45 +++---- .../datasources/jdbc/JDBCRelation.scala | 20 ++-- .../jdbc/JdbcRelationProvider.scala | 30 ++--- .../datasources/jdbc/JdbcUtils.scala | 42 ++----- .../spark/sql/jdbc/PostgresDialect.scala | 4 +- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 11 +- .../spark/sql/jdbc/JDBCWriteSuite.scala | 8 +- 10 files changed, 182 insertions(+), 137 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 835cb6981f5bd..d0f43ab0a9cc9 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1049,16 +1049,20 @@ bin/spark-shell --driver-class-path postgresql-9.4.1207.jar --jars postgresql-9. {% endhighlight %} Tables from the remote database can be loaded as a DataFrame or Spark SQL Temporary table using -the Data Sources API. The following options are supported: +the Data Sources API. Users can specify the JDBC connection properties in the data source options. +user and password are normally provided as connection properties for +logging into the data sources. In addition to the connection properties, Spark also supports +the following case-sensitive options: + + + + + + + + + + +
      Property NameMeaning
      url - The JDBC URL to connect to. + The JDBC URL to connect to. The source-specific connection properties may be specified in the URL. e.g., jdbc:postgresql://localhost/test?user=fred&password=secret
      dbtable @@ -1083,28 +1087,42 @@ the Data Sources API. The following options are supported: partitionColumn must be a numeric column from the table in question. Notice that lowerBound and upperBound are just used to decide the partition stride, not for filtering the rows in table. So all rows in the table will be - partitioned and returned. + partitioned and returned. This option applies only to reading.
      fetchsize - The JDBC fetch size, which determines how many rows to fetch per round trip. This can help performance on JDBC drivers which default to low fetch size (eg. Oracle with 10 rows). + The JDBC fetch size, which determines how many rows to fetch per round trip. This can help performance on JDBC drivers which default to low fetch size (eg. Oracle with 10 rows). This option applies only to reading.
      batchsize + The JDBC batch size, which determines how many rows to insert per round trip. This can help performance on JDBC drivers. This option applies only to writing. It defaults to 1000. +
      isolationLevel + The transaction isolation level, which applies to current connection. It can be one of NONE, READ_COMMITTED, READ_UNCOMMITTED, REPEATABLE_READ, or SERIALIZABLE, corresponding to standard transaction isolation levels defined by JDBC's Connection object, with default of READ_UNCOMMITTED. This option applies only to writing. Please refer the documentation in java.sql.Connection. +
      truncate - This is a JDBC writer related option. When SaveMode.Overwrite is enabled, this option causes Spark to truncate an existing table instead of dropping and recreating it. This can be more efficient, and prevents the table metadata (e.g. indices) from being removed. However, it will not work in some cases, such as when the new data has a different schema. It defaults to false. + This is a JDBC writer related option. When SaveMode.Overwrite is enabled, this option causes Spark to truncate an existing table instead of dropping and recreating it. This can be more efficient, and prevents the table metadata (e.g., indices) from being removed. However, it will not work in some cases, such as when the new data has a different schema. It defaults to false. This option applies only to writing.
      createTableOptions - This is a JDBC writer related option. If specified, this option allows setting of database-specific table and partition options when creating a table. For example: CREATE TABLE t (name string) ENGINE=InnoDB. + This is a JDBC writer related option. If specified, this option allows setting of database-specific table and partition options when creating a table (e.g., CREATE TABLE t (name string) ENGINE=InnoDB.). This option applies only to writing.
      @@ -1328,7 +1346,7 @@ options. - Dataset API and DataFrame API are unified. In Scala, `DataFrame` becomes a type alias for `Dataset[Row]`, while Java API users must replace `DataFrame` with `Dataset`. Both the typed - transformations (e.g. `map`, `filter`, and `groupByKey`) and untyped transformations (e.g. + transformations (e.g., `map`, `filter`, and `groupByKey`) and untyped transformations (e.g., `select` and `groupBy`) are available on the Dataset class. Since compile-time type-safety in Python and R is not a language feature, the concept of Dataset does not apply to these languages’ APIs. Instead, `DataFrame` remains the primary programing abstraction, which is analogous to the @@ -1377,7 +1395,7 @@ options. - Timestamps are now stored at a precision of 1us, rather than 1ns - In the `sql` dialect, floating point numbers are now parsed as decimal. HiveQL parsing remains unchanged. - - The canonical name of SQL/DataFrame functions are now lower case (e.g. sum vs SUM). + - The canonical name of SQL/DataFrame functions are now lower case (e.g., sum vs SUM). - JSON data source will not automatically load new files that are created by other applications (i.e. files that are not inserted to the dataset through Spark SQL). For a JSON persistent table (i.e. the metadata of the table is stored in Hive Metastore), @@ -1392,7 +1410,7 @@ options. Based on user feedback, we created a new, more fluid API for reading data in (`SQLContext.read`) and writing data out (`DataFrame.write`), -and deprecated the old APIs (e.g. `SQLContext.parquetFile`, `SQLContext.jsonFile`). +and deprecated the old APIs (e.g., `SQLContext.parquetFile`, `SQLContext.jsonFile`). See the API docs for `SQLContext.read` (
      Scala, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index b54e695db3b5e..a716a916b7f7f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -29,7 +29,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.json.{JacksonParser, JSONOptions} import org.apache.spark.sql.execution.LogicalRDD import org.apache.spark.sql.execution.datasources.DataSource -import org.apache.spark.sql.execution.datasources.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation} +import org.apache.spark.sql.execution.datasources.jdbc._ import org.apache.spark.sql.execution.datasources.json.InferSchema import org.apache.spark.sql.types.StructType @@ -231,13 +231,10 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { table: String, parts: Array[Partition], connectionProperties: Properties): DataFrame = { - val props = new Properties() - extraOptions.foreach { case (key, value) => - props.put(key, value) - } - // connectionProperties should override settings in extraOptions - props.putAll(connectionProperties) - val relation = JDBCRelation(url, table, parts, props)(sparkSession) + // connectionProperties should override settings in extraOptions. + val params = extraOptions.toMap ++ connectionProperties.asScala.toMap + val options = new JDBCOptions(url, table, params) + val relation = JDBCRelation(parts, options)(sparkSession) sparkSession.baseRelationToDataFrame(relation) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala index bcf65e53afa73..fcd7409159def 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala @@ -17,6 +17,11 @@ package org.apache.spark.sql.execution.datasources.jdbc +import java.sql.{Connection, DriverManager} +import java.util.Properties + +import scala.collection.mutable.ArrayBuffer + /** * Options for the JDBC data source. */ @@ -24,40 +29,115 @@ class JDBCOptions( @transient private val parameters: Map[String, String]) extends Serializable { + import JDBCOptions._ + + def this(url: String, table: String, parameters: Map[String, String]) = { + this(parameters ++ Map( + JDBCOptions.JDBC_URL -> url, + JDBCOptions.JDBC_TABLE_NAME -> table)) + } + + val asConnectionProperties: Properties = { + val properties = new Properties() + // We should avoid to pass the options into properties. See SPARK-17776. + parameters.filterKeys(!jdbcOptionNames.contains(_)) + .foreach { case (k, v) => properties.setProperty(k, v) } + properties + } + // ------------------------------------------------------------ // Required parameters // ------------------------------------------------------------ - require(parameters.isDefinedAt("url"), "Option 'url' is required.") - require(parameters.isDefinedAt("dbtable"), "Option 'dbtable' is required.") + require(parameters.isDefinedAt(JDBC_URL), s"Option '$JDBC_URL' is required.") + require(parameters.isDefinedAt(JDBC_TABLE_NAME), s"Option '$JDBC_TABLE_NAME' is required.") // a JDBC URL - val url = parameters("url") + val url = parameters(JDBC_URL) // name of table - val table = parameters("dbtable") + val table = parameters(JDBC_TABLE_NAME) + + // ------------------------------------------------------------ + // Optional parameters + // ------------------------------------------------------------ + val driverClass = { + val userSpecifiedDriverClass = parameters.get(JDBC_DRIVER_CLASS) + userSpecifiedDriverClass.foreach(DriverRegistry.register) + + // Performing this part of the logic on the driver guards against the corner-case where the + // driver returned for a URL is different on the driver and executors due to classpath + // differences. + userSpecifiedDriverClass.getOrElse { + DriverManager.getDriver(url).getClass.getCanonicalName + } + } // ------------------------------------------------------------ - // Optional parameter list + // Optional parameters only for reading // ------------------------------------------------------------ // the column used to partition - val partitionColumn = parameters.getOrElse("partitionColumn", null) + val partitionColumn = parameters.getOrElse(JDBC_PARTITION_COLUMN, null) // the lower bound of partition column - val lowerBound = parameters.getOrElse("lowerBound", null) + val lowerBound = parameters.getOrElse(JDBC_LOWER_BOUND, null) // the upper bound of the partition column - val upperBound = parameters.getOrElse("upperBound", null) + val upperBound = parameters.getOrElse(JDBC_UPPER_BOUND, null) // the number of partitions - val numPartitions = parameters.getOrElse("numPartitions", null) - + val numPartitions = parameters.getOrElse(JDBC_NUM_PARTITIONS, null) require(partitionColumn == null || (lowerBound != null && upperBound != null && numPartitions != null), - "If 'partitionColumn' is specified then 'lowerBound', 'upperBound'," + - " and 'numPartitions' are required.") + s"If '$JDBC_PARTITION_COLUMN' is specified then '$JDBC_LOWER_BOUND', '$JDBC_UPPER_BOUND'," + + s" and '$JDBC_NUM_PARTITIONS' are required.") + val fetchSize = { + val size = parameters.getOrElse(JDBC_BATCH_FETCH_SIZE, "0").toInt + require(size >= 0, + s"Invalid value `${size.toString}` for parameter " + + s"`$JDBC_BATCH_FETCH_SIZE`. The minimum value is 0. When the value is 0, " + + "the JDBC driver ignores the value and does the estimates.") + size + } // ------------------------------------------------------------ - // The options for DataFrameWriter + // Optional parameters only for writing // ------------------------------------------------------------ // if to truncate the table from the JDBC database - val isTruncate = parameters.getOrElse("truncate", "false").toBoolean + val isTruncate = parameters.getOrElse(JDBC_TRUNCATE, "false").toBoolean // the create table option , which can be table_options or partition_options. // E.g., "CREATE TABLE t (name string) ENGINE=InnoDB DEFAULT CHARSET=utf8" // TODO: to reuse the existing partition parameters for those partition specific options - val createTableOptions = parameters.getOrElse("createTableOptions", "") + val createTableOptions = parameters.getOrElse(JDBC_CREATE_TABLE_OPTIONS, "") + val batchSize = { + val size = parameters.getOrElse(JDBC_BATCH_INSERT_SIZE, "1000").toInt + require(size >= 1, + s"Invalid value `${size.toString}` for parameter " + + s"`$JDBC_BATCH_INSERT_SIZE`. The minimum value is 1.") + size + } + val isolationLevel = + parameters.getOrElse(JDBC_TXN_ISOLATION_LEVEL, "READ_UNCOMMITTED") match { + case "NONE" => Connection.TRANSACTION_NONE + case "READ_UNCOMMITTED" => Connection.TRANSACTION_READ_UNCOMMITTED + case "READ_COMMITTED" => Connection.TRANSACTION_READ_COMMITTED + case "REPEATABLE_READ" => Connection.TRANSACTION_REPEATABLE_READ + case "SERIALIZABLE" => Connection.TRANSACTION_SERIALIZABLE + } +} + +object JDBCOptions { + private val jdbcOptionNames = ArrayBuffer.empty[String] + + private def newOption(name: String): String = { + jdbcOptionNames += name + name + } + + val JDBC_URL = newOption("url") + val JDBC_TABLE_NAME = newOption("dbtable") + val JDBC_DRIVER_CLASS = newOption("driver") + val JDBC_PARTITION_COLUMN = newOption("partitionColumn") + val JDBC_LOWER_BOUND = newOption("lowerBound") + val JDBC_UPPER_BOUND = newOption("upperBound") + val JDBC_NUM_PARTITIONS = newOption("numPartitions") + val JDBC_BATCH_FETCH_SIZE = newOption("fetchsize") + val JDBC_TRUNCATE = newOption("truncate") + val JDBC_CREATE_TABLE_OPTIONS = newOption("createTableOptions") + val JDBC_BATCH_INSERT_SIZE = newOption("batchsize") + val JDBC_TXN_ISOLATION_LEVEL = newOption("isolationLevel") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index f10615ebe4bcf..c0fabc81e42a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.execution.datasources.jdbc import java.sql.{Connection, Date, PreparedStatement, ResultSet, SQLException, Timestamp} -import java.util.Properties import scala.util.control.NonFatal @@ -46,17 +45,18 @@ object JDBCRDD extends Logging { * Takes a (schema, table) specification and returns the table's Catalyst * schema. * - * @param url - The JDBC url to fetch information from. - * @param table - The table name of the desired table. This may also be a - * SQL query wrapped in parentheses. + * @param options - JDBC options that contains url, table and other information. * * @return A StructType giving the table's Catalyst schema. * @throws SQLException if the table specification is garbage. * @throws SQLException if the table contains an unsupported type. */ - def resolveTable(url: String, table: String, properties: Properties): StructType = { + def resolveTable(options: JDBCOptions): StructType = { + val url = options.url + val table = options.table + val properties = options.asConnectionProperties val dialect = JdbcDialects.get(url) - val conn: Connection = JdbcUtils.createConnectionFactory(url, properties)() + val conn: Connection = JdbcUtils.createConnectionFactory(options)() try { val statement = conn.prepareStatement(dialect.getSchemaQuery(table)) try { @@ -143,43 +143,38 @@ object JDBCRDD extends Logging { }) } - - /** * Build and return JDBCRDD from the given information. * * @param sc - Your SparkContext. * @param schema - The Catalyst schema of the underlying database table. - * @param url - The JDBC url to connect to. - * @param fqTable - The fully-qualified table name (or paren'd SQL query) to use. * @param requiredColumns - The names of the columns to SELECT. * @param filters - The filters to include in all WHERE clauses. * @param parts - An array of JDBCPartitions specifying partition ids and * per-partition WHERE clauses. + * @param options - JDBC options that contains url, table and other information. * * @return An RDD representing "SELECT requiredColumns FROM fqTable". */ def scanTable( sc: SparkContext, schema: StructType, - url: String, - properties: Properties, - fqTable: String, requiredColumns: Array[String], filters: Array[Filter], - parts: Array[Partition]): RDD[InternalRow] = { + parts: Array[Partition], + options: JDBCOptions): RDD[InternalRow] = { + val url = options.url val dialect = JdbcDialects.get(url) val quotedColumns = requiredColumns.map(colName => dialect.quoteIdentifier(colName)) new JDBCRDD( sc, - JdbcUtils.createConnectionFactory(url, properties), + JdbcUtils.createConnectionFactory(options), pruneSchema(schema, requiredColumns), - fqTable, quotedColumns, filters, parts, url, - properties) + options) } } @@ -192,12 +187,11 @@ private[jdbc] class JDBCRDD( sc: SparkContext, getConnection: () => Connection, schema: StructType, - fqTable: String, columns: Array[String], filters: Array[Filter], partitions: Array[Partition], url: String, - properties: Properties) + options: JDBCOptions) extends RDD[InternalRow](sc, Nil) { /** @@ -211,7 +205,7 @@ private[jdbc] class JDBCRDD( private val columnList: String = { val sb = new StringBuilder() columns.foreach(x => sb.append(",").append(x)) - if (sb.length == 0) "1" else sb.substring(1) + if (sb.isEmpty) "1" else sb.substring(1) } /** @@ -286,7 +280,7 @@ private[jdbc] class JDBCRDD( conn = getConnection() val dialect = JdbcDialects.get(url) import scala.collection.JavaConverters._ - dialect.beforeFetch(conn, properties.asScala.toMap) + dialect.beforeFetch(conn, options.asConnectionProperties.asScala.toMap) // H2's JDBC driver does not support the setSchema() method. We pass a // fully-qualified table name in the SELECT statement. I don't know how to @@ -294,15 +288,10 @@ private[jdbc] class JDBCRDD( val myWhereClause = getWhereClause(part) - val sqlText = s"SELECT $columnList FROM $fqTable $myWhereClause" + val sqlText = s"SELECT $columnList FROM ${options.table} $myWhereClause" stmt = conn.prepareStatement(sqlText, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY) - val fetchSize = properties.getProperty(JdbcUtils.JDBC_BATCH_FETCH_SIZE, "0").toInt - require(fetchSize >= 0, - s"Invalid value `${fetchSize.toString}` for parameter " + - s"`${JdbcUtils.JDBC_BATCH_FETCH_SIZE}`. The minimum value is 0. When the value is 0, " + - "the JDBC driver ignores the value and does the estimates.") - stmt.setFetchSize(fetchSize) + stmt.setFetchSize(options.fetchSize) rs = stmt.executeQuery() val rowsIterator = JdbcUtils.resultSetToSparkInternalRows(rs, schema, inputMetrics) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala index 11613dd912eca..672c21c6ac734 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.execution.datasources.jdbc -import java.util.Properties - import scala.collection.mutable.ArrayBuffer import org.apache.spark.internal.Logging @@ -102,10 +100,7 @@ private[sql] object JDBCRelation extends Logging { } private[sql] case class JDBCRelation( - url: String, - table: String, - parts: Array[Partition], - properties: Properties = new Properties())(@transient val sparkSession: SparkSession) + parts: Array[Partition], jdbcOptions: JDBCOptions)(@transient val sparkSession: SparkSession) extends BaseRelation with PrunedFilteredScan with InsertableRelation { @@ -114,7 +109,7 @@ private[sql] case class JDBCRelation( override val needConversion: Boolean = false - override val schema: StructType = JDBCRDD.resolveTable(url, table, properties) + override val schema: StructType = JDBCRDD.resolveTable(jdbcOptions) // Check if JDBCRDD.compileFilter can accept input filters override def unhandledFilters(filters: Array[Filter]): Array[Filter] = { @@ -126,15 +121,16 @@ private[sql] case class JDBCRelation( JDBCRDD.scanTable( sparkSession.sparkContext, schema, - url, - properties, - table, requiredColumns, filters, - parts).asInstanceOf[RDD[Row]] + parts, + jdbcOptions).asInstanceOf[RDD[Row]] } override def insert(data: DataFrame, overwrite: Boolean): Unit = { + val url = jdbcOptions.url + val table = jdbcOptions.table + val properties = jdbcOptions.asConnectionProperties data.write .mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append) .jdbc(url, table, properties) @@ -142,6 +138,6 @@ private[sql] case class JDBCRelation( override def toString: String = { // credentials should not be included in the plan output, table information is sufficient. - s"JDBCRelation(${table})" + s"JDBCRelation(${jdbcOptions.table})" } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala index b1a061b6f7422..4420b3b18a907 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala @@ -17,10 +17,6 @@ package org.apache.spark.sql.execution.datasources.jdbc -import java.util.Properties - -import scala.collection.JavaConverters.mapAsJavaMapConverter - import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SQLContext} import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils._ import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister, RelationProvider} @@ -46,9 +42,7 @@ class JdbcRelationProvider extends CreatableRelationProvider partitionColumn, lowerBound.toLong, upperBound.toLong, numPartitions.toInt) } val parts = JDBCRelation.columnPartition(partitionInfo) - val properties = new Properties() // Additional properties that we will pass to getConnection - parameters.foreach(kv => properties.setProperty(kv._1, kv._2)) - JDBCRelation(jdbcOptions.url, jdbcOptions.table, parts, properties)(sqlContext.sparkSession) + JDBCRelation(parts, jdbcOptions)(sqlContext.sparkSession) } override def createRelation( @@ -56,15 +50,13 @@ class JdbcRelationProvider extends CreatableRelationProvider mode: SaveMode, parameters: Map[String, String], df: DataFrame): BaseRelation = { - val options = new JDBCOptions(parameters) - val url = options.url - val table = options.table - val createTableOptions = options.createTableOptions - val isTruncate = options.isTruncate - val props = new Properties() - props.putAll(parameters.asJava) + val jdbcOptions = new JDBCOptions(parameters) + val url = jdbcOptions.url + val table = jdbcOptions.table + val createTableOptions = jdbcOptions.createTableOptions + val isTruncate = jdbcOptions.isTruncate - val conn = JdbcUtils.createConnectionFactory(url, props)() + val conn = JdbcUtils.createConnectionFactory(jdbcOptions)() try { val tableExists = JdbcUtils.tableExists(conn, url, table) if (tableExists) { @@ -73,16 +65,16 @@ class JdbcRelationProvider extends CreatableRelationProvider if (isTruncate && isCascadingTruncateTable(url) == Some(false)) { // In this case, we should truncate table and then load. truncateTable(conn, table) - saveTable(df, url, table, props) + saveTable(df, url, table, jdbcOptions) } else { // Otherwise, do not truncate the table, instead drop and recreate it dropTable(conn, table) createTable(df.schema, url, table, createTableOptions, conn) - saveTable(df, url, table, props) + saveTable(df, url, table, jdbcOptions) } case SaveMode.Append => - saveTable(df, url, table, props) + saveTable(df, url, table, jdbcOptions) case SaveMode.ErrorIfExists => throw new AnalysisException( @@ -95,7 +87,7 @@ class JdbcRelationProvider extends CreatableRelationProvider } } else { createTable(df.schema, url, table, createTableOptions, conn) - saveTable(df, url, table, props) + saveTable(df, url, table, jdbcOptions) } } finally { conn.close() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 47549637b5813..e32db73bd6c6a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.execution.datasources.jdbc import java.sql.{Connection, Driver, DriverManager, PreparedStatement, ResultSet, ResultSetMetaData, SQLException} -import java.util.Properties import scala.collection.JavaConverters._ import scala.util.Try @@ -41,27 +40,13 @@ import org.apache.spark.util.NextIterator * Util functions for JDBC tables. */ object JdbcUtils extends Logging { - - // the property names are case sensitive - val JDBC_BATCH_FETCH_SIZE = "fetchsize" - val JDBC_BATCH_INSERT_SIZE = "batchsize" - val JDBC_TXN_ISOLATION_LEVEL = "isolationLevel" - /** * Returns a factory for creating connections to the given JDBC URL. * - * @param url the JDBC url to connect to. - * @param properties JDBC connection properties. + * @param options - JDBC options that contains url, table and other information. */ - def createConnectionFactory(url: String, properties: Properties): () => Connection = { - val userSpecifiedDriverClass = Option(properties.getProperty("driver")) - userSpecifiedDriverClass.foreach(DriverRegistry.register) - // Performing this part of the logic on the driver guards against the corner-case where the - // driver returned for a URL is different on the driver and executors due to classpath - // differences. - val driverClass: String = userSpecifiedDriverClass.getOrElse { - DriverManager.getDriver(url).getClass.getCanonicalName - } + def createConnectionFactory(options: JDBCOptions): () => Connection = { + val driverClass: String = options.driverClass () => { DriverRegistry.register(driverClass) val driver: Driver = DriverManager.getDrivers.asScala.collectFirst { @@ -71,7 +56,7 @@ object JdbcUtils extends Logging { throw new IllegalStateException( s"Did not find registered driver with class $driverClass") } - driver.connect(url, properties) + driver.connect(options.url, options.asConnectionProperties) } } @@ -550,10 +535,6 @@ object JdbcUtils extends Logging { batchSize: Int, dialect: JdbcDialect, isolationLevel: Int): Iterator[Byte] = { - require(batchSize >= 1, - s"Invalid value `${batchSize.toString}` for parameter " + - s"`$JDBC_BATCH_INSERT_SIZE`. The minimum value is 1.") - val conn = getConnection() var committed = false @@ -676,23 +657,16 @@ object JdbcUtils extends Logging { df: DataFrame, url: String, table: String, - properties: Properties) { + options: JDBCOptions) { val dialect = JdbcDialects.get(url) val nullTypes: Array[Int] = df.schema.fields.map { field => getJdbcType(field.dataType, dialect).jdbcNullType } val rddSchema = df.schema - val getConnection: () => Connection = createConnectionFactory(url, properties) - val batchSize = properties.getProperty(JDBC_BATCH_INSERT_SIZE, "1000").toInt - val isolationLevel = - properties.getProperty(JDBC_TXN_ISOLATION_LEVEL, "READ_UNCOMMITTED") match { - case "NONE" => Connection.TRANSACTION_NONE - case "READ_UNCOMMITTED" => Connection.TRANSACTION_READ_UNCOMMITTED - case "READ_COMMITTED" => Connection.TRANSACTION_READ_COMMITTED - case "REPEATABLE_READ" => Connection.TRANSACTION_REPEATABLE_READ - case "SERIALIZABLE" => Connection.TRANSACTION_SERIALIZABLE - } + val getConnection: () => Connection = createConnectionFactory(options) + val batchSize = options.batchSize + val isolationLevel = options.isolationLevel df.foreachPartition(iterator => savePartition( getConnection, table, iterator, rddSchema, nullTypes, batchSize, dialect, isolationLevel) ) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala index 3f540d6258a0d..4f61a328f47ca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.jdbc import java.sql.{Connection, Types} -import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils +import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils} import org.apache.spark.sql.types._ @@ -94,7 +94,7 @@ private object PostgresDialect extends JdbcDialect { // // See: https://jdbc.postgresql.org/documentation/head/query.html#query-with-cursor // - if (properties.getOrElse(JdbcUtils.JDBC_BATCH_FETCH_SIZE, "0").toInt > 0) { + if (properties.getOrElse(JDBCOptions.JDBC_BATCH_FETCH_SIZE, "0").toInt > 0) { connection.setAutoCommit(false) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 7cc3989b791ad..71cf5e6a22916 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -29,8 +29,7 @@ import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.execution.DataSourceScanExec import org.apache.spark.sql.execution.command.ExplainCommand import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.execution.datasources.jdbc.JDBCRDD -import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils +import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCRDD, JdbcUtils} import org.apache.spark.sql.sources._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -84,7 +83,7 @@ class JDBCSuite extends SparkFunSuite |CREATE TEMPORARY TABLE fetchtwo |USING org.apache.spark.sql.jdbc |OPTIONS (url '$url', dbtable 'TEST.PEOPLE', user 'testUser', password 'testPass', - | ${JdbcUtils.JDBC_BATCH_FETCH_SIZE} '2') + | ${JDBCOptions.JDBC_BATCH_FETCH_SIZE} '2') """.stripMargin.replaceAll("\n", " ")) sql( @@ -354,8 +353,8 @@ class JDBCSuite extends SparkFunSuite test("Basic API with illegal fetchsize") { val properties = new Properties() - properties.setProperty(JdbcUtils.JDBC_BATCH_FETCH_SIZE, "-1") - val e = intercept[SparkException] { + properties.setProperty(JDBCOptions.JDBC_BATCH_FETCH_SIZE, "-1") + val e = intercept[IllegalArgumentException] { spark.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", properties).collect() }.getMessage assert(e.contains("Invalid value `-1` for parameter `fetchsize`")) @@ -364,7 +363,7 @@ class JDBCSuite extends SparkFunSuite test("Basic API with FetchSize") { (0 to 4).foreach { size => val properties = new Properties() - properties.setProperty(JdbcUtils.JDBC_BATCH_FETCH_SIZE, size.toString) + properties.setProperty(JDBCOptions.JDBC_BATCH_FETCH_SIZE, size.toString) assert(spark.read.jdbc( urlWithUserAndPass, "TEST.PEOPLE", properties).collect().length === 3) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index 62b29db4d5524..96540ec92da73 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -26,7 +26,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.SparkException import org.apache.spark.sql.{Row, SaveMode} -import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils +import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -113,8 +113,8 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { (-1 to 0).foreach { size => val properties = new Properties() - properties.setProperty(JdbcUtils.JDBC_BATCH_INSERT_SIZE, size.toString) - val e = intercept[SparkException] { + properties.setProperty(JDBCOptions.JDBC_BATCH_INSERT_SIZE, size.toString) + val e = intercept[IllegalArgumentException] { df.write.mode(SaveMode.Overwrite).jdbc(url, "TEST.BASICCREATETEST", properties) }.getMessage assert(e.contains(s"Invalid value `$size` for parameter `batchsize`")) @@ -126,7 +126,7 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { (1 to 3).foreach { size => val properties = new Properties() - properties.setProperty(JdbcUtils.JDBC_BATCH_INSERT_SIZE, size.toString) + properties.setProperty(JDBCOptions.JDBC_BATCH_INSERT_SIZE, size.toString) df.write.mode(SaveMode.Overwrite).jdbc(url, "TEST.BASICCREATETEST", properties) assert(2 === spark.read.jdbc(url, "TEST.BASICCREATETEST", new Properties()).count()) } From b515768f2668749ad37a3bdf9d265ce45ec447b1 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 10 Oct 2016 22:33:20 -0700 Subject: [PATCH 241/306] [SPARK-17844] Simplify DataFrame API for defining frame boundaries in window functions ## What changes were proposed in this pull request? When I was creating the example code for SPARK-10496, I realized it was pretty convoluted to define the frame boundaries for window functions when there is no partition column or ordering column. The reason is that we don't provide a way to create a WindowSpec directly with the frame boundaries. We can trivially improve this by adding rowsBetween and rangeBetween to Window object. As an example, to compute cumulative sum using the natural ordering, before this pr: ``` df.select('key, sum("value").over(Window.partitionBy(lit(1)).rowsBetween(Long.MinValue, 0))) ``` After this pr: ``` df.select('key, sum("value").over(Window.rowsBetween(Long.MinValue, 0))) ``` Note that you could argue there is no point specifying a window frame without partitionBy/orderBy -- but it is strange that only rowsBetween and rangeBetween are not the only two APIs not available. This also fixes https://issues.apache.org/jira/browse/SPARK-17656 (removing _root_.scala). ## How was this patch tested? Added test cases to compute cumulative sum in DataFrameWindowSuite for Scala/Java and tests.py for Python. Author: Reynold Xin Closes #15412 from rxin/SPARK-17844. --- python/pyspark/sql/tests.py | 9 ++++ python/pyspark/sql/window.py | 48 +++++++++++++++++++ .../apache/spark/sql/expressions/Window.scala | 46 ++++++++++++++++-- .../spark/sql/expressions/WindowSpec.scala | 10 ++-- .../apache/spark/sql/expressions/udaf.scala | 4 +- .../spark/sql/DataFrameWindowSuite.scala | 12 +++++ 6 files changed, 119 insertions(+), 10 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index a9e455565a6cd..7b6f9f0ef1c28 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1859,6 +1859,15 @@ def test_window_functions_without_partitionBy(self): for r, ex in zip(rs, expected): self.assertEqual(tuple(r), ex[:len(r)]) + def test_window_functions_cumulative_sum(self): + df = self.spark.createDataFrame([("one", 1), ("two", 2)], ["key", "value"]) + from pyspark.sql import functions as F + sel = df.select(df.key, F.sum(df.value).over(Window.rowsBetween(-sys.maxsize, 0))) + rs = sorted(sel.collect()) + expected = [("one", 1), ("two", 3)] + for r, ex in zip(rs, expected): + self.assertEqual(tuple(r), ex[:len(r)]) + def test_collect_functions(self): df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) from pyspark.sql import functions diff --git a/python/pyspark/sql/window.py b/python/pyspark/sql/window.py index 46663f69a0881..87e9a988987ea 100644 --- a/python/pyspark/sql/window.py +++ b/python/pyspark/sql/window.py @@ -66,6 +66,54 @@ def orderBy(*cols): jspec = sc._jvm.org.apache.spark.sql.expressions.Window.orderBy(_to_java_cols(cols)) return WindowSpec(jspec) + @staticmethod + @since(2.1) + def rowsBetween(start, end): + """ + Creates a :class:`WindowSpec` with the frame boundaries defined, + from `start` (inclusive) to `end` (inclusive). + + Both `start` and `end` are relative positions from the current row. + For example, "0" means "current row", while "-1" means the row before + the current row, and "5" means the fifth row after the current row. + + :param start: boundary start, inclusive. + The frame is unbounded if this is ``-sys.maxsize`` (or lower). + :param end: boundary end, inclusive. + The frame is unbounded if this is ``sys.maxsize`` (or higher). + """ + if start <= -sys.maxsize: + start = WindowSpec._JAVA_MIN_LONG + if end >= sys.maxsize: + end = WindowSpec._JAVA_MAX_LONG + sc = SparkContext._active_spark_context + jspec = sc._jvm.org.apache.spark.sql.expressions.Window.rowsBetween(start, end) + return WindowSpec(jspec) + + @staticmethod + @since(2.1) + def rangeBetween(start, end): + """ + Creates a :class:`WindowSpec` with the frame boundaries defined, + from `start` (inclusive) to `end` (inclusive). + + Both `start` and `end` are relative from the current row. For example, + "0" means "current row", while "-1" means one off before the current row, + and "5" means the five off after the current row. + + :param start: boundary start, inclusive. + The frame is unbounded if this is ``-sys.maxsize`` (or lower). + :param end: boundary end, inclusive. + The frame is unbounded if this is ``sys.maxsize`` (or higher). + """ + if start <= -sys.maxsize: + start = WindowSpec._JAVA_MIN_LONG + if end >= sys.maxsize: + end = WindowSpec._JAVA_MAX_LONG + sc = SparkContext._active_spark_context + jspec = sc._jvm.org.apache.spark.sql.expressions.Window.rangeBetween(start, end) + return WindowSpec(jspec) + class WindowSpec(object): """ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala index c29ec6f426789..e8a0c5f43fe46 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala @@ -42,7 +42,7 @@ object Window { * Creates a [[WindowSpec]] with the partitioning defined. * @since 1.4.0 */ - @_root_.scala.annotation.varargs + @scala.annotation.varargs def partitionBy(colName: String, colNames: String*): WindowSpec = { spec.partitionBy(colName, colNames : _*) } @@ -51,7 +51,7 @@ object Window { * Creates a [[WindowSpec]] with the partitioning defined. * @since 1.4.0 */ - @_root_.scala.annotation.varargs + @scala.annotation.varargs def partitionBy(cols: Column*): WindowSpec = { spec.partitionBy(cols : _*) } @@ -60,7 +60,7 @@ object Window { * Creates a [[WindowSpec]] with the ordering defined. * @since 1.4.0 */ - @_root_.scala.annotation.varargs + @scala.annotation.varargs def orderBy(colName: String, colNames: String*): WindowSpec = { spec.orderBy(colName, colNames : _*) } @@ -69,11 +69,49 @@ object Window { * Creates a [[WindowSpec]] with the ordering defined. * @since 1.4.0 */ - @_root_.scala.annotation.varargs + @scala.annotation.varargs def orderBy(cols: Column*): WindowSpec = { spec.orderBy(cols : _*) } + /** + * Creates a [[WindowSpec]] with the frame boundaries defined, + * from `start` (inclusive) to `end` (inclusive). + * + * Both `start` and `end` are relative positions from the current row. For example, "0" means + * "current row", while "-1" means the row before the current row, and "5" means the fifth row + * after the current row. + * + * @param start boundary start, inclusive. + * The frame is unbounded if this is the minimum long value. + * @param end boundary end, inclusive. + * The frame is unbounded if this is the maximum long value. + * @since 2.1.0 + */ + // Note: when updating the doc for this method, also update WindowSpec.rowsBetween. + def rowsBetween(start: Long, end: Long): WindowSpec = { + spec.rowsBetween(start, end) + } + + /** + * Creates a [[WindowSpec]] with the frame boundaries defined, + * from `start` (inclusive) to `end` (inclusive). + * + * Both `start` and `end` are relative from the current row. For example, "0" means "current row", + * while "-1" means one off before the current row, and "5" means the five off after the + * current row. + * + * @param start boundary start, inclusive. + * The frame is unbounded if this is the minimum long value. + * @param end boundary end, inclusive. + * The frame is unbounded if this is the maximum long value. + * @since 2.1.0 + */ + // Note: when updating the doc for this method, also update WindowSpec.rangeBetween. + def rangeBetween(start: Long, end: Long): WindowSpec = { + spec.rangeBetween(start, end) + } + private[sql] def spec: WindowSpec = { new WindowSpec(Seq.empty, Seq.empty, UnspecifiedFrame) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala index d716da2668675..82bc8f152d6ea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala @@ -39,7 +39,7 @@ class WindowSpec private[sql]( * Defines the partitioning columns in a [[WindowSpec]]. * @since 1.4.0 */ - @_root_.scala.annotation.varargs + @scala.annotation.varargs def partitionBy(colName: String, colNames: String*): WindowSpec = { partitionBy((colName +: colNames).map(Column(_)): _*) } @@ -48,7 +48,7 @@ class WindowSpec private[sql]( * Defines the partitioning columns in a [[WindowSpec]]. * @since 1.4.0 */ - @_root_.scala.annotation.varargs + @scala.annotation.varargs def partitionBy(cols: Column*): WindowSpec = { new WindowSpec(cols.map(_.expr), orderSpec, frame) } @@ -57,7 +57,7 @@ class WindowSpec private[sql]( * Defines the ordering columns in a [[WindowSpec]]. * @since 1.4.0 */ - @_root_.scala.annotation.varargs + @scala.annotation.varargs def orderBy(colName: String, colNames: String*): WindowSpec = { orderBy((colName +: colNames).map(Column(_)): _*) } @@ -66,7 +66,7 @@ class WindowSpec private[sql]( * Defines the ordering columns in a [[WindowSpec]]. * @since 1.4.0 */ - @_root_.scala.annotation.varargs + @scala.annotation.varargs def orderBy(cols: Column*): WindowSpec = { val sortOrder: Seq[SortOrder] = cols.map { col => col.expr match { @@ -92,6 +92,7 @@ class WindowSpec private[sql]( * The frame is unbounded if this is the maximum long value. * @since 1.4.0 */ + // Note: when updating the doc for this method, also update Window.rowsBetween. def rowsBetween(start: Long, end: Long): WindowSpec = { between(RowFrame, start, end) } @@ -109,6 +110,7 @@ class WindowSpec private[sql]( * The frame is unbounded if this is the maximum long value. * @since 1.4.0 */ + // Note: when updating the doc for this method, also update Window.rangeBetween. def rangeBetween(start: Long, end: Long): WindowSpec = { between(RangeFrame, start, end) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala index eac658c6176cb..5417a0e481158 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala @@ -106,7 +106,7 @@ abstract class UserDefinedAggregateFunction extends Serializable { /** * Creates a [[Column]] for this UDAF using given [[Column]]s as input arguments. */ - @_root_.scala.annotation.varargs + @scala.annotation.varargs def apply(exprs: Column*): Column = { val aggregateExpression = AggregateExpression( @@ -120,7 +120,7 @@ abstract class UserDefinedAggregateFunction extends Serializable { * Creates a [[Column]] for this UDAF using the distinct values of the given * [[Column]]s as input arguments. */ - @_root_.scala.annotation.varargs + @scala.annotation.varargs def distinct(exprs: Column*): Column = { val aggregateExpression = AggregateExpression( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala index c2b47cae8f4c4..5bc386f291043 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala @@ -22,6 +22,9 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{DataType, LongType, StructType} +/** + * Window function testing for DataFrame API. + */ class DataFrameWindowSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -47,6 +50,15 @@ class DataFrameWindowSuite extends QueryTest with SharedSQLContext { Row(1, "1") :: Row(2, "2") :: Row(null, null) :: Row(null, null) :: Nil) } + test("Window.rowsBetween") { + val df = Seq(("one", 1), ("two", 2)).toDF("key", "value") + // Running (cumulative) sum + checkAnswer( + df.select('key, sum("value").over(Window.rowsBetween(Long.MinValue, 0))), + Row("one", 1) :: Row("two", 3) :: Nil + ) + } + test("lead") { val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") df.createOrReplaceTempView("window_table") From 19401a203b441e3355f0d3fc3fd062b6d5bdee1f Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Mon, 10 Oct 2016 22:50:59 -0700 Subject: [PATCH 242/306] [SPARK-15957][ML] RFormula supports forcing to index label ## What changes were proposed in this pull request? ```RFormula``` will index label only when it is string type currently. If the label is numeric type and we use ```RFormula``` to present a classification model, there is no label attributes in label column metadata. The label attributes are useful when making prediction for classification, so we can force to index label by ```StringIndexer``` whether it is numeric or string type for classification. Then SparkR wrappers can extract label attributes from label column metadata successfully. This feature can help us to fix bug similar with [SPARK-15153](https://issues.apache.org/jira/browse/SPARK-15153). For regression, we will still to keep label as numeric type. In this PR, we add a param ```indexLabel``` to control whether to force to index label for ```RFormula```. ## How was this patch tested? Unit tests. Author: Yanbo Liang Closes #13675 from yanboliang/spark-15957. --- .../apache/spark/ml/feature/RFormula.scala | 29 +++++++++++++++++-- .../spark/ml/feature/RFormulaSuite.scala | 27 ++++++++++++++++- 2 files changed, 52 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 2ee899bcca564..389898666eb8e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -26,7 +26,7 @@ import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{Estimator, Model, Pipeline, PipelineModel, PipelineStage, Transformer} import org.apache.spark.ml.attribute.AttributeGroup import org.apache.spark.ml.linalg.VectorUDT -import org.apache.spark.ml.param.{Param, ParamMap} +import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap} import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol} import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset} @@ -104,6 +104,27 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String) @Since("1.5.0") def setLabelCol(value: String): this.type = set(labelCol, value) + /** + * Force to index label whether it is numeric or string type. + * Usually we index label only when it is string type. + * If the formula was used by classification algorithms, + * we can force to index label even it is numeric type by setting this param with true. + * Default: false. + * @group param + */ + @Since("2.1.0") + val forceIndexLabel: BooleanParam = new BooleanParam(this, "forceIndexLabel", + "Force to index label whether it is numeric or string") + setDefault(forceIndexLabel -> false) + + /** @group getParam */ + @Since("2.1.0") + def getForceIndexLabel: Boolean = $(forceIndexLabel) + + /** @group setParam */ + @Since("2.1.0") + def setForceIndexLabel(value: Boolean): this.type = set(forceIndexLabel, value) + /** Whether the formula specifies fitting an intercept. */ private[ml] def hasIntercept: Boolean = { require(isDefined(formula), "Formula must be defined first.") @@ -167,8 +188,8 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String) encoderStages += new VectorAttributeRewriter($(featuresCol), prefixesToRewrite.toMap) encoderStages += new ColumnPruner(tempColumns.toSet) - if (dataset.schema.fieldNames.contains(resolvedFormula.label) && - dataset.schema(resolvedFormula.label).dataType == StringType) { + if ((dataset.schema.fieldNames.contains(resolvedFormula.label) && + dataset.schema(resolvedFormula.label).dataType == StringType) || $(forceIndexLabel)) { encoderStages += new StringIndexer() .setInputCol(resolvedFormula.label) .setOutputCol($(labelCol)) @@ -181,6 +202,8 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String) @Since("1.5.0") // optimistic schema; does not contain any ML attributes override def transformSchema(schema: StructType): StructType = { + require(!hasLabelCol(schema) || !$(forceIndexLabel), + "If label column already exists, forceIndexLabel can not be set with true.") if (hasLabelCol(schema)) { StructType(schema.fields :+ StructField($(featuresCol), new VectorUDT, true)) } else { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index 97c268f3d5c97..c664460d7d8bb 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -57,7 +57,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul } } - test("label column already exists") { + test("label column already exists and forceIndexLabel was set with false") { val formula = new RFormula().setFormula("y ~ x").setLabelCol("y") val original = Seq((0, 1.0), (2, 2.0)).toDF("x", "y") val model = formula.fit(original) @@ -66,6 +66,14 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul assert(resultSchema.toString == model.transform(original).schema.toString) } + test("label column already exists but forceIndexLabel was set with true") { + val formula = new RFormula().setFormula("y ~ x").setLabelCol("y").setForceIndexLabel(true) + val original = spark.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "y") + intercept[IllegalArgumentException] { + formula.fit(original) + } + } + test("label column already exists but is not numeric type") { val formula = new RFormula().setFormula("y ~ x").setLabelCol("y") val original = Seq((0, true), (2, false)).toDF("x", "y") @@ -137,6 +145,23 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul assert(result.collect() === expected.collect()) } + test("force to index label even it is numeric type") { + val formula = new RFormula().setFormula("id ~ a + b").setForceIndexLabel(true) + val original = spark.createDataFrame( + Seq((1.0, "foo", 4), (1.0, "bar", 4), (0.0, "bar", 5), (1.0, "baz", 5)) + ).toDF("id", "a", "b") + val model = formula.fit(original) + val result = model.transform(original) + val expected = spark.createDataFrame( + Seq( + (1.0, "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 0.0), + (1.0, "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 0.0), + (0.0, "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 1.0), + (1.0, "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 0.0)) + ).toDF("id", "a", "b", "features", "label") + assert(result.collect() === expected.collect()) + } + test("attribute generation") { val formula = new RFormula().setFormula("id ~ a + b") val original = Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5)) From 658c7147f5bf637f36e8c66b9207d94b1e7c74c5 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Tue, 11 Oct 2016 08:29:52 +0200 Subject: [PATCH 243/306] [SPARK-17808][PYSPARK] Upgraded version of Pyrolite to 4.13 ## What changes were proposed in this pull request? Upgraded to a newer version of Pyrolite which supports serialization of a BinaryType StructField for PySpark.SQL ## How was this patch tested? Added a unit test which fails with a raised ValueError when using the previous version of Pyrolite 4.9 and Python3 Author: Bryan Cutler Closes #15386 from BryanCutler/pyrolite-upgrade-SPARK-17808. --- core/pom.xml | 2 +- dev/deps/spark-deps-hadoop-2.2 | 2 +- dev/deps/spark-deps-hadoop-2.3 | 2 +- dev/deps/spark-deps-hadoop-2.4 | 2 +- dev/deps/spark-deps-hadoop-2.6 | 2 +- dev/deps/spark-deps-hadoop-2.7 | 2 +- python/pyspark/sql/tests.py | 8 ++++++++ 7 files changed, 14 insertions(+), 6 deletions(-) diff --git a/core/pom.xml b/core/pom.xml index 9a4f234953a23..205bbc588be09 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -320,7 +320,7 @@ net.razorvine pyrolite - 4.9 + 4.13 net.razorvine diff --git a/dev/deps/spark-deps-hadoop-2.2 b/dev/deps/spark-deps-hadoop-2.2 index f4f92c6d20c23..b30f8c347c0af 100644 --- a/dev/deps/spark-deps-hadoop-2.2 +++ b/dev/deps/spark-deps-hadoop-2.2 @@ -141,7 +141,7 @@ pmml-model-1.2.15.jar pmml-schema-1.2.15.jar protobuf-java-2.5.0.jar py4j-0.10.3.jar -pyrolite-4.9.jar +pyrolite-4.13.jar scala-compiler-2.11.8.jar scala-library-2.11.8.jar scala-parser-combinators_2.11-1.0.4.jar diff --git a/dev/deps/spark-deps-hadoop-2.3 b/dev/deps/spark-deps-hadoop-2.3 index 3db013f1a7585..5b3a7651dd299 100644 --- a/dev/deps/spark-deps-hadoop-2.3 +++ b/dev/deps/spark-deps-hadoop-2.3 @@ -148,7 +148,7 @@ pmml-model-1.2.15.jar pmml-schema-1.2.15.jar protobuf-java-2.5.0.jar py4j-0.10.3.jar -pyrolite-4.9.jar +pyrolite-4.13.jar scala-compiler-2.11.8.jar scala-library-2.11.8.jar scala-parser-combinators_2.11-1.0.4.jar diff --git a/dev/deps/spark-deps-hadoop-2.4 b/dev/deps/spark-deps-hadoop-2.4 index 71710109a16ac..e323efe30f64b 100644 --- a/dev/deps/spark-deps-hadoop-2.4 +++ b/dev/deps/spark-deps-hadoop-2.4 @@ -148,7 +148,7 @@ pmml-model-1.2.15.jar pmml-schema-1.2.15.jar protobuf-java-2.5.0.jar py4j-0.10.3.jar -pyrolite-4.9.jar +pyrolite-4.13.jar scala-compiler-2.11.8.jar scala-library-2.11.8.jar scala-parser-combinators_2.11-1.0.4.jar diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index cb30fda253c0a..77d97e5365b9f 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -156,7 +156,7 @@ pmml-model-1.2.15.jar pmml-schema-1.2.15.jar protobuf-java-2.5.0.jar py4j-0.10.3.jar -pyrolite-4.9.jar +pyrolite-4.13.jar scala-compiler-2.11.8.jar scala-library-2.11.8.jar scala-parser-combinators_2.11-1.0.4.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 9008aa80bc877..572edfa0cc29e 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -157,7 +157,7 @@ pmml-model-1.2.15.jar pmml-schema-1.2.15.jar protobuf-java-2.5.0.jar py4j-0.10.3.jar -pyrolite-4.9.jar +pyrolite-4.13.jar scala-compiler-2.11.8.jar scala-library-2.11.8.jar scala-parser-combinators_2.11-1.0.4.jar diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 7b6f9f0ef1c28..86c590dae34d7 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1708,6 +1708,14 @@ def test_read_text_file_list(self): count = df.count() self.assertEquals(count, 4) + def test_BinaryType_serialization(self): + # Pyrolite version <= 4.9 could not serialize BinaryType with Python3 SPARK-17808 + schema = StructType([StructField('mybytes', BinaryType())]) + data = [[bytearray(b'here is my data')], + [bytearray(b'and here is some more')]] + df = self.spark.createDataFrame(data, schema=schema) + df.collect() + class HiveSparkSubmitTests(SparkSubmitTests): From 7388ad94d717784a1837ac5a4a9b53219892d080 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 11 Oct 2016 15:21:28 +0800 Subject: [PATCH 244/306] [SPARK-17338][SQL][FOLLOW-UP] add global temp view ## What changes were proposed in this pull request? address post hoc review comments for https://github.com/apache/spark/pull/14897 ## How was this patch tested? N/A Author: Wenchen Fan Closes #15424 from cloud-fan/global-temp-view. --- project/MimaExcludes.scala | 4 +++- python/pyspark/sql/catalog.py | 5 +++++ .../spark/sql/catalyst/catalog/SessionCatalog.scala | 8 ++++++-- .../src/main/scala/org/apache/spark/sql/Dataset.scala | 9 ++------- .../scala/org/apache/spark/sql/catalog/Catalog.scala | 7 ++++++- .../org/apache/spark/sql/internal/CatalogImpl.scala | 4 ++-- 6 files changed, 24 insertions(+), 13 deletions(-) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index e3d9a17469a35..ae72d37a0b61c 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -57,7 +57,9 @@ object MimaExcludes { ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.tableExists"), ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.functionExists"), // [SPARK-17338][SQL] add global temp view - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.dropGlobalTempView") + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.dropGlobalTempView"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.catalog.Catalog.dropTempView"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.dropTempView") ) } diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py index df3bf4254d4d3..a36d02e0db134 100644 --- a/python/pyspark/sql/catalog.py +++ b/python/pyspark/sql/catalog.py @@ -169,6 +169,10 @@ def createExternalTable(self, tableName, path=None, source=None, schema=None, ** def dropTempView(self, viewName): """Drops the local temporary view with the given view name in the catalog. If the view has been cached before, then it will also be uncached. + Returns true if this view is dropped successfully, false otherwise. + + Note that, the return type of this method was None in Spark 2.0, but changed to Boolean + in Spark 2.1. >>> spark.createDataFrame([(1, 1)]).createTempView("my_table") >>> spark.table("my_table").collect() @@ -185,6 +189,7 @@ def dropTempView(self, viewName): def dropGlobalTempView(self, viewName): """Drops the global temporary view with the given view name in the catalog. If the view has been cached before, then it will also be uncached. + Returns true if this view is dropped successfully, false otherwise. >>> spark.createDataFrame([(1, 1)]).createGlobalTempView("my_table") >>> spark.table("global_temp.my_table").collect() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index e44e30ec648f6..5863c6a71cdf9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -405,13 +405,17 @@ class SessionCatalog( /** * Drop a local temporary view. + * + * Returns true if this view is dropped successfully, false otherwise. */ - def dropTempView(name: String): Unit = synchronized { - tempTables.remove(formatTableName(name)) + def dropTempView(name: String): Boolean = synchronized { + tempTables.remove(formatTableName(name)).isDefined } /** * Drop a global temporary view. + * + * Returns true if this view is dropped successfully, false otherwise. */ def dropGlobalTempView(name: String): Boolean = { globalTempViewManager.remove(formatTableName(name)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 30349ba3cb452..a7a84730a6fd9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2494,7 +2494,7 @@ class Dataset[T] private[sql]( * preserved database `_global_temp`, and we must use the qualified name to refer a global temp * view, e.g. `SELECT * FROM _global_temp.view1`. * - * @throws TempTableAlreadyExistsException if the view name already exists + * @throws AnalysisException if the view name already exists * * @group basic * @since 2.1.0 @@ -2508,12 +2508,7 @@ class Dataset[T] private[sql]( viewName: String, replace: Boolean, global: Boolean): CreateViewCommand = { - val viewType = if (global) { - GlobalTempView - } else { - LocalTempView - } - + val viewType = if (global) GlobalTempView else LocalTempView CreateViewCommand( name = sparkSession.sessionState.sqlParser.parseTableIdentifier(viewName), userSpecifiedColumns = Nil, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala index 717fb291901bf..18cba8ce28b4d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala @@ -269,10 +269,14 @@ abstract class Catalog { * created it, i.e. it will be automatically dropped when the session terminates. It's not * tied to any databases, i.e. we can't use `db1.view1` to reference a local temporary view. * + * Note that, the return type of this method was Unit in Spark 2.0, but changed to Boolean + * in Spark 2.1. + * * @param viewName the name of the view to be dropped. + * @return true if the view is dropped successfully, false otherwise. * @since 2.0.0 */ - def dropTempView(viewName: String): Unit + def dropTempView(viewName: String): Boolean /** * Drops the global temporary view with the given view name in the catalog. @@ -284,6 +288,7 @@ abstract class Catalog { * view, e.g. `SELECT * FROM _global_temp.view1`. * * @param viewName the name of the view to be dropped. + * @return true if the view is dropped successfully, false otherwise. * @since 2.1.0 */ def dropGlobalTempView(viewName: String): Boolean diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala index c05bda3f1b526..f6c297e91b7c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala @@ -371,8 +371,8 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { * @group ddl_ops * @since 2.0.0 */ - override def dropTempView(viewName: String): Unit = { - sparkSession.sessionState.catalog.getTempView(viewName).foreach { tempView => + override def dropTempView(viewName: String): Boolean = { + sparkSession.sessionState.catalog.getTempView(viewName).exists { tempView => sparkSession.sharedState.cacheManager.uncacheQuery(Dataset.ofRows(sparkSession, tempView)) sessionCatalog.dropTempView(viewName) } From 3694ba48f0db0f47baea4b005cdeef3f454b7329 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 11 Oct 2016 15:35:52 +0800 Subject: [PATCH 245/306] [SPARK-17864][SQL] Mark data type APIs as stable (not DeveloperApi) ## What changes were proposed in this pull request? The data type API has not been changed since Spark 1.3.0, and is ready for graduation. This patch marks them as stable APIs using the new InterfaceStability annotation. This patch also looks at the various files in the catalyst module (not the "package") and marks the remaining few classes appropriately as well. ## How was this patch tested? This is an annotation change. No functional changes. Author: Reynold Xin Closes #15426 from rxin/SPARK-17864. --- .../java/org/apache/spark/sql/RowFactory.java | 6 +++++ .../spark/sql/streaming/OutputMode.java | 2 ++ .../org/apache/spark/sql/types/DataTypes.java | 5 ++++ .../spark/sql/types/SQLUserDefinedType.java | 2 ++ .../apache/spark/sql/AnalysisException.scala | 9 ++++---- .../scala/org/apache/spark/sql/Encoder.scala | 3 ++- .../scala/org/apache/spark/sql/Encoders.scala | 3 ++- .../main/scala/org/apache/spark/sql/Row.scala | 10 ++++++-- .../spark/sql/types/AbstractDataType.scala | 7 +++--- .../apache/spark/sql/types/ArrayType.scala | 14 +++++++---- .../apache/spark/sql/types/BinaryType.scala | 10 ++++---- .../apache/spark/sql/types/BooleanType.scala | 12 ++++++---- .../org/apache/spark/sql/types/ByteType.scala | 12 +++++++--- .../sql/types/CalendarIntervalType.scala | 12 ++++++---- .../org/apache/spark/sql/types/DataType.scala | 11 ++++++--- .../org/apache/spark/sql/types/DateType.scala | 12 ++++++---- .../org/apache/spark/sql/types/Decimal.scala | 5 ++-- .../apache/spark/sql/types/DecimalType.scala | 14 +++++++---- .../apache/spark/sql/types/DoubleType.scala | 11 ++++++--- .../apache/spark/sql/types/FloatType.scala | 12 +++++++--- .../apache/spark/sql/types/IntegerType.scala | 11 ++++++--- .../org/apache/spark/sql/types/LongType.scala | 12 ++++++---- .../org/apache/spark/sql/types/MapType.scala | 10 ++++---- .../org/apache/spark/sql/types/Metadata.scala | 20 +++++++++------- .../org/apache/spark/sql/types/NullType.scala | 11 ++++++--- .../apache/spark/sql/types/ShortType.scala | 11 ++++++--- .../apache/spark/sql/types/StringType.scala | 11 ++++++--- .../apache/spark/sql/types/StructField.scala | 5 ++++ .../apache/spark/sql/types/StructType.scala | 23 +++++++++++-------- .../spark/sql/types/TimestampType.scala | 11 ++++++--- .../spark/sql/types/UserDefinedType.scala | 4 ---- 31 files changed, 207 insertions(+), 94 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/RowFactory.java b/sql/catalyst/src/main/java/org/apache/spark/sql/RowFactory.java index 5ed60fe78d116..2ce1fdcbf56ae 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/RowFactory.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/RowFactory.java @@ -17,16 +17,22 @@ package org.apache.spark.sql; +import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.catalyst.expressions.GenericRow; /** * A factory class used to construct {@link Row} objects. + * + * @since 1.3.0 */ +@InterfaceStability.Stable public class RowFactory { /** * Create a {@link Row} from the given arguments. Position i in the argument list becomes * position i in the created {@link Row} object. + * + * @since 1.3.0 */ public static Row create(Object ... values) { return new GenericRow(values); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/OutputMode.java b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/OutputMode.java index 41e2582921198..49a18df2c72c0 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/OutputMode.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/OutputMode.java @@ -18,6 +18,7 @@ package org.apache.spark.sql.streaming; import org.apache.spark.annotation.Experimental; +import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.InternalOutputModes; /** @@ -29,6 +30,7 @@ * @since 2.0.0 */ @Experimental +@InterfaceStability.Evolving public class OutputMode { /** diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java b/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java index 747ab1809fc0a..0f8570fe470bd 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java @@ -19,10 +19,15 @@ import java.util.*; +import org.apache.spark.annotation.InterfaceStability; + /** * To get/create specific data type, users should use singleton objects and factory methods * provided by this class. + * + * @since 1.3.0 */ +@InterfaceStability.Stable public class DataTypes { /** * Gets the StringType object. diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/types/SQLUserDefinedType.java b/sql/catalyst/src/main/java/org/apache/spark/sql/types/SQLUserDefinedType.java index 110ed460cc8fa..1290614a3207d 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/types/SQLUserDefinedType.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/types/SQLUserDefinedType.java @@ -20,6 +20,7 @@ import java.lang.annotation.*; import org.apache.spark.annotation.DeveloperApi; +import org.apache.spark.annotation.InterfaceStability; /** * ::DeveloperApi:: @@ -30,6 +31,7 @@ @DeveloperApi @Retention(RetentionPolicy.RUNTIME) @Target(ElementType.TYPE) +@InterfaceStability.Evolving public @interface SQLUserDefinedType { /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala index 6911843999392..f3003306acc6d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala @@ -17,17 +17,16 @@ package org.apache.spark.sql -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -// TODO: don't swallow original stack trace if it exists - /** - * :: DeveloperApi :: * Thrown when a query fails to analyze, usually because the query itself is invalid. + * + * @since 1.3.0 */ -@DeveloperApi +@InterfaceStability.Stable class AnalysisException protected[sql] ( val message: String, val line: Option[Int] = None, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala index 501c1304dbedb..b9f8c46443021 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql import scala.annotation.implicitNotFound import scala.reflect.ClassTag -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, InterfaceStability} import org.apache.spark.sql.types._ @@ -67,6 +67,7 @@ import org.apache.spark.sql.types._ * @since 1.6.0 */ @Experimental +@InterfaceStability.Evolving @implicitNotFound("Unable to find encoder for type stored in a Dataset. Primitive types " + "(Int, String, etc) and Product types (case classes) are supported by importing " + "spark.implicits._ Support for serializing other types will be added in future " + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala index e72f67c48a296..dc90659a676e0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala @@ -22,7 +22,7 @@ import java.lang.reflect.Modifier import scala.reflect.{classTag, ClassTag} import scala.reflect.runtime.universe.TypeTag -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, InterfaceStability} import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} import org.apache.spark.sql.catalyst.expressions.{BoundReference, Cast} @@ -36,6 +36,7 @@ import org.apache.spark.sql.types._ * @since 1.6.0 */ @Experimental +@InterfaceStability.Evolving object Encoders { /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala index e16850efbea5f..344dcb9bce62b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala @@ -20,9 +20,14 @@ package org.apache.spark.sql import scala.collection.JavaConverters._ import scala.util.hashing.MurmurHash3 +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.expressions.GenericRow import org.apache.spark.sql.types.StructType +/** + * @since 1.3.0 + */ +@InterfaceStability.Stable object Row { /** * This method can be used to extract fields from a [[Row]] object in a pattern match. Example: @@ -117,8 +122,9 @@ object Row { * } * }}} * - * @group row + * @since 1.3.0 */ +@InterfaceStability.Stable trait Row extends Serializable { /** Number of elements in the Row. */ def size: Int = length @@ -351,7 +357,7 @@ trait Row extends Serializable { }.toMap } - override def toString(): String = s"[${this.mkString(",")}]" + override def toString: String = s"[${this.mkString(",")}]" /** * Make a copy of the current [[Row]] object. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala index 1981fd8f0a1b5..76dbb7cf0aec1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.types import scala.reflect.runtime.universe.TypeTag -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.expressions.Expression /** @@ -131,10 +131,11 @@ protected[sql] abstract class AtomicType extends DataType { /** - * :: DeveloperApi :: * Numeric data types. + * + * @since 1.3.0 */ -@DeveloperApi +@InterfaceStability.Stable abstract class NumericType extends AtomicType { // Unfortunately we can't get this implicitly as that breaks Spark Serialization. In order for // implicitly[Numeric[JvmType]] to be valid, we have to change JvmType from a type variable to a diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala index 82a03b0afc002..5d70ef01373f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala @@ -21,9 +21,15 @@ import scala.math.Ordering import org.json4s.JsonDSL._ -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.util.ArrayData +/** + * Companion object for ArrayType. + * + * @since 1.3.0 + */ +@InterfaceStability.Stable object ArrayType extends AbstractDataType { /** Construct a [[ArrayType]] object with the given element type. The `containsNull` is true. */ def apply(elementType: DataType): ArrayType = ArrayType(elementType, containsNull = true) @@ -37,9 +43,7 @@ object ArrayType extends AbstractDataType { override private[sql] def simpleString: String = "array" } - /** - * :: DeveloperApi :: * The data type for collections of multiple values. * Internally these are represented as columns that contain a ``scala.collection.Seq``. * @@ -51,8 +55,10 @@ object ArrayType extends AbstractDataType { * * @param elementType The data type of values. * @param containsNull Indicates if values have `null` values + * + * @since 1.3.0 */ -@DeveloperApi +@InterfaceStability.Stable case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataType { /** No-arg constructor for kryo. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala index c40e140e8c5c6..a4a358a242c70 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala @@ -20,17 +20,16 @@ package org.apache.spark.sql.types import scala.math.Ordering import scala.reflect.runtime.universe.typeTag -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.ScalaReflectionLock import org.apache.spark.sql.catalyst.util.TypeUtils /** - * :: DeveloperApi :: * The data type representing `Array[Byte]` values. * Please use the singleton [[DataTypes.BinaryType]]. */ -@DeveloperApi +@InterfaceStability.Stable class BinaryType private() extends AtomicType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "BinaryType$" in byte code. @@ -54,5 +53,8 @@ class BinaryType private() extends AtomicType { private[spark] override def asNullable: BinaryType = this } - +/** + * @since 1.3.0 + */ +@InterfaceStability.Stable case object BinaryType extends BinaryType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BooleanType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BooleanType.scala index 2d8ee3d9bc286..059f89f9cda32 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BooleanType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BooleanType.scala @@ -20,15 +20,16 @@ package org.apache.spark.sql.types import scala.math.Ordering import scala.reflect.runtime.universe.typeTag -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.ScalaReflectionLock /** - * :: DeveloperApi :: * The data type representing `Boolean` values. Please use the singleton [[DataTypes.BooleanType]]. + * + * @since 1.3.0 */ -@DeveloperApi +@InterfaceStability.Stable class BooleanType private() extends AtomicType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "BooleanType$" in byte code. @@ -45,5 +46,8 @@ class BooleanType private() extends AtomicType { private[spark] override def asNullable: BooleanType = this } - +/** + * @since 1.3.0 + */ +@InterfaceStability.Stable case object BooleanType extends BooleanType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ByteType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ByteType.scala index d37130e27ba5a..bc6251f024e58 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ByteType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ByteType.scala @@ -20,14 +20,15 @@ package org.apache.spark.sql.types import scala.math.{Integral, Numeric, Ordering} import scala.reflect.runtime.universe.typeTag -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.ScalaReflectionLock /** - * :: DeveloperApi :: * The data type representing `Byte` values. Please use the singleton [[DataTypes.ByteType]]. + * + * @since 1.3.0 */ -@DeveloperApi +@InterfaceStability.Stable class ByteType private() extends IntegralType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "ByteType$" in byte code. @@ -48,4 +49,9 @@ class ByteType private() extends IntegralType { private[spark] override def asNullable: ByteType = this } + +/** + * @since 1.3.0 + */ +@InterfaceStability.Stable case object ByteType extends ByteType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CalendarIntervalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CalendarIntervalType.scala index 3565f52c21f69..e121044288e5a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CalendarIntervalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CalendarIntervalType.scala @@ -17,19 +17,19 @@ package org.apache.spark.sql.types -import org.apache.spark.annotation.DeveloperApi - +import org.apache.spark.annotation.InterfaceStability /** - * :: DeveloperApi :: * The data type representing calendar time intervals. The calendar time interval is stored * internally in two components: number of months the number of microseconds. * * Note that calendar intervals are not comparable. * * Please use the singleton [[DataTypes.CalendarIntervalType]]. + * + * @since 1.5.0 */ -@DeveloperApi +@InterfaceStability.Stable class CalendarIntervalType private() extends DataType { override def defaultSize: Int = 16 @@ -37,4 +37,8 @@ class CalendarIntervalType private() extends DataType { private[spark] override def asNullable: CalendarIntervalType = this } +/** + * @since 1.5.0 + */ +@InterfaceStability.Stable case object CalendarIntervalType extends CalendarIntervalType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 4fc65cbce15bd..312585df1516b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -22,15 +22,16 @@ import org.json4s.JsonAST.JValue import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.util.Utils /** - * :: DeveloperApi :: * The base type of all Spark SQL data types. + * + * @since 1.3.0 */ -@DeveloperApi +@InterfaceStability.Stable abstract class DataType extends AbstractDataType { /** * Enables matching against DataType for expressions: @@ -94,6 +95,10 @@ abstract class DataType extends AbstractDataType { } +/** + * @since 1.3.0 + */ +@InterfaceStability.Stable object DataType { def fromJson(json: String): DataType = parseDataType(parse(json)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateType.scala index 2c966230e447e..8d0ecc051f4ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateType.scala @@ -20,19 +20,20 @@ package org.apache.spark.sql.types import scala.math.Ordering import scala.reflect.runtime.universe.typeTag -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.ScalaReflectionLock /** - * :: DeveloperApi :: * A date type, supporting "0001-01-01" through "9999-12-31". * * Please use the singleton [[DataTypes.DateType]]. * * Internally, this is represented as the number of days from 1970-01-01. + * + * @since 1.3.0 */ -@DeveloperApi +@InterfaceStability.Stable class DateType private() extends AtomicType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "DateType$" in byte code. @@ -51,5 +52,8 @@ class DateType private() extends AtomicType { private[spark] override def asNullable: DateType = this } - +/** + * @since 1.3.0 + */ +@InterfaceStability.Stable case object DateType extends DateType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 70859052872dd..465fb83669a76 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.types import java.lang.{Long => JLong} import java.math.{BigInteger, MathContext, RoundingMode} -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.InterfaceStability /** * A mutable implementation of BigDecimal that can hold a Long if values are small enough. @@ -30,6 +30,7 @@ import org.apache.spark.annotation.DeveloperApi * - If decimalVal is set, it represents the whole decimal value * - Otherwise, the decimal value is longVal / (10 ** _scale) */ +@InterfaceStability.Unstable final class Decimal extends Ordered[Decimal] with Serializable { import org.apache.spark.sql.types.Decimal._ @@ -185,7 +186,6 @@ final class Decimal extends Ordered[Decimal] with Serializable { override def toString: String = toBigDecimal.toString() - @DeveloperApi def toDebugString: String = { if (decimalVal.ne(null)) { s"Decimal(expanded,$decimalVal,$precision,$scale})" @@ -380,6 +380,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { } } +@InterfaceStability.Unstable object Decimal { val ROUND_HALF_UP = BigDecimal.RoundingMode.HALF_UP val ROUND_HALF_EVEN = BigDecimal.RoundingMode.HALF_EVEN diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index 6500875f95e54..d7ca0cbeedcd3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -19,14 +19,13 @@ package org.apache.spark.sql.types import scala.reflect.runtime.universe.typeTag -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.ScalaReflectionLock import org.apache.spark.sql.catalyst.expressions.Expression /** - * :: DeveloperApi :: * The data type representing `java.math.BigDecimal` values. * A Decimal that must have fixed precision (the maximum number of digits) and scale (the number * of digits on right side of dot). @@ -36,8 +35,10 @@ import org.apache.spark.sql.catalyst.expressions.Expression * The default precision and scale is (10, 0). * * Please use [[DataTypes.createDecimalType()]] to create a specific instance. + * + * @since 1.3.0 */ -@DeveloperApi +@InterfaceStability.Stable case class DecimalType(precision: Int, scale: Int) extends FractionalType { if (scale > precision) { @@ -101,7 +102,12 @@ case class DecimalType(precision: Int, scale: Int) extends FractionalType { } -/** Extra factory methods and pattern matchers for Decimals */ +/** + * Extra factory methods and pattern matchers for Decimals. + * + * @since 1.3.0 + */ +@InterfaceStability.Stable object DecimalType extends AbstractDataType { import scala.math.min diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala index e553f65f3c99d..c21ac0e43eee0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala @@ -21,15 +21,16 @@ import scala.math.{Fractional, Numeric, Ordering} import scala.math.Numeric.DoubleAsIfIntegral import scala.reflect.runtime.universe.typeTag -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.ScalaReflectionLock import org.apache.spark.util.Utils /** - * :: DeveloperApi :: * The data type representing `Double` values. Please use the singleton [[DataTypes.DoubleType]]. + * + * @since 1.3.0 */ -@DeveloperApi +@InterfaceStability.Stable class DoubleType private() extends FractionalType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "DoubleType$" in byte code. @@ -51,4 +52,8 @@ class DoubleType private() extends FractionalType { private[spark] override def asNullable: DoubleType = this } +/** + * @since 1.3.0 + */ +@InterfaceStability.Stable case object DoubleType extends DoubleType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala index ae9aa9eefaf2a..c5bf8883bad93 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala @@ -21,15 +21,16 @@ import scala.math.{Fractional, Numeric, Ordering} import scala.math.Numeric.FloatAsIfIntegral import scala.reflect.runtime.universe.typeTag -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.ScalaReflectionLock import org.apache.spark.util.Utils /** - * :: DeveloperApi :: * The data type representing `Float` values. Please use the singleton [[DataTypes.FloatType]]. + * + * @since 1.3.0 */ -@DeveloperApi +@InterfaceStability.Stable class FloatType private() extends FractionalType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "FloatType$" in byte code. @@ -51,4 +52,9 @@ class FloatType private() extends FractionalType { private[spark] override def asNullable: FloatType = this } + +/** + * @since 1.3.0 + */ +@InterfaceStability.Stable case object FloatType extends FloatType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntegerType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntegerType.scala index 38a7b8ee52651..724e59c0bcbf4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntegerType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntegerType.scala @@ -20,15 +20,16 @@ package org.apache.spark.sql.types import scala.math.{Integral, Numeric, Ordering} import scala.reflect.runtime.universe.typeTag -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.ScalaReflectionLock /** - * :: DeveloperApi :: * The data type representing `Int` values. Please use the singleton [[DataTypes.IntegerType]]. + * + * @since 1.3.0 */ -@DeveloperApi +@InterfaceStability.Stable class IntegerType private() extends IntegralType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "IntegerType$" in byte code. @@ -49,4 +50,8 @@ class IntegerType private() extends IntegralType { private[spark] override def asNullable: IntegerType = this } +/** + * @since 1.3.0 + */ +@InterfaceStability.Stable case object IntegerType extends IntegerType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/LongType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/LongType.scala index 88aff0c87755c..42285a9d0aa29 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/LongType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/LongType.scala @@ -20,14 +20,15 @@ package org.apache.spark.sql.types import scala.math.{Integral, Numeric, Ordering} import scala.reflect.runtime.universe.typeTag -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.ScalaReflectionLock /** - * :: DeveloperApi :: * The data type representing `Long` values. Please use the singleton [[DataTypes.LongType]]. + * + * @since 1.3.0 */ -@DeveloperApi +@InterfaceStability.Stable class LongType private() extends IntegralType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "LongType$" in byte code. @@ -48,5 +49,8 @@ class LongType private() extends IntegralType { private[spark] override def asNullable: LongType = this } - +/** + * @since 1.3.0 + */ +@InterfaceStability.Stable case object LongType extends LongType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala index 178960929bd83..3a32aa43d1c3a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala @@ -20,10 +20,9 @@ package org.apache.spark.sql.types import org.json4s.JsonAST.JValue import org.json4s.JsonDSL._ -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.InterfaceStability /** - * :: DeveloperApi :: * The data type for Maps. Keys in a map are not allowed to have `null` values. * * Please use [[DataTypes.createMapType()]] to create a specific instance. @@ -32,7 +31,7 @@ import org.apache.spark.annotation.DeveloperApi * @param valueType The data type of map values. * @param valueContainsNull Indicates if map values have `null` values. */ -@DeveloperApi +@InterfaceStability.Stable case class MapType( keyType: DataType, valueType: DataType, @@ -76,7 +75,10 @@ case class MapType( } } - +/** + * @since 1.3.0 + */ +@InterfaceStability.Stable object MapType extends AbstractDataType { override private[sql] def defaultConcreteType: DataType = apply(NullType, NullType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala index 657bd86ce17d9..3aa4bf619f274 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala @@ -22,22 +22,22 @@ import scala.collection.mutable import org.json4s._ import org.json4s.jackson.JsonMethods._ -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.InterfaceStability /** - * :: DeveloperApi :: - * * Metadata is a wrapper over Map[String, Any] that limits the value type to simple ones: Boolean, * Long, Double, String, Metadata, Array[Boolean], Array[Long], Array[Double], Array[String], and * Array[Metadata]. JSON is used for serialization. * * The default constructor is private. User should use either [[MetadataBuilder]] or - * [[Metadata.fromJson()]] to create Metadata instances. + * `Metadata.fromJson()` to create Metadata instances. * * @param map an immutable map that stores the data + * + * @since 1.3.0 */ -@DeveloperApi +@InterfaceStability.Stable sealed class Metadata private[types] (private[types] val map: Map[String, Any]) extends Serializable { @@ -114,6 +114,10 @@ sealed class Metadata private[types] (private[types] val map: Map[String, Any]) private[sql] def jsonValue: JValue = Metadata.toJsonValue(this) } +/** + * @since 1.3.0 + */ +@InterfaceStability.Stable object Metadata { private[this] val _empty = new Metadata(Map.empty) @@ -218,11 +222,11 @@ object Metadata { } /** - * :: DeveloperApi :: - * * Builder for [[Metadata]]. If there is a key collision, the latter will overwrite the former. + * + * @since 1.3.0 */ -@DeveloperApi +@InterfaceStability.Stable class MetadataBuilder { private val map: mutable.Map[String, Any] = mutable.Map.empty diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/NullType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/NullType.scala index aa84115c2e42c..bdf9a819d007b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/NullType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/NullType.scala @@ -17,14 +17,15 @@ package org.apache.spark.sql.types -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.InterfaceStability /** - * :: DeveloperApi :: * The data type representing `NULL` values. Please use the singleton [[DataTypes.NullType]]. + * + * @since 1.3.0 */ -@DeveloperApi +@InterfaceStability.Stable class NullType private() extends DataType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "NullType$" in byte code. @@ -34,4 +35,8 @@ class NullType private() extends DataType { private[spark] override def asNullable: NullType = this } +/** + * @since 1.3.0 + */ +@InterfaceStability.Stable case object NullType extends NullType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ShortType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ShortType.scala index 486cf585284df..3fee299d578cc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ShortType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ShortType.scala @@ -20,14 +20,15 @@ package org.apache.spark.sql.types import scala.math.{Integral, Numeric, Ordering} import scala.reflect.runtime.universe.typeTag -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.ScalaReflectionLock /** - * :: DeveloperApi :: * The data type representing `Short` values. Please use the singleton [[DataTypes.ShortType]]. + * + * @since 1.3.0 */ -@DeveloperApi +@InterfaceStability.Stable class ShortType private() extends IntegralType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "ShortType$" in byte code. @@ -48,4 +49,8 @@ class ShortType private() extends IntegralType { private[spark] override def asNullable: ShortType = this } +/** + * @since 1.3.0 + */ +@InterfaceStability.Stable case object ShortType extends ShortType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala index 44a25361f31c4..5d5a6f52a305b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -20,15 +20,16 @@ package org.apache.spark.sql.types import scala.math.Ordering import scala.reflect.runtime.universe.typeTag -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.ScalaReflectionLock import org.apache.spark.unsafe.types.UTF8String /** - * :: DeveloperApi :: * The data type representing `String` values. Please use the singleton [[DataTypes.StringType]]. + * + * @since 1.3.0 */ -@DeveloperApi +@InterfaceStability.Stable class StringType private() extends AtomicType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "StringType$" in byte code. @@ -45,5 +46,9 @@ class StringType private() extends AtomicType { private[spark] override def asNullable: StringType = this } +/** + * @since 1.3.0 + */ +@InterfaceStability.Stable case object StringType extends StringType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructField.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructField.scala index cb8bf616968e5..2c18fdcc497fe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructField.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructField.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.types import org.json4s.JsonAST.JValue import org.json4s.JsonDSL._ +import org.apache.spark.annotation.InterfaceStability + /** * A field inside a StructType. * @param name The name of this field. @@ -27,7 +29,10 @@ import org.json4s.JsonDSL._ * @param nullable Indicates if values of this field can be `null` values. * @param metadata The metadata of this field. The metadata should be preserved during * transformation if the content of the column is not modified, e.g, in selection. + * + * @since 1.3.0 */ +@InterfaceStability.Stable case class StructField( name: String, dataType: DataType, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index dd4c88c4c43bc..0205c13aa986d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -23,14 +23,13 @@ import scala.util.Try import org.json4s.JsonDSL._ import org.apache.spark.SparkException -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, InterpretedOrdering} import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, LegacyTypeStringParser} import org.apache.spark.sql.catalyst.util.quoteIdentifier import org.apache.spark.util.Utils /** - * :: DeveloperApi :: * A [[StructType]] object can be constructed by * {{{ * StructType(fields: Seq[StructField]) @@ -90,8 +89,10 @@ import org.apache.spark.util.Utils * val row = Row(Row(1, 2, true)) * // row: Row = [[1,2,true]] * }}} + * + * @since 1.3.0 */ -@DeveloperApi +@InterfaceStability.Stable case class StructType(fields: Array[StructField]) extends DataType with Seq[StructField] { /** No-arg constructor for kryo. */ @@ -138,7 +139,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru * .add("c", StringType) */ def add(name: String, dataType: DataType): StructType = { - StructType(fields :+ new StructField(name, dataType, nullable = true, Metadata.empty)) + StructType(fields :+ StructField(name, dataType, nullable = true, Metadata.empty)) } /** @@ -150,7 +151,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru * .add("c", StringType, true) */ def add(name: String, dataType: DataType, nullable: Boolean): StructType = { - StructType(fields :+ new StructField(name, dataType, nullable, Metadata.empty)) + StructType(fields :+ StructField(name, dataType, nullable, Metadata.empty)) } /** @@ -167,7 +168,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru dataType: DataType, nullable: Boolean, metadata: Metadata): StructType = { - StructType(fields :+ new StructField(name, dataType, nullable, metadata)) + StructType(fields :+ StructField(name, dataType, nullable, metadata)) } /** @@ -347,7 +348,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru private[sql] override def simpleString(maxNumberFields: Int): String = { val builder = new StringBuilder val fieldTypes = fields.take(maxNumberFields).map { - case f => s"${f.name}: ${f.dataType.simpleString(maxNumberFields)}" + f => s"${f.name}: ${f.dataType.simpleString(maxNumberFields)}" } builder.append("struct<") builder.append(fieldTypes.mkString(", ")) @@ -393,6 +394,10 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru InterpretedOrdering.forSchema(this.fields.map(_.dataType)) } +/** + * @since 1.3.0 + */ +@InterfaceStability.Stable object StructType extends AbstractDataType { /** @@ -469,7 +474,7 @@ object StructType extends AbstractDataType { nullable = leftNullable || rightNullable) } .orElse { - optionalMeta.putBoolean(metadataKeyForOptionalField, true) + optionalMeta.putBoolean(metadataKeyForOptionalField, value = true) Some(leftField.copy(metadata = optionalMeta.build())) } .foreach(newFields += _) @@ -479,7 +484,7 @@ object StructType extends AbstractDataType { rightFields .filterNot(f => leftMapped.get(f.name).nonEmpty) .foreach { f => - optionalMeta.putBoolean(metadataKeyForOptionalField, true) + optionalMeta.putBoolean(metadataKeyForOptionalField, value = true) newFields += f.copy(metadata = optionalMeta.build()) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TimestampType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TimestampType.scala index 2be9b2d76c9fe..4540d8358acad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TimestampType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TimestampType.scala @@ -20,16 +20,17 @@ package org.apache.spark.sql.types import scala.math.Ordering import scala.reflect.runtime.universe.typeTag -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.ScalaReflectionLock /** - * :: DeveloperApi :: * The data type representing `java.sql.Timestamp` values. * Please use the singleton [[DataTypes.TimestampType]]. + * + * @since 1.3.0 */ -@DeveloperApi +@InterfaceStability.Stable class TimestampType private() extends AtomicType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "TimestampType$" in byte code. @@ -48,4 +49,8 @@ class TimestampType private() extends AtomicType { private[spark] override def asNullable: TimestampType = this } +/** + * @since 1.3.0 + */ +@InterfaceStability.Stable case object TimestampType extends TimestampType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala index 894631382f8ce..c33219c95b50a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala @@ -22,8 +22,6 @@ import java.util.Objects import org.json4s.JsonAST.JValue import org.json4s.JsonDSL._ -import org.apache.spark.annotation.DeveloperApi - /** * The data type for User Defined Types (UDTs). * @@ -96,12 +94,10 @@ abstract class UserDefinedType[UserType >: Null] extends DataType with Serializa } /** - * :: DeveloperApi :: * The user defined type in Python. * * Note: This can only be accessed via Python UDF, or accessed as serialized object. */ -@DeveloperApi private[sql] class PythonUserDefinedType( val sqlType: DataType, override val pyUDT: String, From c8c090640ab73624841d0f4abcfd7409a0838725 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 11 Oct 2016 16:06:40 +0800 Subject: [PATCH 246/306] [SPARK-17821][SQL] Support And and Or in Expression Canonicalize ## What changes were proposed in this pull request? Currently `Canonicalize` object doesn't support `And` and `Or`. So we can compare canonicalized form of predicates consistently. We should add the support. ## How was this patch tested? Jenkins tests. Author: Liang-Chi Hsieh Closes #15388 from viirya/canonicalize-and-or. --- .../catalyst/expressions/Canonicalize.scala | 7 ++ .../expressions/ExpressionSetSuite.scala | 82 +++++++++++++++++++ 2 files changed, 89 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala index 07ba7d5e4a849..e876450c73fde 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala @@ -62,6 +62,13 @@ object Canonicalize extends { case a: Add => orderCommutative(a, { case Add(l, r) => Seq(l, r) }).reduce(Add) case m: Multiply => orderCommutative(m, { case Multiply(l, r) => Seq(l, r) }).reduce(Multiply) + case o: Or => + orderCommutative(o, { case Or(l, r) if l.deterministic && r.deterministic => Seq(l, r) }) + .reduce(Or) + case a: And => + orderCommutative(a, { case And(l, r) if l.deterministic && r.deterministic => Seq(l, r)}) + .reduce(And) + case EqualTo(l, r) if l.hashCode() > r.hashCode() => EqualTo(r, l) case EqualNullSafe(l, r) if l.hashCode() > r.hashCode() => EqualNullSafe(r, l) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala index 60939ee0eda5d..c587d4f632531 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala @@ -80,6 +80,88 @@ class ExpressionSetSuite extends SparkFunSuite { setTest(1, Not(aUpper >= 1), aUpper < 1, Not(Literal(1) <= aUpper), Literal(1) > aUpper) setTest(1, Not(aUpper <= 1), aUpper > 1, Not(Literal(1) >= aUpper), Literal(1) < aUpper) + // Reordering AND/OR expressions + setTest(1, aUpper > bUpper && aUpper <= 10, aUpper <= 10 && aUpper > bUpper) + setTest(1, + aUpper > bUpper && bUpper > 100 && aUpper <= 10, + bUpper > 100 && aUpper <= 10 && aUpper > bUpper) + + setTest(1, aUpper > bUpper || aUpper <= 10, aUpper <= 10 || aUpper > bUpper) + setTest(1, + aUpper > bUpper || bUpper > 100 || aUpper <= 10, + bUpper > 100 || aUpper <= 10 || aUpper > bUpper) + + setTest(1, + (aUpper <= 10 && aUpper > bUpper) || bUpper > 100, + bUpper > 100 || (aUpper <= 10 && aUpper > bUpper)) + + setTest(1, + aUpper >= bUpper || (aUpper > 10 && bUpper < 10), + (bUpper < 10 && aUpper > 10) || aUpper >= bUpper) + + // More complicated cases mixing AND/OR + // Three predicates in the following: + // (bUpper > 100) + // (aUpper < 100 && bUpper <= aUpper) + // (aUpper >= 10 && bUpper >= 50) + // They can be reordered and the sub-predicates contained in each of them can be reordered too. + setTest(1, + (bUpper > 100) || (aUpper < 100 && bUpper <= aUpper) || (aUpper >= 10 && bUpper >= 50), + (aUpper >= 10 && bUpper >= 50) || (bUpper > 100) || (aUpper < 100 && bUpper <= aUpper), + (bUpper >= 50 && aUpper >= 10) || (bUpper <= aUpper && aUpper < 100) || (bUpper > 100)) + + // Two predicates in the following: + // (bUpper > 100 && aUpper < 100 && bUpper <= aUpper) + // (aUpper >= 10 && bUpper >= 50) + setTest(1, + (bUpper > 100 && aUpper < 100 && bUpper <= aUpper) || (aUpper >= 10 && bUpper >= 50), + (aUpper >= 10 && bUpper >= 50) || (aUpper < 100 && bUpper > 100 && bUpper <= aUpper), + (bUpper >= 50 && aUpper >= 10) || (bUpper <= aUpper && aUpper < 100 && bUpper > 100)) + + // Three predicates in the following: + // (aUpper >= 10) + // (bUpper <= 10 && aUpper === bUpper && aUpper < 100) + // (bUpper >= 100) + setTest(1, + (aUpper >= 10) || (bUpper <= 10 && aUpper === bUpper && aUpper < 100) || (bUpper >= 100), + (aUpper === bUpper && aUpper < 100 && bUpper <= 10) || (bUpper >= 100) || (aUpper >= 10), + (aUpper < 100 && bUpper <= 10 && aUpper === bUpper) || (aUpper >= 10) || (bUpper >= 100), + ((bUpper <= 10 && aUpper === bUpper) && aUpper < 100) || ((aUpper >= 10) || (bUpper >= 100))) + + // Don't reorder non-deterministic expression in AND/OR. + setTest(2, Rand(1L) > aUpper && aUpper <= 10, aUpper <= 10 && Rand(1L) > aUpper) + setTest(2, + aUpper > bUpper && bUpper > 100 && Rand(1L) > aUpper, + bUpper > 100 && Rand(1L) > aUpper && aUpper > bUpper) + + setTest(2, Rand(1L) > aUpper || aUpper <= 10, aUpper <= 10 || Rand(1L) > aUpper) + setTest(2, + aUpper > bUpper || aUpper <= Rand(1L) || aUpper <= 10, + aUpper <= Rand(1L) || aUpper <= 10 || aUpper > bUpper) + + // Partial reorder case: we don't reorder non-deterministic expressions, + // but we can reorder sub-expressions in deterministic AND/OR expressions. + // There are two predicates: + // (aUpper > bUpper || bUpper > 100) => we can reorder sub-expressions in it. + // (aUpper === Rand(1L)) + setTest(1, + (aUpper > bUpper || bUpper > 100) && aUpper === Rand(1L), + (bUpper > 100 || aUpper > bUpper) && aUpper === Rand(1L)) + + // There are three predicates: + // (Rand(1L) > aUpper) + // (aUpper <= Rand(1L) && aUpper > bUpper) + // (aUpper > 10 && bUpper > 10) => we can reorder sub-expressions in it. + setTest(1, + Rand(1L) > aUpper || (aUpper <= Rand(1L) && aUpper > bUpper) || (aUpper > 10 && bUpper > 10), + Rand(1L) > aUpper || (aUpper <= Rand(1L) && aUpper > bUpper) || (bUpper > 10 && aUpper > 10)) + + // Same predicates as above, but a negative case when we reorder non-deterministic + // expression in (aUpper <= Rand(1L) && aUpper > bUpper). + setTest(2, + Rand(1L) > aUpper || (aUpper <= Rand(1L) && aUpper > bUpper) || (aUpper > 10 && bUpper > 10), + Rand(1L) > aUpper || (aUpper > bUpper && aUpper <= Rand(1L)) || (aUpper > 10 && bUpper > 10)) + test("add to / remove from set") { val initialSet = ExpressionSet(aUpper + 1 :: Nil) From 75b9e351413dca0930e8545e6283874db09d8482 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Tue, 11 Oct 2016 10:53:07 -0700 Subject: [PATCH 247/306] [SPARK-17346][SQL][TESTS] Fix the flaky topic deletion in KafkaSourceStressSuite ## What changes were proposed in this pull request? A follow up Pr for SPARK-17346 to fix flaky `org.apache.spark.sql.kafka010.KafkaSourceStressSuite`. Test log: https://amplab.cs.berkeley.edu/jenkins/job/spark-master-test-sbt-hadoop-2.4/1855/testReport/junit/org.apache.spark.sql.kafka010/KafkaSourceStressSuite/_It_is_not_a_test_/ Looks like deleting the Kafka internal topic `__consumer_offsets` is flaky. This PR just simply ignores internal topics. ## How was this patch tested? Existing tests. Author: Shixiong Zhu Closes #15384 from zsxwing/SPARK-17346-flaky-test. --- .../org/apache/spark/sql/kafka010/KafkaSourceSuite.scala | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala index 6c03070398fca..c640b93b0a2ee 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala @@ -22,7 +22,6 @@ import java.util.concurrent.atomic.AtomicInteger import scala.util.Random import org.apache.kafka.clients.producer.RecordMetadata -import org.scalatest.BeforeAndAfter import org.scalatest.time.SpanSugar._ import org.apache.spark.sql.execution.streaming._ @@ -344,7 +343,7 @@ class KafkaSourceSuite extends KafkaSourceTest { } -class KafkaSourceStressSuite extends KafkaSourceTest with BeforeAndAfter { +class KafkaSourceStressSuite extends KafkaSourceTest { import testImplicits._ @@ -358,12 +357,6 @@ class KafkaSourceStressSuite extends KafkaSourceTest with BeforeAndAfter { start + Random.nextInt(start + end - 1) } - after { - for (topic <- testUtils.getAllTopicsAndPartitionSize().toMap.keys) { - testUtils.deleteTopic(topic) - } - } - test("stress test with multiple topics and partitions") { topics.foreach { topic => testUtils.createTopic(topic, partitions = nextInt(1, 6)) From 07508bd01d16f3331be167ff92770d19c8b1f46a Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 11 Oct 2016 11:43:24 -0700 Subject: [PATCH 248/306] [SPARK-17817][PYSPARK] PySpark RDD Repartitioning Results in Highly Skewed Partition Sizes ## What changes were proposed in this pull request? Quoted from JIRA description: Calling repartition on a PySpark RDD to increase the number of partitions results in highly skewed partition sizes, with most having 0 rows. The repartition method should evenly spread out the rows across the partitions, and this behavior is correctly seen on the Scala side. Please reference the following code for a reproducible example of this issue: num_partitions = 20000 a = sc.parallelize(range(int(1e6)), 2) # start with 2 even partitions l = a.repartition(num_partitions).glom().map(len).collect() # get length of each partition min(l), max(l), sum(l)/len(l), len(l) # skewed! In Scala's `repartition` code, we will distribute elements evenly across output partitions. However, the RDD from Python is serialized as a single binary data, so the distribution fails. We need to convert the RDD in Python to java object before repartitioning. ## How was this patch tested? Jenkins tests. Author: Liang-Chi Hsieh Closes #15389 from viirya/pyspark-rdd-repartition. --- python/pyspark/rdd.py | 13 ++++++++++--- python/pyspark/tests.py | 10 ++++++++++ 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index ed81eb16df3cd..0e2ae19ca39aa 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -2017,8 +2017,7 @@ def repartition(self, numPartitions): >>> len(rdd.repartition(10).glom().collect()) 10 """ - jrdd = self._jrdd.repartition(numPartitions) - return RDD(jrdd, self.ctx, self._jrdd_deserializer) + return self.coalesce(numPartitions, shuffle=True) def coalesce(self, numPartitions, shuffle=False): """ @@ -2029,7 +2028,15 @@ def coalesce(self, numPartitions, shuffle=False): >>> sc.parallelize([1, 2, 3, 4, 5], 3).coalesce(1).glom().collect() [[1, 2, 3, 4, 5]] """ - jrdd = self._jrdd.coalesce(numPartitions, shuffle) + if shuffle: + # In Scala's repartition code, we will distribute elements evenly across output + # partitions. However, the RDD from Python is serialized as a single binary data, + # so the distribution fails and produces highly skewed partitions. We need to + # convert it to a RDD of java object before repartitioning. + data_java_rdd = self._to_java_object_rdd().coalesce(numPartitions, shuffle) + jrdd = self.ctx._jvm.SerDeUtil.javaToPython(data_java_rdd) + else: + jrdd = self._jrdd.coalesce(numPartitions, shuffle) return RDD(jrdd, self.ctx, self._jrdd_deserializer) def zip(self, other): diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index b0756911bfc10..3e0bd16d85ca4 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -914,6 +914,16 @@ def test_repartitionAndSortWithinPartitions(self): self.assertEqual(partitions[0], [(0, 5), (0, 8), (2, 6)]) self.assertEqual(partitions[1], [(1, 3), (3, 8), (3, 8)]) + def test_repartition_no_skewed(self): + num_partitions = 20 + a = self.sc.parallelize(range(int(1000)), 2) + l = a.repartition(num_partitions).glom().map(len).collect() + zeros = len([x for x in l if x == 0]) + self.assertTrue(zeros == 0) + l = a.coalesce(num_partitions, True).glom().map(len).collect() + zeros = len([x for x in l if x == 0]) + self.assertTrue(zeros == 0) + def test_distinct(self): rdd = self.sc.parallelize((1, 2, 3)*10, 10) self.assertEqual(rdd.getNumPartitions(), 10) From 23405f324a8089f86ebcbede9bb32944137508e8 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 11 Oct 2016 12:41:35 -0700 Subject: [PATCH 249/306] [SPARK-15153][ML][SPARKR] Fix SparkR spark.naiveBayes error when label is numeric type ## What changes were proposed in this pull request? Fix SparkR ```spark.naiveBayes``` error when response variable of dataset is numeric type. See details and how to reproduce this bug at [SPARK-15153](https://issues.apache.org/jira/browse/SPARK-15153). ## How was this patch tested? Add unit test. Author: Yanbo Liang Closes #15431 from yanboliang/spark-15153-2. --- R/pkg/inst/tests/testthat/test_mllib.R | 10 ++++++++++ .../org/apache/spark/ml/r/NaiveBayesWrapper.scala | 1 + 2 files changed, 11 insertions(+) diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index a1eaaf20916a2..c99315726a22c 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -481,6 +481,16 @@ test_that("spark.naiveBayes", { expect_error(m <- e1071::naiveBayes(Survived ~ ., data = t1), NA) expect_equal(as.character(predict(m, t1[1, ])), "Yes") } + + # Test numeric response variable + t1$NumericSurvived <- ifelse(t1$Survived == "No", 0, 1) + t2 <- t1[-4] + df <- suppressWarnings(createDataFrame(t2)) + m <- spark.naiveBayes(df, NumericSurvived ~ ., smoothing = 0.0) + s <- summary(m) + expect_equal(as.double(s$apriori[1, 1]), 0.5833333, tolerance = 1e-6) + expect_equal(sum(s$apriori), 1) + expect_equal(as.double(s$tables[1, "Age_Adult"]), 0.5714286, tolerance = 1e-6) }) test_that("spark.survreg", { diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala index d1a39fea76ef8..4fdab2dd94655 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala @@ -59,6 +59,7 @@ private[r] object NaiveBayesWrapper extends MLReadable[NaiveBayesWrapper] { def fit(formula: String, data: DataFrame, smoothing: Double): NaiveBayesWrapper = { val rFormula = new RFormula() .setFormula(formula) + .setForceIndexLabel(true) RWrapperUtils.checkDataColumns(rFormula, data) val rFormulaModel = rFormula.fit(data) // get labels and feature names from output schema From 5b77e66dd6a128c5992ab3bde418613f84be7009 Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Tue, 11 Oct 2016 14:56:26 -0700 Subject: [PATCH 250/306] [SPARK-17387][PYSPARK] Creating SparkContext() from python without spark-submit ignores user conf ## What changes were proposed in this pull request? The root cause that we would ignore SparkConf when launching JVM is that SparkConf require JVM to be created first. https://github.com/apache/spark/blob/master/python/pyspark/conf.py#L106 In this PR, I would defer the launching of JVM until SparkContext is created so that we can pass SparkConf to JVM correctly. ## How was this patch tested? Use the example code in the description of SPARK-17387, ``` $ SPARK_HOME=$PWD PYTHONPATH=python:python/lib/py4j-0.10.3-src.zip python Python 2.7.12 (default, Jul 1 2016, 15:12:24) [GCC 5.4.0 20160609] on linux2 Type "help", "copyright", "credits" or "license" for more information. >>> from pyspark import SparkContext >>> from pyspark import SparkConf >>> conf = SparkConf().set("spark.driver.memory", "4g") >>> sc = SparkContext(conf=conf) ``` And verify the spark.driver.memory is correctly picked up. ``` ...op/ -Xmx4g org.apache.spark.deploy.SparkSubmit --conf spark.driver.memory=4g pyspark-shell ``` Author: Jeff Zhang Closes #14959 from zjffdu/SPARK-17387. --- python/pyspark/conf.py | 71 +++++++++++++++++++++++++--------- python/pyspark/context.py | 16 ++++++-- python/pyspark/java_gateway.py | 13 ++++++- 3 files changed, 75 insertions(+), 25 deletions(-) diff --git a/python/pyspark/conf.py b/python/pyspark/conf.py index 924da3eecf214..64b6f238e9c32 100644 --- a/python/pyspark/conf.py +++ b/python/pyspark/conf.py @@ -52,6 +52,14 @@ >>> sorted(conf.getAll(), key=lambda p: p[0]) [(u'spark.executorEnv.VAR1', u'value1'), (u'spark.executorEnv.VAR3', u'value3'), \ (u'spark.executorEnv.VAR4', u'value4'), (u'spark.home', u'/path')] +>>> conf._jconf.setExecutorEnv("VAR5", "value5") +JavaObject id... +>>> print(conf.toDebugString()) +spark.executorEnv.VAR1=value1 +spark.executorEnv.VAR3=value3 +spark.executorEnv.VAR4=value4 +spark.executorEnv.VAR5=value5 +spark.home=/path """ __all__ = ['SparkConf'] @@ -101,13 +109,24 @@ def __init__(self, loadDefaults=True, _jvm=None, _jconf=None): self._jconf = _jconf else: from pyspark.context import SparkContext - SparkContext._ensure_initialized() _jvm = _jvm or SparkContext._jvm - self._jconf = _jvm.SparkConf(loadDefaults) + + if _jvm is not None: + # JVM is created, so create self._jconf directly through JVM + self._jconf = _jvm.SparkConf(loadDefaults) + self._conf = None + else: + # JVM is not created, so store data in self._conf first + self._jconf = None + self._conf = {} def set(self, key, value): """Set a configuration property.""" - self._jconf.set(key, unicode(value)) + # Try to set self._jconf first if JVM is created, set self._conf if JVM is not created yet. + if self._jconf is not None: + self._jconf.set(key, unicode(value)) + else: + self._conf[key] = unicode(value) return self def setIfMissing(self, key, value): @@ -118,17 +137,17 @@ def setIfMissing(self, key, value): def setMaster(self, value): """Set master URL to connect to.""" - self._jconf.setMaster(value) + self.set("spark.master", value) return self def setAppName(self, value): """Set application name.""" - self._jconf.setAppName(value) + self.set("spark.app.name", value) return self def setSparkHome(self, value): """Set path where Spark is installed on worker nodes.""" - self._jconf.setSparkHome(value) + self.set("spark.home", value) return self def setExecutorEnv(self, key=None, value=None, pairs=None): @@ -136,10 +155,10 @@ def setExecutorEnv(self, key=None, value=None, pairs=None): if (key is not None and pairs is not None) or (key is None and pairs is None): raise Exception("Either pass one key-value pair or a list of pairs") elif key is not None: - self._jconf.setExecutorEnv(key, value) + self.set("spark.executorEnv." + key, value) elif pairs is not None: for (k, v) in pairs: - self._jconf.setExecutorEnv(k, v) + self.set("spark.executorEnv." + k, v) return self def setAll(self, pairs): @@ -149,35 +168,49 @@ def setAll(self, pairs): :param pairs: list of key-value pairs to set """ for (k, v) in pairs: - self._jconf.set(k, v) + self.set(k, v) return self def get(self, key, defaultValue=None): """Get the configured value for some key, or return a default otherwise.""" if defaultValue is None: # Py4J doesn't call the right get() if we pass None - if not self._jconf.contains(key): - return None - return self._jconf.get(key) + if self._jconf is not None: + if not self._jconf.contains(key): + return None + return self._jconf.get(key) + else: + if key not in self._conf: + return None + return self._conf[key] else: - return self._jconf.get(key, defaultValue) + if self._jconf is not None: + return self._jconf.get(key, defaultValue) + else: + return self._conf.get(key, defaultValue) def getAll(self): """Get all values as a list of key-value pairs.""" - pairs = [] - for elem in self._jconf.getAll(): - pairs.append((elem._1(), elem._2())) - return pairs + if self._jconf is not None: + return [(elem._1(), elem._2()) for elem in self._jconf.getAll()] + else: + return self._conf.items() def contains(self, key): """Does this configuration contain a given key?""" - return self._jconf.contains(key) + if self._jconf is not None: + return self._jconf.contains(key) + else: + return key in self._conf def toDebugString(self): """ Returns a printable version of the configuration, as a list of key=value pairs, one per line. """ - return self._jconf.toDebugString() + if self._jconf is not None: + return self._jconf.toDebugString() + else: + return '\n'.join('%s=%s' % (k, v) for k, v in self._conf.items()) def _test(): diff --git a/python/pyspark/context.py b/python/pyspark/context.py index a3dd1950a522f..1b2e199c395be 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -109,7 +109,7 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, ValueError:... """ self._callsite = first_spark_call() or CallSite(None, None, None) - SparkContext._ensure_initialized(self, gateway=gateway) + SparkContext._ensure_initialized(self, gateway=gateway, conf=conf) try: self._do_init(master, appName, sparkHome, pyFiles, environment, batchSize, serializer, conf, jsc, profiler_cls) @@ -121,7 +121,15 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, serializer, conf, jsc, profiler_cls): self.environment = environment or {} - self._conf = conf or SparkConf(_jvm=self._jvm) + # java gateway must have been launched at this point. + if conf is not None and conf._jconf is not None: + # conf has been initialized in JVM properly, so use conf directly. This represent the + # scenario that JVM has been launched before SparkConf is created (e.g. SparkContext is + # created and then stopped, and we create a new SparkConf and new SparkContext again) + self._conf = conf + else: + self._conf = SparkConf(_jvm=SparkContext._jvm) + self._batchSize = batchSize # -1 represents an unlimited batch size self._unbatched_serializer = serializer if batchSize == 0: @@ -232,14 +240,14 @@ def _initialize_context(self, jconf): return self._jvm.JavaSparkContext(jconf) @classmethod - def _ensure_initialized(cls, instance=None, gateway=None): + def _ensure_initialized(cls, instance=None, gateway=None, conf=None): """ Checks whether a SparkContext is initialized or not. Throws error if a SparkContext is already running. """ with SparkContext._lock: if not SparkContext._gateway: - SparkContext._gateway = gateway or launch_gateway() + SparkContext._gateway = gateway or launch_gateway(conf) SparkContext._jvm = SparkContext._gateway.jvm if instance: diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index f76cadcf62438..c1cf843d84388 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -32,7 +32,12 @@ from pyspark.serializers import read_int -def launch_gateway(): +def launch_gateway(conf=None): + """ + launch jvm gateway + :param conf: spark configuration passed to spark-submit + :return: + """ if "PYSPARK_GATEWAY_PORT" in os.environ: gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"]) else: @@ -41,13 +46,17 @@ def launch_gateway(): # proper classpath and settings from spark-env.sh on_windows = platform.system() == "Windows" script = "./bin/spark-submit.cmd" if on_windows else "./bin/spark-submit" + command = [os.path.join(SPARK_HOME, script)] + if conf: + for k, v in conf.getAll(): + command += ['--conf', '%s=%s' % (k, v)] submit_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell") if os.environ.get("SPARK_TESTING"): submit_args = ' '.join([ "--conf spark.ui.enabled=false", submit_args ]) - command = [os.path.join(SPARK_HOME, script)] + shlex.split(submit_args) + command = command + shlex.split(submit_args) # Start a socket that will be used by PythonGatewayServer to communicate its port to us callback_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) From b9a147181d5e38d9abed0c7215f4c5cb695f579c Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 11 Oct 2016 20:27:08 -0700 Subject: [PATCH 251/306] [SPARK-17720][SQL] introduce static SQL conf ## What changes were proposed in this pull request? SQLConf is session-scoped and mutable. However, we do have the requirement for a static SQL conf, which is global and immutable, e.g. the `schemaStringThreshold` in `HiveExternalCatalog`, the flag to enable/disable hive support, the global temp view database in https://github.com/apache/spark/pull/14897. Actually we've already implemented static SQL conf implicitly via `SparkConf`, this PR just make it explicit and expose it to users, so that they can see the config value via SQL command or `SparkSession.conf`, and forbid users to set/unset static SQL conf. ## How was this patch tested? new tests in SQLConfSuite Author: Wenchen Fan Closes #15295 from cloud-fan/global-conf. --- R/pkg/inst/tests/testthat/test_sparkSQL.R | 2 +- .../spark/internal/config/package.scala | 14 ----- python/pyspark/sql/session.py | 2 +- .../scala/org/apache/spark/repl/Main.scala | 2 +- .../org/apache/spark/repl/ReplSuite.scala | 2 +- .../sql/catalyst/catalog/SessionCatalog.scala | 3 +- .../org/apache/spark/sql/RuntimeConfig.scala | 11 +++- .../org/apache/spark/sql/SparkSession.scala | 8 +-- .../org/apache/spark/sql/api/r/SQLUtils.scala | 4 +- .../apache/spark/sql/internal/SQLConf.scala | 56 +++++++++++++------ .../spark/sql/internal/SharedState.scala | 1 + .../sql/execution/command/DDLSuite.scala | 2 +- .../spark/sql/internal/SQLConfSuite.scala | 23 +++++++- .../spark/sql/hive/HiveExternalCatalog.scala | 7 +-- .../org/apache/spark/sql/hive/HiveUtils.scala | 3 +- .../apache/spark/sql/hive/test/TestHive.scala | 2 +- .../sql/hive/MetastoreDataSourcesSuite.scala | 45 ++++++++------- .../sql/hive/execution/HiveDDLSuite.scala | 2 +- 18 files changed, 111 insertions(+), 78 deletions(-) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 6d8cfad5c1f93..61554248ee8f8 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -2609,7 +2609,7 @@ test_that("enableHiveSupport on SparkSession", { unsetHiveContext() # if we are still here, it must be built with hive conf <- callJMethod(sparkSession, "conf") - value <- callJMethod(conf, "get", "spark.sql.catalogImplementation", "") + value <- callJMethod(conf, "get", "spark.sql.catalogImplementation") expect_equal(value, "hive") }) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 0896e68eca7dc..5a710158db89f 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -91,20 +91,6 @@ package object config { .toSequence .createWithDefault(Nil) - // Note: This is a SQL config but needs to be in core because the REPL depends on it - private[spark] val CATALOG_IMPLEMENTATION = ConfigBuilder("spark.sql.catalogImplementation") - .internal() - .stringConf - .checkValues(Set("hive", "in-memory")) - .createWithDefault("in-memory") - - // Note: This is a SQL config but needs to be in core because it's cross-session and can not put - // in SQLConf. - private[spark] val GLOBAL_TEMP_DATABASE = ConfigBuilder("spark.sql.globalTempDatabase") - .internal() - .stringConf - .createWithDefault("global_temp") - private[spark] val LISTENER_BUS_EVENT_QUEUE_SIZE = ConfigBuilder("spark.scheduler.listenerbus.eventqueue.size") .intConf diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 8418abf99c8d5..1e40b9c39fc4f 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -176,7 +176,7 @@ def getOrCreate(self): sc._conf.set(key, value) session = SparkSession(sc) for key, value in self._options.items(): - session.conf.set(key, value) + session._jsparkSession.sessionState().conf().setConfString(key, value) for key, value in self._options.items(): session.sparkContext._conf.set(key, value) return session diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala index 5dfe18ad49822..fec4d49379591 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala @@ -22,9 +22,9 @@ import java.io.File import scala.tools.nsc.GenericRunnerSettings import org.apache.spark._ -import org.apache.spark.internal.config.CATALOG_IMPLEMENTATION import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.util.Utils object Main extends Logging { diff --git a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala index f7d7a4f041315..9262e938c2a60 100644 --- a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -24,8 +24,8 @@ import scala.collection.mutable.ArrayBuffer import org.apache.commons.lang3.StringEscapeUtils import org.apache.log4j.{Level, LogManager} import org.apache.spark.{SparkContext, SparkFunSuite} -import org.apache.spark.internal.config._ import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.util.Utils class ReplSuite extends SparkFunSuite { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 5863c6a71cdf9..fe41c41a6eb20 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -25,7 +25,6 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.spark.internal.Logging -import org.apache.spark.internal.config.GLOBAL_TEMP_DATABASE import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{CatalystConf, SimpleCatalystConf} import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} @@ -63,7 +62,7 @@ class SessionCatalog( conf: CatalystConf) { this( externalCatalog, - new GlobalTempViewManager(GLOBAL_TEMP_DATABASE.defaultValueString), + new GlobalTempViewManager("global_temp"), DummyFunctionResourceLoader, functionRegistry, conf, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala b/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala index c2baa74ed7d2e..9108d19d0a0c2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import org.apache.spark.annotation.InterfaceStability import org.apache.spark.internal.config.{ConfigEntry, OptionalConfigEntry} -import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} /** @@ -38,6 +38,7 @@ class RuntimeConfig private[sql](sqlConf: SQLConf = new SQLConf) { * @since 2.0.0 */ def set(key: String, value: String): Unit = { + requireNonStaticConf(key) sqlConf.setConfString(key, value) } @@ -47,6 +48,7 @@ class RuntimeConfig private[sql](sqlConf: SQLConf = new SQLConf) { * @since 2.0.0 */ def set(key: String, value: Boolean): Unit = { + requireNonStaticConf(key) set(key, value.toString) } @@ -56,6 +58,7 @@ class RuntimeConfig private[sql](sqlConf: SQLConf = new SQLConf) { * @since 2.0.0 */ def set(key: String, value: Long): Unit = { + requireNonStaticConf(key) set(key, value.toString) } @@ -124,6 +127,7 @@ class RuntimeConfig private[sql](sqlConf: SQLConf = new SQLConf) { * @since 2.0.0 */ def unset(key: String): Unit = { + requireNonStaticConf(key) sqlConf.unsetConf(key) } @@ -134,4 +138,9 @@ class RuntimeConfig private[sql](sqlConf: SQLConf = new SQLConf) { sqlConf.contains(key) } + private def requireNonStaticConf(key: String): Unit = { + if (StaticSQLConf.globalConfKeys.contains(key)) { + throw new AnalysisException(s"Cannot modify the value of a static config: $key") + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index d26eea507284c..137c426b4b88d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -29,7 +29,6 @@ import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext} import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging -import org.apache.spark.internal.config.CATALOG_IMPLEMENTATION import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd} import org.apache.spark.sql.catalog.Catalog @@ -41,6 +40,7 @@ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.ui.SQLListener import org.apache.spark.sql.internal.{CatalogImpl, SessionState, SharedState} +import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.streaming._ import org.apache.spark.sql.types.{DataType, LongType, StructType} @@ -812,7 +812,7 @@ object SparkSession { // Get the session from current thread's active session. var session = activeThreadSession.get() if ((session ne null) && !session.sparkContext.isStopped) { - options.foreach { case (k, v) => session.conf.set(k, v) } + options.foreach { case (k, v) => session.sessionState.conf.setConfString(k, v) } if (options.nonEmpty) { logWarning("Use an existing SparkSession, some configuration may not take effect.") } @@ -824,7 +824,7 @@ object SparkSession { // If the current thread does not have an active session, get it from the global session. session = defaultSession.get() if ((session ne null) && !session.sparkContext.isStopped) { - options.foreach { case (k, v) => session.conf.set(k, v) } + options.foreach { case (k, v) => session.sessionState.conf.setConfString(k, v) } if (options.nonEmpty) { logWarning("Use an existing SparkSession, some configuration may not take effect.") } @@ -850,7 +850,7 @@ object SparkSession { sc } session = new SparkSession(sparkContext) - options.foreach { case (k, v) => session.conf.set(k, v) } + options.foreach { case (k, v) => session.sessionState.conf.setConfString(k, v) } defaultSession.set(session) // Register a successfully instantiated context to the singleton. This should be at the diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index 7d8ea03a27910..9de6510c634b3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -28,11 +28,11 @@ import org.apache.spark.SparkContext import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.api.r.SerDe import org.apache.spark.broadcast.Broadcast -import org.apache.spark.internal.config.CATALOG_IMPLEMENTATION import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema import org.apache.spark.sql.execution.command.ShowTablesCommand +import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.sql.types._ private[sql] object SQLUtils extends Logging { @@ -64,7 +64,7 @@ private[sql] object SQLUtils extends Logging { spark: SparkSession, sparkConfigMap: JMap[Object, Object]): Unit = { for ((name, value) <- sparkConfigMap.asScala) { - spark.conf.set(name.toString, value.toString) + spark.sessionState.conf.setConfString(name.toString, value.toString) } for ((name, value) <- sparkConfigMap.asScala) { spark.sparkContext.conf.set(name.toString, value.toString) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index fecdf792fd14a..8cbfc4c7628f7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -41,7 +41,7 @@ object SQLConf { private val sqlConfEntries = java.util.Collections.synchronizedMap( new java.util.HashMap[String, ConfigEntry[_]]()) - private def register(entry: ConfigEntry[_]): Unit = sqlConfEntries.synchronized { + private[sql] def register(entry: ConfigEntry[_]): Unit = sqlConfEntries.synchronized { require(!sqlConfEntries.containsKey(entry.key), s"Duplicate SQLConfigEntry. ${entry.key} has been registered") sqlConfEntries.put(entry.key, entry) @@ -326,18 +326,6 @@ object SQLConf { .booleanConf .createWithDefault(true) - // This is used to control the when we will split a schema's JSON string to multiple pieces - // in order to fit the JSON string in metastore's table property (by default, the value has - // a length restriction of 4000 characters). We will split the JSON string of a schema - // to its length exceeds the threshold. - val SCHEMA_STRING_LENGTH_THRESHOLD = - SQLConfigBuilder("spark.sql.sources.schemaStringLengthThreshold") - .doc("The maximum length allowed in a single cell when " + - "storing additional schema information in Hive's metastore.") - .internal() - .intConf - .createWithDefault(4000) - val PARTITION_COLUMN_TYPE_INFERENCE = SQLConfigBuilder("spark.sql.sources.partitionColumnTypeInference.enabled") .doc("When true, automatically infer the data types for partitioned columns.") @@ -736,10 +724,6 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def bucketingEnabled: Boolean = getConf(SQLConf.BUCKETING_ENABLED) - // Do not use a value larger than 4000 as the default value of this property. - // See the comments of SCHEMA_STRING_LENGTH_THRESHOLD above for more information. - def schemaStringLengthThreshold: Int = getConf(SCHEMA_STRING_LENGTH_THRESHOLD) - def dataFrameEagerAnalysis: Boolean = getConf(DATAFRAME_EAGER_ANALYSIS) def dataFrameSelfJoinAutoResolveAmbiguity: Boolean = @@ -886,3 +870,41 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { } } +/** + * Static SQL configuration is a cross-session, immutable Spark configuration. External users can + * see the static sql configs via `SparkSession.conf`, but can NOT set/unset them. + */ +object StaticSQLConf { + val globalConfKeys = java.util.Collections.synchronizedSet(new java.util.HashSet[String]()) + + private def buildConf(key: String): ConfigBuilder = { + ConfigBuilder(key).onCreate { entry => + globalConfKeys.add(entry.key) + SQLConf.register(entry) + } + } + + val CATALOG_IMPLEMENTATION = buildConf("spark.sql.catalogImplementation") + .internal() + .stringConf + .checkValues(Set("hive", "in-memory")) + .createWithDefault("in-memory") + + val GLOBAL_TEMP_DATABASE = buildConf("spark.sql.globalTempDatabase") + .internal() + .stringConf + .createWithDefault("global_temp") + + // This is used to control when we will split a schema's JSON string to multiple pieces + // in order to fit the JSON string in metastore's table property (by default, the value has + // a length restriction of 4000 characters, so do not use a value larger than 4000 as the default + // value of this property). We will split the JSON string of a schema to its length exceeds the + // threshold. Note that, this conf is only read in HiveExternalCatalog which is cross-session, + // that's why this conf has to be a static SQL conf. + val SCHEMA_STRING_LENGTH_THRESHOLD = buildConf("spark.sql.sources.schemaStringLengthThreshold") + .doc("The maximum length allowed in a single cell when " + + "storing additional schema information in Hive's metastore.") + .internal() + .intConf + .createWithDefault(4000) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala index c555a43cd2581..c6083b372a2db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.{SparkSession, SQLContext} import org.apache.spark.sql.catalyst.catalog.{ExternalCatalog, GlobalTempViewManager, InMemoryCatalog} import org.apache.spark.sql.execution.CacheManager import org.apache.spark.sql.execution.ui.{SQLListener, SQLTab} +import org.apache.spark.sql.internal.StaticSQLConf._ import org.apache.spark.util.{MutableURLClassLoader, Utils} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 19885156cc722..097dc2441351f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -22,7 +22,6 @@ import java.io.File import org.apache.hadoop.fs.Path import org.scalatest.BeforeAndAfterEach -import org.apache.spark.internal.config._ import org.apache.spark.sql.{AnalysisException, QueryTest, Row, SaveMode} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{DatabaseAlreadyExistsException, FunctionRegistry, NoSuchPartitionException, NoSuchTableException, TempTableAlreadyExistsException} @@ -31,6 +30,7 @@ import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType} import org.apache.spark.sql.catalyst.catalog.{CatalogTablePartition, SessionCatalog} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala index 3c60b233c2b04..f545de0e10a6b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala @@ -19,11 +19,14 @@ package org.apache.spark.sql.internal import org.apache.hadoop.fs.Path -import org.apache.spark.sql.{QueryTest, Row, SparkSession, SQLContext} +import org.apache.spark.sql._ import org.apache.spark.sql.execution.WholeStageCodegenExec +import org.apache.spark.sql.internal.StaticSQLConf._ import org.apache.spark.sql.test.{SharedSQLContext, TestSQLContext} class SQLConfSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + private val testKey = "test.key.0" private val testVal = "test.val.0" @@ -250,4 +253,22 @@ class SQLConfSuite extends QueryTest with SharedSQLContext { } } } + + test("global SQL conf comes from SparkConf") { + val newSession = SparkSession.builder() + .config(SCHEMA_STRING_LENGTH_THRESHOLD.key, "2000") + .getOrCreate() + + assert(newSession.conf.get(SCHEMA_STRING_LENGTH_THRESHOLD.key) == "2000") + checkAnswer( + newSession.sql(s"SET ${SCHEMA_STRING_LENGTH_THRESHOLD.key}"), + Row(SCHEMA_STRING_LENGTH_THRESHOLD.key, "2000")) + } + + test("cannot set/unset global SQL conf") { + val e1 = intercept[AnalysisException](sql(s"SET ${SCHEMA_STRING_LENGTH_THRESHOLD.key}=10")) + assert(e1.message.contains("Cannot modify the value of a static config")) + val e2 = intercept[AnalysisException](spark.conf.unset(SCHEMA_STRING_LENGTH_THRESHOLD.key)) + assert(e2.message.contains("Cannot modify the value of a static config")) + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 261cc6feff090..e1c0cad907b98 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -37,6 +37,7 @@ import org.apache.spark.sql.execution.command.{ColumnStatStruct, DDLUtils} import org.apache.spark.sql.execution.datasources.CaseInsensitiveMap import org.apache.spark.sql.hive.client.HiveClient import org.apache.spark.sql.internal.HiveSerDe +import org.apache.spark.sql.internal.StaticSQLConf.SCHEMA_STRING_LENGTH_THRESHOLD import org.apache.spark.sql.types.{DataType, StructType} @@ -201,11 +202,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // Serialized JSON schema string may be too long to be stored into a single metastore table // property. In this case, we split the JSON string and store each part as a separate table // property. - // TODO: the threshold should be set by `spark.sql.sources.schemaStringLengthThreshold`, - // however the current SQLConf is session isolated, which is not applicable to external - // catalog. We should re-enable this conf instead of hard code the value here, after we have - // global SQLConf. - val threshold = 4000 + val threshold = conf.get(SCHEMA_STRING_LENGTH_THRESHOLD) val schemaJsonString = tableDefinition.schema.json // Split the JSON string. val parts = schemaJsonString.grouped(threshold).toSeq diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala index 39d71e164bf51..a5ef8723c8b6f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala @@ -23,7 +23,6 @@ import java.nio.charset.StandardCharsets import java.sql.Timestamp import java.util.concurrent.TimeUnit -import scala.collection.JavaConverters._ import scala.collection.mutable.HashMap import scala.language.implicitConversions @@ -36,11 +35,11 @@ import org.apache.hadoop.util.VersionInfo import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.internal.Logging -import org.apache.spark.internal.config.CATALOG_IMPLEMENTATION import org.apache.spark.sql._ import org.apache.spark.sql.hive.client._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf._ +import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.sql.types._ import org.apache.spark.util.Utils diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 163f210802b53..6eb571b91ffab 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -30,7 +30,6 @@ import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.internal.Logging -import org.apache.spark.internal.config.CATALOG_IMPLEMENTATION import org.apache.spark.sql.{SparkSession, SQLContext} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder @@ -40,6 +39,7 @@ import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.command.CacheTableCommand import org.apache.spark.sql.hive._ import org.apache.spark.sql.internal.{SharedState, SQLConf} +import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.util.{ShutdownHookManager, Utils} // SPARK-3729: Test key required to check for initialization errors with config. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index 51670649ad1d4..0477122fc6a27 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -26,12 +26,12 @@ import org.apache.hadoop.fs.Path import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType} -import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} import org.apache.spark.sql.hive.HiveExternalCatalog._ import org.apache.spark.sql.hive.client.HiveClient import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.StaticSQLConf.SCHEMA_STRING_LENGTH_THRESHOLD import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -699,28 +699,27 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv } test("SPARK-6024 wide schema support") { - withSQLConf(SQLConf.SCHEMA_STRING_LENGTH_THRESHOLD.key -> "4000") { - withTable("wide_schema") { - withTempDir { tempDir => - // We will need 80 splits for this schema if the threshold is 4000. - val schema = StructType((1 to 5000).map(i => StructField(s"c_$i", StringType))) - - val tableDesc = CatalogTable( - identifier = TableIdentifier("wide_schema"), - tableType = CatalogTableType.EXTERNAL, - storage = CatalogStorageFormat.empty.copy( - properties = Map("path" -> tempDir.getCanonicalPath) - ), - schema = schema, - provider = Some("json") - ) - spark.sessionState.catalog.createTable(tableDesc, ignoreIfExists = false) - - sessionState.refreshTable("wide_schema") - - val actualSchema = table("wide_schema").schema - assert(schema === actualSchema) - } + assert(spark.sparkContext.conf.get(SCHEMA_STRING_LENGTH_THRESHOLD) == 4000) + withTable("wide_schema") { + withTempDir { tempDir => + // We will need 80 splits for this schema if the threshold is 4000. + val schema = StructType((1 to 5000).map(i => StructField(s"c_$i", StringType))) + + val tableDesc = CatalogTable( + identifier = TableIdentifier("wide_schema"), + tableType = CatalogTableType.EXTERNAL, + storage = CatalogStorageFormat.empty.copy( + properties = Map("path" -> tempDir.getCanonicalPath) + ), + schema = schema, + provider = Some("json") + ) + spark.sessionState.catalog.createTable(tableDesc, ignoreIfExists = false) + + sessionState.refreshTable("wide_schema") + + val actualSchema = table("wide_schema").schema + assert(schema === actualSchema) } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 751e976c7b908..8bff6de008fdb 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -22,7 +22,6 @@ import java.io.File import org.apache.hadoop.fs.Path import org.scalatest.BeforeAndAfterEach -import org.apache.spark.internal.config._ import org.apache.spark.sql.{AnalysisException, QueryTest, Row, SaveMode} import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogTable, CatalogTableType} @@ -32,6 +31,7 @@ import org.apache.spark.sql.execution.datasources.CaseInsensitiveMap import org.apache.spark.sql.hive.HiveExternalCatalog import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.sql.test.SQLTestUtils class HiveDDLSuite From 299eb04ba05038c7dbb3ecf74a35d4bbfa456643 Mon Sep 17 00:00:00 2001 From: Alexander Pivovarov Date: Tue, 11 Oct 2016 22:31:21 -0700 Subject: [PATCH 252/306] Fix hadoop.version in building-spark.md Couple of mvn build examples use `-Dhadoop.version=VERSION` instead of actual version number Author: Alexander Pivovarov Closes #15440 from apivovarov/patch-1. --- docs/building-spark.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/building-spark.md b/docs/building-spark.md index da7eeb8348378..f5acee6b90059 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -91,13 +91,13 @@ Examples: ./build/mvn -Pyarn -Phadoop-2.3 -Dhadoop.version=2.3.0 -DskipTests clean package # Apache Hadoop 2.4.X or 2.5.X - ./build/mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=VERSION -DskipTests clean package + ./build/mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=2.4.0 -DskipTests clean package # Apache Hadoop 2.6.X ./build/mvn -Pyarn -Phadoop-2.6 -Dhadoop.version=2.6.0 -DskipTests clean package # Apache Hadoop 2.7.X and later - ./build/mvn -Pyarn -Phadoop-2.7 -Dhadoop.version=VERSION -DskipTests clean package + ./build/mvn -Pyarn -Phadoop-2.7 -Dhadoop.version=2.7.0 -DskipTests clean package # Different versions of HDFS and YARN. ./build/mvn -Pyarn -Phadoop-2.3 -Dhadoop.version=2.3.0 -Dyarn.version=2.2.0 -DskipTests clean package From b512f04f8e546843d5a3f35dcc6b675b5f4f5bc0 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Tue, 11 Oct 2016 22:36:57 -0700 Subject: [PATCH 253/306] [SPARK-17880][DOC] The url linking to `AccumulatorV2` in the document is incorrect. ## What changes were proposed in this pull request? In `programming-guide.md`, the url which links to `AccumulatorV2` says `api/scala/index.html#org.apache.spark.AccumulatorV2` but `api/scala/index.html#org.apache.spark.util.AccumulatorV2` is correct. ## How was this patch tested? manual test. Author: Kousuke Saruta Closes #15439 from sarutak/SPARK-17880. --- docs/programming-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/programming-guide.md b/docs/programming-guide.md index 74d5ee1ca6b3f..20b4bee0f58e1 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -1373,7 +1373,7 @@ res2: Long = 10 {% endhighlight %} While this code used the built-in support for accumulators of type Long, programmers can also -create their own types by subclassing [AccumulatorV2](api/scala/index.html#org.apache.spark.AccumulatorV2). +create their own types by subclassing [AccumulatorV2](api/scala/index.html#org.apache.spark.util.AccumulatorV2). The AccumulatorV2 abstract class has several methods which need to override: `reset` for resetting the accumulator to zero, and `add` for add anothor value into the accumulator, `merge` for merging another same-type accumulator into this one. Other methods need to override can refer to scala API document. For example, supposing we had a `MyVector` class representing mathematical vectors, we could write: From c264ef9b1918256a5018c7a42a1a2b42308ea3f7 Mon Sep 17 00:00:00 2001 From: cody koeninger Date: Wed, 12 Oct 2016 00:40:47 -0700 Subject: [PATCH 254/306] [SPARK-17853][STREAMING][KAFKA][DOC] make it clear that reusing group.id is bad ## What changes were proposed in this pull request? Documentation fix to make it clear that reusing group id for different streams is super duper bad, just like it is with the underlying Kafka consumer. ## How was this patch tested? I built jekyll doc and made sure it looked ok. Author: cody koeninger Closes #15442 from koeninger/SPARK-17853. --- docs/streaming-kafka-0-10-integration.md | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/docs/streaming-kafka-0-10-integration.md b/docs/streaming-kafka-0-10-integration.md index 44c39e39446de..456b8453383db 100644 --- a/docs/streaming-kafka-0-10-integration.md +++ b/docs/streaming-kafka-0-10-integration.md @@ -27,7 +27,7 @@ For Scala/Java applications using SBT/Maven project definitions, link your strea "bootstrap.servers" -> "localhost:9092,anotherhost:9092", "key.deserializer" -> classOf[StringDeserializer], "value.deserializer" -> classOf[StringDeserializer], - "group.id" -> "example", + "group.id" -> "use_a_separate_group_id_for_each_stream", "auto.offset.reset" -> "latest", "enable.auto.commit" -> (false: java.lang.Boolean) ) @@ -48,7 +48,7 @@ Each item in the stream is a [ConsumerRecord](http://kafka.apache.org/0100/javad For possible kafkaParams, see [Kafka consumer config docs](http://kafka.apache.org/documentation.html#newconsumerconfigs). -Note that enable.auto.commit is disabled, for discussion see [Storing Offsets](streaming-kafka-0-10-integration.html#storing-offsets) below. +Note that the example sets enable.auto.commit to false, for discussion see [Storing Offsets](streaming-kafka-0-10-integration.html#storing-offsets) below. ### LocationStrategies The new Kafka consumer API will pre-fetch messages into buffers. Therefore it is important for performance reasons that the Spark integration keep cached consumers on executors (rather than recreating them for each batch), and prefer to schedule partitions on the host locations that have the appropriate consumers. @@ -57,6 +57,9 @@ In most cases, you should use `LocationStrategies.PreferConsistent` as shown abo The cache for consumers has a default maximum size of 64. If you expect to be handling more than (64 * number of executors) Kafka partitions, you can change this setting via `spark.streaming.kafka.consumer.cache.maxCapacity` +The cache is keyed by topicpartition and group.id, so use a **separate** `group.id` for each call to `createDirectStream`. + + ### ConsumerStrategies The new Kafka consumer API has a number of different ways to specify topics, some of which require considerable post-object-instantiation setup. `ConsumerStrategies` provides an abstraction that allows Spark to obtain properly configured consumers even after restart from checkpoint. From 8d33e1e5bfde6d2d1270058e49772427383312b3 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Wed, 12 Oct 2016 10:00:53 +0100 Subject: [PATCH 255/306] [SPARK-11560][MLLIB] Optimize KMeans implementation / remove 'runs' ## What changes were proposed in this pull request? This is a revival of https://github.com/apache/spark/pull/14948 and related to https://github.com/apache/spark/pull/14937. This removes the 'runs' parameter, which has already been disabled, from the K-means implementation and further deprecates API methods that involve it. This also happens to resolve the issue that K-means should not return duplicate centers, meaning that it may return less than k centroids if not enough data is available. ## How was this patch tested? Existing tests Author: Sean Owen Closes #15342 from srowen/SPARK-11560. --- .../spark/mllib/clustering/KMeans.scala | 296 ++++++++---------- 1 file changed, 132 insertions(+), 164 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index 23141aaf42b49..68a7b3b6763af 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -43,18 +43,17 @@ import org.apache.spark.util.random.XORShiftRandom class KMeans private ( private var k: Int, private var maxIterations: Int, - private var runs: Int, private var initializationMode: String, private var initializationSteps: Int, private var epsilon: Double, private var seed: Long) extends Serializable with Logging { /** - * Constructs a KMeans instance with default parameters: {k: 2, maxIterations: 20, runs: 1, + * Constructs a KMeans instance with default parameters: {k: 2, maxIterations: 20, * initializationMode: "k-means||", initializationSteps: 2, epsilon: 1e-4, seed: random}. */ @Since("0.8.0") - def this() = this(2, 20, 1, KMeans.K_MEANS_PARALLEL, 2, 1e-4, Utils.random.nextLong()) + def this() = this(2, 20, KMeans.K_MEANS_PARALLEL, 2, 1e-4, Utils.random.nextLong()) /** * Number of clusters to create (k). @@ -112,15 +111,17 @@ class KMeans private ( * This function has no effect since Spark 2.0.0. */ @Since("1.4.0") + @deprecated("This has no effect and always returns 1", "2.1.0") def getRuns: Int = { logWarning("Getting number of runs has no effect since Spark 2.0.0.") - runs + 1 } /** * This function has no effect since Spark 2.0.0. */ @Since("0.8.0") + @deprecated("This has no effect", "2.1.0") def setRuns(runs: Int): this.type = { logWarning("Setting number of runs has no effect since Spark 2.0.0.") this @@ -239,17 +240,9 @@ class KMeans private ( val initStartTime = System.nanoTime() - // Only one run is allowed when initialModel is given - val numRuns = if (initialModel.nonEmpty) { - if (runs > 1) logWarning("Ignoring runs; one run is allowed when initialModel is given.") - 1 - } else { - runs - } - val centers = initialModel match { case Some(kMeansCenters) => - Array(kMeansCenters.clusterCenters.map(s => new VectorWithNorm(s))) + kMeansCenters.clusterCenters.map(new VectorWithNorm(_)) case None => if (initializationMode == KMeans.RANDOM) { initRandom(data) @@ -258,89 +251,62 @@ class KMeans private ( } } val initTimeInSeconds = (System.nanoTime() - initStartTime) / 1e9 - logInfo(s"Initialization with $initializationMode took " + "%.3f".format(initTimeInSeconds) + - " seconds.") + logInfo(f"Initialization with $initializationMode took $initTimeInSeconds%.3f seconds.") - val active = Array.fill(numRuns)(true) - val costs = Array.fill(numRuns)(0.0) - - var activeRuns = new ArrayBuffer[Int] ++ (0 until numRuns) + var converged = false + var cost = 0.0 var iteration = 0 val iterationStartTime = System.nanoTime() - instr.foreach(_.logNumFeatures(centers(0)(0).vector.size)) + instr.foreach(_.logNumFeatures(centers.head.vector.size)) - // Execute iterations of Lloyd's algorithm until all runs have converged - while (iteration < maxIterations && !activeRuns.isEmpty) { - type WeightedPoint = (Vector, Long) - def mergeContribs(x: WeightedPoint, y: WeightedPoint): WeightedPoint = { - axpy(1.0, x._1, y._1) - (y._1, x._2 + y._2) - } - - val activeCenters = activeRuns.map(r => centers(r)).toArray - val costAccums = activeRuns.map(_ => sc.doubleAccumulator) - - val bcActiveCenters = sc.broadcast(activeCenters) + // Execute iterations of Lloyd's algorithm until converged + while (iteration < maxIterations && !converged) { + val costAccum = sc.doubleAccumulator + val bcCenters = sc.broadcast(centers) // Find the sum and count of points mapping to each center val totalContribs = data.mapPartitions { points => - val thisActiveCenters = bcActiveCenters.value - val runs = thisActiveCenters.length - val k = thisActiveCenters(0).length - val dims = thisActiveCenters(0)(0).vector.size + val thisCenters = bcCenters.value + val dims = thisCenters.head.vector.size - val sums = Array.fill(runs, k)(Vectors.zeros(dims)) - val counts = Array.fill(runs, k)(0L) + val sums = Array.fill(thisCenters.length)(Vectors.zeros(dims)) + val counts = Array.fill(thisCenters.length)(0L) points.foreach { point => - (0 until runs).foreach { i => - val (bestCenter, cost) = KMeans.findClosest(thisActiveCenters(i), point) - costAccums(i).add(cost) - val sum = sums(i)(bestCenter) - axpy(1.0, point.vector, sum) - counts(i)(bestCenter) += 1 - } + val (bestCenter, cost) = KMeans.findClosest(thisCenters, point) + costAccum.add(cost) + val sum = sums(bestCenter) + axpy(1.0, point.vector, sum) + counts(bestCenter) += 1 } - val contribs = for (i <- 0 until runs; j <- 0 until k) yield { - ((i, j), (sums(i)(j), counts(i)(j))) - } - contribs.iterator - }.reduceByKey(mergeContribs).collectAsMap() - - bcActiveCenters.destroy(blocking = false) - - // Update the cluster centers and costs for each active run - for ((run, i) <- activeRuns.zipWithIndex) { - var changed = false - var j = 0 - while (j < k) { - val (sum, count) = totalContribs((i, j)) - if (count != 0) { - scal(1.0 / count, sum) - val newCenter = new VectorWithNorm(sum) - if (KMeans.fastSquaredDistance(newCenter, centers(run)(j)) > epsilon * epsilon) { - changed = true - } - centers(run)(j) = newCenter - } - j += 1 - } - if (!changed) { - active(run) = false - logInfo("Run " + run + " finished in " + (iteration + 1) + " iterations") + counts.indices.filter(counts(_) > 0).map(j => (j, (sums(j), counts(j)))).iterator + }.reduceByKey { case ((sum1, count1), (sum2, count2)) => + axpy(1.0, sum2, sum1) + (sum1, count1 + count2) + }.collectAsMap() + + bcCenters.destroy(blocking = false) + + // Update the cluster centers and costs + converged = true + totalContribs.foreach { case (j, (sum, count)) => + scal(1.0 / count, sum) + val newCenter = new VectorWithNorm(sum) + if (converged && KMeans.fastSquaredDistance(newCenter, centers(j)) > epsilon * epsilon) { + converged = false } - costs(run) = costAccums(i).value + centers(j) = newCenter } - activeRuns = activeRuns.filter(active(_)) + cost = costAccum.value iteration += 1 } val iterationTimeInSeconds = (System.nanoTime() - iterationStartTime) / 1e9 - logInfo(s"Iterations took " + "%.3f".format(iterationTimeInSeconds) + " seconds.") + logInfo(f"Iterations took $iterationTimeInSeconds%.3f seconds.") if (iteration == maxIterations) { logInfo(s"KMeans reached the max number of iterations: $maxIterations.") @@ -348,59 +314,43 @@ class KMeans private ( logInfo(s"KMeans converged in $iteration iterations.") } - val (minCost, bestRun) = costs.zipWithIndex.min + logInfo(s"The cost is $cost.") - logInfo(s"The cost for the best run is $minCost.") - - new KMeansModel(centers(bestRun).map(_.vector)) + new KMeansModel(centers.map(_.vector)) } /** - * Initialize `runs` sets of cluster centers at random. + * Initialize a set of cluster centers at random. */ - private def initRandom(data: RDD[VectorWithNorm]) - : Array[Array[VectorWithNorm]] = { - // Sample all the cluster centers in one pass to avoid repeated scans - val sample = data.takeSample(true, runs * k, new XORShiftRandom(this.seed).nextInt()).toSeq - Array.tabulate(runs)(r => sample.slice(r * k, (r + 1) * k).map { v => - new VectorWithNorm(Vectors.dense(v.vector.toArray), v.norm) - }.toArray) + private def initRandom(data: RDD[VectorWithNorm]): Array[VectorWithNorm] = { + data.takeSample(true, k, new XORShiftRandom(this.seed).nextInt()).map(_.toDense) } /** - * Initialize `runs` sets of cluster centers using the k-means|| algorithm by Bahmani et al. + * Initialize a set of cluster centers using the k-means|| algorithm by Bahmani et al. * (Bahmani et al., Scalable K-Means++, VLDB 2012). This is a variant of k-means++ that tries - * to find with dissimilar cluster centers by starting with a random center and then doing + * to find dissimilar cluster centers by starting with a random center and then doing * passes where more centers are chosen with probability proportional to their squared distance * to the current cluster set. It results in a provable approximation to an optimal clustering. * * The original paper can be found at http://theory.stanford.edu/~sergei/papers/vldb12-kmpar.pdf. */ - private def initKMeansParallel(data: RDD[VectorWithNorm]) - : Array[Array[VectorWithNorm]] = { + private def initKMeansParallel(data: RDD[VectorWithNorm]): Array[VectorWithNorm] = { // Initialize empty centers and point costs. - val centers = Array.tabulate(runs)(r => ArrayBuffer.empty[VectorWithNorm]) - var costs = data.map(_ => Array.fill(runs)(Double.PositiveInfinity)) + var costs = data.map(_ => Double.PositiveInfinity) - // Initialize each run's first center to a random point. + // Initialize the first center to a random point. val seed = new XORShiftRandom(this.seed).nextInt() - val sample = data.takeSample(true, runs, seed).toSeq + val sample = data.takeSample(false, 1, seed) // Could be empty if data is empty; fail with a better message early: - require(sample.size >= runs, s"Required $runs samples but got ${sample.size} from $data") - val newCenters = Array.tabulate(runs)(r => ArrayBuffer(sample(r).toDense)) - - /** Merges new centers to centers. */ - def mergeNewCenters(): Unit = { - var r = 0 - while (r < runs) { - centers(r) ++= newCenters(r) - newCenters(r).clear() - r += 1 - } - } + require(sample.nonEmpty, s"No samples available from $data") + + val centers = ArrayBuffer[VectorWithNorm]() + var newCenters = Seq(sample.head.toDense) + centers ++= newCenters - // On each step, sample 2 * k points on average for each run with probability proportional - // to their squared distance from that run's centers. Note that only distances between points + // On each step, sample 2 * k points on average with probability proportional + // to their squared distance from the centers. Note that only distances between points // and new centers are computed in each iteration. var step = 0 var bcNewCentersList = ArrayBuffer[Broadcast[_]]() @@ -409,74 +359,39 @@ class KMeans private ( bcNewCentersList += bcNewCenters val preCosts = costs costs = data.zip(preCosts).map { case (point, cost) => - Array.tabulate(runs) { r => - math.min(KMeans.pointCost(bcNewCenters.value(r), point), cost(r)) - } - }.persist(StorageLevel.MEMORY_AND_DISK) - val sumCosts = costs - .aggregate(new Array[Double](runs))( - seqOp = (s, v) => { - // s += v - var r = 0 - while (r < runs) { - s(r) += v(r) - r += 1 - } - s - }, - combOp = (s0, s1) => { - // s0 += s1 - var r = 0 - while (r < runs) { - s0(r) += s1(r) - r += 1 - } - s0 - } - ) + math.min(KMeans.pointCost(bcNewCenters.value, point), cost) + }.persist(StorageLevel.MEMORY_AND_DISK) + val sumCosts = costs.sum() bcNewCenters.unpersist(blocking = false) preCosts.unpersist(blocking = false) - val chosen = data.zip(costs).mapPartitionsWithIndex { (index, pointsWithCosts) => + val chosen = data.zip(costs).mapPartitionsWithIndex { (index, pointCosts) => val rand = new XORShiftRandom(seed ^ (step << 16) ^ index) - pointsWithCosts.flatMap { case (p, c) => - val rs = (0 until runs).filter { r => - rand.nextDouble() < 2.0 * c(r) * k / sumCosts(r) - } - if (rs.nonEmpty) Some((p, rs)) else None - } + pointCosts.filter { case (_, c) => rand.nextDouble() < 2.0 * c * k / sumCosts }.map(_._1) }.collect() - mergeNewCenters() - chosen.foreach { case (p, rs) => - rs.foreach(newCenters(_) += p.toDense) - } + newCenters = chosen.map(_.toDense) + centers ++= newCenters step += 1 } - mergeNewCenters() costs.unpersist(blocking = false) bcNewCentersList.foreach(_.destroy(false)) - // Finally, we might have a set of more than k candidate centers for each run; weigh each - // candidate by the number of points in the dataset mapping to it and run a local k-means++ - // on the weighted centers to pick just k of them - val bcCenters = data.context.broadcast(centers) - val weightMap = data.flatMap { p => - Iterator.tabulate(runs) { r => - ((r, KMeans.findClosest(bcCenters.value(r), p)._1), 1.0) - } - }.reduceByKey(_ + _).collectAsMap() + if (centers.size == k) { + centers.toArray + } else { + // Finally, we might have a set of more or less than k candidate centers; weight each + // candidate by the number of points in the dataset mapping to it and run a local k-means++ + // on the weighted centers to pick k of them + val bcCenters = data.context.broadcast(centers) + val countMap = data.map(KMeans.findClosest(bcCenters.value, _)._1).countByValue() - bcCenters.destroy(blocking = false) + bcCenters.destroy(blocking = false) - val finalCenters = (0 until runs).par.map { r => - val myCenters = centers(r).toArray - val myWeights = (0 until myCenters.length).map(i => weightMap.getOrElse((r, i), 0.0)).toArray - LocalKMeans.kMeansPlusPlus(r, myCenters, myWeights, k, 30) + val myWeights = centers.indices.map(countMap.getOrElse(_, 0L).toDouble).toArray + LocalKMeans.kMeansPlusPlus(0, centers.toArray, myWeights, k, 30) } - - finalCenters.toArray } } @@ -493,6 +408,52 @@ object KMeans { @Since("0.8.0") val K_MEANS_PARALLEL = "k-means||" + /** + * Trains a k-means model using the given set of parameters. + * + * @param data Training points as an `RDD` of `Vector` types. + * @param k Number of clusters to create. + * @param maxIterations Maximum number of iterations allowed. + * @param initializationMode The initialization algorithm. This can either be "random" or + * "k-means||". (default: "k-means||") + * @param seed Random seed for cluster initialization. Default is to generate seed based + * on system time. + */ + @Since("2.1.0") + def train( + data: RDD[Vector], + k: Int, + maxIterations: Int, + initializationMode: String, + seed: Long): KMeansModel = { + new KMeans().setK(k) + .setMaxIterations(maxIterations) + .setInitializationMode(initializationMode) + .setSeed(seed) + .run(data) + } + + /** + * Trains a k-means model using the given set of parameters. + * + * @param data Training points as an `RDD` of `Vector` types. + * @param k Number of clusters to create. + * @param maxIterations Maximum number of iterations allowed. + * @param initializationMode The initialization algorithm. This can either be "random" or + * "k-means||". (default: "k-means||") + */ + @Since("2.1.0") + def train( + data: RDD[Vector], + k: Int, + maxIterations: Int, + initializationMode: String): KMeansModel = { + new KMeans().setK(k) + .setMaxIterations(maxIterations) + .setInitializationMode(initializationMode) + .run(data) + } + /** * Trains a k-means model using the given set of parameters. * @@ -506,6 +467,7 @@ object KMeans { * on system time. */ @Since("1.3.0") + @deprecated("Use train method without 'runs'", "2.1.0") def train( data: RDD[Vector], k: Int, @@ -531,6 +493,7 @@ object KMeans { * "k-means||". (default: "k-means||") */ @Since("0.8.0") + @deprecated("Use train method without 'runs'", "2.1.0") def train( data: RDD[Vector], k: Int, @@ -551,19 +514,24 @@ object KMeans { data: RDD[Vector], k: Int, maxIterations: Int): KMeansModel = { - train(data, k, maxIterations, 1, K_MEANS_PARALLEL) + new KMeans().setK(k) + .setMaxIterations(maxIterations) + .run(data) } /** * Trains a k-means model using specified parameters and the default values for unspecified. */ @Since("0.8.0") + @deprecated("Use train method without 'runs'", "2.1.0") def train( data: RDD[Vector], k: Int, maxIterations: Int, runs: Int): KMeansModel = { - train(data, k, maxIterations, runs, K_MEANS_PARALLEL) + new KMeans().setK(k) + .setMaxIterations(maxIterations) + .run(data) } /** From 8880fd13ef2b581f9c7190e7e3e6d24bc11b4ef7 Mon Sep 17 00:00:00 2001 From: Bijay Pathak Date: Wed, 12 Oct 2016 10:09:49 -0700 Subject: [PATCH 256/306] [SPARK-14761][SQL] Reject invalid join methods when join columns are not specified in PySpark DataFrame join. ## What changes were proposed in this pull request? In PySpark, the invalid join type will not throw error for the following join: ```df1.join(df2, how='not-a-valid-join-type')``` The signature of the join is: ```def join(self, other, on=None, how=None):``` The existing code completely ignores the `how` parameter when `on` is `None`. This patch will process the arguments passed to join and pass in to JVM Spark SQL Analyzer, which will validate the join type passed. ## How was this patch tested? Used manual and existing test suites. Author: Bijay Pathak Closes #15409 from bkpathak/SPARK-14761. --- python/pyspark/sql/dataframe.py | 31 +++++++++++++++---------------- python/pyspark/sql/tests.py | 6 ++++++ 2 files changed, 21 insertions(+), 16 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 14e80ea4615ef..ce277eb204d13 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -661,25 +661,24 @@ def join(self, other, on=None, how=None): if on is not None and not isinstance(on, list): on = [on] - if on is None or len(on) == 0: - jdf = self._jdf.crossJoin(other._jdf) - elif isinstance(on[0], basestring): - if how is None: - jdf = self._jdf.join(other._jdf, self._jseq(on), "inner") + if on is not None: + if isinstance(on[0], basestring): + on = self._jseq(on) else: - assert isinstance(how, basestring), "how should be basestring" - jdf = self._jdf.join(other._jdf, self._jseq(on), how) + assert isinstance(on[0], Column), "on should be Column or list of Column" + if len(on) > 1: + on = reduce(lambda x, y: x.__and__(y), on) + else: + on = on[0] + on = on._jc + + if on is None and how is None: + jdf = self._jdf.crossJoin(other._jdf) else: - assert isinstance(on[0], Column), "on should be Column or list of Column" - if len(on) > 1: - on = reduce(lambda x, y: x.__and__(y), on) - else: - on = on[0] if how is None: - jdf = self._jdf.join(other._jdf, on._jc, "inner") - else: - assert isinstance(how, basestring), "how should be basestring" - jdf = self._jdf.join(other._jdf, on._jc, how) + how = "inner" + assert isinstance(how, basestring), "how should be basestring" + jdf = self._jdf.join(other._jdf, on, how) return DataFrame(jdf, self.sql_ctx) @since(1.6) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 86c590dae34d7..61674a8a7ed65 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1508,6 +1508,12 @@ def test_toDF_with_schema_string(self): self.assertEqual(df.schema.simpleString(), "struct") self.assertEqual(df.collect(), [Row(key=i) for i in range(100)]) + # Regression test for invalid join methods when on is None, Spark-14761 + def test_invalid_join_method(self): + df1 = self.spark.createDataFrame([("Alice", 5), ("Bob", 8)], ["name", "age"]) + df2 = self.spark.createDataFrame([("Alice", 80), ("Bob", 90)], ["name", "height"]) + self.assertRaises(IllegalArgumentException, lambda: df1.join(df2, how="invalid-join-type")) + def test_conf(self): spark = self.spark spark.conf.set("bogo", "sipeo") From d5580ebaa086b9feb72d5428f24c5b60cd7da745 Mon Sep 17 00:00:00 2001 From: prigarg Date: Wed, 12 Oct 2016 10:14:45 -0700 Subject: [PATCH 257/306] [SPARK-17884][SQL] To resolve Null pointer exception when casting from empty string to interval type. ## What changes were proposed in this pull request? This change adds a check in castToInterval method of Cast expression , such that if converted value is null , then isNull variable should be set to true. Earlier, the expression Cast(Literal(), CalendarIntervalType) was throwing NullPointerException because of the above mentioned reason. ## How was this patch tested? Added test case in CastSuite.scala jira entry for detail: https://issues.apache.org/jira/browse/SPARK-17884 Author: prigarg Closes #15449 from priyankagargnitk/SPARK-17884. --- .../org/apache/spark/sql/catalyst/expressions/Cast.scala | 7 ++++++- .../apache/spark/sql/catalyst/expressions/CastSuite.scala | 1 + 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 1314c416510dc..58fd65f62ffe7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -657,7 +657,12 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w private[this] def castToIntervalCode(from: DataType): CastFunction = from match { case StringType => (c, evPrim, evNull) => - s"$evPrim = CalendarInterval.fromString($c.toString());" + s"""$evPrim = CalendarInterval.fromString($c.toString()); + if(${evPrim} == null) { + ${evNull} = true; + } + """.stripMargin + } private[this] def decimalToTimestampCode(d: String): String = diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 5c35baacef2fa..b748595fc4f2d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -767,6 +767,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { test("cast between string and interval") { import org.apache.spark.unsafe.types.CalendarInterval + checkEvaluation(Cast(Literal(""), CalendarIntervalType), null) checkEvaluation(Cast(Literal("interval -3 month 7 hours"), CalendarIntervalType), new CalendarInterval(-3, 7 * CalendarInterval.MICROS_PER_HOUR)) checkEvaluation(Cast(Literal.create( From 5cc503f4fe9737a4c7947a80eecac053780606df Mon Sep 17 00:00:00 2001 From: Hossein Date: Wed, 12 Oct 2016 10:32:38 -0700 Subject: [PATCH 258/306] [SPARK-17790][SPARKR] Support for parallelizing R data.frame larger than 2GB ## What changes were proposed in this pull request? If the R data structure that is being parallelized is larger than `INT_MAX` we use files to transfer data to JVM. The serialization protocol mimics Python pickling. This allows us to simply call `PythonRDD.readRDDFromFile` to create the RDD. I tested this on my MacBook. Following code works with this patch: ```R intMax <- .Machine$integer.max largeVec <- 1:intMax rdd <- SparkR:::parallelize(sc, largeVec, 2) ``` ## How was this patch tested? * [x] Unit tests Author: Hossein Closes #15375 from falaki/SPARK-17790. --- R/pkg/R/context.R | 45 ++++++++++++++++++- R/pkg/inst/tests/testthat/test_sparkSQL.R | 11 +++++ .../apache/spark/api/r/RBackendHandler.scala | 2 +- .../scala/org/apache/spark/api/r/RRDD.scala | 13 ++++++ 4 files changed, 68 insertions(+), 3 deletions(-) diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R index fe2f3e3d10a9b..438d77a388f0e 100644 --- a/R/pkg/R/context.R +++ b/R/pkg/R/context.R @@ -87,6 +87,10 @@ objectFile <- function(sc, path, minPartitions = NULL) { #' in the list are split into \code{numSlices} slices and distributed to nodes #' in the cluster. #' +#' If size of serialized slices is larger than spark.r.maxAllocationLimit or (200MB), the function +#' will write it to disk and send the file name to JVM. Also to make sure each slice is not +#' larger than that limit, number of slices may be increased. +#' #' @param sc SparkContext to use #' @param coll collection to parallelize #' @param numSlices number of partitions to create in the RDD @@ -120,6 +124,11 @@ parallelize <- function(sc, coll, numSlices = 1) { coll <- as.list(coll) } + sizeLimit <- getMaxAllocationLimit(sc) + objectSize <- object.size(coll) + + # For large objects we make sure the size of each slice is also smaller than sizeLimit + numSlices <- max(numSlices, ceiling(objectSize / sizeLimit)) if (numSlices > length(coll)) numSlices <- length(coll) @@ -130,12 +139,44 @@ parallelize <- function(sc, coll, numSlices = 1) { # 2-tuples of raws serializedSlices <- lapply(slices, serialize, connection = NULL) - jrdd <- callJStatic("org.apache.spark.api.r.RRDD", - "createRDDFromArray", sc, serializedSlices) + # The PRC backend cannot handle arguments larger than 2GB (INT_MAX) + # If serialized data is safely less than that threshold we send it over the PRC channel. + # Otherwise, we write it to a file and send the file name + if (objectSize < sizeLimit) { + jrdd <- callJStatic("org.apache.spark.api.r.RRDD", "createRDDFromArray", sc, serializedSlices) + } else { + fileName <- writeToTempFile(serializedSlices) + jrdd <- tryCatch(callJStatic( + "org.apache.spark.api.r.RRDD", "createRDDFromFile", sc, fileName, as.integer(numSlices)), + finally = { + file.remove(fileName) + }) + } RDD(jrdd, "byte") } +getMaxAllocationLimit <- function(sc) { + conf <- callJMethod(sc, "getConf") + as.numeric( + callJMethod(conf, + "get", + "spark.r.maxAllocationLimit", + toString(.Machine$integer.max / 10) # Default to a safe value: 200MB + )) +} + +writeToTempFile <- function(serializedSlices) { + fileName <- tempfile() + conn <- file(fileName, "wb") + for (slice in serializedSlices) { + writeBin(as.integer(length(slice)), conn, endian = "big") + writeBin(slice, conn, endian = "big") + } + close(conn) + fileName +} + #' Include this specified package on all workers #' #' This function can be used to include a package on all workers before the diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 61554248ee8f8..af81d0586e0a6 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -208,6 +208,17 @@ test_that("create DataFrame from RDD", { unsetHiveContext() }) +test_that("createDataFrame uses files for large objects", { + # To simulate a large file scenario, we set spark.r.maxAllocationLimit to a smaller value + conf <- callJMethod(sparkSession, "conf") + callJMethod(conf, "set", "spark.r.maxAllocationLimit", "100") + df <- createDataFrame(iris) + + # Resetting the conf back to default value + callJMethod(conf, "set", "spark.r.maxAllocationLimit", toString(.Machine$integer.max / 10)) + expect_equal(dim(df), dim(iris)) +}) + test_that("read/write csv as DataFrame", { csvPath <- tempfile(pattern = "sparkr-test", fileext = ".csv") mockLinesCsv <- c("year,make,model,comment,blank", diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala index 7d5348266bf6e..1422ef888fd4a 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala @@ -168,7 +168,7 @@ private[r] class RBackendHandler(server: RBackend) } } catch { case e: Exception => - logError(s"$methodName on $objId failed") + logError(s"$methodName on $objId failed", e) writeInt(dos, -1) // Writing the error message of the cause for the exception. This will be returned // to user in the R process. diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala index 59c8429c80172..a1a5eb8cf55e8 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala @@ -24,6 +24,7 @@ import scala.reflect.ClassTag import org.apache.spark._ import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext} +import org.apache.spark.api.python.PythonRDD import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD @@ -140,4 +141,16 @@ private[r] object RRDD { def createRDDFromArray(jsc: JavaSparkContext, arr: Array[Array[Byte]]): JavaRDD[Array[Byte]] = { JavaRDD.fromRDD(jsc.sc.parallelize(arr, arr.length)) } + + /** + * Create an RRDD given a temporary file name. This is used to create RRDD when parallelize is + * called on large R objects. + * + * @param fileName name of temporary file on driver machine + * @param parallelism number of slices defaults to 4 + */ + def createRDDFromFile(jsc: JavaSparkContext, fileName: String, parallelism: Int): + JavaRDD[Array[Byte]] = { + PythonRDD.readRDDFromFile(jsc, fileName, parallelism) + } } From f8062b63fc5e07a6bf4c153a74a966602865fa6e Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Wed, 12 Oct 2016 11:14:03 -0700 Subject: [PATCH 259/306] [SPARK-17840][DOCS] Add some pointers for wiki/CONTRIBUTING.md in README.md and some warnings in PULL_REQUEST_TEMPLATE ## What changes were proposed in this pull request? Link to contributing wiki in PR template, README.md ## How was this patch tested? Doc-only change, tested by Jekyll Author: Sean Owen Closes #15429 from srowen/SPARK-17840. --- .github/PULL_REQUEST_TEMPLATE | 4 +--- README.md | 5 +++++ 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/.github/PULL_REQUEST_TEMPLATE b/.github/PULL_REQUEST_TEMPLATE index 989e95ccd0135..0e41cf1826453 100644 --- a/.github/PULL_REQUEST_TEMPLATE +++ b/.github/PULL_REQUEST_TEMPLATE @@ -2,11 +2,9 @@ (Please fill in changes proposed in this fix) - ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) - - (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) +Please review https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark before opening a pull request. diff --git a/README.md b/README.md index c77c429e577cd..dd7d0e22495b3 100644 --- a/README.md +++ b/README.md @@ -97,3 +97,8 @@ building for particular Hive and Hive Thriftserver distributions. Please refer to the [Configuration Guide](http://spark.apache.org/docs/latest/configuration.html) in the online documentation for an overview on how to configure Spark. + +## Contributing + +Please review the [Contribution to Spark](https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark) +wiki for information on how to get started contributing to the project. From eb69335cdbce54f943ae6168aed39687f40e53ed Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Wed, 12 Oct 2016 11:59:01 -0700 Subject: [PATCH 260/306] [BUILD] Closing stale PRs Closes #15303 Closes #15078 Closes #15080 Closes #15135 Closes #14565 Closes #12355 Closes #15404 Author: Sean Owen Closes #15451 from srowen/CloseStalePRs. From 47776e7c0c68590fe446cef910900b1aaead06f9 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Wed, 12 Oct 2016 13:51:53 -0700 Subject: [PATCH 261/306] [SPARK-17850][CORE] Add a flag to ignore corrupt files ## What changes were proposed in this pull request? Add a flag to ignore corrupt files. For Spark core, the configuration is `spark.files.ignoreCorruptFiles`. For Spark SQL, it's `spark.sql.files.ignoreCorruptFiles`. ## How was this patch tested? The added unit tests Author: Shixiong Zhu Closes #15422 from zsxwing/SPARK-17850. --- .../spark/internal/config/package.scala | 5 ++ .../org/apache/spark/rdd/HadoopRDD.scala | 8 ++- .../org/apache/spark/rdd/NewHadoopRDD.scala | 10 ++- .../scala/org/apache/spark/FileSuite.scala | 62 ++++++++++++++++++- .../execution/datasources/FileScanRDD.scala | 30 ++++++++- .../apache/spark/sql/internal/SQLConf.scala | 8 +++ .../datasources/FileSourceStrategySuite.scala | 37 ++++++++++- 7 files changed, 153 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 5a710158db89f..517fc3e9e9c77 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -156,4 +156,9 @@ package object config { .doc("Port to use for the block managed on the driver.") .fallbackConf(BLOCK_MANAGER_PORT) + private[spark] val IGNORE_CORRUPT_FILES = ConfigBuilder("spark.files.ignoreCorruptFiles") + .doc("Whether to ignore corrupt files. If true, the Spark jobs will continue to run when " + + "encountering corrupt files and contents that have been read will still be returned.") + .booleanConf + .createWithDefault(false) } diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 4640b5dc2f654..e1cf3938de098 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -17,7 +17,7 @@ package org.apache.spark.rdd -import java.io.EOFException +import java.io.IOException import java.text.SimpleDateFormat import java.util.Date @@ -43,6 +43,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.IGNORE_CORRUPT_FILES import org.apache.spark.rdd.HadoopRDD.HadoopMapPartitionsWithSplitRDD import org.apache.spark.scheduler.{HDFSCacheTaskLocation, HostTaskLocation} import org.apache.spark.storage.StorageLevel @@ -139,6 +140,8 @@ class HadoopRDD[K, V]( private val shouldCloneJobConf = sparkContext.conf.getBoolean("spark.hadoop.cloneConf", false) + private val ignoreCorruptFiles = sparkContext.conf.get(IGNORE_CORRUPT_FILES) + // Returns a JobConf that will be used on slaves to obtain input splits for Hadoop reads. protected def getJobConf(): JobConf = { val conf: Configuration = broadcastedConf.value.value @@ -253,8 +256,7 @@ class HadoopRDD[K, V]( try { finished = !reader.next(key, value) } catch { - case eof: EOFException => - finished = true + case e: IOException if ignoreCorruptFiles => finished = true } if (!finished) { inputMetrics.incRecordsRead(1) diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index 1c7aec919bdc4..baf31fb658870 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -17,6 +17,7 @@ package org.apache.spark.rdd +import java.io.IOException import java.text.SimpleDateFormat import java.util.Date @@ -33,6 +34,7 @@ import org.apache.spark._ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.IGNORE_CORRUPT_FILES import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD import org.apache.spark.storage.StorageLevel import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager} @@ -85,6 +87,8 @@ class NewHadoopRDD[K, V]( private val shouldCloneJobConf = sparkContext.conf.getBoolean("spark.hadoop.cloneConf", false) + private val ignoreCorruptFiles = sparkContext.conf.get(IGNORE_CORRUPT_FILES) + def getConf: Configuration = { val conf: Configuration = confBroadcast.value.value if (shouldCloneJobConf) { @@ -179,7 +183,11 @@ class NewHadoopRDD[K, V]( override def hasNext: Boolean = { if (!finished && !havePair) { - finished = !reader.nextKeyValue + try { + finished = !reader.nextKeyValue + } catch { + case e: IOException if ignoreCorruptFiles => finished = true + } if (finished) { // Close and release the reader here; close() will also be called when the task // completes, but for tasks that read from many files, it helps to release the diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala index 993834f8d7d42..cc52bb1d23cd5 100644 --- a/core/src/test/scala/org/apache/spark/FileSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -17,7 +17,8 @@ package org.apache.spark -import java.io.{File, FileWriter} +import java.io._ +import java.util.zip.GZIPOutputStream import scala.io.Source @@ -29,6 +30,7 @@ import org.apache.hadoop.mapreduce.lib.input.{FileSplit => NewFileSplit, TextInp import org.apache.hadoop.mapreduce.lib.output.{TextOutputFormat => NewTextOutputFormat} import org.apache.spark.input.PortableDataStream +import org.apache.spark.internal.config.IGNORE_CORRUPT_FILES import org.apache.spark.rdd.{HadoopRDD, NewHadoopRDD} import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils @@ -541,4 +543,62 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { }.collect() assert(inputPaths.toSet === Set(s"$outDir/part-00000", s"$outDir/part-00001")) } + + test("spark.files.ignoreCorruptFiles should work both HadoopRDD and NewHadoopRDD") { + val inputFile = File.createTempFile("input-", ".gz") + try { + // Create a corrupt gzip file + val byteOutput = new ByteArrayOutputStream() + val gzip = new GZIPOutputStream(byteOutput) + try { + gzip.write(Array[Byte](1, 2, 3, 4)) + } finally { + gzip.close() + } + val bytes = byteOutput.toByteArray + val o = new FileOutputStream(inputFile) + try { + // It's corrupt since we only write half of bytes into the file. + o.write(bytes.take(bytes.length / 2)) + } finally { + o.close() + } + + // Reading a corrupt gzip file should throw EOFException + sc = new SparkContext("local", "test") + // Test HadoopRDD + var e = intercept[SparkException] { + sc.textFile(inputFile.toURI.toString).collect() + } + assert(e.getCause.isInstanceOf[EOFException]) + assert(e.getCause.getMessage === "Unexpected end of input stream") + // Test NewHadoopRDD + e = intercept[SparkException] { + sc.newAPIHadoopFile( + inputFile.toURI.toString, + classOf[NewTextInputFormat], + classOf[LongWritable], + classOf[Text]).collect() + } + assert(e.getCause.isInstanceOf[EOFException]) + assert(e.getCause.getMessage === "Unexpected end of input stream") + sc.stop() + + val conf = new SparkConf().set(IGNORE_CORRUPT_FILES, true) + sc = new SparkContext("local", "test", conf) + // Test HadoopRDD + assert(sc.textFile(inputFile.toURI.toString).collect().isEmpty) + // Test NewHadoopRDD + assert { + sc.newAPIHadoopFile( + inputFile.toURI.toString, + classOf[NewTextInputFormat], + classOf[LongWritable], + classOf[Text]).collect().isEmpty + } + } finally { + inputFile.delete() + } + } + } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala index c66da3a83198d..89944570df662 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources +import java.io.IOException + import scala.collection.mutable import org.apache.spark.{Partition => RDDPartition, TaskContext} @@ -25,6 +27,7 @@ import org.apache.spark.rdd.{InputFileNameHolder, RDD} import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.vectorized.ColumnarBatch +import org.apache.spark.util.NextIterator /** * A part (i.e. "block") of a single file that should be read, along with partition column values @@ -62,6 +65,8 @@ class FileScanRDD( @transient val filePartitions: Seq[FilePartition]) extends RDD[InternalRow](sparkSession.sparkContext, Nil) { + private val ignoreCorruptFiles = sparkSession.sessionState.conf.ignoreCorruptFiles + override def compute(split: RDDPartition, context: TaskContext): Iterator[InternalRow] = { val iterator = new Iterator[Object] with AutoCloseable { private val inputMetrics = context.taskMetrics().inputMetrics @@ -119,7 +124,30 @@ class FileScanRDD( InputFileNameHolder.setInputFileName(currentFile.filePath) try { - currentIterator = readFunction(currentFile) + if (ignoreCorruptFiles) { + currentIterator = new NextIterator[Object] { + private val internalIter = readFunction(currentFile) + + override def getNext(): AnyRef = { + try { + if (internalIter.hasNext) { + internalIter.next() + } else { + finished = true + null + } + } catch { + case e: IOException => + finished = true + null + } + } + + override def close(): Unit = {} + } + } else { + currentIterator = readFunction(currentFile) + } } catch { case e: java.io.FileNotFoundException => throw new java.io.FileNotFoundException( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 8cbfc4c7628f7..9e7c1ec211893 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -576,6 +576,12 @@ object SQLConf { .doubleConf .createWithDefault(0.05) + val IGNORE_CORRUPT_FILES = SQLConfigBuilder("spark.sql.files.ignoreCorruptFiles") + .doc("Whether to ignore corrupt files. If true, the Spark jobs will continue to run when " + + "encountering corrupt files and contents that have been read will still be returned.") + .booleanConf + .createWithDefault(false) + object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } @@ -743,6 +749,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def warehousePath: String = new Path(getConf(WAREHOUSE_PATH)).toString + def ignoreCorruptFiles: Boolean = getConf(IGNORE_CORRUPT_FILES) + override def orderByOrdinal: Boolean = getConf(ORDER_BY_ORDINAL) override def groupByOrdinal: Boolean = getConf(GROUP_BY_ORDINAL) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index 45411fa0656cd..c5deb31fec183 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -17,8 +17,9 @@ package org.apache.spark.sql.execution.datasources -import java.io.File +import java.io._ import java.util.concurrent.atomic.AtomicInteger +import java.util.zip.GZIPOutputStream import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{BlockLocation, FileStatus, Path, RawLocalFileSystem} @@ -441,6 +442,40 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi } } + test("spark.files.ignoreCorruptFiles should work in SQL") { + val inputFile = File.createTempFile("input-", ".gz") + try { + // Create a corrupt gzip file + val byteOutput = new ByteArrayOutputStream() + val gzip = new GZIPOutputStream(byteOutput) + try { + gzip.write(Array[Byte](1, 2, 3, 4)) + } finally { + gzip.close() + } + val bytes = byteOutput.toByteArray + val o = new FileOutputStream(inputFile) + try { + // It's corrupt since we only write half of bytes into the file. + o.write(bytes.take(bytes.length / 2)) + } finally { + o.close() + } + withSQLConf(SQLConf.IGNORE_CORRUPT_FILES.key -> "false") { + val e = intercept[SparkException] { + spark.read.text(inputFile.toURI.toString).collect() + } + assert(e.getCause.isInstanceOf[EOFException]) + assert(e.getCause.getMessage === "Unexpected end of input stream") + } + withSQLConf(SQLConf.IGNORE_CORRUPT_FILES.key -> "true") { + assert(spark.read.text(inputFile.toURI.toString).collect().isEmpty) + } + } finally { + inputFile.delete() + } + } + // Helpers for checking the arguments passed to the FileFormat. protected val checkPartitionSchema = From 9ce7d3e542e786c62f047c13f3001e178f76e06a Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Wed, 12 Oct 2016 16:43:03 -0500 Subject: [PATCH 262/306] [SPARK-17675][CORE] Expand Blacklist for TaskSets ## What changes were proposed in this pull request? This is a step along the way to SPARK-8425. To enable incremental review, the first step proposed here is to expand the blacklisting within tasksets. In particular, this will enable blacklisting for * (task, executor) pairs (this already exists via an undocumented config) * (task, node) * (taskset, executor) * (taskset, node) Adding (task, node) is critical to making spark fault-tolerant of one-bad disk in a cluster, without requiring careful tuning of "spark.task.maxFailures". The other additions are also important to avoid many misleading task failures and long scheduling delays when there is one bad node on a large cluster. Note that some of the code changes here aren't really required for just this -- they put pieces in place for SPARK-8425 even though they are not used yet (eg. the `BlacklistTracker` helper is a little out of place, `TaskSetBlacklist` holds onto a little more info than it needs to for just this change, and `ExecutorFailuresInTaskSet` is more complex than it needs to be). ## How was this patch tested? Added unit tests, run tests via jenkins. Author: Imran Rashid Author: mwws Closes #15249 from squito/taskset_blacklist_only. --- .../scala/org/apache/spark/SparkConf.scala | 4 +- .../org/apache/spark/TaskEndReason.scala | 11 + .../spark/internal/config/package.scala | 45 +++ .../spark/scheduler/BlacklistTracker.scala | 114 ++++++++ .../scheduler/ExecutorFailuresInTaskSet.scala | 50 ++++ .../spark/scheduler/TaskSchedulerImpl.scala | 31 +- .../spark/scheduler/TaskSetBlacklist.scala | 128 ++++++++ .../spark/scheduler/TaskSetManager.scala | 276 +++++++++--------- .../scheduler/BlacklistIntegrationSuite.scala | 52 ++-- .../scheduler/BlacklistTrackerSuite.scala | 81 +++++ .../scheduler/SchedulerIntegrationSuite.scala | 4 +- .../scheduler/TaskSchedulerImplSuite.scala | 22 +- .../scheduler/TaskSetBlacklistSuite.scala | 163 +++++++++++ .../spark/scheduler/TaskSetManagerSuite.scala | 131 ++++++++- .../KryoSerializerDistributedSuite.scala | 4 +- docs/configuration.md | 43 +++ .../sql/execution/ui/SQLListenerSuite.scala | 3 +- 17 files changed, 964 insertions(+), 198 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/ExecutorFailuresInTaskSet.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/TaskSetBlacklist.scala create mode 100644 core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/scheduler/TaskSetBlacklistSuite.scala diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 51a699f41d15d..c9c342df82c97 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -636,7 +636,9 @@ private[spark] object SparkConf extends Logging { "Please use spark.kryoserializer.buffer instead. The default value for " + "spark.kryoserializer.buffer.mb was previously specified as '0.064'. Fractional values " + "are no longer accepted. To specify the equivalent now, one may use '64k'."), - DeprecatedConfig("spark.rpc", "2.0", "Not used any more.") + DeprecatedConfig("spark.rpc", "2.0", "Not used any more."), + DeprecatedConfig("spark.scheduler.executorTaskBlacklistTime", "2.1.0", + "Please use the new blacklisting options, spark.blacklist.*") ) Map(configs.map { cfg => (cfg.key -> cfg) } : _*) diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala index 42690844f9610..7ca3c103dbf5b 100644 --- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala +++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala @@ -92,6 +92,16 @@ case class FetchFailed( s"FetchFailed($bmAddressString, shuffleId=$shuffleId, mapId=$mapId, reduceId=$reduceId, " + s"message=\n$message\n)" } + + /** + * Fetch failures lead to a different failure handling path: (1) we don't abort the stage after + * 4 task failures, instead we immediately go back to the stage which generated the map output, + * and regenerate the missing data. (2) we don't count fetch failures for blacklisting, since + * presumably its not the fault of the executor where the task ran, but the executor which + * stored the data. This is especially important because we we might rack up a bunch of + * fetch-failures in rapid succession, on all nodes of the cluster, due to one bad node. + */ + override def countTowardsTaskFailures: Boolean = false } /** @@ -204,6 +214,7 @@ case object TaskResultLost extends TaskFailedReason { @DeveloperApi case object TaskKilled extends TaskFailedReason { override def toErrorString: String = "TaskKilled (killed intentionally)" + override def countTowardsTaskFailures: Boolean = false } /** diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 517fc3e9e9c77..497ca92c7bc60 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -17,6 +17,8 @@ package org.apache.spark.internal +import java.util.concurrent.TimeUnit + import org.apache.spark.launcher.SparkLauncher import org.apache.spark.network.util.ByteUnit import org.apache.spark.util.Utils @@ -91,6 +93,49 @@ package object config { .toSequence .createWithDefault(Nil) + private[spark] val MAX_TASK_FAILURES = + ConfigBuilder("spark.task.maxFailures") + .intConf + .createWithDefault(4) + + // Blacklist confs + private[spark] val BLACKLIST_ENABLED = + ConfigBuilder("spark.blacklist.enabled") + .booleanConf + .createOptional + + private[spark] val MAX_TASK_ATTEMPTS_PER_EXECUTOR = + ConfigBuilder("spark.blacklist.task.maxTaskAttemptsPerExecutor") + .intConf + .createWithDefault(1) + + private[spark] val MAX_TASK_ATTEMPTS_PER_NODE = + ConfigBuilder("spark.blacklist.task.maxTaskAttemptsPerNode") + .intConf + .createWithDefault(2) + + private[spark] val MAX_FAILURES_PER_EXEC_STAGE = + ConfigBuilder("spark.blacklist.stage.maxFailedTasksPerExecutor") + .intConf + .createWithDefault(2) + + private[spark] val MAX_FAILED_EXEC_PER_NODE_STAGE = + ConfigBuilder("spark.blacklist.stage.maxFailedExecutorsPerNode") + .intConf + .createWithDefault(2) + + private[spark] val BLACKLIST_TIMEOUT_CONF = + ConfigBuilder("spark.blacklist.timeout") + .timeConf(TimeUnit.MILLISECONDS) + .createOptional + + private[spark] val BLACKLIST_LEGACY_TIMEOUT_CONF = + ConfigBuilder("spark.scheduler.executorTaskBlacklistTime") + .internal() + .timeConf(TimeUnit.MILLISECONDS) + .createOptional + // End blacklist confs + private[spark] val LISTENER_BUS_EVENT_QUEUE_SIZE = ConfigBuilder("spark.scheduler.listenerbus.eventqueue.size") .intConf diff --git a/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala b/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala new file mode 100644 index 0000000000000..fca4c6d37e446 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import org.apache.spark.SparkConf +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config +import org.apache.spark.util.Utils + +private[scheduler] object BlacklistTracker extends Logging { + + private val DEFAULT_TIMEOUT = "1h" + + /** + * Returns true if the blacklist is enabled, based on checking the configuration in the following + * order: + * 1. Is it specifically enabled or disabled? + * 2. Is it enabled via the legacy timeout conf? + * 3. Default is off + */ + def isBlacklistEnabled(conf: SparkConf): Boolean = { + conf.get(config.BLACKLIST_ENABLED) match { + case Some(enabled) => + enabled + case None => + // if they've got a non-zero setting for the legacy conf, always enable the blacklist, + // otherwise, use the default. + val legacyKey = config.BLACKLIST_LEGACY_TIMEOUT_CONF.key + conf.get(config.BLACKLIST_LEGACY_TIMEOUT_CONF).exists { legacyTimeout => + if (legacyTimeout == 0) { + logWarning(s"Turning off blacklisting due to legacy configuration: $legacyKey == 0") + false + } else { + logWarning(s"Turning on blacklisting due to legacy configuration: $legacyKey > 0") + true + } + } + } + } + + def getBlacklistTimeout(conf: SparkConf): Long = { + conf.get(config.BLACKLIST_TIMEOUT_CONF).getOrElse { + conf.get(config.BLACKLIST_LEGACY_TIMEOUT_CONF).getOrElse { + Utils.timeStringAsMs(DEFAULT_TIMEOUT) + } + } + } + + /** + * Verify that blacklist configurations are consistent; if not, throw an exception. Should only + * be called if blacklisting is enabled. + * + * The configuration for the blacklist is expected to adhere to a few invariants. Default + * values follow these rules of course, but users may unwittingly change one configuration + * without making the corresponding adjustment elsewhere. This ensures we fail-fast when + * there are such misconfigurations. + */ + def validateBlacklistConfs(conf: SparkConf): Unit = { + + def mustBePos(k: String, v: String): Unit = { + throw new IllegalArgumentException(s"$k was $v, but must be > 0.") + } + + Seq( + config.MAX_TASK_ATTEMPTS_PER_EXECUTOR, + config.MAX_TASK_ATTEMPTS_PER_NODE, + config.MAX_FAILURES_PER_EXEC_STAGE, + config.MAX_FAILED_EXEC_PER_NODE_STAGE + ).foreach { config => + val v = conf.get(config) + if (v <= 0) { + mustBePos(config.key, v.toString) + } + } + + val timeout = getBlacklistTimeout(conf) + if (timeout <= 0) { + // first, figure out where the timeout came from, to include the right conf in the message. + conf.get(config.BLACKLIST_TIMEOUT_CONF) match { + case Some(t) => + mustBePos(config.BLACKLIST_TIMEOUT_CONF.key, timeout.toString) + case None => + mustBePos(config.BLACKLIST_LEGACY_TIMEOUT_CONF.key, timeout.toString) + } + } + + val maxTaskFailures = conf.get(config.MAX_TASK_FAILURES) + val maxNodeAttempts = conf.get(config.MAX_TASK_ATTEMPTS_PER_NODE) + + if (maxNodeAttempts >= maxTaskFailures) { + throw new IllegalArgumentException(s"${config.MAX_TASK_ATTEMPTS_PER_NODE.key} " + + s"( = ${maxNodeAttempts}) was >= ${config.MAX_TASK_FAILURES.key} " + + s"( = ${maxTaskFailures} ). Though blacklisting is enabled, with this configuration, " + + s"Spark will not be robust to one bad node. Decrease " + + s"${config.MAX_TASK_ATTEMPTS_PER_NODE.key}, increase ${config.MAX_TASK_FAILURES.key}, " + + s"or disable blacklisting with ${config.BLACKLIST_ENABLED.key}") + } + } +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/ExecutorFailuresInTaskSet.scala b/core/src/main/scala/org/apache/spark/scheduler/ExecutorFailuresInTaskSet.scala new file mode 100644 index 0000000000000..20ab27d127aba --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/ExecutorFailuresInTaskSet.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler + +import scala.collection.mutable.HashMap + +/** + * Small helper for tracking failed tasks for blacklisting purposes. Info on all failures on one + * executor, within one task set. + */ +private[scheduler] class ExecutorFailuresInTaskSet(val node: String) { + /** + * Mapping from index of the tasks in the taskset, to the number of times it has failed on this + * executor. + */ + val taskToFailureCount = HashMap[Int, Int]() + + def updateWithFailure(taskIndex: Int): Unit = { + val prevFailureCount = taskToFailureCount.getOrElse(taskIndex, 0) + taskToFailureCount(taskIndex) = prevFailureCount + 1 + } + + def numUniqueTasksWithFailures: Int = taskToFailureCount.size + + /** + * Return the number of times this executor has failed on the given task index. + */ + def getNumTaskFailures(index: Int): Int = { + taskToFailureCount.getOrElse(index, 0) + } + + override def toString(): String = { + s"numUniqueTasksWithFailures = $numUniqueTasksWithFailures; " + + s"tasksToFailureCount = $taskToFailureCount" + } +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 0ad4730fe20a6..3e3f1ad031e66 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -22,14 +22,14 @@ import java.util.{Timer, TimerTask} import java.util.concurrent.TimeUnit import java.util.concurrent.atomic.AtomicLong -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap -import scala.collection.mutable.HashSet +import scala.collection.Set +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} import scala.util.Random import org.apache.spark._ import org.apache.spark.TaskState.TaskState import org.apache.spark.internal.Logging +import org.apache.spark.internal.config import org.apache.spark.scheduler.SchedulingMode.SchedulingMode import org.apache.spark.scheduler.TaskLocality.TaskLocality import org.apache.spark.scheduler.local.LocalSchedulerBackend @@ -57,7 +57,7 @@ private[spark] class TaskSchedulerImpl( isLocal: Boolean = false) extends TaskScheduler with Logging { - def this(sc: SparkContext) = this(sc, sc.conf.getInt("spark.task.maxFailures", 4)) + def this(sc: SparkContext) = this(sc, sc.conf.get(config.MAX_TASK_FAILURES)) val conf = sc.conf @@ -100,7 +100,7 @@ private[spark] class TaskSchedulerImpl( // The set of executors we have on each host; this is used to compute hostsAlive, which // in turn is used to decide when we can attain data locality on a given host - protected val executorsByHost = new HashMap[String, HashSet[String]] + protected val hostToExecutors = new HashMap[String, HashSet[String]] protected val hostsByRack = new HashMap[String, HashSet[String]] @@ -243,8 +243,8 @@ private[spark] class TaskSchedulerImpl( } } manager.parent.removeSchedulable(manager) - logInfo("Removed TaskSet %s, whose tasks have all completed, from pool %s" - .format(manager.taskSet.id, manager.parent.name)) + logInfo(s"Removed TaskSet ${manager.taskSet.id}, whose tasks have all completed, from pool" + + s" ${manager.parent.name}") } private def resourceOfferSingleTaskSet( @@ -291,11 +291,11 @@ private[spark] class TaskSchedulerImpl( // Also track if new executor is added var newExecAvail = false for (o <- offers) { - if (!executorsByHost.contains(o.host)) { - executorsByHost(o.host) = new HashSet[String]() + if (!hostToExecutors.contains(o.host)) { + hostToExecutors(o.host) = new HashSet[String]() } if (!executorIdToTaskCount.contains(o.executorId)) { - executorsByHost(o.host) += o.executorId + hostToExecutors(o.host) += o.executorId executorAdded(o.executorId, o.host) executorIdToHost(o.executorId) = o.host executorIdToTaskCount(o.executorId) = 0 @@ -334,7 +334,7 @@ private[spark] class TaskSchedulerImpl( } while (launchedTaskAtCurrentMaxLocality) } if (!launchedAnyTask) { - taskSet.abortIfCompletelyBlacklisted(executorIdToHost.keys) + taskSet.abortIfCompletelyBlacklisted(hostToExecutors) } } @@ -542,10 +542,10 @@ private[spark] class TaskSchedulerImpl( executorIdToTaskCount -= executorId val host = executorIdToHost(executorId) - val execs = executorsByHost.getOrElse(host, new HashSet) + val execs = hostToExecutors.getOrElse(host, new HashSet) execs -= executorId if (execs.isEmpty) { - executorsByHost -= host + hostToExecutors -= host for (rack <- getRackForHost(host); hosts <- hostsByRack.get(rack)) { hosts -= host if (hosts.isEmpty) { @@ -565,11 +565,11 @@ private[spark] class TaskSchedulerImpl( } def getExecutorsAliveOnHost(host: String): Option[Set[String]] = synchronized { - executorsByHost.get(host).map(_.toSet) + hostToExecutors.get(host).map(_.toSet) } def hasExecutorsAliveOnHost(host: String): Boolean = synchronized { - executorsByHost.contains(host) + hostToExecutors.contains(host) } def hasHostAliveOnRack(rack: String): Boolean = synchronized { @@ -662,5 +662,4 @@ private[spark] object TaskSchedulerImpl { retval.toList } - } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetBlacklist.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetBlacklist.scala new file mode 100644 index 0000000000000..f4b0f55b7686a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetBlacklist.scala @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler + +import scala.collection.mutable.{HashMap, HashSet} + +import org.apache.spark.SparkConf +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config +import org.apache.spark.util.Clock + +/** + * Handles blacklisting executors and nodes within a taskset. This includes blacklisting specific + * (task, executor) / (task, nodes) pairs, and also completely blacklisting executors and nodes + * for the entire taskset. + * + * THREADING: This class is a helper to [[TaskSetManager]]; as with the methods in + * [[TaskSetManager]] this class is designed only to be called from code with a lock on the + * TaskScheduler (e.g. its event handlers). It should not be called from other threads. + */ +private[scheduler] class TaskSetBlacklist(val conf: SparkConf, val stageId: Int, val clock: Clock) + extends Logging { + + private val MAX_TASK_ATTEMPTS_PER_EXECUTOR = conf.get(config.MAX_TASK_ATTEMPTS_PER_EXECUTOR) + private val MAX_TASK_ATTEMPTS_PER_NODE = conf.get(config.MAX_TASK_ATTEMPTS_PER_NODE) + private val MAX_FAILURES_PER_EXEC_STAGE = conf.get(config.MAX_FAILURES_PER_EXEC_STAGE) + private val MAX_FAILED_EXEC_PER_NODE_STAGE = conf.get(config.MAX_FAILED_EXEC_PER_NODE_STAGE) + + /** + * A map from each executor to the task failures on that executor. + */ + val execToFailures = new HashMap[String, ExecutorFailuresInTaskSet]() + + /** + * Map from node to all executors on it with failures. Needed because we want to know about + * executors on a node even after they have died. (We don't want to bother tracking the + * node -> execs mapping in the usual case when there aren't any failures). + */ + private val nodeToExecsWithFailures = new HashMap[String, HashSet[String]]() + private val nodeToBlacklistedTaskIndexes = new HashMap[String, HashSet[Int]]() + private val blacklistedExecs = new HashSet[String]() + private val blacklistedNodes = new HashSet[String]() + + /** + * Return true if this executor is blacklisted for the given task. This does *not* + * need to return true if the executor is blacklisted for the entire stage. + * That is to keep this method as fast as possible in the inner-loop of the + * scheduler, where those filters will have already been applied. + */ + def isExecutorBlacklistedForTask(executorId: String, index: Int): Boolean = { + execToFailures.get(executorId).exists { execFailures => + execFailures.getNumTaskFailures(index) >= MAX_TASK_ATTEMPTS_PER_EXECUTOR + } + } + + def isNodeBlacklistedForTask(node: String, index: Int): Boolean = { + nodeToBlacklistedTaskIndexes.get(node).exists(_.contains(index)) + } + + /** + * Return true if this executor is blacklisted for the given stage. Completely ignores + * anything to do with the node the executor is on. That + * is to keep this method as fast as possible in the inner-loop of the scheduler, where those + * filters will already have been applied. + */ + def isExecutorBlacklistedForTaskSet(executorId: String): Boolean = { + blacklistedExecs.contains(executorId) + } + + def isNodeBlacklistedForTaskSet(node: String): Boolean = { + blacklistedNodes.contains(node) + } + + private[scheduler] def updateBlacklistForFailedTask( + host: String, + exec: String, + index: Int): Unit = { + val execFailures = execToFailures.getOrElseUpdate(exec, new ExecutorFailuresInTaskSet(host)) + execFailures.updateWithFailure(index) + + // check if this task has also failed on other executors on the same host -- if its gone + // over the limit, blacklist this task from the entire host. + val execsWithFailuresOnNode = nodeToExecsWithFailures.getOrElseUpdate(host, new HashSet()) + execsWithFailuresOnNode += exec + val failuresOnHost = execsWithFailuresOnNode.toIterator.flatMap { exec => + execToFailures.get(exec).map { failures => + // We count task attempts here, not the number of unique executors with failures. This is + // because jobs are aborted based on the number task attempts; if we counted unique + // executors, it would be hard to config to ensure that you try another + // node before hitting the max number of task failures. + failures.getNumTaskFailures(index) + } + }.sum + if (failuresOnHost >= MAX_TASK_ATTEMPTS_PER_NODE) { + nodeToBlacklistedTaskIndexes.getOrElseUpdate(host, new HashSet()) += index + } + + // Check if enough tasks have failed on the executor to blacklist it for the entire stage. + if (execFailures.numUniqueTasksWithFailures >= MAX_FAILURES_PER_EXEC_STAGE) { + if (blacklistedExecs.add(exec)) { + logInfo(s"Blacklisting executor ${exec} for stage $stageId") + // This executor has been pushed into the blacklist for this stage. Let's check if it + // pushes the whole node into the blacklist. + val blacklistedExecutorsOnNode = + execsWithFailuresOnNode.filter(blacklistedExecs.contains(_)) + if (blacklistedExecutorsOnNode.size >= MAX_FAILED_EXEC_PER_NODE_STAGE) { + if (blacklistedNodes.add(host)) { + logInfo(s"Blacklisting ${host} for stage $stageId") + } + } + } + } + } +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 226bed284a40a..9491bc7a0497e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -22,9 +22,7 @@ import java.nio.ByteBuffer import java.util.Arrays import java.util.concurrent.ConcurrentLinkedQueue -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap -import scala.collection.mutable.HashSet +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} import scala.math.{max, min} import scala.util.control.NonFatal @@ -53,19 +51,9 @@ private[spark] class TaskSetManager( sched: TaskSchedulerImpl, val taskSet: TaskSet, val maxTaskFailures: Int, - clock: Clock = new SystemClock()) - extends Schedulable with Logging { + clock: Clock = new SystemClock()) extends Schedulable with Logging { - val conf = sched.sc.conf - - /* - * Sometimes if an executor is dead or in an otherwise invalid state, the driver - * does not realize right away leading to repeated task failures. If enabled, - * this temporarily prevents a task from re-launching on an executor where - * it just failed. - */ - private val EXECUTOR_TASK_BLACKLIST_TIMEOUT = - conf.getLong("spark.scheduler.executorTaskBlacklistTime", 0L) + private val conf = sched.sc.conf // Quantile of tasks at which to start speculation val SPECULATION_QUANTILE = conf.getDouble("spark.speculation.quantile", 0.75) @@ -83,8 +71,6 @@ private[spark] class TaskSetManager( val copiesRunning = new Array[Int](numTasks) val successful = new Array[Boolean](numTasks) private val numFailures = new Array[Int](numTasks) - // key is taskId (aka TaskInfo.index), value is a Map of executor id to when it failed - private val failedExecutors = new HashMap[Int, HashMap[String, Long]]() val taskAttempts = Array.fill[List[TaskInfo]](numTasks)(Nil) var tasksSuccessful = 0 @@ -98,6 +84,14 @@ private[spark] class TaskSetManager( var totalResultSize = 0L var calculatedTasks = 0 + private val taskSetBlacklistHelperOpt: Option[TaskSetBlacklist] = { + if (BlacklistTracker.isBlacklistEnabled(conf)) { + Some(new TaskSetBlacklist(conf, stageId, clock)) + } else { + None + } + } + val runningTasksSet = new HashSet[Long] override def runningTasks: Int = runningTasksSet.size @@ -245,12 +239,15 @@ private[spark] class TaskSetManager( * This method also cleans up any tasks in the list that have already * been launched, since we want that to happen lazily. */ - private def dequeueTaskFromList(execId: String, list: ArrayBuffer[Int]): Option[Int] = { + private def dequeueTaskFromList( + execId: String, + host: String, + list: ArrayBuffer[Int]): Option[Int] = { var indexOffset = list.size while (indexOffset > 0) { indexOffset -= 1 val index = list(indexOffset) - if (!executorIsBlacklisted(execId, index)) { + if (!isTaskBlacklistedOnExecOrNode(index, execId, host)) { // This should almost always be list.trimEnd(1) to remove tail list.remove(indexOffset) if (copiesRunning(index) == 0 && !successful(index)) { @@ -266,19 +263,11 @@ private[spark] class TaskSetManager( taskAttempts(taskIndex).exists(_.host == host) } - /** - * Is this re-execution of a failed task on an executor it already failed in before - * EXECUTOR_TASK_BLACKLIST_TIMEOUT has elapsed ? - */ - private[scheduler] def executorIsBlacklisted(execId: String, taskId: Int): Boolean = { - if (failedExecutors.contains(taskId)) { - val failed = failedExecutors.get(taskId).get - - return failed.contains(execId) && - clock.getTimeMillis() - failed.get(execId).get < EXECUTOR_TASK_BLACKLIST_TIMEOUT + private def isTaskBlacklistedOnExecOrNode(index: Int, execId: String, host: String): Boolean = { + taskSetBlacklistHelperOpt.exists { blacklist => + blacklist.isNodeBlacklistedForTask(host, index) || + blacklist.isExecutorBlacklistedForTask(execId, index) } - - false } /** @@ -292,8 +281,10 @@ private[spark] class TaskSetManager( { speculatableTasks.retain(index => !successful(index)) // Remove finished tasks from set - def canRunOnHost(index: Int): Boolean = - !hasAttemptOnHost(index, host) && !executorIsBlacklisted(execId, index) + def canRunOnHost(index: Int): Boolean = { + !hasAttemptOnHost(index, host) && + !isTaskBlacklistedOnExecOrNode(index, execId, host) + } if (!speculatableTasks.isEmpty) { // Check for process-local tasks; note that tasks can be process-local @@ -366,19 +357,19 @@ private[spark] class TaskSetManager( private def dequeueTask(execId: String, host: String, maxLocality: TaskLocality.Value) : Option[(Int, TaskLocality.Value, Boolean)] = { - for (index <- dequeueTaskFromList(execId, getPendingTasksForExecutor(execId))) { + for (index <- dequeueTaskFromList(execId, host, getPendingTasksForExecutor(execId))) { return Some((index, TaskLocality.PROCESS_LOCAL, false)) } if (TaskLocality.isAllowed(maxLocality, TaskLocality.NODE_LOCAL)) { - for (index <- dequeueTaskFromList(execId, getPendingTasksForHost(host))) { + for (index <- dequeueTaskFromList(execId, host, getPendingTasksForHost(host))) { return Some((index, TaskLocality.NODE_LOCAL, false)) } } if (TaskLocality.isAllowed(maxLocality, TaskLocality.NO_PREF)) { // Look for noPref tasks after NODE_LOCAL for minimize cross-rack traffic - for (index <- dequeueTaskFromList(execId, pendingTasksWithNoPrefs)) { + for (index <- dequeueTaskFromList(execId, host, pendingTasksWithNoPrefs)) { return Some((index, TaskLocality.PROCESS_LOCAL, false)) } } @@ -386,14 +377,14 @@ private[spark] class TaskSetManager( if (TaskLocality.isAllowed(maxLocality, TaskLocality.RACK_LOCAL)) { for { rack <- sched.getRackForHost(host) - index <- dequeueTaskFromList(execId, getPendingTasksForRack(rack)) + index <- dequeueTaskFromList(execId, host, getPendingTasksForRack(rack)) } { return Some((index, TaskLocality.RACK_LOCAL, false)) } } if (TaskLocality.isAllowed(maxLocality, TaskLocality.ANY)) { - for (index <- dequeueTaskFromList(execId, allPendingTasks)) { + for (index <- dequeueTaskFromList(execId, host, allPendingTasks)) { return Some((index, TaskLocality.ANY, false)) } } @@ -421,7 +412,11 @@ private[spark] class TaskSetManager( maxLocality: TaskLocality.TaskLocality) : Option[TaskDescription] = { - if (!isZombie) { + val offerBlacklisted = taskSetBlacklistHelperOpt.exists { blacklist => + blacklist.isNodeBlacklistedForTaskSet(host) || + blacklist.isExecutorBlacklistedForTaskSet(execId) + } + if (!isZombie && !offerBlacklisted) { val curTime = clock.getTimeMillis() var allowedLocality = maxLocality @@ -434,60 +429,59 @@ private[spark] class TaskSetManager( } } - dequeueTask(execId, host, allowedLocality) match { - case Some((index, taskLocality, speculative)) => - // Found a task; do some bookkeeping and return a task description - val task = tasks(index) - val taskId = sched.newTaskId() - // Do various bookkeeping - copiesRunning(index) += 1 - val attemptNum = taskAttempts(index).size - val info = new TaskInfo(taskId, index, attemptNum, curTime, - execId, host, taskLocality, speculative) - taskInfos(taskId) = info - taskAttempts(index) = info :: taskAttempts(index) - // Update our locality level for delay scheduling - // NO_PREF will not affect the variables related to delay scheduling - if (maxLocality != TaskLocality.NO_PREF) { - currentLocalityIndex = getLocalityIndex(taskLocality) - lastLaunchTime = curTime - } - // Serialize and return the task - val startTime = clock.getTimeMillis() - val serializedTask: ByteBuffer = try { - Task.serializeWithDependencies(task, sched.sc.addedFiles, sched.sc.addedJars, ser) - } catch { - // If the task cannot be serialized, then there's no point to re-attempt the task, - // as it will always fail. So just abort the whole task-set. - case NonFatal(e) => - val msg = s"Failed to serialize task $taskId, not attempting to retry it." - logError(msg, e) - abort(s"$msg Exception during serialization: $e") - throw new TaskNotSerializableException(e) - } - if (serializedTask.limit > TaskSetManager.TASK_SIZE_TO_WARN_KB * 1024 && - !emittedTaskSizeWarning) { - emittedTaskSizeWarning = true - logWarning(s"Stage ${task.stageId} contains a task of very large size " + - s"(${serializedTask.limit / 1024} KB). The maximum recommended task size is " + - s"${TaskSetManager.TASK_SIZE_TO_WARN_KB} KB.") - } - addRunningTask(taskId) - - // We used to log the time it takes to serialize the task, but task size is already - // a good proxy to task serialization time. - // val timeTaken = clock.getTime() - startTime - val taskName = s"task ${info.id} in stage ${taskSet.id}" - logInfo(s"Starting $taskName (TID $taskId, $host, partition ${task.partitionId}," + - s" $taskLocality, ${serializedTask.limit} bytes)") - - sched.dagScheduler.taskStarted(task, info) - return Some(new TaskDescription(taskId = taskId, attemptNumber = attemptNum, execId, - taskName, index, serializedTask)) - case _ => + dequeueTask(execId, host, allowedLocality).map { case ((index, taskLocality, speculative)) => + // Found a task; do some bookkeeping and return a task description + val task = tasks(index) + val taskId = sched.newTaskId() + // Do various bookkeeping + copiesRunning(index) += 1 + val attemptNum = taskAttempts(index).size + val info = new TaskInfo(taskId, index, attemptNum, curTime, + execId, host, taskLocality, speculative) + taskInfos(taskId) = info + taskAttempts(index) = info :: taskAttempts(index) + // Update our locality level for delay scheduling + // NO_PREF will not affect the variables related to delay scheduling + if (maxLocality != TaskLocality.NO_PREF) { + currentLocalityIndex = getLocalityIndex(taskLocality) + lastLaunchTime = curTime + } + // Serialize and return the task + val startTime = clock.getTimeMillis() + val serializedTask: ByteBuffer = try { + Task.serializeWithDependencies(task, sched.sc.addedFiles, sched.sc.addedJars, ser) + } catch { + // If the task cannot be serialized, then there's no point to re-attempt the task, + // as it will always fail. So just abort the whole task-set. + case NonFatal(e) => + val msg = s"Failed to serialize task $taskId, not attempting to retry it." + logError(msg, e) + abort(s"$msg Exception during serialization: $e") + throw new TaskNotSerializableException(e) + } + if (serializedTask.limit > TaskSetManager.TASK_SIZE_TO_WARN_KB * 1024 && + !emittedTaskSizeWarning) { + emittedTaskSizeWarning = true + logWarning(s"Stage ${task.stageId} contains a task of very large size " + + s"(${serializedTask.limit / 1024} KB). The maximum recommended task size is " + + s"${TaskSetManager.TASK_SIZE_TO_WARN_KB} KB.") + } + addRunningTask(taskId) + + // We used to log the time it takes to serialize the task, but task size is already + // a good proxy to task serialization time. + // val timeTaken = clock.getTime() - startTime + val taskName = s"task ${info.id} in stage ${taskSet.id}" + logInfo(s"Starting $taskName (TID $taskId, $host, executor ${info.executorId}, " + + s"partition ${task.partitionId}, $taskLocality, ${serializedTask.limit} bytes)") + + sched.dagScheduler.taskStarted(task, info) + new TaskDescription(taskId = taskId, attemptNumber = attemptNum, execId, + taskName, index, serializedTask) } + } else { + None } - None } private def maybeFinishTaskSet() { @@ -589,37 +583,56 @@ private[spark] class TaskSetManager( * the hang as quickly as we could have, but we'll always detect the hang eventually, and the * method is faster in the typical case. In the worst case, this method can take * O(maxTaskFailures + numTasks) time, but it will be faster when there haven't been any task - * failures (this is because the method picks on unscheduled task, and then iterates through each - * executor until it finds one that the task hasn't failed on already). + * failures (this is because the method picks one unscheduled task, and then iterates through each + * executor until it finds one that the task isn't blacklisted on). */ - private[scheduler] def abortIfCompletelyBlacklisted(executors: Iterable[String]): Unit = { - - val pendingTask: Option[Int] = { - // usually this will just take the last pending task, but because of the lazy removal - // from each list, we may need to go deeper in the list. We poll from the end because - // failed tasks are put back at the end of allPendingTasks, so we're more likely to find - // an unschedulable task this way. - val indexOffset = allPendingTasks.lastIndexWhere { indexInTaskSet => - copiesRunning(indexInTaskSet) == 0 && !successful(indexInTaskSet) - } - if (indexOffset == -1) { - None - } else { - Some(allPendingTasks(indexOffset)) - } - } + private[scheduler] def abortIfCompletelyBlacklisted( + hostToExecutors: HashMap[String, HashSet[String]]): Unit = { + taskSetBlacklistHelperOpt.foreach { taskSetBlacklist => + // Only look for unschedulable tasks when at least one executor has registered. Otherwise, + // task sets will be (unnecessarily) aborted in cases when no executors have registered yet. + if (hostToExecutors.nonEmpty) { + // find any task that needs to be scheduled + val pendingTask: Option[Int] = { + // usually this will just take the last pending task, but because of the lazy removal + // from each list, we may need to go deeper in the list. We poll from the end because + // failed tasks are put back at the end of allPendingTasks, so we're more likely to find + // an unschedulable task this way. + val indexOffset = allPendingTasks.lastIndexWhere { indexInTaskSet => + copiesRunning(indexInTaskSet) == 0 && !successful(indexInTaskSet) + } + if (indexOffset == -1) { + None + } else { + Some(allPendingTasks(indexOffset)) + } + } - // If no executors have registered yet, don't abort the stage, just wait. We probably - // got here because a task set was added before the executors registered. - if (executors.nonEmpty) { - // take any task that needs to be scheduled, and see if we can find some executor it *could* - // run on - pendingTask.foreach { taskId => - if (executors.forall(executorIsBlacklisted(_, taskId))) { - val execs = executors.toIndexedSeq.sorted.mkString("(", ",", ")") - val partition = tasks(taskId).partitionId - abort(s"Aborting ${taskSet} because task $taskId (partition $partition)" + - s" has already failed on executors $execs, and no other executors are available.") + pendingTask.foreach { indexInTaskSet => + // try to find some executor this task can run on. Its possible that some *other* + // task isn't schedulable anywhere, but we will discover that in some later call, + // when that unschedulable task is the last task remaining. + val blacklistedEverywhere = hostToExecutors.forall { case (host, execsOnHost) => + // Check if the task can run on the node + val nodeBlacklisted = + taskSetBlacklist.isNodeBlacklistedForTaskSet(host) || + taskSetBlacklist.isNodeBlacklistedForTask(host, indexInTaskSet) + if (nodeBlacklisted) { + true + } else { + // Check if the task can run on any of the executors + execsOnHost.forall { exec => + taskSetBlacklist.isExecutorBlacklistedForTaskSet(exec) || + taskSetBlacklist.isExecutorBlacklistedForTask(exec, indexInTaskSet) + } + } + } + if (blacklistedEverywhere) { + val partition = tasks(indexInTaskSet).partitionId + abort(s"Aborting $taskSet because task $indexInTaskSet (partition $partition) " + + s"cannot run anywhere due to node and executor blacklist. Blacklisting behavior " + + s"can be configured via spark.blacklist.*.") + } } } } @@ -677,8 +690,9 @@ private[spark] class TaskSetManager( } if (!successful(index)) { tasksSuccessful += 1 - logInfo("Finished task %s in stage %s (TID %d) in %d ms on %s (%d/%d)".format( - info.id, taskSet.id, info.taskId, info.duration, info.host, tasksSuccessful, numTasks)) + logInfo(s"Finished task ${info.id} in stage ${taskSet.id} (TID ${info.taskId}) in" + + s" ${info.duration} ms on ${info.host} (executor ${info.executorId})" + + s" ($tasksSuccessful/$numTasks)") // Mark successful and stop if all the tasks have succeeded. successful(index) = true if (tasksSuccessful == numTasks) { @@ -688,7 +702,6 @@ private[spark] class TaskSetManager( logInfo("Ignoring task-finished event for " + info.id + " in stage " + taskSet.id + " because task " + index + " has already completed successfully") } - failedExecutors.remove(index) maybeFinishTaskSet() } @@ -706,8 +719,8 @@ private[spark] class TaskSetManager( val index = info.index copiesRunning(index) -= 1 var accumUpdates: Seq[AccumulatorV2[_, _]] = Seq.empty - val failureReason = s"Lost task ${info.id} in stage ${taskSet.id} (TID $tid, ${info.host}): " + - reason.toErrorString + val failureReason = s"Lost task ${info.id} in stage ${taskSet.id} (TID $tid, ${info.host}," + + s" executor ${info.executorId}): ${reason.toErrorString}" val failureException: Option[Throwable] = reason match { case fetchFailed: FetchFailed => logWarning(failureReason) @@ -715,7 +728,6 @@ private[spark] class TaskSetManager( successful(index) = true tasksSuccessful += 1 } - // Not adding to failed executors for FetchFailed. isZombie = true None @@ -751,8 +763,8 @@ private[spark] class TaskSetManager( logWarning(failureReason) } else { logInfo( - s"Lost task ${info.id} in stage ${taskSet.id} (TID $tid) on executor ${info.host}: " + - s"${ef.className} (${ef.description}) [duplicate $dupCount]") + s"Lost task ${info.id} in stage ${taskSet.id} (TID $tid) on ${info.host}, executor" + + s" ${info.executorId}: ${ef.className} (${ef.description}) [duplicate $dupCount]") } ef.exception @@ -766,9 +778,7 @@ private[spark] class TaskSetManager( logWarning(failureReason) None } - // always add to failed executors - failedExecutors.getOrElseUpdate(index, new HashMap[String, Long]()). - put(info.executorId, clock.getTimeMillis()) + sched.dagScheduler.taskEnded(tasks(index), reason, null, accumUpdates, info) if (successful(index)) { @@ -780,7 +790,9 @@ private[spark] class TaskSetManager( addPendingTask(index) } - if (!isZombie && state != TaskState.KILLED && reason.countTowardsTaskFailures) { + if (!isZombie && reason.countTowardsTaskFailures) { + taskSetBlacklistHelperOpt.foreach(_.updateBlacklistForFailedTask( + info.host, info.executorId, index)) assert (null != failureReason) numFailures(index) += 1 if (numFailures(index) >= maxTaskFailures) { diff --git a/core/src/test/scala/org/apache/spark/scheduler/BlacklistIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BlacklistIntegrationSuite.scala index 14c8b664d4d8b..f6015cd51c2bd 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/BlacklistIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/BlacklistIntegrationSuite.scala @@ -16,10 +16,10 @@ */ package org.apache.spark.scheduler -import scala.concurrent.Await import scala.concurrent.duration._ import org.apache.spark._ +import org.apache.spark.internal.config class BlacklistIntegrationSuite extends SchedulerIntegrationSuite[MultiExecutorMockBackend]{ @@ -42,7 +42,10 @@ class BlacklistIntegrationSuite extends SchedulerIntegrationSuite[MultiExecutorM // Test demonstrating the issue -- without a config change, the scheduler keeps scheduling // according to locality preferences, and so the job fails - testScheduler("If preferred node is bad, without blacklist job will fail") { + testScheduler("If preferred node is bad, without blacklist job will fail", + extraConfs = Seq( + config.BLACKLIST_ENABLED.key -> "false" + )) { val rdd = new MockRDDWithLocalityPrefs(sc, 10, Nil, badHost) withBackend(badHostBackend _) { val jobFuture = submit(rdd, (0 until 10).toArray) @@ -51,37 +54,38 @@ class BlacklistIntegrationSuite extends SchedulerIntegrationSuite[MultiExecutorM assertDataStructuresEmpty(noFailure = false) } - // even with the blacklist turned on, if maxTaskFailures is not more than the number - // of executors on the bad node, then locality preferences will lead to us cycling through - // the executors on the bad node, and still failing the job testScheduler( - "With blacklist on, job will still fail if there are too many bad executors on bad host", + "With default settings, job can succeed despite multiple bad executors on node", extraConfs = Seq( - // set this to something much longer than the test duration so that executors don't get - // removed from the blacklist during the test - ("spark.scheduler.executorTaskBlacklistTime", "10000000") + config.BLACKLIST_ENABLED.key -> "true", + config.MAX_TASK_FAILURES.key -> "4", + "spark.testing.nHosts" -> "2", + "spark.testing.nExecutorsPerHost" -> "5", + "spark.testing.nCoresPerExecutor" -> "10" ) ) { - val rdd = new MockRDDWithLocalityPrefs(sc, 10, Nil, badHost) + // To reliably reproduce the failure that would occur without blacklisting, we have to use 1 + // task. That way, we ensure this 1 task gets rotated through enough bad executors on the host + // to fail the taskSet, before we have a bunch of different tasks fail in the executors so we + // blacklist them. + // But the point here is -- without blacklisting, we would never schedule anything on the good + // host-1 before we hit too many failures trying our preferred host-0. + val rdd = new MockRDDWithLocalityPrefs(sc, 1, Nil, badHost) withBackend(badHostBackend _) { - val jobFuture = submit(rdd, (0 until 10).toArray) + val jobFuture = submit(rdd, (0 until 1).toArray) awaitJobTermination(jobFuture, duration) } - assertDataStructuresEmpty(noFailure = false) + assertDataStructuresEmpty(noFailure = true) } - // Here we run with the blacklist on, and maxTaskFailures high enough that we'll eventually - // schedule on a good node and succeed the job + // Here we run with the blacklist on, and the default config takes care of having this + // robust to one bad node. testScheduler( "Bad node with multiple executors, job will still succeed with the right confs", extraConfs = Seq( - // set this to something much longer than the test duration so that executors don't get - // removed from the blacklist during the test - ("spark.scheduler.executorTaskBlacklistTime", "10000000"), - // this has to be higher than the number of executors on the bad host - ("spark.task.maxFailures", "5"), + config.BLACKLIST_ENABLED.key -> "true", // just to avoid this test taking too long - ("spark.locality.wait", "10ms") + "spark.locality.wait" -> "10ms" ) ) { val rdd = new MockRDDWithLocalityPrefs(sc, 10, Nil, badHost) @@ -98,9 +102,7 @@ class BlacklistIntegrationSuite extends SchedulerIntegrationSuite[MultiExecutorM testScheduler( "SPARK-15865 Progress with fewer executors than maxTaskFailures", extraConfs = Seq( - // set this to something much longer than the test duration so that executors don't get - // removed from the blacklist during the test - "spark.scheduler.executorTaskBlacklistTime" -> "10000000", + config.BLACKLIST_ENABLED.key -> "true", "spark.testing.nHosts" -> "2", "spark.testing.nExecutorsPerHost" -> "1", "spark.testing.nCoresPerExecutor" -> "1" @@ -112,9 +114,9 @@ class BlacklistIntegrationSuite extends SchedulerIntegrationSuite[MultiExecutorM } withBackend(runBackend _) { val jobFuture = submit(new MockRDD(sc, 10, Nil), (0 until 10).toArray) - Await.ready(jobFuture, duration) + awaitJobTermination(jobFuture, duration) val pattern = ("Aborting TaskSet 0.0 because task .* " + - "already failed on executors \\(.*\\), and no other executors are available").r + "cannot run anywhere due to node and executor blacklist").r assert(pattern.findFirstIn(failure.getMessage).isDefined, s"Couldn't find $pattern in ${failure.getMessage()}") } diff --git a/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala new file mode 100644 index 0000000000000..b2e7ec5df015c --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.internal.config + +class BlacklistTrackerSuite extends SparkFunSuite { + + test("blacklist still respects legacy configs") { + val conf = new SparkConf().setMaster("local") + assert(!BlacklistTracker.isBlacklistEnabled(conf)) + conf.set(config.BLACKLIST_LEGACY_TIMEOUT_CONF, 5000L) + assert(BlacklistTracker.isBlacklistEnabled(conf)) + assert(5000 === BlacklistTracker.getBlacklistTimeout(conf)) + // the new conf takes precedence, though + conf.set(config.BLACKLIST_TIMEOUT_CONF, 1000L) + assert(1000 === BlacklistTracker.getBlacklistTimeout(conf)) + + // if you explicitly set the legacy conf to 0, that also would disable blacklisting + conf.set(config.BLACKLIST_LEGACY_TIMEOUT_CONF, 0L) + assert(!BlacklistTracker.isBlacklistEnabled(conf)) + // but again, the new conf takes precendence + conf.set(config.BLACKLIST_ENABLED, true) + assert(BlacklistTracker.isBlacklistEnabled(conf)) + assert(1000 === BlacklistTracker.getBlacklistTimeout(conf)) + } + + test("check blacklist configuration invariants") { + val conf = new SparkConf().setMaster("yarn-cluster") + Seq( + (2, 2), + (2, 3) + ).foreach { case (maxTaskFailures, maxNodeAttempts) => + conf.set(config.MAX_TASK_FAILURES, maxTaskFailures) + conf.set(config.MAX_TASK_ATTEMPTS_PER_NODE.key, maxNodeAttempts.toString) + val excMsg = intercept[IllegalArgumentException] { + BlacklistTracker.validateBlacklistConfs(conf) + }.getMessage() + assert(excMsg === s"${config.MAX_TASK_ATTEMPTS_PER_NODE.key} " + + s"( = ${maxNodeAttempts}) was >= ${config.MAX_TASK_FAILURES.key} " + + s"( = ${maxTaskFailures} ). Though blacklisting is enabled, with this configuration, " + + s"Spark will not be robust to one bad node. Decrease " + + s"${config.MAX_TASK_ATTEMPTS_PER_NODE.key}, increase ${config.MAX_TASK_FAILURES.key}, " + + s"or disable blacklisting with ${config.BLACKLIST_ENABLED.key}") + } + + conf.remove(config.MAX_TASK_FAILURES) + conf.remove(config.MAX_TASK_ATTEMPTS_PER_NODE) + + Seq( + config.MAX_TASK_ATTEMPTS_PER_EXECUTOR, + config.MAX_TASK_ATTEMPTS_PER_NODE, + config.MAX_FAILURES_PER_EXEC_STAGE, + config.MAX_FAILED_EXEC_PER_NODE_STAGE, + config.BLACKLIST_TIMEOUT_CONF + ).foreach { config => + conf.set(config.key, "0") + val excMsg = intercept[IllegalArgumentException] { + BlacklistTracker.validateBlacklistConfs(conf) + }.getMessage() + assert(excMsg.contains(s"${config.key} was 0, but must be > 0.")) + conf.remove(config) + } + } +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala index 5cd548bbc72d9..c28aa06623a60 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala @@ -620,9 +620,9 @@ class BasicSchedulerIntegrationSuite extends SchedulerIntegrationSuite[SingleCor val duration = Duration(1, SECONDS) awaitJobTermination(jobFuture, duration) } + assertDataStructuresEmpty() assert(results === (0 until 10).map { idx => idx -> (42 + idx) }.toMap) assert(stageToAttempts === Map(0 -> Set(0, 1), 1 -> Set(0, 1))) - assertDataStructuresEmpty() } testScheduler("job failure after 4 attempts") { @@ -634,7 +634,7 @@ class BasicSchedulerIntegrationSuite extends SchedulerIntegrationSuite[SingleCor val jobFuture = submit(new MockRDD(sc, 10, Nil), (0 until 10).toArray) val duration = Duration(1, SECONDS) awaitJobTermination(jobFuture, duration) - failure.getMessage.contains("test task failure") + assert(failure.getMessage.contains("test task failure")) } assertDataStructuresEmpty(noFailure = false) } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index 61787b54f824f..f5f1947661d9a 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.scheduler import org.scalatest.BeforeAndAfterEach import org.apache.spark._ +import org.apache.spark.internal.config import org.apache.spark.internal.Logging class FakeSchedulerBackend extends SchedulerBackend { @@ -32,7 +33,6 @@ class FakeSchedulerBackend extends SchedulerBackend { class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with BeforeAndAfterEach with Logging { - var failedTaskSetException: Option[Throwable] = None var failedTaskSetReason: String = null var failedTaskSet = false @@ -60,10 +60,11 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B } def setupScheduler(confs: (String, String)*): TaskSchedulerImpl = { - sc = new SparkContext("local", "TaskSchedulerImplSuite") + val conf = new SparkConf().setMaster("local").setAppName("TaskSchedulerImplSuite") confs.foreach { case (k, v) => - sc.conf.set(k, v) + conf.set(k, v) } + sc = new SparkContext(conf) taskScheduler = new TaskSchedulerImpl(sc) taskScheduler.initialize(new FakeSchedulerBackend) // Need to initialize a DAGScheduler for the taskScheduler to use for callbacks. @@ -287,9 +288,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B // schedulable on another executor. However, that executor may fail later on, leaving the // first task with no place to run. val taskScheduler = setupScheduler( - // set this to something much longer than the test duration so that executors don't get - // removed from the blacklist during the test - "spark.scheduler.executorTaskBlacklistTime" -> "10000000" + config.BLACKLIST_ENABLED.key -> "true" ) val taskSet = FakeTask.createTaskSet(2) @@ -328,8 +327,9 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B assert(tsm.isZombie) assert(failedTaskSet) val idx = failedTask.index - assert(failedTaskSetReason == s"Aborting TaskSet 0.0 because task $idx (partition $idx) has " + - s"already failed on executors (executor0), and no other executors are available.") + assert(failedTaskSetReason === s"Aborting TaskSet 0.0 because task $idx (partition $idx) " + + s"cannot run anywhere due to node and executor blacklist. Blacklisting behavior can be " + + s"configured via spark.blacklist.*.") } test("don't abort if there is an executor available, though it hasn't had scheduled tasks yet") { @@ -339,9 +339,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B // available and not bail on the job val taskScheduler = setupScheduler( - // set this to something much longer than the test duration so that executors don't get - // removed from the blacklist during the test - "spark.scheduler.executorTaskBlacklistTime" -> "10000000" + config.BLACKLIST_ENABLED.key -> "true" ) val taskSet = FakeTask.createTaskSet(2, (0 until 2).map { _ => Seq(TaskLocation("host0")) }: _*) @@ -377,7 +375,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B val taskScheduler = setupScheduler() taskScheduler.submitTasks(FakeTask.createTaskSet(2, 0, - (0 until 2).map { _ => Seq(TaskLocation("host0", "executor2"))}: _* + (0 until 2).map { _ => Seq(TaskLocation("host0", "executor2")) }: _* )) val taskDescs = taskScheduler.resourceOffers(IndexedSeq( diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetBlacklistSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetBlacklistSuite.scala new file mode 100644 index 0000000000000..8c902af5685ff --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetBlacklistSuite.scala @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.internal.config +import org.apache.spark.util.{ManualClock, SystemClock} + +class TaskSetBlacklistSuite extends SparkFunSuite { + + test("Blacklisting tasks, executors, and nodes") { + val conf = new SparkConf().setAppName("test").setMaster("local") + .set(config.BLACKLIST_ENABLED.key, "true") + val clock = new ManualClock + + val taskSetBlacklist = new TaskSetBlacklist(conf, stageId = 0, clock = clock) + clock.setTime(0) + // We will mark task 0 & 1 failed on both executor 1 & 2. + // We should blacklist all executors on that host, for all tasks for the stage. Note the API + // will return false for isExecutorBacklistedForTaskSet even when the node is blacklisted, so + // the executor is implicitly blacklisted (this makes sense with how the scheduler uses the + // blacklist) + + // First, mark task 0 as failed on exec1. + // task 0 should be blacklisted on exec1, and nowhere else + taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "exec1", index = 0) + for { + executor <- (1 to 4).map(_.toString) + index <- 0 until 10 + } { + val shouldBeBlacklisted = (executor == "exec1" && index == 0) + assert(taskSetBlacklist.isExecutorBlacklistedForTask(executor, index) === shouldBeBlacklisted) + } + assert(!taskSetBlacklist.isExecutorBlacklistedForTaskSet("exec1")) + assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) + + // Mark task 1 failed on exec1 -- this pushes the executor into the blacklist + taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "exec1", index = 1) + assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("exec1")) + assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) + // Mark one task as failed on exec2 -- not enough for any further blacklisting yet. + taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "exec2", index = 0) + assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("exec1")) + assert(!taskSetBlacklist.isExecutorBlacklistedForTaskSet("exec2")) + assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) + // Mark another task as failed on exec2 -- now we blacklist exec2, which also leads to + // blacklisting the entire node. + taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "exec2", index = 1) + assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("exec1")) + assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("exec2")) + assert(taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) + // Make sure the blacklist has the correct per-task && per-executor responses, over a wider + // range of inputs. + for { + executor <- (1 to 4).map(e => s"exec$e") + index <- 0 until 10 + } { + withClue(s"exec = $executor; index = $index") { + val badExec = (executor == "exec1" || executor == "exec2") + val badIndex = (index == 0 || index == 1) + assert( + // this ignores whether the executor is blacklisted entirely for the taskset -- that is + // intentional, it keeps it fast and is sufficient for usage in the scheduler. + taskSetBlacklist.isExecutorBlacklistedForTask(executor, index) === (badExec && badIndex)) + assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet(executor) === badExec) + } + } + assert(taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) + val execToFailures = taskSetBlacklist.execToFailures + assert(execToFailures.keySet === Set("exec1", "exec2")) + + Seq("exec1", "exec2").foreach { exec => + assert( + execToFailures(exec).taskToFailureCount === Map( + 0 -> 1, + 1 -> 1 + ) + ) + } + } + + test("multiple attempts for the same task count once") { + // Make sure that for blacklisting tasks, the node counts task attempts, not executors. But for + // stage-level blacklisting, we count unique tasks. The reason for this difference is, with + // task-attempt blacklisting, we want to make it easy to configure so that you ensure a node + // is blacklisted before the taskset is completely aborted because of spark.task.maxFailures. + // But with stage-blacklisting, we want to make sure we're not just counting one bad task + // that has failed many times. + + val conf = new SparkConf().setMaster("local").setAppName("test") + .set(config.MAX_TASK_ATTEMPTS_PER_EXECUTOR, 2) + .set(config.MAX_TASK_ATTEMPTS_PER_NODE, 3) + .set(config.MAX_FAILURES_PER_EXEC_STAGE, 2) + .set(config.MAX_FAILED_EXEC_PER_NODE_STAGE, 3) + val taskSetBlacklist = new TaskSetBlacklist(conf, stageId = 0, new SystemClock()) + // Fail a task twice on hostA, exec:1 + taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "1", index = 0) + taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "1", index = 0) + assert(taskSetBlacklist.isExecutorBlacklistedForTask("1", 0)) + assert(!taskSetBlacklist.isNodeBlacklistedForTask("hostA", 0)) + assert(!taskSetBlacklist.isExecutorBlacklistedForTaskSet("1")) + assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) + + // Fail the same task once more on hostA, exec:2 + taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "2", index = 0) + assert(taskSetBlacklist.isNodeBlacklistedForTask("hostA", 0)) + assert(!taskSetBlacklist.isExecutorBlacklistedForTaskSet("2")) + assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) + + // Fail another task on hostA, exec:1. Now that executor has failures on two different tasks, + // so its blacklisted + taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "1", index = 1) + assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("1")) + assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) + + // Fail a third task on hostA, exec:2, so that exec is blacklisted for the whole task set + taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "2", index = 2) + assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("2")) + assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) + + // Fail a fourth & fifth task on hostA, exec:3. Now we've got three executors that are + // blacklisted for the taskset, so blacklist the whole node. + taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "3", index = 3) + taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "3", index = 4) + assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("3")) + assert(taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) + } + + test("only blacklist nodes for the task set when all the blacklisted executors are all on " + + "same host") { + // we blacklist executors on two different hosts within one taskSet -- make sure that doesn't + // lead to any node blacklisting + val conf = new SparkConf().setAppName("test").setMaster("local") + .set(config.BLACKLIST_ENABLED.key, "true") + val taskSetBlacklist = new TaskSetBlacklist(conf, stageId = 0, new SystemClock()) + taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "1", index = 0) + taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "1", index = 1) + assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("1")) + assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) + + taskSetBlacklist.updateBlacklistForFailedTask("hostB", exec = "2", index = 0) + taskSetBlacklist.updateBlacklistForFailedTask("hostB", exec = "2", index = 1) + assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("1")) + assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("2")) + assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) + assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostB")) + } + +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 7d6ad08036cb4..69edcf3347243 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -25,6 +25,7 @@ import scala.collection.mutable.ArrayBuffer import org.mockito.Mockito.{mock, verify} import org.apache.spark._ +import org.apache.spark.internal.config import org.apache.spark.internal.Logging import org.apache.spark.util.{AccumulatorV2, ManualClock} @@ -103,7 +104,7 @@ class FakeTaskScheduler(sc: SparkContext, liveExecutors: (String, String)* /* ex val host = executorIdToHost.get(execId) assert(host != None) val hostId = host.get - val executorsOnHost = executorsByHost(hostId) + val executorsOnHost = hostToExecutors(hostId) executorsOnHost -= execId for (rack <- getRackForHost(hostId); hosts <- hostsByRack.get(rack)) { hosts -= hostId @@ -125,7 +126,7 @@ class FakeTaskScheduler(sc: SparkContext, liveExecutors: (String, String)* /* ex def addExecutor(execId: String, host: String) { executors.put(execId, host) - val executorsOnHost = executorsByHost.getOrElseUpdate(host, new mutable.HashSet[String]) + val executorsOnHost = hostToExecutors.getOrElseUpdate(host, new mutable.HashSet[String]) executorsOnHost += execId executorIdToHost += execId -> host for (rack <- getRackForHost(host)) { @@ -411,7 +412,8 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg test("executors should be blacklisted after task failure, in spite of locality preferences") { val rescheduleDelay = 300L val conf = new SparkConf(). - set("spark.scheduler.executorTaskBlacklistTime", rescheduleDelay.toString). + set(config.BLACKLIST_ENABLED, true). + set(config.BLACKLIST_TIMEOUT_CONF, rescheduleDelay). // don't wait to jump locality levels in this test set("spark.locality.wait", "0") @@ -475,19 +477,24 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg assert(manager.resourceOffer("exec2", "host2", ANY).isEmpty) } - // After reschedule delay, scheduling on exec1 should be possible. + // Despite advancing beyond the time for expiring executors from within the blacklist, + // we *never* expire from *within* the stage blacklist clock.advance(rescheduleDelay) { val offerResult = manager.resourceOffer("exec1", "host1", PROCESS_LOCAL) - assert(offerResult.isDefined, "Expect resource offer to return a task") + assert(offerResult.isEmpty) + } + { + val offerResult = manager.resourceOffer("exec3", "host3", ANY) + assert(offerResult.isDefined) assert(offerResult.get.index === 0) - assert(offerResult.get.executorId === "exec1") + assert(offerResult.get.executorId === "exec3") - assert(manager.resourceOffer("exec1", "host1", PROCESS_LOCAL).isEmpty) + assert(manager.resourceOffer("exec3", "host3", ANY).isEmpty) - // Cause exec1 to fail : failure 4 + // Cause exec3 to fail : failure 4 manager.handleFailedTask(offerResult.get.taskId, TaskState.FINISHED, TaskResultLost) } @@ -859,6 +866,114 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg assert(sched.endedTasks(3) === Success) } + test("Killing speculative tasks does not count towards aborting the taskset") { + sc = new SparkContext("local", "test") + sched = new FakeTaskScheduler(sc, ("exec1", "host1"), ("exec2", "host2")) + val taskSet = FakeTask.createTaskSet(5) + // Set the speculation multiplier to be 0 so speculative tasks are launched immediately + sc.conf.set("spark.speculation.multiplier", "0.0") + sc.conf.set("spark.speculation.quantile", "0.6") + val clock = new ManualClock() + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) + val accumUpdatesByTask: Array[Seq[AccumulatorV2[_, _]]] = taskSet.tasks.map { task => + task.metrics.internalAccums + } + // Offer resources for 5 tasks to start + val tasks = new ArrayBuffer[TaskDescription]() + for ((k, v) <- List( + "exec1" -> "host1", + "exec1" -> "host1", + "exec1" -> "host1", + "exec2" -> "host2", + "exec2" -> "host2")) { + val taskOption = manager.resourceOffer(k, v, NO_PREF) + assert(taskOption.isDefined) + val task = taskOption.get + assert(task.executorId === k) + tasks += task + } + assert(sched.startedTasks.toSet === (0 until 5).toSet) + // Complete 3 tasks and leave 2 tasks in running + for (id <- Set(0, 1, 2)) { + manager.handleSuccessfulTask(id, createTaskResult(id, accumUpdatesByTask(id))) + assert(sched.endedTasks(id) === Success) + } + + def runningTaskForIndex(index: Int): TaskDescription = { + tasks.find { task => + task.index == index && !sched.endedTasks.contains(task.taskId) + }.getOrElse { + throw new RuntimeException(s"couldn't find index $index in " + + s"tasks: ${tasks.map{t => t.index -> t.taskId}} with endedTasks:" + + s" ${sched.endedTasks.keys}") + } + } + + // have each of the running tasks fail 3 times (not enough to abort the stage) + (0 until 3).foreach { attempt => + Seq(3, 4).foreach { index => + val task = runningTaskForIndex(index) + logInfo(s"failing task $task") + val endReason = ExceptionFailure("a", "b", Array(), "c", None) + manager.handleFailedTask(task.taskId, TaskState.FAILED, endReason) + sched.endedTasks(task.taskId) = endReason + assert(!manager.isZombie) + val nextTask = manager.resourceOffer(s"exec2", s"host2", NO_PREF) + assert(nextTask.isDefined, s"no offer for attempt $attempt of $index") + tasks += nextTask.get + } + } + + // we can't be sure which one of our running tasks will get another speculative copy + val originalTasks = Seq(3, 4).map { index => index -> runningTaskForIndex(index) }.toMap + + // checkSpeculatableTasks checks that the task runtime is greater than the threshold for + // speculating. Since we use a threshold of 0 for speculation, tasks need to be running for + // > 0ms, so advance the clock by 1ms here. + clock.advance(1) + assert(manager.checkSpeculatableTasks(0)) + // Offer resource to start the speculative attempt for the running task + val taskOption5 = manager.resourceOffer("exec1", "host1", NO_PREF) + assert(taskOption5.isDefined) + val speculativeTask = taskOption5.get + assert(speculativeTask.index === 3 || speculativeTask.index === 4) + assert(speculativeTask.taskId === 11) + assert(speculativeTask.executorId === "exec1") + assert(speculativeTask.attemptNumber === 4) + sched.backend = mock(classOf[SchedulerBackend]) + // Complete the speculative attempt for the running task + manager.handleSuccessfulTask(speculativeTask.taskId, createTaskResult(3, accumUpdatesByTask(3))) + // Verify that it kills other running attempt + val origTask = originalTasks(speculativeTask.index) + verify(sched.backend).killTask(origTask.taskId, "exec2", true) + // Because the SchedulerBackend was a mock, the 2nd copy of the task won't actually be + // killed, so the FakeTaskScheduler is only told about the successful completion + // of the speculated task. + assert(sched.endedTasks(3) === Success) + // also because the scheduler is a mock, our manager isn't notified about the task killed event, + // so we do that manually + manager.handleFailedTask(origTask.taskId, TaskState.KILLED, TaskKilled) + // this task has "failed" 4 times, but one of them doesn't count, so keep running the stage + assert(manager.tasksSuccessful === 4) + assert(!manager.isZombie) + + // now run another speculative task + val taskOpt6 = manager.resourceOffer("exec1", "host1", NO_PREF) + assert(taskOpt6.isDefined) + val speculativeTask2 = taskOpt6.get + assert(speculativeTask2.index === 3 || speculativeTask2.index === 4) + assert(speculativeTask2.index !== speculativeTask.index) + assert(speculativeTask2.attemptNumber === 4) + // Complete the speculative attempt for the running task + manager.handleSuccessfulTask(speculativeTask2.taskId, + createTaskResult(3, accumUpdatesByTask(3))) + // Verify that it kills other running attempt + val origTask2 = originalTasks(speculativeTask2.index) + verify(sched.backend).killTask(origTask2.taskId, "exec2", true) + assert(manager.tasksSuccessful === 5) + assert(manager.isZombie) + } + private def createTaskResult( id: Int, accumUpdates: Seq[AccumulatorV2[_, _]] = Seq.empty): DirectTaskResult[Int] = { diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala index c1484b0afa85f..46aa9c37986cc 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.serializer import com.esotericsoftware.kryo.Kryo import org.apache.spark._ +import org.apache.spark.internal.config import org.apache.spark.serializer.KryoDistributedTest._ import org.apache.spark.util.Utils @@ -29,7 +30,8 @@ class KryoSerializerDistributedSuite extends SparkFunSuite with LocalSparkContex val conf = new SparkConf(false) .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") .set("spark.kryo.registrator", classOf[AppJarRegistrator].getName) - .set("spark.task.maxFailures", "1") + .set(config.MAX_TASK_FAILURES, 1) + .set(config.BLACKLIST_ENABLED, false) val jar = TestUtils.createJarWithClasses(List(AppJarRegistrator.customClassName)) conf.setJars(List(jar.getPath)) diff --git a/docs/configuration.md b/docs/configuration.md index 82ce232b336d9..373e22d71a872 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1245,6 +1245,49 @@ Apart from these, the following properties are also available, and may be useful The interval length for the scheduler to revive the worker resource offers to run tasks. + + spark.blacklist.enabled + + false + + + If set to "true", prevent Spark from scheduling tasks on executors that have been blacklisted + due to too many task failures. The blacklisting algorithm can be further controlled by the + other "spark.blacklist" configuration options. + + + + spark.blacklist.task.maxTaskAttemptsPerExecutor + 1 + + (Experimental) For a given task, how many times it can be retried on one executor before the + executor is blacklisted for that task. + + + + spark.blacklist.task.maxTaskAttemptsPerNode + 2 + + (Experimental) For a given task, how many times it can be retried on one node, before the entire + node is blacklisted for that task. + + + + spark.blacklist.stage.maxFailedTasksPerExecutor + 2 + + (Experimental) How many different tasks must fail on one executor, within one stage, before the + executor is blacklisted for that stage. + + + + spark.blacklist.stage.maxFailedExecutorsPerNode + 2 + + (Experimental) How many different executors are marked as blacklisted for a given stage, before + the entire node is marked as failed for the stage. + + spark.speculation false diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala index 6e60b0e4fad15..19b6d2603129c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala @@ -23,6 +23,7 @@ import org.mockito.Mockito.mock import org.apache.spark._ import org.apache.spark.executor.TaskMetrics +import org.apache.spark.internal.config import org.apache.spark.rdd.RDD import org.apache.spark.scheduler._ import org.apache.spark.sql.{DataFrame, SparkSession} @@ -446,7 +447,7 @@ class SQLListenerMemoryLeakSuite extends SparkFunSuite { val conf = new SparkConf() .setMaster("local") .setAppName("test") - .set("spark.task.maxFailures", "1") // Don't retry the tasks to run this test quickly + .set(config.MAX_TASK_FAILURES, 1) // Don't retry the tasks to run this test quickly .set("spark.sql.ui.retainedExecutions", "50") // Set it to 50 to run this test quickly val sc = new SparkContext(conf) try { From f9a56a153e0579283160519065c7f3620d12da3e Mon Sep 17 00:00:00 2001 From: cody koeninger Date: Wed, 12 Oct 2016 15:22:06 -0700 Subject: [PATCH 263/306] [SPARK-17782][STREAMING][KAFKA] alternative eliminate race condition of poll twice ## What changes were proposed in this pull request? Alternative approach to https://github.com/apache/spark/pull/15387 Author: cody koeninger Closes #15401 from koeninger/SPARK-17782-alt. --- .../streaming/kafka010/ConsumerStrategy.scala | 4 ++++ .../kafka010/DirectKafkaInputDStream.scala | 23 +++++++++++++++++-- .../kafka010/DirectKafkaStreamSuite.scala | 12 ++++++---- 3 files changed, 32 insertions(+), 7 deletions(-) diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/ConsumerStrategy.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/ConsumerStrategy.scala index 60255fc655e5f..778c06ea16a2b 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/ConsumerStrategy.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/ConsumerStrategy.scala @@ -104,6 +104,8 @@ private case class Subscribe[K, V]( toSeek.asScala.foreach { case (topicPartition, offset) => consumer.seek(topicPartition, offset) } + // we've called poll, we must pause or next poll may consume messages and set position + consumer.pause(consumer.assignment()) } consumer @@ -154,6 +156,8 @@ private case class SubscribePattern[K, V]( toSeek.asScala.foreach { case (topicPartition, offset) => consumer.seek(topicPartition, offset) } + // we've called poll, we must pause or next poll may consume messages and set position + consumer.pause(consumer.assignment()) } consumer diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala index 13827f68f2cb5..432537ebf05b2 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala @@ -161,12 +161,31 @@ private[spark] class DirectKafkaInputDStream[K, V]( } } + /** + * The concern here is that poll might consume messages despite being paused, + * which would throw off consumer position. Fix position if this happens. + */ + private def paranoidPoll(c: Consumer[K, V]): Unit = { + val msgs = c.poll(0) + if (!msgs.isEmpty) { + // position should be minimum offset per topicpartition + msgs.asScala.foldLeft(Map[TopicPartition, Long]()) { (acc, m) => + val tp = new TopicPartition(m.topic, m.partition) + val off = acc.get(tp).map(o => Math.min(o, m.offset)).getOrElse(m.offset) + acc + (tp -> off) + }.foreach { case (tp, off) => + logInfo(s"poll(0) returned messages, seeking $tp to $off to compensate") + c.seek(tp, off) + } + } + } + /** * Returns the latest (highest) available offsets, taking new partitions into account. */ protected def latestOffsets(): Map[TopicPartition, Long] = { val c = consumer - c.poll(0) + paranoidPoll(c) val parts = c.assignment().asScala // make sure new partitions are reflected in currentOffsets @@ -223,7 +242,7 @@ private[spark] class DirectKafkaInputDStream[K, V]( override def start(): Unit = { val c = consumer - c.poll(0) + paranoidPoll(c) if (currentOffsets.isEmpty) { currentOffsets = c.assignment().asScala.map { tp => tp -> c.position(tp) diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala index e04f35eceb1b4..02aec43c3b34f 100644 --- a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala +++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala @@ -159,17 +159,19 @@ class DirectKafkaStreamSuite } test("pattern based subscription") { - val topics = List("pat1", "pat2", "advanced3") - // Should match 2 out of 3 topics + val topics = List("pat1", "pat2", "pat3", "advanced3") + // Should match 3 out of 4 topics val pat = """pat\d""".r.pattern val data = Map("a" -> 7, "b" -> 9) topics.foreach { t => kafkaTestUtils.createTopic(t) kafkaTestUtils.sendMessages(t, data) } - val offsets = Map(new TopicPartition("pat2", 0) -> 3L) - // 2 matching topics, one of which starts 3 messages later - val expectedTotal = (data.values.sum * 2) - 3 + val offsets = Map( + new TopicPartition("pat2", 0) -> 3L, + new TopicPartition("pat3", 0) -> 4L) + // 3 matching topics, two of which start a total of 7 messages later + val expectedTotal = (data.values.sum * 3) - 7 val kafkaParams = getKafkaParams("auto.offset.reset" -> "earliest") ssc = new StreamingContext(sparkConf, Milliseconds(1000)) From 6f20a92ca30f9c367009c4556939ea4de4284cb9 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 12 Oct 2016 16:45:10 -0700 Subject: [PATCH 264/306] [SPARK-17845] [SQL] More self-evident window function frame boundary API ## What changes were proposed in this pull request? This patch improves the window function frame boundary API to make it more obvious to read and to use. The two high level changes are: 1. Create Window.currentRow, Window.unboundedPreceding, Window.unboundedFollowing to indicate the special values in frame boundaries. These methods map to the special integral values so we are not breaking backward compatibility here. This change makes the frame boundaries more self-evident (instead of Long.MinValue, it becomes Window.unboundedPreceding). 2. In Python, for any value less than or equal to JVM's Long.MinValue, treat it as Window.unboundedPreceding. For any value larger than or equal to JVM's Long.MaxValue, treat it as Window.unboundedFollowing. Before this change, if the user specifies any value that is less than Long.MinValue but not -sys.maxsize (e.g. -sys.maxsize + 1), the number we pass over to the JVM would overflow, resulting in a frame that does not make sense. Code example required to specify a frame before this patch: ``` Window.rowsBetween(-Long.MinValue, 0) ``` While the above code should still work, the new way is more obvious to read: ``` Window.rowsBetween(Window.unboundedPreceding, Window.currentRow) ``` ## How was this patch tested? - Updated DataFrameWindowSuite (for Scala/Java) - Updated test_window_functions_cumulative_sum (for Python) - Renamed DataFrameWindowSuite DataFrameWindowFunctionsSuite to better reflect its purpose Author: Reynold Xin Closes #15438 from rxin/SPARK-17845. --- python/pyspark/sql/tests.py | 25 +++++- python/pyspark/sql/window.py | 89 +++++++++++++------ .../apache/spark/sql/expressions/Window.scala | 62 +++++++++++-- .../spark/sql/expressions/WindowSpec.scala | 24 +++-- ...la => DataFrameWindowFunctionsSuite.scala} | 11 ++- 5 files changed, 160 insertions(+), 51 deletions(-) rename sql/core/src/test/scala/org/apache/spark/sql/{DataFrameWindowSuite.scala => DataFrameWindowFunctionsSuite.scala} (97%) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 61674a8a7ed65..51d5e7ab0568e 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1876,12 +1876,35 @@ def test_window_functions_without_partitionBy(self): def test_window_functions_cumulative_sum(self): df = self.spark.createDataFrame([("one", 1), ("two", 2)], ["key", "value"]) from pyspark.sql import functions as F - sel = df.select(df.key, F.sum(df.value).over(Window.rowsBetween(-sys.maxsize, 0))) + + # Test cumulative sum + sel = df.select( + df.key, + F.sum(df.value).over(Window.rowsBetween(Window.unboundedPreceding, 0))) + rs = sorted(sel.collect()) + expected = [("one", 1), ("two", 3)] + for r, ex in zip(rs, expected): + self.assertEqual(tuple(r), ex[:len(r)]) + + # Test boundary values less than JVM's Long.MinValue and make sure we don't overflow + sel = df.select( + df.key, + F.sum(df.value).over(Window.rowsBetween(Window.unboundedPreceding - 1, 0))) rs = sorted(sel.collect()) expected = [("one", 1), ("two", 3)] for r, ex in zip(rs, expected): self.assertEqual(tuple(r), ex[:len(r)]) + # Test boundary values greater than JVM's Long.MaxValue and make sure we don't overflow + frame_end = Window.unboundedFollowing + 1 + sel = df.select( + df.key, + F.sum(df.value).over(Window.rowsBetween(Window.currentRow, frame_end))) + rs = sorted(sel.collect()) + expected = [("one", 3), ("two", 2)] + for r, ex in zip(rs, expected): + self.assertEqual(tuple(r), ex[:len(r)]) + def test_collect_functions(self): df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) from pyspark.sql import functions diff --git a/python/pyspark/sql/window.py b/python/pyspark/sql/window.py index 87e9a988987ea..c345e623f1cb1 100644 --- a/python/pyspark/sql/window.py +++ b/python/pyspark/sql/window.py @@ -36,8 +36,8 @@ class Window(object): For example: - >>> # PARTITION BY country ORDER BY date ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW - >>> window = Window.partitionBy("country").orderBy("date").rowsBetween(-sys.maxsize, 0) + >>> # ORDER BY date ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW + >>> window = Window.orderBy("date").rowsBetween(Window.unboundedPreceding, Window.currentRow) >>> # PARTITION BY country ORDER BY date RANGE BETWEEN 3 PRECEDING AND 3 FOLLOWING >>> window = Window.orderBy("date").partitionBy("country").rangeBetween(-3, 3) @@ -46,6 +46,16 @@ class Window(object): .. versionadded:: 1.4 """ + + _JAVA_MIN_LONG = -(1 << 63) # -9223372036854775808 + _JAVA_MAX_LONG = (1 << 63) - 1 # 9223372036854775807 + + unboundedPreceding = _JAVA_MIN_LONG + + unboundedFollowing = _JAVA_MAX_LONG + + currentRow = 0 + @staticmethod @since(1.4) def partitionBy(*cols): @@ -77,15 +87,21 @@ def rowsBetween(start, end): For example, "0" means "current row", while "-1" means the row before the current row, and "5" means the fifth row after the current row. + We recommend users use ``Window.unboundedPreceding``, ``Window.unboundedFollowing``, + and ``Window.currentRow`` to specify special boundary values, rather than using integral + values directly. + :param start: boundary start, inclusive. - The frame is unbounded if this is ``-sys.maxsize`` (or lower). + The frame is unbounded if this is ``Window.unboundedPreceding``, or + any value less than or equal to -9223372036854775808. :param end: boundary end, inclusive. - The frame is unbounded if this is ``sys.maxsize`` (or higher). + The frame is unbounded if this is ``Window.unboundedFollowing``, or + any value greater than or equal to 9223372036854775807. """ - if start <= -sys.maxsize: - start = WindowSpec._JAVA_MIN_LONG - if end >= sys.maxsize: - end = WindowSpec._JAVA_MAX_LONG + if start <= Window._JAVA_MIN_LONG: + start = Window.unboundedPreceding + if end >= Window._JAVA_MAX_LONG: + end = Window.unboundedFollowing sc = SparkContext._active_spark_context jspec = sc._jvm.org.apache.spark.sql.expressions.Window.rowsBetween(start, end) return WindowSpec(jspec) @@ -101,15 +117,21 @@ def rangeBetween(start, end): "0" means "current row", while "-1" means one off before the current row, and "5" means the five off after the current row. + We recommend users use ``Window.unboundedPreceding``, ``Window.unboundedFollowing``, + and ``Window.currentRow`` to specify special boundary values, rather than using integral + values directly. + :param start: boundary start, inclusive. - The frame is unbounded if this is ``-sys.maxsize`` (or lower). + The frame is unbounded if this is ``Window.unboundedPreceding``, or + any value less than or equal to -9223372036854775808. :param end: boundary end, inclusive. - The frame is unbounded if this is ``sys.maxsize`` (or higher). + The frame is unbounded if this is ``Window.unboundedFollowing``, or + any value greater than or equal to 9223372036854775807. """ - if start <= -sys.maxsize: - start = WindowSpec._JAVA_MIN_LONG - if end >= sys.maxsize: - end = WindowSpec._JAVA_MAX_LONG + if start <= Window._JAVA_MIN_LONG: + start = Window.unboundedPreceding + if end >= Window._JAVA_MAX_LONG: + end = Window.unboundedFollowing sc = SparkContext._active_spark_context jspec = sc._jvm.org.apache.spark.sql.expressions.Window.rangeBetween(start, end) return WindowSpec(jspec) @@ -127,9 +149,6 @@ class WindowSpec(object): .. versionadded:: 1.4 """ - _JAVA_MAX_LONG = (1 << 63) - 1 - _JAVA_MIN_LONG = - (1 << 63) - def __init__(self, jspec): self._jspec = jspec @@ -160,15 +179,21 @@ def rowsBetween(self, start, end): For example, "0" means "current row", while "-1" means the row before the current row, and "5" means the fifth row after the current row. + We recommend users use ``Window.unboundedPreceding``, ``Window.unboundedFollowing``, + and ``Window.currentRow`` to specify special boundary values, rather than using integral + values directly. + :param start: boundary start, inclusive. - The frame is unbounded if this is ``-sys.maxsize`` (or lower). + The frame is unbounded if this is ``Window.unboundedPreceding``, or + any value less than or equal to -9223372036854775808. :param end: boundary end, inclusive. - The frame is unbounded if this is ``sys.maxsize`` (or higher). + The frame is unbounded if this is ``Window.unboundedFollowing``, or + any value greater than or equal to 9223372036854775807. """ - if start <= -sys.maxsize: - start = self._JAVA_MIN_LONG - if end >= sys.maxsize: - end = self._JAVA_MAX_LONG + if start <= Window._JAVA_MIN_LONG: + start = Window.unboundedPreceding + if end >= Window._JAVA_MAX_LONG: + end = Window.unboundedFollowing return WindowSpec(self._jspec.rowsBetween(start, end)) @since(1.4) @@ -180,15 +205,21 @@ def rangeBetween(self, start, end): "0" means "current row", while "-1" means one off before the current row, and "5" means the five off after the current row. + We recommend users use ``Window.unboundedPreceding``, ``Window.unboundedFollowing``, + and ``Window.currentRow`` to specify special boundary values, rather than using integral + values directly. + :param start: boundary start, inclusive. - The frame is unbounded if this is ``-sys.maxsize`` (or lower). + The frame is unbounded if this is ``Window.unboundedPreceding``, or + any value less than or equal to -9223372036854775808. :param end: boundary end, inclusive. - The frame is unbounded if this is ``sys.maxsize`` (or higher). + The frame is unbounded if this is ``Window.unboundedFollowing``, or + any value greater than or equal to 9223372036854775807. """ - if start <= -sys.maxsize: - start = self._JAVA_MIN_LONG - if end >= sys.maxsize: - end = self._JAVA_MAX_LONG + if start <= Window._JAVA_MIN_LONG: + start = Window.unboundedPreceding + if end >= Window._JAVA_MAX_LONG: + end = Window.unboundedFollowing return WindowSpec(self._jspec.rangeBetween(start, end)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala index e8a0c5f43fe46..3c1f6e897ea62 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala @@ -27,7 +27,8 @@ import org.apache.spark.sql.catalyst.expressions._ * * {{{ * // PARTITION BY country ORDER BY date ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW - * Window.partitionBy("country").orderBy("date").rowsBetween(Long.MinValue, 0) + * Window.partitionBy("country").orderBy("date") + * .rowsBetween(Window.unboundedPreceding, Window.currentRow) * * // PARTITION BY country ORDER BY date ROWS BETWEEN 3 PRECEDING AND 3 FOLLOWING * Window.partitionBy("country").orderBy("date").rowsBetween(-3, 3) @@ -74,6 +75,41 @@ object Window { spec.orderBy(cols : _*) } + /** + * Value representing the last row in the partition, equivalent to "UNBOUNDED PRECEDING" in SQL. + * This can be used to specify the frame boundaries: + * + * {{{ + * Window.rowsBetween(Window.unboundedPreceding, Window.currentRow) + * }}} + * + * @since 2.1.0 + */ + def unboundedPreceding: Long = Long.MinValue + + /** + * Value representing the last row in the partition, equivalent to "UNBOUNDED FOLLOWING" in SQL. + * This can be used to specify the frame boundaries: + * + * {{{ + * Window.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing) + * }}} + * + * @since 2.1.0 + */ + def unboundedFollowing: Long = Long.MaxValue + + /** + * Value representing the current row. This can be used to specify the frame boundaries: + * + * {{{ + * Window.rowsBetween(Window.unboundedPreceding, Window.currentRow) + * }}} + * + * @since 2.1.0 + */ + def currentRow: Long = 0 + /** * Creates a [[WindowSpec]] with the frame boundaries defined, * from `start` (inclusive) to `end` (inclusive). @@ -82,10 +118,14 @@ object Window { * "current row", while "-1" means the row before the current row, and "5" means the fifth row * after the current row. * - * @param start boundary start, inclusive. - * The frame is unbounded if this is the minimum long value. - * @param end boundary end, inclusive. - * The frame is unbounded if this is the maximum long value. + * We recommend users use [[Window.unboundedPreceding]], [[Window.unboundedFollowing]], + * and [[Window.currentRow]] to specify special boundary values, rather than using integral + * values directly. + * + * @param start boundary start, inclusive. The frame is unbounded if this is + * the minimum long value ([[Window.unboundedPreceding]]). + * @param end boundary end, inclusive. The frame is unbounded if this is the + * maximum long value ([[Window.unboundedFollowing]]). * @since 2.1.0 */ // Note: when updating the doc for this method, also update WindowSpec.rowsBetween. @@ -101,10 +141,14 @@ object Window { * while "-1" means one off before the current row, and "5" means the five off after the * current row. * - * @param start boundary start, inclusive. - * The frame is unbounded if this is the minimum long value. - * @param end boundary end, inclusive. - * The frame is unbounded if this is the maximum long value. + * We recommend users use [[Window.unboundedPreceding]], [[Window.unboundedFollowing]], + * and [[Window.currentRow]] to specify special boundary values, rather than using integral + * values directly. + * + * @param start boundary start, inclusive. The frame is unbounded if this is + * the minimum long value ([[Window.unboundedPreceding]]). + * @param end boundary end, inclusive. The frame is unbounded if this is the + * maximum long value ([[Window.unboundedFollowing]]). * @since 2.1.0 */ // Note: when updating the doc for this method, also update WindowSpec.rangeBetween. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala index 82bc8f152d6ea..8ebed399bf2d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala @@ -86,10 +86,14 @@ class WindowSpec private[sql]( * "current row", while "-1" means the row before the current row, and "5" means the fifth row * after the current row. * - * @param start boundary start, inclusive. - * The frame is unbounded if this is the minimum long value. - * @param end boundary end, inclusive. - * The frame is unbounded if this is the maximum long value. + * We recommend users use [[Window.unboundedPreceding]], [[Window.unboundedFollowing]], + * and [[Window.currentRow]] to specify special boundary values, rather than using integral + * values directly. + * + * @param start boundary start, inclusive. The frame is unbounded if this is + * the minimum long value ([[Window.unboundedPreceding]]). + * @param end boundary end, inclusive. The frame is unbounded if this is the + * maximum long value ([[Window.unboundedFollowing]]). * @since 1.4.0 */ // Note: when updating the doc for this method, also update Window.rowsBetween. @@ -104,10 +108,14 @@ class WindowSpec private[sql]( * while "-1" means one off before the current row, and "5" means the five off after the * current row. * - * @param start boundary start, inclusive. - * The frame is unbounded if this is the minimum long value. - * @param end boundary end, inclusive. - * The frame is unbounded if this is the maximum long value. + * We recommend users use [[Window.unboundedPreceding]], [[Window.unboundedFollowing]], + * and [[Window.currentRow]] to specify special boundary values, rather than using integral + * values directly. + * + * @param start boundary start, inclusive. The frame is unbounded if this is + * the minimum long value ([[Window.unboundedPreceding]]). + * @param end boundary end, inclusive. The frame is unbounded if this is the + * maximum long value ([[Window.unboundedFollowing]]). * @since 1.4.0 */ // Note: when updating the doc for this method, also update Window.rangeBetween. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala similarity index 97% rename from sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala index 5bc386f291043..1255c49104718 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.types.{DataType, LongType, StructType} /** * Window function testing for DataFrame API. */ -class DataFrameWindowSuite extends QueryTest with SharedSQLContext { +class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { import testImplicits._ test("reuse window partitionBy") { @@ -54,7 +54,8 @@ class DataFrameWindowSuite extends QueryTest with SharedSQLContext { val df = Seq(("one", 1), ("two", 2)).toDF("key", "value") // Running (cumulative) sum checkAnswer( - df.select('key, sum("value").over(Window.rowsBetween(Long.MinValue, 0))), + df.select('key, sum("value").over( + Window.rowsBetween(Window.unboundedPreceding, Window.currentRow))), Row("one", 1) :: Row("two", 3) :: Nil ) } @@ -156,9 +157,11 @@ class DataFrameWindowSuite extends QueryTest with SharedSQLContext { df.select( $"key", last("key").over( - Window.partitionBy($"value").orderBy($"key").rowsBetween(0, Long.MaxValue)), + Window.partitionBy($"value").orderBy($"key") + .rowsBetween(Window.currentRow, Window.unboundedFollowing)), last("key").over( - Window.partitionBy($"value").orderBy($"key").rowsBetween(Long.MinValue, 0)), + Window.partitionBy($"value").orderBy($"key") + .rowsBetween(Window.unboundedPreceding, Window.currentRow)), last("key").over(Window.partitionBy($"value").orderBy($"key").rowsBetween(-1, 1))), Seq(Row(1, 1, 1, 1), Row(2, 3, 2, 3), Row(3, 3, 3, 3), Row(1, 4, 1, 2), Row(2, 4, 2, 4), Row(4, 4, 4, 4))) From 0d4a695279c514c76aa0e9288c70ac7aaef91b03 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Wed, 12 Oct 2016 19:52:57 -0700 Subject: [PATCH 265/306] [SPARK-17745][ML][PYSPARK] update NB python api - add weight col parameter ## What changes were proposed in this pull request? update python api for NaiveBayes: add weight col parameter. ## How was this patch tested? doctests added. Author: WeichenXu Closes #15406 from WeichenXu123/nb_python_update. --- python/pyspark/ml/classification.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index ea60fab029582..3f763a10d4066 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -981,7 +981,7 @@ def trees(self): @inherit_doc class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasProbabilityCol, - HasRawPredictionCol, HasThresholds, JavaMLWritable, JavaMLReadable): + HasRawPredictionCol, HasThresholds, HasWeightCol, JavaMLWritable, JavaMLReadable): """ Naive Bayes Classifiers. It supports both Multinomial and Bernoulli NB. `Multinomial NB @@ -995,23 +995,23 @@ class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, H >>> from pyspark.sql import Row >>> from pyspark.ml.linalg import Vectors >>> df = spark.createDataFrame([ - ... Row(label=0.0, features=Vectors.dense([0.0, 0.0])), - ... Row(label=0.0, features=Vectors.dense([0.0, 1.0])), - ... Row(label=1.0, features=Vectors.dense([1.0, 0.0]))]) - >>> nb = NaiveBayes(smoothing=1.0, modelType="multinomial") + ... Row(label=0.0, weight=0.1, features=Vectors.dense([0.0, 0.0])), + ... Row(label=0.0, weight=0.5, features=Vectors.dense([0.0, 1.0])), + ... Row(label=1.0, weight=1.0, features=Vectors.dense([1.0, 0.0]))]) + >>> nb = NaiveBayes(smoothing=1.0, modelType="multinomial", weightCol="weight") >>> model = nb.fit(df) >>> model.pi - DenseVector([-0.51..., -0.91...]) + DenseVector([-0.81..., -0.58...]) >>> model.theta - DenseMatrix(2, 2, [-1.09..., -0.40..., -0.40..., -1.09...], 1) + DenseMatrix(2, 2, [-0.91..., -0.51..., -0.40..., -1.09...], 1) >>> test0 = sc.parallelize([Row(features=Vectors.dense([1.0, 0.0]))]).toDF() >>> result = model.transform(test0).head() >>> result.prediction 1.0 >>> result.probability - DenseVector([0.42..., 0.57...]) + DenseVector([0.32..., 0.67...]) >>> result.rawPrediction - DenseVector([-1.60..., -1.32...]) + DenseVector([-1.72..., -0.99...]) >>> test1 = sc.parallelize([Row(features=Vectors.sparse(2, [0], [1.0]))]).toDF() >>> model.transform(test1).head().prediction 1.0 @@ -1045,11 +1045,11 @@ class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, H @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0, - modelType="multinomial", thresholds=None): + modelType="multinomial", thresholds=None, weightCol=None): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0, \ - modelType="multinomial", thresholds=None) + modelType="multinomial", thresholds=None, weightCol=None) """ super(NaiveBayes, self).__init__() self._java_obj = self._new_java_obj( @@ -1062,11 +1062,11 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred @since("1.5.0") def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0, - modelType="multinomial", thresholds=None): + modelType="multinomial", thresholds=None, weightCol=None): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0, \ - modelType="multinomial", thresholds=None) + modelType="multinomial", thresholds=None, weightCol=None) Sets params for Naive Bayes. """ kwargs = self.setParams._input_kwargs From 21cb59f1cd137d96b2596f1abe691b544581cf59 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 12 Oct 2016 19:56:40 -0700 Subject: [PATCH 266/306] [SPARK-17835][ML][MLLIB] Optimize NaiveBayes mllib wrapper to eliminate extra pass on data ## What changes were proposed in this pull request? [SPARK-14077](https://issues.apache.org/jira/browse/SPARK-14077) copied the ```NaiveBayes``` implementation from mllib to ml and left mllib as a wrapper. However, there are some difference between mllib and ml to handle labels: * mllib allow input labels as {-1, +1}, however, ml assumes the input labels in range [0, numClasses). * mllib ```NaiveBayesModel``` expose ```labels``` but ml did not due to the assumption mention above. During the copy in [SPARK-14077](https://issues.apache.org/jira/browse/SPARK-14077), we use ```val labels = data.map(_.label).distinct().collect().sorted``` to get the distinct labels firstly, and then encode the labels for training. It involves extra Spark job compared with the original implementation. Since ```NaiveBayes``` only do one pass aggregation during training, adding another one seems less efficient. We can get the labels in a single pass along with ```NaiveBayes``` training and send them to MLlib side. ## How was this patch tested? Existing tests. Author: Yanbo Liang Closes #15402 from yanboliang/spark-17835. --- .../spark/ml/classification/NaiveBayes.scala | 46 +++++++++++++++---- .../mllib/classification/NaiveBayes.scala | 15 +++--- 2 files changed, 43 insertions(+), 18 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index e565a6fd3ece2..994ed993c99df 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -110,16 +110,28 @@ class NaiveBayes @Since("1.5.0") ( @Since("2.1.0") def setWeightCol(value: String): this.type = set(weightCol, value) - override protected def train(dataset: Dataset[_]): NaiveBayesModel = { - val numClasses = getNumClasses(dataset) + /** + * ml assumes input labels in range [0, numClasses). But this implementation + * is also called by mllib NaiveBayes which allows other kinds of input labels + * such as {-1, +1}. Here we use this parameter to switch between different processing logic. + * It should be removed when we remove mllib NaiveBayes. + */ + private[spark] var isML: Boolean = true - if (isDefined(thresholds)) { - require($(thresholds).length == numClasses, this.getClass.getSimpleName + - ".train() called with non-matching numClasses and thresholds.length." + - s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}") - } + private[spark] def setIsML(isML: Boolean): this.type = { + this.isML = isML + this + } - val numFeatures = dataset.select(col($(featuresCol))).head().getAs[Vector](0).size + override protected def train(dataset: Dataset[_]): NaiveBayesModel = { + if (isML) { + val numClasses = getNumClasses(dataset) + if (isDefined(thresholds)) { + require($(thresholds).length == numClasses, this.getClass.getSimpleName + + ".train() called with non-matching numClasses and thresholds.length." + + s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}") + } + } val requireNonnegativeValues: Vector => Unit = (v: Vector) => { val values = v match { @@ -153,6 +165,7 @@ class NaiveBayes @Since("1.5.0") ( } } + val numFeatures = dataset.select(col($(featuresCol))).head().getAs[Vector](0).size val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) // Aggregates term frequencies per label. @@ -176,6 +189,7 @@ class NaiveBayes @Since("1.5.0") ( val numLabels = aggregated.length val numDocuments = aggregated.map(_._2._1).sum + val labelArray = new Array[Double](numLabels) val piArray = new Array[Double](numLabels) val thetaArray = new Array[Double](numLabels * numFeatures) @@ -183,6 +197,7 @@ class NaiveBayes @Since("1.5.0") ( val piLogDenom = math.log(numDocuments + numLabels * lambda) var i = 0 aggregated.foreach { case (label, (n, sumTermFreqs)) => + labelArray(i) = label piArray(i) = math.log(n + lambda) - piLogDenom val thetaLogDenom = $(modelType) match { case Multinomial => math.log(sumTermFreqs.values.sum + numFeatures * lambda) @@ -201,7 +216,7 @@ class NaiveBayes @Since("1.5.0") ( val pi = Vectors.dense(piArray) val theta = new DenseMatrix(numLabels, numFeatures, thetaArray, true) - new NaiveBayesModel(uid, pi, theta) + new NaiveBayesModel(uid, pi, theta).setOldLabels(labelArray) } @Since("1.5.0") @@ -239,6 +254,19 @@ class NaiveBayesModel private[ml] ( import NaiveBayes.{Bernoulli, Multinomial} + /** + * mllib NaiveBayes is a wrapper of ml implementation currently. + * Input labels of mllib could be {-1, +1} and mllib NaiveBayesModel exposes labels, + * both of which are different from ml, so we should store the labels sequentially + * to be called by mllib. This should be removed when we remove mllib NaiveBayes. + */ + private[spark] var oldLabels: Array[Double] = null + + private[spark] def setOldLabels(labels: Array[Double]): this.type = { + this.oldLabels = labels + this + } + /** * Bernoulli scoring requires log(condprob) if 1, log(1-condprob) if 0. * This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index 32d6968a4e85f..33561be4b5bc1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -364,15 +364,10 @@ class NaiveBayes private ( val nb = new NewNaiveBayes() .setModelType(modelType) .setSmoothing(lambda) + .setIsML(false) - val labels = data.map(_.label).distinct().collect().sorted - - // Input labels for [[org.apache.spark.ml.classification.NaiveBayes]] must be - // in range [0, numClasses). - val dataset = data.map { - case LabeledPoint(label, features) => - (labels.indexOf(label).toDouble, features.asML) - }.toDF("label", "features") + val dataset = data.map { case LabeledPoint(label, features) => (label, features.asML) } + .toDF("label", "features") val newModel = nb.fit(dataset) @@ -383,7 +378,9 @@ class NaiveBayes private ( theta(i)(j) = v } - new NaiveBayesModel(labels, pi, theta, modelType) + require(newModel.oldLabels != null, + "The underlying ML NaiveBayes training does not produce labels.") + new NaiveBayesModel(newModel.oldLabels, pi, theta, modelType) } } From edeb51a39d76d64196d7635f52be1b42c7ec4341 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Wed, 12 Oct 2016 21:40:45 -0700 Subject: [PATCH 267/306] [SPARK-17876] Write StructuredStreaming WAL to a stream instead of materializing all at once ## What changes were proposed in this pull request? The CompactibleFileStreamLog materializes the whole metadata log in memory as a String. This can cause issues when there are lots of files that are being committed, especially during a compaction batch. You may come across stacktraces that look like: ``` java.lang.OutOfMemoryError: Requested array size exceeds VM limit at java.lang.StringCoding.encode(StringCoding.java:350) at java.lang.String.getBytes(String.java:941) at org.apache.spark.sql.execution.streaming.FileStreamSinkLog.serialize(FileStreamSinkLog.scala:127) ``` The safer way is to write to an output stream so that we don't have to materialize a huge string. ## How was this patch tested? Existing unit tests Author: Burak Yavuz Closes #15437 from brkyvz/ser-to-stream. --- .../streaming/CompactibleFileStreamLog.scala | 22 +++++++++----- .../execution/streaming/HDFSMetadataLog.scala | 29 ++++++++++--------- .../streaming/FileStreamSinkLogSuite.scala | 14 +++++---- 3 files changed, 38 insertions(+), 27 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLog.scala index 027b5bbfab8d6..c14feea91ed7d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLog.scala @@ -17,9 +17,10 @@ package org.apache.spark.sql.execution.streaming -import java.io.IOException +import java.io.{InputStream, IOException, OutputStream} import java.nio.charset.StandardCharsets.UTF_8 +import scala.io.{Source => IOSource} import scala.reflect.ClassTag import org.apache.hadoop.fs.{Path, PathFilter} @@ -93,20 +94,25 @@ abstract class CompactibleFileStreamLog[T: ClassTag]( } } - override def serialize(logData: Array[T]): Array[Byte] = { - (metadataLogVersion +: logData.map(serializeData)).mkString("\n").getBytes(UTF_8) + override def serialize(logData: Array[T], out: OutputStream): Unit = { + // called inside a try-finally where the underlying stream is closed in the caller + out.write(metadataLogVersion.getBytes(UTF_8)) + logData.foreach { data => + out.write('\n') + out.write(serializeData(data).getBytes(UTF_8)) + } } - override def deserialize(bytes: Array[Byte]): Array[T] = { - val lines = new String(bytes, UTF_8).split("\n") - if (lines.length == 0) { + override def deserialize(in: InputStream): Array[T] = { + val lines = IOSource.fromInputStream(in, UTF_8.name()).getLines() + if (!lines.hasNext) { throw new IllegalStateException("Incomplete log file") } - val version = lines(0) + val version = lines.next() if (version != metadataLogVersion) { throw new IllegalStateException(s"Unknown log version: ${version}") } - lines.slice(1, lines.length).map(deserializeData) + lines.map(deserializeData).toArray } override def add(batchId: Long, logs: Array[T]): Boolean = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala index 39a0f3341389c..c7235320fd6bd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.execution.streaming -import java.io.{FileNotFoundException, IOException} -import java.nio.ByteBuffer +import java.io.{FileNotFoundException, InputStream, IOException, OutputStream} import java.util.{ConcurrentModificationException, EnumSet, UUID} import scala.reflect.ClassTag @@ -29,7 +28,6 @@ import org.apache.hadoop.fs._ import org.apache.hadoop.fs.permission.FsPermission import org.apache.spark.internal.Logging -import org.apache.spark.network.util.JavaUtils import org.apache.spark.serializer.JavaSerializer import org.apache.spark.sql.SparkSession import org.apache.spark.util.UninterruptibleThread @@ -88,12 +86,16 @@ class HDFSMetadataLog[T: ClassTag](sparkSession: SparkSession, path: String) } } - protected def serialize(metadata: T): Array[Byte] = { - JavaUtils.bufferToArray(serializer.serialize(metadata)) + protected def serialize(metadata: T, out: OutputStream): Unit = { + // called inside a try-finally where the underlying stream is closed in the caller + val outStream = serializer.serializeStream(out) + outStream.writeObject(metadata) } - protected def deserialize(bytes: Array[Byte]): T = { - serializer.deserialize[T](ByteBuffer.wrap(bytes)) + protected def deserialize(in: InputStream): T = { + // called inside a try-finally where the underlying stream is closed in the caller + val inStream = serializer.deserializeStream(in) + inStream.readObject[T]() } /** @@ -114,7 +116,7 @@ class HDFSMetadataLog[T: ClassTag](sparkSession: SparkSession, path: String) // Only write metadata when the batch has not yet been written Thread.currentThread match { case ut: UninterruptibleThread => - ut.runUninterruptibly { writeBatch(batchId, serialize(metadata)) } + ut.runUninterruptibly { writeBatch(batchId, metadata, serialize) } case _ => throw new IllegalStateException( "HDFSMetadataLog.add() must be executed on a o.a.spark.util.UninterruptibleThread") @@ -129,7 +131,7 @@ class HDFSMetadataLog[T: ClassTag](sparkSession: SparkSession, path: String) * There may be multiple [[HDFSMetadataLog]] using the same metadata path. Although it is not a * valid behavior, we still need to prevent it from destroying the files. */ - private def writeBatch(batchId: Long, bytes: Array[Byte]): Unit = { + private def writeBatch(batchId: Long, metadata: T, writer: (T, OutputStream) => Unit): Unit = { // Use nextId to create a temp file var nextId = 0 while (true) { @@ -137,9 +139,9 @@ class HDFSMetadataLog[T: ClassTag](sparkSession: SparkSession, path: String) try { val output = fileManager.create(tempPath) try { - output.write(bytes) + writer(metadata, output) } finally { - output.close() + IOUtils.closeQuietly(output) } try { // Try to commit the batch @@ -193,10 +195,9 @@ class HDFSMetadataLog[T: ClassTag](sparkSession: SparkSession, path: String) if (fileManager.exists(batchMetadataFile)) { val input = fileManager.open(batchMetadataFile) try { - val bytes = IOUtils.toByteArray(input) - Some(deserialize(bytes)) + Some(deserialize(input)) } finally { - input.close() + IOUtils.closeQuietly(input) } } else { logDebug(s"Unable to find batch $batchMetadataFile") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala index 41a8cc2400dff..e1bc674a28071 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.streaming +import java.io.{ByteArrayInputStream, ByteArrayOutputStream} import java.nio.charset.StandardCharsets.UTF_8 import org.apache.spark.SparkFunSuite @@ -133,9 +134,12 @@ class FileStreamSinkLogSuite extends SparkFunSuite with SharedSQLContext { |{"path":"/a/b/y","size":200,"isDir":false,"modificationTime":2000,"blockReplication":2,"blockSize":20000,"action":"delete"} |{"path":"/a/b/z","size":300,"isDir":false,"modificationTime":3000,"blockReplication":3,"blockSize":30000,"action":"add"}""".stripMargin // scalastyle:on - assert(expected === new String(sinkLog.serialize(logs), UTF_8)) - - assert(VERSION === new String(sinkLog.serialize(Array()), UTF_8)) + val baos = new ByteArrayOutputStream() + sinkLog.serialize(logs, baos) + assert(expected === baos.toString(UTF_8.name())) + baos.reset() + sinkLog.serialize(Array(), baos) + assert(VERSION === baos.toString(UTF_8.name())) } } @@ -174,9 +178,9 @@ class FileStreamSinkLogSuite extends SparkFunSuite with SharedSQLContext { blockSize = 30000L, action = FileStreamSinkLog.ADD_ACTION)) - assert(expected === sinkLog.deserialize(logs.getBytes(UTF_8))) + assert(expected === sinkLog.deserialize(new ByteArrayInputStream(logs.getBytes(UTF_8)))) - assert(Nil === sinkLog.deserialize(VERSION.getBytes(UTF_8))) + assert(Nil === sinkLog.deserialize(new ByteArrayInputStream(VERSION.getBytes(UTF_8)))) } } From 064d6650e93ed6515a1309079c361e20404843cc Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 13 Oct 2016 13:27:57 +0800 Subject: [PATCH 268/306] [SPARK-17866][SPARK-17867][SQL] Fix Dataset.dropduplicates ## What changes were proposed in this pull request? Two issues regarding Dataset.dropduplicates: 1. Dataset.dropDuplicates should consider the columns with same column name We find and get the first resolved attribute from output with the given column name in `Dataset.dropDuplicates`. When we have the more than one columns with the same name. Other columns are put into aggregation columns, instead of grouping columns. 2. Dataset.dropDuplicates should not change the output of child plan We create new `Alias` with new exprId in `Dataset.dropDuplicates` now. However it causes problem when we want to select the columns as follows: val ds = Seq(("a", 1), ("a", 2), ("b", 1), ("a", 1)).toDS() // ds("_2") will cause analysis exception ds.dropDuplicates("_1").select(ds("_1").as[String], ds("_2").as[Int]) Because the two issues are both related to `Dataset.dropduplicates` and the code changes are not big, so submitting them together as one PR. ## How was this patch tested? Jenkins tests. Author: Liang-Chi Hsieh Closes #15427 from viirya/fix-dropduplicates. --- .../scala/org/apache/spark/sql/Dataset.scala | 16 ++++++++++++---- .../org/apache/spark/sql/DatasetSuite.scala | 17 +++++++++++++++++ 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index a7a84730a6fd9..e59a483075c94 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1892,17 +1892,25 @@ class Dataset[T] private[sql]( def dropDuplicates(colNames: Seq[String]): Dataset[T] = withTypedPlan { val resolver = sparkSession.sessionState.analyzer.resolver val allColumns = queryExecution.analyzed.output - val groupCols = colNames.map { colName => - allColumns.find(col => resolver(col.name, colName)).getOrElse( + val groupCols = colNames.flatMap { colName => + // It is possibly there are more than one columns with the same name, + // so we call filter instead of find. + val cols = allColumns.filter(col => resolver(col.name, colName)) + if (cols.isEmpty) { throw new AnalysisException( - s"""Cannot resolve column name "$colName" among (${schema.fieldNames.mkString(", ")})""")) + s"""Cannot resolve column name "$colName" among (${schema.fieldNames.mkString(", ")})""") + } + cols } val groupColExprIds = groupCols.map(_.exprId) val aggCols = logicalPlan.output.map { attr => if (groupColExprIds.contains(attr.exprId)) { attr } else { - Alias(new First(attr).toAggregateExpression(), attr.name)() + // Removing duplicate rows should not change output attributes. We should keep + // the original exprId of the attribute. Otherwise, to select a column in original + // dataset will cause analysis exception due to unresolved attribute. + Alias(new First(attr).toAggregateExpression(), attr.name)(exprId = attr.exprId) } } Aggregate(groupCols, aggCols, logicalPlan) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 3243f352a5337..5fce9b4fe97ea 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -872,6 +872,23 @@ class DatasetSuite extends QueryTest with SharedSQLContext { ("a", 1), ("a", 2), ("b", 1)) } + test("dropDuplicates: columns with same column name") { + val ds1 = Seq(("a", 1), ("a", 2), ("b", 1), ("a", 1)).toDS() + val ds2 = Seq(("a", 1), ("a", 2), ("b", 1), ("a", 1)).toDS() + // The dataset joined has two columns of the same name "_2". + val joined = ds1.join(ds2, "_1").select(ds1("_2").as[Int], ds2("_2").as[Int]) + checkDataset( + joined.dropDuplicates(), + (1, 2), (1, 1), (2, 1), (2, 2)) + } + + test("dropDuplicates should not change child plan output") { + val ds = Seq(("a", 1), ("a", 2), ("b", 1), ("a", 1)).toDS() + checkDataset( + ds.dropDuplicates("_1").select(ds("_1").as[String], ds("_2").as[Int]), + ("a", 1), ("b", 1)) + } + test("SPARK-16097: Encoders.tuple should handle null object correctly") { val enc = Encoders.tuple(Encoders.tuple(Encoders.STRING, Encoders.STRING), Encoders.STRING) val data = Seq((("a", "b"), "c"), (null, "d")) From 7222a25a11790fa9d9d1428c84b6f827a785c9e8 Mon Sep 17 00:00:00 2001 From: buzhihuojie Date: Wed, 12 Oct 2016 22:51:54 -0700 Subject: [PATCH 269/306] minor doc fix for Row.scala ## What changes were proposed in this pull request? minor doc fix for "getAnyValAs" in class Row ## How was this patch tested? None. (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Author: buzhihuojie Closes #15452 from david-weiluo-ren/minorDocFixForRow. --- sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala index 344dcb9bce62b..65f91429648c1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala @@ -462,7 +462,7 @@ trait Row extends Serializable { def mkString(start: String, sep: String, end: String): String = toSeq.mkString(start, sep, end) /** - * Returns the value of a given fieldName. + * Returns the value at position i. * * @throws UnsupportedOperationException when schema is not defined. * @throws ClassCastException when data type does not match. From 6f2fa6c54a11caccd446d5560d2014c645fcf7cc Mon Sep 17 00:00:00 2001 From: Alex Bozarth Date: Thu, 13 Oct 2016 03:24:37 -0400 Subject: [PATCH 270/306] [SPARK-11272][WEB UI] Add support for downloading event logs from HistoryServer UI ## What changes were proposed in this pull request? This is a reworked PR based on feedback in #9238 after it was closed and not reopened. As suggested in that PR I've only added the download feature. This functionality already exists in the api and this allows easier access to download event logs to share with others. I've attached a screenshot of the committed version, but I will also include alternate options with screen shots in the comments below. I'm personally not sure which option is best. ## How was this patch tested? Manual testing ![screen shot 2016-10-07 at 6 11 12 pm](https://cloud.githubusercontent.com/assets/13952758/19209213/832fe48e-8cba-11e6-9840-749b1be4d399.png) Author: Alex Bozarth Closes #15400 from ajbozarth/spark11272. --- .../org/apache/spark/ui/static/historypage-template.html | 7 ++++++- .../resources/org/apache/spark/ui/static/historypage.js | 1 + 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html b/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html index a2b3826dd324b..1fd6ef4a71253 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html +++ b/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html @@ -59,7 +59,11 @@ Last Updated - + + + Event Log + + {{#applications}} @@ -73,6 +77,7 @@ {{duration}} {{sparkUser}} {{lastUpdated}} + Download {{/attempts}} {{/applications}} diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage.js b/core/src/main/resources/org/apache/spark/ui/static/historypage.js index c8094005c65dd..2a32e18672a22 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/historypage.js +++ b/core/src/main/resources/org/apache/spark/ui/static/historypage.js @@ -133,6 +133,7 @@ $(document).ready(function() { {name: 'sixth', type: "title-numeric"}, {name: 'seventh'}, {name: 'eighth'}, + {name: 'ninth'}, ], "autoWidth": false, "order": [[ 4, "desc" ]] From db8784feaa605adcbd37af4bc8b7146479b631f8 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 13 Oct 2016 03:26:29 -0400 Subject: [PATCH 271/306] [SPARK-17899][SQL] add a debug mode to keep raw table properties in HiveExternalCatalog ## What changes were proposed in this pull request? Currently `HiveExternalCatalog` will filter out the Spark SQL internal table properties, e.g. `spark.sql.sources.provider`, `spark.sql.sources.schema`, etc. This is reasonable for external users as they don't want to see these internal properties in `DESC TABLE`. However, as a Spark developer, sometimes we do wanna see the raw table properties. This PR adds a new internal SQL conf, `spark.sql.debug`, to enable debug mode and keep these raw table properties. This config can also be used in similar places where we wanna retain debug information in the future. ## How was this patch tested? new test in MetastoreDataSourcesSuite Author: Wenchen Fan Closes #15458 from cloud-fan/debug. --- .../apache/spark/sql/internal/SQLConf.scala | 5 ++++ .../spark/sql/internal/SQLConfSuite.scala | 24 +++++++++++-------- .../spark/sql/hive/HiveExternalCatalog.scala | 9 +++++-- .../sql/hive/MetastoreDataSourcesSuite.scala | 17 ++++++++++++- 4 files changed, 42 insertions(+), 13 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 9e7c1ec211893..192083e2ea5f5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -915,4 +915,9 @@ object StaticSQLConf { .internal() .intConf .createWithDefault(4000) + + val DEBUG_MODE = buildConf("spark.sql.debug") + .internal() + .booleanConf + .createWithDefault(false) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala index f545de0e10a6b..df640ffab91de 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.internal import org.apache.hadoop.fs.Path +import org.apache.spark.SparkContext import org.apache.spark.sql._ import org.apache.spark.sql.execution.WholeStageCodegenExec import org.apache.spark.sql.internal.StaticSQLConf._ @@ -254,18 +255,21 @@ class SQLConfSuite extends QueryTest with SharedSQLContext { } } - test("global SQL conf comes from SparkConf") { - val newSession = SparkSession.builder() - .config(SCHEMA_STRING_LENGTH_THRESHOLD.key, "2000") - .getOrCreate() - - assert(newSession.conf.get(SCHEMA_STRING_LENGTH_THRESHOLD.key) == "2000") - checkAnswer( - newSession.sql(s"SET ${SCHEMA_STRING_LENGTH_THRESHOLD.key}"), - Row(SCHEMA_STRING_LENGTH_THRESHOLD.key, "2000")) + test("static SQL conf comes from SparkConf") { + val previousValue = sparkContext.conf.get(SCHEMA_STRING_LENGTH_THRESHOLD) + try { + sparkContext.conf.set(SCHEMA_STRING_LENGTH_THRESHOLD, 2000) + val newSession = new SparkSession(sparkContext) + assert(newSession.conf.get(SCHEMA_STRING_LENGTH_THRESHOLD) == 2000) + checkAnswer( + newSession.sql(s"SET ${SCHEMA_STRING_LENGTH_THRESHOLD.key}"), + Row(SCHEMA_STRING_LENGTH_THRESHOLD.key, "2000")) + } finally { + sparkContext.conf.set(SCHEMA_STRING_LENGTH_THRESHOLD, previousValue) + } } - test("cannot set/unset global SQL conf") { + test("cannot set/unset static SQL conf") { val e1 = intercept[AnalysisException](sql(s"SET ${SCHEMA_STRING_LENGTH_THRESHOLD.key}=10")) assert(e1.message.contains("Cannot modify the value of a static config")) val e2 = intercept[AnalysisException](spark.conf.unset(SCHEMA_STRING_LENGTH_THRESHOLD.key)) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index e1c0cad907b98..ed189724a2dbd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.execution.command.{ColumnStatStruct, DDLUtils} import org.apache.spark.sql.execution.datasources.CaseInsensitiveMap import org.apache.spark.sql.hive.client.HiveClient import org.apache.spark.sql.internal.HiveSerDe -import org.apache.spark.sql.internal.StaticSQLConf.SCHEMA_STRING_LENGTH_THRESHOLD +import org.apache.spark.sql.internal.StaticSQLConf._ import org.apache.spark.sql.types.{DataType, StructType} @@ -461,13 +461,18 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat } else { table.storage } + val tableProps = if (conf.get(DEBUG_MODE)) { + table.properties + } else { + getOriginalTableProperties(table) + } table.copy( storage = storage, schema = getSchemaFromTableProperties(table), provider = Some(provider), partitionColumnNames = getPartitionColumnsFromTableProperties(table), bucketSpec = getBucketSpecFromTableProperties(table), - properties = getOriginalTableProperties(table)) + properties = tableProps) } getOrElse { table.copy(provider = Some("hive")) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index 0477122fc6a27..7cc6179d44977 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -23,6 +23,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.fs.Path +import org.apache.spark.SparkContext import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType} @@ -31,7 +32,7 @@ import org.apache.spark.sql.hive.HiveExternalCatalog._ import org.apache.spark.sql.hive.client.HiveClient import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.StaticSQLConf.SCHEMA_STRING_LENGTH_THRESHOLD +import org.apache.spark.sql.internal.StaticSQLConf._ import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -1324,4 +1325,18 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv hiveClient.dropTable("default", "t", ignoreIfNotExists = true, purge = true) } } + + test("should keep data source entries in table properties when debug mode is on") { + val previousValue = sparkSession.sparkContext.conf.get(DEBUG_MODE) + try { + sparkSession.sparkContext.conf.set(DEBUG_MODE, true) + val newSession = sparkSession.newSession() + newSession.sql("CREATE TABLE abc(i int) USING json") + val tableMeta = newSession.sessionState.catalog.getTableMetadata(TableIdentifier("abc")) + assert(tableMeta.properties(DATASOURCE_SCHEMA_NUMPARTS).toInt == 1) + assert(tableMeta.properties(DATASOURCE_PROVIDER) == "json") + } finally { + sparkSession.sparkContext.conf.set(DEBUG_MODE, previousValue) + } + } } From 7bf8a4049866b2ec7fdf0406b1ad0c3a12488645 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Thu, 13 Oct 2016 03:29:14 -0400 Subject: [PATCH 272/306] [SPARK-17686][CORE] Support printing out scala and java version with spark-submit --version command ## What changes were proposed in this pull request? In our universal gateway service we need to specify different jars to Spark according to scala version. For now only after launching Spark application can we know which version of Scala it depends on. It makes hard for us to support different Scala + Spark versions to pick the right jars. So here propose to print out Scala version according to Spark version in "spark-submit --version", so that user could leverage this output to make the choice without needing to launching application. ## How was this patch tested? Manually verified in local environment. Author: jerryshao Closes #15456 from jerryshao/SPARK-17686. --- core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 80611658a1640..5c052286099f5 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -24,6 +24,7 @@ import java.security.PrivilegedExceptionAction import scala.annotation.tailrec import scala.collection.mutable.{ArrayBuffer, HashMap, Map} +import scala.util.Properties import org.apache.commons.lang3.StringUtils import org.apache.hadoop.fs.Path @@ -47,7 +48,6 @@ import org.apache.spark.deploy.rest._ import org.apache.spark.launcher.SparkLauncher import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader, Utils} - /** * Whether to submit, kill, or request the status of an application. * The latter two operations are currently supported only for standalone and Mesos cluster modes. @@ -104,6 +104,8 @@ object SparkSubmit { /___/ .__/\_,_/_/ /_/\_\ version %s /_/ """.format(SPARK_VERSION)) + printStream.println("Using Scala %s, %s, %s".format( + Properties.versionString, Properties.javaVmName, Properties.javaVersion)) printStream.println("Branch %s".format(SPARK_BRANCH)) printStream.println("Compiled by user %s on %s".format(SPARK_BUILD_USER, SPARK_BUILD_DATE)) printStream.println("Revision %s".format(SPARK_REVISION)) From 0a8e51a5e4cfd3275eff12e9fbbeb3fb487990aa Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 13 Oct 2016 21:36:39 +0800 Subject: [PATCH 273/306] [SPARK-17657][SQL] Disallow Users to Change Table Type ### What changes were proposed in this pull request? Hive allows users to change the table type from `Managed` to `External` or from `External` to `Managed` by altering table's property `EXTERNAL`. See the JIRA: https://issues.apache.org/jira/browse/HIVE-1329 So far, Spark SQL does not correctly support it, although users can do it. Many assumptions are broken in the implementation. Thus, this PR is to disallow users to change it. In addition, we also do not allow users to set the property `EXTERNAL` when creating a table. ### How was this patch tested? Added test cases Author: gatorsmile Closes #15230 from gatorsmile/alterTableSetExternal. --- .../spark/sql/hive/HiveExternalCatalog.scala | 5 +++ .../sql/hive/execution/HiveDDLSuite.scala | 32 +++++++++++++++++++ 2 files changed, 37 insertions(+) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index ed189724a2dbd..237b829da882f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -112,6 +112,11 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat s"as table property keys may not start with '$DATASOURCE_PREFIX' or '$STATISTICS_PREFIX':" + s" ${invalidKeys.mkString("[", ", ", "]")}") } + // External users are not allowed to set/switch the table type. In Hive metastore, the table + // type can be switched by changing the value of a case-sensitive table property `EXTERNAL`. + if (table.properties.contains("EXTERNAL")) { + throw new AnalysisException("Cannot set or change the preserved property key: 'EXTERNAL'") + } } // -------------------------------------------------------------------------- diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 8bff6de008fdb..3d1712e4354c0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -315,6 +315,38 @@ class HiveDDLSuite assert(message.contains("Cannot alter a table with ALTER VIEW. Please use ALTER TABLE instead")) } + test("create table - SET TBLPROPERTIES EXTERNAL to TRUE") { + val tabName = "tab1" + withTable(tabName) { + val message = intercept[AnalysisException] { + sql(s"CREATE TABLE $tabName (height INT, length INT) TBLPROPERTIES('EXTERNAL'='TRUE')") + }.getMessage + assert(message.contains("Cannot set or change the preserved property key: 'EXTERNAL'")) + } + } + + test("alter table - SET TBLPROPERTIES EXTERNAL to TRUE") { + val tabName = "tab1" + withTable(tabName) { + val catalog = spark.sessionState.catalog + sql(s"CREATE TABLE $tabName (height INT, length INT)") + assert( + catalog.getTableMetadata(TableIdentifier(tabName)).tableType == CatalogTableType.MANAGED) + val message = intercept[AnalysisException] { + sql(s"ALTER TABLE $tabName SET TBLPROPERTIES ('EXTERNAL' = 'TRUE')") + }.getMessage + assert(message.contains("Cannot set or change the preserved property key: 'EXTERNAL'")) + // The table type is not changed to external + assert( + catalog.getTableMetadata(TableIdentifier(tabName)).tableType == CatalogTableType.MANAGED) + // The table property is case sensitive. Thus, external is allowed + sql(s"ALTER TABLE $tabName SET TBLPROPERTIES ('external' = 'TRUE')") + // The table type is not changed to external + assert( + catalog.getTableMetadata(TableIdentifier(tabName)).tableType == CatalogTableType.MANAGED) + } + } + test("alter views and alter table - misuse") { val tabName = "tab1" withTable(tabName) { From 04d417a7ca8ef694658b26fb697a035717414731 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 13 Oct 2016 11:12:30 -0700 Subject: [PATCH 274/306] [SPARK-17830][SQL] Annotate remaining SQL APIs with InterfaceStability ## What changes were proposed in this pull request? This patch annotates all the remaining APIs in SQL (excluding streaming) with InterfaceStability. ## How was this patch tested? N/A - just annotation change. Author: Reynold Xin Closes #15457 from rxin/SPARK-17830-2. --- .../java/org/apache/spark/sql/SaveMode.java | 3 +++ .../org/apache/spark/sql/api/java/UDF1.java | 8 +++--- .../org/apache/spark/sql/api/java/UDF10.java | 8 +++--- .../org/apache/spark/sql/api/java/UDF11.java | 8 +++--- .../org/apache/spark/sql/api/java/UDF12.java | 8 +++--- .../org/apache/spark/sql/api/java/UDF13.java | 8 +++--- .../org/apache/spark/sql/api/java/UDF14.java | 8 +++--- .../org/apache/spark/sql/api/java/UDF15.java | 8 +++--- .../org/apache/spark/sql/api/java/UDF16.java | 8 +++--- .../org/apache/spark/sql/api/java/UDF17.java | 8 +++--- .../org/apache/spark/sql/api/java/UDF18.java | 8 +++--- .../org/apache/spark/sql/api/java/UDF19.java | 8 +++--- .../org/apache/spark/sql/api/java/UDF2.java | 8 +++--- .../org/apache/spark/sql/api/java/UDF20.java | 8 +++--- .../org/apache/spark/sql/api/java/UDF21.java | 8 +++--- .../org/apache/spark/sql/api/java/UDF22.java | 8 +++--- .../org/apache/spark/sql/api/java/UDF3.java | 8 +++--- .../org/apache/spark/sql/api/java/UDF4.java | 8 +++--- .../org/apache/spark/sql/api/java/UDF5.java | 8 +++--- .../org/apache/spark/sql/api/java/UDF6.java | 8 +++--- .../org/apache/spark/sql/api/java/UDF7.java | 8 +++--- .../org/apache/spark/sql/api/java/UDF8.java | 8 +++--- .../org/apache/spark/sql/api/java/UDF9.java | 8 +++--- .../spark/sql/expressions/javalang/typed.java | 2 ++ .../apache/spark/sql/catalog/Catalog.scala | 9 ++++++- .../apache/spark/sql/catalog/interface.scala | 5 ++++ .../spark/sql/expressions/Aggregator.scala | 3 ++- .../sql/expressions/UserDefinedFunction.scala | 3 ++- .../apache/spark/sql/expressions/Window.scala | 4 ++- .../spark/sql/expressions/WindowSpec.scala | 7 ++--- .../sql/expressions/scalalang/typed.scala | 3 ++- .../apache/spark/sql/expressions/udaf.scala | 8 +++++- .../apache/spark/sql/jdbc/JdbcDialects.scala | 5 +++- .../apache/spark/sql/sources/filters.scala | 18 +++++++++++++ .../apache/spark/sql/sources/interfaces.scala | 26 +++++++++++++++++-- 35 files changed, 150 insertions(+), 122 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/SaveMode.java b/sql/core/src/main/java/org/apache/spark/sql/SaveMode.java index 9665c3c46f901..1c3c9794fb6bb 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/SaveMode.java +++ b/sql/core/src/main/java/org/apache/spark/sql/SaveMode.java @@ -16,11 +16,14 @@ */ package org.apache.spark.sql; +import org.apache.spark.annotation.InterfaceStability; + /** * SaveMode is used to specify the expected behavior of saving a DataFrame to a data source. * * @since 1.3.0 */ +@InterfaceStability.Stable public enum SaveMode { /** * Append mode means that when saving a DataFrame to a data source, if data/table already exists, diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF1.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF1.java index ef959e35e1027..1460daf27dc20 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF1.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF1.java @@ -19,14 +19,12 @@ import java.io.Serializable; -// ************************************************** -// THIS FILE IS AUTOGENERATED BY CODE IN -// org.apache.spark.sql.api.java.FunctionRegistration -// ************************************************** +import org.apache.spark.annotation.InterfaceStability; /** * A Spark SQL UDF that has 1 arguments. */ +@InterfaceStability.Stable public interface UDF1 extends Serializable { - public R call(T1 t1) throws Exception; + R call(T1 t1) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF10.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF10.java index 96ab3a96c3d5e..7c4f1e4897084 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF10.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF10.java @@ -19,14 +19,12 @@ import java.io.Serializable; -// ************************************************** -// THIS FILE IS AUTOGENERATED BY CODE IN -// org.apache.spark.sql.api.java.FunctionRegistration -// ************************************************** +import org.apache.spark.annotation.InterfaceStability; /** * A Spark SQL UDF that has 10 arguments. */ +@InterfaceStability.Stable public interface UDF10 extends Serializable { - public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10) throws Exception; + R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF11.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF11.java index 58ae8edd6d817..26a05106aebd6 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF11.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF11.java @@ -19,14 +19,12 @@ import java.io.Serializable; -// ************************************************** -// THIS FILE IS AUTOGENERATED BY CODE IN -// org.apache.spark.sql.api.java.FunctionRegistration -// ************************************************** +import org.apache.spark.annotation.InterfaceStability; /** * A Spark SQL UDF that has 11 arguments. */ +@InterfaceStability.Stable public interface UDF11 extends Serializable { - public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11) throws Exception; + R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF12.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF12.java index d9da0f6eddd94..8ef7a99042025 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF12.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF12.java @@ -19,14 +19,12 @@ import java.io.Serializable; -// ************************************************** -// THIS FILE IS AUTOGENERATED BY CODE IN -// org.apache.spark.sql.api.java.FunctionRegistration -// ************************************************** +import org.apache.spark.annotation.InterfaceStability; /** * A Spark SQL UDF that has 12 arguments. */ +@InterfaceStability.Stable public interface UDF12 extends Serializable { - public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12) throws Exception; + R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF13.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF13.java index 095fc1a8076b5..5c3b2ec1222e2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF13.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF13.java @@ -19,14 +19,12 @@ import java.io.Serializable; -// ************************************************** -// THIS FILE IS AUTOGENERATED BY CODE IN -// org.apache.spark.sql.api.java.FunctionRegistration -// ************************************************** +import org.apache.spark.annotation.InterfaceStability; /** * A Spark SQL UDF that has 13 arguments. */ +@InterfaceStability.Stable public interface UDF13 extends Serializable { - public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13) throws Exception; + R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF14.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF14.java index eb27eaa180086..97e744d843466 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF14.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF14.java @@ -19,14 +19,12 @@ import java.io.Serializable; -// ************************************************** -// THIS FILE IS AUTOGENERATED BY CODE IN -// org.apache.spark.sql.api.java.FunctionRegistration -// ************************************************** +import org.apache.spark.annotation.InterfaceStability; /** * A Spark SQL UDF that has 14 arguments. */ +@InterfaceStability.Stable public interface UDF14 extends Serializable { - public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14) throws Exception; + R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF15.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF15.java index 1fbcff56332b6..7ddbf914fc11a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF15.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF15.java @@ -19,14 +19,12 @@ import java.io.Serializable; -// ************************************************** -// THIS FILE IS AUTOGENERATED BY CODE IN -// org.apache.spark.sql.api.java.FunctionRegistration -// ************************************************** +import org.apache.spark.annotation.InterfaceStability; /** * A Spark SQL UDF that has 15 arguments. */ +@InterfaceStability.Stable public interface UDF15 extends Serializable { - public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15) throws Exception; + R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF16.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF16.java index 1133561787a69..0ae5dc7195ad6 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF16.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF16.java @@ -19,14 +19,12 @@ import java.io.Serializable; -// ************************************************** -// THIS FILE IS AUTOGENERATED BY CODE IN -// org.apache.spark.sql.api.java.FunctionRegistration -// ************************************************** +import org.apache.spark.annotation.InterfaceStability; /** * A Spark SQL UDF that has 16 arguments. */ +@InterfaceStability.Stable public interface UDF16 extends Serializable { - public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16) throws Exception; + R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF17.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF17.java index dfae7922c9b63..03543a556c614 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF17.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF17.java @@ -19,14 +19,12 @@ import java.io.Serializable; -// ************************************************** -// THIS FILE IS AUTOGENERATED BY CODE IN -// org.apache.spark.sql.api.java.FunctionRegistration -// ************************************************** +import org.apache.spark.annotation.InterfaceStability; /** * A Spark SQL UDF that has 17 arguments. */ +@InterfaceStability.Stable public interface UDF17 extends Serializable { - public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16, T17 t17) throws Exception; + R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16, T17 t17) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF18.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF18.java index e9d1c6d52d4ea..46740d3443916 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF18.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF18.java @@ -19,14 +19,12 @@ import java.io.Serializable; -// ************************************************** -// THIS FILE IS AUTOGENERATED BY CODE IN -// org.apache.spark.sql.api.java.FunctionRegistration -// ************************************************** +import org.apache.spark.annotation.InterfaceStability; /** * A Spark SQL UDF that has 18 arguments. */ +@InterfaceStability.Stable public interface UDF18 extends Serializable { - public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16, T17 t17, T18 t18) throws Exception; + R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16, T17 t17, T18 t18) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF19.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF19.java index 46b9d2d3c9457..33fefd8ecaf1d 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF19.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF19.java @@ -19,14 +19,12 @@ import java.io.Serializable; -// ************************************************** -// THIS FILE IS AUTOGENERATED BY CODE IN -// org.apache.spark.sql.api.java.FunctionRegistration -// ************************************************** +import org.apache.spark.annotation.InterfaceStability; /** * A Spark SQL UDF that has 19 arguments. */ +@InterfaceStability.Stable public interface UDF19 extends Serializable { - public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16, T17 t17, T18 t18, T19 t19) throws Exception; + R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16, T17 t17, T18 t18, T19 t19) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF2.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF2.java index cd3fde8da419e..9822f19217d76 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF2.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF2.java @@ -19,14 +19,12 @@ import java.io.Serializable; -// ************************************************** -// THIS FILE IS AUTOGENERATED BY CODE IN -// org.apache.spark.sql.api.java.FunctionRegistration -// ************************************************** +import org.apache.spark.annotation.InterfaceStability; /** * A Spark SQL UDF that has 2 arguments. */ +@InterfaceStability.Stable public interface UDF2 extends Serializable { - public R call(T1 t1, T2 t2) throws Exception; + R call(T1 t1, T2 t2) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF20.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF20.java index 113d3d26be4a7..8c5e90182da1c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF20.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF20.java @@ -19,14 +19,12 @@ import java.io.Serializable; -// ************************************************** -// THIS FILE IS AUTOGENERATED BY CODE IN -// org.apache.spark.sql.api.java.FunctionRegistration -// ************************************************** +import org.apache.spark.annotation.InterfaceStability; /** * A Spark SQL UDF that has 20 arguments. */ +@InterfaceStability.Stable public interface UDF20 extends Serializable { - public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16, T17 t17, T18 t18, T19 t19, T20 t20) throws Exception; + R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16, T17 t17, T18 t18, T19 t19, T20 t20) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF21.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF21.java index 74118f2cf8da7..e3b09f5167cff 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF21.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF21.java @@ -19,14 +19,12 @@ import java.io.Serializable; -// ************************************************** -// THIS FILE IS AUTOGENERATED BY CODE IN -// org.apache.spark.sql.api.java.FunctionRegistration -// ************************************************** +import org.apache.spark.annotation.InterfaceStability; /** * A Spark SQL UDF that has 21 arguments. */ +@InterfaceStability.Stable public interface UDF21 extends Serializable { - public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16, T17 t17, T18 t18, T19 t19, T20 t20, T21 t21) throws Exception; + R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16, T17 t17, T18 t18, T19 t19, T20 t20, T21 t21) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF22.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF22.java index 0e7cc40be45ec..dc6cfa9097bab 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF22.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF22.java @@ -19,14 +19,12 @@ import java.io.Serializable; -// ************************************************** -// THIS FILE IS AUTOGENERATED BY CODE IN -// org.apache.spark.sql.api.java.FunctionRegistration -// ************************************************** +import org.apache.spark.annotation.InterfaceStability; /** * A Spark SQL UDF that has 22 arguments. */ +@InterfaceStability.Stable public interface UDF22 extends Serializable { - public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16, T17 t17, T18 t18, T19 t19, T20 t20, T21 t21, T22 t22) throws Exception; + R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16, T17 t17, T18 t18, T19 t19, T20 t20, T21 t21, T22 t22) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF3.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF3.java index 6a880f16be47a..7c264b69ba195 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF3.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF3.java @@ -19,14 +19,12 @@ import java.io.Serializable; -// ************************************************** -// THIS FILE IS AUTOGENERATED BY CODE IN -// org.apache.spark.sql.api.java.FunctionRegistration -// ************************************************** +import org.apache.spark.annotation.InterfaceStability; /** * A Spark SQL UDF that has 3 arguments. */ +@InterfaceStability.Stable public interface UDF3 extends Serializable { - public R call(T1 t1, T2 t2, T3 t3) throws Exception; + R call(T1 t1, T2 t2, T3 t3) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF4.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF4.java index fcad2febb18e6..58df38fc3c911 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF4.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF4.java @@ -19,14 +19,12 @@ import java.io.Serializable; -// ************************************************** -// THIS FILE IS AUTOGENERATED BY CODE IN -// org.apache.spark.sql.api.java.FunctionRegistration -// ************************************************** +import org.apache.spark.annotation.InterfaceStability; /** * A Spark SQL UDF that has 4 arguments. */ +@InterfaceStability.Stable public interface UDF4 extends Serializable { - public R call(T1 t1, T2 t2, T3 t3, T4 t4) throws Exception; + R call(T1 t1, T2 t2, T3 t3, T4 t4) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF5.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF5.java index ce0cef43a2144..4146f96e2eed5 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF5.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF5.java @@ -19,14 +19,12 @@ import java.io.Serializable; -// ************************************************** -// THIS FILE IS AUTOGENERATED BY CODE IN -// org.apache.spark.sql.api.java.FunctionRegistration -// ************************************************** +import org.apache.spark.annotation.InterfaceStability; /** * A Spark SQL UDF that has 5 arguments. */ +@InterfaceStability.Stable public interface UDF5 extends Serializable { - public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5) throws Exception; + R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF6.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF6.java index f56b806684e61..25d39654c1095 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF6.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF6.java @@ -19,14 +19,12 @@ import java.io.Serializable; -// ************************************************** -// THIS FILE IS AUTOGENERATED BY CODE IN -// org.apache.spark.sql.api.java.FunctionRegistration -// ************************************************** +import org.apache.spark.annotation.InterfaceStability; /** * A Spark SQL UDF that has 6 arguments. */ +@InterfaceStability.Stable public interface UDF6 extends Serializable { - public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6) throws Exception; + R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF7.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF7.java index 25bd6d3241bd4..ce63b6a91adbb 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF7.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF7.java @@ -19,14 +19,12 @@ import java.io.Serializable; -// ************************************************** -// THIS FILE IS AUTOGENERATED BY CODE IN -// org.apache.spark.sql.api.java.FunctionRegistration -// ************************************************** +import org.apache.spark.annotation.InterfaceStability; /** * A Spark SQL UDF that has 7 arguments. */ +@InterfaceStability.Stable public interface UDF7 extends Serializable { - public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7) throws Exception; + R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF8.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF8.java index a3b7ac5f94ce7..0e00209ef6b9f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF8.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF8.java @@ -19,14 +19,12 @@ import java.io.Serializable; -// ************************************************** -// THIS FILE IS AUTOGENERATED BY CODE IN -// org.apache.spark.sql.api.java.FunctionRegistration -// ************************************************** +import org.apache.spark.annotation.InterfaceStability; /** * A Spark SQL UDF that has 8 arguments. */ +@InterfaceStability.Stable public interface UDF8 extends Serializable { - public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8) throws Exception; + R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF9.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF9.java index 205e72a1522fc..077981bb3e3ee 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF9.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF9.java @@ -19,14 +19,12 @@ import java.io.Serializable; -// ************************************************** -// THIS FILE IS AUTOGENERATED BY CODE IN -// org.apache.spark.sql.api.java.FunctionRegistration -// ************************************************** +import org.apache.spark.annotation.InterfaceStability; /** * A Spark SQL UDF that has 9 arguments. */ +@InterfaceStability.Stable public interface UDF9 extends Serializable { - public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9) throws Exception; + R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/expressions/javalang/typed.java b/sql/core/src/main/java/org/apache/spark/sql/expressions/javalang/typed.java index 247e94b86c349..ec9c107b1c119 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/expressions/javalang/typed.java +++ b/sql/core/src/main/java/org/apache/spark/sql/expressions/javalang/typed.java @@ -18,6 +18,7 @@ package org.apache.spark.sql.expressions.javalang; import org.apache.spark.annotation.Experimental; +import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.api.java.function.MapFunction; import org.apache.spark.sql.TypedColumn; import org.apache.spark.sql.execution.aggregate.TypedAverage; @@ -34,6 +35,7 @@ * @since 2.0.0 */ @Experimental +@InterfaceStability.Evolving public class typed { // Note: make sure to keep in sync with typed.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala index 18cba8ce28b4d..889b8a02784d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalog -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, InterfaceStability} import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset} import org.apache.spark.sql.types.StructType @@ -27,6 +27,7 @@ import org.apache.spark.sql.types.StructType * * @since 2.0.0 */ +@InterfaceStability.Stable abstract class Catalog { /** @@ -193,6 +194,7 @@ abstract class Catalog { * @since 2.0.0 */ @Experimental + @InterfaceStability.Evolving def createExternalTable(tableName: String, path: String): DataFrame /** @@ -203,6 +205,7 @@ abstract class Catalog { * @since 2.0.0 */ @Experimental + @InterfaceStability.Evolving def createExternalTable(tableName: String, path: String, source: String): DataFrame /** @@ -213,6 +216,7 @@ abstract class Catalog { * @since 2.0.0 */ @Experimental + @InterfaceStability.Evolving def createExternalTable( tableName: String, source: String, @@ -227,6 +231,7 @@ abstract class Catalog { * @since 2.0.0 */ @Experimental + @InterfaceStability.Evolving def createExternalTable( tableName: String, source: String, @@ -240,6 +245,7 @@ abstract class Catalog { * @since 2.0.0 */ @Experimental + @InterfaceStability.Evolving def createExternalTable( tableName: String, source: String, @@ -255,6 +261,7 @@ abstract class Catalog { * @since 2.0.0 */ @Experimental + @InterfaceStability.Evolving def createExternalTable( tableName: String, source: String, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalog/interface.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalog/interface.scala index 33032f07f7bea..c0c5ebc2ba2d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalog/interface.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalog/interface.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalog import javax.annotation.Nullable +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.DefinedByConstructorParams @@ -33,6 +34,7 @@ import org.apache.spark.sql.catalyst.DefinedByConstructorParams * @param locationUri path (in the form of a uri) to data files. * @since 2.0.0 */ +@InterfaceStability.Stable class Database( val name: String, @Nullable val description: String, @@ -59,6 +61,7 @@ class Database( * @param isTemporary whether the table is a temporary table. * @since 2.0.0 */ +@InterfaceStability.Stable class Table( val name: String, @Nullable val database: String, @@ -90,6 +93,7 @@ class Table( * @param isBucket whether the column is a bucket column. * @since 2.0.0 */ +@InterfaceStability.Stable class Column( val name: String, @Nullable val description: String, @@ -122,6 +126,7 @@ class Column( * @param isTemporary whether the function is a temporary function or not. * @since 2.0.0 */ +@InterfaceStability.Stable class Function( val name: String, @Nullable val database: String, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala index 51179a528c503..eea98414003ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.expressions -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, InterfaceStability} import org.apache.spark.sql.{Dataset, Encoder, TypedColumn} import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete} @@ -51,6 +51,7 @@ import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression * @since 1.6.0 */ @Experimental +@InterfaceStability.Evolving abstract class Aggregator[-IN, BUF, OUT] extends Serializable { /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala index 49fdec57558e8..2e0e937e4aff7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.expressions -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, InterfaceStability} import org.apache.spark.sql.catalyst.expressions.ScalaUDF import org.apache.spark.sql.Column import org.apache.spark.sql.functions @@ -40,6 +40,7 @@ import org.apache.spark.sql.types.DataType * @since 1.3.0 */ @Experimental +@InterfaceStability.Evolving case class UserDefinedFunction protected[sql] ( f: AnyRef, dataType: DataType, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala index 3c1f6e897ea62..07ef60183f6fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.expressions -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, InterfaceStability} import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.expressions._ @@ -37,6 +37,7 @@ import org.apache.spark.sql.catalyst.expressions._ * @since 1.4.0 */ @Experimental +@InterfaceStability.Evolving object Window { /** @@ -177,4 +178,5 @@ object Window { * @since 1.4.0 */ @Experimental +@InterfaceStability.Evolving class Window private() // So we can see Window in JavaDoc. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala index 8ebed399bf2d0..18778c8d1c294 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala @@ -17,8 +17,8 @@ package org.apache.spark.sql.expressions -import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.{catalyst, Column} +import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.expressions._ /** @@ -30,10 +30,11 @@ import org.apache.spark.sql.catalyst.expressions._ * @since 1.4.0 */ @Experimental +@InterfaceStability.Evolving class WindowSpec private[sql]( partitionSpec: Seq[Expression], orderSpec: Seq[SortOrder], - frame: catalyst.expressions.WindowFrame) { + frame: WindowFrame) { /** * Defines the partitioning columns in a [[WindowSpec]]. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/scalalang/typed.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/scalalang/typed.scala index 60d7b7d0894d0..aa71cb9e3bc85 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/scalalang/typed.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/scalalang/typed.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.expressions.scalalang -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, InterfaceStability} import org.apache.spark.sql._ import org.apache.spark.sql.execution.aggregate._ @@ -30,6 +30,7 @@ import org.apache.spark.sql.execution.aggregate._ * @since 2.0.0 */ @Experimental +@InterfaceStability.Evolving // scalastyle:off object typed { // scalastyle:on diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala index 5417a0e481158..ef7c09c72b82d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.expressions -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, InterfaceStability} import org.apache.spark.sql.{Column, Row} import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete} import org.apache.spark.sql.execution.aggregate.ScalaUDAF @@ -26,8 +26,11 @@ import org.apache.spark.sql.types._ /** * :: Experimental :: * The base class for implementing user-defined aggregate functions (UDAF). + * + * @since 1.5.0 */ @Experimental +@InterfaceStability.Evolving abstract class UserDefinedAggregateFunction extends Serializable { /** @@ -136,8 +139,11 @@ abstract class UserDefinedAggregateFunction extends Serializable { * A [[Row]] representing a mutable aggregation buffer. * * This is not meant to be extended outside of Spark. + * + * @since 1.5.0 */ @Experimental +@InterfaceStability.Evolving abstract class MutableAggregationBuffer extends Row { /** Update the ith value of this buffer. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 8dd4b8f662713..dec316be7aea1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.jdbc import java.sql.Connection -import org.apache.spark.annotation.{DeveloperApi, Since} +import org.apache.spark.annotation.{DeveloperApi, InterfaceStability, Since} import org.apache.spark.sql.types._ /** @@ -31,6 +31,7 @@ import org.apache.spark.sql.types._ * send a null value to the database. */ @DeveloperApi +@InterfaceStability.Evolving case class JdbcType(databaseTypeDefinition : String, jdbcNullType : Int) /** @@ -53,6 +54,7 @@ case class JdbcType(databaseTypeDefinition : String, jdbcNullType : Int) * for the given Catalyst type. */ @DeveloperApi +@InterfaceStability.Evolving abstract class JdbcDialect extends Serializable { /** * Check if this dialect instance can handle a certain jdbc url. @@ -142,6 +144,7 @@ abstract class JdbcDialect extends Serializable { * sure to register your dialects first. */ @DeveloperApi +@InterfaceStability.Evolving object JdbcDialects { /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala index 13c0766219a8e..e0494dfd9343b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.sources +import org.apache.spark.annotation.InterfaceStability + //////////////////////////////////////////////////////////////////////////////////////////////////// // This file defines all the filters that we can push down to the data sources. //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -26,6 +28,7 @@ package org.apache.spark.sql.sources * * @since 1.3.0 */ +@InterfaceStability.Stable abstract class Filter { /** * List of columns that are referenced by this filter. @@ -45,6 +48,7 @@ abstract class Filter { * * @since 1.3.0 */ +@InterfaceStability.Stable case class EqualTo(attribute: String, value: Any) extends Filter { override def references: Array[String] = Array(attribute) ++ findReferences(value) } @@ -56,6 +60,7 @@ case class EqualTo(attribute: String, value: Any) extends Filter { * * @since 1.5.0 */ +@InterfaceStability.Stable case class EqualNullSafe(attribute: String, value: Any) extends Filter { override def references: Array[String] = Array(attribute) ++ findReferences(value) } @@ -66,6 +71,7 @@ case class EqualNullSafe(attribute: String, value: Any) extends Filter { * * @since 1.3.0 */ +@InterfaceStability.Stable case class GreaterThan(attribute: String, value: Any) extends Filter { override def references: Array[String] = Array(attribute) ++ findReferences(value) } @@ -76,6 +82,7 @@ case class GreaterThan(attribute: String, value: Any) extends Filter { * * @since 1.3.0 */ +@InterfaceStability.Stable case class GreaterThanOrEqual(attribute: String, value: Any) extends Filter { override def references: Array[String] = Array(attribute) ++ findReferences(value) } @@ -86,6 +93,7 @@ case class GreaterThanOrEqual(attribute: String, value: Any) extends Filter { * * @since 1.3.0 */ +@InterfaceStability.Stable case class LessThan(attribute: String, value: Any) extends Filter { override def references: Array[String] = Array(attribute) ++ findReferences(value) } @@ -96,6 +104,7 @@ case class LessThan(attribute: String, value: Any) extends Filter { * * @since 1.3.0 */ +@InterfaceStability.Stable case class LessThanOrEqual(attribute: String, value: Any) extends Filter { override def references: Array[String] = Array(attribute) ++ findReferences(value) } @@ -105,6 +114,7 @@ case class LessThanOrEqual(attribute: String, value: Any) extends Filter { * * @since 1.3.0 */ +@InterfaceStability.Stable case class In(attribute: String, values: Array[Any]) extends Filter { override def hashCode(): Int = { var h = attribute.hashCode @@ -131,6 +141,7 @@ case class In(attribute: String, values: Array[Any]) extends Filter { * * @since 1.3.0 */ +@InterfaceStability.Stable case class IsNull(attribute: String) extends Filter { override def references: Array[String] = Array(attribute) } @@ -140,6 +151,7 @@ case class IsNull(attribute: String) extends Filter { * * @since 1.3.0 */ +@InterfaceStability.Stable case class IsNotNull(attribute: String) extends Filter { override def references: Array[String] = Array(attribute) } @@ -149,6 +161,7 @@ case class IsNotNull(attribute: String) extends Filter { * * @since 1.3.0 */ +@InterfaceStability.Stable case class And(left: Filter, right: Filter) extends Filter { override def references: Array[String] = left.references ++ right.references } @@ -158,6 +171,7 @@ case class And(left: Filter, right: Filter) extends Filter { * * @since 1.3.0 */ +@InterfaceStability.Stable case class Or(left: Filter, right: Filter) extends Filter { override def references: Array[String] = left.references ++ right.references } @@ -167,6 +181,7 @@ case class Or(left: Filter, right: Filter) extends Filter { * * @since 1.3.0 */ +@InterfaceStability.Stable case class Not(child: Filter) extends Filter { override def references: Array[String] = child.references } @@ -177,6 +192,7 @@ case class Not(child: Filter) extends Filter { * * @since 1.3.1 */ +@InterfaceStability.Stable case class StringStartsWith(attribute: String, value: String) extends Filter { override def references: Array[String] = Array(attribute) } @@ -187,6 +203,7 @@ case class StringStartsWith(attribute: String, value: String) extends Filter { * * @since 1.3.1 */ +@InterfaceStability.Stable case class StringEndsWith(attribute: String, value: String) extends Filter { override def references: Array[String] = Array(attribute) } @@ -197,6 +214,7 @@ case class StringEndsWith(attribute: String, value: String) extends Filter { * * @since 1.3.1 */ +@InterfaceStability.Stable case class StringContains(attribute: String, value: String) extends Filter { override def references: Array[String] = Array(attribute) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 6484c782b5d15..3172d5ded9504 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources -import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow @@ -37,6 +37,7 @@ import org.apache.spark.sql.types.StructType * @since 1.5.0 */ @DeveloperApi +@InterfaceStability.Evolving trait DataSourceRegister { /** @@ -68,6 +69,7 @@ trait DataSourceRegister { * @since 1.3.0 */ @DeveloperApi +@InterfaceStability.Evolving trait RelationProvider { /** * Returns a new base relation with the given parameters. @@ -99,6 +101,7 @@ trait RelationProvider { * @since 1.3.0 */ @DeveloperApi +@InterfaceStability.Evolving trait SchemaRelationProvider { /** * Returns a new base relation with the given parameters and user defined schema. @@ -114,17 +117,26 @@ trait SchemaRelationProvider { /** * ::Experimental:: * Implemented by objects that can produce a streaming [[Source]] for a specific format or system. + * + * @since 2.0.0 */ @Experimental +@InterfaceStability.Unstable trait StreamSourceProvider { - /** Returns the name and schema of the source that can be used to continually read data. */ + /** + * Returns the name and schema of the source that can be used to continually read data. + * @since 2.0.0 + */ def sourceSchema( sqlContext: SQLContext, schema: Option[StructType], providerName: String, parameters: Map[String, String]): (String, StructType) + /** + * @since 2.0.0 + */ def createSource( sqlContext: SQLContext, metadataPath: String, @@ -136,8 +148,11 @@ trait StreamSourceProvider { /** * ::Experimental:: * Implemented by objects that can produce a streaming [[Sink]] for a specific format or system. + * + * @since 2.0.0 */ @Experimental +@InterfaceStability.Unstable trait StreamSinkProvider { def createSink( sqlContext: SQLContext, @@ -150,6 +165,7 @@ trait StreamSinkProvider { * @since 1.3.0 */ @DeveloperApi +@InterfaceStability.Evolving trait CreatableRelationProvider { /** * Save the DataFrame to the destination and return a relation with the given parameters based on @@ -186,6 +202,7 @@ trait CreatableRelationProvider { * @since 1.3.0 */ @DeveloperApi +@InterfaceStability.Evolving abstract class BaseRelation { def sqlContext: SQLContext def schema: StructType @@ -237,6 +254,7 @@ abstract class BaseRelation { * @since 1.3.0 */ @DeveloperApi +@InterfaceStability.Evolving trait TableScan { def buildScan(): RDD[Row] } @@ -249,6 +267,7 @@ trait TableScan { * @since 1.3.0 */ @DeveloperApi +@InterfaceStability.Evolving trait PrunedScan { def buildScan(requiredColumns: Array[String]): RDD[Row] } @@ -268,6 +287,7 @@ trait PrunedScan { * @since 1.3.0 */ @DeveloperApi +@InterfaceStability.Evolving trait PrunedFilteredScan { def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] } @@ -291,6 +311,7 @@ trait PrunedFilteredScan { * @since 1.3.0 */ @DeveloperApi +@InterfaceStability.Evolving trait InsertableRelation { def insert(data: DataFrame, overwrite: Boolean): Unit } @@ -306,6 +327,7 @@ trait InsertableRelation { * @since 1.3.0 */ @Experimental +@InterfaceStability.Unstable trait CatalystScan { def buildScan(requiredColumns: Seq[Attribute], filters: Seq[Expression]): RDD[Row] } From 84f149e414475c2e60863898992001c21cfc13b2 Mon Sep 17 00:00:00 2001 From: Pete Robbins Date: Thu, 13 Oct 2016 11:26:30 -0700 Subject: [PATCH 275/306] [SPARK-17827][SQL] maxColLength type should be Int for String and Binary ## What changes were proposed in this pull request? correct the expected type from Length function to be Int ## How was this patch tested? Test runs on little endian and big endian platforms Author: Pete Robbins Closes #15464 from robbinspg/SPARK-17827. --- .../spark/sql/catalyst/plans/logical/Statistics.scala | 4 ++-- .../org/apache/spark/sql/StatisticsColumnSuite.scala | 8 ++++---- .../scala/org/apache/spark/sql/hive/StatisticsSuite.scala | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala index 43455c989c0f4..f3e2147b8f974 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala @@ -98,7 +98,7 @@ case class StringColumnStat(statRow: InternalRow) { // The indices here must be consistent with `ColumnStatStruct.stringColumnStat`. val numNulls: Long = statRow.getLong(0) val avgColLen: Double = statRow.getDouble(1) - val maxColLen: Long = statRow.getLong(2) + val maxColLen: Long = statRow.getInt(2) val ndv: Long = statRow.getLong(3) } @@ -106,7 +106,7 @@ case class BinaryColumnStat(statRow: InternalRow) { // The indices here must be consistent with `ColumnStatStruct.binaryColumnStat`. val numNulls: Long = statRow.getLong(0) val avgColLen: Double = statRow.getDouble(1) - val maxColLen: Long = statRow.getLong(2) + val maxColLen: Long = statRow.getInt(2) } case class BooleanColumnStat(statRow: InternalRow) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala index 0ee0547c45591..f1a201abd8da6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala @@ -150,7 +150,7 @@ class StatisticsColumnSuite extends StatisticsTest { val colStat = ColumnStat(InternalRow( values.count(_.isEmpty).toLong, nonNullValues.map(_.length).sum / nonNullValues.length.toDouble, - nonNullValues.map(_.length).max.toLong, + nonNullValues.map(_.length).max.toInt, nonNullValues.distinct.length.toLong)) (f, colStat) } @@ -165,7 +165,7 @@ class StatisticsColumnSuite extends StatisticsTest { val colStat = ColumnStat(InternalRow( values.count(_.isEmpty).toLong, nonNullValues.map(_.length).sum / nonNullValues.length.toDouble, - nonNullValues.map(_.length).max.toLong)) + nonNullValues.map(_.length).max.toInt)) (f, colStat) } checkColStats(df, expectedColStatsSeq) @@ -255,10 +255,10 @@ class StatisticsColumnSuite extends StatisticsTest { doubleSeq.distinct.length.toLong)) case StringType => ColumnStat(InternalRow(0L, stringSeq.map(_.length).sum / stringSeq.length.toDouble, - stringSeq.map(_.length).max.toLong, stringSeq.distinct.length.toLong)) + stringSeq.map(_.length).max.toInt, stringSeq.distinct.length.toLong)) case BinaryType => ColumnStat(InternalRow(0L, binarySeq.map(_.length).sum / binarySeq.length.toDouble, - binarySeq.map(_.length).max.toLong)) + binarySeq.map(_.length).max.toInt)) case BooleanType => ColumnStat(InternalRow(0L, booleanSeq.count(_.equals(true)).toLong, booleanSeq.count(_.equals(false)).toLong)) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 99dd080683d40..85228bb00123d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -378,7 +378,7 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils ColumnStat(InternalRow(0L, intSeq.max, intSeq.min, intSeq.distinct.length.toLong)) case StringType => ColumnStat(InternalRow(0L, stringSeq.map(_.length).sum / stringSeq.length.toDouble, - stringSeq.map(_.length).max.toLong, stringSeq.distinct.length.toLong)) + stringSeq.map(_.length).max.toInt, stringSeq.distinct.length.toLong)) case BooleanType => ColumnStat(InternalRow(0L, booleanSeq.count(_.equals(true)).toLong, booleanSeq.count(_.equals(false)).toLong)) From 08eac356095c7faa2b19d52f2fb0cbc47eb7d1d1 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Thu, 13 Oct 2016 13:31:50 -0700 Subject: [PATCH 276/306] [SPARK-17834][SQL] Fetch the earliest offsets manually in KafkaSource instead of counting on KafkaConsumer ## What changes were proposed in this pull request? Because `KafkaConsumer.poll(0)` may update the partition offsets, this PR just calls `seekToBeginning` to manually set the earliest offsets for the KafkaSource initial offsets. ## How was this patch tested? Existing tests. Author: Shixiong Zhu Closes #15397 from zsxwing/SPARK-17834. --- .../spark/sql/kafka010/KafkaSource.scala | 55 ++++++++++++------- .../sql/kafka010/KafkaSourceProvider.scala | 19 +++++-- 2 files changed, 48 insertions(+), 26 deletions(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala index 1be70db87497e..4b0bb0a0f725c 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala @@ -82,6 +82,7 @@ private[kafka010] case class KafkaSource( executorKafkaParams: ju.Map[String, Object], sourceOptions: Map[String, String], metadataPath: String, + startFromEarliestOffset: Boolean, failOnDataLoss: Boolean) extends Source with Logging { @@ -109,7 +110,11 @@ private[kafka010] case class KafkaSource( private lazy val initialPartitionOffsets = { val metadataLog = new HDFSMetadataLog[KafkaSourceOffset](sqlContext.sparkSession, metadataPath) metadataLog.get(0).getOrElse { - val offsets = KafkaSourceOffset(fetchPartitionOffsets(seekToEnd = false)) + val offsets = if (startFromEarliestOffset) { + KafkaSourceOffset(fetchEarliestOffsets()) + } else { + KafkaSourceOffset(fetchLatestOffsets()) + } metadataLog.add(0, offsets) logInfo(s"Initial offsets: $offsets") offsets @@ -123,7 +128,7 @@ private[kafka010] case class KafkaSource( // Make sure initialPartitionOffsets is initialized initialPartitionOffsets - val offset = KafkaSourceOffset(fetchPartitionOffsets(seekToEnd = true)) + val offset = KafkaSourceOffset(fetchLatestOffsets()) logDebug(s"GetOffset: ${offset.partitionToOffsets.toSeq.map(_.toString).sorted}") Some(offset) } @@ -227,26 +232,34 @@ private[kafka010] case class KafkaSource( override def toString(): String = s"KafkaSource[$consumerStrategy]" /** - * Fetch the offset of a partition, either seek to the latest offsets or use the current offsets - * in the consumer. + * Fetch the earliest offsets of partitions. */ - private def fetchPartitionOffsets( - seekToEnd: Boolean): Map[TopicPartition, Long] = withRetriesWithoutInterrupt { - // Make sure `KafkaConsumer.poll` won't be interrupted (KAFKA-1894) - assert(Thread.currentThread().isInstanceOf[StreamExecutionThread]) + private def fetchEarliestOffsets(): Map[TopicPartition, Long] = withRetriesWithoutInterrupt { // Poll to get the latest assigned partitions consumer.poll(0) val partitions = consumer.assignment() consumer.pause(partitions) - logDebug(s"Partitioned assigned to consumer: $partitions") + logDebug(s"Partitions assigned to consumer: $partitions. Seeking to the beginning") - // Get the current or latest offset of each partition - if (seekToEnd) { - consumer.seekToEnd(partitions) - logDebug("Seeked to the end") - } + consumer.seekToBeginning(partitions) + val partitionOffsets = partitions.asScala.map(p => p -> consumer.position(p)).toMap + logDebug(s"Got earliest offsets for partition : $partitionOffsets") + partitionOffsets + } + + /** + * Fetch the latest offset of partitions. + */ + private def fetchLatestOffsets(): Map[TopicPartition, Long] = withRetriesWithoutInterrupt { + // Poll to get the latest assigned partitions + consumer.poll(0) + val partitions = consumer.assignment() + consumer.pause(partitions) + logDebug(s"Partitions assigned to consumer: $partitions. Seeking to the end.") + + consumer.seekToEnd(partitions) val partitionOffsets = partitions.asScala.map(p => p -> consumer.position(p)).toMap - logDebug(s"Got offsets for partition : $partitionOffsets") + logDebug(s"Got latest offsets for partition : $partitionOffsets") partitionOffsets } @@ -256,22 +269,21 @@ private[kafka010] case class KafkaSource( */ private def fetchNewPartitionEarliestOffsets( newPartitions: Seq[TopicPartition]): Map[TopicPartition, Long] = withRetriesWithoutInterrupt { - // Make sure `KafkaConsumer.poll` won't be interrupted (KAFKA-1894) - assert(Thread.currentThread().isInstanceOf[StreamExecutionThread]) // Poll to get the latest assigned partitions consumer.poll(0) val partitions = consumer.assignment() + consumer.pause(partitions) logDebug(s"\tPartitioned assigned to consumer: $partitions") // Get the earliest offset of each partition consumer.seekToBeginning(partitions) - val partitionToOffsets = newPartitions.filter { p => + val partitionOffsets = newPartitions.filter { p => // When deleting topics happen at the same time, some partitions may not be in `partitions`. // So we need to ignore them partitions.contains(p) }.map(p => p -> consumer.position(p)).toMap - logDebug(s"Got offsets for new partitions: $partitionToOffsets") - partitionToOffsets + logDebug(s"Got earliest offsets for new partitions: $partitionOffsets") + partitionOffsets } /** @@ -284,6 +296,9 @@ private[kafka010] case class KafkaSource( */ private def withRetriesWithoutInterrupt( body: => Map[TopicPartition, Long]): Map[TopicPartition, Long] = { + // Make sure `KafkaConsumer.poll` won't be interrupted (KAFKA-1894) + assert(Thread.currentThread().isInstanceOf[StreamExecutionThread]) + synchronized { var result: Option[Map[TopicPartition, Long]] = None var attempt = 1 diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 1b0a2fe955d03..23b1b60f3bcaa 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -77,10 +77,15 @@ private[kafka010] class KafkaSourceProvider extends StreamSourceProvider // id. Hence, we should generate a unique id for each query. val uniqueGroupId = s"spark-kafka-source-${UUID.randomUUID}-${metadataPath.hashCode}" - val autoOffsetResetValue = caseInsensitiveParams.get(STARTING_OFFSET_OPTION_KEY) match { - case Some(value) => value.trim() // same values as those supported by auto.offset.reset - case None => "latest" - } + val startFromEarliestOffset = + caseInsensitiveParams.get(STARTING_OFFSET_OPTION_KEY).map(_.trim.toLowerCase) match { + case Some("latest") => false + case Some("earliest") => true + case Some(pos) => + // This should not happen since we have already checked the options. + throw new IllegalStateException(s"Invalid $STARTING_OFFSET_OPTION_KEY: $pos") + case None => false + } val kafkaParamsForStrategy = ConfigUpdater("source", specifiedKafkaParams) @@ -90,8 +95,9 @@ private[kafka010] class KafkaSourceProvider extends StreamSourceProvider // So that consumers in Kafka source do not mess with any existing group id .set(ConsumerConfig.GROUP_ID_CONFIG, s"$uniqueGroupId-driver") - // So that consumers can start from earliest or latest - .set(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, autoOffsetResetValue) + // Set to "latest" to avoid exceptions. However, KafkaSource will fetch the initial offsets + // by itself instead of counting on KafkaConsumer. + .set(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "latest") // So that consumers in the driver does not commit offsets unnecessarily .set(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "false") @@ -147,6 +153,7 @@ private[kafka010] class KafkaSourceProvider extends StreamSourceProvider kafkaParamsForExecutors, parameters, metadataPath, + startFromEarliestOffset, failOnDataLoss) } From 7106866c220c73960c6fe2a70e4911516617e21f Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 13 Oct 2016 13:36:26 -0700 Subject: [PATCH 277/306] [SPARK-17731][SQL][STREAMING] Metrics for structured streaming ## What changes were proposed in this pull request? Metrics are needed for monitoring structured streaming apps. Here is the design doc for implementing the necessary metrics. https://docs.google.com/document/d/1NIdcGuR1B3WIe8t7VxLrt58TJB4DtipWEbj5I_mzJys/edit?usp=sharing Specifically, this PR adds the following public APIs changes. ### New APIs - `StreamingQuery.status` returns a `StreamingQueryStatus` object (renamed from `StreamingQueryInfo`, see later) - `StreamingQueryStatus` has the following important fields - inputRate - Current rate (rows/sec) at which data is being generated by all the sources - processingRate - Current rate (rows/sec) at which the query is processing data from all the sources - ~~outputRate~~ - *Does not work with wholestage codegen* - latency - Current average latency between the data being available in source and the sink writing the corresponding output - sourceStatuses: Array[SourceStatus] - Current statuses of the sources - sinkStatus: SinkStatus - Current status of the sink - triggerStatus - Low-level detailed status of the last completed/currently active trigger - latencies - getOffset, getBatch, full trigger, wal writes - timestamps - trigger start, finish, after getOffset, after getBatch - numRows - input, output, state total/updated rows for aggregations - `SourceStatus` has the following important fields - inputRate - Current rate (rows/sec) at which data is being generated by the source - processingRate - Current rate (rows/sec) at which the query is processing data from the source - triggerStatus - Low-level detailed status of the last completed/currently active trigger - Python API for `StreamingQuery.status()` ### Breaking changes to existing APIs **Existing direct public facing APIs** - Deprecated direct public-facing APIs `StreamingQuery.sourceStatuses` and `StreamingQuery.sinkStatus` in favour of `StreamingQuery.status.sourceStatuses/sinkStatus`. - Branch 2.0 should have it deprecated, master should have it removed. **Existing advanced listener APIs** - `StreamingQueryInfo` renamed to `StreamingQueryStatus` for consistency with `SourceStatus`, `SinkStatus` - Earlier StreamingQueryInfo was used only in the advanced listener API, but now it is used in direct public-facing API (StreamingQuery.status) - Field `queryInfo` in listener events `QueryStarted`, `QueryProgress`, `QueryTerminated` changed have name `queryStatus` and return type `StreamingQueryStatus`. - Field `offsetDesc` in `SourceStatus` was Option[String], converted it to `String`. - For `SourceStatus` and `SinkStatus` made constructor private instead of private[sql] to make them more java-safe. Instead added `private[sql] object SourceStatus/SinkStatus.apply()` which are harder to accidentally use in Java. ## How was this patch tested? Old and new unit tests. - Rate calculation and other internal logic of StreamMetrics tested by StreamMetricsSuite. - New info in statuses returned through StreamingQueryListener is tested in StreamingQueryListenerSuite. - New and old info returned through StreamingQuery.status is tested in StreamingQuerySuite. - Source-specific tests for making sure input rows are counted are is source-specific test suites. - Additional tests to test minor additions in LocalTableScanExec, StateStore, etc. Metrics also manually tested using Ganglia sink Author: Tathagata Das Closes #15307 from tdas/SPARK-17731. --- .../spark/sql/kafka010/KafkaSourceSuite.scala | 27 ++ project/MimaExcludes.scala | 13 + python/pyspark/sql/streaming.py | 301 +++++++++++++++++ .../spark/sql/catalyst/trees/TreeNode.scala | 7 + .../sql/execution/LocalTableScanExec.scala | 5 +- .../streaming/StatefulAggregate.scala | 31 +- .../execution/streaming/StreamExecution.scala | 307 ++++++++++++++---- .../execution/streaming/StreamMetrics.scala | 242 ++++++++++++++ .../sql/execution/streaming/memory.scala | 7 + .../state/HDFSBackedStateStoreProvider.scala | 2 + .../streaming/state/StateStore.scala | 3 + .../apache/spark/sql/internal/SQLConf.scala | 8 + .../spark/sql/streaming/SinkStatus.scala | 28 +- .../spark/sql/streaming/SourceStatus.scala | 54 ++- .../spark/sql/streaming/StreamingQuery.scala | 13 +- .../sql/streaming/StreamingQueryInfo.scala | 37 --- .../streaming/StreamingQueryListener.scala | 8 +- .../sql/streaming/StreamingQueryStatus.scala | 139 ++++++++ .../execution/metric/SQLMetricsSuite.scala | 17 + .../streaming/StreamMetricsSuite.scala | 213 ++++++++++++ .../streaming/TextSocketStreamSuite.scala | 24 ++ .../streaming/state/StateStoreSuite.scala | 5 + .../sql/streaming/FileStreamSourceSuite.scala | 14 + .../spark/sql/streaming/StreamTest.scala | 72 ++++ .../streaming/StreamingAggregationSuite.scala | 54 +++ .../StreamingQueryListenerSuite.scala | 220 +++++-------- .../sql/streaming/StreamingQuerySuite.scala | 180 +++++++++- 27 files changed, 1758 insertions(+), 273 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamMetrics.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryInfo.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/StreamMetricsSuite.scala diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala index c640b93b0a2ee..8b5296ea135c7 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala @@ -264,6 +264,33 @@ class KafkaSourceSuite extends KafkaSourceTest { testUnsupportedConfig("kafka.auto.offset.reset", "latest") } + test("input row metrics") { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 5) + testUtils.sendMessages(topic, Array("-1")) + require(testUtils.getLatestOffsets(Set(topic)).size === 5) + + val kafka = spark + .readStream + .format("kafka") + .option("subscribe", topic) + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + + val mapped = kafka.map(kv => kv._2.toInt + 1) + testStream(mapped)( + makeSureGetOffsetCalled, + AddKafkaData(Set(topic), 1, 2, 3), + CheckAnswer(2, 3, 4), + AssertOnLastQueryStatus { status => + assert(status.triggerDetails.get("numRows.input.total").toInt > 0) + assert(status.sourceStatuses(0).processingRate > 0.0) + } + ) + } + private def newTopic(): String = s"topic-${topicId.getAndIncrement()}" private def testFromLatestOffsets(topic: String, options: (String, String)*): Unit = { diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index ae72d37a0b61c..1349af4219c16 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -56,6 +56,19 @@ object MimaExcludes { ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.databaseExists"), ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.tableExists"), ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.functionExists"), + + // [SPARK-17731][SQL][Streaming] Metrics for structured streaming + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.SourceStatus.this"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.streaming.SourceStatus.offsetDesc"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQuery.status"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.SinkStatus.this"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.StreamingQueryInfo"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.streaming.StreamingQueryListener#QueryStarted.this"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryListener#QueryStarted.queryInfo"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.streaming.StreamingQueryListener#QueryProgress.this"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryListener#QueryProgress.queryInfo"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryListener#QueryTerminated.queryInfo"), + // [SPARK-17338][SQL] add global temp view ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.dropGlobalTempView"), ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.catalog.Catalog.dropTempView"), diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 4e438fd5bee22..ce47bd1640fb1 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -189,6 +189,304 @@ def resetTerminated(self): self._jsqm.resetTerminated() +class StreamingQueryStatus(object): + """A class used to report information about the progress of a StreamingQuery. + + .. note:: Experimental + + .. versionadded:: 2.1 + """ + + def __init__(self, jsqs): + self._jsqs = jsqs + + def __str__(self): + """ + Pretty string of this query status. + + >>> print(sqs) + StreamingQueryStatus: + Query name: query + Query id: 1 + Status timestamp: 123 + Input rate: 15.5 rows/sec + Processing rate 23.5 rows/sec + Latency: 345.0 ms + Trigger details: + isDataPresentInTrigger: true + isTriggerActive: true + latency.getBatch.total: 20 + latency.getOffset.total: 10 + numRows.input.total: 100 + triggerId: 5 + Source statuses [1 source]: + Source 1: MySource1 + Available offset: #0 + Input rate: 15.5 rows/sec + Processing rate: 23.5 rows/sec + Trigger details: + numRows.input.source: 100 + latency.getOffset.source: 10 + latency.getBatch.source: 20 + Sink status: MySink + Committed offsets: [#1, -] + """ + return self._jsqs.toString() + + @property + @ignore_unicode_prefix + @since(2.1) + def name(self): + """ + Name of the query. This name is unique across all active queries. + + >>> sqs.name + u'query' + """ + return self._jsqs.name() + + @property + @since(2.1) + def id(self): + """ + Id of the query. This id is unique across all queries that have been started in + the current process. + + >>> int(sqs.id) + 1 + """ + return self._jsqs.id() + + @property + @since(2.1) + def timestamp(self): + """ + Timestamp (ms) of when this query was generated. + + >>> int(sqs.timestamp) + 123 + """ + return self._jsqs.timestamp() + + @property + @since(2.1) + def inputRate(self): + """ + Current total rate (rows/sec) at which data is being generated by all the sources. + + >>> sqs.inputRate + 15.5 + """ + return self._jsqs.inputRate() + + @property + @since(2.1) + def processingRate(self): + """ + Current rate (rows/sec) at which the query is processing data from all the sources. + + >>> sqs.processingRate + 23.5 + """ + return self._jsqs.processingRate() + + @property + @since(2.1) + def latency(self): + """ + Current average latency between the data being available in source and the sink + writing the corresponding output. + + >>> sqs.latency + 345.0 + """ + if (self._jsqs.latency().nonEmpty()): + return self._jsqs.latency().get() + else: + return None + + @property + @ignore_unicode_prefix + @since(2.1) + def sourceStatuses(self): + """ + Current statuses of the sources as a list. + + >>> len(sqs.sourceStatuses) + 1 + >>> sqs.sourceStatuses[0].description + u'MySource1' + """ + return [SourceStatus(ss) for ss in self._jsqs.sourceStatuses()] + + @property + @ignore_unicode_prefix + @since(2.1) + def sinkStatus(self): + """ + Current status of the sink. + + >>> sqs.sinkStatus.description + u'MySink' + """ + return SinkStatus(self._jsqs.sinkStatus()) + + @property + @ignore_unicode_prefix + @since(2.1) + def triggerDetails(self): + """ + Low-level details of the currently active trigger (e.g. number of rows processed + in trigger, latency of intermediate steps, etc.). + + If no trigger is currently active, then it will have details of the last completed trigger. + + >>> sqs.triggerDetails + {u'triggerId': u'5', u'latency.getBatch.total': u'20', u'numRows.input.total': u'100', + u'isTriggerActive': u'true', u'latency.getOffset.total': u'10', + u'isDataPresentInTrigger': u'true'} + """ + return self._jsqs.triggerDetails() + + +class SourceStatus(object): + """ + Status and metrics of a streaming Source. + + .. note:: Experimental + + .. versionadded:: 2.1 + """ + + def __init__(self, jss): + self._jss = jss + + def __str__(self): + """ + Pretty string of this source status. + + >>> print(sqs.sourceStatuses[0]) + SourceStatus: MySource1 + Available offset: #0 + Input rate: 15.5 rows/sec + Processing rate: 23.5 rows/sec + Trigger details: + numRows.input.source: 100 + latency.getOffset.source: 10 + latency.getBatch.source: 20 + """ + return self._jss.toString() + + @property + @ignore_unicode_prefix + @since(2.1) + def description(self): + """ + Description of the source corresponding to this status. + + >>> sqs.sourceStatuses[0].description + u'MySource1' + """ + return self._jss.description() + + @property + @ignore_unicode_prefix + @since(2.1) + def offsetDesc(self): + """ + Description of the current offset if known. + + >>> sqs.sourceStatuses[0].offsetDesc + u'#0' + """ + return self._jss.offsetDesc() + + @property + @since(2.1) + def inputRate(self): + """ + Current rate (rows/sec) at which data is being generated by the source. + + >>> sqs.sourceStatuses[0].inputRate + 15.5 + """ + return self._jss.inputRate() + + @property + @since(2.1) + def processingRate(self): + """ + Current rate (rows/sec) at which the query is processing data from the source. + + >>> sqs.sourceStatuses[0].processingRate + 23.5 + """ + return self._jss.processingRate() + + @property + @ignore_unicode_prefix + @since(2.1) + def triggerDetails(self): + """ + Low-level details of the currently active trigger (e.g. number of rows processed + in trigger, latency of intermediate steps, etc.). + + If no trigger is currently active, then it will have details of the last completed trigger. + + >>> sqs.sourceStatuses[0].triggerDetails + {u'numRows.input.source': u'100', u'latency.getOffset.source': u'10', + u'latency.getBatch.source': u'20'} + """ + return self._jss.triggerDetails() + + +class SinkStatus(object): + """ + Status and metrics of a streaming Sink. + + .. note:: Experimental + + .. versionadded:: 2.1 + """ + + def __init__(self, jss): + self._jss = jss + + def __str__(self): + """ + Pretty string of this source status. + + >>> print(sqs.sinkStatus) + SinkStatus: MySink + Committed offsets: [#1, -] + """ + return self._jss.toString() + + @property + @ignore_unicode_prefix + @since(2.1) + def description(self): + """ + Description of the source corresponding to this status. + + >>> sqs.sinkStatus.description + u'MySink' + """ + return self._jss.description() + + @property + @ignore_unicode_prefix + @since(2.1) + def offsetDesc(self): + """ + Description of the current offsets up to which data has been written by the sink. + + >>> sqs.sinkStatus.offsetDesc + u'[#1, -]' + """ + return self._jss.offsetDesc() + + class Trigger(object): """Used to indicate how often results should be produced by a :class:`StreamingQuery`. @@ -753,11 +1051,14 @@ def _test(): globs['sdf_schema'] = StructType([StructField("data", StringType(), False)]) globs['df'] = \ globs['spark'].readStream.format('text').load('python/test_support/sql/streaming') + globs['sqs'] = StreamingQueryStatus( + spark.sparkContext._jvm.org.apache.spark.sql.streaming.StreamingQueryStatus.testStatus()) (failure_count, test_count) = doctest.testmod( pyspark.sql.streaming, globs=globs, optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF) globs['spark'].stop() + if failure_count: exit(-1) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 83cb375525832..ea8d8fef7bdf1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -164,6 +164,13 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { ret } + /** + * Returns a Seq containing the leaves in this tree. + */ + def collectLeaves(): Seq[BaseType] = { + this.collect { case p if p.children.isEmpty => p } + } + /** * Finds and returns the first [[TreeNode]] of the tree for which the given partial function * is defined (pre-order), and applies the partial function to it. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala index 6598fa381aa3d..e366b9af35c62 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala @@ -64,10 +64,13 @@ case class LocalTableScanExec( } override def executeCollect(): Array[InternalRow] = { + longMetric("numOutputRows").add(unsafeRows.size) unsafeRows } override def executeTake(limit: Int): Array[InternalRow] = { - unsafeRows.take(limit) + val taken = unsafeRows.take(limit) + longMetric("numOutputRows").add(taken.size) + taken } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala index 4d0283fbef1d0..587ea7d02acab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.execution +import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.execution.SparkPlan @@ -56,7 +57,12 @@ case class StateStoreRestoreExec( child: SparkPlan) extends execution.UnaryExecNode with StatefulOperator { + override lazy val metrics = Map( + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) + override protected def doExecute(): RDD[InternalRow] = { + val numOutputRows = longMetric("numOutputRows") + child.execute().mapPartitionsWithStateStore( getStateId.checkpointLocation, operatorId = getStateId.operatorId, @@ -69,6 +75,7 @@ case class StateStoreRestoreExec( iter.flatMap { row => val key = getKey(row) val savedState = store.get(key) + numOutputRows += 1 row +: savedState.toSeq } } @@ -86,7 +93,13 @@ case class StateStoreSaveExec( child: SparkPlan) extends execution.UnaryExecNode with StatefulOperator { + override lazy val metrics = Map( + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), + "numTotalStateRows" -> SQLMetrics.createMetric(sparkContext, "number of total state rows"), + "numUpdatedStateRows" -> SQLMetrics.createMetric(sparkContext, "number of updated state rows")) + override protected def doExecute(): RDD[InternalRow] = { + metrics // force lazy init at driver assert(returnAllStates.nonEmpty, "Incorrect planning in IncrementalExecution, returnAllStates have not been set") val saveAndReturnFunc = if (returnAllStates.get) saveAndReturnAll _ else saveAndReturnUpdated _ @@ -111,6 +124,10 @@ case class StateStoreSaveExec( private def saveAndReturnUpdated( store: StateStore, iter: Iterator[InternalRow]): Iterator[InternalRow] = { + val numOutputRows = longMetric("numOutputRows") + val numTotalStateRows = longMetric("numTotalStateRows") + val numUpdatedStateRows = longMetric("numUpdatedStateRows") + new Iterator[InternalRow] { private[this] val baseIterator = iter private[this] val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output) @@ -118,6 +135,7 @@ case class StateStoreSaveExec( override def hasNext: Boolean = { if (!baseIterator.hasNext) { store.commit() + numTotalStateRows += store.numKeys() false } else { true @@ -128,6 +146,8 @@ case class StateStoreSaveExec( val row = baseIterator.next().asInstanceOf[UnsafeRow] val key = getKey(row) store.put(key.copy(), row.copy()) + numOutputRows += 1 + numUpdatedStateRows += 1 row } } @@ -142,12 +162,21 @@ case class StateStoreSaveExec( store: StateStore, iter: Iterator[InternalRow]): Iterator[InternalRow] = { val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output) + val numOutputRows = longMetric("numOutputRows") + val numTotalStateRows = longMetric("numTotalStateRows") + val numUpdatedStateRows = longMetric("numUpdatedStateRows") + while (iter.hasNext) { val row = iter.next().asInstanceOf[UnsafeRow] val key = getKey(row) store.put(key.copy(), row.copy()) + numUpdatedStateRows += 1 } store.commit() - store.iterator().map(_._2.asInstanceOf[InternalRow]) + numTotalStateRows += store.numKeys() + store.iterator().map { case (k, v) => + numOutputRows += 1 + v.asInstanceOf[InternalRow] + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 333239f875bd3..9144736c940f5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.execution.{QueryExecution, SparkPlan} import org.apache.spark.sql.execution.command.ExplainCommand import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming._ @@ -57,6 +57,7 @@ class StreamExecution( extends StreamingQuery with Logging { import org.apache.spark.sql.streaming.StreamingQueryListener._ + import StreamMetrics._ private val pollingDelayMs = sparkSession.sessionState.conf.streamingPollingDelay @@ -105,11 +106,22 @@ class StreamExecution( var lastExecution: QueryExecution = null @volatile - var streamDeathCause: StreamingQueryException = null + private var streamDeathCause: StreamingQueryException = null /* Get the call site in the caller thread; will pass this into the micro batch thread */ private val callSite = Utils.getCallSite() + /** Metrics for this query */ + private val streamMetrics = + new StreamMetrics(uniqueSources.toSet, triggerClock, s"StructuredStreaming.$name") + + @volatile + private var currentStatus: StreamingQueryStatus = null + + /** Flag that signals whether any error with input metrics have already been logged */ + @volatile + private var metricWarningLogged: Boolean = false + /** * The thread that runs the micro-batches of this stream. Note that this thread must be * [[org.apache.spark.util.UninterruptibleThread]] to avoid potential deadlocks in using @@ -136,16 +148,14 @@ class StreamExecution( /** Whether the query is currently active or not */ override def isActive: Boolean = state == ACTIVE + /** Returns the current status of the query. */ + override def status: StreamingQueryStatus = currentStatus + /** Returns current status of all the sources. */ - override def sourceStatuses: Array[SourceStatus] = { - val localAvailableOffsets = availableOffsets - sources.map(s => - new SourceStatus(s.toString, localAvailableOffsets.get(s).map(_.toString))).toArray - } + override def sourceStatuses: Array[SourceStatus] = currentStatus.sourceStatuses.toArray /** Returns current status of the sink. */ - override def sinkStatus: SinkStatus = - new SinkStatus(sink.toString, committedOffsets.toCompositeOffset(sources).toString) + override def sinkStatus: SinkStatus = currentStatus.sinkStatus /** Returns the [[StreamingQueryException]] if the query was terminated by an exception. */ override def exception: Option[StreamingQueryException] = Option(streamDeathCause) @@ -176,7 +186,11 @@ class StreamExecution( // Mark ACTIVE and then post the event. QueryStarted event is synchronously sent to listeners, // so must mark this as ACTIVE first. state = ACTIVE - postEvent(new QueryStarted(this.toInfo)) // Assumption: Does not throw exception. + if (sparkSession.sessionState.conf.streamingMetricsEnabled) { + sparkSession.sparkContext.env.metricsSystem.registerSource(streamMetrics) + } + updateStatus() + postEvent(new QueryStarted(currentStatus)) // Assumption: Does not throw exception. // Unblock starting thread startLatch.countDown() @@ -185,25 +199,41 @@ class StreamExecution( SparkSession.setActiveSession(sparkSession) triggerExecutor.execute(() => { - if (isActive) { - if (currentBatchId < 0) { - // We'll do this initialization only once - populateStartOffsets() - logDebug(s"Stream running from $committedOffsets to $availableOffsets") + streamMetrics.reportTriggerStarted(currentBatchId) + streamMetrics.reportTriggerDetail(STATUS_MESSAGE, "Finding new data from sources") + updateStatus() + val isTerminated = reportTimeTaken(TRIGGER_LATENCY) { + if (isActive) { + if (currentBatchId < 0) { + // We'll do this initialization only once + populateStartOffsets() + logDebug(s"Stream running from $committedOffsets to $availableOffsets") + } else { + constructNextBatch() + } + if (dataAvailable) { + streamMetrics.reportTriggerDetail(IS_DATA_PRESENT_IN_TRIGGER, true) + streamMetrics.reportTriggerDetail(STATUS_MESSAGE, "Processing new data") + updateStatus() + runBatch() + // We'll increase currentBatchId after we complete processing current batch's data + currentBatchId += 1 + } else { + streamMetrics.reportTriggerDetail(IS_DATA_PRESENT_IN_TRIGGER, false) + streamMetrics.reportTriggerDetail(STATUS_MESSAGE, "No new data") + updateStatus() + Thread.sleep(pollingDelayMs) + } + true } else { - constructNextBatch() + false } - if (dataAvailable) { - runBatch() - // We'll increase currentBatchId after we complete processing current batch's data - currentBatchId += 1 - } else { - Thread.sleep(pollingDelayMs) - } - true - } else { - false } + // Update metrics and notify others + streamMetrics.reportTriggerFinished() + updateStatus() + postEvent(new QueryProgress(currentStatus)) + isTerminated }) } catch { case _: InterruptedException if state == TERMINATED => // interrupted by stop() @@ -221,8 +251,16 @@ class StreamExecution( } } finally { state = TERMINATED + + // Update metrics and status + streamMetrics.stop() + sparkSession.sparkContext.env.metricsSystem.removeSource(streamMetrics) + updateStatus() + + // Notify others sparkSession.streams.notifyQueryTermination(StreamExecution.this) - postEvent(new QueryTerminated(this.toInfo, exception.map(_.cause).map(Utils.exceptionString))) + postEvent( + new QueryTerminated(currentStatus, exception.map(_.cause).map(Utils.exceptionString))) terminationLatch.countDown() } } @@ -248,7 +286,6 @@ class StreamExecution( committedOffsets = lastOffsets.toStreamProgress(sources) logDebug(s"Resuming with committed offsets: $committedOffsets") } - case None => // We are starting this stream for the first time. logInfo(s"Starting new streaming query.") currentBatchId = 0 @@ -278,8 +315,14 @@ class StreamExecution( val hasNewData = { awaitBatchLock.lock() try { - val newData = uniqueSources.flatMap(s => s.getOffset.map(o => s -> o)) - availableOffsets ++= newData + reportTimeTaken(GET_OFFSET_LATENCY) { + val latestOffsets: Map[Source, Option[Offset]] = uniqueSources.map { s => + reportTimeTaken(s, SOURCE_GET_OFFSET_LATENCY) { + (s, s.getOffset) + } + }.toMap + availableOffsets ++= latestOffsets.filter { case (s, o) => o.nonEmpty }.mapValues(_.get) + } if (dataAvailable) { true @@ -292,16 +335,19 @@ class StreamExecution( } } if (hasNewData) { - assert(offsetLog.add(currentBatchId, availableOffsets.toCompositeOffset(sources)), - s"Concurrent update to the log. Multiple streaming jobs detected for $currentBatchId") - logInfo(s"Committed offsets for batch $currentBatchId.") - - // Now that we have logged the new batch, no further processing will happen for - // the previous batch, and it is safe to discard the old metadata. - // Note that purge is exclusive, i.e. it purges everything before currentBatchId. - // NOTE: If StreamExecution implements pipeline parallelism (multiple batches in - // flight at the same time), this cleanup logic will need to change. - offsetLog.purge(currentBatchId) + reportTimeTaken(OFFSET_WAL_WRITE_LATENCY) { + assert( + offsetLog.add(currentBatchId, availableOffsets.toCompositeOffset(sources)), + s"Concurrent update to the log. Multiple streaming jobs detected for $currentBatchId") + logInfo(s"Committed offsets for batch $currentBatchId.") + + // Now that we have logged the new batch, no further processing will happen for + // the previous batch, and it is safe to discard the old metadata. + // Note that purge is exclusive, i.e. it purges everything before currentBatchId. + // NOTE: If StreamExecution implements pipeline parallelism (multiple batches in + // flight at the same time), this cleanup logic will need to change. + offsetLog.purge(currentBatchId) + } } else { awaitBatchLock.lock() try { @@ -311,26 +357,30 @@ class StreamExecution( awaitBatchLock.unlock() } } + reportTimestamp(GET_OFFSET_TIMESTAMP) } /** * Processes any data available between `availableOffsets` and `committedOffsets`. */ private def runBatch(): Unit = { - val startTime = System.nanoTime() - // TODO: Move this to IncrementalExecution. // Request unprocessed data from all sources. - val newData = availableOffsets.flatMap { - case (source, available) + val newData = reportTimeTaken(GET_BATCH_LATENCY) { + availableOffsets.flatMap { + case (source, available) if committedOffsets.get(source).map(_ != available).getOrElse(true) => - val current = committedOffsets.get(source) - val batch = source.getBatch(current, available) - logDebug(s"Retrieving data from $source: $current -> $available") - Some(source -> batch) - case _ => None - }.toMap + val current = committedOffsets.get(source) + val batch = reportTimeTaken(source, SOURCE_GET_BATCH_LATENCY) { + source.getBatch(current, available) + } + logDebug(s"Retrieving data from $source: $current -> $available") + Some(source -> batch) + case _ => None + } + } + reportTimestamp(GET_BATCH_TIMESTAMP) // A list of attributes that will need to be updated. var replacements = new ArrayBuffer[(Attribute, Attribute)] @@ -351,25 +401,24 @@ class StreamExecution( // Rewire the plan to use the new attributes that were returned by the source. val replacementMap = AttributeMap(replacements) - val newPlan = withNewSources transformAllExpressions { + val triggerLogicalPlan = withNewSources transformAllExpressions { case a: Attribute if replacementMap.contains(a) => replacementMap(a) } - val optimizerStart = System.nanoTime() - lastExecution = new IncrementalExecution( - sparkSession, - newPlan, - outputMode, - checkpointFile("state"), - currentBatchId) - - lastExecution.executedPlan - val optimizerTime = (System.nanoTime() - optimizerStart).toDouble / 1000000 - logDebug(s"Optimized batch in ${optimizerTime}ms") + val executedPlan = reportTimeTaken(OPTIMIZER_LATENCY) { + lastExecution = new IncrementalExecution( + sparkSession, + triggerLogicalPlan, + outputMode, + checkpointFile("state"), + currentBatchId) + lastExecution.executedPlan // Force the lazy generation of execution plan + } val nextBatch = new Dataset(sparkSession, lastExecution, RowEncoder(lastExecution.analyzed.schema)) sink.addBatch(currentBatchId, nextBatch) + reportNumRows(executedPlan, triggerLogicalPlan, newData) awaitBatchLock.lock() try { @@ -379,11 +428,8 @@ class StreamExecution( awaitBatchLock.unlock() } - val batchTime = (System.nanoTime() - startTime).toDouble / 1000000 - logInfo(s"Completed up to $availableOffsets in ${batchTime}ms") // Update committed offsets. committedOffsets ++= availableOffsets - postEvent(new QueryProgress(this.toInfo)) } private def postEvent(event: StreamingQueryListener.Event) { @@ -516,12 +562,131 @@ class StreamExecution( """.stripMargin } - private def toInfo: StreamingQueryInfo = { - new StreamingQueryInfo( - this.name, - this.id, - this.sourceStatuses, - this.sinkStatus) + /** + * Report row metrics of the executed trigger + * @param triggerExecutionPlan Execution plan of the trigger + * @param triggerLogicalPlan Logical plan of the trigger, generated from the query logical plan + * @param sourceToDF Source to DataFrame returned by the source.getBatch in this trigger + */ + private def reportNumRows( + triggerExecutionPlan: SparkPlan, + triggerLogicalPlan: LogicalPlan, + sourceToDF: Map[Source, DataFrame]): Unit = { + // We want to associate execution plan leaves to sources that generate them, so that we match + // the their metrics (e.g. numOutputRows) to the sources. To do this we do the following. + // Consider the translation from the streaming logical plan to the final executed plan. + // + // streaming logical plan (with sources) <==> trigger's logical plan <==> executed plan + // + // 1. We keep track of streaming sources associated with each leaf in the trigger's logical plan + // - Each logical plan leaf will be associated with a single streaming source. + // - There can be multiple logical plan leaves associated with a streaming source. + // - There can be leaves not associated with any streaming source, because they were + // generated from a batch source (e.g. stream-batch joins) + // + // 2. Assuming that the executed plan has same number of leaves in the same order as that of + // the trigger logical plan, we associate executed plan leaves with corresponding + // streaming sources. + // + // 3. For each source, we sum the metrics of the associated execution plan leaves. + // + val logicalPlanLeafToSource = sourceToDF.flatMap { case (source, df) => + df.logicalPlan.collectLeaves().map { leaf => leaf -> source } + } + val allLogicalPlanLeaves = triggerLogicalPlan.collectLeaves() // includes non-streaming sources + val allExecPlanLeaves = triggerExecutionPlan.collectLeaves() + val sourceToNumInputRows: Map[Source, Long] = + if (allLogicalPlanLeaves.size == allExecPlanLeaves.size) { + val execLeafToSource = allLogicalPlanLeaves.zip(allExecPlanLeaves).flatMap { + case (lp, ep) => logicalPlanLeafToSource.get(lp).map { source => ep -> source } + } + val sourceToNumInputRows = execLeafToSource.map { case (execLeaf, source) => + val numRows = execLeaf.metrics.get("numOutputRows").map(_.value).getOrElse(0L) + source -> numRows + } + sourceToNumInputRows.groupBy(_._1).mapValues(_.map(_._2).sum) // sum up rows for each source + } else { + if (!metricWarningLogged) { + def toString[T](seq: Seq[T]): String = s"(size = ${seq.size}), ${seq.mkString(", ")}" + logWarning( + "Could not report metrics as number leaves in trigger logical plan did not match that" + + s" of the execution plan:\n" + + s"logical plan leaves: ${toString(allLogicalPlanLeaves)}\n" + + s"execution plan leaves: ${toString(allExecPlanLeaves)}\n") + metricWarningLogged = true + } + Map.empty + } + val numOutputRows = triggerExecutionPlan.metrics.get("numOutputRows").map(_.value) + val stateNodes = triggerExecutionPlan.collect { + case p if p.isInstanceOf[StateStoreSaveExec] => p + } + + streamMetrics.reportNumInputRows(sourceToNumInputRows) + stateNodes.zipWithIndex.foreach { case (s, i) => + streamMetrics.reportTriggerDetail( + NUM_TOTAL_STATE_ROWS(i + 1), + s.metrics.get("numTotalStateRows").map(_.value).getOrElse(0L)) + streamMetrics.reportTriggerDetail( + NUM_UPDATED_STATE_ROWS(i + 1), + s.metrics.get("numUpdatedStateRows").map(_.value).getOrElse(0L)) + } + updateStatus() + } + + private def reportTimeTaken[T](triggerDetailKey: String)(body: => T): T = { + val startTime = triggerClock.getTimeMillis() + val result = body + val endTime = triggerClock.getTimeMillis() + val timeTaken = math.max(endTime - startTime, 0) + streamMetrics.reportTriggerDetail(triggerDetailKey, timeTaken) + updateStatus() + if (triggerDetailKey == TRIGGER_LATENCY) { + logInfo(s"Completed up to $availableOffsets in $timeTaken ms") + } + result + } + + private def reportTimeTaken[T](source: Source, triggerDetailKey: String)(body: => T): T = { + val startTime = triggerClock.getTimeMillis() + val result = body + val endTime = triggerClock.getTimeMillis() + streamMetrics.reportSourceTriggerDetail( + source, triggerDetailKey, math.max(endTime - startTime, 0)) + updateStatus() + result + } + + private def reportTimestamp(triggerDetailKey: String): Unit = { + streamMetrics.reportTriggerDetail(triggerDetailKey, triggerClock.getTimeMillis) + updateStatus() + } + + private def updateStatus(): Unit = { + val localAvailableOffsets = availableOffsets + val sourceStatuses = sources.map { s => + SourceStatus( + s.toString, + localAvailableOffsets.get(s).map(_.toString).getOrElse("-"), // TODO: use json if available + streamMetrics.currentSourceInputRate(s), + streamMetrics.currentSourceProcessingRate(s), + streamMetrics.currentSourceTriggerDetails(s)) + }.toArray + val sinkStatus = SinkStatus( + sink.toString, + committedOffsets.toCompositeOffset(sources).toString) + + currentStatus = + StreamingQueryStatus( + name = name, + id = id, + timestamp = triggerClock.getTimeMillis(), + inputRate = streamMetrics.currentInputRate(), + processingRate = streamMetrics.currentProcessingRate(), + latency = streamMetrics.currentLatency(), + sourceStatuses = sourceStatuses, + sinkStatus = sinkStatus, + triggerDetails = streamMetrics.currentTriggerDetails()) } trait State diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamMetrics.scala new file mode 100644 index 0000000000000..e98d1883e4596 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamMetrics.scala @@ -0,0 +1,242 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.{util => ju} + +import scala.collection.mutable + +import com.codahale.metrics.{Gauge, MetricRegistry} + +import org.apache.spark.internal.Logging +import org.apache.spark.metrics.source.{Source => CodahaleSource} +import org.apache.spark.util.Clock + +/** + * Class that manages all the metrics related to a StreamingQuery. It does the following. + * - Calculates metrics (rates, latencies, etc.) based on information reported by StreamExecution. + * - Allows the current metric values to be queried + * - Serves some of the metrics through Codahale/DropWizard metrics + * + * @param sources Unique set of sources in a query + * @param triggerClock Clock used for triggering in StreamExecution + * @param codahaleSourceName Root name for all the Codahale metrics + */ +class StreamMetrics(sources: Set[Source], triggerClock: Clock, codahaleSourceName: String) + extends CodahaleSource with Logging { + + import StreamMetrics._ + + // Trigger infos + private val triggerDetails = new mutable.HashMap[String, String] + private val sourceTriggerDetails = new mutable.HashMap[Source, mutable.HashMap[String, String]] + + // Rate estimators for sources and sinks + private val inputRates = new mutable.HashMap[Source, RateCalculator] + private val processingRates = new mutable.HashMap[Source, RateCalculator] + + // Number of input rows in the current trigger + private val numInputRows = new mutable.HashMap[Source, Long] + private var currentTriggerStartTimestamp: Long = -1 + private var previousTriggerStartTimestamp: Long = -1 + private var latency: Option[Double] = None + + override val sourceName: String = codahaleSourceName + override val metricRegistry: MetricRegistry = new MetricRegistry + + // =========== Initialization =========== + + // Metric names should not have . in them, so that all the metrics of a query are identified + // together in Ganglia as a single metric group + registerGauge("inputRate-total", currentInputRate) + registerGauge("processingRate-total", () => currentProcessingRate) + registerGauge("latency", () => currentLatency().getOrElse(-1.0)) + + sources.foreach { s => + inputRates.put(s, new RateCalculator) + processingRates.put(s, new RateCalculator) + sourceTriggerDetails.put(s, new mutable.HashMap[String, String]) + + registerGauge(s"inputRate-${s.toString}", () => currentSourceInputRate(s)) + registerGauge(s"processingRate-${s.toString}", () => currentSourceProcessingRate(s)) + } + + // =========== Setter methods =========== + + def reportTriggerStarted(triggerId: Long): Unit = synchronized { + numInputRows.clear() + triggerDetails.clear() + sourceTriggerDetails.values.foreach(_.clear()) + + reportTriggerDetail(TRIGGER_ID, triggerId) + sources.foreach(s => reportSourceTriggerDetail(s, TRIGGER_ID, triggerId)) + reportTriggerDetail(IS_TRIGGER_ACTIVE, true) + currentTriggerStartTimestamp = triggerClock.getTimeMillis() + reportTriggerDetail(START_TIMESTAMP, currentTriggerStartTimestamp) + } + + def reportTriggerDetail[T](key: String, value: T): Unit = synchronized { + triggerDetails.put(key, value.toString) + } + + def reportSourceTriggerDetail[T](source: Source, key: String, value: T): Unit = synchronized { + sourceTriggerDetails(source).put(key, value.toString) + } + + def reportNumInputRows(inputRows: Map[Source, Long]): Unit = synchronized { + numInputRows ++= inputRows + } + + def reportTriggerFinished(): Unit = synchronized { + require(currentTriggerStartTimestamp >= 0) + val currentTriggerFinishTimestamp = triggerClock.getTimeMillis() + reportTriggerDetail(FINISH_TIMESTAMP, currentTriggerFinishTimestamp) + triggerDetails.remove(STATUS_MESSAGE) + reportTriggerDetail(IS_TRIGGER_ACTIVE, false) + + // Report number of rows + val totalNumInputRows = numInputRows.values.sum + reportTriggerDetail(NUM_INPUT_ROWS, totalNumInputRows) + numInputRows.foreach { case (s, r) => + reportSourceTriggerDetail(s, NUM_SOURCE_INPUT_ROWS, r) + } + + val currentTriggerDuration = currentTriggerFinishTimestamp - currentTriggerStartTimestamp + val previousInputIntervalOption = if (previousTriggerStartTimestamp >= 0) { + Some(currentTriggerStartTimestamp - previousTriggerStartTimestamp) + } else None + + // Update input rate = num rows received by each source during the previous trigger interval + // Interval is measures as interval between start times of previous and current trigger. + // + // TODO: Instead of trigger start, we should use time when getOffset was called on each source + // as this may be different for each source if there are many sources in the query plan + // and getOffset is called serially on them. + if (previousInputIntervalOption.nonEmpty) { + sources.foreach { s => + inputRates(s).update(numInputRows.getOrElse(s, 0), previousInputIntervalOption.get) + } + } + + // Update processing rate = num rows processed for each source in current trigger duration + sources.foreach { s => + processingRates(s).update(numInputRows.getOrElse(s, 0), currentTriggerDuration) + } + + // Update latency = if data present, 0.5 * previous trigger interval + current trigger duration + if (previousInputIntervalOption.nonEmpty && totalNumInputRows > 0) { + latency = Some((previousInputIntervalOption.get.toDouble / 2) + currentTriggerDuration) + } else { + latency = None + } + + previousTriggerStartTimestamp = currentTriggerStartTimestamp + currentTriggerStartTimestamp = -1 + } + + // =========== Getter methods =========== + + def currentInputRate(): Double = synchronized { + // Since we are calculating source input rates using the same time interval for all sources + // it is fine to calculate total input rate as the sum of per source input rate. + inputRates.map(_._2.currentRate).sum + } + + def currentSourceInputRate(source: Source): Double = synchronized { + inputRates(source).currentRate + } + + def currentProcessingRate(): Double = synchronized { + // Since we are calculating source processing rates using the same time interval for all sources + // it is fine to calculate total processing rate as the sum of per source processing rate. + processingRates.map(_._2.currentRate).sum + } + + def currentSourceProcessingRate(source: Source): Double = synchronized { + processingRates(source).currentRate + } + + def currentLatency(): Option[Double] = synchronized { latency } + + def currentTriggerDetails(): Map[String, String] = synchronized { triggerDetails.toMap } + + def currentSourceTriggerDetails(source: Source): Map[String, String] = synchronized { + sourceTriggerDetails(source).toMap + } + + // =========== Other methods =========== + + private def registerGauge[T](name: String, f: () => T)(implicit num: Numeric[T]): Unit = { + synchronized { + metricRegistry.register(name, new Gauge[T] { + override def getValue: T = f() + }) + } + } + + def stop(): Unit = synchronized { + triggerDetails.clear() + inputRates.valuesIterator.foreach { _.stop() } + processingRates.valuesIterator.foreach { _.stop() } + latency = None + } +} + +object StreamMetrics extends Logging { + /** Simple utility class to calculate rate while avoiding DivideByZero */ + class RateCalculator { + @volatile private var rate: Option[Double] = None + + def update(numRows: Long, timeGapMs: Long): Unit = { + if (timeGapMs > 0) { + rate = Some(numRows.toDouble * 1000 / timeGapMs) + } else { + rate = None + logDebug(s"Rate updates cannot with zero or negative time gap $timeGapMs") + } + } + + def currentRate: Double = rate.getOrElse(0.0) + + def stop(): Unit = { rate = None } + } + + + val TRIGGER_ID = "triggerId" + val IS_TRIGGER_ACTIVE = "isTriggerActive" + val IS_DATA_PRESENT_IN_TRIGGER = "isDataPresentInTrigger" + val STATUS_MESSAGE = "statusMessage" + + val START_TIMESTAMP = "timestamp.triggerStart" + val GET_OFFSET_TIMESTAMP = "timestamp.afterGetOffset" + val GET_BATCH_TIMESTAMP = "timestamp.afterGetBatch" + val FINISH_TIMESTAMP = "timestamp.triggerFinish" + + val GET_OFFSET_LATENCY = "latency.getOffset.total" + val GET_BATCH_LATENCY = "latency.getBatch.total" + val OFFSET_WAL_WRITE_LATENCY = "latency.offsetLogWrite" + val OPTIMIZER_LATENCY = "latency.optimizer" + val TRIGGER_LATENCY = "latency.fullTrigger" + val SOURCE_GET_OFFSET_LATENCY = "latency.getOffset.source" + val SOURCE_GET_BATCH_LATENCY = "latency.getBatch.source" + + val NUM_INPUT_ROWS = "numRows.input.total" + val NUM_SOURCE_INPUT_ROWS = "numRows.input.source" + def NUM_TOTAL_STATE_ROWS(aggId: Int): String = s"numRows.state.aggregation$aggId.total" + def NUM_UPDATED_STATE_ROWS(aggId: Int): String = s"numRows.state.aggregation$aggId.updated" +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 5052c4d50c5ed..788fcd0361bee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -112,6 +112,11 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) } override def stop() {} + + def reset(): Unit = synchronized { + batches.clear() + currentOffset = new LongOffset(-1) + } } /** @@ -165,6 +170,8 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink wi logDebug(s"Skipping already committed batch: $batchId") } } + + override def toString(): String = "MemorySink" } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index bec966b15ed0f..7d71f5242c27d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -197,6 +197,8 @@ private[state] class HDFSBackedStateStoreProvider( allUpdates.values().asScala.toIterator } + override def numKeys(): Long = mapToUpdate.size() + /** * Whether all updates have been committed */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index a67fdceb3cee6..7132e284c28f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -77,6 +77,9 @@ trait StateStore { */ def updates(): Iterator[StoreUpdate] + /** Number of keys in the state store */ + def numKeys(): Long + /** * Whether all updates have been committed */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 192083e2ea5f5..e671604c39855 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -569,6 +569,12 @@ object SQLConf { .timeConf(TimeUnit.MILLISECONDS) .createWithDefault(10L) + val STREAMING_METRICS_ENABLED = + SQLConfigBuilder("spark.sql.streaming.metricsEnabled") + .doc("Whether Dropwizard/Codahale metrics will be reported for active streaming queries.") + .booleanConf + .createWithDefault(false) + val NDV_MAX_ERROR = SQLConfigBuilder("spark.sql.statistics.ndv.maxError") .internal() @@ -635,6 +641,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def streamingPollingDelay: Long = getConf(STREAMING_POLLING_DELAY) + def streamingMetricsEnabled: Boolean = getConf(STREAMING_METRICS_ENABLED) + def filesMaxPartitionBytes: Long = getConf(FILES_MAX_PARTITION_BYTES) def filesOpenCostInBytes: Long = getConf(FILES_OPEN_COST_IN_BYTES) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/SinkStatus.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/SinkStatus.scala index de1efe961f8bd..c9911665f7d72 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/SinkStatus.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/SinkStatus.scala @@ -18,17 +18,33 @@ package org.apache.spark.sql.streaming import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.execution.streaming.Sink +import org.apache.spark.sql.streaming.StreamingQueryStatus.indent /** * :: Experimental :: - * Status and metrics of a streaming [[Sink]]. + * Status and metrics of a streaming sink. * - * @param description Description of the source corresponding to this status - * @param offsetDesc Description of the current offset up to which data has been written by the sink + * @param description Description of the source corresponding to this status. + * @param offsetDesc Description of the current offsets up to which data has been written + * by the sink. * @since 2.0.0 */ @Experimental -class SinkStatus private[sql]( +class SinkStatus private( val description: String, - val offsetDesc: String) + val offsetDesc: String) { + + override def toString: String = + "SinkStatus:" + indent(prettyString) + + private[sql] def prettyString: String = { + s"""$description + |Committed offsets: $offsetDesc + |""".stripMargin + } +} + +/** Companion object, primarily for creating SinkStatus instances internally */ +private[sql] object SinkStatus { + def apply(desc: String, offsetDesc: String): SinkStatus = new SinkStatus(desc, offsetDesc) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/SourceStatus.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/SourceStatus.scala index bd0c8485e4fdd..6ace4833be22f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/SourceStatus.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/SourceStatus.scala @@ -17,18 +17,60 @@ package org.apache.spark.sql.streaming +import java.{util => ju} + +import scala.collection.JavaConverters._ + import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.execution.streaming.Source +import org.apache.spark.sql.streaming.StreamingQueryStatus.indent /** * :: Experimental :: - * Status and metrics of a streaming [[Source]]. + * Status and metrics of a streaming Source. * - * @param description Description of the source corresponding to this status - * @param offsetDesc Description of the current [[Source]] offset if known + * @param description Description of the source corresponding to this status. + * @param offsetDesc Description of the current offset if known. + * @param inputRate Current rate (rows/sec) at which data is being generated by the source. + * @param processingRate Current rate (rows/sec) at which the query is processing data from + * the source. + * @param triggerDetails Low-level details of the currently active trigger (e.g. number of + * rows processed in trigger, latency of intermediate steps, etc.). + * If no trigger is active, then it will have details of the last completed + * trigger. * @since 2.0.0 */ @Experimental -class SourceStatus private[sql] ( +class SourceStatus private( val description: String, - val offsetDesc: Option[String]) + val offsetDesc: String, + val inputRate: Double, + val processingRate: Double, + val triggerDetails: ju.Map[String, String]) { + + override def toString: String = + "SourceStatus:" + indent(prettyString) + + private[sql] def prettyString: String = { + val triggerDetailsLines = + triggerDetails.asScala.map { case (k, v) => s"$k: $v" } + s"""$description + |Available offset: $offsetDesc + |Input rate: $inputRate rows/sec + |Processing rate: $processingRate rows/sec + |Trigger details: + |""".stripMargin + indent(triggerDetailsLines) + + } +} + +/** Companion object, primarily for creating SourceStatus instances internally */ +private[sql] object SourceStatus { + def apply( + desc: String, + offsetDesc: String, + inputRate: Double, + processingRate: Double, + triggerDetails: Map[String, String]): SourceStatus = { + new SourceStatus(desc, offsetDesc, inputRate, processingRate, triggerDetails.asJava) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala index 91f0a1e3446a1..0a85414451981 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala @@ -62,13 +62,24 @@ trait StreamingQuery { */ def exception: Option[StreamingQueryException] + /** + * Returns the current status of the query. + * @since 2.0.2 + */ + def status: StreamingQueryStatus + /** * Returns current status of all the sources. * @since 2.0.0 */ + @deprecated("use status.sourceStatuses", "2.0.2") def sourceStatuses: Array[SourceStatus] - /** Returns current status of the sink. */ + /** + * Returns current status of the sink. + * @since 2.0.0 + */ + @deprecated("use status.sinkStatus", "2.0.2") def sinkStatus: SinkStatus /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryInfo.scala deleted file mode 100644 index 1af2668817eae..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryInfo.scala +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.streaming - -import org.apache.spark.annotation.Experimental - -/** - * :: Experimental :: - * A class used to report information about the progress of a [[StreamingQuery]]. - * - * @param name The [[StreamingQuery]] name. This name is unique across all active queries. - * @param id The [[StreamingQuery]] id. This id is unique across - * all queries that have been started in the current process. - * @param sourceStatuses The current statuses of the [[StreamingQuery]]'s sources. - * @param sinkStatus The current status of the [[StreamingQuery]]'s sink. - */ -@Experimental -class StreamingQueryInfo private[sql]( - val name: String, - val id: Long, - val sourceStatuses: Seq[SourceStatus], - val sinkStatus: SinkStatus) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala index 8a8855d85a4c7..69790e33b2168 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala @@ -84,7 +84,7 @@ object StreamingQueryListener { * @since 2.0.0 */ @Experimental - class QueryStarted private[sql](val queryInfo: StreamingQueryInfo) extends Event + class QueryStarted private[sql](val queryStatus: StreamingQueryStatus) extends Event /** * :: Experimental :: @@ -92,19 +92,19 @@ object StreamingQueryListener { * @since 2.0.0 */ @Experimental - class QueryProgress private[sql](val queryInfo: StreamingQueryInfo) extends Event + class QueryProgress private[sql](val queryStatus: StreamingQueryStatus) extends Event /** * :: Experimental :: * Event representing that termination of a query * - * @param queryInfo Information about the status of the query. + * @param queryStatus Information about the status of the query. * @param exception The exception message of the [[StreamingQuery]] if the query was terminated * with an exception. Otherwise, it will be `None`. * @since 2.0.0 */ @Experimental class QueryTerminated private[sql]( - val queryInfo: StreamingQueryInfo, + val queryStatus: StreamingQueryStatus, val exception: Option[String]) extends Event } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala new file mode 100644 index 0000000000000..47689928730d0 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import java.{util => ju} + +import scala.collection.JavaConverters._ + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.execution.streaming.{CompositeOffset, LongOffset} + +/** + * :: Experimental :: + * A class used to report information about the progress of a [[StreamingQuery]]. + * + * @param name Name of the query. This name is unique across all active queries. + * @param id Id of the query. This id is unique across + * all queries that have been started in the current process. + * @param timestamp Timestamp (ms) of when this query was generated. + * @param inputRate Current rate (rows/sec) at which data is being generated by all the sources. + * @param processingRate Current rate (rows/sec) at which the query is processing data from + * all the sources. + * @param latency Current average latency between the data being available in source and the sink + * writing the corresponding output. + * @param sourceStatuses Current statuses of the sources. + * @param sinkStatus Current status of the sink. + * @param triggerDetails Low-level details of the currently active trigger (e.g. number of + * rows processed in trigger, latency of intermediate steps, etc.). + * If no trigger is active, then it will have details of the last completed + * trigger. + * @since 2.0.0 + */ +@Experimental +class StreamingQueryStatus private( + val name: String, + val id: Long, + val timestamp: Long, + val inputRate: Double, + val processingRate: Double, + val latency: Option[Double], + val sourceStatuses: Array[SourceStatus], + val sinkStatus: SinkStatus, + val triggerDetails: ju.Map[String, String]) { + + import StreamingQueryStatus._ + + override def toString: String = { + val sourceStatusLines = sourceStatuses.zipWithIndex.map { case (s, i) => + s"Source ${i + 1}:" + indent(s.prettyString) + } + val sinkStatusLines = sinkStatus.prettyString + val triggerDetailsLines = triggerDetails.asScala.map { case (k, v) => s"$k: $v" }.toSeq.sorted + val numSources = sourceStatuses.length + val numSourcesString = s"$numSources source" + { if (numSources > 1) "s" else "" } + + val allLines = s""" + |Query name: $name + |Query id: $id + |Status timestamp: $timestamp + |Input rate: $inputRate rows/sec + |Processing rate $processingRate rows/sec + |Latency: ${latency.getOrElse("-")} ms + |Trigger details: + |${indent(triggerDetailsLines)} + |Source statuses [$numSourcesString]: + |${indent(sourceStatusLines)} + |Sink status: ${indent(sinkStatusLines)}""".stripMargin + + s"StreamingQueryStatus:${indent(allLines)}" + } +} + +/** Companion object, primarily for creating StreamingQueryInfo instances internally */ +private[sql] object StreamingQueryStatus { + def apply( + name: String, + id: Long, + timestamp: Long, + inputRate: Double, + processingRate: Double, + latency: Option[Double], + sourceStatuses: Array[SourceStatus], + sinkStatus: SinkStatus, + triggerDetails: Map[String, String]): StreamingQueryStatus = { + new StreamingQueryStatus(name, id, timestamp, inputRate, processingRate, + latency, sourceStatuses, sinkStatus, triggerDetails.asJava) + } + + def indent(strings: Iterable[String]): String = strings.map(indent).mkString("\n") + def indent(string: String): String = string.split("\n").map(" " + _).mkString("\n") + + /** Create an instance of status for python testing */ + def testStatus(): StreamingQueryStatus = { + import org.apache.spark.sql.execution.streaming.StreamMetrics._ + StreamingQueryStatus( + name = "query", + id = 1, + timestamp = 123, + inputRate = 15.5, + processingRate = 23.5, + latency = Some(345), + sourceStatuses = Array( + SourceStatus( + desc = "MySource1", + offsetDesc = LongOffset(0).toString, + inputRate = 15.5, + processingRate = 23.5, + triggerDetails = Map( + NUM_SOURCE_INPUT_ROWS -> "100", + SOURCE_GET_OFFSET_LATENCY -> "10", + SOURCE_GET_BATCH_LATENCY -> "20"))), + sinkStatus = SinkStatus( + desc = "MySink", + offsetDesc = CompositeOffset(Some(LongOffset(1)) :: None :: Nil).toString), + triggerDetails = Map( + TRIGGER_ID -> "5", + IS_TRIGGER_ACTIVE -> "true", + IS_DATA_PRESENT_IN_TRIGGER -> "true", + GET_OFFSET_LATENCY -> "10", + GET_BATCH_LATENCY -> "20", + NUM_INPUT_ROWS -> "100" + )) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index bba40c6510cfb..229d8814e0143 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.metric import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.execution.SparkPlanInfo import org.apache.spark.sql.execution.ui.SparkPlanGraph import org.apache.spark.sql.functions._ @@ -85,6 +86,22 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { } } + test("LocalTableScanExec computes metrics in collect and take") { + val df1 = spark.createDataset(Seq(1, 2, 3)) + val logical = df1.queryExecution.logical + require(logical.isInstanceOf[LocalRelation]) + df1.collect() + val metrics1 = df1.queryExecution.executedPlan.collectLeaves().head.metrics + assert(metrics1.contains("numOutputRows")) + assert(metrics1("numOutputRows").value === 3) + + val df2 = spark.createDataset(Seq(1, 2, 3)).limit(2) + df2.collect() + val metrics2 = df2.queryExecution.executedPlan.collectLeaves().head.metrics + assert(metrics2.contains("numOutputRows")) + assert(metrics2("numOutputRows").value === 2) + } + test("Filter metrics") { // Assume the execution plan is // PhysicalRDD(nodeId = 1) -> Filter(nodeId = 0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/StreamMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/StreamMetricsSuite.scala new file mode 100644 index 0000000000000..938423db64745 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/StreamMetricsSuite.scala @@ -0,0 +1,213 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import org.scalactic.TolerantNumerics + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.types.{StructField, StructType} +import org.apache.spark.util.ManualClock + +class StreamMetricsSuite extends SparkFunSuite { + import StreamMetrics._ + + // To make === between double tolerate inexact values + implicit val doubleEquality = TolerantNumerics.tolerantDoubleEquality(0.01) + + test("rates, latencies, trigger details - basic life cycle") { + val sm = newStreamMetrics(source) + assert(sm.currentInputRate() === 0.0) + assert(sm.currentProcessingRate() === 0.0) + assert(sm.currentSourceInputRate(source) === 0.0) + assert(sm.currentSourceProcessingRate(source) === 0.0) + assert(sm.currentLatency() === None) + assert(sm.currentTriggerDetails().isEmpty) + + // When trigger started, the rates should not change, but should return + // reported trigger details + sm.reportTriggerStarted(1) + sm.reportTriggerDetail("key", "value") + sm.reportSourceTriggerDetail(source, "key2", "value2") + assert(sm.currentInputRate() === 0.0) + assert(sm.currentProcessingRate() === 0.0) + assert(sm.currentSourceInputRate(source) === 0.0) + assert(sm.currentSourceProcessingRate(source) === 0.0) + assert(sm.currentLatency() === None) + assert(sm.currentTriggerDetails() === + Map(TRIGGER_ID -> "1", IS_TRIGGER_ACTIVE -> "true", + START_TIMESTAMP -> "0", "key" -> "value")) + assert(sm.currentSourceTriggerDetails(source) === + Map(TRIGGER_ID -> "1", "key2" -> "value2")) + + // Finishing the trigger should calculate the rates, except input rate which needs + // to have another trigger interval + sm.reportNumInputRows(Map(source -> 100L)) // 100 input rows, 10 output rows + clock.advance(1000) + sm.reportTriggerFinished() + assert(sm.currentInputRate() === 0.0) + assert(sm.currentProcessingRate() === 100.0) // 100 input rows processed in 1 sec + assert(sm.currentSourceInputRate(source) === 0.0) + assert(sm.currentSourceProcessingRate(source) === 100.0) + assert(sm.currentLatency() === None) + assert(sm.currentTriggerDetails() === + Map(TRIGGER_ID -> "1", IS_TRIGGER_ACTIVE -> "false", + START_TIMESTAMP -> "0", FINISH_TIMESTAMP -> "1000", + NUM_INPUT_ROWS -> "100", "key" -> "value")) + assert(sm.currentSourceTriggerDetails(source) === + Map(TRIGGER_ID -> "1", NUM_SOURCE_INPUT_ROWS -> "100", "key2" -> "value2")) + + // After another trigger starts, the rates and latencies should not change until + // new rows are reported + clock.advance(1000) + sm.reportTriggerStarted(2) + assert(sm.currentInputRate() === 0.0) + assert(sm.currentProcessingRate() === 100.0) + assert(sm.currentSourceInputRate(source) === 0.0) + assert(sm.currentSourceProcessingRate(source) === 100.0) + assert(sm.currentLatency() === None) + + // Reporting new rows should update the rates and latencies + sm.reportNumInputRows(Map(source -> 200L)) // 200 input rows + clock.advance(500) + sm.reportTriggerFinished() + assert(sm.currentInputRate() === 100.0) // 200 input rows generated in 2 seconds b/w starts + assert(sm.currentProcessingRate() === 400.0) // 200 output rows processed in 0.5 sec + assert(sm.currentSourceInputRate(source) === 100.0) + assert(sm.currentSourceProcessingRate(source) === 400.0) + assert(sm.currentLatency().get === 1500.0) // 2000 ms / 2 + 500 ms + + // Rates should be set to 0 after stop + sm.stop() + assert(sm.currentInputRate() === 0.0) + assert(sm.currentProcessingRate() === 0.0) + assert(sm.currentSourceInputRate(source) === 0.0) + assert(sm.currentSourceProcessingRate(source) === 0.0) + assert(sm.currentLatency() === None) + assert(sm.currentTriggerDetails().isEmpty) + } + + test("rates and latencies - after trigger with no data") { + val sm = newStreamMetrics(source) + // Trigger 1 with data + sm.reportTriggerStarted(1) + sm.reportNumInputRows(Map(source -> 100L)) // 100 input rows + clock.advance(1000) + sm.reportTriggerFinished() + + // Trigger 2 with data + clock.advance(1000) + sm.reportTriggerStarted(2) + sm.reportNumInputRows(Map(source -> 200L)) // 200 input rows + clock.advance(500) + sm.reportTriggerFinished() + + // Make sure that all rates are set + require(sm.currentInputRate() === 100.0) // 200 input rows generated in 2 seconds b/w starts + require(sm.currentProcessingRate() === 400.0) // 200 output rows processed in 0.5 sec + require(sm.currentSourceInputRate(source) === 100.0) + require(sm.currentSourceProcessingRate(source) === 400.0) + require(sm.currentLatency().get === 1500.0) // 2000 ms / 2 + 500 ms + + // Trigger 3 with data + clock.advance(500) + sm.reportTriggerStarted(3) + clock.advance(500) + sm.reportTriggerFinished() + + // Rates are set to zero and latency is set to None + assert(sm.currentInputRate() === 0.0) + assert(sm.currentProcessingRate() === 0.0) + assert(sm.currentSourceInputRate(source) === 0.0) + assert(sm.currentSourceProcessingRate(source) === 0.0) + assert(sm.currentLatency() === None) + sm.stop() + } + + test("rates - after trigger with multiple sources, and one source having no info") { + val source1 = TestSource(1) + val source2 = TestSource(2) + val sm = newStreamMetrics(source1, source2) + // Trigger 1 with data + sm.reportTriggerStarted(1) + sm.reportNumInputRows(Map(source1 -> 100L, source2 -> 100L)) + clock.advance(1000) + sm.reportTriggerFinished() + + // Trigger 2 with data + clock.advance(1000) + sm.reportTriggerStarted(2) + sm.reportNumInputRows(Map(source1 -> 200L, source2 -> 200L)) + clock.advance(500) + sm.reportTriggerFinished() + + // Make sure that all rates are set + assert(sm.currentInputRate() === 200.0) // 200*2 input rows generated in 2 seconds b/w starts + assert(sm.currentProcessingRate() === 800.0) // 200*2 output rows processed in 0.5 sec + assert(sm.currentSourceInputRate(source1) === 100.0) + assert(sm.currentSourceInputRate(source2) === 100.0) + assert(sm.currentSourceProcessingRate(source1) === 400.0) + assert(sm.currentSourceProcessingRate(source2) === 400.0) + + // Trigger 3 with data + clock.advance(500) + sm.reportTriggerStarted(3) + clock.advance(500) + sm.reportNumInputRows(Map(source1 -> 200L)) + sm.reportTriggerFinished() + + // Rates are set to zero and latency is set to None + assert(sm.currentInputRate() === 200.0) + assert(sm.currentProcessingRate() === 400.0) + assert(sm.currentSourceInputRate(source1) === 200.0) + assert(sm.currentSourceInputRate(source2) === 0.0) + assert(sm.currentSourceProcessingRate(source1) === 400.0) + assert(sm.currentSourceProcessingRate(source2) === 0.0) + sm.stop() + } + + test("registered Codahale metrics") { + import scala.collection.JavaConverters._ + val sm = newStreamMetrics(source) + val gaugeNames = sm.metricRegistry.getGauges().keySet().asScala + + // so that all metrics are considered as a single metric group in Ganglia + assert(!gaugeNames.exists(_.contains("."))) + assert(gaugeNames === Set( + "inputRate-total", + "inputRate-source0", + "processingRate-total", + "processingRate-source0", + "latency")) + } + + private def newStreamMetrics(sources: Source*): StreamMetrics = { + new StreamMetrics(sources.toSet, clock, "test") + } + + private val clock = new ManualClock() + private val source = TestSource(0) + + case class TestSource(id: Int) extends Source { + override def schema: StructType = StructType(Array.empty[StructField]) + override def getOffset: Option[Offset] = Some(new LongOffset(0)) + override def getBatch(start: Option[Offset], end: Offset): DataFrame = { null } + override def stop() {} + override def toString(): String = s"source$id" + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/TextSocketStreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/TextSocketStreamSuite.scala index 6b0ba7acb4804..5174a0415304c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/TextSocketStreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/TextSocketStreamSuite.scala @@ -156,6 +156,30 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before } } + test("input row metrics") { + serverThread = new ServerThread() + serverThread.start() + + val provider = new TextSocketSourceProvider + val parameters = Map("host" -> "localhost", "port" -> serverThread.port.toString) + source = provider.createSource(sqlContext, "", None, "", parameters) + + failAfter(streamingTimeout) { + serverThread.enqueue("hello") + while (source.getOffset.isEmpty) { + Thread.sleep(10) + } + val batch = source.getBatch(None, source.getOffset.get).as[String] + batch.collect() + val numRowsMetric = + batch.queryExecution.executedPlan.collectLeaves().head.metrics.get("numOutputRows") + assert(numRowsMetric.nonEmpty) + assert(numRowsMetric.get.value === 1) + source.stop() + source = null + } + } + private class ServerThread extends Thread with Logging { private val serverSocket = new ServerSocket(0) private val messageQueue = new LinkedBlockingQueue[String]() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index 984b84fd13fbd..06f1bd6c3bcc7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -74,6 +74,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth // Verify state after updating put(store, "a", 1) + assert(store.numKeys() === 1) intercept[IllegalStateException] { store.iterator() } @@ -85,7 +86,9 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth // Make updates, commit and then verify state put(store, "b", 2) put(store, "aa", 3) + assert(store.numKeys() === 3) remove(store, _.startsWith("a")) + assert(store.numKeys() === 1) assert(store.commit() === 1) assert(store.hasCommitted) @@ -107,7 +110,9 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth val reloadedProvider = new HDFSBackedStateStoreProvider( store.id, keySchema, valueSchema, StateStoreConf.empty, new Configuration) val reloadedStore = reloadedProvider.getStore(1) + assert(reloadedStore.numKeys() === 1) put(reloadedStore, "c", 4) + assert(reloadedStore.numKeys() === 2) assert(reloadedStore.commit() === 2) assert(rowsToSet(reloadedStore.iterator()) === Set("b" -> 2, "c" -> 4)) assert(getDataFromFiles(provider) === Set("b" -> 2, "c" -> 4)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index 7f9c981a4e9c9..aabdccaaf319d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -998,6 +998,20 @@ class FileStreamSourceSuite extends FileStreamSourceTest { } } } + + test("input row metrics") { + withTempDirs { case (src, tmp) => + val input = spark.readStream.format("text").load(src.getCanonicalPath) + testStream(input)( + AddTextFileData("100", src, tmp), + CheckAnswer("100"), + AssertOnLastQueryStatus { status => + assert(status.triggerDetails.get("numRows.input.total") === "1") + assert(status.sourceStatuses(0).processingRate > 0.0) + } + ) + } + } } class FileStreamSourceStressTestSuite extends FileStreamSourceTest { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index fa13d385cce75..3b9d3786349ad 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -28,6 +28,8 @@ import scala.util.control.NonFatal import org.scalatest.Assertions import org.scalatest.concurrent.{Eventually, Timeouts} +import org.scalatest.concurrent.AsyncAssertions.Waiter +import org.scalatest.concurrent.Eventually._ import org.scalatest.concurrent.PatienceConfiguration.Timeout import org.scalatest.exceptions.TestFailedDueToTimeoutException import org.scalatest.time.Span @@ -38,6 +40,7 @@ import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, Ro import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.streaming.StreamingQueryListener._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.{Clock, ManualClock, SystemClock, Utils} @@ -198,6 +201,10 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { } } + case class AssertOnLastQueryStatus(condition: StreamingQueryStatus => Unit) + extends StreamAction + + /** * Executes the specified actions on the given streaming DataFrame and provides helpful * error messages in the case of failures or incorrect answers. @@ -299,9 +306,12 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { val testThread = Thread.currentThread() val metadataRoot = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath + val statusCollector = new QueryStatusCollector try { + spark.streams.addListener(statusCollector) startedTest.foreach { action => + logInfo(s"Processing test stream action: $action") action match { case StartStream(trigger, triggerClock) => verify(currentStream == null, "stream already running") @@ -399,6 +409,13 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { val streamToAssert = Option(currentStream).getOrElse(lastStream) verify({ a.run(); true }, s"Assert failed: ${a.message}") + case a: AssertOnLastQueryStatus => + Eventually.eventually(timeout(streamingTimeout)) { + require(statusCollector.lastTriggerStatus.nonEmpty) + } + val status = statusCollector.lastTriggerStatus.get + verify({ a.condition(status); true }, "Assert on last query status failed") + case a: AddData => try { // Add data and get the source where it was added, and the expected offset of the @@ -473,6 +490,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { if (currentStream != null && currentStream.microBatchThread.isAlive) { currentStream.stop() } + spark.streams.removeListener(statusCollector) } } @@ -606,4 +624,58 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { } } } + + + class QueryStatusCollector extends StreamingQueryListener { + // to catch errors in the async listener events + @volatile private var asyncTestWaiter = new Waiter + + @volatile var startStatus: StreamingQueryStatus = null + @volatile var terminationStatus: StreamingQueryStatus = null + @volatile var terminationException: Option[String] = null + + private val progressStatuses = new mutable.ArrayBuffer[StreamingQueryStatus] + + /** Get the info of the last trigger that processed data */ + def lastTriggerStatus: Option[StreamingQueryStatus] = synchronized { + progressStatuses.filter { i => + i.triggerDetails.get("isTriggerActive").toBoolean == false && + i.triggerDetails.get("isDataPresentInTrigger").toBoolean == true + }.lastOption + } + + def reset(): Unit = { + startStatus = null + terminationStatus = null + progressStatuses.clear() + asyncTestWaiter = new Waiter + } + + def checkAsyncErrors(): Unit = { + asyncTestWaiter.await(timeout(10 seconds)) + } + + + override def onQueryStarted(queryStarted: QueryStarted): Unit = { + asyncTestWaiter { + startStatus = queryStarted.queryStatus + } + } + + override def onQueryProgress(queryProgress: QueryProgress): Unit = { + asyncTestWaiter { + assert(startStatus != null, "onQueryProgress called before onQueryStarted") + synchronized { progressStatuses += queryProgress.queryStatus } + } + } + + override def onQueryTerminated(queryTerminated: QueryTerminated): Unit = { + asyncTestWaiter { + assert(startStatus != null, "onQueryTerminated called before onQueryStarted") + terminationStatus = queryTerminated.queryStatus + terminationException = queryTerminated.exception + } + asyncTestWaiter.dismiss() + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index 8681199817fe6..e59b5491f90b6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -22,6 +22,7 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkException import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.InternalOutputModes._ +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.expressions.scalalang.typed @@ -129,6 +130,59 @@ class StreamingAggregationSuite extends StreamTest with BeforeAndAfterAll { ) } + test("state metrics") { + val inputData = MemoryStream[Int] + + val aggregated = + inputData.toDS() + .flatMap(x => Seq(x, x + 1)) + .toDF("value") + .groupBy($"value") + .agg(count("*")) + .as[(Int, Long)] + + implicit class RichStreamExecution(query: StreamExecution) { + def stateNodes: Seq[SparkPlan] = { + query.lastExecution.executedPlan.collect { + case p if p.isInstanceOf[StateStoreSaveExec] => p + } + } + } + + // Test with Update mode + testStream(aggregated, Update)( + AddData(inputData, 1), + CheckLastBatch((1, 1), (2, 1)), + AssertOnQuery { _.stateNodes.size === 1 }, + AssertOnQuery { _.stateNodes.head.metrics.get("numOutputRows").get.value === 2 }, + AssertOnQuery { _.stateNodes.head.metrics.get("numUpdatedStateRows").get.value === 2 }, + AssertOnQuery { _.stateNodes.head.metrics.get("numTotalStateRows").get.value === 2 }, + AddData(inputData, 2, 3), + CheckLastBatch((2, 2), (3, 2), (4, 1)), + AssertOnQuery { _.stateNodes.size === 1 }, + AssertOnQuery { _.stateNodes.head.metrics.get("numOutputRows").get.value === 3 }, + AssertOnQuery { _.stateNodes.head.metrics.get("numUpdatedStateRows").get.value === 3 }, + AssertOnQuery { _.stateNodes.head.metrics.get("numTotalStateRows").get.value === 4 } + ) + + // Test with Complete mode + inputData.reset() + testStream(aggregated, Complete)( + AddData(inputData, 1), + CheckLastBatch((1, 1), (2, 1)), + AssertOnQuery { _.stateNodes.size === 1 }, + AssertOnQuery { _.stateNodes.head.metrics.get("numOutputRows").get.value === 2 }, + AssertOnQuery { _.stateNodes.head.metrics.get("numUpdatedStateRows").get.value === 2 }, + AssertOnQuery { _.stateNodes.head.metrics.get("numTotalStateRows").get.value === 2 }, + AddData(inputData, 2, 3), + CheckLastBatch((1, 1), (2, 2), (3, 2), (4, 1)), + AssertOnQuery { _.stateNodes.size === 1 }, + AssertOnQuery { _.stateNodes.head.metrics.get("numOutputRows").get.value === 4 }, + AssertOnQuery { _.stateNodes.head.metrics.get("numUpdatedStateRows").get.value === 3 }, + AssertOnQuery { _.stateNodes.head.metrics.get("numTotalStateRows").get.value === 4 } + ) + } + test("multiple keys") { val inputData = MemoryStream[Int] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala index 831543a47420a..6256385dfd0e4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala @@ -17,92 +17,97 @@ package org.apache.spark.sql.streaming -import java.util.concurrent.ConcurrentLinkedQueue - +import org.scalactic.TolerantNumerics import org.scalatest.BeforeAndAfter import org.scalatest.PrivateMethodTester._ -import org.scalatest.concurrent.AsyncAssertions.Waiter -import org.scalatest.concurrent.Eventually._ -import org.scalatest.concurrent.PatienceConfiguration.Timeout -import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkException +import org.apache.spark.sql.DataFrame import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.util.JsonProtocol +import org.apache.spark.sql.functions._ +import org.apache.spark.util.{JsonProtocol, ManualClock} class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { import testImplicits._ - import StreamingQueryListener._ + import StreamingQueryListenerSuite._ + + // To make === between double tolerate inexact values + implicit val doubleEquality = TolerantNumerics.tolerantDoubleEquality(0.01) after { spark.streams.active.foreach(_.stop()) assert(spark.streams.active.isEmpty) assert(addedListeners.isEmpty) // Make sure we don't leak any events to the next test - spark.sparkContext.listenerBus.waitUntilEmpty(10000) } - test("single listener") { - val listener = new QueryStatusCollector - val input = MemoryStream[Int] - withListenerAdded(listener) { - testStream(input.toDS)( - StartStream(), - AssertOnQuery("Incorrect query status in onQueryStarted") { query => - val status = listener.startStatus - assert(status != null) - assert(status.name === query.name) - assert(status.id === query.id) - assert(status.sourceStatuses.size === 1) - assert(status.sourceStatuses(0).description.contains("Memory")) - - // The source and sink offsets must be None as this must be called before the - // batches have started - assert(status.sourceStatuses(0).offsetDesc === None) - assert(status.sinkStatus.offsetDesc === CompositeOffset(None :: Nil).toString) - - // No progress events or termination events - assert(listener.progressStatuses.isEmpty) - assert(listener.terminationStatus === null) - true - }, - AddDataMemory(input, Seq(1, 2, 3)), - CheckAnswer(1, 2, 3), - AssertOnQuery("Incorrect query status in onQueryProgress") { query => - eventually(Timeout(streamingTimeout)) { + test("single listener, check trigger statuses") { + import StreamingQueryListenerSuite._ + clock = new ManualClock() + + /** Custom MemoryStream that waits for manual clock to reach a time */ + val inputData = new MemoryStream[Int](0, sqlContext) { + // Wait for manual clock to be 100 first time there is data + override def getOffset: Option[Offset] = { + val offset = super.getOffset + if (offset.nonEmpty) { + clock.waitTillTime(100) + } + offset + } - // There should be only on progress event as batch has been processed - assert(listener.progressStatuses.size === 1) - val status = listener.progressStatuses.peek() - assert(status != null) - assert(status.name === query.name) - assert(status.id === query.id) - assert(status.sourceStatuses(0).offsetDesc === Some(LongOffset(0).toString)) - assert(status.sinkStatus.offsetDesc === CompositeOffset.fill(LongOffset(0)).toString) + // Wait for manual clock to be 300 first time there is data + override def getBatch(start: Option[Offset], end: Offset): DataFrame = { + clock.waitTillTime(300) + super.getBatch(start, end) + } + } - // No termination events - assert(listener.terminationStatus === null) - } - true - }, - StopStream, - AssertOnQuery("Incorrect query status in onQueryTerminated") { query => - eventually(Timeout(streamingTimeout)) { - val status = listener.terminationStatus - assert(status != null) - assert(status.name === query.name) - assert(status.id === query.id) - assert(status.sourceStatuses(0).offsetDesc === Some(LongOffset(0).toString)) - assert(status.sinkStatus.offsetDesc === CompositeOffset.fill(LongOffset(0)).toString) - assert(listener.terminationException === None) - } - listener.checkAsyncErrors() - true - } - ) + // This is to make sure thatquery waits for manual clock to be 600 first time there is data + val mapped = inputData.toDS().agg(count("*")).as[Long].coalesce(1).map { x => + clock.waitTillTime(600) + x } + + testStream(mapped, OutputMode.Complete)( + StartStream(triggerClock = clock), + AddData(inputData, 1, 2), + AdvanceManualClock(100), // unblock getOffset, will block on getBatch + AdvanceManualClock(200), // unblock getBatch, will block on computation + AdvanceManualClock(300), // unblock computation + AssertOnQuery { _ => clock.getTimeMillis() === 600 }, + AssertOnLastQueryStatus { status: StreamingQueryStatus => + // Check the correctness of the trigger info of the last completed batch reported by + // onQueryProgress + assert(status.triggerDetails.get("triggerId") == "0") + assert(status.triggerDetails.get("isTriggerActive") === "false") + assert(status.triggerDetails.get("isDataPresentInTrigger") === "true") + + assert(status.triggerDetails.get("timestamp.triggerStart") === "0") + assert(status.triggerDetails.get("timestamp.afterGetOffset") === "100") + assert(status.triggerDetails.get("timestamp.afterGetBatch") === "300") + assert(status.triggerDetails.get("timestamp.triggerFinish") === "600") + + assert(status.triggerDetails.get("latency.getOffset.total") === "100") + assert(status.triggerDetails.get("latency.getBatch.total") === "200") + assert(status.triggerDetails.get("latency.optimizer") === "0") + assert(status.triggerDetails.get("latency.offsetLogWrite") === "0") + assert(status.triggerDetails.get("latency.fullTrigger") === "600") + + assert(status.triggerDetails.get("numRows.input.total") === "2") + assert(status.triggerDetails.get("numRows.state.aggregation1.total") === "1") + assert(status.triggerDetails.get("numRows.state.aggregation1.updated") === "1") + + assert(status.sourceStatuses.length === 1) + assert(status.sourceStatuses(0).triggerDetails.get("triggerId") === "0") + assert(status.sourceStatuses(0).triggerDetails.get("latency.getOffset.source") === "100") + assert(status.sourceStatuses(0).triggerDetails.get("latency.getBatch.source") === "200") + assert(status.sourceStatuses(0).triggerDetails.get("numRows.input.source") === "2") + }, + CheckAnswer(2) + ) } test("adding and removing listener") { @@ -172,56 +177,37 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { } test("QueryStarted serialization") { - val queryStartedInfo = new StreamingQueryInfo( - "name", - 1, - Seq(new SourceStatus("source1", None), new SourceStatus("source2", None)), - new SinkStatus("sink", CompositeOffset(None :: None :: Nil).toString)) - val queryStarted = new StreamingQueryListener.QueryStarted(queryStartedInfo) + val queryStarted = new StreamingQueryListener.QueryStarted(StreamingQueryStatus.testStatus) val json = JsonProtocol.sparkEventToJson(queryStarted) val newQueryStarted = JsonProtocol.sparkEventFromJson(json) .asInstanceOf[StreamingQueryListener.QueryStarted] - assertStreamingQueryInfoEquals(queryStarted.queryInfo, newQueryStarted.queryInfo) + assertStreamingQueryInfoEquals(queryStarted.queryStatus, newQueryStarted.queryStatus) } test("QueryProgress serialization") { - val queryProcessInfo = new StreamingQueryInfo( - "name", - 1, - Seq( - new SourceStatus("source1", Some(LongOffset(0).toString)), - new SourceStatus("source2", Some(LongOffset(1).toString))), - new SinkStatus("sink", new CompositeOffset(Array(None, Some(LongOffset(1)))).toString)) - val queryProcess = new StreamingQueryListener.QueryProgress(queryProcessInfo) + val queryProcess = new StreamingQueryListener.QueryProgress(StreamingQueryStatus.testStatus) val json = JsonProtocol.sparkEventToJson(queryProcess) val newQueryProcess = JsonProtocol.sparkEventFromJson(json) .asInstanceOf[StreamingQueryListener.QueryProgress] - assertStreamingQueryInfoEquals(queryProcess.queryInfo, newQueryProcess.queryInfo) + assertStreamingQueryInfoEquals(queryProcess.queryStatus, newQueryProcess.queryStatus) } test("QueryTerminated serialization") { - val queryTerminatedInfo = new StreamingQueryInfo( - "name", - 1, - Seq( - new SourceStatus("source1", Some(LongOffset(0).toString)), - new SourceStatus("source2", Some(LongOffset(1).toString))), - new SinkStatus("sink", new CompositeOffset(Array(None, Some(LongOffset(1)))).toString)) val exception = new RuntimeException("exception") val queryQueryTerminated = new StreamingQueryListener.QueryTerminated( - queryTerminatedInfo, + StreamingQueryStatus.testStatus, Some(exception.getMessage)) val json = JsonProtocol.sparkEventToJson(queryQueryTerminated) val newQueryTerminated = JsonProtocol.sparkEventFromJson(json) .asInstanceOf[StreamingQueryListener.QueryTerminated] - assertStreamingQueryInfoEquals(queryQueryTerminated.queryInfo, newQueryTerminated.queryInfo) + assertStreamingQueryInfoEquals(queryQueryTerminated.queryStatus, newQueryTerminated.queryStatus) assert(queryQueryTerminated.exception === newQueryTerminated.exception) } private def assertStreamingQueryInfoEquals( - expected: StreamingQueryInfo, - actual: StreamingQueryInfo): Unit = { + expected: StreamingQueryStatus, + actual: StreamingQueryStatus): Unit = { assert(expected.name === actual.name) assert(expected.sourceStatuses.size === actual.sourceStatuses.size) expected.sourceStatuses.zip(actual.sourceStatuses).foreach { @@ -243,7 +229,7 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { private def withListenerAdded(listener: StreamingQueryListener)(body: => Unit): Unit = { try { - failAfter(1 minute) { + failAfter(streamingTimeout) { spark.streams.addListener(listener) body } @@ -258,49 +244,9 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { val listenerBus = spark.streams invokePrivate listenerBusMethod() listenerBus.listeners.toArray.map(_.asInstanceOf[StreamingQueryListener]) } +} - class QueryStatusCollector extends StreamingQueryListener { - // to catch errors in the async listener events - @volatile private var asyncTestWaiter = new Waiter - - @volatile var startStatus: StreamingQueryInfo = null - @volatile var terminationStatus: StreamingQueryInfo = null - @volatile var terminationException: Option[String] = null - - val progressStatuses = new ConcurrentLinkedQueue[StreamingQueryInfo] - - def reset(): Unit = { - startStatus = null - terminationStatus = null - progressStatuses.clear() - asyncTestWaiter = new Waiter - } - - def checkAsyncErrors(): Unit = { - asyncTestWaiter.await(timeout(streamingTimeout)) - } - - - override def onQueryStarted(queryStarted: QueryStarted): Unit = { - asyncTestWaiter { - startStatus = queryStarted.queryInfo - } - } - - override def onQueryProgress(queryProgress: QueryProgress): Unit = { - asyncTestWaiter { - assert(startStatus != null, "onQueryProgress called before onQueryStarted") - progressStatuses.add(queryProgress.queryInfo) - } - } - - override def onQueryTerminated(queryTerminated: QueryTerminated): Unit = { - asyncTestWaiter { - assert(startStatus != null, "onQueryTerminated called before onQueryStarted") - terminationStatus = queryTerminated.queryInfo - terminationException = queryTerminated.exception - } - asyncTestWaiter.dismiss() - } - } +object StreamingQueryListenerSuite { + // Singleton reference to clock that does not get serialized in task closures + @volatile var clock: ManualClock = null } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 88f1f188ab2af..9f8e2db966367 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -17,18 +17,27 @@ package org.apache.spark.sql.streaming +import org.scalactic.TolerantNumerics +import org.scalatest.concurrent.Eventually._ import org.scalatest.BeforeAndAfter +import org.apache.spark.internal.Logging +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.streaming.StreamingQueryListener._ +import org.apache.spark.sql.types.StructType import org.apache.spark.SparkException -import org.apache.spark.sql.execution.streaming.{CompositeOffset, LongOffset, MemoryStream, StreamExecution} +import org.apache.spark.sql.execution.streaming._ import org.apache.spark.util.Utils -class StreamingQuerySuite extends StreamTest with BeforeAndAfter { +class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging { import AwaitTerminationTester._ import testImplicits._ + // To make === between double tolerate inexact values + implicit val doubleEquality = TolerantNumerics.tolerantDoubleEquality(0.01) + after { sqlContext.streams.active.foreach(_.stop()) } @@ -100,31 +109,145 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter { ) } - testQuietly("source and sink statuses") { + testQuietly("query statuses") { val inputData = MemoryStream[Int] val mapped = inputData.toDS().map(6 / _) - testStream(mapped)( - AssertOnQuery(_.sourceStatuses.length === 1), + AssertOnQuery(q => q.status.name === q.name), + AssertOnQuery(q => q.status.id === q.id), + AssertOnQuery(_.status.timestamp <= System.currentTimeMillis), + AssertOnQuery(_.status.inputRate === 0.0), + AssertOnQuery(_.status.processingRate === 0.0), + AssertOnQuery(_.status.sourceStatuses.length === 1), + AssertOnQuery(_.status.sourceStatuses(0).description.contains("Memory")), + AssertOnQuery(_.status.sourceStatuses(0).offsetDesc === "-"), + AssertOnQuery(_.status.sourceStatuses(0).inputRate === 0.0), + AssertOnQuery(_.status.sourceStatuses(0).processingRate === 0.0), + AssertOnQuery(_.status.sinkStatus.description.contains("Memory")), + AssertOnQuery(_.status.sinkStatus.offsetDesc === CompositeOffset(None :: Nil).toString), AssertOnQuery(_.sourceStatuses(0).description.contains("Memory")), - AssertOnQuery(_.sourceStatuses(0).offsetDesc === None), + AssertOnQuery(_.sourceStatuses(0).offsetDesc === "-"), + AssertOnQuery(_.sourceStatuses(0).inputRate === 0.0), + AssertOnQuery(_.sourceStatuses(0).processingRate === 0.0), AssertOnQuery(_.sinkStatus.description.contains("Memory")), AssertOnQuery(_.sinkStatus.offsetDesc === new CompositeOffset(None :: Nil).toString), + AddData(inputData, 1, 2), CheckAnswer(6, 3), - AssertOnQuery(_.sourceStatuses(0).offsetDesc === Some(LongOffset(0).toString)), + AssertOnQuery(_.status.timestamp <= System.currentTimeMillis), + AssertOnQuery(_.status.inputRate >= 0.0), + AssertOnQuery(_.status.processingRate >= 0.0), + AssertOnQuery(_.status.sourceStatuses.length === 1), + AssertOnQuery(_.status.sourceStatuses(0).description.contains("Memory")), + AssertOnQuery(_.status.sourceStatuses(0).offsetDesc === LongOffset(0).toString), + AssertOnQuery(_.status.sourceStatuses(0).inputRate >= 0.0), + AssertOnQuery(_.status.sourceStatuses(0).processingRate >= 0.0), + AssertOnQuery(_.status.sinkStatus.description.contains("Memory")), + AssertOnQuery(_.status.sinkStatus.offsetDesc === + CompositeOffset.fill(LongOffset(0)).toString), + AssertOnQuery(_.sourceStatuses(0).offsetDesc === LongOffset(0).toString), + AssertOnQuery(_.sourceStatuses(0).inputRate >= 0.0), + AssertOnQuery(_.sourceStatuses(0).processingRate >= 0.0), AssertOnQuery(_.sinkStatus.offsetDesc === CompositeOffset.fill(LongOffset(0)).toString), + AddData(inputData, 1, 2), CheckAnswer(6, 3, 6, 3), - AssertOnQuery(_.sourceStatuses(0).offsetDesc === Some(LongOffset(1).toString)), + AssertOnQuery(_.status.sourceStatuses(0).offsetDesc === LongOffset(1).toString), + AssertOnQuery(_.status.sinkStatus.offsetDesc === + CompositeOffset.fill(LongOffset(1)).toString), + AssertOnQuery(_.sourceStatuses(0).offsetDesc === LongOffset(1).toString), AssertOnQuery(_.sinkStatus.offsetDesc === CompositeOffset.fill(LongOffset(1)).toString), + + StopStream, + AssertOnQuery(_.status.inputRate === 0.0), + AssertOnQuery(_.status.processingRate === 0.0), + AssertOnQuery(_.status.sourceStatuses.length === 1), + AssertOnQuery(_.status.sourceStatuses(0).offsetDesc === LongOffset(1).toString), + AssertOnQuery(_.status.sourceStatuses(0).inputRate === 0.0), + AssertOnQuery(_.status.sourceStatuses(0).processingRate === 0.0), + AssertOnQuery(_.status.sinkStatus.offsetDesc === + CompositeOffset.fill(LongOffset(1)).toString), + AssertOnQuery(_.sourceStatuses(0).offsetDesc === LongOffset(1).toString), + AssertOnQuery(_.sourceStatuses(0).inputRate === 0.0), + AssertOnQuery(_.sourceStatuses(0).processingRate === 0.0), + AssertOnQuery(_.sinkStatus.offsetDesc === CompositeOffset.fill(LongOffset(1)).toString), + AssertOnQuery(_.status.triggerDetails.isEmpty), + + StartStream(), AddData(inputData, 0), ExpectFailure[SparkException], - AssertOnQuery(_.sourceStatuses(0).offsetDesc === Some(LongOffset(2).toString)), + AssertOnQuery(_.status.inputRate === 0.0), + AssertOnQuery(_.status.processingRate === 0.0), + AssertOnQuery(_.status.sourceStatuses.length === 1), + AssertOnQuery(_.status.sourceStatuses(0).offsetDesc === LongOffset(2).toString), + AssertOnQuery(_.status.sourceStatuses(0).inputRate === 0.0), + AssertOnQuery(_.status.sourceStatuses(0).processingRate === 0.0), + AssertOnQuery(_.status.sinkStatus.offsetDesc === + CompositeOffset.fill(LongOffset(1)).toString), + AssertOnQuery(_.sourceStatuses(0).offsetDesc === LongOffset(2).toString), + AssertOnQuery(_.sourceStatuses(0).inputRate === 0.0), + AssertOnQuery(_.sourceStatuses(0).processingRate === 0.0), AssertOnQuery(_.sinkStatus.offsetDesc === CompositeOffset.fill(LongOffset(1)).toString) ) } + test("codahale metrics") { + val inputData = MemoryStream[Int] + + /** Whether metrics of a query is registered for reporting */ + def isMetricsRegistered(query: StreamingQuery): Boolean = { + val sourceName = s"StructuredStreaming.${query.name}" + val sources = spark.sparkContext.env.metricsSystem.getSourcesByName(sourceName) + require(sources.size <= 1) + sources.nonEmpty + } + // Disabled by default + assert(spark.conf.get("spark.sql.streaming.metricsEnabled").toBoolean === false) + + withSQLConf("spark.sql.streaming.metricsEnabled" -> "false") { + testStream(inputData.toDF)( + AssertOnQuery { q => !isMetricsRegistered(q) }, + StopStream, + AssertOnQuery { q => !isMetricsRegistered(q) } + ) + } + + // Registered when enabled + withSQLConf("spark.sql.streaming.metricsEnabled" -> "true") { + testStream(inputData.toDF)( + AssertOnQuery { q => isMetricsRegistered(q) }, + StopStream, + AssertOnQuery { q => !isMetricsRegistered(q) } + ) + } + } + + test("input row calculation with mixed batch and streaming sources") { + val streamingTriggerDF = spark.createDataset(1 to 10).toDF + val streamingInputDF = createSingleTriggerStreamingDF(streamingTriggerDF).toDF("value") + val staticInputDF = spark.createDataFrame(Seq(1 -> "1", 2 -> "2")).toDF("value", "anotherValue") + + // Trigger input has 10 rows, static input has 2 rows, + // therefore after the first trigger, the calculated input rows should be 10 + val status = getFirstTriggerStatus(streamingInputDF.join(staticInputDF, "value")) + assert(status.triggerDetails.get("numRows.input.total") === "10") + assert(status.sourceStatuses.size === 1) + assert(status.sourceStatuses(0).triggerDetails.get("numRows.input.source") === "10") + } + + test("input row calculation with trigger DF having multiple leaves") { + val streamingTriggerDF = + spark.createDataset(1 to 5).toDF.union(spark.createDataset(6 to 10).toDF) + require(streamingTriggerDF.logicalPlan.collectLeaves().size > 1) + val streamingInputDF = createSingleTriggerStreamingDF(streamingTriggerDF) + + // After the first trigger, the calculated input rows should be 10 + val status = getFirstTriggerStatus(streamingInputDF) + assert(status.triggerDetails.get("numRows.input.total") === "10") + assert(status.sourceStatuses.size === 1) + assert(status.sourceStatuses(0).triggerDetails.get("numRows.input.source") === "10") + } + testQuietly("StreamExecution metadata garbage collection") { val inputData = MemoryStream[Int] val mapped = inputData.toDS().map(6 / _) @@ -149,6 +272,45 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter { ) } + /** Create a streaming DF that only execute one batch in which it returns the given static DF */ + private def createSingleTriggerStreamingDF(triggerDF: DataFrame): DataFrame = { + require(!triggerDF.isStreaming) + // A streaming Source that generate only on trigger and returns the given Dataframe as batch + val source = new Source() { + override def schema: StructType = triggerDF.schema + override def getOffset: Option[Offset] = Some(LongOffset(0)) + override def getBatch(start: Option[Offset], end: Offset): DataFrame = triggerDF + override def stop(): Unit = {} + } + StreamingExecutionRelation(source) + } + + /** Returns the query status at the end of the first trigger of streaming DF */ + private def getFirstTriggerStatus(streamingDF: DataFrame): StreamingQueryStatus = { + // A StreamingQueryListener that gets the query status after the first completed trigger + val listener = new StreamingQueryListener { + @volatile var firstStatus: StreamingQueryStatus = null + override def onQueryStarted(queryStarted: QueryStarted): Unit = { } + override def onQueryProgress(queryProgress: QueryProgress): Unit = { + if (firstStatus == null) firstStatus = queryProgress.queryStatus + } + override def onQueryTerminated(queryTerminated: QueryTerminated): Unit = { } + } + + try { + spark.streams.addListener(listener) + val q = streamingDF.writeStream.format("memory").queryName("test").start() + q.processAllAvailable() + eventually(timeout(streamingTimeout)) { + assert(listener.firstStatus != null) + } + listener.firstStatus + } finally { + spark.streams.active.map(_.stop()) + spark.streams.removeListener(listener) + } + } + /** * A [[StreamAction]] to test the behavior of `StreamingQuery.awaitTermination()`. * From adc112429d6fe671e6e8294824a0e41a2b1ec2e0 Mon Sep 17 00:00:00 2001 From: petermaxlee Date: Thu, 13 Oct 2016 14:16:39 -0700 Subject: [PATCH 278/306] [SPARK-17661][SQL] Consolidate various listLeafFiles implementations ## What changes were proposed in this pull request? There are 4 listLeafFiles-related functions in Spark: - ListingFileCatalog.listLeafFiles (which calls HadoopFsRelation.listLeafFilesInParallel if the number of paths passed in is greater than a threshold; if it is lower, then it has its own serial version implemented) - HadoopFsRelation.listLeafFiles (called only by HadoopFsRelation.listLeafFilesInParallel) - HadoopFsRelation.listLeafFilesInParallel (called only by ListingFileCatalog.listLeafFiles) It is actually very confusing and error prone because there are effectively two distinct implementations for the serial version of listing leaf files. As an example, SPARK-17599 updated only one of the code path and ignored the other one. This code can be improved by: - Move all file listing code into ListingFileCatalog, since it is the only class that needs this. - Keep only one function for listing files in serial. ## How was this patch tested? This change should be covered by existing unit and integration tests. I also moved a test case for HadoopFsRelation.shouldFilterOut from HadoopFsRelationSuite to ListingFileCatalogSuite. Author: petermaxlee Closes #15235 from petermaxlee/SPARK-17661. --- .../datasources/ListingFileCatalog.scala | 231 +++++++++++++----- .../datasources/fileSourceInterfaces.scala | 154 ------------ .../datasources/HadoopFsRelationSuite.scala | 11 - .../datasources/ListingFileCatalogSuite.scala | 34 +++ 4 files changed, 206 insertions(+), 224 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/ListingFileCatalogSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ListingFileCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ListingFileCatalog.scala index 32532084236cf..a68ae523e0faa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ListingFileCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ListingFileCatalog.scala @@ -21,11 +21,14 @@ import java.io.FileNotFoundException import scala.collection.mutable -import org.apache.hadoop.fs.{FileStatus, LocatedFileStatus, Path} +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs._ import org.apache.hadoop.mapred.{FileInputFormat, JobConf} +import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession import org.apache.spark.sql.types.StructType +import org.apache.spark.util.SerializableConfiguration /** @@ -82,73 +85,183 @@ class ListingFileCatalog( * This is publicly visible for testing. */ def listLeafFiles(paths: Seq[Path]): mutable.LinkedHashSet[FileStatus] = { - if (paths.length >= sparkSession.sessionState.conf.parallelPartitionDiscoveryThreshold) { - HadoopFsRelation.listLeafFilesInParallel(paths, hadoopConf, sparkSession) - } else { - // Right now, the number of paths is less than the value of - // parallelPartitionDiscoveryThreshold. So, we will list file statues at the driver. - // If there is any child that has more files than the threshold, we will use parallel - // listing. - - // Dummy jobconf to get to the pathFilter defined in configuration - val jobConf = new JobConf(hadoopConf, this.getClass) - val pathFilter = FileInputFormat.getInputPathFilter(jobConf) - - val statuses: Seq[FileStatus] = paths.flatMap { path => - val fs = path.getFileSystem(hadoopConf) - logTrace(s"Listing $path on driver") - - val childStatuses = { - try { - val stats = fs.listStatus(path) - if (pathFilter != null) stats.filter(f => pathFilter.accept(f.getPath)) else stats - } catch { - case _: FileNotFoundException => - logWarning(s"The directory $path was not found. Was it deleted very recently?") - Array.empty[FileStatus] - } - } + val files = + if (paths.length >= sparkSession.sessionState.conf.parallelPartitionDiscoveryThreshold) { + ListingFileCatalog.listLeafFilesInParallel(paths, hadoopConf, sparkSession) + } else { + ListingFileCatalog.listLeafFilesInSerial(paths, hadoopConf) + } + + mutable.LinkedHashSet(files: _*) + } + + override def equals(other: Any): Boolean = other match { + case hdfs: ListingFileCatalog => paths.toSet == hdfs.paths.toSet + case _ => false + } + + override def hashCode(): Int = paths.toSet.hashCode() +} + + +object ListingFileCatalog extends Logging { + + /** A serializable variant of HDFS's BlockLocation. */ + private case class SerializableBlockLocation( + names: Array[String], + hosts: Array[String], + offset: Long, + length: Long) + + /** A serializable variant of HDFS's FileStatus. */ + private case class SerializableFileStatus( + path: String, + length: Long, + isDir: Boolean, + blockReplication: Short, + blockSize: Long, + modificationTime: Long, + accessTime: Long, + blockLocations: Array[SerializableBlockLocation]) + + /** + * List a collection of path recursively. + */ + private def listLeafFilesInSerial( + paths: Seq[Path], + hadoopConf: Configuration): Seq[FileStatus] = { + // Dummy jobconf to get to the pathFilter defined in configuration + val jobConf = new JobConf(hadoopConf, this.getClass) + val filter = FileInputFormat.getInputPathFilter(jobConf) + + paths.flatMap { path => + val fs = path.getFileSystem(hadoopConf) + listLeafFiles0(fs, path, filter) + } + } - childStatuses.map { - case f: LocatedFileStatus => f - - // NOTE: - // - // - Although S3/S3A/S3N file system can be quite slow for remote file metadata - // operations, calling `getFileBlockLocations` does no harm here since these file system - // implementations don't actually issue RPC for this method. - // - // - Here we are calling `getFileBlockLocations` in a sequential manner, but it should not - // be a big deal since we always use to `listLeafFilesInParallel` when the number of - // paths exceeds threshold. - case f => - if (f.isDirectory ) { - // If f is a directory, we do not need to call getFileBlockLocations (SPARK-14959). - f - } else { - HadoopFsRelation.createLocatedFileStatus(f, fs.getFileBlockLocations(f, 0, f.getLen)) + /** + * List a collection of path recursively in parallel (using Spark executors). + * Each task launched will use [[listLeafFilesInSerial]] to list. + */ + private def listLeafFilesInParallel( + paths: Seq[Path], + hadoopConf: Configuration, + sparkSession: SparkSession): Seq[FileStatus] = { + assert(paths.size >= sparkSession.sessionState.conf.parallelPartitionDiscoveryThreshold) + logInfo(s"Listing leaf files and directories in parallel under: ${paths.mkString(", ")}") + + val sparkContext = sparkSession.sparkContext + val serializableConfiguration = new SerializableConfiguration(hadoopConf) + val serializedPaths = paths.map(_.toString) + + // Set the number of parallelism to prevent following file listing from generating many tasks + // in case of large #defaultParallelism. + val numParallelism = Math.min(paths.size, 10000) + + val statuses = sparkContext + .parallelize(serializedPaths, numParallelism) + .mapPartitions { paths => + val hadoopConf = serializableConfiguration.value + listLeafFilesInSerial(paths.map(new Path(_)).toSeq, hadoopConf).iterator + }.map { status => + // Turn FileStatus into SerializableFileStatus so we can send it back to the driver + val blockLocations = status match { + case f: LocatedFileStatus => + f.getBlockLocations.map { loc => + SerializableBlockLocation( + loc.getNames, + loc.getHosts, + loc.getOffset, + loc.getLength) } + + case _ => + Array.empty[SerializableBlockLocation] } - }.filterNot { status => - val name = status.getPath.getName - HadoopFsRelation.shouldFilterOut(name) - } - val (dirs, files) = statuses.partition(_.isDirectory) + SerializableFileStatus( + status.getPath.toString, + status.getLen, + status.isDirectory, + status.getReplication, + status.getBlockSize, + status.getModificationTime, + status.getAccessTime, + blockLocations) + }.collect() - // It uses [[LinkedHashSet]] since the order of files can affect the results. (SPARK-11500) - if (dirs.isEmpty) { - mutable.LinkedHashSet(files: _*) - } else { - mutable.LinkedHashSet(files: _*) ++ listLeafFiles(dirs.map(_.getPath)) + // Turn SerializableFileStatus back to Status + statuses.map { f => + val blockLocations = f.blockLocations.map { loc => + new BlockLocation(loc.names, loc.hosts, loc.offset, loc.length) } + new LocatedFileStatus( + new FileStatus( + f.length, f.isDir, f.blockReplication, f.blockSize, f.modificationTime, new Path(f.path)), + blockLocations) } } - override def equals(other: Any): Boolean = other match { - case hdfs: ListingFileCatalog => paths.toSet == hdfs.paths.toSet - case _ => false + /** + * List a single path, provided as a FileStatus, in serial. + */ + private def listLeafFiles0( + fs: FileSystem, path: Path, filter: PathFilter): Seq[FileStatus] = { + logTrace(s"Listing $path") + val name = path.getName.toLowerCase + if (shouldFilterOut(name)) { + Seq.empty[FileStatus] + } else { + // [SPARK-17599] Prevent ListingFileCatalog from failing if path doesn't exist + // Note that statuses only include FileStatus for the files and dirs directly under path, + // and does not include anything else recursively. + val statuses = try fs.listStatus(path) catch { + case _: FileNotFoundException => + logWarning(s"The directory $path was not found. Was it deleted very recently?") + Array.empty[FileStatus] + } + + val allLeafStatuses = { + val (dirs, files) = statuses.partition(_.isDirectory) + val stats = files ++ dirs.flatMap(dir => listLeafFiles0(fs, dir.getPath, filter)) + if (filter != null) stats.filter(f => filter.accept(f.getPath)) else stats + } + + allLeafStatuses.filterNot(status => shouldFilterOut(status.getPath.getName)).map { + case f: LocatedFileStatus => + f + + // NOTE: + // + // - Although S3/S3A/S3N file system can be quite slow for remote file metadata + // operations, calling `getFileBlockLocations` does no harm here since these file system + // implementations don't actually issue RPC for this method. + // + // - Here we are calling `getFileBlockLocations` in a sequential manner, but it should not + // be a big deal since we always use to `listLeafFilesInParallel` when the number of + // paths exceeds threshold. + case f => + // The other constructor of LocatedFileStatus will call FileStatus.getPermission(), + // which is very slow on some file system (RawLocalFileSystem, which is launch a + // subprocess and parse the stdout). + val locations = fs.getFileBlockLocations(f, 0, f.getLen) + val lfs = new LocatedFileStatus(f.getLen, f.isDirectory, f.getReplication, f.getBlockSize, + f.getModificationTime, 0, null, null, null, null, f.getPath, locations) + if (f.isSymlink) { + lfs.setSymlink(f.getSymlink) + } + lfs + } + } } - override def hashCode(): Int = paths.toSet.hashCode() + /** Checks if we should filter out this path name. */ + def shouldFilterOut(pathName: String): Boolean = { + // We filter everything that starts with _ and ., except _common_metadata and _metadata + // because Parquet needs to find those metadata files from leaf files returned by this method. + // We should refactor this logic to not mix metadata files with data files. + ((pathName.startsWith("_") && !pathName.contains("=")) || pathName.startsWith(".")) && + !pathName.startsWith("_common_metadata") && !pathName.startsWith("_metadata") + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala index 5cc5f32e6e809..69dd622ce4a54 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala @@ -17,16 +17,12 @@ package org.apache.spark.sql.execution.datasources -import scala.collection.mutable - import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs._ import org.apache.hadoop.io.compress.{CompressionCodecFactory, SplittableCompressionCodec} -import org.apache.hadoop.mapred.{FileInputFormat, JobConf} import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.spark.annotation.Experimental -import org.apache.spark.internal.Logging import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.catalog.BucketSpec @@ -35,7 +31,6 @@ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjectio import org.apache.spark.sql.execution.FileRelation import org.apache.spark.sql.sources.{BaseRelation, DataSourceRegister, Filter} import org.apache.spark.sql.types.StructType -import org.apache.spark.util.SerializableConfiguration /** * ::Experimental:: @@ -352,152 +347,3 @@ trait FileCatalog { /** Refresh the file listing */ def refresh(): Unit } - - -/** - * Helper methods for gathering metadata from HDFS. - */ -object HadoopFsRelation extends Logging { - - /** Checks if we should filter out this path name. */ - def shouldFilterOut(pathName: String): Boolean = { - // We filter everything that starts with _ and ., except _common_metadata and _metadata - // because Parquet needs to find those metadata files from leaf files returned by this method. - // We should refactor this logic to not mix metadata files with data files. - ((pathName.startsWith("_") && !pathName.contains("=")) || pathName.startsWith(".")) && - !pathName.startsWith("_common_metadata") && !pathName.startsWith("_metadata") - } - - /** - * Create a LocatedFileStatus using FileStatus and block locations. - */ - def createLocatedFileStatus(f: FileStatus, locations: Array[BlockLocation]): LocatedFileStatus = { - // The other constructor of LocatedFileStatus will call FileStatus.getPermission(), which is - // very slow on some file system (RawLocalFileSystem, which is launch a subprocess and parse the - // stdout). - val lfs = new LocatedFileStatus(f.getLen, f.isDirectory, f.getReplication, f.getBlockSize, - f.getModificationTime, 0, null, null, null, null, f.getPath, locations) - if (f.isSymlink) { - lfs.setSymlink(f.getSymlink) - } - lfs - } - - // We don't filter files/directories whose name start with "_" except "_temporary" here, as - // specific data sources may take advantages over them (e.g. Parquet _metadata and - // _common_metadata files). "_temporary" directories are explicitly ignored since failed - // tasks/jobs may leave partial/corrupted data files there. Files and directories whose name - // start with "." are also ignored. - def listLeafFiles(fs: FileSystem, status: FileStatus, filter: PathFilter): Array[FileStatus] = { - logTrace(s"Listing ${status.getPath}") - val name = status.getPath.getName.toLowerCase - if (shouldFilterOut(name)) { - Array.empty[FileStatus] - } else { - val statuses = { - val (dirs, files) = fs.listStatus(status.getPath).partition(_.isDirectory) - val stats = files ++ dirs.flatMap(dir => listLeafFiles(fs, dir, filter)) - if (filter != null) stats.filter(f => filter.accept(f.getPath)) else stats - } - // statuses do not have any dirs. - statuses.filterNot(status => shouldFilterOut(status.getPath.getName)).map { - case f: LocatedFileStatus => f - - // NOTE: - // - // - Although S3/S3A/S3N file system can be quite slow for remote file metadata - // operations, calling `getFileBlockLocations` does no harm here since these file system - // implementations don't actually issue RPC for this method. - // - // - Here we are calling `getFileBlockLocations` in a sequential manner, but it should not - // be a big deal since we always use to `listLeafFilesInParallel` when the number of - // paths exceeds threshold. - case f => createLocatedFileStatus(f, fs.getFileBlockLocations(f, 0, f.getLen)) - } - } - } - - // `FileStatus` is Writable but not serializable. What make it worse, somehow it doesn't play - // well with `SerializableWritable`. So there seems to be no way to serialize a `FileStatus`. - // Here we use `FakeFileStatus` to extract key components of a `FileStatus` to serialize it from - // executor side and reconstruct it on driver side. - case class FakeBlockLocation( - names: Array[String], - hosts: Array[String], - offset: Long, - length: Long) - - case class FakeFileStatus( - path: String, - length: Long, - isDir: Boolean, - blockReplication: Short, - blockSize: Long, - modificationTime: Long, - accessTime: Long, - blockLocations: Array[FakeBlockLocation]) - - def listLeafFilesInParallel( - paths: Seq[Path], - hadoopConf: Configuration, - sparkSession: SparkSession): mutable.LinkedHashSet[FileStatus] = { - assert(paths.size >= sparkSession.sessionState.conf.parallelPartitionDiscoveryThreshold) - logInfo(s"Listing leaf files and directories in parallel under: ${paths.mkString(", ")}") - - val sparkContext = sparkSession.sparkContext - val serializableConfiguration = new SerializableConfiguration(hadoopConf) - val serializedPaths = paths.map(_.toString) - - // Set the number of parallelism to prevent following file listing from generating many tasks - // in case of large #defaultParallelism. - val numParallelism = Math.min(paths.size, 10000) - - val fakeStatuses = sparkContext - .parallelize(serializedPaths, numParallelism) - .mapPartitions { paths => - // Dummy jobconf to get to the pathFilter defined in configuration - // It's very expensive to create a JobConf(ClassUtil.findContainingJar() is slow) - val jobConf = new JobConf(serializableConfiguration.value, this.getClass) - val pathFilter = FileInputFormat.getInputPathFilter(jobConf) - paths.map(new Path(_)).flatMap { path => - val fs = path.getFileSystem(serializableConfiguration.value) - listLeafFiles(fs, fs.getFileStatus(path), pathFilter) - } - }.map { status => - val blockLocations = status match { - case f: LocatedFileStatus => - f.getBlockLocations.map { loc => - FakeBlockLocation( - loc.getNames, - loc.getHosts, - loc.getOffset, - loc.getLength) - } - - case _ => - Array.empty[FakeBlockLocation] - } - - FakeFileStatus( - status.getPath.toString, - status.getLen, - status.isDirectory, - status.getReplication, - status.getBlockSize, - status.getModificationTime, - status.getAccessTime, - blockLocations) - }.collect() - - val hadoopFakeStatuses = fakeStatuses.map { f => - val blockLocations = f.blockLocations.map { loc => - new BlockLocation(loc.names, loc.hosts, loc.offset, loc.length) - } - new LocatedFileStatus( - new FileStatus( - f.length, f.isDir, f.blockReplication, f.blockSize, f.modificationTime, new Path(f.path)), - blockLocations) - } - mutable.LinkedHashSet(hadoopFakeStatuses: _*) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala index 3c68dc8bb98d8..89d57653adcbd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala @@ -39,15 +39,4 @@ class HadoopFsRelationSuite extends QueryTest with SharedSQLContext { assert(df.queryExecution.logical.statistics.sizeInBytes === BigInt(totalSize)) } } - - test("file filtering") { - assert(!HadoopFsRelation.shouldFilterOut("abcd")) - assert(HadoopFsRelation.shouldFilterOut(".ab")) - assert(HadoopFsRelation.shouldFilterOut("_cd")) - - assert(!HadoopFsRelation.shouldFilterOut("_metadata")) - assert(!HadoopFsRelation.shouldFilterOut("_common_metadata")) - assert(HadoopFsRelation.shouldFilterOut("_ab_metadata")) - assert(HadoopFsRelation.shouldFilterOut("_cd_common_metadata")) - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/ListingFileCatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/ListingFileCatalogSuite.scala new file mode 100644 index 0000000000000..f15730aeb11f2 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/ListingFileCatalogSuite.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.spark.SparkFunSuite + +class ListingFileCatalogSuite extends SparkFunSuite { + + test("file filtering") { + assert(!ListingFileCatalog.shouldFilterOut("abcd")) + assert(ListingFileCatalog.shouldFilterOut(".ab")) + assert(ListingFileCatalog.shouldFilterOut("_cd")) + + assert(!ListingFileCatalog.shouldFilterOut("_metadata")) + assert(!ListingFileCatalog.shouldFilterOut("_common_metadata")) + assert(ListingFileCatalog.shouldFilterOut("_ab_metadata")) + assert(ListingFileCatalog.shouldFilterOut("_cd_common_metadata")) + } +} From 9dc0ca060d5925cd666b34021e62f7b38bb3aabb Mon Sep 17 00:00:00 2001 From: Jakob Odersky Date: Thu, 13 Oct 2016 17:48:09 -0700 Subject: [PATCH 279/306] [SPARK-17368][SQL] Add support for value class serialization and deserialization ## What changes were proposed in this pull request? Value classes were unsupported because catalyst data types were obtained through reflection on erased types, which would resolve to a value class' wrapped type and hence lead to unavailable methods during code generation. E.g. the following class ```scala case class Foo(x: Int) extends AnyVal ``` would be seen as an `int` in catalyst and will cause instance cast failures when generated java code tries to treat it as a `Foo`. This patch simply removes the erasure step when getting data types for catalyst. ## How was this patch tested? Additional tests in `ExpressionEncoderSuite`. Author: Jakob Odersky Closes #15284 from jodersky/value-classes. --- .../spark/sql/catalyst/ScalaReflection.scala | 2 +- .../catalyst/encoders/ExpressionEncoderSuite.scala | 14 ++++++++++++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 7923cfce82100..31c6e5def143b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -628,7 +628,7 @@ object ScalaReflection extends ScalaReflection { /* * Retrieves the runtime class corresponding to the provided type. */ - def getClassFromType(tpe: Type): Class[_] = mirror.runtimeClass(tpe.erasure.typeSymbol.asClass) + def getClassFromType(tpe: Type): Class[_] = mirror.runtimeClass(tpe.typeSymbol.asClass) case class Schema(dataType: DataType, nullable: Boolean) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index 4df9062018995..4d896c2e38f10 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -66,8 +66,6 @@ case class RepeatedData( mapFieldNull: scala.collection.Map[Int, java.lang.Long], structField: PrimitiveData) -case class SpecificCollection(l: List[Int]) - /** For testing Kryo serialization based encoder. */ class KryoSerializable(val value: Int) { override def hashCode(): Int = value @@ -107,6 +105,12 @@ class UDTForCaseClass extends UserDefinedType[UDTCaseClass] { } } +case class PrimitiveValueClass(wrapped: Int) extends AnyVal +case class ReferenceValueClass(wrapped: ReferenceValueClass.Container) extends AnyVal +object ReferenceValueClass { + case class Container(data: Int) +} + class ExpressionEncoderSuite extends PlanTest with AnalysisTest { OuterScopes.addOuterScope(this) @@ -290,6 +294,12 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest { ExpressionEncoder.tuple(intEnc, ExpressionEncoder.tuple(intEnc, longEnc)) } + encodeDecodeTest( + PrimitiveValueClass(42), "primitive value class") + + encodeDecodeTest( + ReferenceValueClass(ReferenceValueClass.Container(1)), "reference value class") + productTest(("UDT", new ExamplePoint(0.1, 0.2))) test("nullable of encoder schema") { From 44cbb61b34a98e3e0d8e2543a4eb6e950e0019a5 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 13 Oct 2016 19:44:24 -0700 Subject: [PATCH 280/306] [SPARK-15957][FOLLOW-UP][ML][PYSPARK] Add Python API for RFormula forceIndexLabel. ## What changes were proposed in this pull request? Follow-up work of #13675, add Python API for ```RFormula forceIndexLabel```. ## How was this patch tested? Unit test. Author: Yanbo Liang Closes #15430 from yanboliang/spark-15957-python. --- python/pyspark/ml/feature.py | 31 +++++++++++++++++++++++++++---- python/pyspark/ml/tests.py | 16 ++++++++++++++++ 2 files changed, 43 insertions(+), 4 deletions(-) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 64b21caa616ec..a33c3e79453e1 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -2494,21 +2494,30 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, JavaMLReadable, JavaM formula = Param(Params._dummy(), "formula", "R model formula", typeConverter=TypeConverters.toString) + forceIndexLabel = Param(Params._dummy(), "forceIndexLabel", + "Force to index label whether it is numeric or string", + typeConverter=TypeConverters.toBoolean) + @keyword_only - def __init__(self, formula=None, featuresCol="features", labelCol="label"): + def __init__(self, formula=None, featuresCol="features", labelCol="label", + forceIndexLabel=False): """ - __init__(self, formula=None, featuresCol="features", labelCol="label") + __init__(self, formula=None, featuresCol="features", labelCol="label", \ + forceIndexLabel=False) """ super(RFormula, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.RFormula", self.uid) + self._setDefault(forceIndexLabel=False) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @keyword_only @since("1.5.0") - def setParams(self, formula=None, featuresCol="features", labelCol="label"): + def setParams(self, formula=None, featuresCol="features", labelCol="label", + forceIndexLabel=False): """ - setParams(self, formula=None, featuresCol="features", labelCol="label") + setParams(self, formula=None, featuresCol="features", labelCol="label", \ + forceIndexLabel=False) Sets params for RFormula. """ kwargs = self.setParams._input_kwargs @@ -2528,6 +2537,20 @@ def getFormula(self): """ return self.getOrDefault(self.formula) + @since("2.1.0") + def setForceIndexLabel(self, value): + """ + Sets the value of :py:attr:`forceIndexLabel`. + """ + return self._set(forceIndexLabel=value) + + @since("2.1.0") + def getForceIndexLabel(self): + """ + Gets the value of :py:attr:`forceIndexLabel`. + """ + return self.getOrDefault(self.forceIndexLabel) + def _create_model(self, java_model): return RFormulaModel(java_model) diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index e233549850888..9d46cc3b4ae64 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -477,6 +477,22 @@ def test_count_vectorizer_with_binary(self): feature, expected = r self.assertEqual(feature, expected) + def test_rformula_force_index_label(self): + df = self.spark.createDataFrame([ + (1.0, 1.0, "a"), + (0.0, 2.0, "b"), + (1.0, 0.0, "a")], ["y", "x", "s"]) + # Does not index label by default since it's numeric type. + rf = RFormula(formula="y ~ x + s") + model = rf.fit(df) + transformedDF = model.transform(df) + self.assertEqual(transformedDF.head().label, 1.0) + # Force to index label. + rf2 = RFormula(formula="y ~ x + s").setForceIndexLabel(True) + model2 = rf2.fit(df) + transformedDF2 = model2.transform(df) + self.assertEqual(transformedDF2.head().label, 0.0) + class HasInducedError(Params): From 8543996c3f44098a521fc6b90ca0bb575f606e2a Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 14 Oct 2016 12:35:59 +0800 Subject: [PATCH 281/306] [SPARK-17927][SQL] Remove dead code in WriterContainer. ## What changes were proposed in this pull request? speculationEnabled and DATASOURCE_OUTPUTPATH seem like just dead code. ## How was this patch tested? Tests should fail if they are not dead code. Author: Reynold Xin Closes #15477 from rxin/SPARK-17927. --- .../sql/execution/datasources/WriterContainer.scala | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala index 7880c7cfa16f8..253aa4405defa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala @@ -49,7 +49,6 @@ private[datasources] case class WriteRelation( object WriterContainer { val DATASOURCE_WRITEJOBUUID = "spark.sql.sources.writeJobUUID" - val DATASOURCE_OUTPUTPATH = "spark.sql.sources.output.path" } private[datasources] abstract class BaseWriterContainer( @@ -73,9 +72,6 @@ private[datasources] abstract class BaseWriterContainer( // This is only used on driver side. @transient private val jobContext: JobContext = job - private val speculationEnabled: Boolean = - relation.sparkSession.sparkContext.conf.getBoolean("spark.speculation", defaultValue = false) - // The following fields are initialized and used on both driver and executor side. @transient protected var outputCommitter: OutputCommitter = _ @transient private var jobId: JobID = _ @@ -247,8 +243,6 @@ private[datasources] class DefaultWriterContainer( def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = { executorSideSetup(taskContext) - val configuration = taskAttemptContext.getConfiguration - configuration.set(WriterContainer.DATASOURCE_OUTPUTPATH, outputPath) var writer = newOutputWriter(getWorkPath) writer.initConverter(dataSchema) @@ -353,15 +347,10 @@ private[datasources] class DynamicPartitionWriterContainer( private def newOutputWriter( key: InternalRow, getPartitionString: UnsafeProjection): OutputWriter = { - val configuration = taskAttemptContext.getConfiguration val path = if (partitionColumns.nonEmpty) { val partitionPath = getPartitionString(key).getString(0) - configuration.set( - WriterContainer.DATASOURCE_OUTPUTPATH, - new Path(outputPath, partitionPath).toString) new Path(getWorkPath, partitionPath).toString } else { - configuration.set(WriterContainer.DATASOURCE_OUTPUTPATH, outputPath) getWorkPath } val bucketId = getBucketIdFromKey(key) From 6c29b3de763115d8676ed91f896e75c490e8c5b2 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 14 Oct 2016 14:14:52 +0800 Subject: [PATCH 282/306] [SPARK-17925][SQL] Break fileSourceInterfaces.scala into multiple pieces ## What changes were proposed in this pull request? This patch does a few changes to the file structure of data sources: - Break fileSourceInterfaces.scala into multiple pieces (HadoopFsRelation, FileFormat, OutputWriter) - Move ParquetOutputWriter into its own file I created this as a separate patch so it'd be easier to review my future PRs that focus on refactoring this internal logic. This patch only moves code around, and has no logic changes. ## How was this patch tested? N/A - should be covered by existing tests. Author: Reynold Xin Closes #15473 from rxin/SPARK-17925. --- ...ourceInterfaces.scala => FileFormat.scala} | 143 +------------- .../datasources/HadoopFsRelation.scala | 77 ++++++++ .../execution/datasources/OutputWriter.scala | 101 ++++++++++ .../parquet/ParquetFileFormat.scala | 144 -------------- .../parquet/ParquetOutputWriter.scala | 178 ++++++++++++++++++ 5 files changed, 359 insertions(+), 284 deletions(-) rename sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/{fileSourceInterfaces.scala => FileFormat.scala} (59%) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/OutputWriter.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOutputWriter.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala similarity index 59% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala index 69dd622ce4a54..bde2d2b89d56f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala @@ -20,152 +20,15 @@ package org.apache.spark.sql.execution.datasources import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs._ import org.apache.hadoop.io.compress.{CompressionCodecFactory, SplittableCompressionCodec} -import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} +import org.apache.hadoop.mapreduce.Job -import org.apache.spark.annotation.Experimental import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} -import org.apache.spark.sql.catalyst.catalog.BucketSpec +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection -import org.apache.spark.sql.execution.FileRelation -import org.apache.spark.sql.sources.{BaseRelation, DataSourceRegister, Filter} +import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType -/** - * ::Experimental:: - * A factory that produces [[OutputWriter]]s. A new [[OutputWriterFactory]] is created on driver - * side for each write job issued when writing to a [[HadoopFsRelation]], and then gets serialized - * to executor side to create actual [[OutputWriter]]s on the fly. - * - * @since 1.4.0 - */ -@Experimental -abstract class OutputWriterFactory extends Serializable { - /** - * When writing to a [[HadoopFsRelation]], this method gets called by each task on executor side - * to instantiate new [[OutputWriter]]s. - * - * @param path Path of the file to which this [[OutputWriter]] is supposed to write. Note that - * this may not point to the final output file. For example, `FileOutputFormat` writes to - * temporary directories and then merge written files back to the final destination. In - * this case, `path` points to a temporary output file under the temporary directory. - * @param dataSchema Schema of the rows to be written. Partition columns are not included in the - * schema if the relation being written is partitioned. - * @param context The Hadoop MapReduce task context. - * @since 1.4.0 - */ - def newInstance( - path: String, - bucketId: Option[Int], // TODO: This doesn't belong here... - dataSchema: StructType, - context: TaskAttemptContext): OutputWriter - - /** - * Returns a new instance of [[OutputWriter]] that will write data to the given path. - * This method gets called by each task on executor to write [[InternalRow]]s to - * format-specific files. Compared to the other `newInstance()`, this is a newer API that - * passes only the path that the writer must write to. The writer must write to the exact path - * and not modify it (do not add subdirectories, extensions, etc.). All other - * file-format-specific information needed to create the writer must be passed - * through the [[OutputWriterFactory]] implementation. - * @since 2.0.0 - */ - def newWriter(path: String): OutputWriter = { - throw new UnsupportedOperationException("newInstance with just path not supported") - } -} - -/** - * ::Experimental:: - * [[OutputWriter]] is used together with [[HadoopFsRelation]] for persisting rows to the - * underlying file system. Subclasses of [[OutputWriter]] must provide a zero-argument constructor. - * An [[OutputWriter]] instance is created and initialized when a new output file is opened on - * executor side. This instance is used to persist rows to this single output file. - * - * @since 1.4.0 - */ -@Experimental -abstract class OutputWriter { - /** - * Persists a single row. Invoked on the executor side. When writing to dynamically partitioned - * tables, dynamic partition columns are not included in rows to be written. - * - * @since 1.4.0 - */ - def write(row: Row): Unit - - /** - * Closes the [[OutputWriter]]. Invoked on the executor side after all rows are persisted, before - * the task output is committed. - * - * @since 1.4.0 - */ - def close(): Unit - - private var converter: InternalRow => Row = _ - - protected[sql] def initConverter(dataSchema: StructType) = { - converter = - CatalystTypeConverters.createToScalaConverter(dataSchema).asInstanceOf[InternalRow => Row] - } - - protected[sql] def writeInternal(row: InternalRow): Unit = { - write(converter(row)) - } -} - -/** - * Acts as a container for all of the metadata required to read from a datasource. All discovery, - * resolution and merging logic for schemas and partitions has been removed. - * - * @param location A [[FileCatalog]] that can enumerate the locations of all the files that comprise - * this relation. - * @param partitionSchema The schema of the columns (if any) that are used to partition the relation - * @param dataSchema The schema of any remaining columns. Note that if any partition columns are - * present in the actual data files as well, they are preserved. - * @param bucketSpec Describes the bucketing (hash-partitioning of the files by some column values). - * @param fileFormat A file format that can be used to read and write the data in files. - * @param options Configuration used when reading / writing data. - */ -case class HadoopFsRelation( - location: FileCatalog, - partitionSchema: StructType, - dataSchema: StructType, - bucketSpec: Option[BucketSpec], - fileFormat: FileFormat, - options: Map[String, String])(val sparkSession: SparkSession) - extends BaseRelation with FileRelation { - - override def sqlContext: SQLContext = sparkSession.sqlContext - - val schema: StructType = { - val dataSchemaColumnNames = dataSchema.map(_.name.toLowerCase).toSet - StructType(dataSchema ++ partitionSchema.filterNot { column => - dataSchemaColumnNames.contains(column.name.toLowerCase) - }) - } - - def partitionSchemaOption: Option[StructType] = - if (partitionSchema.isEmpty) None else Some(partitionSchema) - def partitionSpec: PartitionSpec = location.partitionSpec() - - def refresh(): Unit = location.refresh() - - override def toString: String = { - fileFormat match { - case source: DataSourceRegister => source.shortName() - case _ => "HadoopFiles" - } - } - - /** Returns the list of files that will be read when scanning this relation. */ - override def inputFiles: Array[String] = - location.allFiles().map(_.getPath.toUri.toString).toArray - - override def sizeInBytes: Long = location.allFiles().map(_.getLen).sum -} - /** * Used to read and write data stored in files to/from the [[InternalRow]] format. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala new file mode 100644 index 0000000000000..c7ebe0b76a150 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.spark.sql.{SparkSession, SQLContext} +import org.apache.spark.sql.catalyst.catalog.BucketSpec +import org.apache.spark.sql.execution.FileRelation +import org.apache.spark.sql.sources.{BaseRelation, DataSourceRegister} +import org.apache.spark.sql.types.StructType + + +/** + * Acts as a container for all of the metadata required to read from a datasource. All discovery, + * resolution and merging logic for schemas and partitions has been removed. + * + * @param location A [[FileCatalog]] that can enumerate the locations of all the files that comprise + * this relation. + * @param partitionSchema The schema of the columns (if any) that are used to partition the relation + * @param dataSchema The schema of any remaining columns. Note that if any partition columns are + * present in the actual data files as well, they are preserved. + * @param bucketSpec Describes the bucketing (hash-partitioning of the files by some column values). + * @param fileFormat A file format that can be used to read and write the data in files. + * @param options Configuration used when reading / writing data. + */ +case class HadoopFsRelation( + location: FileCatalog, + partitionSchema: StructType, + dataSchema: StructType, + bucketSpec: Option[BucketSpec], + fileFormat: FileFormat, + options: Map[String, String])(val sparkSession: SparkSession) + extends BaseRelation with FileRelation { + + override def sqlContext: SQLContext = sparkSession.sqlContext + + val schema: StructType = { + val dataSchemaColumnNames = dataSchema.map(_.name.toLowerCase).toSet + StructType(dataSchema ++ partitionSchema.filterNot { column => + dataSchemaColumnNames.contains(column.name.toLowerCase) + }) + } + + def partitionSchemaOption: Option[StructType] = + if (partitionSchema.isEmpty) None else Some(partitionSchema) + + def partitionSpec: PartitionSpec = location.partitionSpec() + + def refresh(): Unit = location.refresh() + + override def toString: String = { + fileFormat match { + case source: DataSourceRegister => source.shortName() + case _ => "HadoopFiles" + } + } + + /** Returns the list of files that will be read when scanning this relation. */ + override def inputFiles: Array[String] = + location.allFiles().map(_.getPath.toUri.toString).toArray + + override def sizeInBytes: Long = location.allFiles().map(_.getLen).sum +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/OutputWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/OutputWriter.scala new file mode 100644 index 0000000000000..d2eec7b1413f8 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/OutputWriter.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.hadoop.mapreduce.TaskAttemptContext + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.types.StructType + + +/** + * A factory that produces [[OutputWriter]]s. A new [[OutputWriterFactory]] is created on driver + * side for each write job issued when writing to a [[HadoopFsRelation]], and then gets serialized + * to executor side to create actual [[OutputWriter]]s on the fly. + */ +abstract class OutputWriterFactory extends Serializable { + /** + * When writing to a [[HadoopFsRelation]], this method gets called by each task on executor side + * to instantiate new [[OutputWriter]]s. + * + * @param path Path of the file to which this [[OutputWriter]] is supposed to write. Note that + * this may not point to the final output file. For example, `FileOutputFormat` writes to + * temporary directories and then merge written files back to the final destination. In + * this case, `path` points to a temporary output file under the temporary directory. + * @param dataSchema Schema of the rows to be written. Partition columns are not included in the + * schema if the relation being written is partitioned. + * @param context The Hadoop MapReduce task context. + * @since 1.4.0 + */ + def newInstance( + path: String, + bucketId: Option[Int], // TODO: This doesn't belong here... + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter + + /** + * Returns a new instance of [[OutputWriter]] that will write data to the given path. + * This method gets called by each task on executor to write InternalRows to + * format-specific files. Compared to the other `newInstance()`, this is a newer API that + * passes only the path that the writer must write to. The writer must write to the exact path + * and not modify it (do not add subdirectories, extensions, etc.). All other + * file-format-specific information needed to create the writer must be passed + * through the [[OutputWriterFactory]] implementation. + * @since 2.0.0 + */ + def newWriter(path: String): OutputWriter = { + throw new UnsupportedOperationException("newInstance with just path not supported") + } +} + + +/** + * [[OutputWriter]] is used together with [[HadoopFsRelation]] for persisting rows to the + * underlying file system. Subclasses of [[OutputWriter]] must provide a zero-argument constructor. + * An [[OutputWriter]] instance is created and initialized when a new output file is opened on + * executor side. This instance is used to persist rows to this single output file. + */ +abstract class OutputWriter { + /** + * Persists a single row. Invoked on the executor side. When writing to dynamically partitioned + * tables, dynamic partition columns are not included in rows to be written. + * + * @since 1.4.0 + */ + def write(row: Row): Unit + + /** + * Closes the [[OutputWriter]]. Invoked on the executor side after all rows are persisted, before + * the task output is committed. + * + * @since 1.4.0 + */ + def close(): Unit + + private var converter: InternalRow => Row = _ + + protected[sql] def initConverter(dataSchema: StructType) = { + converter = + CatalystTypeConverters.createToScalaConverter(dataSchema).asInstanceOf[InternalRow => Row] + } + + protected[sql] def writeInternal(row: InternalRow): Unit = { + write(converter(row)) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index 4a308ff1a32f8..6faafed1e6290 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -425,150 +425,6 @@ class ParquetFileFormat } } -/** - * A factory for generating OutputWriters for writing parquet files. This implemented is different - * from the [[ParquetOutputWriter]] as this does not use any [[OutputCommitter]]. It simply - * writes the data to the path used to generate the output writer. Callers of this factory - * has to ensure which files are to be considered as committed. - */ -private[parquet] class ParquetOutputWriterFactory( - sqlConf: SQLConf, - dataSchema: StructType, - hadoopConf: Configuration, - options: Map[String, String]) extends OutputWriterFactory { - - private val serializableConf: SerializableConfiguration = { - val job = Job.getInstance(hadoopConf) - val conf = ContextUtil.getConfiguration(job) - val parquetOptions = new ParquetOptions(options, sqlConf) - - // We're not really using `ParquetOutputFormat[Row]` for writing data here, because we override - // it in `ParquetOutputWriter` to support appending and dynamic partitioning. The reason why - // we set it here is to setup the output committer class to `ParquetOutputCommitter`, which is - // bundled with `ParquetOutputFormat[Row]`. - job.setOutputFormatClass(classOf[ParquetOutputFormat[Row]]) - - ParquetOutputFormat.setWriteSupportClass(job, classOf[ParquetWriteSupport]) - - // We want to clear this temporary metadata from saving into Parquet file. - // This metadata is only useful for detecting optional columns when pushing down filters. - val dataSchemaToWrite = StructType.removeMetadata( - StructType.metadataKeyForOptionalField, - dataSchema).asInstanceOf[StructType] - ParquetWriteSupport.setSchema(dataSchemaToWrite, conf) - - // Sets flags for `CatalystSchemaConverter` (which converts Catalyst schema to Parquet schema) - // and `CatalystWriteSupport` (writing actual rows to Parquet files). - conf.set( - SQLConf.PARQUET_BINARY_AS_STRING.key, - sqlConf.isParquetBinaryAsString.toString) - - conf.set( - SQLConf.PARQUET_INT96_AS_TIMESTAMP.key, - sqlConf.isParquetINT96AsTimestamp.toString) - - conf.set( - SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key, - sqlConf.writeLegacyParquetFormat.toString) - - // Sets compression scheme - conf.set(ParquetOutputFormat.COMPRESSION, parquetOptions.compressionCodec) - new SerializableConfiguration(conf) - } - - /** - * Returns a [[OutputWriter]] that writes data to the give path without using - * [[OutputCommitter]]. - */ - override def newWriter(path: String): OutputWriter = new OutputWriter { - - // Create TaskAttemptContext that is used to pass on Configuration to the ParquetRecordWriter - private val hadoopTaskAttemptId = new TaskAttemptID(new TaskID(new JobID, TaskType.MAP, 0), 0) - private val hadoopAttemptContext = new TaskAttemptContextImpl( - serializableConf.value, hadoopTaskAttemptId) - - // Instance of ParquetRecordWriter that does not use OutputCommitter - private val recordWriter = createNoCommitterRecordWriter(path, hadoopAttemptContext) - - override def write(row: Row): Unit = { - throw new UnsupportedOperationException("call writeInternal") - } - - protected[sql] override def writeInternal(row: InternalRow): Unit = { - recordWriter.write(null, row) - } - - override def close(): Unit = recordWriter.close(hadoopAttemptContext) - } - - /** Create a [[ParquetRecordWriter]] that writes the given path without using OutputCommitter */ - private def createNoCommitterRecordWriter( - path: String, - hadoopAttemptContext: TaskAttemptContext): RecordWriter[Void, InternalRow] = { - // Custom ParquetOutputFormat that disable use of committer and writes to the given path - val outputFormat = new ParquetOutputFormat[InternalRow]() { - override def getOutputCommitter(c: TaskAttemptContext): OutputCommitter = { null } - override def getDefaultWorkFile(c: TaskAttemptContext, ext: String): Path = { new Path(path) } - } - outputFormat.getRecordWriter(hadoopAttemptContext) - } - - /** Disable the use of the older API. */ - def newInstance( - path: String, - bucketId: Option[Int], - dataSchema: StructType, - context: TaskAttemptContext): OutputWriter = { - throw new UnsupportedOperationException( - "this version of newInstance not supported for " + - "ParquetOutputWriterFactory") - } -} - - -// NOTE: This class is instantiated and used on executor side only, no need to be serializable. -private[parquet] class ParquetOutputWriter( - path: String, - bucketId: Option[Int], - context: TaskAttemptContext) - extends OutputWriter { - - private val recordWriter: RecordWriter[Void, InternalRow] = { - val outputFormat = { - new ParquetOutputFormat[InternalRow]() { - // Here we override `getDefaultWorkFile` for two reasons: - // - // 1. To allow appending. We need to generate unique output file names to avoid - // overwriting existing files (either exist before the write job, or are just written - // by other tasks within the same write job). - // - // 2. To allow dynamic partitioning. Default `getDefaultWorkFile` uses - // `FileOutputCommitter.getWorkPath()`, which points to the base directory of all - // partitions in the case of dynamic partitioning. - override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { - val configuration = context.getConfiguration - val uniqueWriteJobId = configuration.get(WriterContainer.DATASOURCE_WRITEJOBUUID) - val taskAttemptId = context.getTaskAttemptID - val split = taskAttemptId.getTaskID.getId - val bucketString = bucketId.map(BucketingUtils.bucketIdToString).getOrElse("") - // It has the `.parquet` extension at the end because (de)compression tools - // such as gunzip would not be able to decompress this as the compression - // is not applied on this whole file but on each "page" in Parquet format. - new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$bucketString$extension") - } - } - } - - outputFormat.getRecordWriter(context) - } - - override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal") - - override def writeInternal(row: InternalRow): Unit = recordWriter.write(null, row) - - override def close(): Unit = recordWriter.close(context) -} - object ParquetFileFormat extends Logging { private[parquet] def readSchema( footers: Seq[Footer], sparkSession: SparkSession): Option[StructType] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOutputWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOutputWriter.scala new file mode 100644 index 0000000000000..f89ce05d82d90 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOutputWriter.scala @@ -0,0 +1,178 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapreduce._ +import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl +import org.apache.parquet.hadoop.{ParquetOutputFormat, ParquetRecordWriter} +import org.apache.parquet.hadoop.util.ContextUtil + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.datasources.{BucketingUtils, OutputWriter, OutputWriterFactory, WriterContainer} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.SerializableConfiguration + + +/** + * A factory for generating OutputWriters for writing parquet files. This implemented is different + * from the [[ParquetOutputWriter]] as this does not use any [[OutputCommitter]]. It simply + * writes the data to the path used to generate the output writer. Callers of this factory + * has to ensure which files are to be considered as committed. + */ +private[parquet] class ParquetOutputWriterFactory( + sqlConf: SQLConf, + dataSchema: StructType, + hadoopConf: Configuration, + options: Map[String, String]) + extends OutputWriterFactory { + + private val serializableConf: SerializableConfiguration = { + val job = Job.getInstance(hadoopConf) + val conf = ContextUtil.getConfiguration(job) + val parquetOptions = new ParquetOptions(options, sqlConf) + + // We're not really using `ParquetOutputFormat[Row]` for writing data here, because we override + // it in `ParquetOutputWriter` to support appending and dynamic partitioning. The reason why + // we set it here is to setup the output committer class to `ParquetOutputCommitter`, which is + // bundled with `ParquetOutputFormat[Row]`. + job.setOutputFormatClass(classOf[ParquetOutputFormat[Row]]) + + ParquetOutputFormat.setWriteSupportClass(job, classOf[ParquetWriteSupport]) + + // We want to clear this temporary metadata from saving into Parquet file. + // This metadata is only useful for detecting optional columns when pushing down filters. + val dataSchemaToWrite = StructType.removeMetadata( + StructType.metadataKeyForOptionalField, + dataSchema).asInstanceOf[StructType] + ParquetWriteSupport.setSchema(dataSchemaToWrite, conf) + + // Sets flags for `CatalystSchemaConverter` (which converts Catalyst schema to Parquet schema) + // and `CatalystWriteSupport` (writing actual rows to Parquet files). + conf.set( + SQLConf.PARQUET_BINARY_AS_STRING.key, + sqlConf.isParquetBinaryAsString.toString) + + conf.set( + SQLConf.PARQUET_INT96_AS_TIMESTAMP.key, + sqlConf.isParquetINT96AsTimestamp.toString) + + conf.set( + SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key, + sqlConf.writeLegacyParquetFormat.toString) + + // Sets compression scheme + conf.set(ParquetOutputFormat.COMPRESSION, parquetOptions.compressionCodec) + new SerializableConfiguration(conf) + } + + /** + * Returns a [[OutputWriter]] that writes data to the give path without using + * [[OutputCommitter]]. + */ + override def newWriter(path: String): OutputWriter = new OutputWriter { + + // Create TaskAttemptContext that is used to pass on Configuration to the ParquetRecordWriter + private val hadoopTaskAttemptId = new TaskAttemptID(new TaskID(new JobID, TaskType.MAP, 0), 0) + private val hadoopAttemptContext = new TaskAttemptContextImpl( + serializableConf.value, hadoopTaskAttemptId) + + // Instance of ParquetRecordWriter that does not use OutputCommitter + private val recordWriter = createNoCommitterRecordWriter(path, hadoopAttemptContext) + + override def write(row: Row): Unit = { + throw new UnsupportedOperationException("call writeInternal") + } + + protected[sql] override def writeInternal(row: InternalRow): Unit = { + recordWriter.write(null, row) + } + + override def close(): Unit = recordWriter.close(hadoopAttemptContext) + } + + /** Create a [[ParquetRecordWriter]] that writes the given path without using OutputCommitter */ + private def createNoCommitterRecordWriter( + path: String, + hadoopAttemptContext: TaskAttemptContext): RecordWriter[Void, InternalRow] = { + // Custom ParquetOutputFormat that disable use of committer and writes to the given path + val outputFormat = new ParquetOutputFormat[InternalRow]() { + override def getOutputCommitter(c: TaskAttemptContext): OutputCommitter = { null } + override def getDefaultWorkFile(c: TaskAttemptContext, ext: String): Path = { new Path(path) } + } + outputFormat.getRecordWriter(hadoopAttemptContext) + } + + /** Disable the use of the older API. */ + def newInstance( + path: String, + bucketId: Option[Int], + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter = { + throw new UnsupportedOperationException( + "this version of newInstance not supported for " + + "ParquetOutputWriterFactory") + } +} + + +// NOTE: This class is instantiated and used on executor side only, no need to be serializable. +private[parquet] class ParquetOutputWriter( + path: String, + bucketId: Option[Int], + context: TaskAttemptContext) + extends OutputWriter { + + private val recordWriter: RecordWriter[Void, InternalRow] = { + val outputFormat = { + new ParquetOutputFormat[InternalRow]() { + // Here we override `getDefaultWorkFile` for two reasons: + // + // 1. To allow appending. We need to generate unique output file names to avoid + // overwriting existing files (either exist before the write job, or are just written + // by other tasks within the same write job). + // + // 2. To allow dynamic partitioning. Default `getDefaultWorkFile` uses + // `FileOutputCommitter.getWorkPath()`, which points to the base directory of all + // partitions in the case of dynamic partitioning. + override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { + val configuration = context.getConfiguration + val uniqueWriteJobId = configuration.get(WriterContainer.DATASOURCE_WRITEJOBUUID) + val taskAttemptId = context.getTaskAttemptID + val split = taskAttemptId.getTaskID.getId + val bucketString = bucketId.map(BucketingUtils.bucketIdToString).getOrElse("") + // It has the `.parquet` extension at the end because (de)compression tools + // such as gunzip would not be able to decompress this as the compression + // is not applied on this whole file but on each "page" in Parquet format. + new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$bucketString$extension") + } + } + } + + outputFormat.getRecordWriter(context) + } + + override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal") + + override def writeInternal(row: InternalRow): Unit = recordWriter.write(null, row) + + override def close(): Unit = recordWriter.close(context) +} From 2fb12b0a33deeeadfac451095f64dea6c967caac Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 14 Oct 2016 15:53:50 +0800 Subject: [PATCH 283/306] [SPARK-17903][SQL] MetastoreRelation should talk to external catalog instead of hive client ## What changes were proposed in this pull request? `HiveExternalCatalog` should be the only interface to talk to the hive metastore. In `MetastoreRelation` we can just use `ExternalCatalog` instead of `HiveClient` to interact with hive metastore, and add missing API in `ExternalCatalog`. ## How was this patch tested? existing tests. Author: Wenchen Fan Closes #15460 from cloud-fan/relation. --- .../catalyst/catalog/ExternalCatalog.scala | 13 +++++++++++++ .../catalyst/catalog/InMemoryCatalog.scala | 8 ++++++++ .../spark/sql/hive/HiveExternalCatalog.scala | 8 ++++++++ .../spark/sql/hive/HiveMetastoreCatalog.scala | 7 +++---- .../spark/sql/hive/MetastoreRelation.scala | 19 ++++++++++++------- .../apache/spark/sql/hive/TableReader.scala | 3 +-- .../spark/sql/hive/client/HiveClient.scala | 15 +++------------ .../sql/hive/client/HiveClientImpl.scala | 10 ++++++---- .../sql/hive/HiveExternalCatalogSuite.scala | 9 +++++++++ .../sql/hive/MetastoreRelationSuite.scala | 2 +- .../spark/sql/hive/client/VersionsSuite.scala | 4 ++-- 11 files changed, 66 insertions(+), 32 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala index dd93b467eeeb2..348d3d0be2152 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.catalog import org.apache.spark.sql.catalyst.analysis.{FunctionAlreadyExistsException, NoSuchDatabaseException, NoSuchFunctionException} +import org.apache.spark.sql.catalyst.expressions.Expression /** @@ -196,6 +197,18 @@ abstract class ExternalCatalog { table: String, partialSpec: Option[TablePartitionSpec] = None): Seq[CatalogTablePartition] + /** + * List the metadata of selected partitions according to the given partition predicates. + * + * @param db database name + * @param table table name + * @param predicates partition predicated + */ + def listPartitionsByFilter( + db: String, + table: String, + predicates: Seq[Expression]): Seq[CatalogTablePartition] + // -------------------------------------------------------------------------- // Functions // -------------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala index 3e31127118b44..49280f82e20be 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala @@ -28,6 +28,7 @@ import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.util.StringUtils /** @@ -477,6 +478,13 @@ class InMemoryCatalog( catalog(db).tables(table).partitions.values.toSeq } + override def listPartitionsByFilter( + db: String, + table: String, + predicates: Seq[Expression]): Seq[CatalogTablePartition] = { + throw new UnsupportedOperationException("listPartitionsByFilter is not implemented.") + } + // -------------------------------------------------------------------------- // Functions // -------------------------------------------------------------------------- diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 237b829da882f..b5d93c3d7c804 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException import org.apache.spark.sql.catalyst.catalog._ +import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Statistics} import org.apache.spark.sql.execution.command.{ColumnStatStruct, DDLUtils} import org.apache.spark.sql.execution.datasources.CaseInsensitiveMap @@ -646,6 +647,13 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat client.getPartitions(db, table, partialSpec) } + override def listPartitionsByFilter( + db: String, + table: String, + predicates: Seq[Expression]): Seq[CatalogTablePartition] = { + client.getPartitionsByFilter(db, table, predicates) + } + // -------------------------------------------------------------------------- // Functions // -------------------------------------------------------------------------- diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 8410a2e4a47ca..c44f0adda44c0 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -44,8 +44,6 @@ import org.apache.spark.sql.types._ */ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Logging { private val sessionState = sparkSession.sessionState.asInstanceOf[HiveSessionState] - private val client = - sparkSession.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client /** A fully qualified identifier for a table (i.e., database.tableName) */ case class QualifiedTableName(database: String, name: String) @@ -104,7 +102,8 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log def hiveDefaultTableFilePath(tableIdent: TableIdentifier): String = { // Code based on: hiveWarehouse.getTablePath(currentDatabase, tableName) val QualifiedTableName(dbName, tblName) = getQualifiedTableName(tableIdent) - new Path(new Path(client.getDatabase(dbName).locationUri), tblName).toString + val dbLocation = sparkSession.sharedState.externalCatalog.getDatabase(dbName).locationUri + new Path(new Path(dbLocation), tblName).toString } def lookupRelation( @@ -129,7 +128,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log } else { val qualifiedTable = MetastoreRelation( - qualifiedTableName.database, qualifiedTableName.name)(table, client, sparkSession) + qualifiedTableName.database, qualifiedTableName.name)(table, sparkSession) alias.map(a => SubqueryAlias(a, qualifiedTable, None)).getOrElse(qualifiedTable) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/MetastoreRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/MetastoreRelation.scala index 33f0ecff63529..da809cf991de2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/MetastoreRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/MetastoreRelation.scala @@ -43,7 +43,6 @@ private[hive] case class MetastoreRelation( databaseName: String, tableName: String) (val catalogTable: CatalogTable, - @transient private val client: HiveClient, @transient private val sparkSession: SparkSession) extends LeafNode with MultiInstanceRelation with FileRelation with CatalogRelation { @@ -59,7 +58,7 @@ private[hive] case class MetastoreRelation( Objects.hashCode(databaseName, tableName, output) } - override protected def otherCopyArgs: Seq[AnyRef] = catalogTable :: client :: sparkSession :: Nil + override protected def otherCopyArgs: Seq[AnyRef] = catalogTable :: sparkSession :: Nil private def toHiveColumn(c: StructField): FieldSchema = { new FieldSchema(c.name, c.dataType.catalogString, c.getComment.orNull) @@ -146,11 +145,18 @@ private[hive] case class MetastoreRelation( // When metastore partition pruning is turned off, we cache the list of all partitions to // mimic the behavior of Spark < 1.5 - private lazy val allPartitions: Seq[CatalogTablePartition] = client.getPartitions(catalogTable) + private lazy val allPartitions: Seq[CatalogTablePartition] = { + sparkSession.sharedState.externalCatalog.listPartitions( + catalogTable.database, + catalogTable.identifier.table) + } def getHiveQlPartitions(predicates: Seq[Expression] = Nil): Seq[Partition] = { val rawPartitions = if (sparkSession.sessionState.conf.metastorePartitionPruning) { - client.getPartitionsByFilter(catalogTable, predicates) + sparkSession.sharedState.externalCatalog.listPartitionsByFilter( + catalogTable.database, + catalogTable.identifier.table, + predicates) } else { allPartitions } @@ -234,8 +240,7 @@ private[hive] case class MetastoreRelation( val columnOrdinals = AttributeMap(attributes.zipWithIndex) override def inputFiles: Array[String] = { - val partLocations = client - .getPartitionsByFilter(catalogTable, Nil) + val partLocations = allPartitions .flatMap(_.storage.locationUri) .toArray if (partLocations.nonEmpty) { @@ -248,6 +253,6 @@ private[hive] case class MetastoreRelation( } override def newInstance(): MetastoreRelation = { - MetastoreRelation(databaseName, tableName)(catalogTable, client, sparkSession) + MetastoreRelation(databaseName, tableName)(catalogTable, sparkSession) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index 2a54163a04e9b..aaf30f41f29c2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -149,8 +149,7 @@ class HadoopTableReader( * subdirectory of each partition being read. If None, then all files are accepted. */ def makeRDDForPartitionedTable( - partitionToDeserializer: Map[HivePartition, - Class[_ <: Deserializer]], + partitionToDeserializer: Map[HivePartition, Class[_ <: Deserializer]], filterOpt: Option[PathFilter]): RDD[InternalRow] = { // SPARK-5068:get FileStatus and do the filtering locally when the path is not exists diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala index 984d23bb09dbd..9ee3d629c9977 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala @@ -172,24 +172,15 @@ private[hive] trait HiveClient { * Returns the partitions for the given table that match the supplied partition spec. * If no partition spec is specified, all partitions are returned. */ - final def getPartitions( + def getPartitions( db: String, table: String, - partialSpec: Option[TablePartitionSpec]): Seq[CatalogTablePartition] = { - getPartitions(getTable(db, table), partialSpec) - } - - /** - * Returns the partitions for the given table that match the supplied partition spec. - * If no partition spec is specified, all partitions are returned. - */ - def getPartitions( - table: CatalogTable, partialSpec: Option[TablePartitionSpec] = None): Seq[CatalogTablePartition] /** Returns partitions filtered by predicates for the given table. */ def getPartitionsByFilter( - table: CatalogTable, + db: String, + table: String, predicates: Seq[Expression]): Seq[CatalogTablePartition] /** Loads a static partition into an existing table. */ diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index dd33d750a4d45..5c8f7ff1af9fa 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -525,9 +525,10 @@ private[hive] class HiveClientImpl( * If no partition spec is specified, all partitions are returned. */ override def getPartitions( - table: CatalogTable, + db: String, + table: String, spec: Option[TablePartitionSpec]): Seq[CatalogTablePartition] = withHiveState { - val hiveTable = toHiveTable(table) + val hiveTable = toHiveTable(getTable(db, table)) spec match { case None => shim.getAllPartitions(client, hiveTable).map(fromHivePartition) case Some(s) => client.getPartitions(hiveTable, s.asJava).asScala.map(fromHivePartition) @@ -535,9 +536,10 @@ private[hive] class HiveClientImpl( } override def getPartitionsByFilter( - table: CatalogTable, + db: String, + table: String, predicates: Seq[Expression]): Seq[CatalogTablePartition] = withHiveState { - val hiveTable = toHiveTable(table) + val hiveTable = toHiveTable(getTable(db, table)) shim.getPartitionsByFilter(client, hiveTable, predicates).map(fromHivePartition) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala index 26c2549820de6..efa0beb85030b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala @@ -21,6 +21,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.spark.SparkConf import org.apache.spark.sql.catalyst.catalog._ +import org.apache.spark.sql.catalyst.dsl.expressions._ /** * Test suite for the [[HiveExternalCatalog]]. @@ -43,4 +44,12 @@ class HiveExternalCatalogSuite extends ExternalCatalogSuite { externalCatalog.client.reset() } + import utils._ + + test("list partitions by filter") { + val catalog = newBasicCatalog() + val selectedPartitions = catalog.listPartitionsByFilter("db2", "tbl2", Seq('a.int === 1)) + assert(selectedPartitions.length == 1) + assert(selectedPartitions.head.spec == part1.spec) + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreRelationSuite.scala index 2f3055dcac4c5..c28e41a85c39d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreRelationSuite.scala @@ -29,7 +29,7 @@ class MetastoreRelationSuite extends SparkFunSuite { tableType = CatalogTableType.VIEW, storage = CatalogStorageFormat.empty, schema = StructType(StructField("a", IntegerType, true) :: Nil)) - val relation = MetastoreRelation("db", "test")(table, null, null) + val relation = MetastoreRelation("db", "test")(table, null) // No exception should be thrown relation.makeCopy(Array("db", "test")) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index 9a10957c8efa5..c158bf1ab09cb 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -295,12 +295,12 @@ class VersionsSuite extends SparkFunSuite with Logging { } test(s"$version: getPartitions(catalogTable)") { - assert(2 == client.getPartitions(client.getTable("default", "src_part")).size) + assert(2 == client.getPartitions("default", "src_part").size) } test(s"$version: getPartitionsByFilter") { // Only one partition [1, 1] for key2 == 1 - val result = client.getPartitionsByFilter(client.getTable("default", "src_part"), + val result = client.getPartitionsByFilter("default", "src_part", Seq(EqualTo(AttributeReference("key2", IntegerType)(), Literal(1)))) // Hive 0.12 doesn't support getPartitionsByFilter, it ignores the filter condition. From 1db8feab8c564053c05e8bdc1a7f5026fd637d4f Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Fri, 14 Oct 2016 04:17:03 -0700 Subject: [PATCH 284/306] [SPARK-15402][ML][PYSPARK] PySpark ml.evaluation should support save/load ## What changes were proposed in this pull request? Since ```ml.evaluation``` has supported save/load at Scala side, supporting it at Python side is very straightforward and easy. ## How was this patch tested? Add python doctest. Author: Yanbo Liang Closes #13194 from yanboliang/spark-15402. --- python/pyspark/ml/evaluation.py | 45 ++++++++++++++++++++++++++------- 1 file changed, 36 insertions(+), 9 deletions(-) diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py index 1fe8772da772a..7aa16fa5b90f2 100644 --- a/python/pyspark/ml/evaluation.py +++ b/python/pyspark/ml/evaluation.py @@ -22,6 +22,7 @@ from pyspark.ml.param import Param, Params, TypeConverters from pyspark.ml.param.shared import HasLabelCol, HasPredictionCol, HasRawPredictionCol from pyspark.ml.common import inherit_doc +from pyspark.ml.util import JavaMLReadable, JavaMLWritable __all__ = ['Evaluator', 'BinaryClassificationEvaluator', 'RegressionEvaluator', 'MulticlassClassificationEvaluator'] @@ -103,7 +104,8 @@ def isLargerBetter(self): @inherit_doc -class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPredictionCol): +class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPredictionCol, + JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -121,6 +123,11 @@ class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPrediction 0.70... >>> evaluator.evaluate(dataset, {evaluator.metricName: "areaUnderPR"}) 0.83... + >>> bce_path = temp_path + "/bce" + >>> evaluator.save(bce_path) + >>> evaluator2 = BinaryClassificationEvaluator.load(bce_path) + >>> str(evaluator2.getRawPredictionCol()) + 'raw' .. versionadded:: 1.4.0 """ @@ -172,7 +179,8 @@ def setParams(self, rawPredictionCol="rawPrediction", labelCol="label", @inherit_doc -class RegressionEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol): +class RegressionEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol, + JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -190,6 +198,11 @@ class RegressionEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol): 0.993... >>> evaluator.evaluate(dataset, {evaluator.metricName: "mae"}) 2.649... + >>> re_path = temp_path + "/re" + >>> evaluator.save(re_path) + >>> evaluator2 = RegressionEvaluator.load(re_path) + >>> str(evaluator2.getPredictionCol()) + 'raw' .. versionadded:: 1.4.0 """ @@ -244,7 +257,8 @@ def setParams(self, predictionCol="prediction", labelCol="label", @inherit_doc -class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol): +class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol, + JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -260,6 +274,11 @@ class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictio 0.66... >>> evaluator.evaluate(dataset, {evaluator.metricName: "accuracy"}) 0.66... + >>> mce_path = temp_path + "/mce" + >>> evaluator.save(mce_path) + >>> evaluator2 = MulticlassClassificationEvaluator.load(mce_path) + >>> str(evaluator2.getPredictionCol()) + 'prediction' .. versionadded:: 1.5.0 """ @@ -311,19 +330,27 @@ def setParams(self, predictionCol="prediction", labelCol="label", if __name__ == "__main__": import doctest + import tempfile + import pyspark.ml.evaluation from pyspark.sql import SparkSession - globs = globals().copy() + globs = pyspark.ml.evaluation.__dict__.copy() # The small batch size here ensures that we see multiple batches, # even in these small test examples: spark = SparkSession.builder\ .master("local[2]")\ .appName("ml.evaluation tests")\ .getOrCreate() - sc = spark.sparkContext - globs['sc'] = sc globs['spark'] = spark - (failure_count, test_count) = doctest.testmod( - globs=globs, optionflags=doctest.ELLIPSIS) - spark.stop() + temp_path = tempfile.mkdtemp() + globs['temp_path'] = temp_path + try: + (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + spark.stop() + finally: + from shutil import rmtree + try: + rmtree(temp_path) + except OSError: + pass if failure_count: exit(-1) From a1b136d05c6c458ae8211b0844bfc98d7693fa42 Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Fri, 14 Oct 2016 04:25:14 -0700 Subject: [PATCH 285/306] [SPARK-14634][ML] Add BisectingKMeansSummary ## What changes were proposed in this pull request? Add BisectingKMeansSummary ## How was this patch tested? unit test Author: Zheng RuiFeng Closes #12394 from zhengruifeng/biKMSummary. --- .../spark/ml/clustering/BisectingKMeans.scala | 74 ++++++++++++++++++- .../ml/clustering/BisectingKMeansSuite.scala | 18 ++++- .../ml/clustering/GaussianMixtureSuite.scala | 2 +- .../spark/ml/clustering/KMeansSuite.scala | 2 +- 4 files changed, 91 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala index a97bd0fb16fd7..add8ee2a4ff8e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.clustering import org.apache.hadoop.fs.Path +import org.apache.spark.SparkException import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.linalg.{Vector, VectorUDT} @@ -127,6 +128,29 @@ class BisectingKMeansModel private[ml] ( @Since("2.0.0") override def write: MLWriter = new BisectingKMeansModel.BisectingKMeansModelWriter(this) + + private var trainingSummary: Option[BisectingKMeansSummary] = None + + private[clustering] def setSummary(summary: BisectingKMeansSummary): this.type = { + this.trainingSummary = Some(summary) + this + } + + /** + * Return true if there exists summary of model. + */ + @Since("2.1.0") + def hasSummary: Boolean = trainingSummary.nonEmpty + + /** + * Gets summary of model on training set. An exception is + * thrown if `trainingSummary == None`. + */ + @Since("2.1.0") + def summary: BisectingKMeansSummary = trainingSummary.getOrElse { + throw new SparkException( + s"No training summary available for the ${this.getClass.getSimpleName}") + } } object BisectingKMeansModel extends MLReadable[BisectingKMeansModel] { @@ -228,14 +252,22 @@ class BisectingKMeans @Since("2.0.0") ( case Row(point: Vector) => OldVectors.fromML(point) } + val instr = Instrumentation.create(this, rdd) + instr.logParams(featuresCol, predictionCol, k, maxIter, seed, minDivisibleClusterSize) + val bkm = new MLlibBisectingKMeans() .setK($(k)) .setMaxIterations($(maxIter)) .setMinDivisibleClusterSize($(minDivisibleClusterSize)) .setSeed($(seed)) val parentModel = bkm.run(rdd) - val model = new BisectingKMeansModel(uid, parentModel) - copyValues(model.setParent(this)) + val model = copyValues(new BisectingKMeansModel(uid, parentModel).setParent(this)) + val summary = new BisectingKMeansSummary( + model.transform(dataset), $(predictionCol), $(featuresCol), $(k)) + model.setSummary(summary) + val m = model.setSummary(summary) + instr.logSuccess(m) + m } @Since("2.0.0") @@ -251,3 +283,41 @@ object BisectingKMeans extends DefaultParamsReadable[BisectingKMeans] { @Since("2.0.0") override def load(path: String): BisectingKMeans = super.load(path) } + + +/** + * :: Experimental :: + * Summary of BisectingKMeans. + * + * @param predictions [[DataFrame]] produced by [[BisectingKMeansModel.transform()]] + * @param predictionCol Name for column of predicted clusters in `predictions` + * @param featuresCol Name for column of features in `predictions` + * @param k Number of clusters + */ +@Since("2.1.0") +@Experimental +class BisectingKMeansSummary private[clustering] ( + @Since("2.1.0") @transient val predictions: DataFrame, + @Since("2.1.0") val predictionCol: String, + @Since("2.1.0") val featuresCol: String, + @Since("2.1.0") val k: Int) extends Serializable { + + /** + * Cluster centers of the transformed data. + */ + @Since("2.1.0") + @transient lazy val cluster: DataFrame = predictions.select(predictionCol) + + /** + * Size of (number of data points in) each cluster. + */ + @Since("2.1.0") + lazy val clusterSizes: Array[Long] = { + val sizes = Array.fill[Long](k)(0) + cluster.groupBy(predictionCol).count().select(predictionCol, "count").collect().foreach { + case Row(cluster: Int, count: Long) => sizes(cluster) = count + } + sizes + } + +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala index 4f7d4418a8d09..f2368a9f8dad5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala @@ -68,7 +68,7 @@ class BisectingKMeansSuite } } - test("fit & transform") { + test("fit, transform and summary") { val predictionColName = "bisecting_kmeans_prediction" val bkm = new BisectingKMeans().setK(k).setPredictionCol(predictionColName).setSeed(1) val model = bkm.fit(dataset) @@ -85,6 +85,22 @@ class BisectingKMeansSuite assert(clusters === Set(0, 1, 2, 3, 4)) assert(model.computeCost(dataset) < 0.1) assert(model.hasParent) + + // Check validity of model summary + val numRows = dataset.count() + assert(model.hasSummary) + val summary: BisectingKMeansSummary = model.summary + assert(summary.predictionCol === predictionColName) + assert(summary.featuresCol === "features") + assert(summary.predictions.count() === numRows) + for (c <- Array(predictionColName, "features")) { + assert(summary.predictions.columns.contains(c)) + } + assert(summary.cluster.columns === Array(predictionColName)) + val clusterSizes = summary.clusterSizes + assert(clusterSizes.length === k) + assert(clusterSizes.sum === numRows) + assert(clusterSizes.forall(_ >= 0)) } test("read/write") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala index 04366f5250287..003fa6abf6597 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala @@ -70,7 +70,7 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext } } - test("fit, transform, and summary") { + test("fit, transform and summary") { val predictionColName = "gm_prediction" val probabilityColName = "gm_probability" val gm = new GaussianMixture().setK(k).setMaxIter(2).setPredictionCol(predictionColName) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index c9ba5a288aadf..ca392653557c4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -82,7 +82,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR } } - test("fit, transform, and summary") { + test("fit, transform and summary") { val predictionColName = "kmeans_prediction" val kmeans = new KMeans().setK(k).setPredictionCol(predictionColName).setSeed(1) val model = kmeans.fit(dataset) From c8b612decba28e51789891f7881b6d4ebc50e2bb Mon Sep 17 00:00:00 2001 From: Peng Date: Fri, 14 Oct 2016 12:48:57 +0100 Subject: [PATCH 286/306] [SPARK-17870][MLLIB][ML] Change statistic to pValue for SelectKBest and SelectPercentile because of DoF difference ## What changes were proposed in this pull request? For feature selection method ChiSquareSelector, it is based on the ChiSquareTestResult.statistic (ChiSqure value) to select the features. It select the features with the largest ChiSqure value. But the Degree of Freedom (df) of ChiSqure value is different in Statistics.chiSqTest(RDD), and for different df, you cannot base on ChiSqure value to select features. So we change statistic to pValue for SelectKBest and SelectPercentile ## How was this patch tested? change existing test Author: Peng Closes #15444 from mpjlu/chisqure-bug. --- .../org/apache/spark/mllib/feature/ChiSqSelector.scala | 4 ++-- .../org/apache/spark/ml/feature/ChiSqSelectorSuite.scala | 6 +++--- .../apache/spark/mllib/feature/ChiSqSelectorSuite.scala | 8 ++++---- python/pyspark/ml/feature.py | 4 ++-- python/pyspark/mllib/feature.py | 8 ++++---- 5 files changed, 15 insertions(+), 15 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala index c305b36278e87..f8276de4f23d4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala @@ -234,11 +234,11 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable { val features = selectorType match { case ChiSqSelector.KBest => chiSqTestResult - .sortBy { case (res, _) => -res.statistic } + .sortBy { case (res, _) => res.pValue } .take(numTopFeatures) case ChiSqSelector.Percentile => chiSqTestResult - .sortBy { case (res, _) => -res.statistic } + .sortBy { case (res, _) => res.pValue } .take((chiSqTestResult.length * percentile).toInt) case ChiSqSelector.FPR => chiSqTestResult diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala index dfebfc87ea1d3..6af06d82d671a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala @@ -38,10 +38,10 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext ) val preFilteredData = Seq( - Vectors.dense(0.0), - Vectors.dense(6.0), Vectors.dense(8.0), - Vectors.dense(5.0) + Vectors.dense(0.0), + Vectors.dense(0.0), + Vectors.dense(8.0) ) val df = sc.parallelize(data.zip(preFilteredData)) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala index ec23a4aa7364d..ac702b4b7c69e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala @@ -54,10 +54,10 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext { LabeledPoint(1.0, Vectors.dense(Array(0.0, 9.0, 8.0))), LabeledPoint(2.0, Vectors.dense(Array(8.0, 9.0, 5.0)))), 2) val preFilteredData = - Set(LabeledPoint(0.0, Vectors.dense(Array(0.0))), - LabeledPoint(1.0, Vectors.dense(Array(6.0))), - LabeledPoint(1.0, Vectors.dense(Array(8.0))), - LabeledPoint(2.0, Vectors.dense(Array(5.0)))) + Set(LabeledPoint(0.0, Vectors.dense(Array(8.0))), + LabeledPoint(1.0, Vectors.dense(Array(0.0))), + LabeledPoint(1.0, Vectors.dense(Array(0.0))), + LabeledPoint(2.0, Vectors.dense(Array(8.0)))) val model = new ChiSqSelector(1).fit(labeledDiscreteData) val filteredData = labeledDiscreteData.map { lp => LabeledPoint(lp.label, model.transform(lp.features)) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index a33c3e79453e1..7683360664ebd 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -2592,9 +2592,9 @@ class ChiSqSelector(JavaEstimator, HasFeaturesCol, HasOutputCol, HasLabelCol, Ja >>> selector = ChiSqSelector(numTopFeatures=1, outputCol="selectedFeatures") >>> model = selector.fit(df) >>> model.transform(df).head().selectedFeatures - DenseVector([1.0]) + DenseVector([18.0]) >>> model.selectedFeatures - [3] + [2] >>> chiSqSelectorPath = temp_path + "/chi-sq-selector" >>> selector.save(chiSqSelectorPath) >>> loadedSelector = ChiSqSelector.load(chiSqSelectorPath) diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index 4aea81840a162..50ef7c7901c2c 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -288,15 +288,15 @@ class ChiSqSelector(object): ... ] >>> model = ChiSqSelector().setNumTopFeatures(1).fit(sc.parallelize(data)) >>> model.transform(SparseVector(3, {1: 9.0, 2: 6.0})) - SparseVector(1, {0: 6.0}) + SparseVector(1, {}) >>> model.transform(DenseVector([8.0, 9.0, 5.0])) - DenseVector([5.0]) + DenseVector([8.0]) >>> model = ChiSqSelector().setSelectorType("percentile").setPercentile(0.34).fit( ... sc.parallelize(data)) >>> model.transform(SparseVector(3, {1: 9.0, 2: 6.0})) - SparseVector(1, {0: 6.0}) + SparseVector(1, {}) >>> model.transform(DenseVector([8.0, 9.0, 5.0])) - DenseVector([5.0]) + DenseVector([8.0]) >>> data = [ ... LabeledPoint(0.0, SparseVector(4, {0: 8.0, 1: 7.0})), ... LabeledPoint(1.0, SparseVector(4, {1: 9.0, 2: 6.0, 3: 4.0})), From 28b645b1e643ae0f6c56cbe5a92356623306717f Mon Sep 17 00:00:00 2001 From: invkrh Date: Fri, 14 Oct 2016 12:52:08 +0100 Subject: [PATCH 287/306] [SPARK-17855][CORE] Remove query string from jar url ## What changes were proposed in this pull request? Spark-submit support jar url with http protocol. However, if the url contains any query strings, `worker.DriverRunner.downloadUserJar()` method will throw "Did not see expected jar" exception. This is because this method checks the existance of a downloaded jar whose name contains query strings. This is a problem when your jar is located on some web service which requires some additional information to retrieve the file. This pr just removes query strings before checking jar existance on worker. ## How was this patch tested? For now, you can only test this patch by manual test. * Deploy a spark cluster locally * Make sure apache httpd service is on * Save an uber jar, e.g spark-job.jar under `/var/www/html/` * Use http://localhost/spark-job.jar?param=1 as jar url when running `spark-submit` * Job should be launched Author: invkrh Closes #15420 from invkrh/spark-17855. --- .../spark/deploy/worker/DriverRunner.scala | 24 +++++++------------ 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala index 289b0b93b0e84..e878c10183f61 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala @@ -18,12 +18,12 @@ package org.apache.spark.deploy.worker import java.io._ +import java.net.URI import java.nio.charset.StandardCharsets import scala.collection.JavaConverters._ import com.google.common.io.Files -import org.apache.hadoop.fs.Path import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.deploy.{DriverDescription, SparkHadoopUtil} @@ -147,30 +147,24 @@ private[deploy] class DriverRunner( * Will throw an exception if there are errors downloading the jar. */ private def downloadUserJar(driverDir: File): String = { - val jarPath = new Path(driverDesc.jarUrl) - val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) - val destPath = new File(driverDir.getAbsolutePath, jarPath.getName) - val jarFileName = jarPath.getName + val jarFileName = new URI(driverDesc.jarUrl).getPath.split("/").last val localJarFile = new File(driverDir, jarFileName) - val localJarFilename = localJarFile.getAbsolutePath - if (!localJarFile.exists()) { // May already exist if running multiple workers on one node - logInfo(s"Copying user jar $jarPath to $destPath") + logInfo(s"Copying user jar ${driverDesc.jarUrl} to $localJarFile") Utils.fetchFile( driverDesc.jarUrl, driverDir, conf, securityManager, - hadoopConf, + SparkHadoopUtil.get.newConfiguration(conf), System.currentTimeMillis(), useCache = false) + if (!localJarFile.exists()) { // Verify copy succeeded + throw new IOException( + s"Can not find expected jar $jarFileName which should have been loaded in $driverDir") + } } - - if (!localJarFile.exists()) { // Verify copy succeeded - throw new Exception(s"Did not see expected jar $jarFileName in $driverDir") - } - - localJarFilename + localJarFile.getAbsolutePath } private[worker] def prepareAndRunDriver(): Int = { From 7486442fe0b70f2aea21d569604e71d7ddf19a77 Mon Sep 17 00:00:00 2001 From: wangzhenhua Date: Fri, 14 Oct 2016 21:18:49 +0800 Subject: [PATCH 288/306] [SPARK-17073][SQL][FOLLOWUP] generate column-level statistics ## What changes were proposed in this pull request? This pr adds some test cases for statistics: case sensitive column names, non ascii column names, refresh table, and also improves some documentation. ## How was this patch tested? add test cases Author: wangzhenhua Closes #15360 from wzhfy/colStats2. --- .../command/AnalyzeColumnCommand.scala | 53 ++--- .../apache/spark/sql/internal/SQLConf.scala | 3 +- .../spark/sql/hive/StatisticsSuite.scala | 198 +++++++++++++++--- 3 files changed, 197 insertions(+), 57 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala index 7066378279971..488138709a12b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -59,10 +59,12 @@ case class AnalyzeColumnCommand( def updateStats(catalogTable: CatalogTable, newTotalSize: Long): Unit = { val (rowCount, columnStats) = computeColStats(sparkSession, relation) + // We also update table-level stats in order to keep them consistent with column-level stats. val statistics = Statistics( sizeInBytes = newTotalSize, rowCount = Some(rowCount), - colStats = columnStats ++ catalogTable.stats.map(_.colStats).getOrElse(Map())) + // Newly computed column stats should override the existing ones. + colStats = catalogTable.stats.map(_.colStats).getOrElse(Map()) ++ columnStats) sessionState.catalog.alterTable(catalogTable.copy(stats = Some(statistics))) // Refresh the cached data source table in the catalog. sessionState.catalog.refreshTable(tableIdentWithDB) @@ -90,8 +92,9 @@ case class AnalyzeColumnCommand( } } if (duplicatedColumns.nonEmpty) { - logWarning(s"Duplicated columns ${duplicatedColumns.mkString("(", ", ", ")")} detected " + - s"when analyzing columns ${columnNames.mkString("(", ", ", ")")}, ignoring them.") + logWarning("Duplicate column names were deduplicated in `ANALYZE TABLE` statement. " + + s"Input columns: ${columnNames.mkString("(", ", ", ")")}. " + + s"Duplicate columns: ${duplicatedColumns.mkString("(", ", ", ")")}.") } // Collect statistics per column. @@ -116,22 +119,24 @@ case class AnalyzeColumnCommand( } object ColumnStatStruct { - val zero = Literal(0, LongType) - val one = Literal(1, LongType) + private val zero = Literal(0, LongType) + private val one = Literal(1, LongType) - def numNulls(e: Expression): Expression = if (e.nullable) Sum(If(IsNull(e), one, zero)) else zero - def max(e: Expression): Expression = Max(e) - def min(e: Expression): Expression = Min(e) - def ndv(e: Expression, relativeSD: Double): Expression = { + private def numNulls(e: Expression): Expression = { + if (e.nullable) Sum(If(IsNull(e), one, zero)) else zero + } + private def max(e: Expression): Expression = Max(e) + private def min(e: Expression): Expression = Min(e) + private def ndv(e: Expression, relativeSD: Double): Expression = { // the approximate ndv should never be larger than the number of rows Least(Seq(HyperLogLogPlusPlus(e, relativeSD), Count(one))) } - def avgLength(e: Expression): Expression = Average(Length(e)) - def maxLength(e: Expression): Expression = Max(Length(e)) - def numTrues(e: Expression): Expression = Sum(If(e, one, zero)) - def numFalses(e: Expression): Expression = Sum(If(Not(e), one, zero)) + private def avgLength(e: Expression): Expression = Average(Length(e)) + private def maxLength(e: Expression): Expression = Max(Length(e)) + private def numTrues(e: Expression): Expression = Sum(If(e, one, zero)) + private def numFalses(e: Expression): Expression = Sum(If(Not(e), one, zero)) - def getStruct(exprs: Seq[Expression]): CreateStruct = { + private def getStruct(exprs: Seq[Expression]): CreateStruct = { CreateStruct(exprs.map { expr: Expression => expr.transformUp { case af: AggregateFunction => af.toAggregateExpression() @@ -139,19 +144,19 @@ object ColumnStatStruct { }) } - def numericColumnStat(e: Expression, relativeSD: Double): Seq[Expression] = { + private def numericColumnStat(e: Expression, relativeSD: Double): Seq[Expression] = { Seq(numNulls(e), max(e), min(e), ndv(e, relativeSD)) } - def stringColumnStat(e: Expression, relativeSD: Double): Seq[Expression] = { + private def stringColumnStat(e: Expression, relativeSD: Double): Seq[Expression] = { Seq(numNulls(e), avgLength(e), maxLength(e), ndv(e, relativeSD)) } - def binaryColumnStat(e: Expression): Seq[Expression] = { + private def binaryColumnStat(e: Expression): Seq[Expression] = { Seq(numNulls(e), avgLength(e), maxLength(e)) } - def booleanColumnStat(e: Expression): Seq[Expression] = { + private def booleanColumnStat(e: Expression): Seq[Expression] = { Seq(numNulls(e), numTrues(e), numFalses(e)) } @@ -162,14 +167,14 @@ object ColumnStatStruct { } } - def apply(e: Attribute, relativeSD: Double): CreateStruct = e.dataType match { + def apply(attr: Attribute, relativeSD: Double): CreateStruct = attr.dataType match { // Use aggregate functions to compute statistics we need. - case _: NumericType | TimestampType | DateType => getStruct(numericColumnStat(e, relativeSD)) - case StringType => getStruct(stringColumnStat(e, relativeSD)) - case BinaryType => getStruct(binaryColumnStat(e)) - case BooleanType => getStruct(booleanColumnStat(e)) + case _: NumericType | TimestampType | DateType => getStruct(numericColumnStat(attr, relativeSD)) + case StringType => getStruct(stringColumnStat(attr, relativeSD)) + case BinaryType => getStruct(binaryColumnStat(attr)) + case BooleanType => getStruct(booleanColumnStat(attr)) case otherType => throw new AnalysisException("Analyzing columns is not supported for column " + - s"${e.name} of data type: ${e.dataType}.") + s"${attr.name} of data type: ${attr.dataType}.") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index e671604c39855..c8447651dd672 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -578,7 +578,8 @@ object SQLConf { val NDV_MAX_ERROR = SQLConfigBuilder("spark.sql.statistics.ndv.maxError") .internal() - .doc("The maximum estimation error allowed in HyperLogLog++ algorithm.") + .doc("The maximum estimation error allowed in HyperLogLog++ algorithm when generating " + + "column level statistics.") .doubleConf .createWithDefault(0.05) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 85228bb00123d..c351063a63ff8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -21,7 +21,7 @@ import java.io.{File, PrintWriter} import scala.reflect.ClassTag -import org.apache.spark.sql.{AnalysisException, QueryTest, Row, StatisticsTest} +import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Statistics} import org.apache.spark.sql.execution.command.{AnalyzeTableCommand, DDLUtils} @@ -358,53 +358,187 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils } } - test("generate column-level statistics and load them from hive metastore") { + private def getStatsBeforeAfterUpdate(isAnalyzeColumns: Boolean): (Statistics, Statistics) = { + val tableName = "tbl" + var statsBeforeUpdate: Statistics = null + var statsAfterUpdate: Statistics = null + withTable(tableName) { + val tableIndent = TableIdentifier(tableName, Some("default")) + val catalog = spark.sessionState.catalog.asInstanceOf[HiveSessionCatalog] + sql(s"CREATE TABLE $tableName (key int) USING PARQUET") + sql(s"INSERT INTO $tableName SELECT 1") + if (isAnalyzeColumns) { + sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS key") + } else { + sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS") + } + // Table lookup will make the table cached. + catalog.lookupRelation(tableIndent) + statsBeforeUpdate = catalog.getCachedDataSourceTable(tableIndent) + .asInstanceOf[LogicalRelation].catalogTable.get.stats.get + + sql(s"INSERT INTO $tableName SELECT 2") + if (isAnalyzeColumns) { + sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS key") + } else { + sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS") + } + catalog.lookupRelation(tableIndent) + statsAfterUpdate = catalog.getCachedDataSourceTable(tableIndent) + .asInstanceOf[LogicalRelation].catalogTable.get.stats.get + } + (statsBeforeUpdate, statsAfterUpdate) + } + + test("test refreshing table stats of cached data source table by `ANALYZE TABLE` statement") { + val (statsBeforeUpdate, statsAfterUpdate) = getStatsBeforeAfterUpdate(isAnalyzeColumns = false) + + assert(statsBeforeUpdate.sizeInBytes > 0) + assert(statsBeforeUpdate.rowCount == Some(1)) + + assert(statsAfterUpdate.sizeInBytes > statsBeforeUpdate.sizeInBytes) + assert(statsAfterUpdate.rowCount == Some(2)) + } + + test("test refreshing column stats of cached data source table by `ANALYZE TABLE` statement") { + val (statsBeforeUpdate, statsAfterUpdate) = getStatsBeforeAfterUpdate(isAnalyzeColumns = true) + + assert(statsBeforeUpdate.sizeInBytes > 0) + assert(statsBeforeUpdate.rowCount == Some(1)) + StatisticsTest.checkColStat( + dataType = IntegerType, + colStat = statsBeforeUpdate.colStats("key"), + expectedColStat = ColumnStat(InternalRow(0L, 1, 1, 1L)), + rsd = spark.sessionState.conf.ndvMaxError) + + assert(statsAfterUpdate.sizeInBytes > statsBeforeUpdate.sizeInBytes) + assert(statsAfterUpdate.rowCount == Some(2)) + StatisticsTest.checkColStat( + dataType = IntegerType, + colStat = statsAfterUpdate.colStats("key"), + expectedColStat = ColumnStat(InternalRow(0L, 2, 1, 2L)), + rsd = spark.sessionState.conf.ndvMaxError) + } + + private lazy val (testDataFrame, expectedColStatsSeq) = { import testImplicits._ val intSeq = Seq(1, 2) val stringSeq = Seq("a", "bb") + val binarySeq = Seq("a", "bb").map(_.getBytes) val booleanSeq = Seq(true, false) - val data = intSeq.indices.map { i => - (intSeq(i), stringSeq(i), booleanSeq(i)) + (intSeq(i), stringSeq(i), binarySeq(i), booleanSeq(i)) } - val tableName = "table" - withTable(tableName) { - val df = data.toDF("c1", "c2", "c3") - df.write.format("parquet").saveAsTable(tableName) - val expectedColStatsSeq = df.schema.map { f => - val colStat = f.dataType match { - case IntegerType => - ColumnStat(InternalRow(0L, intSeq.max, intSeq.min, intSeq.distinct.length.toLong)) - case StringType => - ColumnStat(InternalRow(0L, stringSeq.map(_.length).sum / stringSeq.length.toDouble, - stringSeq.map(_.length).max.toInt, stringSeq.distinct.length.toLong)) - case BooleanType => - ColumnStat(InternalRow(0L, booleanSeq.count(_.equals(true)).toLong, - booleanSeq.count(_.equals(false)).toLong)) - } - (f, colStat) + val df: DataFrame = data.toDF("c1", "c2", "c3", "c4") + val expectedColStatsSeq: Seq[(StructField, ColumnStat)] = df.schema.map { f => + val colStat = f.dataType match { + case IntegerType => + ColumnStat(InternalRow(0L, intSeq.max, intSeq.min, intSeq.distinct.length.toLong)) + case StringType => + ColumnStat(InternalRow(0L, stringSeq.map(_.length).sum / stringSeq.length.toDouble, + stringSeq.map(_.length).max.toInt, stringSeq.distinct.length.toLong)) + case BinaryType => + ColumnStat(InternalRow(0L, binarySeq.map(_.length).sum / binarySeq.length.toDouble, + binarySeq.map(_.length).max.toInt)) + case BooleanType => + ColumnStat(InternalRow(0L, booleanSeq.count(_.equals(true)).toLong, + booleanSeq.count(_.equals(false)).toLong)) } + (f, colStat) + } + (df, expectedColStatsSeq) + } + + private def checkColStats( + tableName: String, + isDataSourceTable: Boolean, + expectedColStatsSeq: Seq[(StructField, ColumnStat)]): Unit = { + val readback = spark.table(tableName) + val stats = readback.queryExecution.analyzed.collect { + case rel: MetastoreRelation => + assert(!isDataSourceTable, "Expected a Hive serde table, but got a data source table") + rel.catalogTable.stats.get + case rel: LogicalRelation => + assert(isDataSourceTable, "Expected a data source table, but got a Hive serde table") + rel.catalogTable.get.stats.get + } + assert(stats.length == 1) + val columnStats = stats.head.colStats + assert(columnStats.size == expectedColStatsSeq.length) + expectedColStatsSeq.foreach { case (field, expectedColStat) => + StatisticsTest.checkColStat( + dataType = field.dataType, + colStat = columnStats(field.name), + expectedColStat = expectedColStat, + rsd = spark.sessionState.conf.ndvMaxError) + } + } - sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS c1, c2, c3") - val readback = spark.table(tableName) - val relations = readback.queryExecution.analyzed.collect { case rel: LogicalRelation => - val columnStats = rel.catalogTable.get.stats.get.colStats - expectedColStatsSeq.foreach { case (field, expectedColStat) => - assert(columnStats.contains(field.name)) - val colStat = columnStats(field.name) + test("generate and load column-level stats for data source table") { + val dsTable = "dsTable" + withTable(dsTable) { + testDataFrame.write.format("parquet").saveAsTable(dsTable) + sql(s"ANALYZE TABLE $dsTable COMPUTE STATISTICS FOR COLUMNS c1, c2, c3, c4") + checkColStats(dsTable, isDataSourceTable = true, expectedColStatsSeq) + } + } + + test("generate and load column-level stats for hive serde table") { + val hTable = "hTable" + val tmp = "tmp" + withTable(hTable, tmp) { + testDataFrame.write.format("parquet").saveAsTable(tmp) + sql(s"CREATE TABLE $hTable (c1 int, c2 string, c3 binary, c4 boolean) STORED AS TEXTFILE") + sql(s"INSERT INTO $hTable SELECT * FROM $tmp") + sql(s"ANALYZE TABLE $hTable COMPUTE STATISTICS FOR COLUMNS c1, c2, c3, c4") + checkColStats(hTable, isDataSourceTable = false, expectedColStatsSeq) + } + } + + // When caseSensitive is on, for columns with only case difference, they are different columns + // and we should generate column stats for all of them. + private def checkCaseSensitiveColStats(columnName: String): Unit = { + val tableName = "tbl" + withTable(tableName) { + val column1 = columnName.toLowerCase + val column2 = columnName.toUpperCase + withSQLConf("spark.sql.caseSensitive" -> "true") { + sql(s"CREATE TABLE $tableName (`$column1` int, `$column2` double) USING PARQUET") + sql(s"INSERT INTO $tableName SELECT 1, 3.0") + sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS `$column1`, `$column2`") + val readback = spark.table(tableName) + val relations = readback.queryExecution.analyzed.collect { case rel: LogicalRelation => + val columnStats = rel.catalogTable.get.stats.get.colStats + assert(columnStats.size == 2) + StatisticsTest.checkColStat( + dataType = IntegerType, + colStat = columnStats(column1), + expectedColStat = ColumnStat(InternalRow(0L, 1, 1, 1L)), + rsd = spark.sessionState.conf.ndvMaxError) StatisticsTest.checkColStat( - dataType = field.dataType, - colStat = colStat, - expectedColStat = expectedColStat, + dataType = DoubleType, + colStat = columnStats(column2), + expectedColStat = ColumnStat(InternalRow(0L, 3.0d, 3.0d, 1L)), rsd = spark.sessionState.conf.ndvMaxError) + rel } - rel + assert(relations.size == 1) } - assert(relations.size == 1) } } + test("check column statistics for case sensitive column names") { + checkCaseSensitiveColStats(columnName = "c1") + } + + test("check column statistics for case sensitive non-ascii column names") { + // scalastyle:off + // non ascii characters are not allowed in the source code, so we disable the scalastyle. + checkCaseSensitiveColStats(columnName = "列c") + // scalastyle:on + } + test("estimates the size of a test MetastoreRelation") { val df = sql("""SELECT * FROM src""") val sizes = df.queryExecution.analyzed.collect { case mr: MetastoreRelation => From a0ebcb3a30ec64e01608ed6fa7b7ffb7acbd3af2 Mon Sep 17 00:00:00 2001 From: Dhruve Ashar Date: Fri, 14 Oct 2016 17:45:27 +0100 Subject: [PATCH 289/306] [DOC] Fix typo in sql hive doc Change is too trivial to file a JIRA. Author: Dhruve Ashar Closes #15485 from dhruve/master. --- docs/sql-programming-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index d0f43ab0a9cc9..dcc828cc69fed 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -998,7 +998,7 @@ The following options can be used to configure the version of Hive that is used
    • A classpath in the standard format for the JVM. This classpath must include all of Hive and its dependencies, including the correct version of Hadoop. These jars only need to be present on the driver, but if you are running in yarn cluster mode then you must ensure - they are packaged with you application.
    • + they are packaged with your application. From fa37877af02a956203e8a00811b20f34af0278f7 Mon Sep 17 00:00:00 2001 From: Andrew Ash Date: Fri, 14 Oct 2016 18:13:19 +0100 Subject: [PATCH 290/306] Typo: form -> from ## What changes were proposed in this pull request? Minor typo fix ## How was this patch tested? Existing unit tests on Jenkins Author: Andrew Ash Closes #15486 from ash211/patch-8. --- .../src/main/scala/org/apache/spark/sql/DataFrameReader.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index a716a916b7f7f..ac3358592202f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -363,7 +363,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * type. *
    • `quote` (default `"`): sets the single character used for escaping quoted values where * the separator can be part of the value. If you would like to turn off quotations, you need to - * set not `null` but an empty string. This behaviour is different form + * set not `null` but an empty string. This behaviour is different from * `com.databricks.spark.csv`.
    • *
    • `escape` (default `\`): sets the single character used for escaping quotes inside * an already quoted value.
    • From 05800b4b4e7873ebc445dfcd020b76d7539686e1 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 14 Oct 2016 12:39:25 -0700 Subject: [PATCH 291/306] [TEST] Ignore flaky test in StreamingQueryListenerSuite ## What changes were proposed in this pull request? Ignoring the flaky test introduced in #15307 https://amplab.cs.berkeley.edu/jenkins/job/spark-master-test-sbt-hadoop-2.7/1736/testReport/junit/org.apache.spark.sql.streaming/StreamingQueryListenerSuite/single_listener__check_trigger_statuses/ Author: Tathagata Das Closes #15491 from tdas/metrics-flaky-test. --- .../spark/sql/streaming/StreamingQueryListenerSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala index 6256385dfd0e4..9e0eefbc58aa5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala @@ -43,7 +43,7 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { // Make sure we don't leak any events to the next test } - test("single listener, check trigger statuses") { + ignore("single listener, check trigger statuses") { import StreamingQueryListenerSuite._ clock = new ManualClock() From de1c1ca5c9d6064d3b7b3711e3bfb08fa018abe8 Mon Sep 17 00:00:00 2001 From: sethah Date: Fri, 14 Oct 2016 20:21:03 +0000 Subject: [PATCH 292/306] [SPARK-17941][ML][TEST] Logistic regression tests should use sample weights. ## What changes were proposed in this pull request? The sample weight testing for logistic regressions is not robust. Logistic regression suite already has many test cases comparing results to R glmnet. Since both libraries support sample weights, we should use sample weights in the test to increase coverage for sample weighting. This patch doesn't really add any code and makes the testing more complete. Also fixed some errors with the R code that was referenced in the test suit. Changed `standardization=T` to `standardize=T` since the former is invalid. ## How was this patch tested? Existing unit tests are modified. No non-test code is touched. Author: sethah Closes #15488 from sethah/logreg_weight_tests. --- .../LogisticRegressionSuite.scala | 1493 +++++++++-------- 1 file changed, 748 insertions(+), 745 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 42b56754e0835..bc631dc6d3149 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -25,14 +25,14 @@ import scala.util.control.Breaks._ import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.classification.LogisticRegressionSuite._ -import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.feature.{Instance, LabeledPoint} import org.apache.spark.ml.linalg.{DenseMatrix, Matrices, SparseMatrix, SparseVector, Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{Dataset, Row} -import org.apache.spark.sql.functions.{col, lit} +import org.apache.spark.sql.functions.{col, lit, rand} import org.apache.spark.sql.types.LongType class LogisticRegressionSuite @@ -40,6 +40,7 @@ class LogisticRegressionSuite import testImplicits._ + private val seed = 42 @transient var smallBinaryDataset: Dataset[_] = _ @transient var smallMultinomialDataset: Dataset[_] = _ @transient var binaryDataset: Dataset[_] = _ @@ -49,7 +50,7 @@ class LogisticRegressionSuite override def beforeAll(): Unit = { super.beforeAll() - smallBinaryDataset = generateLogisticInput(1.0, 1.0, nPoints = 100, seed = 42).toDF() + smallBinaryDataset = generateLogisticInput(1.0, 1.0, nPoints = 100, seed = seed).toDF() smallMultinomialDataset = { val nPoints = 100 @@ -61,7 +62,7 @@ class LogisticRegressionSuite val xVariance = Array(0.6856, 0.1899) val testData = generateMultinomialLogisticInput( - coefficients, xMean, xVariance, addIntercept = true, nPoints, 42) + coefficients, xMean, xVariance, addIntercept = true, nPoints, seed) val df = sc.parallelize(testData, 4).toDF() df.cache() @@ -76,9 +77,9 @@ class LogisticRegressionSuite val testData = generateMultinomialLogisticInput(coefficients, xMean, xVariance, - addIntercept = true, nPoints, 42) + addIntercept = true, nPoints, seed) - sc.parallelize(testData, 4).toDF() + sc.parallelize(testData, 4).toDF().withColumn("weight", rand(seed)) } multinomialDataset = { @@ -91,9 +92,9 @@ class LogisticRegressionSuite val xVariance = Array(0.6856, 0.1899, 3.116, 0.581) val testData = generateMultinomialLogisticInput( - coefficients, xMean, xVariance, addIntercept = true, nPoints, 42) + coefficients, xMean, xVariance, addIntercept = true, nPoints, seed) - val df = sc.parallelize(testData, 4).toDF() + val df = sc.parallelize(testData, 4).toDF().withColumn("weight", rand(seed)) df.cache() df } @@ -104,11 +105,11 @@ class LogisticRegressionSuite * so we can validate the training accuracy compared with R's glmnet package. */ ignore("export test data into CSV format") { - binaryDataset.rdd.map { case Row(label: Double, features: Vector) => - label + "," + features.toArray.mkString(",") + binaryDataset.rdd.map { case Row(label: Double, features: Vector, weight: Double) => + label + "," + weight + "," + features.toArray.mkString(",") }.repartition(1).saveAsTextFile("target/tmp/LogisticRegressionSuite/binaryDataset") - multinomialDataset.rdd.map { case Row(label: Double, features: Vector) => - label + "," + features.toArray.mkString(",") + multinomialDataset.rdd.map { case Row(label: Double, features: Vector, weight: Double) => + label + "," + weight + "," + features.toArray.mkString(",") }.repartition(1).saveAsTextFile("target/tmp/LogisticRegressionSuite/multinomialDataset") } @@ -519,31 +520,35 @@ class LogisticRegressionSuite test("binary logistic regression with intercept without regularization") { val trainer1 = (new LogisticRegression).setFitIntercept(true).setStandardization(true) + .setWeightCol("weight") val trainer2 = (new LogisticRegression).setFitIntercept(true).setStandardization(false) + .setWeightCol("weight") val model1 = trainer1.fit(binaryDataset) val model2 = trainer2.fit(binaryDataset) /* - Using the following R code to load the data and train the model using glmnet package. - - library("glmnet") - data <- read.csv("path", header=FALSE) - label = factor(data$V1) - features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - coefficients = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 0)) - coefficients + Use the following R code to load the data and train the model using glmnet package. + library("glmnet") + data <- read.csv("path", header=FALSE) + label = factor(data$V1) + w = data$V2 + features = as.matrix(data.frame(data$V3, data$V4, data$V5, data$V6)) + coefficients = coef(glmnet(features, label, weights=w, family="binomial", alpha = 0, + lambda = 0)) + coefficients + $`0` + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 2.7355261 + data.V3 -0.5734389 + data.V4 0.8911736 + data.V5 -0.3878645 + data.V6 -0.8060570 - 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - (Intercept) 2.8366423 - data.V2 -0.5895848 - data.V3 0.8931147 - data.V4 -0.3925051 - data.V5 -0.7996864 */ - val interceptR = 2.8366423 - val coefficientsR = Vectors.dense(-0.5895848, 0.8931147, -0.3925051, -0.7996864) + val coefficientsR = Vectors.dense(-0.5734389, 0.8911736, -0.3878645, -0.8060570) + val interceptR = 2.7355261 assert(model1.intercept ~== interceptR relTol 1E-3) assert(model1.coefficients ~= coefficientsR relTol 1E-3) @@ -555,413 +560,374 @@ class LogisticRegressionSuite test("binary logistic regression without intercept without regularization") { val trainer1 = (new LogisticRegression).setFitIntercept(false).setStandardization(true) + .setWeightCol("weight") val trainer2 = (new LogisticRegression).setFitIntercept(false).setStandardization(false) + .setWeightCol("weight") val model1 = trainer1.fit(binaryDataset) val model2 = trainer2.fit(binaryDataset) /* - Using the following R code to load the data and train the model using glmnet package. + Use the following R code to load the data and train the model using glmnet package. - library("glmnet") - data <- read.csv("path", header=FALSE) - label = factor(data$V1) - features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - coefficients = - coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 0, intercept=FALSE)) - coefficients + library("glmnet") + data <- read.csv("path", header=FALSE) + label = factor(data$V1) + w = data$V2 + features = as.matrix(data.frame(data$V3, data$V4, data$V5, data$V6)) + coefficients = coef(glmnet(features, label, weights=w, family="binomial", alpha = 0, + lambda = 0, intercept=FALSE)) + coefficients + $`0` + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + data.V3 -0.3448461 + data.V4 1.2776453 + data.V5 -0.3539178 + data.V6 -0.7469384 - 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - (Intercept) . - data.V2 -0.3534996 - data.V3 1.2964482 - data.V4 -0.3571741 - data.V5 -0.7407946 */ - val interceptR = 0.0 - val coefficientsR = Vectors.dense(-0.3534996, 1.2964482, -0.3571741, -0.7407946) + val coefficientsR = Vectors.dense(-0.3448461, 1.2776453, -0.3539178, -0.7469384) - assert(model1.intercept ~== interceptR relTol 1E-3) + assert(model1.intercept ~== 0.0 relTol 1E-3) assert(model1.coefficients ~= coefficientsR relTol 1E-2) // Without regularization, with or without standardization should converge to the same solution. - assert(model2.intercept ~== interceptR relTol 1E-3) + assert(model2.intercept ~== 0.0 relTol 1E-3) assert(model2.coefficients ~= coefficientsR relTol 1E-2) } test("binary logistic regression with intercept with L1 regularization") { val trainer1 = (new LogisticRegression).setFitIntercept(true) - .setElasticNetParam(1.0).setRegParam(0.12).setStandardization(true) + .setElasticNetParam(1.0).setRegParam(0.12).setStandardization(true).setWeightCol("weight") val trainer2 = (new LogisticRegression).setFitIntercept(true) - .setElasticNetParam(1.0).setRegParam(0.12).setStandardization(false) + .setElasticNetParam(1.0).setRegParam(0.12).setStandardization(false).setWeightCol("weight") val model1 = trainer1.fit(binaryDataset) val model2 = trainer2.fit(binaryDataset) /* - Using the following R code to load the data and train the model using glmnet package. + Use the following R code to load the data and train the model using glmnet package. - library("glmnet") - data <- read.csv("path", header=FALSE) - label = factor(data$V1) - features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - coefficients = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12)) - coefficients + library("glmnet") + data <- read.csv("path", header=FALSE) + label = factor(data$V1) + w = data$V2 + features = as.matrix(data.frame(data$V3, data$V4, data$V5, data$V6)) + coefficients = coef(glmnet(features, label, weights=w, family="binomial", alpha = 1, + lambda = 0.12, standardize=T)) + coefficients + $`0` + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) -0.06775980 + data.V3 . + data.V4 . + data.V5 -0.03933146 + data.V6 -0.03047580 - 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - (Intercept) -0.05627428 - data.V2 . - data.V3 . - data.V4 -0.04325749 - data.V5 -0.02481551 */ - val interceptR1 = -0.05627428 - val coefficientsR1 = Vectors.dense(0.0, 0.0, -0.04325749, -0.02481551) + val coefficientsRStd = Vectors.dense(0.0, 0.0, -0.03933146, -0.03047580) + val interceptRStd = -0.06775980 - assert(model1.intercept ~== interceptR1 relTol 1E-2) - assert(model1.coefficients ~= coefficientsR1 absTol 2E-2) + assert(model1.intercept ~== interceptRStd relTol 1E-2) + assert(model1.coefficients ~= coefficientsRStd absTol 2E-2) /* - Using the following R code to load the data and train the model using glmnet package. + Use the following R code to load the data and train the model using glmnet package. - library("glmnet") - data <- read.csv("path", header=FALSE) - label = factor(data$V1) - features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - coefficients = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12, - standardize=FALSE)) - coefficients + library("glmnet") + data <- read.csv("path", header=FALSE) + label = factor(data$V1) + w = data$V2 + features = as.matrix(data.frame(data$V3, data$V4, data$V5, data$V6)) + coefficients = coef(glmnet(features, label, weights=w, family="binomial", alpha = 1, + lambda = 0.12, standardize=F)) + coefficients + $`0` + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 0.3544768 + data.V3 . + data.V4 . + data.V5 -0.1626191 + data.V6 . - 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - (Intercept) 0.3722152 - data.V2 . - data.V3 . - data.V4 -0.1665453 - data.V5 . */ - val interceptR2 = 0.3722152 - val coefficientsR2 = Vectors.dense(0.0, 0.0, -0.1665453, 0.0) + val coefficientsR = Vectors.dense(0.0, 0.0, -0.1626191, 0.0) + val interceptR = 0.3544768 - assert(model2.intercept ~== interceptR2 relTol 1E-2) - assert(model2.coefficients ~== coefficientsR2 absTol 1E-3) + assert(model2.intercept ~== interceptR relTol 1E-2) + assert(model2.coefficients ~== coefficientsR absTol 1E-3) // TODO: move this to a standalone test of compression after SPARK-17471 assert(model2.coefficients.isInstanceOf[SparseVector]) } test("binary logistic regression without intercept with L1 regularization") { val trainer1 = (new LogisticRegression).setFitIntercept(false) - .setElasticNetParam(1.0).setRegParam(0.12).setStandardization(true) + .setElasticNetParam(1.0).setRegParam(0.12).setStandardization(true).setWeightCol("weight") val trainer2 = (new LogisticRegression).setFitIntercept(false) - .setElasticNetParam(1.0).setRegParam(0.12).setStandardization(false) + .setElasticNetParam(1.0).setRegParam(0.12).setStandardization(false).setWeightCol("weight") val model1 = trainer1.fit(binaryDataset) val model2 = trainer2.fit(binaryDataset) /* - Using the following R code to load the data and train the model using glmnet package. - - library("glmnet") - data <- read.csv("path", header=FALSE) - label = factor(data$V1) - features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - coefficients = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12, - intercept=FALSE)) - coefficients - - 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - (Intercept) . - data.V2 . - data.V3 . - data.V4 -0.05189203 - data.V5 -0.03891782 - */ - val interceptR1 = 0.0 - val coefficientsR1 = Vectors.dense(0.0, 0.0, -0.05189203, -0.03891782) - - assert(model1.intercept ~== interceptR1 relTol 1E-3) - assert(model1.coefficients ~= coefficientsR1 absTol 1E-3) + Use the following R code to load the data and train the model using glmnet package. - /* - Using the following R code to load the data and train the model using glmnet package. + library("glmnet") + data <- read.csv("path", header=FALSE) + label = factor(data$V1) + w = data$V2 + features = as.matrix(data.frame(data$V3, data$V4, data$V5, data$V6)) + coefficientsStd = coef(glmnet(features, label, weights=w, family="binomial", alpha = 1, + lambda = 0.12, intercept=F, standardize=T)) + coefficients = coef(glmnet(features, label, weights=w, family="binomial", alpha = 1, + lambda = 0.12, intercept=F, standardize=F)) + coefficientsStd + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + data.V3 . + data.V4 . + data.V5 -0.04967635 + data.V6 -0.04757757 - library("glmnet") - data <- read.csv("path", header=FALSE) - label = factor(data$V1) - features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - coefficients = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12, - intercept=FALSE, standardize=FALSE)) - coefficients + coefficients + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + data.V3 . + data.V4 . + data.V5 -0.08433195 + data.V6 . - 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - (Intercept) . - data.V2 . - data.V3 . - data.V4 -0.08420782 - data.V5 . */ - val interceptR2 = 0.0 - val coefficientsR2 = Vectors.dense(0.0, 0.0, -0.08420782, 0.0) + val coefficientsRStd = Vectors.dense(0.0, 0.0, -0.04967635, -0.04757757) - assert(model2.intercept ~== interceptR2 absTol 1E-3) - assert(model2.coefficients ~= coefficientsR2 absTol 1E-3) + val coefficientsR = Vectors.dense(0.0, 0.0, -0.08433195, 0.0) + + assert(model1.intercept ~== 0.0 absTol 1E-3) + assert(model1.coefficients ~= coefficientsRStd absTol 1E-3) + assert(model2.intercept ~== 0.0 absTol 1E-3) + assert(model2.coefficients ~= coefficientsR absTol 1E-3) } test("binary logistic regression with intercept with L2 regularization") { val trainer1 = (new LogisticRegression).setFitIntercept(true) - .setElasticNetParam(0.0).setRegParam(1.37).setStandardization(true) + .setElasticNetParam(0.0).setRegParam(1.37).setStandardization(true).setWeightCol("weight") val trainer2 = (new LogisticRegression).setFitIntercept(true) - .setElasticNetParam(0.0).setRegParam(1.37).setStandardization(false) + .setElasticNetParam(0.0).setRegParam(1.37).setStandardization(false).setWeightCol("weight") val model1 = trainer1.fit(binaryDataset) val model2 = trainer2.fit(binaryDataset) /* - Using the following R code to load the data and train the model using glmnet package. - - library("glmnet") - data <- read.csv("path", header=FALSE) - label = factor(data$V1) - features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - coefficients = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37)) - coefficients - - 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - (Intercept) 0.15021751 - data.V2 -0.07251837 - data.V3 0.10724191 - data.V4 -0.04865309 - data.V5 -0.10062872 - */ - val interceptR1 = 0.15021751 - val coefficientsR1 = Vectors.dense(-0.07251837, 0.10724191, -0.04865309, -0.10062872) - - assert(model1.intercept ~== interceptR1 relTol 1E-3) - assert(model1.coefficients ~= coefficientsR1 relTol 1E-3) + Use the following R code to load the data and train the model using glmnet package. - /* - Using the following R code to load the data and train the model using glmnet package. + library("glmnet") + data <- read.csv("path", header=FALSE) + label = factor(data$V1) + w = data$V2 + features = as.matrix(data.frame(data$V3, data$V4, data$V5, data$V6)) + coefficientsStd = coef(glmnet(features, label, weights=w, family="binomial", alpha = 0, + lambda = 1.37, standardize=T)) + coefficients = coef(glmnet(features, label, weights=w, family="binomial", alpha = 0, + lambda = 1.37, standardize=F)) + coefficientsStd + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 0.12707703 + data.V3 -0.06980967 + data.V4 0.10803933 + data.V5 -0.04800404 + data.V6 -0.10165096 - library("glmnet") - data <- read.csv("path", header=FALSE) - label = factor(data$V1) - features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - coefficients = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37, - standardize=FALSE)) - coefficients + coefficients + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 0.46613016 + data.V3 -0.04944529 + data.V4 0.02326772 + data.V5 -0.11362772 + data.V6 -0.06312848 - 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - (Intercept) 0.48657516 - data.V2 -0.05155371 - data.V3 0.02301057 - data.V4 -0.11482896 - data.V5 -0.06266838 */ - val interceptR2 = 0.48657516 - val coefficientsR2 = Vectors.dense(-0.05155371, 0.02301057, -0.11482896, -0.06266838) + val coefficientsRStd = Vectors.dense(-0.06980967, 0.10803933, -0.04800404, -0.10165096) + val interceptRStd = 0.12707703 + val coefficientsR = Vectors.dense(-0.04944529, 0.02326772, -0.11362772, -0.06312848) + val interceptR = 0.46613016 - assert(model2.intercept ~== interceptR2 relTol 1E-3) - assert(model2.coefficients ~= coefficientsR2 relTol 1E-3) + assert(model1.intercept ~== interceptRStd relTol 1E-3) + assert(model1.coefficients ~= coefficientsRStd relTol 1E-3) + assert(model2.intercept ~== interceptR relTol 1E-3) + assert(model2.coefficients ~= coefficientsR relTol 1E-3) } test("binary logistic regression without intercept with L2 regularization") { val trainer1 = (new LogisticRegression).setFitIntercept(false) - .setElasticNetParam(0.0).setRegParam(1.37).setStandardization(true) + .setElasticNetParam(0.0).setRegParam(1.37).setStandardization(true).setWeightCol("weight") val trainer2 = (new LogisticRegression).setFitIntercept(false) - .setElasticNetParam(0.0).setRegParam(1.37).setStandardization(false) + .setElasticNetParam(0.0).setRegParam(1.37).setStandardization(false).setWeightCol("weight") val model1 = trainer1.fit(binaryDataset) val model2 = trainer2.fit(binaryDataset) /* - Using the following R code to load the data and train the model using glmnet package. + Use the following R code to load the data and train the model using glmnet package. - library("glmnet") - data <- read.csv("path", header=FALSE) - label = factor(data$V1) - features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - coefficients = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37, - intercept=FALSE)) - coefficients + library("glmnet") + data <- read.csv("path", header=FALSE) + label = factor(data$V1) + w = data$V2 + features = as.matrix(data.frame(data$V3, data$V4, data$V5, data$V6)) + coefficientsStd = coef(glmnet(features, label, weights=w, family="binomial", alpha = 0, + lambda = 1.37, intercept=F, standardize=T)) + coefficients = coef(glmnet(features, label, weights=w, family="binomial", alpha = 0, + lambda = 1.37, intercept=F, standardize=F)) + coefficientsStd + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + data.V3 -0.06000152 + data.V4 0.12598737 + data.V5 -0.04669009 + data.V6 -0.09941025 - 5 x 1 sparse Matrix of class "dgCMatrix" + coefficients + 5 x 1 sparse Matrix of class "dgCMatrix" s0 - (Intercept) . - data.V2 -0.06099165 - data.V3 0.12857058 - data.V4 -0.04708770 - data.V5 -0.09799775 - */ - val interceptR1 = 0.0 - val coefficientsR1 = Vectors.dense(-0.06099165, 0.12857058, -0.04708770, -0.09799775) - - assert(model1.intercept ~== interceptR1 absTol 1E-3) - assert(model1.coefficients ~= coefficientsR1 relTol 1E-2) - - /* - Using the following R code to load the data and train the model using glmnet package. - - library("glmnet") - data <- read.csv("path", header=FALSE) - label = factor(data$V1) - features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - coefficients = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37, - intercept=FALSE, standardize=FALSE)) - coefficients + (Intercept) . + data.V3 -0.005482255 + data.V4 0.048106338 + data.V5 -0.093411640 + data.V6 -0.054149798 - 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - (Intercept) . - data.V2 -0.005679651 - data.V3 0.048967094 - data.V4 -0.093714016 - data.V5 -0.053314311 */ - val interceptR2 = 0.0 - val coefficientsR2 = Vectors.dense(-0.005679651, 0.048967094, -0.093714016, -0.053314311) + val coefficientsRStd = Vectors.dense(-0.06000152, 0.12598737, -0.04669009, -0.09941025) + val coefficientsR = Vectors.dense(-0.005482255, 0.048106338, -0.093411640, -0.054149798) - assert(model2.intercept ~== interceptR2 absTol 1E-3) - assert(model2.coefficients ~= coefficientsR2 relTol 1E-2) + assert(model1.intercept ~== 0.0 absTol 1E-3) + assert(model1.coefficients ~= coefficientsRStd relTol 1E-2) + assert(model2.intercept ~== 0.0 absTol 1E-3) + assert(model2.coefficients ~= coefficientsR relTol 1E-2) } test("binary logistic regression with intercept with ElasticNet regularization") { - val trainer1 = (new LogisticRegression).setFitIntercept(true) - .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(true) + val trainer1 = (new LogisticRegression).setFitIntercept(true).setMaxIter(200) + .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(true).setWeightCol("weight") val trainer2 = (new LogisticRegression).setFitIntercept(true) - .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(false) + .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(false).setWeightCol("weight") val model1 = trainer1.fit(binaryDataset) val model2 = trainer2.fit(binaryDataset) /* - Using the following R code to load the data and train the model using glmnet package. - - library("glmnet") - data <- read.csv("path", header=FALSE) - label = factor(data$V1) - features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - coefficients = coef(glmnet(features,label, family="binomial", alpha = 0.38, lambda = 0.21)) - coefficients - - 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - (Intercept) 0.57734851 - data.V2 -0.05310287 - data.V3 . - data.V4 -0.08849250 - data.V5 -0.15458796 - */ - val interceptR1 = 0.57734851 - val coefficientsR1 = Vectors.dense(-0.05310287, 0.0, -0.08849250, -0.15458796) - - assert(model1.intercept ~== interceptR1 relTol 6E-3) - assert(model1.coefficients ~== coefficientsR1 absTol 5E-3) + Use the following R code to load the data and train the model using glmnet package. - /* - Using the following R code to load the data and train the model using glmnet package. + library("glmnet") + data <- read.csv("path", header=FALSE) + label = factor(data$V1) + w = data$V2 + features = as.matrix(data.frame(data$V3, data$V4, data$V5, data$V6)) + coefficientsStd = coef(glmnet(features, label, weights=w, family="binomial", alpha = 0.38, + lambda = 0.21, standardize=T)) + coefficients = coef(glmnet(features, label, weights=w, family="binomial", alpha = 0.38, + lambda = 0.21, standardize=F)) + coefficientsStd + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 0.49991996 + data.V3 -0.04131110 + data.V4 . + data.V5 -0.08585233 + data.V6 -0.15875400 - library("glmnet") - data <- read.csv("path", header=FALSE) - label = factor(data$V1) - features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - coefficients = coef(glmnet(features,label, family="binomial", alpha = 0.38, lambda = 0.21, - standardize=FALSE)) - coefficients + coefficients + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 0.5024256 + data.V3 . + data.V4 . + data.V5 -0.1846038 + data.V6 -0.0559614 - 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - (Intercept) 0.51555993 - data.V2 . - data.V3 . - data.V4 -0.18807395 - data.V5 -0.05350074 */ - val interceptR2 = 0.51555993 - val coefficientsR2 = Vectors.dense(0.0, 0.0, -0.18807395, -0.05350074) - - assert(model2.intercept ~== interceptR2 relTol 6E-3) - assert(model2.coefficients ~= coefficientsR2 absTol 1E-3) + val coefficientsRStd = Vectors.dense(-0.04131110, 0.0, -0.08585233, -0.15875400) + val interceptRStd = 0.49991996 + val coefficientsR = Vectors.dense(0.0, 0.0, -0.1846038, -0.0559614) + val interceptR = 0.5024256 + + assert(model1.intercept ~== interceptRStd relTol 6E-3) + assert(model1.coefficients ~== coefficientsRStd absTol 5E-3) + assert(model2.intercept ~== interceptR relTol 6E-3) + assert(model2.coefficients ~= coefficientsR absTol 1E-3) } test("binary logistic regression without intercept with ElasticNet regularization") { val trainer1 = (new LogisticRegression).setFitIntercept(false) - .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(true) + .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(true).setWeightCol("weight") val trainer2 = (new LogisticRegression).setFitIntercept(false) - .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(false) + .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(false).setWeightCol("weight") val model1 = trainer1.fit(binaryDataset) val model2 = trainer2.fit(binaryDataset) /* - Using the following R code to load the data and train the model using glmnet package. - - library("glmnet") - data <- read.csv("path", header=FALSE) - label = factor(data$V1) - features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - coefficients = coef(glmnet(features,label, family="binomial", alpha = 0.38, lambda = 0.21, - intercept=FALSE)) - coefficients - - 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - (Intercept) . - data.V2 -0.001005743 - data.V3 0.072577857 - data.V4 -0.081203769 - data.V5 -0.142534158 - */ - val interceptR1 = 0.0 - val coefficientsR1 = Vectors.dense(-0.001005743, 0.072577857, -0.081203769, -0.142534158) - - assert(model1.intercept ~== interceptR1 relTol 1E-3) - assert(model1.coefficients ~= coefficientsR1 absTol 1E-2) + Use the following R code to load the data and train the model using glmnet package. - /* - Using the following R code to load the data and train the model using glmnet package. + library("glmnet") + data <- read.csv("path", header=FALSE) + label = factor(data$V1) + w = data$V2 + features = as.matrix(data.frame(data$V3, data$V4, data$V5, data$V6)) + coefficientsStd = coef(glmnet(features, label, weights=w, family="binomial", alpha = 0.38, + lambda = 0.21, intercept=FALSE, standardize=T)) + coefficients = coef(glmnet(features, label, weights=w, family="binomial", alpha = 0.38, + lambda = 0.21, intercept=FALSE, standardize=F)) + coefficientsStd + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + data.V3 . + data.V4 0.06859390 + data.V5 -0.07900058 + data.V6 -0.14684320 - library("glmnet") - data <- read.csv("path", header=FALSE) - label = factor(data$V1) - features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - coefficients = coef(glmnet(features,label, family="binomial", alpha = 0.38, lambda = 0.21, - intercept=FALSE, standardize=FALSE)) - coefficients + coefficients + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + data.V3 . + data.V4 0.03060637 + data.V5 -0.11126742 + data.V6 . - 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - (Intercept) . - data.V2 . - data.V3 0.03345223 - data.V4 -0.11304532 - data.V5 . */ - val interceptR2 = 0.0 - val coefficientsR2 = Vectors.dense(0.0, 0.03345223, -0.11304532, 0.0) + val coefficientsRStd = Vectors.dense(0.0, 0.06859390, -0.07900058, -0.14684320) + val coefficientsR = Vectors.dense(0.0, 0.03060637, -0.11126742, 0.0) - assert(model2.intercept ~== interceptR2 absTol 1E-3) - assert(model2.coefficients ~= coefficientsR2 absTol 1E-3) + assert(model1.intercept ~== 0.0 relTol 1E-3) + assert(model1.coefficients ~= coefficientsRStd absTol 1E-2) + assert(model2.intercept ~== 0.0 absTol 1E-3) + assert(model2.coefficients ~= coefficientsR absTol 1E-3) } test("binary logistic regression with intercept with strong L1 regularization") { - val trainer1 = (new LogisticRegression).setFitIntercept(true) + val trainer1 = (new LogisticRegression).setFitIntercept(true).setWeightCol("weight") .setElasticNetParam(1.0).setRegParam(6.0).setStandardization(true) - val trainer2 = (new LogisticRegression).setFitIntercept(true) + val trainer2 = (new LogisticRegression).setFitIntercept(true).setWeightCol("weight") .setElasticNetParam(1.0).setRegParam(6.0).setStandardization(false) val model1 = trainer1.fit(binaryDataset) val model2 = trainer2.fit(binaryDataset) - val histogram = binaryDataset.rdd.map { case Row(label: Double, features: Vector) => label } + val histogram = binaryDataset.as[Instance].rdd.map { i => (i.label, i.weight)} .treeAggregate(new MultiClassSummarizer)( seqOp = (c, v) => (c, v) match { - case (classSummarizer: MultiClassSummarizer, label: Double) => classSummarizer.add(label) + case (classSummarizer: MultiClassSummarizer, (label: Double, weight: Double)) => + classSummarizer.add(label, weight) }, combOp = (c1, c2) => (c1, c2) match { case (classSummarizer1: MultiClassSummarizer, classSummarizer2: MultiClassSummarizer) => @@ -989,25 +955,26 @@ class LogisticRegressionSuite assert(model2.coefficients ~= coefficientsTheory absTol 1E-6) /* - TODO: why is this needed? The correctness of L1 regularization is already checked elsewhere Using the following R code to load the data and train the model using glmnet package. library("glmnet") data <- read.csv("path", header=FALSE) label = factor(data$V1) - features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - coefficients = coef(glmnet(features,label, family="binomial", alpha = 1.0, lambda = 6.0)) + w = data$V2 + features = as.matrix(data.frame(data$V3, data$V4, data$V5, data$V6)) + coefficients = coef(glmnet(features, label, weights=w, family="binomial", alpha = 1.0, + lambda = 6.0)) coefficients 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - (Intercept) -0.2480643 - data.V2 0.0000000 - data.V3 . - data.V4 . - data.V5 . + s0 + (Intercept) -0.2516986 + data.V3 0.0000000 + data.V4 . + data.V5 . + data.V6 . */ - val interceptR = -0.248065 + val interceptR = -0.2516986 val coefficientsR = Vectors.dense(0.0, 0.0, 0.0, 0.0) assert(model1.intercept ~== interceptR relTol 1E-5) @@ -1015,9 +982,9 @@ class LogisticRegressionSuite } test("multinomial logistic regression with intercept with strong L1 regularization") { - val trainer1 = (new LogisticRegression).setFitIntercept(true) + val trainer1 = (new LogisticRegression).setFitIntercept(true).setWeightCol("weight") .setElasticNetParam(1.0).setRegParam(6.0).setStandardization(true) - val trainer2 = (new LogisticRegression).setFitIntercept(true) + val trainer2 = (new LogisticRegression).setFitIntercept(true).setWeightCol("weight") .setElasticNetParam(1.0).setRegParam(6.0).setStandardization(false) val sqlContext = multinomialDataset.sqlContext @@ -1025,16 +992,17 @@ class LogisticRegressionSuite val model1 = trainer1.fit(multinomialDataset) val model2 = trainer2.fit(multinomialDataset) - val histogram = multinomialDataset.as[LabeledPoint].rdd.map(_.label) + val histogram = multinomialDataset.as[Instance].rdd.map(i => (i.label, i.weight)) .treeAggregate(new MultiClassSummarizer)( seqOp = (c, v) => (c, v) match { - case (classSummarizer: MultiClassSummarizer, label: Double) => classSummarizer.add(label) + case (classSummarizer: MultiClassSummarizer, (label: Double, weight: Double)) => + classSummarizer.add(label, weight) }, combOp = (c1, c2) => (c1, c2) match { case (classSummarizer1: MultiClassSummarizer, classSummarizer2: MultiClassSummarizer) => classSummarizer1.merge(classSummarizer2) }).histogram - val numFeatures = multinomialDataset.as[LabeledPoint].first().features.size + val numFeatures = multinomialDataset.as[Instance].first().features.size val numClasses = histogram.length /* @@ -1068,52 +1036,58 @@ class LogisticRegressionSuite test("multinomial logistic regression with intercept without regularization") { val trainer1 = (new LogisticRegression).setFitIntercept(true) - .setElasticNetParam(0.0).setRegParam(0.0).setStandardization(true).setMaxIter(100) + .setElasticNetParam(0.0).setRegParam(0.0).setStandardization(true).setWeightCol("weight") val trainer2 = (new LogisticRegression).setFitIntercept(true) - .setElasticNetParam(0.0).setRegParam(0.0).setStandardization(false) + .setElasticNetParam(0.0).setRegParam(0.0).setStandardization(false).setWeightCol("weight") val model1 = trainer1.fit(multinomialDataset) val model2 = trainer2.fit(multinomialDataset) /* - Using the following R code to load the data and train the model using glmnet package. - > library("glmnet") - > data <- read.csv("path", header=FALSE) - > label = as.factor(data$V1) - > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - > coefficients = coef(glmnet(features, label, family="multinomial", alpha = 0, lambda = 0)) - > coefficients - $`0` - 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - -2.24493379 - V2 0.25096771 - V3 -0.03915938 - V4 0.14766639 - V5 0.36810817 - $`1` - 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - 0.3778931 - V2 -0.3327489 - V3 0.8893666 - V4 -0.2306948 - V5 -0.4442330 - $`2` - 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - 1.86704066 - V2 0.08178121 - V3 -0.85020722 - V4 0.08302840 - V5 0.07612480 - */ + Use the following R code to load the data and train the model using glmnet package. + library("glmnet") + data <- read.csv("path", header=FALSE) + label = as.factor(data$V1) + w = data$V2 + features = as.matrix(data.frame(data$V3, data$V4, data$V5, data$V6)) + coefficients = coef(glmnet(features, label, weights=w, family="multinomial", + alpha = 0, lambda = 0)) + coefficients + $`0` + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + -2.10320093 + data.V3 0.24337896 + data.V4 -0.05916156 + data.V5 0.14446790 + data.V6 0.35976165 + + $`1` + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + 0.3394473 + data.V3 -0.3443375 + data.V4 0.9181331 + data.V5 -0.2283959 + data.V6 -0.4388066 + + $`2` + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + 1.76375361 + data.V3 0.10095851 + data.V4 -0.85897154 + data.V5 0.08392798 + data.V6 0.07904499 + + + */ val coefficientsR = new DenseMatrix(3, 4, Array( - 0.2509677, -0.0391594, 0.1476664, 0.3681082, - -0.3327489, 0.8893666, -0.2306948, -0.4442330, - 0.0817812, -0.8502072, 0.0830284, 0.0761248), isTransposed = true) - val interceptsR = Vectors.dense(-2.2449338, 0.3778931, 1.8670407) + 0.24337896, -0.05916156, 0.14446790, 0.35976165, + -0.3443375, 0.9181331, -0.2283959, -0.4388066, + 0.10095851, -0.85897154, 0.08392798, 0.07904499), isTransposed = true) + val interceptsR = Vectors.dense(-2.10320093, 0.3394473, 1.76375361) assert(model1.coefficientMatrix ~== coefficientsR relTol 0.05) assert(model1.coefficientMatrix.toArray.sum ~== 0.0 absTol eps) @@ -1128,52 +1102,57 @@ class LogisticRegressionSuite test("multinomial logistic regression without intercept without regularization") { val trainer1 = (new LogisticRegression).setFitIntercept(false) - .setElasticNetParam(0.0).setRegParam(0.0).setStandardization(true) + .setElasticNetParam(0.0).setRegParam(0.0).setStandardization(true).setWeightCol("weight") val trainer2 = (new LogisticRegression).setFitIntercept(false) - .setElasticNetParam(0.0).setRegParam(0.0).setStandardization(false) + .setElasticNetParam(0.0).setRegParam(0.0).setStandardization(false).setWeightCol("weight") val model1 = trainer1.fit(multinomialDataset) val model2 = trainer2.fit(multinomialDataset) /* - Using the following R code to load the data and train the model using glmnet package. - library("glmnet") - data <- read.csv("path", header=FALSE) - label = as.factor(data$V1) - features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - coefficients = coef(glmnet(features, label, family="multinomial", alpha = 0, lambda = 0, - intercept=F)) - > coefficients - $`0` - 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - . - V2 0.06992464 - V3 -0.36562784 - V4 0.12142680 - V5 0.32052211 - $`1` - 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - . - V2 -0.3036269 - V3 0.9449630 - V4 -0.2271038 - V5 -0.4364839 - $`2` - 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - . - V2 0.2337022 - V3 -0.5793351 - V4 0.1056770 - V5 0.1159618 - */ + Use the following R code to load the data and train the model using glmnet package. + + library("glmnet") + data <- read.csv("path", header=FALSE) + label = as.factor(data$V1) + w = data$V2 + features = as.matrix(data.frame(data$V3, data$V4, data$V5, data$V6)) + coefficients = coef(glmnet(features, label, weights=w, family="multinomial", alpha = 0, + lambda = 0, intercept=F)) + coefficients + $`0` + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + . + data.V3 0.07276291 + data.V4 -0.36325496 + data.V5 0.12015088 + data.V6 0.31397340 + + $`1` + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + . + data.V3 -0.3180040 + data.V4 0.9679074 + data.V5 -0.2252219 + data.V6 -0.4319914 + $`2` + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + . + data.V3 0.2452411 + data.V4 -0.6046524 + data.V5 0.1050710 + data.V6 0.1180180 + + + */ val coefficientsR = new DenseMatrix(3, 4, Array( - 0.0699246, -0.3656278, 0.1214268, 0.3205221, - -0.3036269, 0.9449630, -0.2271038, -0.4364839, - 0.2337022, -0.5793351, 0.1056770, 0.1159618), isTransposed = true) + 0.07276291, -0.36325496, 0.12015088, 0.31397340, + -0.3180040, 0.9679074, -0.2252219, -0.4319914, + 0.2452411, -0.6046524, 0.1050710, 0.1180180), isTransposed = true) assert(model1.coefficientMatrix ~== coefficientsR relTol 0.05) assert(model1.coefficientMatrix.toArray.sum ~== 0.0 absTol eps) @@ -1190,92 +1169,95 @@ class LogisticRegressionSuite // use tighter constraints because OWL-QN solver takes longer to converge val trainer1 = (new LogisticRegression).setFitIntercept(true) .setElasticNetParam(1.0).setRegParam(0.05).setStandardization(true) - .setMaxIter(300).setTol(1e-10) + .setMaxIter(300).setTol(1e-10).setWeightCol("weight") val trainer2 = (new LogisticRegression).setFitIntercept(true) .setElasticNetParam(1.0).setRegParam(0.05).setStandardization(false) - .setMaxIter(300).setTol(1e-10) + .setMaxIter(300).setTol(1e-10).setWeightCol("weight") val model1 = trainer1.fit(multinomialDataset) val model2 = trainer2.fit(multinomialDataset) /* - Use the following R code to load the data and train the model using glmnet package. - library("glmnet") - data <- read.csv("path", header=FALSE) - label = as.factor(data$V1) - features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - coefficientsStd = coef(glmnet(features, label, family="multinomial", alpha = 1, - lambda = 0.05, standardization=T)) - coefficients = coef(glmnet(features, label, family="multinomial", alpha = 1, lambda = 0.05, - standardization=F)) - > coefficientsStd - $`0` - 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - -0.68988825 - V2 . - V3 . - V4 . - V5 0.09404023 - - $`1` - 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - -0.2303499 - V2 -0.1232443 - V3 0.3258380 - V4 -0.1564688 - V5 -0.2053965 - - $`2` - 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - 0.9202381 - V2 . - V3 -0.4803856 - V4 . - V5 . - - > coefficients - $`0` - 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - -0.44893320 - V2 . - V3 . - V4 0.01933812 - V5 0.03666044 - - $`1` - 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - 0.7376760 - V2 -0.0577182 - V3 . - V4 -0.2081718 - V5 -0.1304592 - - $`2` - 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - -0.2887428 - V2 . - V3 . - V4 . - V5 . - */ + Use the following R code to load the data and train the model using glmnet package. - val coefficientsRStd = new DenseMatrix(3, 4, Array( - 0.0, 0.0, 0.0, 0.09404023, - -0.1232443, 0.3258380, -0.1564688, -0.2053965, - 0.0, -0.4803856, 0.0, 0.0), isTransposed = true) - val interceptsRStd = Vectors.dense(-0.68988825, -0.2303499, 0.9202381) + library("glmnet") + data <- read.csv("path", header=FALSE) + label = as.factor(data$V1) + w = data$V2 + features = as.matrix(data.frame(data$V3, data$V4, data$V5, data$V6)) + coefficientsStd = coef(glmnet(features, label, weights=w, family="multinomial", + alpha = 1, lambda = 0.05, standardize=T)) + coefficients = coef(glmnet(features, label, weights=w, family="multinomial", alpha = 1, + lambda = 0.05, standardize=F)) + coefficientsStd + $`0` + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + -0.62244703 + data.V3 . + data.V4 . + data.V5 . + data.V6 0.08419825 + $`1` + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + -0.2804845 + data.V3 -0.1336960 + data.V4 0.3717091 + data.V5 -0.1530363 + data.V6 -0.2035286 + + $`2` + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + 0.9029315 + data.V3 . + data.V4 -0.4629737 + data.V5 . + data.V6 . + + + coefficients + $`0` + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + -0.44215290 + data.V3 . + data.V4 . + data.V5 0.01767089 + data.V6 0.02542866 + + $`1` + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + 0.76308326 + data.V3 -0.06818576 + data.V4 . + data.V5 -0.20446351 + data.V6 -0.13017924 + + $`2` + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + -0.3209304 + data.V3 . + data.V4 . + data.V5 . + data.V6 . + + + */ + val coefficientsRStd = new DenseMatrix(3, 4, Array( + 0.0, 0.0, 0.0, 0.08419825, + -0.1336960, 0.3717091, -0.1530363, -0.2035286, + 0.0, -0.4629737, 0.0, 0.0), isTransposed = true) + val interceptsRStd = Vectors.dense(-0.62244703, -0.2804845, 0.9029315) val coefficientsR = new DenseMatrix(3, 4, Array( - 0.0, 0.0, 0.01933812, 0.03666044, - -0.0577182, 0.0, -0.2081718, -0.1304592, + 0.0, 0.0, 0.01767089, 0.02542866, + -0.06818576, 0.0, -0.20446351, -0.13017924, 0.0, 0.0, 0.0, 0.0), isTransposed = true) - val interceptsR = Vectors.dense(-0.44893320, 0.7376760, -0.2887428) + val interceptsR = Vectors.dense(-0.44215290, 0.76308326, -0.3209304) assert(model1.coefficientMatrix ~== coefficientsRStd absTol 0.02) assert(model1.interceptVector ~== interceptsRStd relTol 0.1) @@ -1287,87 +1269,91 @@ class LogisticRegressionSuite test("multinomial logistic regression without intercept with L1 regularization") { val trainer1 = (new LogisticRegression).setFitIntercept(false) - .setElasticNetParam(1.0).setRegParam(0.05).setStandardization(true) + .setElasticNetParam(1.0).setRegParam(0.05).setStandardization(true).setWeightCol("weight") val trainer2 = (new LogisticRegression).setFitIntercept(false) - .setElasticNetParam(1.0).setRegParam(0.05).setStandardization(false) + .setElasticNetParam(1.0).setRegParam(0.05).setStandardization(false).setWeightCol("weight") val model1 = trainer1.fit(multinomialDataset) val model2 = trainer2.fit(multinomialDataset) /* Use the following R code to load the data and train the model using glmnet package. + library("glmnet") data <- read.csv("path", header=FALSE) label = as.factor(data$V1) - features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - coefficientsStd = coef(glmnet(features, label, family="multinomial", alpha = 1, - lambda = 0.05, intercept=F, standardization=T)) - coefficients = coef(glmnet(features, label, family="multinomial", alpha = 1, lambda = 0.05, - intercept=F, standardization=F)) - > coefficientsStd + w = data$V2 + features = as.matrix(data.frame(data$V3, data$V4, data$V5, data$V6)) + coefficientsStd = coef(glmnet(features, label, weights=w, family="multinomial", alpha = 1, + lambda = 0.05, intercept=F, standardize=T)) + coefficients = coef(glmnet(features, label, weights=w, family="multinomial", alpha = 1, + lambda = 0.05, intercept=F, standardize=F)) + coefficientsStd $`0` 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - . - V2 . - V3 . - V4 . - V5 0.01525105 + s0 + . + data.V3 . + data.V4 . + data.V5 . + data.V6 0.01144225 $`1` 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - . - V2 -0.1502410 - V3 0.5134658 - V4 -0.1601146 - V5 -0.2500232 + s0 + . + data.V3 -0.1678787 + data.V4 0.5385351 + data.V5 -0.1573039 + data.V6 -0.2471624 $`2` 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - . - V2 0.003301875 - V3 . - V4 . - V5 . - - > coefficients + s0 + . + data.V3 . + data.V4 . + data.V5 . + data.V6 . + + + coefficients $`0` 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - . - V2 . - V3 . - V4 . - V5 . + s0 + . + data.V3 . + data.V4 . + data.V5 . + data.V6 . $`1` 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - . - V2 . - V3 0.1943624 - V4 -0.1902577 - V5 -0.1028789 + s0 + . + data.V3 . + data.V4 0.1929409 + data.V5 -0.1889121 + data.V6 -0.1010413 $`2` 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - . - V2 . - V3 . - V4 . - V5 . - */ + s0 + . + data.V3 . + data.V4 . + data.V5 . + data.V6 . + + */ val coefficientsRStd = new DenseMatrix(3, 4, Array( - 0.0, 0.0, 0.0, 0.01525105, - -0.1502410, 0.5134658, -0.1601146, -0.2500232, - 0.003301875, 0.0, 0.0, 0.0), isTransposed = true) + 0.0, 0.0, 0.0, 0.01144225, + -0.1678787, 0.5385351, -0.1573039, -0.2471624, + 0.0, 0.0, 0.0, 0.0), isTransposed = true) val coefficientsR = new DenseMatrix(3, 4, Array( 0.0, 0.0, 0.0, 0.0, - 0.0, 0.1943624, -0.1902577, -0.1028789, + 0.0, 0.1929409, -0.1889121, -0.1010413, 0.0, 0.0, 0.0, 0.0), isTransposed = true) assert(model1.coefficientMatrix ~== coefficientsRStd absTol 0.01) @@ -1380,92 +1366,95 @@ class LogisticRegressionSuite test("multinomial logistic regression with intercept with L2 regularization") { val trainer1 = (new LogisticRegression).setFitIntercept(true) - .setElasticNetParam(0.0).setRegParam(0.1).setStandardization(true) + .setElasticNetParam(0.0).setRegParam(0.1).setStandardization(true).setWeightCol("weight") val trainer2 = (new LogisticRegression).setFitIntercept(true) - .setElasticNetParam(0.0).setRegParam(0.1).setStandardization(false) + .setElasticNetParam(0.0).setRegParam(0.1).setStandardization(false).setWeightCol("weight") val model1 = trainer1.fit(multinomialDataset) val model2 = trainer2.fit(multinomialDataset) /* Use the following R code to load the data and train the model using glmnet package. + library("glmnet") data <- read.csv("path", header=FALSE) label = as.factor(data$V1) - features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - coefficientsStd = coef(glmnet(features, label, family="multinomial", alpha = 0, - lambda = 0.1, intercept=T, standardization=T)) - coefficients = coef(glmnet(features, label, family="multinomial", alpha = 0, - lambda = 0.1, intercept=T, standardization=F)) - > coefficientsStd + w = data$V2 + features = as.matrix(data.frame( data$V3, data$V4, data$V5, data$V6)) + coefficientsStd = coef(glmnet(features, label, weights=w, family="multinomial", + alpha = 0, lambda = 0.1, intercept=T, standardize=T)) + coefficients = coef(glmnet(features, label, weights=w, family="multinomial", alpha = 0, + lambda = 0.1, intercept=T, standardize=F)) + coefficientsStd $`0` 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - -1.70040424 - V2 0.17576070 - V3 0.01527894 - V4 0.10216108 - V5 0.26099531 + s0 + -1.5898288335 + data.V3 0.1691226336 + data.V4 0.0002983651 + data.V5 0.1001732896 + data.V6 0.2554575585 $`1` 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - 0.2438590 - V2 -0.2238875 - V3 0.5967610 - V4 -0.1555496 - V5 -0.3010479 + s0 + 0.2125746 + data.V3 -0.2304586 + data.V4 0.6153492 + data.V5 -0.1537017 + data.V6 -0.2975443 $`2` 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - 1.45654525 - V2 0.04812679 - V3 -0.61203992 - V4 0.05338850 - V5 0.04005258 - - > coefficients + s0 + 1.37725427 + data.V3 0.06133600 + data.V4 -0.61564761 + data.V5 0.05352840 + data.V6 0.04208671 + + + coefficients $`0` 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - -1.65488543 - V2 0.15715048 - V3 0.01992903 - V4 0.12428858 - V5 0.22130317 + s0 + -1.5681088 + data.V3 0.1508182 + data.V4 0.0121955 + data.V5 0.1217930 + data.V6 0.2162850 $`1` 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - 1.1297533 - V2 -0.1974768 - V3 0.2776373 - V4 -0.1869445 - V5 -0.2510320 + s0 + 1.1217130 + data.V3 -0.2028984 + data.V4 0.2862431 + data.V5 -0.1843559 + data.V6 -0.2481218 $`2` 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - 0.52513212 - V2 0.04032627 - V3 -0.29756637 - V4 0.06265594 - V5 0.02972883 - */ + s0 + 0.44639579 + data.V3 0.05208012 + data.V4 -0.29843864 + data.V5 0.06256289 + data.V6 0.03183676 - val coefficientsRStd = new DenseMatrix(3, 4, Array( - 0.17576070, 0.01527894, 0.10216108, 0.26099531, - -0.2238875, 0.5967610, -0.1555496, -0.3010479, - 0.04812679, -0.61203992, 0.05338850, 0.04005258), isTransposed = true) - val interceptsRStd = Vectors.dense(-1.70040424, 0.2438590, 1.45654525) + */ + val coefficientsRStd = new DenseMatrix(3, 4, Array( + 0.1691226336, 0.0002983651, 0.1001732896, 0.2554575585, + -0.2304586, 0.6153492, -0.1537017, -0.2975443, + 0.06133600, -0.61564761, 0.05352840, 0.04208671), isTransposed = true) + val interceptsRStd = Vectors.dense(-1.5898288335, 0.2125746, 1.37725427) val coefficientsR = new DenseMatrix(3, 4, Array( - 0.15715048, 0.01992903, 0.12428858, 0.22130317, - -0.1974768, 0.2776373, -0.1869445, -0.2510320, - 0.04032627, -0.29756637, 0.06265594, 0.02972883), isTransposed = true) - val interceptsR = Vectors.dense(-1.65488543, 1.1297533, 0.52513212) + 0.1508182, 0.0121955, 0.1217930, 0.2162850, + -0.2028984, 0.2862431, -0.1843559, -0.2481218, + 0.05208012, -0.29843864, 0.06256289, 0.03183676), isTransposed = true) + val interceptsR = Vectors.dense(-1.5681088, 1.1217130, 0.44639579) - assert(model1.coefficientMatrix ~== coefficientsRStd relTol 0.05) + assert(model1.coefficientMatrix ~== coefficientsRStd absTol 0.001) assert(model1.interceptVector ~== interceptsRStd relTol 0.05) assert(model1.interceptVector.toArray.sum ~== 0.0 absTol eps) assert(model2.coefficientMatrix ~== coefficientsR relTol 0.05) @@ -1475,86 +1464,92 @@ class LogisticRegressionSuite test("multinomial logistic regression without intercept with L2 regularization") { val trainer1 = (new LogisticRegression).setFitIntercept(false) - .setElasticNetParam(0.0).setRegParam(0.1).setStandardization(true) + .setElasticNetParam(0.0).setRegParam(0.1).setStandardization(true).setWeightCol("weight") val trainer2 = (new LogisticRegression).setFitIntercept(false) - .setElasticNetParam(0.0).setRegParam(0.1).setStandardization(false) + .setElasticNetParam(0.0).setRegParam(0.1).setStandardization(false).setWeightCol("weight") val model1 = trainer1.fit(multinomialDataset) val model2 = trainer2.fit(multinomialDataset) /* Use the following R code to load the data and train the model using glmnet package. + library("glmnet") data <- read.csv("path", header=FALSE) label = as.factor(data$V1) - features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - coefficientsStd = coef(glmnet(features, label, family="multinomial", alpha = 0, - lambda = 0.1, intercept=F, standardization=T)) - coefficients = coef(glmnet(features, label, family="multinomial", alpha = 0, - lambda = 0.1, intercept=F, standardization=F)) - > coefficientsStd + w = data$V2 + features = as.matrix(data.frame(data$V3, data$V4, data$V5, data$V6)) + coefficientsStd = coef(glmnet(features, label, weights=w, family="multinomial", alpha = 0, + lambda = 0.1, intercept=F, standardize=T)) + coefficients = coef(glmnet(features, label, weights=w, family="multinomial", alpha = 0, + lambda = 0.1, intercept=F, standardize=F)) + coefficientsStd $`0` 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - . - V2 0.03904171 - V3 -0.23354322 - V4 0.08288096 - V5 0.22706393 + s0 + . + data.V3 0.04048126 + data.V4 -0.23075758 + data.V5 0.08228864 + data.V6 0.22277648 $`1` 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - . - V2 -0.2061848 - V3 0.6341398 - V4 -0.1530059 - V5 -0.2958455 + s0 + . + data.V3 -0.2149745 + data.V4 0.6478666 + data.V5 -0.1515158 + data.V6 -0.2930498 $`2` 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - . - V2 0.16714312 - V3 -0.40059658 - V4 0.07012496 - V5 0.06878158 - > coefficients + s0 + . + data.V3 0.17449321 + data.V4 -0.41710901 + data.V5 0.06922716 + data.V6 0.07027332 + + + coefficients $`0` 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - . - V2 -0.005704542 - V3 -0.144466409 - V4 0.092080736 - V5 0.182927657 + s0 + . + data.V3 -0.003949652 + data.V4 -0.142982415 + data.V5 0.091439598 + data.V6 0.179286241 $`1` 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - . - V2 -0.08469036 - V3 0.38996748 - V4 -0.16468436 - V5 -0.22522976 + s0 + . + data.V3 -0.09071124 + data.V4 0.39752531 + data.V5 -0.16233832 + data.V6 -0.22206059 $`2` 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - . - V2 0.09039490 - V3 -0.24550107 - V4 0.07260362 - V5 0.04230210 + s0 + . + data.V3 0.09466090 + data.V4 -0.25454290 + data.V5 0.07089872 + data.V6 0.04277435 + + */ val coefficientsRStd = new DenseMatrix(3, 4, Array( - 0.03904171, -0.23354322, 0.08288096, 0.2270639, - -0.2061848, 0.6341398, -0.1530059, -0.2958455, - 0.16714312, -0.40059658, 0.07012496, 0.06878158), isTransposed = true) + 0.04048126, -0.23075758, 0.08228864, 0.22277648, + -0.2149745, 0.6478666, -0.1515158, -0.2930498, + 0.17449321, -0.41710901, 0.06922716, 0.07027332), isTransposed = true) val coefficientsR = new DenseMatrix(3, 4, Array( - -0.005704542, -0.144466409, 0.092080736, 0.182927657, - -0.08469036, 0.38996748, -0.16468436, -0.22522976, - 0.0903949, -0.24550107, 0.07260362, 0.0423021), isTransposed = true) + -0.003949652, -0.142982415, 0.091439598, 0.179286241, + -0.09071124, 0.39752531, -0.16233832, -0.22206059, + 0.09466090, -0.25454290, 0.07089872, 0.04277435), isTransposed = true) assert(model1.coefficientMatrix ~== coefficientsRStd absTol 0.01) assert(model1.interceptVector.toArray === Array.fill(3)(0.0)) @@ -1565,10 +1560,10 @@ class LogisticRegressionSuite } test("multinomial logistic regression with intercept with elasticnet regularization") { - val trainer1 = (new LogisticRegression).setFitIntercept(true) + val trainer1 = (new LogisticRegression).setFitIntercept(true).setWeightCol("weight") .setElasticNetParam(0.5).setRegParam(0.1).setStandardization(true) .setMaxIter(300).setTol(1e-10) - val trainer2 = (new LogisticRegression).setFitIntercept(true) + val trainer2 = (new LogisticRegression).setFitIntercept(true).setWeightCol("weight") .setElasticNetParam(0.5).setRegParam(0.1).setStandardization(false) .setMaxIter(300).setTol(1e-10) @@ -1576,82 +1571,85 @@ class LogisticRegressionSuite val model2 = trainer2.fit(multinomialDataset) /* Use the following R code to load the data and train the model using glmnet package. + library("glmnet") data <- read.csv("path", header=FALSE) label = as.factor(data$V1) - features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - coefficientsStd = coef(glmnet(features, label, family="multinomial", alpha = 0.5, - lambda = 0.1, intercept=T, standardization=T)) - coefficients = coef(glmnet(features, label, family="multinomial", alpha = 0.5, - lambda = 0.1, intercept=T, standardization=F)) - > coefficientsStd + w = data$V2 + features = as.matrix(data.frame(data$V3, data$V4, data$V5, data$V6)) + coefficientsStd = coef(glmnet(features, label, weights=w, family="multinomial", alpha = 0.5, + lambda = 0.1, intercept=T, standardize=T)) + coefficients = coef(glmnet(features, label, weights=w, family="multinomial", alpha = 0.5, + lambda = 0.1, intercept=T, standardize=F)) + coefficientsStd $`0` 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - -0.5521819483 - V2 0.0003092611 - V3 . - V4 . - V5 0.0913818490 + s0 + -0.50133383 + data.V3 . + data.V4 . + data.V5 . + data.V6 0.08351653 $`1` 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - -0.27531989 - V2 -0.09790029 - V3 0.28502034 - V4 -0.12416487 - V5 -0.16513373 + s0 + -0.3151913 + data.V3 -0.1058702 + data.V4 0.3183251 + data.V5 -0.1212969 + data.V6 -0.1629778 $`2` 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - 0.8275018 - V2 . - V3 -0.4044859 - V4 . - V5 . - - > coefficients + s0 + 0.8165252 + data.V3 . + data.V4 -0.3943069 + data.V5 . + data.V6 . + + + coefficients $`0` 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - -0.39876213 - V2 . - V3 . - V4 0.02547520 - V5 0.03893991 + s0 + -0.38857157 + data.V3 . + data.V4 . + data.V5 0.02384198 + data.V6 0.03127749 $`1` 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - 0.61089869 - V2 -0.04224269 - V3 . - V4 -0.18923970 - V5 -0.09104249 + s0 + 0.62492165 + data.V3 -0.04949061 + data.V4 . + data.V5 -0.18584462 + data.V6 -0.08952455 $`2` 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - -0.2121366 - V2 . - V3 . - V4 . - V5 . - */ + s0 + -0.2363501 + data.V3 . + data.V4 . + data.V5 . + data.V6 . - val coefficientsRStd = new DenseMatrix(3, 4, Array( - 0.0003092611, 0.0, 0.0, 0.091381849, - -0.09790029, 0.28502034, -0.12416487, -0.16513373, - 0.0, -0.4044859, 0.0, 0.0), isTransposed = true) - val interceptsRStd = Vectors.dense(-0.5521819483, -0.27531989, 0.8275018) + */ + val coefficientsRStd = new DenseMatrix(3, 4, Array( + 0.0, 0.0, 0.0, 0.08351653, + -0.1058702, 0.3183251, -0.1212969, -0.1629778, + 0.0, -0.3943069, 0.0, 0.0), isTransposed = true) + val interceptsRStd = Vectors.dense(-0.50133383, -0.3151913, 0.8165252) val coefficientsR = new DenseMatrix(3, 4, Array( - 0.0, 0.0, 0.0254752, 0.03893991, - -0.04224269, 0.0, -0.1892397, -0.09104249, + 0.0, 0.0, 0.02384198, 0.03127749, + -0.04949061, 0.0, -0.18584462, -0.08952455, 0.0, 0.0, 0.0, 0.0), isTransposed = true) - val interceptsR = Vectors.dense(-0.39876213, 0.61089869, -0.2121366) + val interceptsR = Vectors.dense(-0.38857157, 0.62492165, -0.2363501) assert(model1.coefficientMatrix ~== coefficientsRStd absTol 0.01) assert(model1.interceptVector ~== interceptsRStd absTol 0.01) @@ -1662,10 +1660,10 @@ class LogisticRegressionSuite } test("multinomial logistic regression without intercept with elasticnet regularization") { - val trainer1 = (new LogisticRegression).setFitIntercept(false) + val trainer1 = (new LogisticRegression).setFitIntercept(false).setWeightCol("weight") .setElasticNetParam(0.5).setRegParam(0.1).setStandardization(true) .setMaxIter(300).setTol(1e-10) - val trainer2 = (new LogisticRegression).setFitIntercept(false) + val trainer2 = (new LogisticRegression).setFitIntercept(false).setWeightCol("weight") .setElasticNetParam(0.5).setRegParam(0.1).setStandardization(false) .setMaxIter(300).setTol(1e-10) @@ -1673,78 +1671,83 @@ class LogisticRegressionSuite val model2 = trainer2.fit(multinomialDataset) /* Use the following R code to load the data and train the model using glmnet package. + library("glmnet") data <- read.csv("path", header=FALSE) label = as.factor(data$V1) - features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - coefficientsStd = coef(glmnet(features, label, family="multinomial", alpha = 0.5, - lambda = 0.1, intercept=F, standardization=T)) - coefficients = coef(glmnet(features, label, family="multinomial", alpha = 0.5, - lambda = 0.1, intercept=F, standardization=F)) - > coefficientsStd + w = data$V2 + features = as.matrix(data.frame(data$V3, data$V4, data$V5, data$V6)) + coefficientsStd = coef(glmnet(features, label, weights=w, family="multinomial", alpha = 0.5, + lambda = 0.1, intercept=F, standardize=T)) + coefficients = coef(glmnet(features, label, weights=w, family="multinomial", alpha = 0.5, + lambda = 0.1, intercept=F, standardize=F)) + coefficientsStd $`0` 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - . - V2 . - V3 . - V4 . - V5 0.03543706 + s0 + . + data.V3 . + data.V4 . + data.V5 . + data.V6 0.03238285 $`1` 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - . - V2 -0.1187387 - V3 0.4025482 - V4 -0.1270969 - V5 -0.1918386 + s0 + . + data.V3 -0.1328284 + data.V4 0.4219321 + data.V5 -0.1247544 + data.V6 -0.1893318 $`2` 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - . - V2 0.00774365 - V3 . - V4 . - V5 . - - > coefficients + s0 + . + data.V3 0.004572312 + data.V4 . + data.V5 . + data.V6 . + + + coefficients $`0` 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - . - V2 . - V3 . - V4 . - V5 . + s0 + . + data.V3 . + data.V4 . + data.V5 . + data.V6 . $`1` 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - . - V2 . - V3 0.14666497 - V4 -0.16570638 - V5 -0.05982875 + s0 + . + data.V3 . + data.V4 0.14571623 + data.V5 -0.16456351 + data.V6 -0.05866264 $`2` 5 x 1 sparse Matrix of class "dgCMatrix" - s0 - . - V2 . - V3 . - V4 . - V5 . + s0 + . + data.V3 . + data.V4 . + data.V5 . + data.V6 . + + */ val coefficientsRStd = new DenseMatrix(3, 4, Array( - 0.0, 0.0, 0.0, 0.03543706, - -0.1187387, 0.4025482, -0.1270969, -0.1918386, - 0.0, 0.0, 0.0, 0.00774365), isTransposed = true) + 0.0, 0.0, 0.0, 0.03238285, + -0.1328284, 0.4219321, -0.1247544, -0.1893318, + 0.004572312, 0.0, 0.0, 0.0), isTransposed = true) val coefficientsR = new DenseMatrix(3, 4, Array( 0.0, 0.0, 0.0, 0.0, - 0.0, 0.14666497, -0.16570638, -0.05982875, + 0.0, 0.14571623, -0.16456351, -0.05866264, 0.0, 0.0, 0.0, 0.0), isTransposed = true) assert(model1.coefficientMatrix ~== coefficientsRStd absTol 0.01) From 7ab86244e30ca81eb4fa40ea77b4c2b8881cbab2 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Fri, 14 Oct 2016 13:22:59 -0700 Subject: [PATCH 293/306] [SPARK-17620][SQL] Determine Serde by hive.default.fileformat when Creating Hive Serde Tables ## What changes were proposed in this pull request? Make sure the hive.default.fileformat is used to when creating the storage format metadata. Output ``` SQL scala> spark.sql("SET hive.default.fileformat=orc") res1: org.apache.spark.sql.DataFrame = [key: string, value: string] scala> spark.sql("CREATE TABLE tmp_default(id INT)") res2: org.apache.spark.sql.DataFrame = [] ``` Before ```SQL scala> spark.sql("DESC FORMATTED tmp_default").collect.foreach(println) .. [# Storage Information,,] [SerDe Library:,org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe,] [InputFormat:,org.apache.hadoop.hive.ql.io.orc.OrcInputFormat,] [OutputFormat:,org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat,] [Compressed:,No,] [Storage Desc Parameters:,,] [ serialization.format,1,] ``` After ```SQL scala> spark.sql("DESC FORMATTED tmp_default").collect.foreach(println) .. [# Storage Information,,] [SerDe Library:,org.apache.hadoop.hive.ql.io.orc.OrcSerde,] [InputFormat:,org.apache.hadoop.hive.ql.io.orc.OrcInputFormat,] [OutputFormat:,org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat,] [Compressed:,No,] [Storage Desc Parameters:,,] [ serialization.format,1,] ``` ## How was this patch tested? Added new tests to HiveDDLCommandSuite Author: Dilip Biswal Closes #15190 from dilipbiswal/orc. --- .../spark/sql/execution/SparkSqlParser.scala | 4 +- .../spark/sql/hive/HiveDDLCommandSuite.scala | 26 ++++++++++++- .../sql/hive/execution/SQLQuerySuite.scala | 39 +++++++++++++++++-- 3 files changed, 60 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index be2eddbb0e423..8c68d1e3a2379 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -1010,9 +1010,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { .orElse(Some("org.apache.hadoop.mapred.TextInputFormat")), outputFormat = defaultHiveSerde.flatMap(_.outputFormat) .orElse(Some("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")), - // Note: Keep this unspecified because we use the presence of the serde to decide - // whether to convert a table created by CTAS to a datasource table. - serde = None, + serde = defaultHiveSerde.flatMap(_.serde), compressed = false, properties = Map()) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala index 9ce3338647398..81337493c7f28 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala @@ -30,10 +30,12 @@ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{Generate, ScriptTransformation} import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources.CreateTable -import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.{TestHive, TestHiveSingleton} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types.StructType -class HiveDDLCommandSuite extends PlanTest { +class HiveDDLCommandSuite extends PlanTest with SQLTestUtils with TestHiveSingleton { val parser = TestHive.sessionState.sqlParser private def extractTableDesc(sql: String): (CatalogTable, Boolean) = { @@ -556,4 +558,24 @@ class HiveDDLCommandSuite extends PlanTest { assert(partition2.get.apply("c") == "1" && partition2.get.apply("d") == "2") } + test("Test the default fileformat for Hive-serde tables") { + withSQLConf("hive.default.fileformat" -> "orc") { + val (desc, exists) = extractTableDesc("CREATE TABLE IF NOT EXISTS fileformat_test (id int)") + assert(exists) + assert(desc.storage.inputFormat == Some("org.apache.hadoop.hive.ql.io.orc.OrcInputFormat")) + assert(desc.storage.outputFormat == Some("org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat")) + assert(desc.storage.serde == Some("org.apache.hadoop.hive.ql.io.orc.OrcSerde")) + } + + withSQLConf("hive.default.fileformat" -> "parquet") { + val (desc, exists) = extractTableDesc("CREATE TABLE IF NOT EXISTS fileformat_test (id int)") + assert(exists) + val input = desc.storage.inputFormat + val output = desc.storage.outputFormat + val serde = desc.storage.serde + assert(input == Some("org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat")) + assert(output == Some("org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat")) + assert(serde == Some("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe")) + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 6f2a16662bf10..5798f47228216 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -492,7 +492,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { def checkRelation( tableName: String, - isDataSourceParquet: Boolean, + isDataSourceTable: Boolean, format: String, userSpecifiedLocation: Option[String] = None): Unit = { val relation = EliminateSubqueryAliases( @@ -501,7 +501,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { sessionState.catalog.getTableMetadata(TableIdentifier(tableName)) relation match { case LogicalRelation(r: HadoopFsRelation, _, _) => - if (!isDataSourceParquet) { + if (!isDataSourceTable) { fail( s"${classOf[MetastoreRelation].getCanonicalName} is expected, but found " + s"${HadoopFsRelation.getClass.getCanonicalName}.") @@ -514,7 +514,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { assert(catalogTable.provider.get === format) case r: MetastoreRelation => - if (isDataSourceParquet) { + if (isDataSourceTable) { fail( s"${HadoopFsRelation.getClass.getCanonicalName} is expected, but found " + s"${classOf[MetastoreRelation].getCanonicalName}.") @@ -524,8 +524,15 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { assert(r.catalogTable.storage.locationUri.get === location) case None => // OK. } - // Also make sure that the format is the desired format. + // Also make sure that the format and serde are as desired. assert(catalogTable.storage.inputFormat.get.toLowerCase.contains(format)) + assert(catalogTable.storage.outputFormat.get.toLowerCase.contains(format)) + val serde = catalogTable.storage.serde.get + format match { + case "sequence" | "text" => assert(serde.contains("LazySimpleSerDe")) + case "rcfile" => assert(serde.contains("LazyBinaryColumnarSerDe")) + case _ => assert(serde.toLowerCase.contains(format)) + } } // When a user-specified location is defined, the table type needs to be EXTERNAL. @@ -587,6 +594,30 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } + test("CTAS with default fileformat") { + val table = "ctas1" + val ctas = s"CREATE TABLE IF NOT EXISTS $table SELECT key k, value FROM src" + withSQLConf(SQLConf.CONVERT_CTAS.key -> "true") { + withSQLConf("hive.default.fileformat" -> "textfile") { + withTable(table) { + sql(ctas) + // We should use parquet here as that is the default datasource fileformat. The default + // datasource file format is controlled by `spark.sql.sources.default` configuration. + // This testcase verifies that setting `hive.default.fileformat` has no impact on + // the target table's fileformat in case of CTAS. + assert(sessionState.conf.defaultDataSourceName === "parquet") + checkRelation(table, isDataSourceTable = true, "parquet") + } + } + withSQLConf("spark.sql.sources.default" -> "orc") { + withTable(table) { + sql(ctas) + checkRelation(table, isDataSourceTable = true, "orc") + } + } + } + } + test("CTAS without serde with location") { withSQLConf(SQLConf.CONVERT_CTAS.key -> "true") { withTempDir { dir => From 522dd0d0e5af83e45a3c3526c191aa4b8bcaeeeb Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Fri, 14 Oct 2016 14:09:35 -0700 Subject: [PATCH 294/306] Revert "[SPARK-17620][SQL] Determine Serde by hive.default.fileformat when Creating Hive Serde Tables" This reverts commit 7ab86244e30ca81eb4fa40ea77b4c2b8881cbab2. --- .../spark/sql/execution/SparkSqlParser.scala | 4 +- .../spark/sql/hive/HiveDDLCommandSuite.scala | 26 +------------ .../sql/hive/execution/SQLQuerySuite.scala | 39 ++----------------- 3 files changed, 9 insertions(+), 60 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 8c68d1e3a2379..be2eddbb0e423 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -1010,7 +1010,9 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { .orElse(Some("org.apache.hadoop.mapred.TextInputFormat")), outputFormat = defaultHiveSerde.flatMap(_.outputFormat) .orElse(Some("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")), - serde = defaultHiveSerde.flatMap(_.serde), + // Note: Keep this unspecified because we use the presence of the serde to decide + // whether to convert a table created by CTAS to a datasource table. + serde = None, compressed = false, properties = Map()) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala index 81337493c7f28..9ce3338647398 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala @@ -30,12 +30,10 @@ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{Generate, ScriptTransformation} import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources.CreateTable -import org.apache.spark.sql.hive.test.{TestHive, TestHiveSingleton} -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.types.StructType -class HiveDDLCommandSuite extends PlanTest with SQLTestUtils with TestHiveSingleton { +class HiveDDLCommandSuite extends PlanTest { val parser = TestHive.sessionState.sqlParser private def extractTableDesc(sql: String): (CatalogTable, Boolean) = { @@ -558,24 +556,4 @@ class HiveDDLCommandSuite extends PlanTest with SQLTestUtils with TestHiveSingle assert(partition2.get.apply("c") == "1" && partition2.get.apply("d") == "2") } - test("Test the default fileformat for Hive-serde tables") { - withSQLConf("hive.default.fileformat" -> "orc") { - val (desc, exists) = extractTableDesc("CREATE TABLE IF NOT EXISTS fileformat_test (id int)") - assert(exists) - assert(desc.storage.inputFormat == Some("org.apache.hadoop.hive.ql.io.orc.OrcInputFormat")) - assert(desc.storage.outputFormat == Some("org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat")) - assert(desc.storage.serde == Some("org.apache.hadoop.hive.ql.io.orc.OrcSerde")) - } - - withSQLConf("hive.default.fileformat" -> "parquet") { - val (desc, exists) = extractTableDesc("CREATE TABLE IF NOT EXISTS fileformat_test (id int)") - assert(exists) - val input = desc.storage.inputFormat - val output = desc.storage.outputFormat - val serde = desc.storage.serde - assert(input == Some("org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat")) - assert(output == Some("org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat")) - assert(serde == Some("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe")) - } - } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 5798f47228216..6f2a16662bf10 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -492,7 +492,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { def checkRelation( tableName: String, - isDataSourceTable: Boolean, + isDataSourceParquet: Boolean, format: String, userSpecifiedLocation: Option[String] = None): Unit = { val relation = EliminateSubqueryAliases( @@ -501,7 +501,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { sessionState.catalog.getTableMetadata(TableIdentifier(tableName)) relation match { case LogicalRelation(r: HadoopFsRelation, _, _) => - if (!isDataSourceTable) { + if (!isDataSourceParquet) { fail( s"${classOf[MetastoreRelation].getCanonicalName} is expected, but found " + s"${HadoopFsRelation.getClass.getCanonicalName}.") @@ -514,7 +514,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { assert(catalogTable.provider.get === format) case r: MetastoreRelation => - if (isDataSourceTable) { + if (isDataSourceParquet) { fail( s"${HadoopFsRelation.getClass.getCanonicalName} is expected, but found " + s"${classOf[MetastoreRelation].getCanonicalName}.") @@ -524,15 +524,8 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { assert(r.catalogTable.storage.locationUri.get === location) case None => // OK. } - // Also make sure that the format and serde are as desired. + // Also make sure that the format is the desired format. assert(catalogTable.storage.inputFormat.get.toLowerCase.contains(format)) - assert(catalogTable.storage.outputFormat.get.toLowerCase.contains(format)) - val serde = catalogTable.storage.serde.get - format match { - case "sequence" | "text" => assert(serde.contains("LazySimpleSerDe")) - case "rcfile" => assert(serde.contains("LazyBinaryColumnarSerDe")) - case _ => assert(serde.toLowerCase.contains(format)) - } } // When a user-specified location is defined, the table type needs to be EXTERNAL. @@ -594,30 +587,6 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } - test("CTAS with default fileformat") { - val table = "ctas1" - val ctas = s"CREATE TABLE IF NOT EXISTS $table SELECT key k, value FROM src" - withSQLConf(SQLConf.CONVERT_CTAS.key -> "true") { - withSQLConf("hive.default.fileformat" -> "textfile") { - withTable(table) { - sql(ctas) - // We should use parquet here as that is the default datasource fileformat. The default - // datasource file format is controlled by `spark.sql.sources.default` configuration. - // This testcase verifies that setting `hive.default.fileformat` has no impact on - // the target table's fileformat in case of CTAS. - assert(sessionState.conf.defaultDataSourceName === "parquet") - checkRelation(table, isDataSourceTable = true, "parquet") - } - } - withSQLConf("spark.sql.sources.default" -> "orc") { - withTable(table) { - sql(ctas) - checkRelation(table, isDataSourceTable = true, "orc") - } - } - } - } - test("CTAS without serde with location") { withSQLConf(SQLConf.CONVERT_CTAS.key -> "true") { withTempDir { dir => From da9aeb0fde589f7c21c2f4a32036a68c0353965d Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 14 Oct 2016 14:45:20 -0700 Subject: [PATCH 295/306] [SPARK-17863][SQL] should not add column into Distinct ## What changes were proposed in this pull request? We are trying to resolve the attribute in sort by pulling up some column for grandchild into child, but that's wrong when the child is Distinct, because the added column will change the behavior of Distinct, we should not do that. ## How was this patch tested? Added regression test. Author: Davies Liu Closes #15489 from davies/order_distinct. --- .../sql/catalyst/analysis/Analyzer.scala | 2 ++ .../org/apache/spark/sql/SQLQuerySuite.scala | 24 +++++++++++++++++++ 2 files changed, 26 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 536d38777f89d..f8f4799322b3b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -838,6 +838,8 @@ class Analyzer( // attributes that its child might have or could have. val missing = missingAttrs -- g.child.outputSet g.copy(join = true, child = addMissingAttr(g.child, missing)) + case d: Distinct => + throw new AnalysisException(s"Can't add $missingAttrs to $d") case u: UnaryNode => u.withNewChildren(addMissingAttr(u.child, missingAttrs) :: Nil) case other => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 0ee8c959eeb4d..60978efddd7f8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1106,6 +1106,30 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { ) } + test("SPARK-17863: SELECT distinct does not work correctly if order by missing attribute") { + checkAnswer( + sql("""select distinct struct.a, struct.b + |from ( + | select named_struct('a', 1, 'b', 2, 'c', 3) as struct + | union all + | select named_struct('a', 1, 'b', 2, 'c', 4) as struct) tmp + |order by a, b + |""".stripMargin), + Row(1, 2) :: Nil) + + val error = intercept[AnalysisException] { + sql("""select distinct struct.a, struct.b + |from ( + | select named_struct('a', 1, 'b', 2, 'c', 3) as struct + | union all + | select named_struct('a', 1, 'b', 2, 'c', 4) as struct) tmp + |order by struct.a, struct.b + |""".stripMargin) + } + assert(error.message contains "cannot resolve '`struct.a`' given input columns: [a, b]") + + } + test("cast boolean to string") { // TODO Ensure true/false string letter casing is consistent with Hive in all cases. checkAnswer( From 5aeb7384c7aa5f487f031f9ae07d3f1653399d14 Mon Sep 17 00:00:00 2001 From: Nick Pentreath Date: Fri, 14 Oct 2016 15:07:32 -0700 Subject: [PATCH 296/306] [SPARK-16063][SQL] Add storageLevel to Dataset [SPARK-11905](https://issues.apache.org/jira/browse/SPARK-11905) added support for `persist`/`cache` for `Dataset`. However, there is no user-facing API to check if a `Dataset` is cached and if so what the storage level is. This PR adds `getStorageLevel` to `Dataset`, analogous to `RDD.getStorageLevel`. Updated `DatasetCacheSuite`. Author: Nick Pentreath Closes #13780 from MLnick/ds-storagelevel. Signed-off-by: Michael Armbrust --- python/pyspark/sql/dataframe.py | 36 +++++++++++++++---- .../scala/org/apache/spark/sql/Dataset.scala | 12 +++++++ .../apache/spark/sql/DatasetCacheSuite.scala | 36 +++++++++++++------ 3 files changed, 68 insertions(+), 16 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index ce277eb204d13..7606ac08bae67 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -407,24 +407,48 @@ def foreachPartition(self, f): @since(1.3) def cache(self): - """ Persists with the default storage level (C{MEMORY_ONLY}). + """Persists the :class:`DataFrame` with the default storage level (C{MEMORY_AND_DISK}). + + .. note:: the default storage level has changed to C{MEMORY_AND_DISK} to match Scala in 2.0. """ self.is_cached = True self._jdf.cache() return self @since(1.3) - def persist(self, storageLevel=StorageLevel.MEMORY_ONLY): - """Sets the storage level to persist its values across operations - after the first time it is computed. This can only be used to assign - a new storage level if the RDD does not have a storage level set yet. - If no storage level is specified defaults to (C{MEMORY_ONLY}). + def persist(self, storageLevel=StorageLevel.MEMORY_AND_DISK): + """Sets the storage level to persist the contents of the :class:`DataFrame` across + operations after the first time it is computed. This can only be used to assign + a new storage level if the :class:`DataFrame` does not have a storage level set yet. + If no storage level is specified defaults to (C{MEMORY_AND_DISK}). + + .. note:: the default storage level has changed to C{MEMORY_AND_DISK} to match Scala in 2.0. """ self.is_cached = True javaStorageLevel = self._sc._getJavaStorageLevel(storageLevel) self._jdf.persist(javaStorageLevel) return self + @property + @since(2.1) + def storageLevel(self): + """Get the :class:`DataFrame`'s current storage level. + + >>> df.storageLevel + StorageLevel(False, False, False, False, 1) + >>> df.cache().storageLevel + StorageLevel(True, True, False, True, 1) + >>> df2.persist(StorageLevel.DISK_ONLY_2).storageLevel + StorageLevel(True, False, False, False, 2) + """ + java_storage_level = self._jdf.storageLevel() + storage_level = StorageLevel(java_storage_level.useDisk(), + java_storage_level.useMemory(), + java_storage_level.useOffHeap(), + java_storage_level.deserialized(), + java_storage_level.replication()) + return storage_level + @since(1.3) def unpersist(self, blocking=False): """Marks the :class:`DataFrame` as non-persistent, and remove all blocks for it from diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index e59a483075c94..70c9cf5ae2440 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2401,6 +2401,18 @@ class Dataset[T] private[sql]( this } + /** + * Get the Dataset's current storage level, or StorageLevel.NONE if not persisted. + * + * @group basic + * @since 2.1.0 + */ + def storageLevel: StorageLevel = { + sparkSession.sharedState.cacheManager.lookupCachedData(this).map { cachedData => + cachedData.cachedRepresentation.storageLevel + }.getOrElse(StorageLevel.NONE) + } + /** * Mark the Dataset as non-persistent, and remove all blocks for it from memory and disk. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala index 8d5e9645df894..e0561ee2797a5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala @@ -19,11 +19,32 @@ package org.apache.spark.sql import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.storage.StorageLevel class DatasetCacheSuite extends QueryTest with SharedSQLContext { import testImplicits._ + test("get storage level") { + val ds1 = Seq("1", "2").toDS().as("a") + val ds2 = Seq(2, 3).toDS().as("b") + + // default storage level + ds1.persist() + ds2.cache() + assert(ds1.storageLevel == StorageLevel.MEMORY_AND_DISK) + assert(ds2.storageLevel == StorageLevel.MEMORY_AND_DISK) + // unpersist + ds1.unpersist() + assert(ds1.storageLevel == StorageLevel.NONE) + // non-default storage level + ds1.persist(StorageLevel.MEMORY_ONLY_2) + assert(ds1.storageLevel == StorageLevel.MEMORY_ONLY_2) + // joined Dataset should not be persisted + val joined = ds1.joinWith(ds2, $"a.value" === $"b.value") + assert(joined.storageLevel == StorageLevel.NONE) + } + test("persist and unpersist") { val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS().select(expr("_2 + 1").as[Int]) val cached = ds.cache() @@ -37,8 +58,7 @@ class DatasetCacheSuite extends QueryTest with SharedSQLContext { 2, 3, 4) // Drop the cache. cached.unpersist() - assert(spark.sharedState.cacheManager.lookupCachedData(cached).isEmpty, - "The Dataset should not be cached.") + assert(cached.storageLevel == StorageLevel.NONE, "The Dataset should not be cached.") } test("persist and then rebind right encoder when join 2 datasets") { @@ -55,11 +75,9 @@ class DatasetCacheSuite extends QueryTest with SharedSQLContext { assertCached(joined, 2) ds1.unpersist() - assert(spark.sharedState.cacheManager.lookupCachedData(ds1).isEmpty, - "The Dataset ds1 should not be cached.") + assert(ds1.storageLevel == StorageLevel.NONE, "The Dataset ds1 should not be cached.") ds2.unpersist() - assert(spark.sharedState.cacheManager.lookupCachedData(ds2).isEmpty, - "The Dataset ds2 should not be cached.") + assert(ds2.storageLevel == StorageLevel.NONE, "The Dataset ds2 should not be cached.") } test("persist and then groupBy columns asKey, map") { @@ -74,10 +92,8 @@ class DatasetCacheSuite extends QueryTest with SharedSQLContext { assertCached(agged.filter(_._1 == "b")) ds.unpersist() - assert(spark.sharedState.cacheManager.lookupCachedData(ds).isEmpty, - "The Dataset ds should not be cached.") + assert(ds.storageLevel == StorageLevel.NONE, "The Dataset ds should not be cached.") agged.unpersist() - assert(spark.sharedState.cacheManager.lookupCachedData(agged).isEmpty, - "The Dataset agged should not be cached.") + assert(agged.storageLevel == StorageLevel.NONE, "The Dataset agged should not be cached.") } } From f00df40cfefef0f3fc73f16ada1006e4dcfa5a39 Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Fri, 14 Oct 2016 15:50:35 -0700 Subject: [PATCH 297/306] [SPARK-11775][PYSPARK][SQL] Allow PySpark to register Java UDF Currently pyspark can only call the builtin java UDF, but can not call custom java UDF. It would be better to allow that. 2 benefits: * Leverage the power of rich third party java library * Improve the performance. Because if we use python UDF, python daemons will be started on worker which will affect the performance. Author: Jeff Zhang Closes #9766 from zjffdu/SPARK-11775. --- python/pyspark/sql/context.py | 28 ++++++- .../sql/catalyst/JavaTypeInference.scala | 2 +- .../apache/spark/sql/UDFRegistration.scala | 75 ++++++++++++++++++- .../apache/spark/sql/JavaStringLength.java | 30 ++++++++ .../org/apache/spark/sql/JavaUDFSuite.java | 21 ++++++ 5 files changed, 152 insertions(+), 4 deletions(-) create mode 100644 sql/core/src/test/java/test/org/apache/spark/sql/JavaStringLength.java diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 8264dcf8a97d2..de4c335ad2752 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -28,7 +28,7 @@ from pyspark.sql.dataframe import DataFrame from pyspark.sql.readwriter import DataFrameReader from pyspark.sql.streaming import DataStreamReader -from pyspark.sql.types import Row, StringType +from pyspark.sql.types import IntegerType, Row, StringType from pyspark.sql.utils import install_exception_handler __all__ = ["SQLContext", "HiveContext", "UDFRegistration"] @@ -202,6 +202,32 @@ def registerFunction(self, name, f, returnType=StringType()): """ self.sparkSession.catalog.registerFunction(name, f, returnType) + @ignore_unicode_prefix + @since(2.1) + def registerJavaFunction(self, name, javaClassName, returnType=None): + """Register a java UDF so it can be used in SQL statements. + + In addition to a name and the function itself, the return type can be optionally specified. + When the return type is not specified we would infer it via reflection. + :param name: name of the UDF + :param javaClassName: fully qualified name of java class + :param returnType: a :class:`pyspark.sql.types.DataType` object + + >>> sqlContext.registerJavaFunction("javaStringLength", + ... "test.org.apache.spark.sql.JavaStringLength", IntegerType()) + >>> sqlContext.sql("SELECT javaStringLength('test')").collect() + [Row(UDF(test)=4)] + >>> sqlContext.registerJavaFunction("javaStringLength2", + ... "test.org.apache.spark.sql.JavaStringLength") + >>> sqlContext.sql("SELECT javaStringLength2('test')").collect() + [Row(UDF(test)=4)] + + """ + jdt = None + if returnType is not None: + jdt = self.sparkSession._jsparkSession.parseDataType(returnType.json()) + self.sparkSession._jsparkSession.udf().registerJava(name, javaClassName, jdt) + # TODO(andrew): delete this once we refactor things to take in SparkSession def _inferSchema(self, rdd, samplingRatio=None): """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index e6f61b00ebd70..04f0cfce883f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -59,7 +59,7 @@ object JavaTypeInference { * @param typeToken Java type * @return (SQL data type, nullable) */ - private def inferDataType(typeToken: TypeToken[_]): (DataType, Boolean) = { + private[sql] def inferDataType(typeToken: TypeToken[_]): (DataType, Boolean) = { typeToken.getRawType match { case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) => (c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index 617a14793697b..0444ad10d34fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -17,19 +17,25 @@ package org.apache.spark.sql +import java.io.IOException +import java.lang.reflect.{ParameterizedType, Type} + import scala.reflect.runtime.universe.TypeTag import scala.util.Try +import com.google.common.reflect.TypeToken + import org.apache.spark.annotation.InterfaceStability import org.apache.spark.internal.Logging import org.apache.spark.sql.api.java._ +import org.apache.spark.sql.catalyst.{JavaTypeInference, ScalaReflection} import org.apache.spark.sql.catalyst.analysis.FunctionRegistry -import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF} import org.apache.spark.sql.execution.aggregate.ScalaUDAF import org.apache.spark.sql.execution.python.UserDefinedPythonFunction import org.apache.spark.sql.expressions.{UserDefinedAggregateFunction, UserDefinedFunction} -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.{DataType, DataTypes} +import org.apache.spark.util.Utils /** * Functions for registering user-defined functions. Use [[SQLContext.udf]] to access this. @@ -413,6 +419,71 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends ////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////// + /** + * Register a Java UDF class using reflection, for use from pyspark + * + * @param name udf name + * @param className fully qualified class name of udf + * @param returnDataType return type of udf. If it is null, spark would try to infer + * via reflection. + */ + private[sql] def registerJava(name: String, className: String, returnDataType: DataType): Unit = { + + try { + val clazz = Utils.classForName(className) + val udfInterfaces = clazz.getGenericInterfaces + .filter(_.isInstanceOf[ParameterizedType]) + .map(_.asInstanceOf[ParameterizedType]) + .filter(e => e.getRawType.isInstanceOf[Class[_]] && e.getRawType.asInstanceOf[Class[_]].getCanonicalName.startsWith("org.apache.spark.sql.api.java.UDF")) + if (udfInterfaces.length == 0) { + throw new IOException(s"UDF class ${className} doesn't implement any UDF interface") + } else if (udfInterfaces.length > 1) { + throw new IOException(s"It is invalid to implement multiple UDF interfaces, UDF class ${className}") + } else { + try { + val udf = clazz.newInstance() + val udfReturnType = udfInterfaces(0).getActualTypeArguments.last + var returnType = returnDataType + if (returnType == null) { + returnType = JavaTypeInference.inferDataType(TypeToken.of(udfReturnType))._1 + } + + udfInterfaces(0).getActualTypeArguments.length match { + case 2 => register(name, udf.asInstanceOf[UDF1[_, _]], returnType) + case 3 => register(name, udf.asInstanceOf[UDF2[_, _, _]], returnType) + case 4 => register(name, udf.asInstanceOf[UDF3[_, _, _, _]], returnType) + case 5 => register(name, udf.asInstanceOf[UDF4[_, _, _, _, _]], returnType) + case 6 => register(name, udf.asInstanceOf[UDF5[_, _, _, _, _, _]], returnType) + case 7 => register(name, udf.asInstanceOf[UDF6[_, _, _, _, _, _, _]], returnType) + case 8 => register(name, udf.asInstanceOf[UDF7[_, _, _, _, _, _, _, _]], returnType) + case 9 => register(name, udf.asInstanceOf[UDF8[_, _, _, _, _, _, _, _, _]], returnType) + case 10 => register(name, udf.asInstanceOf[UDF9[_, _, _, _, _, _, _, _, _, _]], returnType) + case 11 => register(name, udf.asInstanceOf[UDF10[_, _, _, _, _, _, _, _, _, _, _]], returnType) + case 12 => register(name, udf.asInstanceOf[UDF11[_, _, _, _, _, _, _, _, _, _, _, _]], returnType) + case 13 => register(name, udf.asInstanceOf[UDF12[_, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) + case 14 => register(name, udf.asInstanceOf[UDF13[_, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) + case 15 => register(name, udf.asInstanceOf[UDF14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) + case 16 => register(name, udf.asInstanceOf[UDF15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) + case 17 => register(name, udf.asInstanceOf[UDF16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) + case 18 => register(name, udf.asInstanceOf[UDF17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) + case 19 => register(name, udf.asInstanceOf[UDF18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) + case 20 => register(name, udf.asInstanceOf[UDF19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) + case 21 => register(name, udf.asInstanceOf[UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) + case 22 => register(name, udf.asInstanceOf[UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) + case 23 => register(name, udf.asInstanceOf[UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) + case n => logError(s"UDF class with ${n} type arguments is not supported ") + } + } catch { + case e @ (_: InstantiationException | _: IllegalArgumentException) => + logError(s"Can not instantiate class ${className}, please make sure it has public non argument constructor") + } + } + } catch { + case e: ClassNotFoundException => logError(s"Can not load class ${className}, please make sure it is on the classpath") + } + + } + /** * Register a user-defined function with 1 arguments. * @since 1.3.0 diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaStringLength.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaStringLength.java new file mode 100644 index 0000000000000..b90224f2ae397 --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaStringLength.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql; + +import org.apache.spark.sql.api.java.UDF1; + +/** + * It is used for register Java UDF from PySpark + */ +public class JavaStringLength implements UDF1 { + @Override + public Integer call(String str) throws Exception { + return new Integer(str.length()); + } +} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java index 2274912521a56..8bf3278c43880 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java @@ -87,4 +87,25 @@ public Integer call(String str1, String str2) { Row result = spark.sql("SELECT stringLengthTest('test', 'test2')").head(); Assert.assertEquals(9, result.getInt(0)); } + + public static class StringLengthTest implements UDF2 { + @Override + public Integer call(String str1, String str2) throws Exception { + return new Integer(str1.length() + str2.length()); + } + } + + @SuppressWarnings("unchecked") + @Test + public void udf3Test() { + spark.udf().registerJava("stringLengthTest", StringLengthTest.class.getName(), + DataTypes.IntegerType); + Row result = spark.sql("SELECT stringLengthTest('test', 'test2')").head(); + Assert.assertEquals(9, result.getInt(0)); + + // returnType is not provided + spark.udf().registerJava("stringLengthTest2", StringLengthTest.class.getName(), null); + result = spark.sql("SELECT stringLengthTest('test', 'test2')").head(); + Assert.assertEquals(9, result.getInt(0)); + } } From 72adfbf94ab6a6ce2a5f3111140274476150f201 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 14 Oct 2016 16:13:42 -0700 Subject: [PATCH 298/306] [SPARK-17900][SQL] Graduate a list of Spark SQL APIs to stable ## What changes were proposed in this pull request? This patch graduates a list of Spark SQL APIs and mark them stable. The following are marked stable: Dataset/DataFrame - functions, since 1.3 - ColumnName, since 1.3 - DataFrameNaFunctions, since 1.3.1 - DataFrameStatFunctions, since 1.4 - UserDefinedFunction, since 1.3 - UserDefinedAggregateFunction, since 1.5 - Window and WindowSpec, since 1.4 Data sources: - DataSourceRegister, since 1.5 - RelationProvider, since 1.3 - SchemaRelationProvider, since 1.3 - CreatableRelationProvider, since 1.3 - BaseRelation, since 1.3 - TableScan, since 1.3 - PrunedScan, since 1.3 - PrunedFilteredScan, since 1.3 - InsertableRelation, since 1.3 The following are kept experimental / evolving: Data sources: - CatalystScan (tied to internal logical plans so it is not stable by definition) Structured streaming: - all classes (introduced new in 2.0 and will likely change) Dataset typed operations (introduced in 1.6 and 2.0 and might change, although probability is low) - all typed methods on Dataset - KeyValueGroupedDataset - o.a.s.sql.expressions.javalang.typed - o.a.s.sql.expressions.scalalang.typed - methods that return typed Dataset in SparkSession We should discuss more whether we want to mark Dataset typed operations stable in 2.1. ## How was this patch tested? N/A - just annotation changes. Author: Reynold Xin Closes #15469 from rxin/SPARK-17900. --- .../scala/org/apache/spark/sql/Column.scala | 6 ++-- .../spark/sql/DataFrameNaFunctions.scala | 6 ++-- .../spark/sql/DataFrameStatFunctions.scala | 6 ++-- .../sql/expressions/UserDefinedFunction.scala | 10 ++++-- .../apache/spark/sql/expressions/Window.scala | 10 ++---- .../spark/sql/expressions/WindowSpec.scala | 6 ++-- .../apache/spark/sql/expressions/udaf.scala | 30 ++++++++++++---- .../org/apache/spark/sql/functions.scala | 4 +-- .../apache/spark/sql/sources/interfaces.scala | 35 +++++-------------- 9 files changed, 51 insertions(+), 62 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index d22bb17934ce7..05e867bf5be96 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import scala.language.implicitConversions -import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} @@ -1181,13 +1181,11 @@ class Column(protected[sql] val expr: Expression) extends Logging { /** - * :: Experimental :: * A convenient class used for constructing schema. * * @since 1.3.0 */ -@Experimental -@InterfaceStability.Evolving +@InterfaceStability.Stable class ColumnName(name: String) extends Column(name) { /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index 65a9c008f9650..0d43f09bc54cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -21,20 +21,18 @@ import java.{lang => jl} import scala.collection.JavaConverters._ -import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ /** - * :: Experimental :: * Functionality for working with missing data in [[DataFrame]]s. * * @since 1.3.1 */ -@Experimental -@InterfaceStability.Evolving +@InterfaceStability.Stable final class DataFrameNaFunctions private[sql](df: DataFrame) { /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index a212bb6205328..b5bbcee37150f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -21,20 +21,18 @@ import java.{lang => jl, util => ju} import scala.collection.JavaConverters._ -import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.stat._ import org.apache.spark.sql.types._ import org.apache.spark.util.sketch.{BloomFilter, CountMinSketch} /** - * :: Experimental :: * Statistic functions for [[DataFrame]]s. * * @since 1.4.0 */ -@Experimental -@InterfaceStability.Evolving +@InterfaceStability.Stable final class DataFrameStatFunctions private[sql](df: DataFrame) { /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala index 2e0e937e4aff7..28598af781653 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.expressions -import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.expressions.ScalaUDF import org.apache.spark.sql.Column import org.apache.spark.sql.functions @@ -39,13 +39,17 @@ import org.apache.spark.sql.types.DataType * * @since 1.3.0 */ -@Experimental -@InterfaceStability.Evolving +@InterfaceStability.Stable case class UserDefinedFunction protected[sql] ( f: AnyRef, dataType: DataType, inputTypes: Option[Seq[DataType]]) { + /** + * Returns an expression that invokes the UDF, using the given arguments. + * + * @since 1.3.0 + */ def apply(exprs: Column*): Column = { Column(ScalaUDF(f, dataType, exprs.map(_.expr), inputTypes.getOrElse(Nil))) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala index 07ef60183f6fb..0b26d863cac5d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala @@ -17,12 +17,11 @@ package org.apache.spark.sql.expressions -import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.expressions._ /** - * :: Experimental :: * Utility functions for defining window in DataFrames. * * {{{ @@ -36,8 +35,7 @@ import org.apache.spark.sql.catalyst.expressions._ * * @since 1.4.0 */ -@Experimental -@InterfaceStability.Evolving +@InterfaceStability.Stable object Window { /** @@ -164,7 +162,6 @@ object Window { } /** - * :: Experimental :: * Utility functions for defining window in DataFrames. * * {{{ @@ -177,6 +174,5 @@ object Window { * * @since 1.4.0 */ -@Experimental -@InterfaceStability.Evolving +@InterfaceStability.Stable class Window private() // So we can see Window in JavaDoc. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala index 18778c8d1c294..1e85b6e7881ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala @@ -17,20 +17,18 @@ package org.apache.spark.sql.expressions -import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.expressions._ /** - * :: Experimental :: * A window specification that defines the partitioning, ordering, and frame boundaries. * * Use the static methods in [[Window]] to create a [[WindowSpec]]. * * @since 1.4.0 */ -@Experimental -@InterfaceStability.Evolving +@InterfaceStability.Stable class WindowSpec private[sql]( partitionSpec: Seq[Expression], orderSpec: Seq[SortOrder], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala index ef7c09c72b82d..bc9788d81fe6a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala @@ -17,20 +17,18 @@ package org.apache.spark.sql.expressions -import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.{Column, Row} import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete} import org.apache.spark.sql.execution.aggregate.ScalaUDAF import org.apache.spark.sql.types._ /** - * :: Experimental :: * The base class for implementing user-defined aggregate functions (UDAF). * * @since 1.5.0 */ -@Experimental -@InterfaceStability.Evolving +@InterfaceStability.Stable abstract class UserDefinedAggregateFunction extends Serializable { /** @@ -46,6 +44,8 @@ abstract class UserDefinedAggregateFunction extends Serializable { * * The name of a field of this [[StructType]] is only used to identify the corresponding * input argument. Users can choose names to identify the input arguments. + * + * @since 1.5.0 */ def inputSchema: StructType @@ -63,17 +63,23 @@ abstract class UserDefinedAggregateFunction extends Serializable { * * The name of a field of this [[StructType]] is only used to identify the corresponding * buffer value. Users can choose names to identify the input arguments. + * + * @since 1.5.0 */ def bufferSchema: StructType /** * The [[DataType]] of the returned value of this [[UserDefinedAggregateFunction]]. + * + * @since 1.5.0 */ def dataType: DataType /** * Returns true iff this function is deterministic, i.e. given the same input, * always return the same output. + * + * @since 1.5.0 */ def deterministic: Boolean @@ -83,6 +89,8 @@ abstract class UserDefinedAggregateFunction extends Serializable { * The contract should be that applying the merge function on two initial buffers should just * return the initial buffer itself, i.e. * `merge(initialBuffer, initialBuffer)` should equal `initialBuffer`. + * + * @since 1.5.0 */ def initialize(buffer: MutableAggregationBuffer): Unit @@ -90,6 +98,8 @@ abstract class UserDefinedAggregateFunction extends Serializable { * Updates the given aggregation buffer `buffer` with new input data from `input`. * * This is called once per input row. + * + * @since 1.5.0 */ def update(buffer: MutableAggregationBuffer, input: Row): Unit @@ -97,17 +107,23 @@ abstract class UserDefinedAggregateFunction extends Serializable { * Merges two aggregation buffers and stores the updated buffer values back to `buffer1`. * * This is called when we merge two partially aggregated data together. + * + * @since 1.5.0 */ def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit /** * Calculates the final result of this [[UserDefinedAggregateFunction]] based on the given * aggregation buffer. + * + * @since 1.5.0 */ def evaluate(buffer: Row): Any /** * Creates a [[Column]] for this UDAF using given [[Column]]s as input arguments. + * + * @since 1.5.0 */ @scala.annotation.varargs def apply(exprs: Column*): Column = { @@ -122,6 +138,8 @@ abstract class UserDefinedAggregateFunction extends Serializable { /** * Creates a [[Column]] for this UDAF using the distinct values of the given * [[Column]]s as input arguments. + * + * @since 1.5.0 */ @scala.annotation.varargs def distinct(exprs: Column*): Column = { @@ -135,15 +153,13 @@ abstract class UserDefinedAggregateFunction extends Serializable { } /** - * :: Experimental :: * A [[Row]] representing a mutable aggregation buffer. * * This is not meant to be extended outside of Spark. * * @since 1.5.0 */ -@Experimental -@InterfaceStability.Evolving +@InterfaceStability.Stable abstract class MutableAggregationBuffer extends Row { /** Update the ith value of this buffer. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index de4943152720c..5f1efd22d8204 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -37,7 +37,6 @@ import org.apache.spark.util.Utils /** - * :: Experimental :: * Functions available for DataFrame operations. * * @groupname udf_funcs UDF functions @@ -53,8 +52,7 @@ import org.apache.spark.util.Utils * @groupname Ungrouped Support functions for DataFrames * @since 1.3.0 */ -@Experimental -@InterfaceStability.Evolving +@InterfaceStability.Stable // scalastyle:off object functions { // scalastyle:on diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 3172d5ded9504..15a48072525b2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -27,7 +27,6 @@ import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType /** - * ::DeveloperApi:: * Data sources should implement this trait so that they can register an alias to their data source. * This allows users to give the data source alias as the format type over the fully qualified * class name. @@ -36,8 +35,7 @@ import org.apache.spark.sql.types.StructType * * @since 1.5.0 */ -@DeveloperApi -@InterfaceStability.Evolving +@InterfaceStability.Stable trait DataSourceRegister { /** @@ -54,7 +52,6 @@ trait DataSourceRegister { } /** - * ::DeveloperApi:: * Implemented by objects that produce relations for a specific kind of data source. When * Spark SQL is given a DDL operation with a USING clause specified (to specify the implemented * RelationProvider), this interface is used to pass in the parameters specified by a user. @@ -68,8 +65,7 @@ trait DataSourceRegister { * * @since 1.3.0 */ -@DeveloperApi -@InterfaceStability.Evolving +@InterfaceStability.Stable trait RelationProvider { /** * Returns a new base relation with the given parameters. @@ -80,7 +76,6 @@ trait RelationProvider { } /** - * ::DeveloperApi:: * Implemented by objects that produce relations for a specific kind of data source * with a given schema. When Spark SQL is given a DDL operation with a USING clause specified ( * to specify the implemented SchemaRelationProvider) and a user defined schema, this interface @@ -100,8 +95,7 @@ trait RelationProvider { * * @since 1.3.0 */ -@DeveloperApi -@InterfaceStability.Evolving +@InterfaceStability.Stable trait SchemaRelationProvider { /** * Returns a new base relation with the given parameters and user defined schema. @@ -164,8 +158,7 @@ trait StreamSinkProvider { /** * @since 1.3.0 */ -@DeveloperApi -@InterfaceStability.Evolving +@InterfaceStability.Stable trait CreatableRelationProvider { /** * Save the DataFrame to the destination and return a relation with the given parameters based on @@ -189,7 +182,6 @@ trait CreatableRelationProvider { } /** - * ::DeveloperApi:: * Represents a collection of tuples with a known schema. Classes that extend BaseRelation must * be able to produce the schema of their data in the form of a [[StructType]]. Concrete * implementation should inherit from one of the descendant `Scan` classes, which define various @@ -201,8 +193,7 @@ trait CreatableRelationProvider { * * @since 1.3.0 */ -@DeveloperApi -@InterfaceStability.Evolving +@InterfaceStability.Stable abstract class BaseRelation { def sqlContext: SQLContext def schema: StructType @@ -248,32 +239,27 @@ abstract class BaseRelation { } /** - * ::DeveloperApi:: * A BaseRelation that can produce all of its tuples as an RDD of Row objects. * * @since 1.3.0 */ -@DeveloperApi -@InterfaceStability.Evolving +@InterfaceStability.Stable trait TableScan { def buildScan(): RDD[Row] } /** - * ::DeveloperApi:: * A BaseRelation that can eliminate unneeded columns before producing an RDD * containing all of its tuples as Row objects. * * @since 1.3.0 */ -@DeveloperApi -@InterfaceStability.Evolving +@InterfaceStability.Stable trait PrunedScan { def buildScan(requiredColumns: Array[String]): RDD[Row] } /** - * ::DeveloperApi:: * A BaseRelation that can eliminate unneeded columns and filter using selected * predicates before producing an RDD containing all matching tuples as Row objects. * @@ -286,14 +272,12 @@ trait PrunedScan { * * @since 1.3.0 */ -@DeveloperApi -@InterfaceStability.Evolving +@InterfaceStability.Stable trait PrunedFilteredScan { def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] } /** - * ::DeveloperApi:: * A BaseRelation that can be used to insert data into it through the insert method. * If overwrite in insert method is true, the old data in the relation should be overwritten with * the new data. If overwrite in insert method is false, the new data should be appended. @@ -310,8 +294,7 @@ trait PrunedFilteredScan { * * @since 1.3.0 */ -@DeveloperApi -@InterfaceStability.Evolving +@InterfaceStability.Stable trait InsertableRelation { def insert(data: DataFrame, overwrite: Boolean): Unit } From 2d96d35dc0fed6df249606d9ce9272c0f0109fa2 Mon Sep 17 00:00:00 2001 From: Srinath Shankar Date: Fri, 14 Oct 2016 18:24:47 -0700 Subject: [PATCH 299/306] [SPARK-17946][PYSPARK] Python crossJoin API similar to Scala ## What changes were proposed in this pull request? Add a crossJoin function to the DataFrame API similar to that in Scala. Joins with no condition (cartesian products) must be specified with the crossJoin API ## How was this patch tested? Added python tests to ensure that an AnalysisException if a cartesian product is specified without crossJoin(), and that cartesian products can execute if specified via crossJoin() (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark before opening a pull request. Author: Srinath Shankar Closes #15493 from srinathshankar/crosspython. --- python/pyspark/sql/dataframe.py | 26 +++++++++++++++---- python/pyspark/sql/tests.py | 15 ++++++++++- .../scala/org/apache/spark/sql/Dataset.scala | 2 +- 3 files changed, 36 insertions(+), 7 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 7606ac08bae67..29710acf54c4f 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -650,6 +650,25 @@ def alias(self, alias): assert isinstance(alias, basestring), "alias should be a string" return DataFrame(getattr(self._jdf, "as")(alias), self.sql_ctx) + @ignore_unicode_prefix + @since(2.1) + def crossJoin(self, other): + """Returns the cartesian product with another :class:`DataFrame`. + + :param other: Right side of the cartesian product. + + >>> df.select("age", "name").collect() + [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] + >>> df2.select("name", "height").collect() + [Row(name=u'Tom', height=80), Row(name=u'Bob', height=85)] + >>> df.crossJoin(df2.select("height")).select("age", "name", "height").collect() + [Row(age=2, name=u'Alice', height=80), Row(age=2, name=u'Alice', height=85), + Row(age=5, name=u'Bob', height=80), Row(age=5, name=u'Bob', height=85)] + """ + + jdf = self._jdf.crossJoin(other._jdf) + return DataFrame(jdf, self.sql_ctx) + @ignore_unicode_prefix @since(1.3) def join(self, other, on=None, how=None): @@ -690,14 +709,11 @@ def join(self, other, on=None, how=None): on = self._jseq(on) else: assert isinstance(on[0], Column), "on should be Column or list of Column" - if len(on) > 1: - on = reduce(lambda x, y: x.__and__(y), on) - else: - on = on[0] + on = reduce(lambda x, y: x.__and__(y), on) on = on._jc if on is None and how is None: - jdf = self._jdf.crossJoin(other._jdf) + jdf = self._jdf.join(other._jdf) else: if how is None: how = "inner" diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 51d5e7ab0568e..3d46b852c52e1 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1466,7 +1466,7 @@ def test_functions_broadcast(self): self.assertEqual(1, plan1.toString().count("BroadcastHashJoin")) # no join key -- should not be a broadcast join - plan2 = df1.join(broadcast(df2))._jdf.queryExecution().executedPlan() + plan2 = df1.crossJoin(broadcast(df2))._jdf.queryExecution().executedPlan() self.assertEqual(0, plan2.toString().count("BroadcastHashJoin")) # planner should not crash without a join @@ -1514,6 +1514,19 @@ def test_invalid_join_method(self): df2 = self.spark.createDataFrame([("Alice", 80), ("Bob", 90)], ["name", "height"]) self.assertRaises(IllegalArgumentException, lambda: df1.join(df2, how="invalid-join-type")) + # Cartesian products require cross join syntax + def test_require_cross(self): + from pyspark.sql.functions import broadcast + + df1 = self.spark.createDataFrame([(1, "1")], ("key", "value")) + df2 = self.spark.createDataFrame([(1, "1")], ("key", "value")) + + # joins without conditions require cross join syntax + self.assertRaises(AnalysisException, lambda: df1.join(df2).collect()) + + # works with crossJoin + self.assertEqual(1, df1.crossJoin(df2).count()) + def test_conf(self): spark = self.spark spark.conf.set("bogo", "sipeo") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 70c9cf5ae2440..7ae3275245c5d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -774,7 +774,7 @@ class Dataset[T] private[sql]( * @param right Right side of the join operation. * * @group untypedrel - * @since 2.0.0 + * @since 2.1.0 */ def crossJoin(right: Dataset[_]): DataFrame = withPlan { Join(logicalPlan, right.logicalPlan, joinType = Cross, None) From 6ce1b675ee9fc9a6034439c3ca00441f9f172f84 Mon Sep 17 00:00:00 2001 From: Michael Allman Date: Fri, 14 Oct 2016 18:26:18 -0700 Subject: [PATCH 300/306] [SPARK-16980][SQL] Load only catalog table partition metadata required to answer a query (This PR addresses https://issues.apache.org/jira/browse/SPARK-16980.) ## What changes were proposed in this pull request? In a new Spark session, when a partitioned Hive table is converted to use Spark's `HadoopFsRelation` in `HiveMetastoreCatalog`, metadata for every partition of that table are retrieved from the metastore and loaded into driver memory. In addition, every partition's metadata files are read from the filesystem to perform schema inference. If a user queries such a table with predicates which prune that table's partitions, we would like to be able to answer that query without consulting partition metadata which are not involved in the query. When querying a table with a large number of partitions for some data from a small number of partitions (maybe even a single partition), the current conversion strategy is highly inefficient. I suspect this scenario is not uncommon in the wild. In addition to being inefficient in running time, the current strategy is inefficient in its use of driver memory. When the sum of the number of partitions of all tables loaded in a driver reaches a certain level (somewhere in the tens of thousands), their cached data exhaust all driver heap memory in the default configuration. I suspect this scenario is less common (in that not too many deployments work with tables with tens of thousands of partitions), however this does illustrate how large the memory footprint of this metadata can be. With tables with hundreds or thousands of partitions, I would expect the `HiveMetastoreCatalog` table cache to represent a significant portion of the driver's heap space. This PR proposes an alternative approach. Basically, it makes four changes: 1. It adds a new method, `listPartitionsByFilter` to the Catalyst `ExternalCatalog` trait which returns the partition metadata for a given sequence of partition pruning predicates. 1. It refactors the `FileCatalog` type hierarchy to include a new `TableFileCatalog` to efficiently return files only for partitions matching a sequence of partition pruning predicates. 1. It removes partition loading and caching from `HiveMetastoreCatalog`. 1. It adds a new Catalyst optimizer rule, `PruneFileSourcePartitions`, which applies a plan's partition-pruning predicates to prune out unnecessary partition files from a `HadoopFsRelation`'s underlying file catalog. The net effect is that when a query over a partitioned Hive table is planned, the analyzer retrieves the table metadata from `HiveMetastoreCatalog`. As part of this operation, the `HiveMetastoreCatalog` builds a `HadoopFsRelation` with a `TableFileCatalog`. It does not load any partition metadata or scan any files. The optimizer prunes-away unnecessary table partitions by sending the partition-pruning predicates to the relation's `TableFileCatalog `. The `TableFileCatalog` in turn calls the `listPartitionsByFilter` method on its external catalog. This queries the Hive metastore, passing along those filters. As a bonus, performing partition pruning during optimization leads to a more accurate relation size estimate. This, along with c481bdf, can lead to automatic, safe application of the broadcast optimization in a join where it might previously have been omitted. ## Open Issues 1. This PR omits partition metadata caching. I can add this once the overall strategy for the cold path is established, perhaps in a future PR. 1. This PR removes and omits partitioned Hive table schema reconciliation. As a result, it fails to find Parquet schema columns with upper case letters because of the Hive metastore's case-insensitivity. This issue may be fixed by #14750, but that PR appears to have stalled. ericl has contributed to this PR a workaround for Parquet wherein schema reconciliation occurs at query execution time instead of planning. Whether ORC requires a similar patch is an open issue. 1. This PR omits an implementation of `listPartitionsByFilter` for the `InMemoryCatalog`. 1. This PR breaks parquet log output redirection during query execution. I can work around this by running `Class.forName("org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat$")` first thing in a Spark shell session, but I haven't figured out how to fix this properly. ## How was this patch tested? The current Spark unit tests were run, and some ad-hoc tests were performed to validate that only the necessary partition metadata is loaded. Author: Michael Allman Author: Eric Liang Author: Eric Liang Closes #14690 from mallman/spark-16980-lazy_partition_fetching. --- .../spark/metrics/source/StaticSources.scala | 34 ++- .../catalyst/catalog/ExternalCatalog.scala | 5 +- .../catalyst/catalog/InMemoryCatalog.scala | 4 +- .../sql/catalyst/catalog/interface.scala | 15 +- .../scala/org/apache/spark/sql/Dataset.scala | 4 +- .../spark/sql/execution/CacheManager.scala | 2 +- .../sql/execution/DataSourceScanExec.scala | 28 ++- .../spark/sql/execution/SparkOptimizer.scala | 2 + .../command/createDataSourceTables.scala | 2 +- .../execution/datasources/DataSource.scala | 4 +- .../datasources/DataSourceStrategy.scala | 8 +- .../execution/datasources/FileFormat.scala | 46 +++- .../datasources/HadoopFsRelation.scala | 16 +- .../datasources/ListingFileCatalog.scala | 197 +-------------- .../datasources/LogicalRelation.scala | 2 +- .../PartitioningAwareFileCatalog.scala | 24 +- .../PruneFileSourcePartitions.scala | 72 ++++++ .../datasources/SessionFileCatalog.scala | 225 ++++++++++++++++++ .../datasources/TableFileCatalog.scala | 113 +++++++++ .../parquet/ParquetReadSupport.scala | 6 +- .../streaming/MetadataLogFileCatalog.scala | 2 +- .../apache/spark/sql/internal/SQLConf.scala | 9 + .../datasources/FileCatalogSuite.scala | 5 +- .../datasources/FileSourceStrategySuite.scala | 2 +- ...te.scala => SessionFileCatalogSuite.scala} | 16 +- .../ParquetPartitionDiscoverySuite.scala | 6 +- .../parquet/ParquetSchemaSuite.scala | 28 +++ .../spark/sql/hive/HiveExternalCatalog.scala | 37 ++- .../spark/sql/hive/HiveMetastoreCatalog.scala | 126 ++++------ .../spark/sql/hive/client/HiveClient.scala | 15 +- .../sql/hive/client/HiveClientImpl.scala | 19 +- .../spark/sql/hive/orc/OrcFileFormat.scala | 12 +- .../spark/sql/hive/HiveDataFrameSuite.scala | 109 ++++++++- .../sql/hive/HiveMetadataCacheSuite.scala | 41 ++++ .../spark/sql/hive/client/VersionsSuite.scala | 4 +- .../spark/sql/hive/orc/OrcQuerySuite.scala | 22 ++ .../apache/spark/sql/hive/parquetSuites.scala | 20 +- 37 files changed, 914 insertions(+), 368 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SessionFileCatalog.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/TableFileCatalog.scala rename sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/{ListingFileCatalogSuite.scala => SessionFileCatalogSuite.scala} (66%) diff --git a/core/src/main/scala/org/apache/spark/metrics/source/StaticSources.scala b/core/src/main/scala/org/apache/spark/metrics/source/StaticSources.scala index 6bba259acc391..cf92a10deabd5 100644 --- a/core/src/main/scala/org/apache/spark/metrics/source/StaticSources.scala +++ b/core/src/main/scala/org/apache/spark/metrics/source/StaticSources.scala @@ -26,7 +26,7 @@ private[spark] object StaticSources { * The set of all static sources. These sources may be reported to from any class, including * static classes, without requiring reference to a SparkEnv. */ - val allSources = Seq(CodegenMetrics) + val allSources = Seq(CodegenMetrics, HiveCatalogMetrics) } /** @@ -60,3 +60,35 @@ object CodegenMetrics extends Source { val METRIC_GENERATED_METHOD_BYTECODE_SIZE = metricRegistry.histogram(MetricRegistry.name("generatedMethodSize")) } + +/** + * :: Experimental :: + * Metrics for access to the hive external catalog. + */ +@Experimental +object HiveCatalogMetrics extends Source { + override val sourceName: String = "HiveExternalCatalog" + override val metricRegistry: MetricRegistry = new MetricRegistry() + + /** + * Tracks the total number of partition metadata entries fetched via the client api. + */ + val METRIC_PARTITIONS_FETCHED = metricRegistry.counter(MetricRegistry.name("partitionsFetched")) + + /** + * Tracks the total number of files discovered off of the filesystem by ListingFileCatalog. + */ + val METRIC_FILES_DISCOVERED = metricRegistry.counter(MetricRegistry.name("filesDiscovered")) + + /** + * Resets the values of all metrics to zero. This is useful in tests. + */ + def reset(): Unit = { + METRIC_PARTITIONS_FETCHED.dec(METRIC_PARTITIONS_FETCHED.getCount()) + METRIC_FILES_DISCOVERED.dec(METRIC_FILES_DISCOVERED.getCount()) + } + + // clients can use these to avoid classloader issues with the codahale classes + def incrementFetchedPartitions(n: Int): Unit = METRIC_PARTITIONS_FETCHED.inc(n) + def incrementFilesDiscovered(n: Int): Unit = METRIC_FILES_DISCOVERED.inc(n) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala index 348d3d0be2152..a5e02523d2889 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala @@ -198,11 +198,12 @@ abstract class ExternalCatalog { partialSpec: Option[TablePartitionSpec] = None): Seq[CatalogTablePartition] /** - * List the metadata of selected partitions according to the given partition predicates. + * List the metadata of partitions that belong to the specified table, assuming it exists, that + * satisfy the given partition-pruning predicate expressions. * * @param db database name * @param table table name - * @param predicates partition predicated + * @param predicates partition-pruning predicates */ def listPartitionsByFilter( db: String, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala index 49280f82e20be..f95c9f8cfa2d4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala @@ -482,7 +482,9 @@ class InMemoryCatalog( db: String, table: String, predicates: Seq[Expression]): Seq[CatalogTablePartition] = { - throw new UnsupportedOperationException("listPartitionsByFilter is not implemented.") + // TODO: Provide an implementation + throw new UnsupportedOperationException( + "listPartitionsByFilter is not implemented for InMemoryCatalog") } // -------------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index 51326ca25e9cc..1a57a7707caa1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -20,11 +20,11 @@ package org.apache.spark.sql.catalyst.catalog import java.util.Date import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, TableIdentifier} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.catalyst.util.quoteIdentifier -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{StructField, StructType} /** @@ -97,6 +97,15 @@ case class CatalogTablePartition( output.filter(_.nonEmpty).mkString("CatalogPartition(\n\t", "\n\t", ")") } + + /** + * Given the partition schema, returns a row with that schema holding the partition values. + */ + def toRow(partitionSchema: StructType): InternalRow = { + InternalRow.fromSeq(partitionSchema.map { case StructField(name, dataType, _, _) => + Cast(Literal(spec(name)), dataType).eval() + }) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 7ae3275245c5d..7dccbbd3f0a5b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -43,7 +43,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.usePrettyExpression import org.apache.spark.sql.execution.{FileRelation, LogicalRDD, QueryExecution, SQLExecution} import org.apache.spark.sql.execution.command.{CreateViewCommand, ExplainCommand, GlobalTempView, LocalTempView} -import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.execution.datasources.{FileCatalog, HadoopFsRelation, LogicalRelation} import org.apache.spark.sql.execution.datasources.json.JacksonGenerator import org.apache.spark.sql.execution.python.EvaluatePython import org.apache.spark.sql.streaming.{DataStreamWriter, StreamingQuery} @@ -2614,7 +2614,7 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ def inputFiles: Array[String] = { - val files: Seq[String] = logicalPlan.collect { + val files: Seq[String] = queryExecution.optimizedPlan.collect { case LogicalRelation(fsBasedRelation: FileRelation, _, _) => fsBasedRelation.inputFiles case fr: FileRelation => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index 83b7c779ab818..92fd366e101fd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -185,7 +185,7 @@ class CacheManager extends Logging { plan match { case lr: LogicalRelation => lr.relation match { case hr: HadoopFsRelation => - val invalidate = hr.location.paths + val invalidate = hr.location.rootPaths .map(_.makeQualified(fs.getUri, fs.getWorkingDirectory)) .contains(qualifiedPath) if (invalidate) hr.location.refresh() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 6cdba406937de..623d2be55dcec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -225,13 +225,27 @@ case class FileSourceScanExec( } // These metadata values make scan plans uniquely identifiable for equality checking. - override val metadata: Map[String, String] = Map( - "Format" -> relation.fileFormat.toString, - "ReadSchema" -> outputSchema.catalogString, - "Batched" -> supportsBatch.toString, - "PartitionFilters" -> partitionFilters.mkString("[", ", ", "]"), - "PushedFilters" -> dataFilters.mkString("[", ", ", "]"), - "InputPaths" -> relation.location.paths.mkString(", ")) + override val metadata: Map[String, String] = { + def seqToString(seq: Seq[Any]) = seq.mkString("[", ", ", "]") + val location = relation.location + val locationDesc = + location.getClass.getSimpleName + seqToString(location.rootPaths) + val metadata = + Map( + "Format" -> relation.fileFormat.toString, + "ReadSchema" -> outputSchema.catalogString, + "Batched" -> supportsBatch.toString, + "PartitionFilters" -> seqToString(partitionFilters), + "PushedFilters" -> seqToString(dataFilters), + "Location" -> locationDesc) + val withOptPartitionCount = + relation.partitionSchemaOption.map { _ => + metadata + ("PartitionCount" -> selectedPartitions.size.toString) + } getOrElse { + metadata + } + withOptPartitionCount + } private lazy val inputRDD: RDD[InternalRow] = { val readFile: (PartitionedFile) => Iterator[InternalRow] = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index 8b762b5d6c5f2..981728331d361 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.ExperimentalMethods import org.apache.spark.sql.catalyst.catalog.SessionCatalog import org.apache.spark.sql.catalyst.optimizer.Optimizer +import org.apache.spark.sql.execution.datasources.PruneFileSourcePartitions import org.apache.spark.sql.execution.python.ExtractPythonUDFFromAggregate import org.apache.spark.sql.internal.SQLConf @@ -32,5 +33,6 @@ class SparkOptimizer( override def batches: Seq[Batch] = super.batches :+ Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog, conf)) :+ Batch("Extract Python UDF from Aggregate", Once, ExtractPythonUDFFromAggregate) :+ + Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions) :+ Batch("User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala index a04a13e698c43..a8c75a7f29cef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala @@ -67,7 +67,7 @@ case class CreateDataSourceTableCommand(table: CatalogTable, ignoreIfExists: Boo dataSource match { case fs: HadoopFsRelation => - if (table.tableType == CatalogTableType.EXTERNAL && fs.location.paths.isEmpty) { + if (table.tableType == CatalogTableType.EXTERNAL && fs.location.rootPaths.isEmpty) { throw new AnalysisException( "Cannot create a file-based external data source table without path") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index e75e7d2770b4e..92b1fff7d8127 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -471,9 +471,7 @@ case class DataSource( val existingPartitionColumns = Try { resolveRelation() .asInstanceOf[HadoopFsRelation] - .location - .partitionSpec() - .partitionColumns + .partitionSchema .fieldNames .toSeq }.getOrElse(Seq.empty[String]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 6f9ed50a02b09..7d0abe86a44df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -163,14 +163,14 @@ case class DataSourceAnalysis(conf: CatalystConf) extends Rule[LogicalPlan] { if query.resolved && t.schema.asNullable == query.schema.asNullable => // Sanity checks - if (t.location.paths.size != 1) { + if (t.location.rootPaths.size != 1) { throw new AnalysisException( "Can only write data to relations with a single path.") } - val outputPath = t.location.paths.head + val outputPath = t.location.rootPaths.head val inputPaths = query.collect { - case LogicalRelation(r: HadoopFsRelation, _, _) => r.location.paths + case LogicalRelation(r: HadoopFsRelation, _, _) => r.location.rootPaths }.flatten val mode = if (overwrite) SaveMode.Overwrite else SaveMode.Append @@ -184,7 +184,7 @@ case class DataSourceAnalysis(conf: CatalystConf) extends Rule[LogicalPlan] { query.resolve(t.partitionSchema, t.sparkSession.sessionState.analyzer.resolver), t.bucketSpec, t.fileFormat, - () => t.refresh(), + () => t.location.refresh(), t.options, query, mode) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala index bde2d2b89d56f..e7239ef91b326 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjectio import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType + /** * Used to read and write data stored in files to/from the [[InternalRow]] format. */ @@ -182,16 +183,17 @@ abstract class TextBasedFileFormat extends FileFormat { case class Partition(values: InternalRow, files: Seq[FileStatus]) /** - * An interface for objects capable of enumerating the files that comprise a relation as well - * as the partitioning characteristics of those files. + * An interface for objects capable of enumerating the root paths of a relation as well as the + * partitions of a relation subject to some pruning expressions. */ -trait FileCatalog { - - /** Returns the list of input paths from which the catalog will get files. */ - def paths: Seq[Path] +trait BasicFileCatalog { - /** Returns the specification of the partitions inferred from the data. */ - def partitionSpec(): PartitionSpec + /** + * Returns the list of root input paths from which the catalog will get files. There may be a + * single root path from which partitions are discovered, or individual partitions may be + * specified by each path. + */ + def rootPaths: Seq[Path] /** * Returns all valid files grouped into partitions when the data is partitioned. If the data is @@ -204,9 +206,33 @@ trait FileCatalog { */ def listFiles(filters: Seq[Expression]): Seq[Partition] + /** Returns the list of files that will be read when scanning this relation. */ + def inputFiles: Array[String] + + /** Refresh any cached file listings */ + def refresh(): Unit + + /** Sum of table file sizes, in bytes */ + def sizeInBytes: Long +} + +/** + * A [[BasicFileCatalog]] which can enumerate all of the files comprising a relation and, from + * those, infer the relation's partition specification. + */ +// TODO: Consider a more descriptive, appropriate name which suggests this is a file catalog for +// which it is safe to list all of its files? +trait FileCatalog extends BasicFileCatalog { + + /** Returns the specification of the partitions inferred from the data. */ + def partitionSpec(): PartitionSpec + /** Returns all the valid files. */ def allFiles(): Seq[FileStatus] - /** Refresh the file listing */ - def refresh(): Unit + /** Returns the list of files that will be read when scanning this relation. */ + override def inputFiles: Array[String] = + allFiles().map(_.getPath.toUri.toString).toArray + + override def sizeInBytes: Long = allFiles().map(_.getLen).sum } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala index c7ebe0b76a150..db889edf032d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala @@ -28,8 +28,8 @@ import org.apache.spark.sql.types.StructType * Acts as a container for all of the metadata required to read from a datasource. All discovery, * resolution and merging logic for schemas and partitions has been removed. * - * @param location A [[FileCatalog]] that can enumerate the locations of all the files that comprise - * this relation. + * @param location A [[BasicFileCatalog]] that can enumerate the locations of all the files that + * comprise this relation. * @param partitionSchema The schema of the columns (if any) that are used to partition the relation * @param dataSchema The schema of any remaining columns. Note that if any partition columns are * present in the actual data files as well, they are preserved. @@ -38,7 +38,7 @@ import org.apache.spark.sql.types.StructType * @param options Configuration used when reading / writing data. */ case class HadoopFsRelation( - location: FileCatalog, + location: BasicFileCatalog, partitionSchema: StructType, dataSchema: StructType, bucketSpec: Option[BucketSpec], @@ -58,10 +58,6 @@ case class HadoopFsRelation( def partitionSchemaOption: Option[StructType] = if (partitionSchema.isEmpty) None else Some(partitionSchema) - def partitionSpec: PartitionSpec = location.partitionSpec() - - def refresh(): Unit = location.refresh() - override def toString: String = { fileFormat match { case source: DataSourceRegister => source.shortName() @@ -69,9 +65,7 @@ case class HadoopFsRelation( } } - /** Returns the list of files that will be read when scanning this relation. */ - override def inputFiles: Array[String] = - location.allFiles().map(_.getPath.toUri.toString).toArray + override def sizeInBytes: Long = location.sizeInBytes - override def sizeInBytes: Long = location.allFiles().map(_.getLen).sum + override def inputFiles: Array[String] = location.inputFiles } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ListingFileCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ListingFileCatalog.scala index a68ae523e0faa..6d10501b7265d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ListingFileCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ListingFileCatalog.scala @@ -17,32 +17,26 @@ package org.apache.spark.sql.execution.datasources -import java.io.FileNotFoundException - import scala.collection.mutable -import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs._ -import org.apache.hadoop.mapred.{FileInputFormat, JobConf} -import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession import org.apache.spark.sql.types.StructType -import org.apache.spark.util.SerializableConfiguration /** * A [[FileCatalog]] that generates the list of files to process by recursively listing all the * files present in `paths`. * + * @param rootPaths the list of root table paths to scan * @param parameters as set of options to control discovery - * @param paths a list of paths to scan * @param partitionSchema an optional partition schema that will be use to provide types for the * discovered partitions */ class ListingFileCatalog( sparkSession: SparkSession, - override val paths: Seq[Path], + override val rootPaths: Seq[Path], parameters: Map[String, String], partitionSchema: Option[StructType]) extends PartitioningAwareFileCatalog(sparkSession, parameters, partitionSchema) { @@ -70,198 +64,17 @@ class ListingFileCatalog( } override def refresh(): Unit = { - val files = listLeafFiles(paths) + val files = listLeafFiles(rootPaths) cachedLeafFiles = new mutable.LinkedHashMap[Path, FileStatus]() ++= files.map(f => f.getPath -> f) cachedLeafDirToChildrenFiles = files.toArray.groupBy(_.getPath.getParent) cachedPartitionSpec = null } - /** - * List leaf files of given paths. This method will submit a Spark job to do parallel - * listing whenever there is a path having more files than the parallel partition discovery - * discovery threshold. - * - * This is publicly visible for testing. - */ - def listLeafFiles(paths: Seq[Path]): mutable.LinkedHashSet[FileStatus] = { - val files = - if (paths.length >= sparkSession.sessionState.conf.parallelPartitionDiscoveryThreshold) { - ListingFileCatalog.listLeafFilesInParallel(paths, hadoopConf, sparkSession) - } else { - ListingFileCatalog.listLeafFilesInSerial(paths, hadoopConf) - } - - mutable.LinkedHashSet(files: _*) - } - override def equals(other: Any): Boolean = other match { - case hdfs: ListingFileCatalog => paths.toSet == hdfs.paths.toSet + case hdfs: ListingFileCatalog => rootPaths.toSet == hdfs.rootPaths.toSet case _ => false } - override def hashCode(): Int = paths.toSet.hashCode() -} - - -object ListingFileCatalog extends Logging { - - /** A serializable variant of HDFS's BlockLocation. */ - private case class SerializableBlockLocation( - names: Array[String], - hosts: Array[String], - offset: Long, - length: Long) - - /** A serializable variant of HDFS's FileStatus. */ - private case class SerializableFileStatus( - path: String, - length: Long, - isDir: Boolean, - blockReplication: Short, - blockSize: Long, - modificationTime: Long, - accessTime: Long, - blockLocations: Array[SerializableBlockLocation]) - - /** - * List a collection of path recursively. - */ - private def listLeafFilesInSerial( - paths: Seq[Path], - hadoopConf: Configuration): Seq[FileStatus] = { - // Dummy jobconf to get to the pathFilter defined in configuration - val jobConf = new JobConf(hadoopConf, this.getClass) - val filter = FileInputFormat.getInputPathFilter(jobConf) - - paths.flatMap { path => - val fs = path.getFileSystem(hadoopConf) - listLeafFiles0(fs, path, filter) - } - } - - /** - * List a collection of path recursively in parallel (using Spark executors). - * Each task launched will use [[listLeafFilesInSerial]] to list. - */ - private def listLeafFilesInParallel( - paths: Seq[Path], - hadoopConf: Configuration, - sparkSession: SparkSession): Seq[FileStatus] = { - assert(paths.size >= sparkSession.sessionState.conf.parallelPartitionDiscoveryThreshold) - logInfo(s"Listing leaf files and directories in parallel under: ${paths.mkString(", ")}") - - val sparkContext = sparkSession.sparkContext - val serializableConfiguration = new SerializableConfiguration(hadoopConf) - val serializedPaths = paths.map(_.toString) - - // Set the number of parallelism to prevent following file listing from generating many tasks - // in case of large #defaultParallelism. - val numParallelism = Math.min(paths.size, 10000) - - val statuses = sparkContext - .parallelize(serializedPaths, numParallelism) - .mapPartitions { paths => - val hadoopConf = serializableConfiguration.value - listLeafFilesInSerial(paths.map(new Path(_)).toSeq, hadoopConf).iterator - }.map { status => - // Turn FileStatus into SerializableFileStatus so we can send it back to the driver - val blockLocations = status match { - case f: LocatedFileStatus => - f.getBlockLocations.map { loc => - SerializableBlockLocation( - loc.getNames, - loc.getHosts, - loc.getOffset, - loc.getLength) - } - - case _ => - Array.empty[SerializableBlockLocation] - } - - SerializableFileStatus( - status.getPath.toString, - status.getLen, - status.isDirectory, - status.getReplication, - status.getBlockSize, - status.getModificationTime, - status.getAccessTime, - blockLocations) - }.collect() - - // Turn SerializableFileStatus back to Status - statuses.map { f => - val blockLocations = f.blockLocations.map { loc => - new BlockLocation(loc.names, loc.hosts, loc.offset, loc.length) - } - new LocatedFileStatus( - new FileStatus( - f.length, f.isDir, f.blockReplication, f.blockSize, f.modificationTime, new Path(f.path)), - blockLocations) - } - } - - /** - * List a single path, provided as a FileStatus, in serial. - */ - private def listLeafFiles0( - fs: FileSystem, path: Path, filter: PathFilter): Seq[FileStatus] = { - logTrace(s"Listing $path") - val name = path.getName.toLowerCase - if (shouldFilterOut(name)) { - Seq.empty[FileStatus] - } else { - // [SPARK-17599] Prevent ListingFileCatalog from failing if path doesn't exist - // Note that statuses only include FileStatus for the files and dirs directly under path, - // and does not include anything else recursively. - val statuses = try fs.listStatus(path) catch { - case _: FileNotFoundException => - logWarning(s"The directory $path was not found. Was it deleted very recently?") - Array.empty[FileStatus] - } - - val allLeafStatuses = { - val (dirs, files) = statuses.partition(_.isDirectory) - val stats = files ++ dirs.flatMap(dir => listLeafFiles0(fs, dir.getPath, filter)) - if (filter != null) stats.filter(f => filter.accept(f.getPath)) else stats - } - - allLeafStatuses.filterNot(status => shouldFilterOut(status.getPath.getName)).map { - case f: LocatedFileStatus => - f - - // NOTE: - // - // - Although S3/S3A/S3N file system can be quite slow for remote file metadata - // operations, calling `getFileBlockLocations` does no harm here since these file system - // implementations don't actually issue RPC for this method. - // - // - Here we are calling `getFileBlockLocations` in a sequential manner, but it should not - // be a big deal since we always use to `listLeafFilesInParallel` when the number of - // paths exceeds threshold. - case f => - // The other constructor of LocatedFileStatus will call FileStatus.getPermission(), - // which is very slow on some file system (RawLocalFileSystem, which is launch a - // subprocess and parse the stdout). - val locations = fs.getFileBlockLocations(f, 0, f.getLen) - val lfs = new LocatedFileStatus(f.getLen, f.isDirectory, f.getReplication, f.getBlockSize, - f.getModificationTime, 0, null, null, null, null, f.getPath, locations) - if (f.isSymlink) { - lfs.setSymlink(f.getSymlink) - } - lfs - } - } - } - - /** Checks if we should filter out this path name. */ - def shouldFilterOut(pathName: String): Boolean = { - // We filter everything that starts with _ and ., except _common_metadata and _metadata - // because Parquet needs to find those metadata files from leaf files returned by this method. - // We should refactor this logic to not mix metadata files with data files. - ((pathName.startsWith("_") && !pathName.contains("=")) || pathName.startsWith(".")) && - !pathName.startsWith("_common_metadata") && !pathName.startsWith("_metadata") - } + override def hashCode(): Int = rootPaths.toSet.hashCode() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala index d9562fd32e87d..7c28d48f26416 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala @@ -94,7 +94,7 @@ case class LogicalRelation( } override def refresh(): Unit = relation match { - case fs: HadoopFsRelation => fs.refresh() + case fs: HadoopFsRelation => fs.location.refresh() case _ => // Do nothing. } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileCatalog.scala index 702ba97222e34..b2508115c282f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileCatalog.scala @@ -21,7 +21,6 @@ import scala.collection.mutable import org.apache.hadoop.fs.{FileStatus, Path} -import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.{expressions, InternalRow} import org.apache.spark.sql.catalyst.expressions._ @@ -40,9 +39,10 @@ abstract class PartitioningAwareFileCatalog( sparkSession: SparkSession, parameters: Map[String, String], partitionSchema: Option[StructType]) - extends FileCatalog with Logging { + extends SessionFileCatalog(sparkSession) with FileCatalog { + import PartitioningAwareFileCatalog.BASE_PATH_PARAM - protected val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(parameters) + override protected val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(parameters) protected def leafFiles: mutable.LinkedHashMap[Path, FileStatus] @@ -72,8 +72,8 @@ abstract class PartitioningAwareFileCatalog( override def allFiles(): Seq[FileStatus] = { if (partitionSpec().partitionColumns.isEmpty) { - // For each of the input paths, get the list of files inside them - paths.flatMap { path => + // For each of the root input paths, get the list of files inside them + rootPaths.flatMap { path => // Make the path qualified (consistent with listLeafFiles and listLeafFilesInParallel). val fs = path.getFileSystem(hadoopConf) val qualifiedPathPre = fs.makeQualified(path) @@ -105,8 +105,6 @@ abstract class PartitioningAwareFileCatalog( protected def inferPartitioning(): PartitionSpec = { // We use leaf dirs containing data files to discover the schema. val leafDirs = leafDirToChildrenFiles.filter { case (_, files) => - // SPARK-15895: Metadata files (e.g. Parquet summary files) and temporary files should not be - // counted as data files, so that they shouldn't participate partition discovery. files.exists(f => isDataPath(f.getPath)) }.keys.toSeq partitionSchema match { @@ -194,24 +192,30 @@ abstract class PartitioningAwareFileCatalog( * and the returned DataFrame will have the column of `something`. */ private def basePaths: Set[Path] = { - parameters.get("basePath").map(new Path(_)) match { + parameters.get(BASE_PATH_PARAM).map(new Path(_)) match { case Some(userDefinedBasePath) => val fs = userDefinedBasePath.getFileSystem(hadoopConf) if (!fs.isDirectory(userDefinedBasePath)) { - throw new IllegalArgumentException("Option 'basePath' must be a directory") + throw new IllegalArgumentException(s"Option '$BASE_PATH_PARAM' must be a directory") } Set(fs.makeQualified(userDefinedBasePath)) case None => - paths.map { path => + rootPaths.map { path => // Make the path qualified (consistent with listLeafFiles and listLeafFilesInParallel). val qualifiedPath = path.getFileSystem(hadoopConf).makeQualified(path) if (leafFiles.contains(qualifiedPath)) qualifiedPath.getParent else qualifiedPath }.toSet } } + // SPARK-15895: Metadata files (e.g. Parquet summary files) and temporary files should not be + // counted as data files, so that they shouldn't participate partition discovery. private def isDataPath(path: Path): Boolean = { val name = path.getName !((name.startsWith("_") && !name.contains("=")) || name.startsWith(".")) } } + +object PartitioningAwareFileCatalog { + val BASE_PATH_PARAM = "basePath" +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala new file mode 100644 index 0000000000000..29121a47d92d1 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.planning.PhysicalOperation +import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.rules.Rule + +private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { + case op @ PhysicalOperation(projects, filters, + logicalRelation @ + LogicalRelation(fsRelation @ + HadoopFsRelation( + tableFileCatalog: TableFileCatalog, + partitionSchema, + _, + _, + _, + _), + _, + _)) + if filters.nonEmpty && fsRelation.partitionSchemaOption.isDefined => + // The attribute name of predicate could be different than the one in schema in case of + // case insensitive, we should change them to match the one in schema, so we donot need to + // worry about case sensitivity anymore. + val normalizedFilters = filters.map { e => + e transform { + case a: AttributeReference => + a.withName(logicalRelation.output.find(_.semanticEquals(a)).get.name) + } + } + + val sparkSession = fsRelation.sparkSession + val partitionColumns = + logicalRelation.resolve( + partitionSchema, sparkSession.sessionState.analyzer.resolver) + val partitionSet = AttributeSet(partitionColumns) + val partitionKeyFilters = + ExpressionSet(normalizedFilters.filter(_.references.subsetOf(partitionSet))) + + if (partitionKeyFilters.nonEmpty) { + val prunedFileCatalog = tableFileCatalog.filterPartitions(partitionKeyFilters.toSeq) + val prunedFsRelation = + fsRelation.copy(location = prunedFileCatalog)(sparkSession) + val prunedLogicalRelation = logicalRelation.copy(relation = prunedFsRelation) + + // Keep partition-pruning predicates so that they are visible in physical planning + val filterExpression = filters.reduceLeft(And) + val filter = Filter(filterExpression, prunedLogicalRelation) + Project(projects, filter) + } else { + op + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SessionFileCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SessionFileCatalog.scala new file mode 100644 index 0000000000000..4807a92c2e6b8 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SessionFileCatalog.scala @@ -0,0 +1,225 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import java.io.FileNotFoundException + +import scala.collection.mutable + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs._ +import org.apache.hadoop.mapred.{FileInputFormat, JobConf} + +import org.apache.spark.internal.Logging +import org.apache.spark.metrics.source.HiveCatalogMetrics +import org.apache.spark.sql.SparkSession +import org.apache.spark.util.SerializableConfiguration + + +/** + * A base class for [[BasicFileCatalog]]s that need a [[SparkSession]] and the ability to find leaf + * files in a list of HDFS paths. + * + * @param sparkSession a [[SparkSession]] + * @param ignoreFileNotFound (see [[ListingFileCatalog]]) + */ +abstract class SessionFileCatalog(sparkSession: SparkSession) + extends BasicFileCatalog with Logging { + protected val hadoopConf: Configuration + + /** + * List leaf files of given paths. This method will submit a Spark job to do parallel + * listing whenever there is a path having more files than the parallel partition discovery + * discovery threshold. + * + * This is publicly visible for testing. + */ + def listLeafFiles(paths: Seq[Path]): mutable.LinkedHashSet[FileStatus] = { + val files = + if (paths.length >= sparkSession.sessionState.conf.parallelPartitionDiscoveryThreshold) { + SessionFileCatalog.listLeafFilesInParallel(paths, hadoopConf, sparkSession) + } else { + SessionFileCatalog.listLeafFilesInSerial(paths, hadoopConf) + } + + HiveCatalogMetrics.incrementFilesDiscovered(files.size) + mutable.LinkedHashSet(files: _*) + } +} + +object SessionFileCatalog extends Logging { + + /** A serializable variant of HDFS's BlockLocation. */ + private case class SerializableBlockLocation( + names: Array[String], + hosts: Array[String], + offset: Long, + length: Long) + + /** A serializable variant of HDFS's FileStatus. */ + private case class SerializableFileStatus( + path: String, + length: Long, + isDir: Boolean, + blockReplication: Short, + blockSize: Long, + modificationTime: Long, + accessTime: Long, + blockLocations: Array[SerializableBlockLocation]) + + /** + * List a collection of path recursively. + */ + private def listLeafFilesInSerial( + paths: Seq[Path], + hadoopConf: Configuration): Seq[FileStatus] = { + // Dummy jobconf to get to the pathFilter defined in configuration + val jobConf = new JobConf(hadoopConf, this.getClass) + val filter = FileInputFormat.getInputPathFilter(jobConf) + + paths.flatMap { path => + val fs = path.getFileSystem(hadoopConf) + listLeafFiles0(fs, path, filter) + } + } + + /** + * List a collection of path recursively in parallel (using Spark executors). + * Each task launched will use [[listLeafFilesInSerial]] to list. + */ + private def listLeafFilesInParallel( + paths: Seq[Path], + hadoopConf: Configuration, + sparkSession: SparkSession): Seq[FileStatus] = { + assert(paths.size >= sparkSession.sessionState.conf.parallelPartitionDiscoveryThreshold) + logInfo(s"Listing leaf files and directories in parallel under: ${paths.mkString(", ")}") + + val sparkContext = sparkSession.sparkContext + val serializableConfiguration = new SerializableConfiguration(hadoopConf) + val serializedPaths = paths.map(_.toString) + + // Set the number of parallelism to prevent following file listing from generating many tasks + // in case of large #defaultParallelism. + val numParallelism = Math.min(paths.size, 10000) + + val statuses = sparkContext + .parallelize(serializedPaths, numParallelism) + .mapPartitions { paths => + val hadoopConf = serializableConfiguration.value + listLeafFilesInSerial(paths.map(new Path(_)).toSeq, hadoopConf).iterator + }.map { status => + // Turn FileStatus into SerializableFileStatus so we can send it back to the driver + val blockLocations = status match { + case f: LocatedFileStatus => + f.getBlockLocations.map { loc => + SerializableBlockLocation( + loc.getNames, + loc.getHosts, + loc.getOffset, + loc.getLength) + } + + case _ => + Array.empty[SerializableBlockLocation] + } + + SerializableFileStatus( + status.getPath.toString, + status.getLen, + status.isDirectory, + status.getReplication, + status.getBlockSize, + status.getModificationTime, + status.getAccessTime, + blockLocations) + }.collect() + + // Turn SerializableFileStatus back to Status + statuses.map { f => + val blockLocations = f.blockLocations.map { loc => + new BlockLocation(loc.names, loc.hosts, loc.offset, loc.length) + } + new LocatedFileStatus( + new FileStatus( + f.length, f.isDir, f.blockReplication, f.blockSize, f.modificationTime, new Path(f.path)), + blockLocations) + } + } + + /** + * List a single path, provided as a FileStatus, in serial. + */ + private def listLeafFiles0( + fs: FileSystem, path: Path, filter: PathFilter): Seq[FileStatus] = { + logTrace(s"Listing $path") + val name = path.getName.toLowerCase + if (shouldFilterOut(name)) { + Seq.empty[FileStatus] + } else { + // [SPARK-17599] Prevent ListingFileCatalog from failing if path doesn't exist + // Note that statuses only include FileStatus for the files and dirs directly under path, + // and does not include anything else recursively. + val statuses = try fs.listStatus(path) catch { + case _: FileNotFoundException => + logWarning(s"The directory $path was not found. Was it deleted very recently?") + Array.empty[FileStatus] + } + + val allLeafStatuses = { + val (dirs, files) = statuses.partition(_.isDirectory) + val stats = files ++ dirs.flatMap(dir => listLeafFiles0(fs, dir.getPath, filter)) + if (filter != null) stats.filter(f => filter.accept(f.getPath)) else stats + } + + allLeafStatuses.filterNot(status => shouldFilterOut(status.getPath.getName)).map { + case f: LocatedFileStatus => + f + + // NOTE: + // + // - Although S3/S3A/S3N file system can be quite slow for remote file metadata + // operations, calling `getFileBlockLocations` does no harm here since these file system + // implementations don't actually issue RPC for this method. + // + // - Here we are calling `getFileBlockLocations` in a sequential manner, but it should not + // be a big deal since we always use to `listLeafFilesInParallel` when the number of + // paths exceeds threshold. + case f => + // The other constructor of LocatedFileStatus will call FileStatus.getPermission(), + // which is very slow on some file system (RawLocalFileSystem, which is launch a + // subprocess and parse the stdout). + val locations = fs.getFileBlockLocations(f, 0, f.getLen) + val lfs = new LocatedFileStatus(f.getLen, f.isDirectory, f.getReplication, f.getBlockSize, + f.getModificationTime, 0, null, null, null, null, f.getPath, locations) + if (f.isSymlink) { + lfs.setSymlink(f.getSymlink) + } + lfs + } + } + } + + /** Checks if we should filter out this path name. */ + def shouldFilterOut(pathName: String): Boolean = { + // We filter everything that starts with _ and ., except _common_metadata and _metadata + // because Parquet needs to find those metadata files from leaf files returned by this method. + // We should refactor this logic to not mix metadata files with data files. + ((pathName.startsWith("_") && !pathName.contains("=")) || pathName.startsWith(".")) && + !pathName.startsWith("_common_metadata") && !pathName.startsWith("_metadata") + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/TableFileCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/TableFileCatalog.scala new file mode 100644 index 0000000000000..a5c41b244589b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/TableFileCatalog.scala @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.hadoop.fs.Path + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.StructType + + +/** + * A [[BasicFileCatalog]] for a metastore catalog table. + * + * @param sparkSession a [[SparkSession]] + * @param db the table's database name + * @param table the table's (unqualified) name + * @param partitionSchema the schema of a partitioned table's partition columns + * @param sizeInBytes the table's data size in bytes + */ +class TableFileCatalog( + sparkSession: SparkSession, + db: String, + table: String, + partitionSchema: Option[StructType], + override val sizeInBytes: Long) + extends SessionFileCatalog(sparkSession) { + + override protected val hadoopConf = sparkSession.sessionState.newHadoopConf + + private val externalCatalog = sparkSession.sharedState.externalCatalog + + private val catalogTable = externalCatalog.getTable(db, table) + + private val baseLocation = catalogTable.storage.locationUri + + override def rootPaths: Seq[Path] = baseLocation.map(new Path(_)).toSeq + + override def listFiles(filters: Seq[Expression]): Seq[Partition] = { + filterPartitions(filters).listFiles(Nil) + } + + override def refresh(): Unit = {} + + /** + * Returns a [[ListingFileCatalog]] for this table restricted to the subset of partitions + * specified by the given partition-pruning filters. + * + * @param filters partition-pruning filters + */ + def filterPartitions(filters: Seq[Expression]): ListingFileCatalog = { + if (filters.isEmpty) { + cachedAllPartitions + } else { + filterPartitions0(filters) + } + } + + private def filterPartitions0(filters: Seq[Expression]): ListingFileCatalog = { + val parameters = baseLocation + .map(loc => Map(PartitioningAwareFileCatalog.BASE_PATH_PARAM -> loc)) + .getOrElse(Map.empty) + partitionSchema match { + case Some(schema) => + val selectedPartitions = externalCatalog.listPartitionsByFilter(db, table, filters) + val partitions = selectedPartitions.map { p => + PartitionDirectory(p.toRow(schema), p.storage.locationUri.get) + } + val partitionSpec = PartitionSpec(schema, partitions) + new PrunedTableFileCatalog( + sparkSession, new Path(baseLocation.get), partitionSpec) + case None => + new ListingFileCatalog(sparkSession, rootPaths, parameters, None) + } + } + + // Not used in the hot path of queries when metastore partition pruning is enabled + lazy val cachedAllPartitions: ListingFileCatalog = filterPartitions0(Nil) + + override def inputFiles: Array[String] = cachedAllPartitions.inputFiles +} + +/** + * An override of the standard HDFS listing based catalog, that overrides the partition spec with + * the information from the metastore. + * + * @param tableBasePath The default base path of the Hive metastore table + * @param partitionSpec The partition specifications from Hive metastore + */ +private class PrunedTableFileCatalog( + sparkSession: SparkSession, + tableBasePath: Path, + override val partitionSpec: PartitionSpec) + extends ListingFileCatalog( + sparkSession, + partitionSpec.partitions.map(_.path), + Map.empty, + Some(partitionSpec.partitionColumns)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala index f1a35dd8a6200..4dea8cf29ec58 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala @@ -269,11 +269,15 @@ private[parquet] object ParquetReadSupport { */ private def clipParquetGroupFields( parquetRecord: GroupType, structType: StructType): Seq[Type] = { - val parquetFieldMap = parquetRecord.getFields.asScala.map(f => f.getName -> f).toMap + val parquetFieldMap = parquetRecord.getFields.asScala + .map(f => f.getName -> f).toMap + val caseInsensitiveParquetFieldMap = parquetRecord.getFields.asScala + .map(f => f.getName.toLowerCase -> f).toMap val toParquet = new ParquetSchemaConverter(writeLegacyParquetFormat = false) structType.map { f => parquetFieldMap .get(f.name) + .orElse(caseInsensitiveParquetFieldMap.get(f.name.toLowerCase)) .map(clipParquetType(_, f.dataType)) .getOrElse(toParquet.convertField(f)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileCatalog.scala index a32c4671e3475..82b67cb1ca6ee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileCatalog.scala @@ -47,7 +47,7 @@ class MetadataLogFileCatalog(sparkSession: SparkSession, path: Path) allFilesFromLog.toArray.groupBy(_.getPath.getParent) } - override def paths: Seq[Path] = path :: Nil + override def rootPaths: Seq[Path] = path :: Nil override def refresh(): Unit = { } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index c8447651dd672..e73d0187b584b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -269,6 +269,13 @@ object SQLConf { .booleanConf .createWithDefault(false) + val HIVE_FILESOURCE_PARTITION_PRUNING = + SQLConfigBuilder("spark.sql.hive.filesourcePartitionPruning") + .doc("When true, enable metastore partition pruning for file source tables as well. " + + "This is currently implemented for converted Hive tables only.") + .booleanConf + .createWithDefault(true) + val OPTIMIZER_METADATA_ONLY = SQLConfigBuilder("spark.sql.optimizer.metadataOnly") .doc("When true, enable the metadata-only query optimization that use the table's metadata " + "to produce the partition columns instead of table scans. It applies when all the columns " + @@ -676,6 +683,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def metastorePartitionPruning: Boolean = getConf(HIVE_METASTORE_PARTITION_PRUNING) + def filesourcePartitionPruning: Boolean = getConf(HIVE_FILESOURCE_PARTITION_PRUNING) + def gatherFastStats: Boolean = getConf(GATHER_FASTSTAT) def optimizerMetadataOnly: Boolean = getConf(OPTIMIZER_METADATA_ONLY) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileCatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileCatalogSuite.scala index fa3abd0098f5b..2695974b84b00 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileCatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileCatalogSuite.scala @@ -77,13 +77,14 @@ class FileCatalogSuite extends SharedSQLContext { val catalog1 = new ListingFileCatalog( spark, Seq(new Path(deletedFolder.getCanonicalPath)), Map.empty, None) // doesn't throw an exception - assert(catalog1.listLeafFiles(catalog1.paths).isEmpty) + assert(catalog1.listLeafFiles(catalog1.rootPaths).isEmpty) } } test("SPARK-17613 - PartitioningAwareFileCatalog: base path w/o '/' at end") { class MockCatalog( - override val paths: Seq[Path]) extends PartitioningAwareFileCatalog(spark, Map.empty, None) { + override val rootPaths: Seq[Path]) + extends PartitioningAwareFileCatalog(spark, Map.empty, None) { override def refresh(): Unit = {} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index c5deb31fec183..c32254d9dfde2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -395,7 +395,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi val fileCatalog = new ListingFileCatalog( sparkSession = spark, - paths = Seq(new Path(tempDir)), + rootPaths = Seq(new Path(tempDir)), parameters = Map.empty[String, String], partitionSchema = None) // This should not fail. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/ListingFileCatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SessionFileCatalogSuite.scala similarity index 66% rename from sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/ListingFileCatalogSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SessionFileCatalogSuite.scala index f15730aeb11f2..df509583377ae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/ListingFileCatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SessionFileCatalogSuite.scala @@ -19,16 +19,16 @@ package org.apache.spark.sql.execution.datasources import org.apache.spark.SparkFunSuite -class ListingFileCatalogSuite extends SparkFunSuite { +class SessionFileCatalogSuite extends SparkFunSuite { test("file filtering") { - assert(!ListingFileCatalog.shouldFilterOut("abcd")) - assert(ListingFileCatalog.shouldFilterOut(".ab")) - assert(ListingFileCatalog.shouldFilterOut("_cd")) + assert(!SessionFileCatalog.shouldFilterOut("abcd")) + assert(SessionFileCatalog.shouldFilterOut(".ab")) + assert(SessionFileCatalog.shouldFilterOut("_cd")) - assert(!ListingFileCatalog.shouldFilterOut("_metadata")) - assert(!ListingFileCatalog.shouldFilterOut("_common_metadata")) - assert(ListingFileCatalog.shouldFilterOut("_ab_metadata")) - assert(ListingFileCatalog.shouldFilterOut("_cd_common_metadata")) + assert(!SessionFileCatalog.shouldFilterOut("_metadata")) + assert(!SessionFileCatalog.shouldFilterOut("_common_metadata")) + assert(SessionFileCatalog.shouldFilterOut("_ab_metadata")) + assert(SessionFileCatalog.shouldFilterOut("_cd_common_metadata")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index 8d18be9300f7e..43357c97c395a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -30,7 +30,7 @@ import org.apache.parquet.hadoop.ParquetOutputFormat import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Literal -import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation, PartitionDirectory => Partition, PartitioningUtils, PartitionSpec} +import org.apache.spark.sql.execution.datasources.{FileCatalog, HadoopFsRelation, LogicalRelation, PartitionDirectory => Partition, PartitioningUtils, PartitionSpec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -626,8 +626,8 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha (1 to 10).map(i => (i, i.toString)).toDF("a", "b").write.parquet(dir.getCanonicalPath) val queryExecution = spark.read.parquet(dir.getCanonicalPath).queryExecution queryExecution.analyzed.collectFirst { - case LogicalRelation(relation: HadoopFsRelation, _, _) => - assert(relation.partitionSpec === PartitionSpec.emptySpec) + case LogicalRelation(HadoopFsRelation(location: FileCatalog, _, _, _, _, _), _, _) => + assert(location.partitionSpec === PartitionSpec.emptySpec) }.getOrElse { fail(s"Expecting a ParquetRelation2, but got:\n$queryExecution") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala index 8a980a7eb538f..c3d202ced24c8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala @@ -1080,6 +1080,34 @@ class ParquetSchemaSuite extends ParquetSchemaTest { } } + testSchemaClipping( + "falls back to case insensitive resolution", + + parquetSchema = + """message root { + | required group A { + | optional int32 B; + | } + | optional int32 c; + |} + """.stripMargin, + + catalystSchema = { + val nestedType = new StructType().add("b", IntegerType, nullable = true) + new StructType() + .add("a", nestedType, nullable = true) + .add("c", IntegerType, nullable = true) + }, + + expectedSchema = + """message root { + | required group A { + | optional int32 B; + | } + | optional int32 c; + |} + """.stripMargin) + testSchemaClipping( "simple nested struct", diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index b5d93c3d7c804..ff59b54f53909 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -29,17 +29,17 @@ import org.apache.thrift.TException import org.apache.spark.SparkConf import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException import org.apache.spark.sql.catalyst.catalog._ -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Statistics} import org.apache.spark.sql.execution.command.{ColumnStatStruct, DDLUtils} import org.apache.spark.sql.execution.datasources.CaseInsensitiveMap import org.apache.spark.sql.hive.client.HiveClient import org.apache.spark.sql.internal.HiveSerDe import org.apache.spark.sql.internal.StaticSQLConf._ -import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.sql.types.{DataType, StructField, StructType} /** @@ -650,8 +650,35 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat override def listPartitionsByFilter( db: String, table: String, - predicates: Seq[Expression]): Seq[CatalogTablePartition] = { - client.getPartitionsByFilter(db, table, predicates) + predicates: Seq[Expression]): Seq[CatalogTablePartition] = withClient { + val catalogTable = client.getTable(db, table) + val partitionColumnNames = catalogTable.partitionColumnNames.toSet + val nonPartitionPruningPredicates = predicates.filterNot { + _.references.map(_.name).toSet.subsetOf(partitionColumnNames) + } + + if (nonPartitionPruningPredicates.nonEmpty) { + sys.error("Expected only partition pruning predicates: " + + predicates.reduceLeft(And)) + } + + val partitionSchema = catalogTable.partitionSchema + + if (predicates.nonEmpty) { + val clientPrunedPartitions = + client.getPartitionsByFilter(catalogTable, predicates) + val boundPredicate = + InterpretedPredicate.create(predicates.reduce(And).transform { + case att: AttributeReference => + val index = partitionSchema.indexWhere(_.name == att.name) + BoundReference(index, partitionSchema(index).dataType, nullable = true) + }) + clientPrunedPartitions.filter { case p: CatalogTablePartition => + boundPredicate(p.toRow(partitionSchema)) + } + } else { + client.getPartitions(catalogTable) + } } // -------------------------------------------------------------------------- diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index c44f0adda44c0..4a2aaa7d4f6ca 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -135,12 +135,12 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log private def getCached( tableIdentifier: QualifiedTableName, - pathsInMetastore: Seq[String], + pathsInMetastore: Seq[Path], metastoreRelation: MetastoreRelation, schemaInMetastore: StructType, expectedFileFormat: Class[_ <: FileFormat], expectedBucketSpec: Option[BucketSpec], - partitionSpecInMetastore: Option[PartitionSpec]): Option[LogicalRelation] = { + partitionSchema: Option[StructType]): Option[LogicalRelation] = { cachedDataSourceTables.getIfPresent(tableIdentifier) match { case null => None // Cache miss @@ -152,12 +152,10 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log // If we have the same paths, same schema, and same partition spec, // we will use the cached relation. val useCached = - relation.location.paths.map(_.toString).toSet == pathsInMetastore.toSet && + relation.location.rootPaths.toSet == pathsInMetastore.toSet && logical.schema.sameType(schemaInMetastore) && relation.bucketSpec == expectedBucketSpec && - relation.partitionSpec == partitionSpecInMetastore.getOrElse { - PartitionSpec(StructType(Nil), Array.empty[PartitionDirectory]) - } + relation.partitionSchema == partitionSchema.getOrElse(StructType(Nil)) if (useCached) { Some(logical) @@ -196,61 +194,59 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log QualifiedTableName(metastoreRelation.databaseName, metastoreRelation.tableName) val bucketSpec = None // We don't support hive bucketed tables, only ones we write out. + val lazyPruningEnabled = sparkSession.sqlContext.conf.filesourcePartitionPruning val result = if (metastoreRelation.hiveQlTable.isPartitioned) { val partitionSchema = StructType.fromAttributes(metastoreRelation.partitionKeys) - val partitionColumnDataTypes = partitionSchema.map(_.dataType) - // We're converting the entire table into HadoopFsRelation, so predicates to Hive metastore - // are empty. - val partitions = metastoreRelation.getHiveQlPartitions().map { p => - val location = p.getLocation - val values = InternalRow.fromSeq(p.getValues.asScala.zip(partitionColumnDataTypes).map { - case (rawValue, dataType) => Cast(Literal(rawValue), dataType).eval(null) - }) - PartitionDirectory(values, location) - } - val partitionSpec = PartitionSpec(partitionSchema, partitions) - val partitionPaths = partitions.map(_.path.toString) - - // By convention (for example, see MetaStorePartitionedTableFileCatalog), the definition of a - // partitioned table's paths depends on whether that table has any actual partitions. - // Partitioned tables without partitions use the location of the table's base path. - // Partitioned tables with partitions use the locations of those partitions' data locations, - // _omitting_ the table's base path. - val paths = if (partitionPaths.isEmpty) { - Seq(metastoreRelation.hiveQlTable.getDataLocation.toString) + + val rootPaths: Seq[Path] = if (lazyPruningEnabled) { + Seq(metastoreRelation.hiveQlTable.getDataLocation) } else { - partitionPaths + // By convention (for example, see TableFileCatalog), the definition of a + // partitioned table's paths depends on whether that table has any actual partitions. + // Partitioned tables without partitions use the location of the table's base path. + // Partitioned tables with partitions use the locations of those partitions' data + // locations,_omitting_ the table's base path. + val paths = metastoreRelation.getHiveQlPartitions().map { p => + new Path(p.getLocation) + } + if (paths.isEmpty) { + Seq(metastoreRelation.hiveQlTable.getDataLocation) + } else { + paths + } } val cached = getCached( tableIdentifier, - paths, + rootPaths, metastoreRelation, metastoreSchema, fileFormatClass, bucketSpec, - Some(partitionSpec)) - - val hadoopFsRelation = cached.getOrElse { - val fileCatalog = new MetaStorePartitionedTableFileCatalog( - sparkSession, - new Path(metastoreRelation.catalogTable.storage.locationUri.get), - partitionSpec) - - val inferredSchema = if (fileType.equals("parquet")) { - val inferredSchema = - defaultSource.inferSchema(sparkSession, options, fileCatalog.allFiles()) - inferredSchema.map { inferred => - ParquetFileFormat.mergeMetastoreParquetSchema(metastoreSchema, inferred) - }.getOrElse(metastoreSchema) - } else { - defaultSource.inferSchema(sparkSession, options, fileCatalog.allFiles()).get + Some(partitionSchema)) + + val logicalRelation = cached.getOrElse { + val db = metastoreRelation.databaseName + val table = metastoreRelation.tableName + val sizeInBytes = metastoreRelation.statistics.sizeInBytes.toLong + val fileCatalog = { + val catalog = new TableFileCatalog( + sparkSession, db, table, Some(partitionSchema), sizeInBytes) + if (lazyPruningEnabled) { + catalog + } else { + catalog.cachedAllPartitions + } } + val partitionSchemaColumnNames = partitionSchema.map(_.name.toLowerCase).toSet + val dataSchema = + StructType(metastoreSchema + .filterNot(field => partitionSchemaColumnNames.contains(field.name.toLowerCase))) val relation = HadoopFsRelation( location = fileCatalog, partitionSchema = partitionSchema, - dataSchema = inferredSchema, + dataSchema = dataSchema, bucketSpec = bucketSpec, fileFormat = defaultSource, options = options)(sparkSession = sparkSession) @@ -260,12 +256,12 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log created } - hadoopFsRelation + logicalRelation } else { - val paths = Seq(metastoreRelation.hiveQlTable.getDataLocation.toString) + val rootPath = metastoreRelation.hiveQlTable.getDataLocation val cached = getCached(tableIdentifier, - paths, + Seq(rootPath), metastoreRelation, metastoreSchema, fileFormatClass, @@ -276,14 +272,13 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log LogicalRelation( DataSource( sparkSession = sparkSession, - paths = paths, + paths = rootPath.toString :: Nil, userSpecifiedSchema = Some(metastoreRelation.schema), bucketSpec = bucketSpec, options = options, className = fileType).resolveRelation(), catalogTable = Some(metastoreRelation.catalogTable)) - cachedDataSourceTables.put(tableIdentifier, created) created } @@ -371,34 +366,3 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log } } } - -/** - * An override of the standard HDFS listing based catalog, that overrides the partition spec with - * the information from the metastore. - * - * @param tableBasePath The default base path of the Hive metastore table - * @param partitionSpec The partition specifications from Hive metastore - */ -private[hive] class MetaStorePartitionedTableFileCatalog( - sparkSession: SparkSession, - tableBasePath: Path, - override val partitionSpec: PartitionSpec) - extends ListingFileCatalog( - sparkSession, - MetaStorePartitionedTableFileCatalog.getPaths(tableBasePath, partitionSpec), - Map.empty, - Some(partitionSpec.partitionColumns)) { -} - -private[hive] object MetaStorePartitionedTableFileCatalog { - /** Get the list of paths to list files in the for a metastore table */ - def getPaths(tableBasePath: Path, partitionSpec: PartitionSpec): Seq[Path] = { - // If there are no partitions currently specified then use base path, - // otherwise use the paths corresponding to the partitions. - if (partitionSpec.partitions.isEmpty) { - Seq(tableBasePath) - } else { - partitionSpec.partitions.map(_.path) - } - } -} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala index 9ee3d629c9977..569a9c11398ea 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala @@ -172,15 +172,24 @@ private[hive] trait HiveClient { * Returns the partitions for the given table that match the supplied partition spec. * If no partition spec is specified, all partitions are returned. */ - def getPartitions( + final def getPartitions( db: String, table: String, + partialSpec: Option[TablePartitionSpec]): Seq[CatalogTablePartition] = { + getPartitions(getTable(db, table), partialSpec) + } + + /** + * Returns the partitions for the given table that match the supplied partition spec. + * If no partition spec is specified, all partitions are returned. + */ + def getPartitions( + catalogTable: CatalogTable, partialSpec: Option[TablePartitionSpec] = None): Seq[CatalogTablePartition] /** Returns partitions filtered by predicates for the given table. */ def getPartitionsByFilter( - db: String, - table: String, + catalogTable: CatalogTable, predicates: Seq[Expression]): Seq[CatalogTablePartition] /** Loads a static partition into an existing table. */ diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 5c8f7ff1af9fa..e745a8c5b3589 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -37,6 +37,7 @@ import org.apache.hadoop.security.UserGroupInformation import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.internal.Logging +import org.apache.spark.metrics.source.HiveCatalogMetrics import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{NoSuchDatabaseException, NoSuchPartitionException} @@ -525,22 +526,24 @@ private[hive] class HiveClientImpl( * If no partition spec is specified, all partitions are returned. */ override def getPartitions( - db: String, - table: String, + table: CatalogTable, spec: Option[TablePartitionSpec]): Seq[CatalogTablePartition] = withHiveState { - val hiveTable = toHiveTable(getTable(db, table)) - spec match { + val hiveTable = toHiveTable(table) + val parts = spec match { case None => shim.getAllPartitions(client, hiveTable).map(fromHivePartition) case Some(s) => client.getPartitions(hiveTable, s.asJava).asScala.map(fromHivePartition) } + HiveCatalogMetrics.incrementFetchedPartitions(parts.length) + parts } override def getPartitionsByFilter( - db: String, - table: String, + table: CatalogTable, predicates: Seq[Expression]): Seq[CatalogTablePartition] = withHiveState { - val hiveTable = toHiveTable(getTable(db, table)) - shim.getPartitionsByFilter(client, hiveTable, predicates).map(fromHivePartition) + val hiveTable = toHiveTable(table) + val parts = shim.getPartitionsByFilter(client, hiveTable, predicates).map(fromHivePartition) + HiveCatalogMetrics.incrementFetchedPartitions(parts.length) + parts } override def listTables(dbName: String): Seq[String] = withHiveState { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala index e94f49ea81177..1af3280e18a89 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala @@ -313,7 +313,17 @@ private[orc] object OrcRelation extends HiveInspectors { def setRequiredColumns( conf: Configuration, physicalSchema: StructType, requestedSchema: StructType): Unit = { - val ids = requestedSchema.map(a => physicalSchema.fieldIndex(a.name): Integer) + val caseInsensitiveFieldMap: Map[String, Int] = physicalSchema.fieldNames + .zipWithIndex + .map(f => (f._1.toLowerCase, f._2)) + .toMap + val ids = requestedSchema.map { a => + val exactMatch: Option[Int] = physicalSchema.getFieldIndex(a.name) + val res = exactMatch.getOrElse( + caseInsensitiveFieldMap.getOrElse(a.name, + throw new IllegalArgumentException(s"""Field "$a.name" does not exist."""))) + res: Integer + } val (sortedIDs, sortedNames) = ids.zip(requestedSchema.fieldNames).sorted.unzip HiveShim.appendReadColumns(conf, sortedIDs, sortedNames) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameSuite.scala index 96e9054cd4876..f65e74de87a57 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameSuite.scala @@ -17,10 +17,14 @@ package org.apache.spark.sql.hive +import java.io.File + +import org.apache.spark.metrics.source.HiveCatalogMetrics import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.QueryTest -class HiveDataFrameSuite extends QueryTest with TestHiveSingleton { +class HiveDataFrameSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { test("table name with schema") { // regression test for SPARK-11778 spark.sql("create schema usrdb") @@ -34,4 +38,107 @@ class HiveDataFrameSuite extends QueryTest with TestHiveSingleton { val hiveClient = spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client assert(hiveClient.getConf("hive.in.test", "") == "true") } + + private def setupPartitionedTable(tableName: String, dir: File): Unit = { + spark.range(5).selectExpr("id", "id as partCol1", "id as partCol2").write + .partitionBy("partCol1", "partCol2") + .mode("overwrite") + .parquet(dir.getAbsolutePath) + + spark.sql(s""" + |create external table $tableName (id long) + |partitioned by (partCol1 int, partCol2 int) + |stored as parquet + |location "${dir.getAbsolutePath}"""".stripMargin) + spark.sql(s"msck repair table $tableName") + } + + test("partitioned pruned table reports only selected files") { + assert(spark.sqlContext.getConf(HiveUtils.CONVERT_METASTORE_PARQUET.key) == "true") + withTable("test") { + withTempDir { dir => + setupPartitionedTable("test", dir) + val df = spark.sql("select * from test") + assert(df.count() == 5) + assert(df.inputFiles.length == 5) // unpruned + + val df2 = spark.sql("select * from test where partCol1 = 3 or partCol2 = 4") + assert(df2.count() == 2) + assert(df2.inputFiles.length == 2) // pruned, so we have less files + + val df3 = spark.sql("select * from test where PARTCOL1 = 3 or partcol2 = 4") + assert(df3.count() == 2) + assert(df3.inputFiles.length == 2) + + val df4 = spark.sql("select * from test where partCol1 = 999") + assert(df4.count() == 0) + assert(df4.inputFiles.length == 0) + } + } + } + + test("lazy partition pruning reads only necessary partition data") { + withSQLConf("spark.sql.hive.filesourcePartitionPruning" -> "true") { + withTable("test") { + withTempDir { dir => + setupPartitionedTable("test", dir) + HiveCatalogMetrics.reset() + spark.sql("select * from test where partCol1 = 999").count() + assert(HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount() == 0) + assert(HiveCatalogMetrics.METRIC_FILES_DISCOVERED.getCount() == 0) + + HiveCatalogMetrics.reset() + spark.sql("select * from test where partCol1 < 2").count() + assert(HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount() == 2) + assert(HiveCatalogMetrics.METRIC_FILES_DISCOVERED.getCount() == 2) + + HiveCatalogMetrics.reset() + spark.sql("select * from test where partCol1 < 3").count() + assert(HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount() == 3) + assert(HiveCatalogMetrics.METRIC_FILES_DISCOVERED.getCount() == 3) + + // should read all + HiveCatalogMetrics.reset() + spark.sql("select * from test").count() + assert(HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount() == 5) + assert(HiveCatalogMetrics.METRIC_FILES_DISCOVERED.getCount() == 5) + + // read all should be cached + HiveCatalogMetrics.reset() + spark.sql("select * from test").count() + assert(HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount() == 0) + assert(HiveCatalogMetrics.METRIC_FILES_DISCOVERED.getCount() == 0) + } + } + } + } + + test("all partitions read and cached when filesource partition pruning is off") { + withSQLConf("spark.sql.hive.filesourcePartitionPruning" -> "false") { + withTable("test") { + withTempDir { dir => + setupPartitionedTable("test", dir) + + // We actually query the partitions from hive each time the table is resolved in this + // mode. This is kind of terrible, but is needed to preserve the legacy behavior + // of doing plan cache validation based on the entire partition set. + HiveCatalogMetrics.reset() + spark.sql("select * from test where partCol1 = 999").count() + // 5 from table resolution, another 5 from ListingFileCatalog + assert(HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount() == 10) + assert(HiveCatalogMetrics.METRIC_FILES_DISCOVERED.getCount() == 5) + + HiveCatalogMetrics.reset() + spark.sql("select * from test where partCol1 < 2").count() + assert(HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount() == 5) + assert(HiveCatalogMetrics.METRIC_FILES_DISCOVERED.getCount() == 0) + + HiveCatalogMetrics.reset() + spark.sql("select * from test").count() + assert(HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount() == 5) + assert(HiveCatalogMetrics.METRIC_FILES_DISCOVERED.getCount() == 0) + } + } + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetadataCacheSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetadataCacheSuite.scala index 3414f5e0409a1..7af81a3a90504 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetadataCacheSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetadataCacheSuite.scala @@ -59,4 +59,45 @@ class HiveMetadataCacheSuite extends QueryTest with SQLTestUtils with TestHiveSi } } } + + def testCaching(pruningEnabled: Boolean): Unit = { + test(s"partitioned table is cached when partition pruning is $pruningEnabled") { + withSQLConf("spark.sql.hive.filesourcePartitionPruning" -> pruningEnabled.toString) { + withTable("test") { + withTempDir { dir => + spark.range(5).selectExpr("id", "id as f1", "id as f2").write + .partitionBy("f1", "f2") + .mode("overwrite") + .parquet(dir.getAbsolutePath) + + spark.sql(s""" + |create external table test (id long) + |partitioned by (f1 int, f2 int) + |stored as parquet + |location "${dir.getAbsolutePath}"""".stripMargin) + spark.sql("msck repair table test") + + val df = spark.sql("select * from test") + assert(sql("select * from test").count() == 5) + + // Delete a file, then assert that we tried to read it. This means the table was cached. + val p = new Path(spark.table("test").inputFiles.head) + assert(p.getFileSystem(hiveContext.sessionState.newHadoopConf()).delete(p, true)) + val e = intercept[SparkException] { + sql("select * from test").count() + } + assert(e.getMessage.contains("FileNotFoundException")) + + // Test refreshing the cache. + spark.catalog.refreshTable("test") + assert(sql("select * from test").count() == 4) + } + } + } + } + } + + for (pruningEnabled <- Seq(true, false)) { + testCaching(pruningEnabled) + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index c158bf1ab09cb..9a10957c8efa5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -295,12 +295,12 @@ class VersionsSuite extends SparkFunSuite with Logging { } test(s"$version: getPartitions(catalogTable)") { - assert(2 == client.getPartitions("default", "src_part").size) + assert(2 == client.getPartitions(client.getTable("default", "src_part")).size) } test(s"$version: getPartitionsByFilter") { // Only one partition [1, 1] for key2 == 1 - val result = client.getPartitionsByFilter("default", "src_part", + val result = client.getPartitionsByFilter(client.getTable("default", "src_part"), Seq(EqualTo(AttributeReference("key2", IntegerType)(), Literal(1)))) // Hive 0.12 doesn't support getPartitionsByFilter, it ignores the filter condition. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala index b2ee49c441ef2..ecb5972984523 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -474,6 +474,28 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { } } + test("converted ORC table supports resolving mixed case field") { + withSQLConf(HiveUtils.CONVERT_METASTORE_ORC.key -> "true") { + withTable("dummy_orc") { + withTempPath { dir => + val df = spark.range(5).selectExpr("id", "id as valueField", "id as partitionValue") + df.write + .partitionBy("partitionValue") + .mode("overwrite") + .orc(dir.getAbsolutePath) + + spark.sql(s""" + |create external table dummy_orc (id long, valueField long) + |partitioned by (partitionValue int) + |stored as orc + |location "${dir.getAbsolutePath}"""".stripMargin) + spark.sql(s"msck repair table dummy_orc") + checkAnswer(spark.sql("select * from dummy_orc"), df) + } + } + } + } + test("SPARK-14962 Produce correct results on array type with isnotnull") { withSQLConf(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key -> "true") { val data = (0 until 10).map(i => Tuple1(Array(i))) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index 2f6d9fb96b825..9fc62a389db4d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -175,7 +175,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { (1 to 10).map(i => Tuple1(Seq(new Integer(i), null))).toDF("a") .createOrReplaceTempView("jt_array") - setConf(HiveUtils.CONVERT_METASTORE_PARQUET, true) + assert(spark.sqlContext.getConf(HiveUtils.CONVERT_METASTORE_PARQUET.key) == "true") } override def afterAll(): Unit = { @@ -187,7 +187,6 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { "jt", "jt_array", "test_parquet") - setConf(HiveUtils.CONVERT_METASTORE_PARQUET, false) } test(s"conversion is working") { @@ -586,6 +585,23 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { checkAnswer( sql("SELECT * FROM test_added_partitions"), Seq(("foo", 0), ("bar", 0), ("baz", 1)).toDF("a", "b")) + + // Check it with pruning predicates + checkAnswer( + sql("SELECT * FROM test_added_partitions where b = 0"), + Seq(("foo", 0), ("bar", 0)).toDF("a", "b")) + checkAnswer( + sql("SELECT * FROM test_added_partitions where b = 1"), + Seq(("baz", 1)).toDF("a", "b")) + checkAnswer( + sql("SELECT * FROM test_added_partitions where b = 2"), + Seq[(String, Int)]().toDF("a", "b")) + + // Also verify the inputFiles implementation + assert(sql("select * from test_added_partitions").inputFiles.length == 2) + assert(sql("select * from test_added_partitions where b = 0").inputFiles.length == 1) + assert(sql("select * from test_added_partitions where b = 1").inputFiles.length == 1) + assert(sql("select * from test_added_partitions where b = 2").inputFiles.length == 0) } } } From 36d81c2c68ef4114592b069287743eb5cb078318 Mon Sep 17 00:00:00 2001 From: Jun Kim Date: Sat, 15 Oct 2016 00:36:55 -0700 Subject: [PATCH 301/306] [SPARK-17953][DOCUMENTATION] Fix typo in SparkSession scaladoc ## What changes were proposed in this pull request? ### Before: ```scala SparkSession.builder() .master("local") .appName("Word Count") .config("spark.some.config.option", "some-value"). .getOrCreate() ``` ### After: ```scala SparkSession.builder() .master("local") .appName("Word Count") .config("spark.some.config.option", "some-value") .getOrCreate() ``` There was one unexpected dot! Author: Jun Kim Closes #15498 from tae-jun/SPARK-17953. --- sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 137c426b4b88d..baae55013787d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -64,7 +64,7 @@ import org.apache.spark.util.Utils * SparkSession.builder() * .master("local") * .appName("Word Count") - * .config("spark.some.config.option", "some-value"). + * .config("spark.some.config.option", "some-value") * .getOrCreate() * }}} */ From ed1463341455830b8867b721a1b34f291139baf3 Mon Sep 17 00:00:00 2001 From: Zhan Zhang Date: Sat, 15 Oct 2016 18:45:04 -0700 Subject: [PATCH 302/306] [SPARK-17637][SCHEDULER] Packed scheduling for Spark tasks across executors ## What changes were proposed in this pull request? Restructure the code and implement two new task assigner. PackedAssigner: try to allocate tasks to the executors with least available cores, so that spark can release reserved executors when dynamic allocation is enabled. BalancedAssigner: try to allocate tasks to the executors with more available cores in order to balance the workload across all executors. By default, the original round robin assigner is used. We test a pipeline, and new PackedAssigner save around 45% regarding the reserved cpu and memory with dynamic allocation enabled. ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) Both unit test in TaskSchedulerImplSuite and manual tests in production pipeline. Author: Zhan Zhang Closes #15218 from zhzhan/packed-scheduler. --- .../apache/spark/scheduler/TaskAssigner.scala | 154 ++++++++++++++++++ .../spark/scheduler/TaskSchedulerImpl.scala | 53 +++--- .../scheduler/TaskSchedulerImplSuite.scala | 67 ++++++++ docs/configuration.md | 11 ++ 4 files changed, 266 insertions(+), 19 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/scheduler/TaskAssigner.scala diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskAssigner.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskAssigner.scala new file mode 100644 index 0000000000000..62df9657a6ac6 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskAssigner.scala @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.PriorityQueue +import scala.util.Random + +import org.apache.spark.SparkConf + +case class OfferState(workOffer: WorkerOffer, var cores: Int) { + // Build a list of tasks to assign to each worker. + val tasks = new ArrayBuffer[TaskDescription](cores) +} + +abstract class TaskAssigner(conf: SparkConf) { + var offer: Seq[OfferState] = _ + val CPUS_PER_TASK = conf.getInt("spark.task.cpus", 1) + + // The final assigned offer returned to TaskScheduler. + def tasks(): Seq[ArrayBuffer[TaskDescription]] = offer.map(_.tasks) + + // construct the assigner by the workoffer. + def construct(workOffer: Seq[WorkerOffer]): Unit = { + offer = workOffer.map(o => OfferState(o, o.cores)) + } + + // Invoked in each round of Taskset assignment to initialize the internal structure. + def init(): Unit + + // Indicating whether there is offer available to be used by one round of Taskset assignment. + def hasNext(): Boolean + + // Next available offer returned to one round of Taskset assignment. + def getNext(): OfferState + + // Called by the TaskScheduler to indicate whether the current offer is accepted + // In order to decide whether the current is valid for the next offering. + def taskAssigned(assigned: Boolean): Unit + + // Release internally maintained resources. Subclass is responsible to + // release its own private resources. + def reset: Unit = { + offer = null + } +} + +class RoundRobinAssigner(conf: SparkConf) extends TaskAssigner(conf) { + var i = 0 + override def construct(workOffer: Seq[WorkerOffer]): Unit = { + offer = Random.shuffle(workOffer.map(o => OfferState(o, o.cores))) + } + override def init(): Unit = { + i = 0 + } + override def hasNext: Boolean = { + i < offer.size + } + override def getNext(): OfferState = { + offer(i) + } + override def taskAssigned(assigned: Boolean): Unit = { + i += 1 + } + override def reset: Unit = { + super.reset + i = 0 + } +} + +class BalancedAssigner(conf: SparkConf) extends TaskAssigner(conf) { + var maxHeap: PriorityQueue[OfferState] = _ + var current: OfferState = _ + + override def construct(workOffer: Seq[WorkerOffer]): Unit = { + offer = Random.shuffle(workOffer.map(o => OfferState(o, o.cores))) + } + implicit val ord: Ordering[OfferState] = new Ordering[OfferState] { + def compare(x: OfferState, y: OfferState): Int = { + return Ordering[Int].compare(x.cores, y.cores) + } + } + def init(): Unit = { + maxHeap = new PriorityQueue[OfferState]() + offer.filter(_.cores >= CPUS_PER_TASK).foreach(maxHeap.enqueue(_)) + } + override def hasNext: Boolean = { + maxHeap.size > 0 + } + override def getNext(): OfferState = { + current = maxHeap.dequeue() + current + } + + override def taskAssigned(assigned: Boolean): Unit = { + if (current.cores >= CPUS_PER_TASK && assigned) { + maxHeap.enqueue(current) + } + } + override def reset: Unit = { + super.reset + maxHeap = null + current = null + } +} + +class PackedAssigner(conf: SparkConf) extends TaskAssigner(conf) { + + var sorted: Seq[OfferState] = _ + var i = 0 + var current: OfferState = _ + + override def init(): Unit = { + i = 0 + sorted = offer.filter(_.cores >= CPUS_PER_TASK).sortBy(_.cores) + } + + override def hasNext: Boolean = { + i < sorted.size + } + + override def getNext(): OfferState = { + current = sorted(i) + current + } + + def taskAssigned(assigned: Boolean): Unit = { + if (current.cores < CPUS_PER_TASK || !assigned) { + i += 1 + } + } + + override def reset: Unit = { + super.reset + sorted = null + current = null + i = 0 + } +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 3e3f1ad031e66..fb732ea8e5a3b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -22,9 +22,7 @@ import java.util.{Timer, TimerTask} import java.util.concurrent.TimeUnit import java.util.concurrent.atomic.AtomicLong -import scala.collection.Set import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} -import scala.util.Random import org.apache.spark._ import org.apache.spark.TaskState.TaskState @@ -61,6 +59,21 @@ private[spark] class TaskSchedulerImpl( val conf = sc.conf + val DEFAULT_TASK_ASSIGNER = classOf[RoundRobinAssigner].getName + lazy val taskAssigner: TaskAssigner = { + val className = conf.get("spark.task.assigner", DEFAULT_TASK_ASSIGNER) + try { + logInfo(s"""constructing assigner as $className""") + val ctor = Utils.classForName(className).getConstructor(classOf[SparkConf]) + ctor.newInstance(conf).asInstanceOf[TaskAssigner] + } catch { + case _: Throwable => + logWarning( + s"""$className cannot be constructed fallback to default + | $DEFAULT_TASK_ASSIGNER""".stripMargin) + new RoundRobinAssigner(conf) + } + } // How often to check for speculative tasks val SPECULATION_INTERVAL_MS = conf.getTimeAsMs("spark.speculation.interval", "100ms") @@ -250,24 +263,26 @@ private[spark] class TaskSchedulerImpl( private def resourceOfferSingleTaskSet( taskSet: TaskSetManager, maxLocality: TaskLocality, - shuffledOffers: Seq[WorkerOffer], - availableCpus: Array[Int], - tasks: IndexedSeq[ArrayBuffer[TaskDescription]]) : Boolean = { + taskAssigner: TaskAssigner) : Boolean = { var launchedTask = false - for (i <- 0 until shuffledOffers.size) { - val execId = shuffledOffers(i).executorId - val host = shuffledOffers(i).host - if (availableCpus(i) >= CPUS_PER_TASK) { + taskAssigner.init() + while(taskAssigner.hasNext()) { + var assigned = false + val current = taskAssigner.getNext() + val execId = current.workOffer.executorId + val host = current.workOffer.host + if (current.cores >= CPUS_PER_TASK) { try { for (task <- taskSet.resourceOffer(execId, host, maxLocality)) { - tasks(i) += task + current.tasks += task val tid = task.taskId taskIdToTaskSetManager(tid) = taskSet taskIdToExecutorId(tid) = execId executorIdToTaskCount(execId) += 1 - availableCpus(i) -= CPUS_PER_TASK - assert(availableCpus(i) >= 0) + current.cores = current.cores - CPUS_PER_TASK + assert(current.cores >= 0) launchedTask = true + assigned = true } } catch { case e: TaskNotSerializableException => @@ -277,8 +292,10 @@ private[spark] class TaskSchedulerImpl( return launchedTask } } + taskAssigner.taskAssigned(assigned) } return launchedTask + } /** @@ -305,12 +322,8 @@ private[spark] class TaskSchedulerImpl( hostsByRack.getOrElseUpdate(rack, new HashSet[String]()) += o.host } } + taskAssigner.construct(offers) - // Randomly shuffle offers to avoid always placing tasks on the same set of workers. - val shuffledOffers = Random.shuffle(offers) - // Build a list of tasks to assign to each worker. - val tasks = shuffledOffers.map(o => new ArrayBuffer[TaskDescription](o.cores)) - val availableCpus = shuffledOffers.map(o => o.cores).toArray val sortedTaskSets = rootPool.getSortedTaskSetQueue for (taskSet <- sortedTaskSets) { logDebug("parentName: %s, name: %s, runningTasks: %s".format( @@ -329,7 +342,7 @@ private[spark] class TaskSchedulerImpl( for (currentMaxLocality <- taskSet.myLocalityLevels) { do { launchedTaskAtCurrentMaxLocality = resourceOfferSingleTaskSet( - taskSet, currentMaxLocality, shuffledOffers, availableCpus, tasks) + taskSet, currentMaxLocality, taskAssigner) launchedAnyTask |= launchedTaskAtCurrentMaxLocality } while (launchedTaskAtCurrentMaxLocality) } @@ -337,10 +350,12 @@ private[spark] class TaskSchedulerImpl( taskSet.abortIfCompletelyBlacklisted(hostToExecutors) } } - + val tasks = taskAssigner.tasks + taskAssigner.reset if (tasks.size > 0) { hasLaunchedTask = true } + return tasks } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index f5f1947661d9a..2584f85bc553a 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -109,6 +109,72 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B assert(!failedTaskSet) } + test("Scheduler balance the assignment to the worker with more free cores") { + val taskScheduler = setupScheduler(("spark.task.assigner", classOf[BalancedAssigner].getName)) + val workerOffers = IndexedSeq(new WorkerOffer("executor0", "host0", 2), + new WorkerOffer("executor1", "host1", 4)) + val selectedExecutorIds = { + val taskSet = FakeTask.createTaskSet(2) + taskScheduler.submitTasks(taskSet) + val taskDescriptions = taskScheduler.resourceOffers(workerOffers).flatten + assert(2 === taskDescriptions.length) + taskDescriptions.map(_.executorId) + } + val count = selectedExecutorIds.count(_ == workerOffers(1).executorId) + assert(count == 2) + assert(!failedTaskSet) + } + + test("Scheduler balance the assignment across workers with same free cores") { + val taskScheduler = setupScheduler(("spark.task.assigner", classOf[BalancedAssigner].getName)) + val workerOffers = IndexedSeq(new WorkerOffer("executor0", "host0", 2), + new WorkerOffer("executor1", "host1", 2)) + val selectedExecutorIds = { + val taskSet = FakeTask.createTaskSet(2) + taskScheduler.submitTasks(taskSet) + val taskDescriptions = taskScheduler.resourceOffers(workerOffers).flatten + assert(2 === taskDescriptions.length) + taskDescriptions.map(_.executorId) + } + val count = selectedExecutorIds.count(_ == workerOffers(1).executorId) + assert(count == 1) + assert(!failedTaskSet) + } + + test("Scheduler packs the assignment to workers with less free cores") { + val taskScheduler = setupScheduler(("spark.task.assigner", classOf[PackedAssigner].getName)) + val workerOffers = IndexedSeq(new WorkerOffer("executor0", "host0", 2), + new WorkerOffer("executor1", "host1", 4)) + val selectedExecutorIds = { + val taskSet = FakeTask.createTaskSet(2) + taskScheduler.submitTasks(taskSet) + val taskDescriptions = taskScheduler.resourceOffers(workerOffers).flatten + assert(2 === taskDescriptions.length) + taskDescriptions.map(_.executorId) + } + val count = selectedExecutorIds.count(_ == workerOffers(0).executorId) + assert(count == 2) + assert(!failedTaskSet) + } + + test("Scheduler keeps packing the assignment to the same worker") { + val taskScheduler = setupScheduler(("spark.task.assigner", classOf[PackedAssigner].getName)) + val workerOffers = IndexedSeq(new WorkerOffer("executor0", "host0", 4), + new WorkerOffer("executor1", "host1", 4)) + val selectedExecutorIds = { + val taskSet = FakeTask.createTaskSet(4) + taskScheduler.submitTasks(taskSet) + val taskDescriptions = taskScheduler.resourceOffers(workerOffers).flatten + assert(4 === taskDescriptions.length) + taskDescriptions.map(_.executorId) + } + + val count = selectedExecutorIds.count(_ == workerOffers(0).executorId) + assert(count == 4) + assert(!failedTaskSet) + } + + test("Scheduler correctly accounts for multiple CPUs per task") { val taskCpus = 2 val taskScheduler = setupScheduler("spark.task.cpus" -> taskCpus.toString) @@ -408,4 +474,5 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B assert(thirdTaskDescs.size === 0) assert(taskScheduler.getExecutorsAliveOnHost("host1") === Some(Set("executor1", "executor3"))) } + } diff --git a/docs/configuration.md b/docs/configuration.md index 373e22d71a872..6f3fbeb76cc24 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1334,6 +1334,17 @@ Apart from these, the following properties are also available, and may be useful Should be greater than or equal to 1. Number of allowed retries = this value - 1. + + spark.task.assigner + org.apache.spark.scheduler.RoundRobinAssigner + + The strategy of how to allocate tasks among workers with free cores. + By default, round robin with randomness is used. + org.apache.spark.scheduler.BalancedAssigner tries to balance the task across all workers (allocating tasks to + workers with most free cores). org.apache.spark.scheduler.PackedAssigner tries to allocate tasks to workers + with the least free cores, which may help releasing the resources when dynamic allocation is enabled. + + #### Dynamic Allocation From 72a6e7a57a63aba69f26c84bf68a5fb213d2a521 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 15 Oct 2016 22:31:37 -0700 Subject: [PATCH 303/306] Revert "[SPARK-17637][SCHEDULER] Packed scheduling for Spark tasks across executors" This reverts commit ed1463341455830b8867b721a1b34f291139baf3. The patch merged had obvious quality and documentation issue. The idea is useful, and we should work towards improving its quality and merging it in again. --- .../apache/spark/scheduler/TaskAssigner.scala | 154 ------------------ .../spark/scheduler/TaskSchedulerImpl.scala | 53 +++--- .../scheduler/TaskSchedulerImplSuite.scala | 67 -------- docs/configuration.md | 11 -- 4 files changed, 19 insertions(+), 266 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/scheduler/TaskAssigner.scala diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskAssigner.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskAssigner.scala deleted file mode 100644 index 62df9657a6ac6..0000000000000 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskAssigner.scala +++ /dev/null @@ -1,154 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.scheduler - -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.PriorityQueue -import scala.util.Random - -import org.apache.spark.SparkConf - -case class OfferState(workOffer: WorkerOffer, var cores: Int) { - // Build a list of tasks to assign to each worker. - val tasks = new ArrayBuffer[TaskDescription](cores) -} - -abstract class TaskAssigner(conf: SparkConf) { - var offer: Seq[OfferState] = _ - val CPUS_PER_TASK = conf.getInt("spark.task.cpus", 1) - - // The final assigned offer returned to TaskScheduler. - def tasks(): Seq[ArrayBuffer[TaskDescription]] = offer.map(_.tasks) - - // construct the assigner by the workoffer. - def construct(workOffer: Seq[WorkerOffer]): Unit = { - offer = workOffer.map(o => OfferState(o, o.cores)) - } - - // Invoked in each round of Taskset assignment to initialize the internal structure. - def init(): Unit - - // Indicating whether there is offer available to be used by one round of Taskset assignment. - def hasNext(): Boolean - - // Next available offer returned to one round of Taskset assignment. - def getNext(): OfferState - - // Called by the TaskScheduler to indicate whether the current offer is accepted - // In order to decide whether the current is valid for the next offering. - def taskAssigned(assigned: Boolean): Unit - - // Release internally maintained resources. Subclass is responsible to - // release its own private resources. - def reset: Unit = { - offer = null - } -} - -class RoundRobinAssigner(conf: SparkConf) extends TaskAssigner(conf) { - var i = 0 - override def construct(workOffer: Seq[WorkerOffer]): Unit = { - offer = Random.shuffle(workOffer.map(o => OfferState(o, o.cores))) - } - override def init(): Unit = { - i = 0 - } - override def hasNext: Boolean = { - i < offer.size - } - override def getNext(): OfferState = { - offer(i) - } - override def taskAssigned(assigned: Boolean): Unit = { - i += 1 - } - override def reset: Unit = { - super.reset - i = 0 - } -} - -class BalancedAssigner(conf: SparkConf) extends TaskAssigner(conf) { - var maxHeap: PriorityQueue[OfferState] = _ - var current: OfferState = _ - - override def construct(workOffer: Seq[WorkerOffer]): Unit = { - offer = Random.shuffle(workOffer.map(o => OfferState(o, o.cores))) - } - implicit val ord: Ordering[OfferState] = new Ordering[OfferState] { - def compare(x: OfferState, y: OfferState): Int = { - return Ordering[Int].compare(x.cores, y.cores) - } - } - def init(): Unit = { - maxHeap = new PriorityQueue[OfferState]() - offer.filter(_.cores >= CPUS_PER_TASK).foreach(maxHeap.enqueue(_)) - } - override def hasNext: Boolean = { - maxHeap.size > 0 - } - override def getNext(): OfferState = { - current = maxHeap.dequeue() - current - } - - override def taskAssigned(assigned: Boolean): Unit = { - if (current.cores >= CPUS_PER_TASK && assigned) { - maxHeap.enqueue(current) - } - } - override def reset: Unit = { - super.reset - maxHeap = null - current = null - } -} - -class PackedAssigner(conf: SparkConf) extends TaskAssigner(conf) { - - var sorted: Seq[OfferState] = _ - var i = 0 - var current: OfferState = _ - - override def init(): Unit = { - i = 0 - sorted = offer.filter(_.cores >= CPUS_PER_TASK).sortBy(_.cores) - } - - override def hasNext: Boolean = { - i < sorted.size - } - - override def getNext(): OfferState = { - current = sorted(i) - current - } - - def taskAssigned(assigned: Boolean): Unit = { - if (current.cores < CPUS_PER_TASK || !assigned) { - i += 1 - } - } - - override def reset: Unit = { - super.reset - sorted = null - current = null - i = 0 - } -} diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index fb732ea8e5a3b..3e3f1ad031e66 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -22,7 +22,9 @@ import java.util.{Timer, TimerTask} import java.util.concurrent.TimeUnit import java.util.concurrent.atomic.AtomicLong +import scala.collection.Set import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} +import scala.util.Random import org.apache.spark._ import org.apache.spark.TaskState.TaskState @@ -59,21 +61,6 @@ private[spark] class TaskSchedulerImpl( val conf = sc.conf - val DEFAULT_TASK_ASSIGNER = classOf[RoundRobinAssigner].getName - lazy val taskAssigner: TaskAssigner = { - val className = conf.get("spark.task.assigner", DEFAULT_TASK_ASSIGNER) - try { - logInfo(s"""constructing assigner as $className""") - val ctor = Utils.classForName(className).getConstructor(classOf[SparkConf]) - ctor.newInstance(conf).asInstanceOf[TaskAssigner] - } catch { - case _: Throwable => - logWarning( - s"""$className cannot be constructed fallback to default - | $DEFAULT_TASK_ASSIGNER""".stripMargin) - new RoundRobinAssigner(conf) - } - } // How often to check for speculative tasks val SPECULATION_INTERVAL_MS = conf.getTimeAsMs("spark.speculation.interval", "100ms") @@ -263,26 +250,24 @@ private[spark] class TaskSchedulerImpl( private def resourceOfferSingleTaskSet( taskSet: TaskSetManager, maxLocality: TaskLocality, - taskAssigner: TaskAssigner) : Boolean = { + shuffledOffers: Seq[WorkerOffer], + availableCpus: Array[Int], + tasks: IndexedSeq[ArrayBuffer[TaskDescription]]) : Boolean = { var launchedTask = false - taskAssigner.init() - while(taskAssigner.hasNext()) { - var assigned = false - val current = taskAssigner.getNext() - val execId = current.workOffer.executorId - val host = current.workOffer.host - if (current.cores >= CPUS_PER_TASK) { + for (i <- 0 until shuffledOffers.size) { + val execId = shuffledOffers(i).executorId + val host = shuffledOffers(i).host + if (availableCpus(i) >= CPUS_PER_TASK) { try { for (task <- taskSet.resourceOffer(execId, host, maxLocality)) { - current.tasks += task + tasks(i) += task val tid = task.taskId taskIdToTaskSetManager(tid) = taskSet taskIdToExecutorId(tid) = execId executorIdToTaskCount(execId) += 1 - current.cores = current.cores - CPUS_PER_TASK - assert(current.cores >= 0) + availableCpus(i) -= CPUS_PER_TASK + assert(availableCpus(i) >= 0) launchedTask = true - assigned = true } } catch { case e: TaskNotSerializableException => @@ -292,10 +277,8 @@ private[spark] class TaskSchedulerImpl( return launchedTask } } - taskAssigner.taskAssigned(assigned) } return launchedTask - } /** @@ -322,8 +305,12 @@ private[spark] class TaskSchedulerImpl( hostsByRack.getOrElseUpdate(rack, new HashSet[String]()) += o.host } } - taskAssigner.construct(offers) + // Randomly shuffle offers to avoid always placing tasks on the same set of workers. + val shuffledOffers = Random.shuffle(offers) + // Build a list of tasks to assign to each worker. + val tasks = shuffledOffers.map(o => new ArrayBuffer[TaskDescription](o.cores)) + val availableCpus = shuffledOffers.map(o => o.cores).toArray val sortedTaskSets = rootPool.getSortedTaskSetQueue for (taskSet <- sortedTaskSets) { logDebug("parentName: %s, name: %s, runningTasks: %s".format( @@ -342,7 +329,7 @@ private[spark] class TaskSchedulerImpl( for (currentMaxLocality <- taskSet.myLocalityLevels) { do { launchedTaskAtCurrentMaxLocality = resourceOfferSingleTaskSet( - taskSet, currentMaxLocality, taskAssigner) + taskSet, currentMaxLocality, shuffledOffers, availableCpus, tasks) launchedAnyTask |= launchedTaskAtCurrentMaxLocality } while (launchedTaskAtCurrentMaxLocality) } @@ -350,12 +337,10 @@ private[spark] class TaskSchedulerImpl( taskSet.abortIfCompletelyBlacklisted(hostToExecutors) } } - val tasks = taskAssigner.tasks - taskAssigner.reset + if (tasks.size > 0) { hasLaunchedTask = true } - return tasks } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index 2584f85bc553a..f5f1947661d9a 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -109,72 +109,6 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B assert(!failedTaskSet) } - test("Scheduler balance the assignment to the worker with more free cores") { - val taskScheduler = setupScheduler(("spark.task.assigner", classOf[BalancedAssigner].getName)) - val workerOffers = IndexedSeq(new WorkerOffer("executor0", "host0", 2), - new WorkerOffer("executor1", "host1", 4)) - val selectedExecutorIds = { - val taskSet = FakeTask.createTaskSet(2) - taskScheduler.submitTasks(taskSet) - val taskDescriptions = taskScheduler.resourceOffers(workerOffers).flatten - assert(2 === taskDescriptions.length) - taskDescriptions.map(_.executorId) - } - val count = selectedExecutorIds.count(_ == workerOffers(1).executorId) - assert(count == 2) - assert(!failedTaskSet) - } - - test("Scheduler balance the assignment across workers with same free cores") { - val taskScheduler = setupScheduler(("spark.task.assigner", classOf[BalancedAssigner].getName)) - val workerOffers = IndexedSeq(new WorkerOffer("executor0", "host0", 2), - new WorkerOffer("executor1", "host1", 2)) - val selectedExecutorIds = { - val taskSet = FakeTask.createTaskSet(2) - taskScheduler.submitTasks(taskSet) - val taskDescriptions = taskScheduler.resourceOffers(workerOffers).flatten - assert(2 === taskDescriptions.length) - taskDescriptions.map(_.executorId) - } - val count = selectedExecutorIds.count(_ == workerOffers(1).executorId) - assert(count == 1) - assert(!failedTaskSet) - } - - test("Scheduler packs the assignment to workers with less free cores") { - val taskScheduler = setupScheduler(("spark.task.assigner", classOf[PackedAssigner].getName)) - val workerOffers = IndexedSeq(new WorkerOffer("executor0", "host0", 2), - new WorkerOffer("executor1", "host1", 4)) - val selectedExecutorIds = { - val taskSet = FakeTask.createTaskSet(2) - taskScheduler.submitTasks(taskSet) - val taskDescriptions = taskScheduler.resourceOffers(workerOffers).flatten - assert(2 === taskDescriptions.length) - taskDescriptions.map(_.executorId) - } - val count = selectedExecutorIds.count(_ == workerOffers(0).executorId) - assert(count == 2) - assert(!failedTaskSet) - } - - test("Scheduler keeps packing the assignment to the same worker") { - val taskScheduler = setupScheduler(("spark.task.assigner", classOf[PackedAssigner].getName)) - val workerOffers = IndexedSeq(new WorkerOffer("executor0", "host0", 4), - new WorkerOffer("executor1", "host1", 4)) - val selectedExecutorIds = { - val taskSet = FakeTask.createTaskSet(4) - taskScheduler.submitTasks(taskSet) - val taskDescriptions = taskScheduler.resourceOffers(workerOffers).flatten - assert(4 === taskDescriptions.length) - taskDescriptions.map(_.executorId) - } - - val count = selectedExecutorIds.count(_ == workerOffers(0).executorId) - assert(count == 4) - assert(!failedTaskSet) - } - - test("Scheduler correctly accounts for multiple CPUs per task") { val taskCpus = 2 val taskScheduler = setupScheduler("spark.task.cpus" -> taskCpus.toString) @@ -474,5 +408,4 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B assert(thirdTaskDescs.size === 0) assert(taskScheduler.getExecutorsAliveOnHost("host1") === Some(Set("executor1", "executor3"))) } - } diff --git a/docs/configuration.md b/docs/configuration.md index 6f3fbeb76cc24..373e22d71a872 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1334,17 +1334,6 @@ Apart from these, the following properties are also available, and may be useful Should be greater than or equal to 1. Number of allowed retries = this value - 1. - - spark.task.assigner - org.apache.spark.scheduler.RoundRobinAssigner - - The strategy of how to allocate tasks among workers with free cores. - By default, round robin with randomness is used. - org.apache.spark.scheduler.BalancedAssigner tries to balance the task across all workers (allocating tasks to - workers with most free cores). org.apache.spark.scheduler.PackedAssigner tries to allocate tasks to workers - with the least free cores, which may help releasing the resources when dynamic allocation is enabled. - - #### Dynamic Allocation From 59e3eb5af8d0969bbb785af77b66343bda7acc38 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sun, 16 Oct 2016 20:15:32 -0700 Subject: [PATCH 304/306] [SPARK-17819][SQL] Support default database in connection URIs for Spark Thrift Server ## What changes were proposed in this pull request? Currently, Spark Thrift Server ignores the default database in URI. This PR supports that like the following. ```sql $ bin/beeline -u jdbc:hive2://localhost:10000 -e "create database testdb" $ bin/beeline -u jdbc:hive2://localhost:10000/testdb -e "create table t(a int)" $ bin/beeline -u jdbc:hive2://localhost:10000/testdb -e "show tables" ... +------------+--------------+--+ | tableName | isTemporary | +------------+--------------+--+ | t | false | +------------+--------------+--+ 1 row selected (0.347 seconds) $ bin/beeline -u jdbc:hive2://localhost:10000 -e "show tables" ... +------------+--------------+--+ | tableName | isTemporary | +------------+--------------+--+ +------------+--------------+--+ No rows selected (0.098 seconds) ``` ## How was this patch tested? Manual. Note: I tried to add a test case for this, but I cannot found a suitable testsuite for this. I'll add the testcase if some advice is given. Author: Dongjoon Hyun Closes #15399 from dongjoon-hyun/SPARK-17819. --- .../thriftserver/SparkSQLSessionManager.scala | 3 + .../thriftserver/JdbcConnectionUriSuite.scala | 70 +++++++++++++++++++ 2 files changed, 73 insertions(+) create mode 100644 sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/JdbcConnectionUriSuite.scala diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala index 6a5117aea492d..226b7e175a9d9 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala @@ -79,6 +79,9 @@ private[hive] class SparkSQLSessionManager(hiveServer: HiveServer2, sqlContext: sqlContext.newSession() } ctx.setConf("spark.sql.hive.version", HiveUtils.hiveExecutionVersion) + if (sessionConf != null && sessionConf.containsKey("use:database")) { + ctx.sql(s"use ${sessionConf.get("use:database")}") + } sparkSqlOperationManager.sessionToContexts.put(sessionHandle, ctx) sessionHandle } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/JdbcConnectionUriSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/JdbcConnectionUriSuite.scala new file mode 100644 index 0000000000000..fb8a7e273ae44 --- /dev/null +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/JdbcConnectionUriSuite.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import java.sql.DriverManager + +import org.apache.hive.jdbc.HiveDriver + +import org.apache.spark.util.Utils + +class JdbcConnectionUriSuite extends HiveThriftServer2Test { + Utils.classForName(classOf[HiveDriver].getCanonicalName) + + override def mode: ServerMode.Value = ServerMode.binary + + val JDBC_TEST_DATABASE = "jdbc_test_database" + val USER = System.getProperty("user.name") + val PASSWORD = "" + + override protected def beforeAll(): Unit = { + super.beforeAll() + + val jdbcUri = s"jdbc:hive2://localhost:$serverPort/" + val connection = DriverManager.getConnection(jdbcUri, USER, PASSWORD) + val statement = connection.createStatement() + statement.execute(s"CREATE DATABASE $JDBC_TEST_DATABASE") + connection.close() + } + + override protected def afterAll(): Unit = { + try { + val jdbcUri = s"jdbc:hive2://localhost:$serverPort/" + val connection = DriverManager.getConnection(jdbcUri, USER, PASSWORD) + val statement = connection.createStatement() + statement.execute(s"DROP DATABASE $JDBC_TEST_DATABASE") + connection.close() + } finally { + super.afterAll() + } + } + + test("SPARK-17819 Support default database in connection URIs") { + val jdbcUri = s"jdbc:hive2://localhost:$serverPort/$JDBC_TEST_DATABASE" + val connection = DriverManager.getConnection(jdbcUri, USER, PASSWORD) + val statement = connection.createStatement() + try { + val resultSet = statement.executeQuery("select current_database()") + resultSet.next() + assert(resultSet.getString(1) === JDBC_TEST_DATABASE) + } finally { + statement.close() + connection.close() + } + } +} From e18d02c5a8f8af2e42079ab414f5d84b3e1a279e Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 17 Oct 2016 12:08:25 +0800 Subject: [PATCH 305/306] [SPARK-17947][SQL] Add Doc and Comment about spark.sql.debug ### What changes were proposed in this pull request? Just document the impact of `spark.sql.debug`: When enabling the debug, Spark SQL internal table properties are not filtered out; however, some related DDL commands (e.g., Analyze Table and CREATE TABLE LIKE) might not work properly. ### How was this patch tested? N/A Author: gatorsmile Closes #15494 from gatorsmile/addDocForSQLDebug. --- .../src/main/scala/org/apache/spark/sql/internal/SQLConf.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index e73d0187b584b..a055e0135c136 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -934,8 +934,11 @@ object StaticSQLConf { .intConf .createWithDefault(4000) + // When enabling the debug, Spark SQL internal table properties are not filtered out; however, + // some related DDL commands (e.g., ANALYZE TABLE and CREATE TABLE LIKE) might not work properly. val DEBUG_MODE = buildConf("spark.sql.debug") .internal() + .doc("Only used for internal debugging. Not all functions are supported when it is enabled.") .booleanConf .createWithDefault(false) } From 56b0f5f4d1d7826737b81ebc4ec5dad83b6463e3 Mon Sep 17 00:00:00 2001 From: Weiqing Yang Date: Sun, 16 Oct 2016 22:38:30 -0700 Subject: [PATCH 306/306] [MINOR][SQL] Add prettyName for current_database function ## What changes were proposed in this pull request? Added a `prettyname` for current_database function. ## How was this patch tested? Manually. Before: ``` scala> sql("select current_database()").show +-----------------+ |currentdatabase()| +-----------------+ | default| +-----------------+ ``` After: ``` scala> sql("select current_database()").show +------------------+ |current_database()| +------------------+ | default| +------------------+ ``` Author: Weiqing Yang Closes #15506 from weiqingy/prettyName. --- .../scala/org/apache/spark/sql/catalyst/expressions/misc.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 138ef2a1dcc01..5ead16908732f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -618,6 +618,7 @@ case class CurrentDatabase() extends LeafExpression with Unevaluable { override def dataType: DataType = StringType override def foldable: Boolean = true override def nullable: Boolean = false + override def prettyName: String = "current_database" } /**