Skip to content

[SPARK-51834][SQL] Support end-to-end table constraint management #50631

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -4070,6 +4070,12 @@
],
"sqlState" : "HV091"
},
"NON_DETERMINISTIC_CHECK_CONSTRAINT" : {
"message" : [
"The check constraint `<checkCondition>` is non-deterministic. Check constraints must only contain deterministic expressions."
],
"sqlState" : "42621"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error code seems consistent with DB2 and what we use for generated columns, +1.

},
"NON_FOLDABLE_ARGUMENT" : {
"message" : [
"The function <funcName> requires the parameter <paramName> to be a foldable expression of the type <paramType>, but the actual argument is a non-foldable."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1138,6 +1138,12 @@ trait CheckAnalysis extends LookupCatalog with QueryErrorsBase with PlanToString
case RenameColumn(table: ResolvedTable, col: ResolvedFieldName, newName) =>
checkColumnNotExists("rename", col.path :+ newName, table.schema)

case AddConstraint(_: ResolvedTable, check: CheckConstraint) if !check.deterministic =>
check.child.failAnalysis(
errorClass = "NON_DETERMINISTIC_CHECK_CONSTRAINT",
messageParameters = Map("checkCondition" -> check.condition)
)

case AlterColumns(table: ResolvedTable, specs) =>
val groupedColumns = specs.groupBy(_.column.name)
groupedColumns.collect {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@ import scala.jdk.CollectionConverters._
import org.apache.spark.SparkException
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.SqlScriptingLocalVariableManager
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogPlugin, Identifier, LookupCatalog, SupportsNamespaces}
import org.apache.spark.sql.catalyst.util.SparkCharVarcharUtils.replaceCharVarcharWithString
import org.apache.spark.sql.connector.catalog._
import org.apache.spark.sql.errors.DataTypeErrors.toSQLId
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.util.ArrayImplicits._
Expand Down Expand Up @@ -77,14 +79,19 @@ class ResolveCatalogs(val catalogManager: CatalogManager)
assertValidSessionVariableNameParts(nameParts, resolved)
d.copy(name = resolved)

// For CREATE TABLE and REPLACE TABLE statements, resolve the table identifier and include
// the table columns as output. This allows expressions (e.g., constraints) referencing these
// columns to be resolved correctly.
case c @ CreateTable(UnresolvedIdentifier(nameParts, allowTemp), columns, _, _, _) =>
val resolvedIdentifier = resolveIdentifier(nameParts, allowTemp, columns)
c.copy(name = resolvedIdentifier)

case r @ ReplaceTable(UnresolvedIdentifier(nameParts, allowTemp), columns, _, _, _) =>
val resolvedIdentifier = resolveIdentifier(nameParts, allowTemp, columns)
r.copy(name = resolvedIdentifier)

case UnresolvedIdentifier(nameParts, allowTemp) =>
if (allowTemp && catalogManager.v1SessionCatalog.isTempView(nameParts)) {
val ident = Identifier.of(nameParts.dropRight(1).toArray, nameParts.last)
ResolvedIdentifier(FakeSystemCatalog, ident)
} else {
val CatalogAndIdentifier(catalog, identifier) = nameParts
ResolvedIdentifier(catalog, identifier)
}
resolveIdentifier(nameParts, allowTemp, Nil)

case CurrentNamespace =>
ResolvedNamespace(currentCatalog, catalogManager.currentNamespace.toImmutableArraySeq)
Expand All @@ -94,6 +101,27 @@ class ResolveCatalogs(val catalogManager: CatalogManager)
resolveNamespace(catalog, ns, fetchMetadata)
}

private def resolveIdentifier(
nameParts: Seq[String],
allowTemp: Boolean,
columns: Seq[ColumnDefinition]): ResolvedIdentifier = {
val columnOutput = columns.map { col =>
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan I made this new change to bypass 38c6ef4#diff-583171e935b2dc349378063a5841c5b98b30a2d57ac3743a9eccfe7bffcb8f2aR286
Does this look good to you?

val dataType = if (conf.preserveCharVarcharTypeInfo) {
col.dataType
} else {
replaceCharVarcharWithString(col.dataType)
}
AttributeReference(col.name, dataType, col.nullable, col.metadata)()
}
if (allowTemp && catalogManager.v1SessionCatalog.isTempView(nameParts)) {
val ident = Identifier.of(nameParts.dropRight(1).toArray, nameParts.last)
ResolvedIdentifier(FakeSystemCatalog, ident, columnOutput)
} else {
val CatalogAndIdentifier(catalog, identifier) = nameParts
ResolvedIdentifier(catalog, identifier, columnOutput)
}
}

private def resolveNamespace(
catalog: CatalogPlugin,
ns: Seq[String],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.SparkThrowable
import org.apache.spark.sql.catalyst.expressions.{Expression, Literal}
import org.apache.spark.sql.catalyst.expressions._
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the agreement in the community on wildcard imports? Are they permitted after a given number of elements are imported directly?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As per https://github.com/databricks/scala-style-guide?tab=readme-ov-file#imports,
"Avoid using wildcard imports, unless you are importing more than 6 entities"

import org.apache.spark.sql.catalyst.optimizer.ComputeCurrentTime
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
Expand Down Expand Up @@ -61,7 +61,7 @@ object ResolveTableSpec extends Rule[LogicalPlan] {
input: LogicalPlan,
tableSpec: TableSpecBase,
withNewSpec: TableSpecBase => LogicalPlan): LogicalPlan = tableSpec match {
case u: UnresolvedTableSpec if u.optionExpression.resolved =>
case u: UnresolvedTableSpec if u.childrenResolved =>
val newOptions: Seq[(String, String)] = u.optionExpression.options.map {
case (key: String, null) =>
(key, null)
Expand All @@ -86,6 +86,18 @@ object ResolveTableSpec extends Rule[LogicalPlan] {
}
(key, newValue)
}

u.constraints.foreach {
case check: CheckConstraint =>
if (!check.child.deterministic) {
check.child.failAnalysis(
errorClass = "NON_DETERMINISTIC_CHECK_CONSTRAINT",
messageParameters = Map("checkCondition" -> check.condition)
)
}
case _ =>
}

val newTableSpec = TableSpec(
properties = u.properties,
provider = u.provider,
Expand All @@ -94,7 +106,8 @@ object ResolveTableSpec extends Rule[LogicalPlan] {
comment = u.comment,
collation = u.collation,
serde = u.serde,
external = u.external)
external = u.external,
constraints = u.constraints.map(_.toV2Constraint))
withNewSpec(newTableSpec)
case _ =>
input
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,8 +252,13 @@ case class ResolvedNonPersistentFunc(
*/
case class ResolvedIdentifier(
catalog: CatalogPlugin,
identifier: Identifier) extends LeafNodeWithoutStats {
override def output: Seq[Attribute] = Nil
identifier: Identifier,
override val output: Seq[Attribute] = Nil) extends LeafNodeWithoutStats

object ResolvedIdentifier {
def unapply(ri: ResolvedIdentifier): Option[(CatalogPlugin, Identifier)] = {
Some((ri.catalog, ri.identifier))
}
}

// A fake v2 catalog to hold temp views.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,17 @@ package org.apache.spark.sql.catalyst.expressions

import java.util.UUID

import org.apache.spark.sql.catalyst.analysis.UnresolvedException
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.catalyst.trees.CurrentOrigin
import org.apache.spark.sql.types.{DataType, StringType}
import org.apache.spark.sql.catalyst.util.V2ExpressionBuilder
import org.apache.spark.sql.connector.catalog.constraints.Constraint
import org.apache.spark.sql.connector.expressions.FieldReference
import org.apache.spark.sql.types.DataType

trait TableConstraint {
trait TableConstraint extends Expression with Unevaluable {
/** Convert to a data source v2 constraint */
def toV2Constraint: Constraint

/** Returns the user-provided name of the constraint */
def userProvidedName: String
Expand Down Expand Up @@ -92,6 +98,10 @@ trait TableConstraint {
)
}
}

override def nullable: Boolean = throw new UnresolvedException("nullable")

override def dataType: DataType = throw new UnresolvedException("dataType")
}

case class ConstraintCharacteristic(enforced: Option[Boolean], rely: Option[Boolean])
Expand All @@ -108,10 +118,25 @@ case class CheckConstraint(
override val tableName: String = null,
override val userProvidedCharacteristic: ConstraintCharacteristic = ConstraintCharacteristic.empty)
extends UnaryExpression
with Unevaluable
with TableConstraint {
// scalastyle:on line.size.limit

def toV2Constraint: Constraint = {
val predicate = new V2ExpressionBuilder(child, true).buildPredicate().orNull
val enforced = userProvidedCharacteristic.enforced.getOrElse(true)
val rely = userProvidedCharacteristic.rely.getOrElse(false)
// TODO(SPARK-51903): Change the status to VALIDATED when we support validation on ALTER TABLE
val validateStatus = Constraint.ValidationStatus.UNVALIDATED
Constraint
.check(name)
.predicateSql(condition)
.predicate(predicate)
.rely(rely)
.enforced(enforced)
.validationStatus(validateStatus)
.build()
}

override protected def withNewChildInternal(newChild: Expression): Expression =
copy(child = newChild)

Expand All @@ -121,8 +146,6 @@ case class CheckConstraint(

override def sql: String = s"CONSTRAINT $userProvidedName CHECK ($condition)"

override def dataType: DataType = StringType

override def withUserProvidedName(name: String): TableConstraint = copy(userProvidedName = name)

override def withTableName(tableName: String): TableConstraint = copy(tableName = tableName)
Expand All @@ -137,9 +160,20 @@ case class PrimaryKeyConstraint(
override val userProvidedName: String = null,
override val tableName: String = null,
override val userProvidedCharacteristic: ConstraintCharacteristic = ConstraintCharacteristic.empty)
extends TableConstraint {
extends LeafExpression with TableConstraint {
// scalastyle:on line.size.limit

override def toV2Constraint: Constraint = {
val enforced = userProvidedCharacteristic.enforced.getOrElse(false)
val rely = userProvidedCharacteristic.rely.getOrElse(false)
Constraint
.primaryKey(name, columns.map(FieldReference.column).toArray)
.rely(rely)
.enforced(enforced)
.validationStatus(Constraint.ValidationStatus.UNVALIDATED)
.build()
}

override protected def generateName(tableName: String): String = s"${tableName}_pk"

override def withUserProvidedName(name: String): TableConstraint = copy(userProvidedName = name)
Expand All @@ -158,9 +192,20 @@ case class UniqueConstraint(
override val userProvidedName: String = null,
override val tableName: String = null,
override val userProvidedCharacteristic: ConstraintCharacteristic = ConstraintCharacteristic.empty)
extends TableConstraint {
extends LeafExpression with TableConstraint {
// scalastyle:on line.size.limit

override def toV2Constraint: Constraint = {
val enforced = userProvidedCharacteristic.enforced.getOrElse(false)
val rely = userProvidedCharacteristic.rely.getOrElse(false)
Constraint
.unique(name, columns.map(FieldReference.column).toArray)
.rely(rely)
.enforced(enforced)
.validationStatus(Constraint.ValidationStatus.UNVALIDATED)
.build()
}

override protected def generateName(tableName: String): String = {
s"${tableName}_uniq_$randomSuffix"
}
Expand All @@ -183,9 +228,25 @@ case class ForeignKeyConstraint(
override val userProvidedName: String = null,
override val tableName: String = null,
override val userProvidedCharacteristic: ConstraintCharacteristic = ConstraintCharacteristic.empty)
extends TableConstraint {
extends LeafExpression with TableConstraint {
// scalastyle:on line.size.limit

import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._

override def toV2Constraint: Constraint = {
val enforced = userProvidedCharacteristic.enforced.getOrElse(false)
val rely = userProvidedCharacteristic.rely.getOrElse(false)
Constraint
.foreignKey(name,
childColumns.map(FieldReference.column).toArray,
parentTableId.asIdentifier,
parentColumns.map(FieldReference.column).toArray)
.rely(rely)
.enforced(enforced)
.validationStatus(Constraint.ValidationStatus.UNVALIDATED)
.build()
}

override protected def generateName(tableName: String): String =
s"${tableName}_${parentTableId.last}_fk_$randomSuffix"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.catalyst.plans.logical

import org.apache.spark.sql.catalyst.analysis.{FieldName, FieldPosition, ResolvedFieldName, UnresolvedException}
import org.apache.spark.sql.catalyst.analysis.{FieldName, FieldPosition, ResolvedFieldName, ResolvedTable, UnresolvedException}
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.catalog.ClusterBySpec
import org.apache.spark.sql.catalyst.expressions.{Expression, TableConstraint, Unevaluable}
Expand Down Expand Up @@ -295,7 +295,16 @@ case class AlterTableCollation(
case class AddConstraint(
table: LogicalPlan,
tableConstraint: TableConstraint) extends AlterTableCommand {
override def changes: Seq[TableChange] = Seq.empty
override def changes: Seq[TableChange] = {
val constraint = tableConstraint.toV2Constraint
val validatedTableVersion = table match {
case t: ResolvedTable if constraint.enforced() =>
t.table.currentVersion()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Created a follow-up https://issues.apache.org/jira/browse/SPARK-51835 for testing the table version

case _ =>
null
}
Seq(TableChange.addConstraint(constraint, validatedTableVersion))
Copy link
Contributor

@aokolnychyi aokolnychyi Apr 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CHECK constraints must optionally validate existing data in ALTER.
Am I right this PR doesn't have this? What would be our plan?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

must optionally validate

Make sense. Do you mean CHECK ... NOT ENFOCED?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ENFORCED/NOT ENFORCED impacts subsequent writes. I was referring to ALTER TABLE ... ADD CONSTRAINT that must scan the existing data.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for my understanding. Anton's comment was about how to validate the existing data in ALTER TABLE ... ADD CONSTRAINT. Is it addressed in this PR, @gengliangwang ?

The above follow-up JIRA (SPARK-51905) is not about that, isn't it?

SPARK-51905 Disallow NOT ENFORCED CHECK constraint

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think we need one more JIRA to add the scan capability to ALTER TABLE ... ADD CONSTRAINT.

Copy link
Contributor

@aokolnychyi aokolnychyi Apr 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, @gengliangwang already created it: SPARK-51903.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. Thank you, @aokolnychyi . Ya, SPARK-51903 is what I expected.

}

protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan = copy(table = newChild)
}
Expand All @@ -308,7 +317,8 @@ case class DropConstraint(
name: String,
ifExists: Boolean,
cascade: Boolean) extends AlterTableCommand {
override def changes: Seq[TableChange] = Seq.empty
override def changes: Seq[TableChange] =
Seq(TableChange.dropConstraint(name, ifExists, cascade))

protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan = copy(table = newChild)
}
Original file line number Diff line number Diff line change
Expand Up @@ -1505,19 +1505,25 @@ case class UnresolvedTableSpec(
serde: Option[SerdeInfo],
external: Boolean,
constraints: Seq[TableConstraint])
extends UnaryExpression with Unevaluable with TableSpecBase {
extends Expression with Unevaluable with TableSpecBase {

override def dataType: DataType =
throw new SparkUnsupportedOperationException("_LEGACY_ERROR_TEMP_3113")

override def child: Expression = optionExpression

override protected def withNewChildInternal(newChild: Expression): Expression =
this.copy(optionExpression = newChild.asInstanceOf[OptionList])

override def simpleString(maxFields: Int): String = {
this.copy(properties = Utils.redact(properties).toMap).toString
}

override def nullable: Boolean = true

override def children: Seq[Expression] = optionExpression +: constraints

override protected def withNewChildrenInternal(
newChildren: IndexedSeq[Expression]): Expression = {
copy(
optionExpression = newChildren.head.asInstanceOf[OptionList],
constraints = newChildren.tail.asInstanceOf[Seq[TableConstraint]])
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ case class CreateTableExec(
.withColumns(columns)
.withPartitions(partitioning.toArray)
.withProperties(tableProperties.asJava)
.withConstraints(tableSpec.constraints.toArray)
.build()
catalog.createTable(identifier, tableInfo)
} catch {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ case class ReplaceTableExec(
.withColumns(columns)
.withPartitions(partitioning.toArray)
.withProperties(tableProperties.asJava)
.withConstraints(tableSpec.constraints.toArray)
.build()
catalog.createTable(ident, tableInfo)
Seq.empty
Expand Down
Loading