Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ class Analyzer(
ResolveWindowOrder ::
ResolveWindowFrame ::
ResolveNaturalAndUsingJoin ::
ResolveOutputColumns ::
ExtractWindowExpressions ::
GlobalAggregates ::
ResolveAggregateFunctions ::
Expand Down Expand Up @@ -451,7 +452,7 @@ class Analyzer(
}

def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case i @ InsertIntoTable(u: UnresolvedRelation, parts, child, _, _) if child.resolved =>
case i @ InsertIntoTable(u: UnresolvedRelation, parts, child, _, _, _) if child.resolved =>
// A partitioned relation's schema can be different from the input logicalPlan, since
// partition columns are all moved after data columns. We Project to adjust the ordering.
val input = if (parts.nonEmpty) {
Expand Down Expand Up @@ -516,6 +517,124 @@ class Analyzer(
}
}

object ResolveOutputColumns extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case ins @ InsertIntoTable(relation: LogicalPlan, partition, _, _, _, _)
if relation.resolved && ins.childrenResolved && !ins.resolved =>
resolveOutputColumns(ins, expectedColumns(relation, partition), relation.toString)
}

private def resolveOutputColumns(
insertInto: InsertIntoTable,
columns: Seq[Attribute],
relation: String) = {
val resolved = if (insertInto.isMatchByName) {
projectAndCastOutputColumns(columns, insertInto.child, relation)
} else {
castAndRenameOutputColumns(columns, insertInto.child, relation)
}

if (resolved == insertInto.child.output) {
insertInto
} else {
insertInto.copy(child = Project(resolved, insertInto.child))
}
}

/**
* Resolves output columns by input column name, adding casts if necessary.
*/
private def projectAndCastOutputColumns(
output: Seq[Attribute],
data: LogicalPlan,
relation: String): Seq[NamedExpression] = {
if (output.size > data.output.size) {
// always a problem
throw new AnalysisException(
s"""Not enough data columns to write into $relation:
|Data columns: ${data.output.mkString(",")}
|Table columns: ${output.mkString(",")}""".stripMargin)
} else if (output.size < data.output.size) {
// be conservative and fail if there are too many columns
throw new AnalysisException(
s"""Extra data columns to write into $relation:
|Data columns: ${data.output.mkString(",")}
|Table columns: ${output.mkString(",")}""".stripMargin)
}

output.map { col =>
data.resolveQuoted(col.name, resolver) match {
case Some(inCol) if !col.dataType.sameType(inCol.dataType) =>
Alias(UpCast(inCol, col.dataType, Seq()), col.name)()
case Some(inCol) => inCol
case None =>
throw new AnalysisException(
s"Cannot resolve ${col.name} in ${data.output.mkString(",")}")
}
}
}

private def castAndRenameOutputColumns(
output: Seq[Attribute],
data: LogicalPlan,
relation: String): Seq[NamedExpression] = {
val outputNames = output.map(_.name)
// incoming expressions may not have names
val inputNames = data.output.flatMap(col => Option(col.name))
if (output.size > data.output.size) {
// always a problem
throw new AnalysisException(
s"""Not enough data columns to write into $relation:
|Data columns: ${data.output.mkString(",")}
|Table columns: ${outputNames.mkString(",")}""".stripMargin)
} else if (output.size < data.output.size) {
// be conservative and fail if there are too many columns
throw new AnalysisException(
s"""Extra data columns to write into $relation:
|Data columns: ${data.output.mkString(",")}
|Table columns: ${outputNames.mkString(",")}""".stripMargin)
} else {
// check for reordered names and warn. this may be on purpose, so it isn't an error.
if (outputNames.toSet == inputNames.toSet && outputNames != inputNames) {
logWarning(
s"""Data column names match the table in a different order:
|Data columns: ${inputNames.mkString(",")}
|Table columns: ${outputNames.mkString(",")}""".stripMargin)
}
}

data.output.zip(output).map {
case (in, out) if !in.dataType.sameType(out.dataType) =>
Alias(Cast(in, out.dataType), out.name)()
case (in, out) if in.name != out.name =>
Alias(in, out.name)()
case (in, _) => in
}
}

