Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -53,25 +53,53 @@ case class ResolveDefaultColumns(

// This field stores the enclosing INSERT INTO command, once we find one.
var enclosingInsert: Option[InsertIntoStatement] = None
// This field stores the schema of the target table of the above command.
var insertTableSchemaWithoutPartitionColumns: Option[StructType] = None

override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsWithPruning(
(_ => SQLConf.get.enableDefaultColumns), ruleId) {
case i@InsertIntoStatement(_, _, _, _, _, _)
if i.query.collectFirst { case u: UnresolvedInlineTable => u }.isDefined =>
enclosingInsert = Some(i)
i
override def apply(plan: LogicalPlan): LogicalPlan = {
// Initialize by clearing our reference to the enclosing INSERT INTO command.
enclosingInsert = None
insertTableSchemaWithoutPartitionColumns = None
// Traverse the logical query plan in preorder (top-down).
plan.resolveOperatorsWithPruning(
(_ => SQLConf.get.enableDefaultColumns), ruleId) {
case i@InsertIntoStatement(_, _, _, _, _, _)
if i.query.collectFirst { case u: UnresolvedInlineTable
if u.rows.nonEmpty && u.rows.forall(_.size == u.rows(0).size) => u
}.isDefined =>
enclosingInsert = Some(i)
insertTableSchemaWithoutPartitionColumns = getInsertTableSchemaWithoutPartitionColumns
val regenerated: InsertIntoStatement = regenerateUserSpecifiedCols(i)
regenerated

case table: UnresolvedInlineTable
if enclosingInsert.isDefined =>
val expanded: UnresolvedInlineTable = addMissingDefaultColumnValues(table).getOrElse(table)
val replaced: LogicalPlan =
replaceExplicitDefaultColumnValues(analyzer, expanded).getOrElse(table)
replaced

case i@InsertIntoStatement(_, _, _, project: Project, _, _) =>
enclosingInsert = Some(i)
insertTableSchemaWithoutPartitionColumns = getInsertTableSchemaWithoutPartitionColumns
val expanded: Project = addMissingDefaultColumnValues(project).getOrElse(project)
val replaced: Option[LogicalPlan] = replaceExplicitDefaultColumnValues(analyzer, expanded)
val updated: InsertIntoStatement =
if (replaced.isDefined) i.copy(query = replaced.get) else i
val regenerated: InsertIntoStatement = regenerateUserSpecifiedCols(updated)
regenerated
}
}

case table: UnresolvedInlineTable
if enclosingInsert.isDefined &&
table.rows.nonEmpty && table.rows.forall(_.size == table.rows(0).size) =>
val expanded: UnresolvedInlineTable = addMissingDefaultColumnValues(table).getOrElse(table)
replaceExplicitDefaultColumnValues(analyzer, expanded).getOrElse(table)

case i@InsertIntoStatement(_, _, _, project: Project, _, _) =>
enclosingInsert = Some(i)
val expanded: Project = addMissingDefaultColumnValues(project).getOrElse(project)
val replaced: Option[LogicalPlan] = replaceExplicitDefaultColumnValues(analyzer, expanded)
if (replaced.isDefined) i.copy(query = replaced.get) else i
// Helper method to regenerate user-specified columns of an InsertIntoStatement based on the names
// in the insertTableSchemaWithoutPartitionColumns field of this class.
private def regenerateUserSpecifiedCols(i: InsertIntoStatement): InsertIntoStatement = {
if (i.userSpecifiedCols.nonEmpty && insertTableSchemaWithoutPartitionColumns.isDefined) {
i.copy(
userSpecifiedCols = insertTableSchemaWithoutPartitionColumns.get.fields.map(_.name))
} else {
i
}
}

// Helper method to check if an expression is an explicit DEFAULT column reference.
Expand All @@ -87,7 +115,7 @@ case class ResolveDefaultColumns(
table: UnresolvedInlineTable): Option[UnresolvedInlineTable] = {
assert(enclosingInsert.isDefined)
val numQueryOutputs: Int = table.rows(0).size
val schema = getInsertTableSchemaWithoutPartitionColumns.getOrElse(return None)
val schema = insertTableSchemaWithoutPartitionColumns.getOrElse(return None)
val newDefaultExpressions: Seq[Expression] = getDefaultExpressions(numQueryOutputs, schema)
val newNames: Seq[String] = schema.fields.drop(numQueryOutputs).map { _.name }
if (newDefaultExpressions.nonEmpty) {
Expand All @@ -104,7 +132,7 @@ case class ResolveDefaultColumns(
*/
private def addMissingDefaultColumnValues(project: Project): Option[Project] = {
val numQueryOutputs: Int = project.projectList.size
val schema = getInsertTableSchemaWithoutPartitionColumns.getOrElse(return None)
val schema = insertTableSchemaWithoutPartitionColumns.getOrElse(return None)
val newDefaultExpressions: Seq[Expression] = getDefaultExpressions(numQueryOutputs, schema)
val newAliases: Seq[NamedExpression] =
newDefaultExpressions.zip(schema.fields).map {
Expand All @@ -122,24 +150,29 @@ case class ResolveDefaultColumns(
*/
private def getDefaultExpressions(numQueryOutputs: Int, schema: StructType): Seq[Expression] = {
val remainingFields: Seq[StructField] = schema.fields.drop(numQueryOutputs)
val numDefaultExpressionsToAdd: Int = {
if (SQLConf.get.useNullsForMissingDefaultColumnValues) {
remainingFields.size
} else {
remainingFields.takeWhile(_.metadata.contains(CURRENT_DEFAULT_COLUMN_METADATA_KEY)).size
}
}
val numDefaultExpressionsToAdd = getStructFieldsForDefaultExpressions(remainingFields).size
Seq.fill(numDefaultExpressionsToAdd)(UnresolvedAttribute(CURRENT_DEFAULT_COLUMN_NAME))
}

/**
* This is a helper for the getDefaultExpressions methods above.
*/
private def getStructFieldsForDefaultExpressions(fields: Seq[StructField]): Seq[StructField] = {
if (SQLConf.get.useNullsForMissingDefaultColumnValues) {
fields
} else {
fields.takeWhile(_.metadata.contains(CURRENT_DEFAULT_COLUMN_METADATA_KEY))
}
}

/**
* Replaces unresolved DEFAULT column references with corresponding values in a logical plan.
*/
private def replaceExplicitDefaultColumnValues(
analyzer: Analyzer,
input: LogicalPlan): Option[LogicalPlan] = {
assert(enclosingInsert.isDefined)
val schema = getInsertTableSchemaWithoutPartitionColumns.getOrElse(return None)
val schema = insertTableSchemaWithoutPartitionColumns.getOrElse(return None)
val columnNames: Seq[String] = schema.fields.map { _.name }
val defaultExpressions: Seq[Expression] = schema.fields.map {
case f if f.metadata.contains(CURRENT_DEFAULT_COLUMN_METADATA_KEY) =>
Expand Down Expand Up @@ -193,7 +226,7 @@ case class ResolveDefaultColumns(
}

/**
* Replaces unresolved DEFAULT column references with corresponding values in an projection.
* Replaces unresolved DEFAULT column references with corresponding values in a projection.
*/
private def replaceExplicitDefaultColumnValues(
defaultExpressions: Seq[Expression],
Expand Down Expand Up @@ -230,21 +263,48 @@ case class ResolveDefaultColumns(
*/
private def getInsertTableSchemaWithoutPartitionColumns: Option[StructType] = {
assert(enclosingInsert.isDefined)
try {
val tableName = enclosingInsert.get.table match {
case r: UnresolvedRelation => TableIdentifier(r.name)
case r: UnresolvedCatalogRelation => r.tableMeta.identifier
case _ => return None
}
val lookup = catalog.lookupRelation(tableName)
lookup match {
case SubqueryAlias(_, r: UnresolvedCatalogRelation) =>
Some(StructType(r.tableMeta.schema.fields.dropRight(
enclosingInsert.get.partitionSpec.size)))
case _ => None
}
val tableName = enclosingInsert.get.table match {
case r: UnresolvedRelation => TableIdentifier(r.name)
case r: UnresolvedCatalogRelation => r.tableMeta.identifier
case _ => return None
}
// Lookup the relation from the catalog by name. This either succeeds or returns some "not
// found" error. In the latter cases, return out of this rule without changing anything and let
// the analyzer return a proper error message elsewhere.
val lookup: LogicalPlan = try {
catalog.lookupRelation(tableName)
} catch {
case _: NoSuchTableException => None
case _: AnalysisException => return None
}
val schema: StructType = lookup match {
case SubqueryAlias(_, r: UnresolvedCatalogRelation) =>
StructType(r.tableMeta.schema.fields.dropRight(
enclosingInsert.get.partitionSpec.size))
case _ => return None
}
// Rearrange the columns in the result schema to match the order of the explicit column list,
// if any.
val userSpecifiedCols: Seq[String] = enclosingInsert.get.userSpecifiedCols
if (userSpecifiedCols.isEmpty) {
return Some(schema)
}
def normalize(str: String) = {
if (SQLConf.get.caseSensitiveAnalysis) str else str.toLowerCase()
}
val colNamesToFields: Map[String, StructField] =
schema.fields.map {
field: StructField => normalize(field.name) -> field
}.toMap
val userSpecifiedFields: Seq[StructField] =
userSpecifiedCols.map {
name: String => colNamesToFields.getOrElse(normalize(name), return None)
}
val userSpecifiedColNames: Set[String] = userSpecifiedCols.toSet
val nonUserSpecifiedFields: Seq[StructField] =
schema.fields.filter {
field => !userSpecifiedColNames.contains(field.name)
}
Some(StructType(userSpecifiedFields ++
getStructFieldsForDefaultExpressions(nonUserSpecifiedFields)))
}
}
Loading