Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import java.util.Locale
import org.apache.spark.sql.{AnalysisException, SaveMode, SparkSession}
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, RowOrdering}
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Expression, InputFileBlockLength, InputFileBlockStart, InputFileName, RowOrdering}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.command.DDLUtils
Expand Down Expand Up @@ -409,6 +409,42 @@ object HiveOnlyCheck extends (LogicalPlan => Unit) {
}
}


/**
* A rule to do various checks before reading a table.
*/
object PreReadCheck extends (LogicalPlan => Unit) {
def apply(plan: LogicalPlan): Unit = {
plan.foreach {
case operator: LogicalPlan =>
operator transformExpressionsUp {
case e @ (_: InputFileName | _: InputFileBlockLength | _: InputFileBlockStart) =>
checkNumInputFileBlockSources(e, operator)
e
}
}
}

private def checkNumInputFileBlockSources(e: Expression, operator: LogicalPlan): Int = {
operator match {
case _: CatalogRelation => 1
case _ @ LogicalRelation(_: HadoopFsRelation, _, _) => 1
case _: LeafNode => 0
// UNION ALL has multiple children, but these children do not concurrently use InputFileBlock.
case u: Union =>
if (u.children.map(checkNumInputFileBlockSources(e, _)).sum >= 1) 1 else 0
case o =>
val numInputFileBlockSources = o.children.map(checkNumInputFileBlockSources(e, _)).sum
if (numInputFileBlockSources > 1) {
e.failAnalysis(s"'${e.prettyName}' does not support more than one sources")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to check it as early as possible; otherwise, Union might eat it.

} else {
numInputFileBlockSources
}
}
}
}


/**
* A rule to do various checks before inserting into or writing to a data source table.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ abstract class BaseSessionStateBuilder(

override val extendedCheckRules: Seq[LogicalPlan => Unit] =
PreWriteCheck +:
PreReadCheck +:
HiveOnlyCheck +:
customCheckRules
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,63 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext {
)
}

test("input_file_name, input_file_block_start, input_file_block_length - more than one source") {
withTempView("tempView1") {
withTable("tab1", "tab2") {
val data = sparkContext.parallelize(0 to 9).toDF("id")
data.write.saveAsTable("tab1")
data.write.saveAsTable("tab2")
data.createOrReplaceTempView("tempView1")
Seq("input_file_name", "input_file_block_start", "input_file_block_length").foreach { f =>
val e = intercept[AnalysisException] {
sql(s"SELECT *, $f() FROM tab1 JOIN tab2 ON tab1.id = tab2.id")
}.getMessage
assert(e.contains(s"'$f' does not support more than one source"))
}

def checkResult(
fromClause: String,
exceptionExpected: Boolean,
numExpectedRows: Int = 0): Unit = {
val stmt = s"SELECT *, input_file_name() FROM ($fromClause)"
if (exceptionExpected) {
val e = intercept[AnalysisException](sql(stmt)).getMessage
assert(e.contains("'input_file_name' does not support more than one source"))
} else {
assert(sql(stmt).count() == numExpectedRows)
}
}

checkResult(
"SELECT * FROM tab1 UNION ALL SELECT * FROM tab2 UNION ALL SELECT * FROM tab2",
exceptionExpected = false,
numExpectedRows = 30)

checkResult(
"(SELECT * FROM tempView1 NATURAL JOIN tab2) UNION ALL SELECT * FROM tab2",
exceptionExpected = false,
numExpectedRows = 20)

checkResult(
"(SELECT * FROM tab1 UNION ALL SELECT * FROM tab2) NATURAL JOIN tempView1",
exceptionExpected = false,
numExpectedRows = 20)

checkResult(
"(SELECT * FROM tempView1 UNION ALL SELECT * FROM tab2) NATURAL JOIN tab2",
exceptionExpected = true)

checkResult(
"(SELECT * FROM tab1 NATURAL JOIN tab2) UNION ALL SELECT * FROM tab2",
exceptionExpected = true)

checkResult(
"(SELECT * FROM tab1 UNION ALL SELECT * FROM tab2) NATURAL JOIN tab2",
exceptionExpected = true)
}
}
}

test("input_file_name, input_file_block_start, input_file_block_length - FileScanRDD") {
withTempPath { dir =>
val data = sparkContext.parallelize(0 to 10).toDF("id")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,7 @@ class FileStreamSinkSuite extends StreamTest {
// Verify that MetadataLogFileIndex is being used and the correct partitioning schema has
// been inferred
val hadoopdFsRelations = outputDf.queryExecution.analyzed.collect {
case LogicalRelation(baseRelation, _, _) if baseRelation.isInstanceOf[HadoopFsRelation] =>
baseRelation.asInstanceOf[HadoopFsRelation]
case LogicalRelation(baseRelation: HadoopFsRelation, _, _) => baseRelation
}
assert(hadoopdFsRelations.size === 1)
assert(hadoopdFsRelations.head.location.isInstanceOf[MetadataLogFileIndex])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,22 +69,23 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session
override protected def analyzer: Analyzer = new Analyzer(catalog, conf) {
override val extendedResolutionRules: Seq[Rule[LogicalPlan]] =
new ResolveHiveSerdeTable(session) +:
new FindDataSourceTable(session) +:
new ResolveSQLOnFile(session) +:
customResolutionRules
new FindDataSourceTable(session) +:
new ResolveSQLOnFile(session) +:
customResolutionRules

override val postHocResolutionRules: Seq[Rule[LogicalPlan]] =
new DetermineTableStats(session) +:
RelationConversions(conf, catalog) +:
PreprocessTableCreation(session) +:
PreprocessTableInsertion(conf) +:
DataSourceAnalysis(conf) +:
HiveAnalysis +:
customPostHocResolutionRules
RelationConversions(conf, catalog) +:
PreprocessTableCreation(session) +:
PreprocessTableInsertion(conf) +:
DataSourceAnalysis(conf) +:
HiveAnalysis +:
customPostHocResolutionRules

override val extendedCheckRules: Seq[LogicalPlan => Unit] =
PreWriteCheck +:
customCheckRules
PreReadCheck +:
customCheckRules
}

/**
Expand Down