Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AMORO-3436] Fix #{3436} merge into statement cannot find primary keys #3439

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,14 @@ case class ResolveMergeIntoMixedFormatTableReferences(spark: SparkSession)

def checkConditionIsPrimaryKey(aliasedTable: LogicalPlan, cond: Expression): Unit = {
EliminateSubqueryAliases(aliasedTable) match {
case r @ DataSourceV2Relation(tbl, _, _, _, _) if isMixedFormatRelation(r) =>
case r@DataSourceV2Relation(tbl, _, _, _, _) if isMixedFormatRelation(r) =>
tbl match {
case mixedSparkTable: MixedSparkTable =>
if (mixedSparkTable.table().isKeyedTable) {
val primaryKeys = mixedSparkTable.table().asKeyedTable().primaryKeySpec().fieldNames()
val attributes = aliasedTable.output.filter(p => primaryKeys.contains(p.name))
val condRefs = cond.references.filter(f => attributes.contains(f))
//val condRefs = cond.references.filter(f => attributes.contains(f))
val condRefs = attributes.filter(attr => cond.references.contains(attr))
if (condRefs.isEmpty) {
throw new UnsupportedOperationException(
s"Condition ${cond.references}. is not allowed because is not a primary key")
Expand All @@ -55,12 +56,12 @@ case class ResolveMergeIntoMixedFormatTableReferences(spark: SparkSession)
}

override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp {
case m @ UnresolvedMergeIntoMixedFormatTable(
aliasedTable,
source,
cond,
matchedActions,
notMatchedActions) =>
case m@UnresolvedMergeIntoMixedFormatTable(
aliasedTable,
source,
cond,
matchedActions,
notMatchedActions) =>
checkConditionIsPrimaryKey(aliasedTable, cond)

val resolvedMatchedActions = matchedActions.map {
Expand Down Expand Up @@ -122,9 +123,9 @@ case class ResolveMergeIntoMixedFormatTableReferences(spark: SparkSession)
}

private def resolveLiteralFunction(
nameParts: Seq[String],
attribute: UnresolvedAttribute,
plan: LogicalPlan): Option[Expression] = {
nameParts: Seq[String],
attribute: UnresolvedAttribute,
plan: LogicalPlan): Option[Expression] = {
if (nameParts.length != 1) return None
val isNamedExpression = plan match {
case Aggregate(_, aggregateExpressions, _) => aggregateExpressions.contains(attribute)
Expand All @@ -142,14 +143,14 @@ case class ResolveMergeIntoMixedFormatTableReferences(spark: SparkSession)
}

def resolveExpressionBottomUp(
expr: Expression,
plan: LogicalPlan,
throws: Boolean = false): Expression = {
expr: Expression,
plan: LogicalPlan,
throws: Boolean = false): Expression = {
if (expr.resolved) return expr
try {
expr transformUp {
case GetColumnByOrdinal(ordinal, _) => plan.output(ordinal)
case u @ UnresolvedAttribute(nameParts) =>
case u@UnresolvedAttribute(nameParts) =>
val result =
withPosition(u) {
plan.resolve(nameParts, resolver)
Expand Down Expand Up @@ -182,8 +183,8 @@ case class ResolveMergeIntoMixedFormatTableReferences(spark: SparkSession)
def resolver: Resolver = conf.resolver

def resolveExpressionByPlanChildren(
e: Expression,
q: LogicalPlan): Expression = {
e: Expression,
q: LogicalPlan): Expression = {
resolveExpression(
e,
resolveColumnByName = nameParts => {
Expand All @@ -197,10 +198,10 @@ case class ResolveMergeIntoMixedFormatTableReferences(spark: SparkSession)
}

private def resolveExpression(
expr: Expression,
resolveColumnByName: Seq[String] => Option[Expression],
getAttrCandidates: () => Seq[Attribute],
throws: Boolean): Expression = {
expr: Expression,
resolveColumnByName: Seq[String] => Option[Expression],
getAttrCandidates: () => Seq[Attribute],
throws: Boolean): Expression = {
def innerResolve(e: Expression, isTopLevel: Boolean): Expression = {
if (e.resolved) return e
e match {
Expand All @@ -211,7 +212,7 @@ case class ResolveMergeIntoMixedFormatTableReferences(spark: SparkSession)
assert(ordinal >= 0 && ordinal < attrCandidates.length)
attrCandidates(ordinal)

case u @ UnresolvedAttribute(nameParts) =>
case u@UnresolvedAttribute(nameParts) =>
val result = withPosition(u) {
resolveColumnByName(nameParts).map {
case Alias(child, _) if !isTopLevel => child
Expand All @@ -221,7 +222,7 @@ case class ResolveMergeIntoMixedFormatTableReferences(spark: SparkSession)
logDebug(s"Resolving $u to $result")
result

case u @ UnresolvedExtractValue(child, fieldName) =>
case u@UnresolvedExtractValue(child, fieldName) =>
val newChild = innerResolve(child, isTopLevel = false)
if (newChild.resolved) {
withOrigin(u.origin) {
Expand All @@ -244,9 +245,9 @@ case class ResolveMergeIntoMixedFormatTableReferences(spark: SparkSession)

// copied from ResolveReferences in Spark
private def resolveAssignments(
assignments: Seq[Assignment],
mergeInto: UnresolvedMergeIntoMixedFormatTable,
resolveValuesWithSourceOnly: Boolean): Seq[Assignment] = {
assignments: Seq[Assignment],
mergeInto: UnresolvedMergeIntoMixedFormatTable,
resolveValuesWithSourceOnly: Boolean): Seq[Assignment] = {
assignments.map { assign =>
val resolvedKey = assign.key match {
case c if !c.resolved =>
Expand Down
Loading