diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 4fb713b8108c3..b4d159eab4508 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1696,8 +1696,8 @@ class Analyzer( // Only a few unary nodes (Project/Filter/Aggregate) can contain subqueries. case q: UnaryNode if q.childrenResolved => resolveSubQueries(q, q.children) - case d: DeleteFromTable if d.childrenResolved => - resolveSubQueries(d, d.children) + case s: SupportsSubquery if s.childrenResolved => + resolveSubQueries(s, s.children) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index db4ed47fa54c6..e053d73c59d46 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -593,19 +593,19 @@ trait CheckAnalysis extends PredicateHelper { // Only certain operators are allowed to host subquery expression containing // outer references. plan match { - case _: Filter | _: Aggregate | _: Project | _: DeleteFromTable => // Ok + case _: Filter | _: Aggregate | _: Project | _: SupportsSubquery => // Ok case other => failAnalysis( "Correlated scalar sub-queries can only be used in a " + - s"Filter/Aggregate/Project: $plan") + s"Filter/Aggregate/Project and a few commands: $plan") } } case inSubqueryOrExistsSubquery => plan match { - case _: Filter | _: DeleteFromTable => // Ok + case _: Filter | _: SupportsSubquery => // Ok case _ => failAnalysis(s"IN/EXISTS predicate sub-queries can only be used in" + - s" Filter/DeleteFromTable: $plan") + s" Filter and a few commands: $plan") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala index 3757569443e74..12c92ced0e09a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala @@ -96,8 +96,12 @@ class ResolveCatalogs(val catalogManager: CatalogManager) val aliased = tableAlias.map(SubqueryAlias(_, r)).getOrElse(r) DeleteFromTable(aliased, condition) - case update: UpdateTableStatement => - throw new AnalysisException(s"UPDATE TABLE is not supported temporarily.") + case u @ UpdateTableStatement( + nameParts @ CatalogAndIdentifierParts(catalog, tableName), _, _, _, _) => + val r = UnresolvedV2Relation(nameParts, catalog.asTableCatalog, tableName.asIdentifier) + val aliased = u.tableAlias.map(SubqueryAlias(_, r)).getOrElse(r) + val columns = u.columns.map(UnresolvedAttribute(_)) + UpdateTable(aliased, columns, u.values, u.condition) case DescribeTableStatement( nameParts @ NonSessionCatalog(catalog, tableName), partitionSpec, isExtended) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 4793b5942a79e..f03174babcd9d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -233,6 +233,16 @@ object IntegerLiteral { } } +/** + * Extractor for retrieving String literals. + */ +object StringLiteral { + def unapply(a: Any): Option[String] = a match { + case Literal(s: UTF8String, StringType) => Some(s.toString) + case _ => None + } +} + /** * Extractor for and other utility methods for decimal literals. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index c0d53104874ab..d66371dd89e2d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -599,10 +599,17 @@ case class DescribeTable(table: NamedRelation, isExtended: Boolean) extends Comm } case class DeleteFromTable( - child: LogicalPlan, - condition: Option[Expression]) extends Command { + table: LogicalPlan, + condition: Option[Expression]) extends Command with SupportsSubquery { + override def children: Seq[LogicalPlan] = table :: Nil +} - override def children: Seq[LogicalPlan] = child :: Nil +case class UpdateTable( + table: LogicalPlan, + columns: Seq[Expression], + values: Seq[Expression], + condition: Option[Expression]) extends Command with SupportsSubquery { + override def children: Seq[LogicalPlan] = table :: Nil } /** @@ -1241,6 +1248,12 @@ case class Deduplicate( override def output: Seq[Attribute] = child.output } +/** + * A trait to represent the commands that support subqueries. + * This is used to whitelist such commands in the subquery-related checks. + */ +trait SupportsSubquery extends LogicalPlan + /** A trait used for logical plan nodes that create or replace V2 table definitions. */ trait V2CreateTablePlan extends LogicalPlan { def tableName: Identifier diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/UpdateTableStatement.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/UpdateTableStatement.scala index 954374c15b932..84b6d3d5d0b83 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/UpdateTableStatement.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/UpdateTableStatement.scala @@ -22,6 +22,6 @@ import org.apache.spark.sql.catalyst.expressions.Expression case class UpdateTableStatement( tableName: Seq[String], tableAlias: Option[String], - attrs: Seq[Seq[String]], + columns: Seq[Seq[String]], values: Seq[Expression], condition: Option[Expression]) extends ParsedStatement diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index f0356f5a42d67..3dabbca9deeee 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -532,7 +532,7 @@ class AnalysisErrorSuite extends AnalysisTest { Seq(a, Alias(InSubquery(Seq(a), ListQuery(LocalRelation(b))), "c")()), LocalRelation(a)) assertAnalysisError(plan, "Predicate sub-queries can only be used" + - " in Filter/DeleteFromTable" :: Nil) + " in Filter" :: Nil) } test("PredicateSubQuery is used is a nested condition") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index a2f45898d273f..6e43c9b8bd80b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -745,6 +745,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case ExternalRDD(outputObjAttr, rdd) => ExternalRDDScanExec(outputObjAttr, rdd) :: Nil case r: LogicalRDD => RDDScanExec(r.output, r.rdd, "ExistingRDD", r.outputPartitioning, r.outputOrdering) :: Nil + case _: UpdateTable => + throw new UnsupportedOperationException(s"UPDATE TABLE is not supported temporarily.") case _ => Nil } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index ddb8938cea901..d353e6b3f56d8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -938,19 +938,19 @@ class DataSourceV2SQLSuite val errorMsg = "Found duplicate column(s) in the table definition of `t`" Seq((true, ("a", "a")), (false, ("aA", "Aa"))).foreach { case (caseSensitive, (c0, c1)) => withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { - testCreateAnalysisError( + assertAnalysisError( s"CREATE TABLE t ($c0 INT, $c1 INT) USING $v2Source", errorMsg ) - testCreateAnalysisError( + assertAnalysisError( s"CREATE TABLE testcat.t ($c0 INT, $c1 INT) USING $v2Source", errorMsg ) - testCreateAnalysisError( + assertAnalysisError( s"CREATE OR REPLACE TABLE t ($c0 INT, $c1 INT) USING $v2Source", errorMsg ) - testCreateAnalysisError( + assertAnalysisError( s"CREATE OR REPLACE TABLE testcat.t ($c0 INT, $c1 INT) USING $v2Source", errorMsg ) @@ -962,19 +962,19 @@ class DataSourceV2SQLSuite val errorMsg = "Found duplicate column(s) in the table definition of `t`" Seq((true, ("a", "a")), (false, ("aA", "Aa"))).foreach { case (caseSensitive, (c0, c1)) => withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { - testCreateAnalysisError( + assertAnalysisError( s"CREATE TABLE t (d struct<$c0: INT, $c1: INT>) USING $v2Source", errorMsg ) - testCreateAnalysisError( + assertAnalysisError( s"CREATE TABLE testcat.t (d struct<$c0: INT, $c1: INT>) USING $v2Source", errorMsg ) - testCreateAnalysisError( + assertAnalysisError( s"CREATE OR REPLACE TABLE t (d struct<$c0: INT, $c1: INT>) USING $v2Source", errorMsg ) - testCreateAnalysisError( + assertAnalysisError( s"CREATE OR REPLACE TABLE testcat.t (d struct<$c0: INT, $c1: INT>) USING $v2Source", errorMsg ) @@ -984,20 +984,20 @@ class DataSourceV2SQLSuite test("tableCreation: bucket column names not in table definition") { val errorMsg = "Couldn't find column c in" - testCreateAnalysisError( + assertAnalysisError( s"CREATE TABLE tbl (a int, b string) USING $v2Source CLUSTERED BY (c) INTO 4 BUCKETS", errorMsg ) - testCreateAnalysisError( + assertAnalysisError( s"CREATE TABLE testcat.tbl (a int, b string) USING $v2Source CLUSTERED BY (c) INTO 4 BUCKETS", errorMsg ) - testCreateAnalysisError( + assertAnalysisError( s"CREATE OR REPLACE TABLE tbl (a int, b string) USING $v2Source " + "CLUSTERED BY (c) INTO 4 BUCKETS", errorMsg ) - testCreateAnalysisError( + assertAnalysisError( s"CREATE OR REPLACE TABLE testcat.tbl (a int, b string) USING $v2Source " + "CLUSTERED BY (c) INTO 4 BUCKETS", errorMsg @@ -1008,19 +1008,19 @@ class DataSourceV2SQLSuite val errorMsg = "Found duplicate column(s) in the partitioning" Seq((true, ("a", "a")), (false, ("aA", "Aa"))).foreach { case (caseSensitive, (c0, c1)) => withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { - testCreateAnalysisError( + assertAnalysisError( s"CREATE TABLE t ($c0 INT) USING $v2Source PARTITIONED BY ($c0, $c1)", errorMsg ) - testCreateAnalysisError( + assertAnalysisError( s"CREATE TABLE testcat.t ($c0 INT) USING $v2Source PARTITIONED BY ($c0, $c1)", errorMsg ) - testCreateAnalysisError( + assertAnalysisError( s"CREATE OR REPLACE TABLE t ($c0 INT) USING $v2Source PARTITIONED BY ($c0, $c1)", errorMsg ) - testCreateAnalysisError( + assertAnalysisError( s"CREATE OR REPLACE TABLE testcat.t ($c0 INT) USING $v2Source PARTITIONED BY ($c0, $c1)", errorMsg ) @@ -1032,22 +1032,22 @@ class DataSourceV2SQLSuite val errorMsg = "Found duplicate column(s) in the bucket definition" Seq((true, ("a", "a")), (false, ("aA", "Aa"))).foreach { case (caseSensitive, (c0, c1)) => withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { - testCreateAnalysisError( + assertAnalysisError( s"CREATE TABLE t ($c0 INT) USING $v2Source " + s"CLUSTERED BY ($c0, $c1) INTO 2 BUCKETS", errorMsg ) - testCreateAnalysisError( + assertAnalysisError( s"CREATE TABLE testcat.t ($c0 INT) USING $v2Source " + s"CLUSTERED BY ($c0, $c1) INTO 2 BUCKETS", errorMsg ) - testCreateAnalysisError( + assertAnalysisError( s"CREATE OR REPLACE TABLE t ($c0 INT) USING $v2Source " + s"CLUSTERED BY ($c0, $c1) INTO 2 BUCKETS", errorMsg ) - testCreateAnalysisError( + assertAnalysisError( s"CREATE OR REPLACE TABLE testcat.t ($c0 INT) USING $v2Source " + s"CLUSTERED BY ($c0, $c1) INTO 2 BUCKETS", errorMsg @@ -1120,7 +1120,7 @@ class DataSourceV2SQLSuite } } - test("Update: basic - update all") { + test("UPDATE TABLE") { val t = "testcat.ns1.ns2.tbl" withTable(t) { sql( @@ -1129,23 +1129,29 @@ class DataSourceV2SQLSuite |USING foo |PARTITIONED BY (id, p) """.stripMargin) - sql( - s""" - |INSERT INTO $t - |VALUES (1L, 'Herry', 26, 1), - |(2L, 'Jack', 31, 2), - |(3L, 'Lisa', 28, 3), - |(4L, 'Frank', 33, 3) - """.stripMargin) + + // UPDATE non-existing table + assertAnalysisError( + "UPDATE dummy SET name='abc'", + "Table not found") + + // UPDATE non-existing column + assertAnalysisError( + s"UPDATE $t SET dummy='abc'", + "cannot resolve") + assertAnalysisError( + s"UPDATE $t SET name='abc' WHERE dummy=1", + "cannot resolve") + + // UPDATE is not implemented yet. + val e = intercept[UnsupportedOperationException] { + sql(s"UPDATE $t SET name='Robert', age=32 WHERE p=1") + } + assert(e.getMessage.contains("UPDATE TABLE is not supported temporarily")) } - val errMsg = "UPDATE TABLE is not supported temporarily" - testCreateAnalysisError( - s"UPDATE $t SET name='Robert', age=32", - errMsg - ) } - private def testCreateAnalysisError(sqlStatement: String, expectedError: String): Unit = { + private def assertAnalysisError(sqlStatement: String, expectedError: String): Unit = { val errMsg = intercept[AnalysisException] { sql(sqlStatement) }.getMessage diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala index 104c845bfcc12..0f4fe656dd20a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala @@ -25,11 +25,12 @@ import org.mockito.Mockito.{mock, when} import org.mockito.invocation.InvocationOnMock import org.apache.spark.sql.{AnalysisException, SaveMode} -import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, Analyzer, EmptyFunctionRegistry, NoSuchTableException, ResolveCatalogs, ResolveSessionCatalog, UnresolvedV2Relation} +import org.apache.spark.sql.catalyst.{AliasIdentifier, TableIdentifier} +import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, Analyzer, EmptyFunctionRegistry, NoSuchTableException, ResolveCatalogs, ResolveSessionCatalog, UnresolvedAttribute, UnresolvedV2Relation} import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType, InMemoryCatalog, SessionCatalog} +import org.apache.spark.sql.catalyst.expressions.{EqualTo, IntegerLiteral, StringLiteral} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser -import org.apache.spark.sql.catalyst.plans.logical.{AlterTable, CreateTableAsSelect, CreateV2Table, DescribeTable, DropTable, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{AlterTable, CreateTableAsSelect, CreateV2Table, DescribeTable, DropTable, LogicalPlan, SubqueryAlias, UpdateTable} import org.apache.spark.sql.connector.InMemoryTableProvider import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogNotFoundException, Identifier, Table, TableCatalog, TableChange, V1Table} import org.apache.spark.sql.execution.datasources.CreateTable @@ -67,7 +68,9 @@ class PlanResolutionSuite extends AnalysisTest { when(newCatalog.loadTable(any())).thenAnswer((invocation: InvocationOnMock) => { invocation.getArgument[Identifier](0).name match { case "v1Table" => - mock(classOf[V1Table]) + val v1Table = mock(classOf[V1Table]) + when(v1Table.schema).thenReturn(new StructType().add("i", "int")) + v1Table case "v2Table" => table case name => @@ -736,11 +739,11 @@ class PlanResolutionSuite extends AnalysisTest { // For non-existing tables, we convert it to v2 command with `UnresolvedV2Table` parsed4 match { case AlterTable(_, _, _: UnresolvedV2Relation, _) => // OK - case _ => fail("unexpected plan:\n" + parsed4.treeString) + case _ => fail("Expect AlterTable, but got:\n" + parsed4.treeString) } parsed5 match { case AlterTable(_, _, _: UnresolvedV2Relation, _) => // OK - case _ => fail("unexpected plan:\n" + parsed5.treeString) + case _ => fail("Expect AlterTable, but got:\n" + parsed5.treeString) } } @@ -767,7 +770,7 @@ class PlanResolutionSuite extends AnalysisTest { TableChange.setProperty("a", "1"), TableChange.setProperty("b", "0.1"), TableChange.setProperty("c", "true"))) - case _ => fail("expect AlterTable") + case _ => fail("Expect AlterTable, but got:\n" + parsed.treeString) } } } @@ -788,13 +791,13 @@ class PlanResolutionSuite extends AnalysisTest { parsed match { case AlterTable(_, _, _: DataSourceV2Relation, changes) => assert(changes == Seq(TableChange.setProperty("location", "new location"))) - case _ => fail("expect AlterTable") + case _ => fail("Expect AlterTable, but got:\n" + parsed.treeString) } } } } - test("describe table") { + test("DESCRIBE TABLE") { Seq("v1Table" -> true, "v2Table" -> false, "testcat.tab" -> false).foreach { case (tblName, useV1Command) => val sql1 = s"DESC TABLE $tblName" @@ -811,13 +814,13 @@ class PlanResolutionSuite extends AnalysisTest { parsed1 match { case DescribeTable(_: DataSourceV2Relation, isExtended) => assert(!isExtended) - case _ => fail("expect DescribeTable") + case _ => fail("Expect DescribeTable, but got:\n" + parsed1.treeString) } parsed2 match { case DescribeTable(_: DataSourceV2Relation, isExtended) => assert(isExtended) - case _ => fail("expect DescribeTable") + case _ => fail("Expect DescribeTable, but got:\n" + parsed2.treeString) } } @@ -839,5 +842,62 @@ class PlanResolutionSuite extends AnalysisTest { assert(parsed4.isInstanceOf[DescribeTableCommand]) } + test("UPDATE TABLE") { + Seq("v1Table", "v2Table", "testcat.tab").foreach { tblName => + val sql1 = s"UPDATE $tblName SET name='Robert', age=32" + val sql2 = s"UPDATE $tblName AS t SET name='Robert', age=32" + val sql3 = s"UPDATE $tblName AS t SET name='Robert', age=32 WHERE p=1" + + val parsed1 = parseAndResolve(sql1) + val parsed2 = parseAndResolve(sql2) + val parsed3 = parseAndResolve(sql3) + + parsed1 match { + case u @ UpdateTable( + _: DataSourceV2Relation, + Seq(name: UnresolvedAttribute, age: UnresolvedAttribute), + Seq(StringLiteral("Robert"), IntegerLiteral(32)), + None) => + assert(name.name == "name") + assert(age.name == "age") + + case _ => fail("Expect UpdateTable, but got:\n" + parsed1.treeString) + } + + parsed2 match { + case UpdateTable( + SubqueryAlias(AliasIdentifier("t", None), _: DataSourceV2Relation), + Seq(name: UnresolvedAttribute, age: UnresolvedAttribute), + Seq(StringLiteral("Robert"), IntegerLiteral(32)), + None) => + assert(name.name == "name") + assert(age.name == "age") + + case _ => fail("Expect UpdateTable, but got:\n" + parsed2.treeString) + } + + parsed3 match { + case UpdateTable( + SubqueryAlias(AliasIdentifier("t", None), _: DataSourceV2Relation), + Seq(name: UnresolvedAttribute, age: UnresolvedAttribute), + Seq(StringLiteral("Robert"), IntegerLiteral(32)), + Some(EqualTo(p: UnresolvedAttribute, IntegerLiteral(1)))) => + assert(name.name == "name") + assert(age.name == "age") + assert(p.name == "p") + + case _ => fail("Expect UpdateTable, but got:\n" + parsed3.treeString) + } + } + + val sql = "UPDATE non_existing SET id=1" + val parsed = parseAndResolve(sql) + parsed match { + case u: UpdateTable => + assert(u.table.isInstanceOf[UnresolvedV2Relation]) + case _ => fail("Expect UpdateTable, but got:\n" + parsed.treeString) + } + } + // TODO: add tests for more commands. }