private def expectedColumns(
data: LogicalPlan,
partitionData: Map[String, Option[String]]): Seq[Attribute] = {
data match {
case partitioned: CatalogRelation =>
val tablePartitionNames = partitioned.catalogTable.partitionColumns.map(_.name)
val (inputPartCols, dataColumns) = data.output.partition { attr =>
tablePartitionNames.contains(attr.name)
}
// Get the dynamic partition columns in partition order
val dynamicNames = tablePartitionNames.filter(
name => partitionData.getOrElse(name, None).isEmpty)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

partitionData.contains?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using contains would include partitions that are set statically for this query, which have values like Some("static-val"). This doesn't happen through the DataFrameWriter, but is valid HiveQL.

val dynamicPartCols = dynamicNames.map { name =>
inputPartCols.find(_.name == name).getOrElse(
throw new AnalysisException(s"Cannot find partition column $name"))
}

dataColumns ++ dynamicPartCols
case _ => data.output
}
}
}

/**
* Replaces [[UnresolvedAttribute]]s with concrete [[AttributeReference]]s from
* a logical plan node's children.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ trait CheckAnalysis extends PredicateHelper {
|${s.catalogTable.identifier}
""".stripMargin)

case InsertIntoTable(s: SimpleCatalogRelation, _, _, _, _) =>
case InsertIntoTable(s: SimpleCatalogRelation, _, _, _, _, _) =>
failAnalysis(
s"""
|Hive support is required to insert into the following tables:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ package object dsl {
def insertInto(tableName: String, overwrite: Boolean = false): LogicalPlan =
InsertIntoTable(
analysis.UnresolvedRelation(TableIdentifier(tableName)),
Map.empty, logicalPlan, overwrite, false)
Map.empty, logicalPlan, overwrite, ifNotExists = false, Map.empty)

def as(alias: String): LogicalPlan = logicalPlan match {
case UnresolvedRelation(tbl, _) => UnresolvedRelation(tbl, Option(alias))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,9 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
UnresolvedRelation(tableIdent, None),
partitionKeys,
query,
ctx.OVERWRITE != null,
ctx.EXISTS != null)
overwrite = ctx.OVERWRITE != null,
ifNotExists = ctx.EXISTS != null,
Map.empty /* SQL always matches by position */)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -359,30 +359,43 @@ case class InsertIntoTable(
partition: Map[String, Option[String]],
child: LogicalPlan,
overwrite: Boolean,
ifNotExists: Boolean)
ifNotExists: Boolean,
options: Map[String, String])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this options map only contain matchByName?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now, yes. But I changed from a boolean to a map because there are other options that will be added, like filesPerPartition. Changing the case class requires code changes in several places just to add _ to a match expression or pass the argument on. Using a map means we can extend InsertIntoTable without updating the signature all the time. Patches should be smaller and have fewer conflicts.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just opened another PR that requires this map, #13206. That adds a hint for the number of writers per output partition so the optimizer can add the appropriate repartition and sort operation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As this PR won't do the byName resolution, should we remove this map parameter?

