From 04ff406c5a82f9454419d0f0054b1aa75ea97aa0 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sun, 9 Jul 2017 15:52:55 -0700 Subject: [PATCH 1/4] fix. --- .../sql/catalyst/analysis/CheckAnalysis.scala | 8 ++++++++ .../apache/spark/sql/ColumnExpressionSuite.scala | 16 ++++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 85c52792ef65..0c8d4b263bbf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -74,6 +74,10 @@ trait CheckAnalysis extends PredicateHelper { } } + private def getNumLeafNodes(operator: LogicalPlan): Int = { + operator.collect { case _: LeafNode => 1 }.sum + } + def checkAnalysis(plan: LogicalPlan): Unit = { // We transform up and order the rules so as to catch the first possible failure instead // of the result of cascading resolution failures. @@ -100,6 +104,10 @@ trait CheckAnalysis extends PredicateHelper { failAnalysis( s"invalid cast from ${c.child.dataType.simpleString} to ${c.dataType.simpleString}") + case e @ (_: InputFileName | _: InputFileBlockLength | _: InputFileBlockStart) + if getNumLeafNodes(operator) > 1 => + e.failAnalysis(s"'${e.prettyName}' does not support more than one sources") + case g: Grouping => failAnalysis("grouping() can only be used with GroupingSets/Cube/Rollup") case g: GroupingID => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index bc708ca88d7e..a40ad431ca6b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -530,11 +530,27 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { ) } + test("input_file_name, input_file_block_start, input_file_block_length - more than one sources") { + withTable("tab1", "tab2") { + val data = sparkContext.parallelize(0 to 10).toDF("id") + data.write.saveAsTable("tab1") + data.write.saveAsTable("tab2") + Seq("input_file_name", "input_file_block_start", "input_file_block_length").foreach { func => + val e = intercept[AnalysisException] { + sql(s"SELECT *, $func() FROM tab1 JOIN tab2 ON tab1.id = tab2.id") + }.getMessage + assert(e.contains(s"'$func' does not support more than one sources")) + } + } + } + test("input_file_name, input_file_block_start, input_file_block_length - FileScanRDD") { withTempPath { dir => val data = sparkContext.parallelize(0 to 10).toDF("id") data.write.parquet(dir.getCanonicalPath) + spark.read.parquet(dir.getCanonicalPath).explain(true) + // Test the 3 expressions when reading from files val q = spark.read.parquet(dir.getCanonicalPath).select( input_file_name(), expr("input_file_block_start()"), expr("input_file_block_length()")) From 596ea17bc99a703004ab7bef657603a9db57d5f2 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sun, 9 Jul 2017 16:02:36 -0700 Subject: [PATCH 2/4] fix. --- .../test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index a40ad431ca6b..502a0d16cc3d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -549,8 +549,6 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { val data = sparkContext.parallelize(0 to 10).toDF("id") data.write.parquet(dir.getCanonicalPath) - spark.read.parquet(dir.getCanonicalPath).explain(true) - // Test the 3 expressions when reading from files val q = spark.read.parquet(dir.getCanonicalPath).select( input_file_name(), expr("input_file_block_start()"), expr("input_file_block_length()")) From 6b48a9e52ded62715b32aef4ee31b121d3e7aee9 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 10 Jul 2017 18:49:06 -0700 Subject: [PATCH 3/4] fix. --- .../sql/catalyst/analysis/CheckAnalysis.scala | 11 +++++--- .../spark/sql/ColumnExpressionSuite.scala | 27 ++++++++++++++++++- 2 files changed, 34 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 0c8d4b263bbf..74c68a360cba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -74,8 +74,13 @@ trait CheckAnalysis extends PredicateHelper { } } - private def getNumLeafNodes(operator: LogicalPlan): Int = { - operator.collect { case _: LeafNode => 1 }.sum + private def getNumInputFileBlockSources(operator: LogicalPlan): Int = { + operator match { + case _: LeafNode => 1 + // UNION ALL has multiple children, but these children do not concurrently use InputFileBlock. + case u: Union => u.children.map(getNumInputFileBlockSources).sum - u.children.length + 1 + case o => o.children.map(getNumInputFileBlockSources).sum + } } def checkAnalysis(plan: LogicalPlan): Unit = { @@ -105,7 +110,7 @@ trait CheckAnalysis extends PredicateHelper { s"invalid cast from ${c.child.dataType.simpleString} to ${c.dataType.simpleString}") case e @ (_: InputFileName | _: InputFileBlockLength | _: InputFileBlockStart) - if getNumLeafNodes(operator) > 1 => + if getNumInputFileBlockSources(operator) > 1 => e.failAnalysis(s"'${e.prettyName}' does not support more than one sources") case g: Grouping => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 502a0d16cc3d..a2f007506f24 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -532,7 +532,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { test("input_file_name, input_file_block_start, input_file_block_length - more than one sources") { withTable("tab1", "tab2") { - val data = sparkContext.parallelize(0 to 10).toDF("id") + val data = sparkContext.parallelize(0 to 9).toDF("id") data.write.saveAsTable("tab1") data.write.saveAsTable("tab2") Seq("input_file_name", "input_file_block_start", "input_file_block_length").foreach { func => @@ -541,6 +541,31 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { }.getMessage assert(e.contains(s"'$func' does not support more than one sources")) } + + val df = sql( + """ + |SELECT *, input_file_name() + |FROM (SELECT * FROM tab1 UNION ALL SELECT * FROM tab2 UNION ALL SELECT * FROM tab2) + """.stripMargin) + assert(df.count() == 30) + + var e = intercept[AnalysisException] { + sql( + """ + |SELECT *, input_file_name() + |FROM (SELECT * FROM tab1 NATURAL JOIN tab2) UNION ALL SELECT * FROM tab2 + """.stripMargin) + }.getMessage + assert(e.contains("'input_file_name' does not support more than one sources")) + + e = intercept[AnalysisException] { + sql( + """ + |SELECT *, input_file_name() + |FROM (SELECT * FROM tab1 UNION ALL SELECT * FROM tab2) NATURAL JOIN tab2 + """.stripMargin) + }.getMessage + assert(e.contains("'input_file_name' does not support more than one sources")) } } From c4de2b8e2583c55f1b761569050d2c21506c2291 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sun, 16 Jul 2017 19:31:04 -0700 Subject: [PATCH 4/4] fix. --- .../sql/catalyst/analysis/CheckAnalysis.scala | 13 --- .../sql/execution/datasources/rules.scala | 38 +++++++- .../internal/BaseSessionStateBuilder.scala | 1 + .../spark/sql/ColumnExpressionSuite.scala | 88 +++++++++++-------- .../sql/streaming/FileStreamSinkSuite.scala | 3 +- .../sql/hive/HiveSessionStateBuilder.scala | 21 ++--- 6 files changed, 103 insertions(+), 61 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 74c68a360cba..85c52792ef65 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -74,15 +74,6 @@ trait CheckAnalysis extends PredicateHelper { } } - private def getNumInputFileBlockSources(operator: LogicalPlan): Int = { - operator match { - case _: LeafNode => 1 - // UNION ALL has multiple children, but these children do not concurrently use InputFileBlock. - case u: Union => u.children.map(getNumInputFileBlockSources).sum - u.children.length + 1 - case o => o.children.map(getNumInputFileBlockSources).sum - } - } - def checkAnalysis(plan: LogicalPlan): Unit = { // We transform up and order the rules so as to catch the first possible failure instead // of the result of cascading resolution failures. @@ -109,10 +100,6 @@ trait CheckAnalysis extends PredicateHelper { failAnalysis( s"invalid cast from ${c.child.dataType.simpleString} to ${c.dataType.simpleString}") - case e @ (_: InputFileName | _: InputFileBlockLength | _: InputFileBlockStart) - if getNumInputFileBlockSources(operator) > 1 => - e.failAnalysis(s"'${e.prettyName}' does not support more than one sources") - case g: Grouping => failAnalysis("grouping() can only be used with GroupingSets/Cube/Rollup") case g: GroupingID => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 3f4a78580f1e..bf914ffb25bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -22,7 +22,7 @@ import java.util.Locale import org.apache.spark.sql.{AnalysisException, SaveMode, SparkSession} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog._ -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, RowOrdering} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Expression, InputFileBlockLength, InputFileBlockStart, InputFileName, RowOrdering} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.command.DDLUtils @@ -409,6 +409,42 @@ object HiveOnlyCheck extends (LogicalPlan => Unit) { } } + +/** + * A rule to do various checks before reading a table. + */ +object PreReadCheck extends (LogicalPlan => Unit) { + def apply(plan: LogicalPlan): Unit = { + plan.foreach { + case operator: LogicalPlan => + operator transformExpressionsUp { + case e @ (_: InputFileName | _: InputFileBlockLength | _: InputFileBlockStart) => + checkNumInputFileBlockSources(e, operator) + e + } + } + } + + private def checkNumInputFileBlockSources(e: Expression, operator: LogicalPlan): Int = { + operator match { + case _: CatalogRelation => 1 + case _ @ LogicalRelation(_: HadoopFsRelation, _, _) => 1 + case _: LeafNode => 0 + // UNION ALL has multiple children, but these children do not concurrently use InputFileBlock. + case u: Union => + if (u.children.map(checkNumInputFileBlockSources(e, _)).sum >= 1) 1 else 0 + case o => + val numInputFileBlockSources = o.children.map(checkNumInputFileBlockSources(e, _)).sum + if (numInputFileBlockSources > 1) { + e.failAnalysis(s"'${e.prettyName}' does not support more than one sources") + } else { + numInputFileBlockSources + } + } + } +} + + /** * A rule to do various checks before inserting into or writing to a data source table. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 72d0ddc62303..eb005ebc7aa9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -168,6 +168,7 @@ abstract class BaseSessionStateBuilder( override val extendedCheckRules: Seq[LogicalPlan => Unit] = PreWriteCheck +: + PreReadCheck +: HiveOnlyCheck +: customCheckRules } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index a2f007506f24..7c45be21961d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -530,42 +530,60 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { ) } - test("input_file_name, input_file_block_start, input_file_block_length - more than one sources") { - withTable("tab1", "tab2") { - val data = sparkContext.parallelize(0 to 9).toDF("id") - data.write.saveAsTable("tab1") - data.write.saveAsTable("tab2") - Seq("input_file_name", "input_file_block_start", "input_file_block_length").foreach { func => - val e = intercept[AnalysisException] { - sql(s"SELECT *, $func() FROM tab1 JOIN tab2 ON tab1.id = tab2.id") - }.getMessage - assert(e.contains(s"'$func' does not support more than one sources")) + test("input_file_name, input_file_block_start, input_file_block_length - more than one source") { + withTempView("tempView1") { + withTable("tab1", "tab2") { + val data = sparkContext.parallelize(0 to 9).toDF("id") + data.write.saveAsTable("tab1") + data.write.saveAsTable("tab2") + data.createOrReplaceTempView("tempView1") + Seq("input_file_name", "input_file_block_start", "input_file_block_length").foreach { f => + val e = intercept[AnalysisException] { + sql(s"SELECT *, $f() FROM tab1 JOIN tab2 ON tab1.id = tab2.id") + }.getMessage + assert(e.contains(s"'$f' does not support more than one source")) + } + + def checkResult( + fromClause: String, + exceptionExpected: Boolean, + numExpectedRows: Int = 0): Unit = { + val stmt = s"SELECT *, input_file_name() FROM ($fromClause)" + if (exceptionExpected) { + val e = intercept[AnalysisException](sql(stmt)).getMessage + assert(e.contains("'input_file_name' does not support more than one source")) + } else { + assert(sql(stmt).count() == numExpectedRows) + } + } + + checkResult( + "SELECT * FROM tab1 UNION ALL SELECT * FROM tab2 UNION ALL SELECT * FROM tab2", + exceptionExpected = false, + numExpectedRows = 30) + + checkResult( + "(SELECT * FROM tempView1 NATURAL JOIN tab2) UNION ALL SELECT * FROM tab2", + exceptionExpected = false, + numExpectedRows = 20) + + checkResult( + "(SELECT * FROM tab1 UNION ALL SELECT * FROM tab2) NATURAL JOIN tempView1", + exceptionExpected = false, + numExpectedRows = 20) + + checkResult( + "(SELECT * FROM tempView1 UNION ALL SELECT * FROM tab2) NATURAL JOIN tab2", + exceptionExpected = true) + + checkResult( + "(SELECT * FROM tab1 NATURAL JOIN tab2) UNION ALL SELECT * FROM tab2", + exceptionExpected = true) + + checkResult( + "(SELECT * FROM tab1 UNION ALL SELECT * FROM tab2) NATURAL JOIN tab2", + exceptionExpected = true) } - - val df = sql( - """ - |SELECT *, input_file_name() - |FROM (SELECT * FROM tab1 UNION ALL SELECT * FROM tab2 UNION ALL SELECT * FROM tab2) - """.stripMargin) - assert(df.count() == 30) - - var e = intercept[AnalysisException] { - sql( - """ - |SELECT *, input_file_name() - |FROM (SELECT * FROM tab1 NATURAL JOIN tab2) UNION ALL SELECT * FROM tab2 - """.stripMargin) - }.getMessage - assert(e.contains("'input_file_name' does not support more than one sources")) - - e = intercept[AnalysisException] { - sql( - """ - |SELECT *, input_file_name() - |FROM (SELECT * FROM tab1 UNION ALL SELECT * FROM tab2) NATURAL JOIN tab2 - """.stripMargin) - }.getMessage - assert(e.contains("'input_file_name' does not support more than one sources")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala index bb6a27803bb2..7a185e8c9711 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala @@ -126,8 +126,7 @@ class FileStreamSinkSuite extends StreamTest { // Verify that MetadataLogFileIndex is being used and the correct partitioning schema has // been inferred val hadoopdFsRelations = outputDf.queryExecution.analyzed.collect { - case LogicalRelation(baseRelation, _, _) if baseRelation.isInstanceOf[HadoopFsRelation] => - baseRelation.asInstanceOf[HadoopFsRelation] + case LogicalRelation(baseRelation: HadoopFsRelation, _, _) => baseRelation } assert(hadoopdFsRelations.size === 1) assert(hadoopdFsRelations.head.location.isInstanceOf[MetadataLogFileIndex]) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index e16c9e46b772..92cb4ef11c9e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -69,22 +69,23 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session override protected def analyzer: Analyzer = new Analyzer(catalog, conf) { override val extendedResolutionRules: Seq[Rule[LogicalPlan]] = new ResolveHiveSerdeTable(session) +: - new FindDataSourceTable(session) +: - new ResolveSQLOnFile(session) +: - customResolutionRules + new FindDataSourceTable(session) +: + new ResolveSQLOnFile(session) +: + customResolutionRules override val postHocResolutionRules: Seq[Rule[LogicalPlan]] = new DetermineTableStats(session) +: - RelationConversions(conf, catalog) +: - PreprocessTableCreation(session) +: - PreprocessTableInsertion(conf) +: - DataSourceAnalysis(conf) +: - HiveAnalysis +: - customPostHocResolutionRules + RelationConversions(conf, catalog) +: + PreprocessTableCreation(session) +: + PreprocessTableInsertion(conf) +: + DataSourceAnalysis(conf) +: + HiveAnalysis +: + customPostHocResolutionRules override val extendedCheckRules: Seq[LogicalPlan => Unit] = PreWriteCheck +: - customCheckRules + PreReadCheck +: + customCheckRules } /**