extends LogicalPlan {

override def children: Seq[LogicalPlan] = child :: Nil
override def output: Seq[Attribute] = Seq.empty

private[spark] def isMatchByName: Boolean = {
options.get("matchByName").map(_.toBoolean).getOrElse(false)
}

private[spark] lazy val expectedColumns = {
if (table.output.isEmpty) {
None
} else {
val numDynamicPartitions = partition.values.count(_.isEmpty)
val dynamicPartitionNames = partition.filter {
case (name, Some(_)) => false
case (name, None) => true
}.keySet
val (partitionColumns, dataColumns) = table.output
.partition(a => partition.keySet.contains(a.name))
Some(dataColumns ++ partitionColumns.takeRight(numDynamicPartitions))
Some(dataColumns ++ partitionColumns.filter(col => dynamicPartitionNames.contains(col.name)))
}
}

assert(overwrite || !ifNotExists)
override lazy val resolved: Boolean =
childrenResolved && table.resolved && expectedColumns.forall { expected =>
child.output.size == expected.size && child.output.zip(expected).forall {
case (childAttr, tableAttr) =>
DataType.equalsIgnoreCompatibleNullability(childAttr.dataType, tableAttr.dataType)
}
childrenResolved && table.resolved && {
expectedColumns match {
case Some(expected) =>
child.output.size == expected.size && child.output.zip(expected).forall {
case (childAttr, tableAttr) =>
childAttr.name == tableAttr.name && // required by some relations
DataType.equalsIgnoreCompatibleNullability(childAttr.dataType, tableAttr.dataType)
}
case None => true
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ class PlanParserSuite extends PlanTest {
partition: Map[String, Option[String]],
overwrite: Boolean = false,
ifNotExists: Boolean = false): LogicalPlan =
InsertIntoTable(table("s"), partition, plan, overwrite, ifNotExists)
InsertIntoTable(table("s"), partition, plan, overwrite, ifNotExists, Map.empty)

// Single inserts
assertEqual(s"insert overwrite table s $sql",
Expand All @@ -196,9 +196,11 @@ class PlanParserSuite extends PlanTest {
val plan2 = table("t").where('x > 5).select(star())
assertEqual("from t insert into s select * limit 1 insert into u select * where x > 5",
InsertIntoTable(
table("s"), Map.empty, plan.limit(1), overwrite = false, ifNotExists = false).union(
table("s"), Map.empty, plan.limit(1), overwrite = false, ifNotExists = false,
Map.empty).union(
InsertIntoTable(
table("u"), Map.empty, plan2, overwrite = false, ifNotExists = false)))
table("u"), Map.empty, plan2, overwrite = false, ifNotExists = false,
Map.empty)))
}

test("aggregation") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
partitions.getOrElse(Map.empty[String, Option[String]]),
df.logicalPlan,
overwrite,
ifNotExists = false)).toRdd
ifNotExists = false,
options = extraOptions.toMap)).toRdd
}

private def normalizedParCols: Option[Seq[String]] = partitioningColumns.map { cols =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ import org.apache.spark.unsafe.types.UTF8String
private[sql] object DataSourceAnalysis extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case i @ logical.InsertIntoTable(
l @ LogicalRelation(t: HadoopFsRelation, _, _), part, query, overwrite, false)
l @ LogicalRelation(t: HadoopFsRelation, _, _), part, query, overwrite, false, _)
if query.resolved && t.schema.asNullable == query.schema.asNullable =>

// Sanity checks
Expand Down Expand Up @@ -110,7 +110,7 @@ private[sql] class FindDataSourceTable(sparkSession: SparkSession) extends Rule[
}

override def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case i @ logical.InsertIntoTable(s: SimpleCatalogRelation, _, _, _, _)
case i @ logical.InsertIntoTable(s: SimpleCatalogRelation, _, _, _, _, _)
if DDLUtils.isDatasourceTable(s.metadata) =>
i.copy(table = readDataSourceTable(sparkSession, s.metadata))

Expand Down Expand Up @@ -152,7 +152,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
l.output, toCatalystRDD(l, baseRelation.buildScan()), baseRelation) :: Nil

case i @ logical.InsertIntoTable(l @ LogicalRelation(t: InsertableRelation, _, _),
part, query, overwrite, false) if part.isEmpty =>
part, query, overwrite, false, _) if part.isEmpty =>
ExecutedCommandExec(InsertIntoDataSourceCommand(l, query, overwrite)) :: Nil

case _ => Nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,55 +61,6 @@ private[sql] class ResolveDataSource(sparkSession: SparkSession) extends Rule[Lo
}
}

/**
* A rule to do pre-insert data type casting and field renaming. Before we insert into
* an [[InsertableRelation]], we will use this rule to make sure that
* the columns to be inserted have the correct data type and fields have the correct names.
*/
private[sql] object PreInsertCastAndRename extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
// Wait until children are resolved.
case p: LogicalPlan if !p.childrenResolved => p

// We are inserting into an InsertableRelation or HadoopFsRelation.
case i @ InsertIntoTable(
l @ LogicalRelation(_: InsertableRelation | _: HadoopFsRelation, _, _), _, child, _, _) =>
// First, make sure the data to be inserted have the same number of fields with the
// schema of the relation.
if (l.output.size != child.output.size) {
sys.error(
s"$l requires that the query in the SELECT clause of the INSERT INTO/OVERWRITE " +
s"statement generates the same number of columns as its schema.")
}
castAndRenameChildOutput(i, l.output, child)
}

/** If necessary, cast data types and rename fields to the expected types and names. */
def castAndRenameChildOutput(
insertInto: InsertIntoTable,
expectedOutput: Seq[Attribute],
child: LogicalPlan): InsertIntoTable = {
val newChildOutput = expectedOutput.zip(child.output).map {
case (expected, actual) =>
val needCast = !expected.dataType.sameType(actual.dataType)
// We want to make sure the filed names in the data to be inserted exactly match
// names in the schema.
val needRename = expected.name != actual.name
(needCast, needRename) match {
case (true, _) => Alias(Cast(actual, expected.dataType), expected.name)()
case (false, true) => Alias(actual, expected.name)()
case (_, _) => actual
}
}

if (newChildOutput == child.output) {
insertInto
} else {
insertInto.copy(child = Project(newChildOutput, child))
}
}
}

/**
* A rule to do various checks before inserting into or writing to a data source table.
*/
Expand All @@ -122,7 +73,7 @@ private[sql] case class PreWriteCheck(conf: SQLConf, catalog: SessionCatalog)
plan.foreach {
case i @ logical.InsertIntoTable(
l @ LogicalRelation(t: InsertableRelation, _, _),
partition, query, overwrite, ifNotExists) =>
partition, query, overwrite, ifNotExists, _) =>
// Right now, we do not support insert into a data source table with partition specs.
if (partition.nonEmpty) {
failAnalysis(s"Insert into a partition is not allowed because $l is not partitioned.")
Expand All @@ -140,7 +91,7 @@ private[sql] case class PreWriteCheck(conf: SQLConf, catalog: SessionCatalog)
}

case logical.InsertIntoTable(
LogicalRelation(r: HadoopFsRelation, _, _), part, query, overwrite, _) =>
LogicalRelation(r: HadoopFsRelation, _, _), part, query, overwrite, _, _) =>
// We need to make sure the partition columns specified by users do match partition
// columns of the relation.
val existingPartitionColumns = r.partitionSchema.fieldNames.toSet
Expand Down Expand Up @@ -168,11 +119,11 @@ private[sql] case class PreWriteCheck(conf: SQLConf, catalog: SessionCatalog)
// OK
}

case logical.InsertIntoTable(l: LogicalRelation, _, _, _, _) =>
case logical.InsertIntoTable(l: LogicalRelation, _, _, _, _, _) =>
// The relation in l is not an InsertableRelation.
failAnalysis(s"$l does not allow insertion.")

case logical.InsertIntoTable(t, _, _, _, _) =>
case logical.InsertIntoTable(t, _, _, _, _, _) =>
if (!t.isInstanceOf[LeafNode] || t == OneRowRelation || t.isInstanceOf[LocalRelation]) {
failAnalysis(s"Inserting into an RDD-based table is not allowed.")
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.parser.ParserInterface
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.command.AnalyzeTableCommand
import org.apache.spark.sql.execution.datasources.{DataSourceAnalysis, FindDataSourceTable, PreInsertCastAndRename, ResolveDataSource}
import org.apache.spark.sql.execution.datasources.{DataSourceAnalysis, FindDataSourceTable, ResolveDataSource}
import org.apache.spark.sql.streaming.{ContinuousQuery, ContinuousQueryManager}
import org.apache.spark.sql.util.ExecutionListenerManager

Expand Down Expand Up @@ -111,7 +111,6 @@ private[sql] class SessionState(sparkSession: SparkSession) {
lazy val analyzer: Analyzer = {
new Analyzer(catalog, conf) {
override val extendedResolutionRules =
PreInsertCastAndRename ::
new FindDataSourceTable(sparkSession) ::
DataSourceAnalysis ::
(if (conf.runSQLonFile) new ResolveDataSource(sparkSession) :: Nil else Nil)
Expand Down
Loading