diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index a7027e89e187..e72aaaf06f35 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -471,6 +471,20 @@ ], "sqlState" : "0A000" }, + "CLUSTERING_COLUMNS_MISMATCH" : { + "message" : [ + "Specified clustering does not match that of the existing table .", + "Specified clustering columns: [].", + "Existing clustering columns: []." + ], + "sqlState" : "42P10" + }, + "CLUSTERING_NOT_SUPPORTED" : { + "message" : [ + "'' does not support clustering." + ], + "sqlState" : "42000" + }, "CODEC_NOT_AVAILABLE" : { "message" : [ "The codec is not available." diff --git a/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 4702f09a14c2..0bb3c3a6ecb8 100644 --- a/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -3072,6 +3072,11 @@ class SparkConnectPlanner( w.partitionBy(names.toSeq: _*) } + if (writeOperation.getClusteringColumnsCount > 0) { + val names = writeOperation.getClusteringColumnsList.asScala + w.clusterBy(names.head, names.tail.toSeq: _*) + } + if (writeOperation.hasSource) { w.format(writeOperation.getSource) } @@ -3135,6 +3140,11 @@ class SparkConnectPlanner( w.partitionedBy(names.head, names.tail: _*) } + if (writeOperation.getClusteringColumnsCount > 0) { + val names = writeOperation.getClusteringColumnsList.asScala + w.clusterBy(names.head, names.tail.toSeq: _*) + } + writeOperation.getMode match { case proto.WriteOperationV2.Mode.MODE_CREATE => if (writeOperation.hasProvider) { diff --git a/connect/server/src/test/scala/org/apache/spark/sql/connect/dsl/package.scala b/connect/server/src/test/scala/org/apache/spark/sql/connect/dsl/package.scala index 3edb63ee8e81..fdbfc39cc9a5 100644 --- a/connect/server/src/test/scala/org/apache/spark/sql/connect/dsl/package.scala +++ b/connect/server/src/test/scala/org/apache/spark/sql/connect/dsl/package.scala @@ -219,6 +219,7 @@ package object dsl { mode: Option[String] = None, sortByColumns: Seq[String] = Seq.empty, partitionByCols: Seq[String] = Seq.empty, + clusterByCols: Seq[String] = Seq.empty, bucketByCols: Seq[String] = Seq.empty, numBuckets: Option[Int] = None): Command = { val writeOp = WriteOperation.newBuilder() @@ -242,6 +243,7 @@ package object dsl { } sortByColumns.foreach(writeOp.addSortColumnNames(_)) partitionByCols.foreach(writeOp.addPartitioningColumns(_)) + clusterByCols.foreach(writeOp.addClusteringColumns(_)) if (numBuckets.nonEmpty && bucketByCols.nonEmpty) { val op = WriteOperation.BucketBy.newBuilder() @@ -272,6 +274,7 @@ package object dsl { options: Map[String, String] = Map.empty, tableProperties: Map[String, String] = Map.empty, partitionByCols: Seq[Expression] = Seq.empty, + clusterByCols: Seq[String] = Seq.empty, mode: Option[String] = None, overwriteCondition: Option[Expression] = None): Command = { val writeOp = WriteOperationV2.newBuilder() @@ -279,6 +282,7 @@ package object dsl { tableName.foreach(writeOp.setTableName) provider.foreach(writeOp.setProvider) partitionByCols.foreach(writeOp.addPartitioningColumns) + clusterByCols.foreach(writeOp.addClusteringColumns) options.foreach { case (k, v) => writeOp.putOptions(k, v) } diff --git a/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala b/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala index 6721555220fe..190f8cde16f5 100644 --- a/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala +++ b/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala @@ -596,6 +596,48 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { } } + test("Write with clustering") { + // Cluster by existing column. + withTable("testtable") { + transform( + localRelation.write( + tableName = Some("testtable"), + tableSaveMethod = Some("save_as_table"), + format = Some("parquet"), + clusterByCols = Seq("id"))) + } + + // Cluster by non-existing column. + assertThrows[AnalysisException]( + transform( + localRelation + .write( + tableName = Some("testtable"), + tableSaveMethod = Some("save_as_table"), + format = Some("parquet"), + clusterByCols = Seq("noid")))) + } + + test("Write V2 with clustering") { + // Cluster by existing column. + withTable("testtable") { + transform( + localRelation.writeV2( + tableName = Some("testtable"), + mode = Some("MODE_CREATE"), + clusterByCols = Seq("id"))) + } + + // Cluster by non-existing column. + assertThrows[AnalysisException]( + transform( + localRelation + .writeV2( + tableName = Some("testtable"), + mode = Some("MODE_CREATE"), + clusterByCols = Seq("noid")))) + } + test("Write with invalid bucketBy configuration") { val cmd = localRelation.write(bucketByCols = Seq("id"), numBuckets = Some(0)) assertThrows[InvalidCommandInput] { diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 563a9865e73f..616bc8151396 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -201,6 +201,22 @@ final class DataFrameWriter[T] private[sql] (ds: Dataset[T]) { this } + /** + * Clusters the output by the given columns on the storage. The rows with matching values in the + * specified clustering columns will be consolidated within the same group. + * + * For instance, if you cluster a dataset by date, the data sharing the same date will be stored + * together in a file. This arrangement improves query efficiency when you apply selective + * filters to these clustering columns, thanks to data skipping. + * + * @since 4.0.0 + */ + @scala.annotation.varargs + def clusterBy(colName: String, colNames: String*): DataFrameWriter[T] = { + this.clusteringColumns = Option(colName +: colNames) + this + } + /** * Saves the content of the `DataFrame` at the specified path. * @@ -242,6 +258,7 @@ final class DataFrameWriter[T] private[sql] (ds: Dataset[T]) { source.foreach(builder.setSource) sortColumnNames.foreach(names => builder.addAllSortColumnNames(names.asJava)) partitioningColumns.foreach(cols => builder.addAllPartitioningColumns(cols.asJava)) + clusteringColumns.foreach(cols => builder.addAllClusteringColumns(cols.asJava)) numBuckets.foreach(n => { val bucketBuilder = proto.WriteOperation.BucketBy.newBuilder() @@ -515,4 +532,6 @@ final class DataFrameWriter[T] private[sql] (ds: Dataset[T]) { private var numBuckets: Option[Int] = None private var sortColumnNames: Option[Seq[String]] = None + + private var clusteringColumns: Option[Seq[String]] = None } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala index 7107895c0ad2..cb7e1f13bd01 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala @@ -41,6 +41,8 @@ final class DataFrameWriterV2[T] private[sql] (table: String, ds: Dataset[T]) private var partitioning: Option[Seq[proto.Expression]] = None + private var clustering: Option[Seq[String]] = None + private var overwriteCondition: Option[proto.Expression] = None override def using(provider: String): CreateTableWriter[T] = { @@ -77,6 +79,12 @@ final class DataFrameWriterV2[T] private[sql] (table: String, ds: Dataset[T]) this } + @scala.annotation.varargs + override def clusterBy(colName: String, colNames: String*): CreateTableWriter[T] = { + this.clustering = Some(colName +: colNames) + this + } + override def create(): Unit = { executeWriteOperation(proto.WriteOperationV2.Mode.MODE_CREATE) } @@ -145,6 +153,7 @@ final class DataFrameWriterV2[T] private[sql] (table: String, ds: Dataset[T]) provider.foreach(builder.setProvider) partitioning.foreach(columns => builder.addAllPartitioningColumns(columns.asJava)) + clustering.foreach(columns => builder.addAllClusteringColumns(columns.asJava)) options.foreach { case (k, v) => builder.putOptions(k, v) @@ -272,8 +281,22 @@ trait CreateTableWriter[T] extends WriteConfigMethods[CreateTableWriter[T]] { * * @since 3.4.0 */ + @scala.annotation.varargs def partitionedBy(column: Column, columns: Column*): CreateTableWriter[T] + /** + * Clusters the output by the given columns on the storage. The rows with matching values in the + * specified clustering columns will be consolidated within the same group. + * + * For instance, if you cluster a dataset by date, the data sharing the same date will be stored + * together in a file. This arrangement improves query efficiency when you apply selective + * filters to these clustering columns, thanks to data skipping. + * + * @since 4.0.0 + */ + @scala.annotation.varargs + def clusterBy(colName: String, colNames: String*): CreateTableWriter[T] + /** * Specifies a provider for the underlying output data source. Spark's default catalog supports * "parquet", "json", etc. diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDatasetSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDatasetSuite.scala index 9d6f07cf603a..c69cbcf6332e 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDatasetSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDatasetSuite.scala @@ -85,6 +85,7 @@ class ClientDatasetSuite extends ConnectFunSuite with BeforeAndAfterEach { .setNumBuckets(2) .addBucketColumnNames("col1") .addBucketColumnNames("col2")) + .addClusteringColumns("col3") val expectedPlan = proto.Plan .newBuilder() @@ -95,6 +96,7 @@ class ClientDatasetSuite extends ConnectFunSuite with BeforeAndAfterEach { .sortBy("col1") .partitionBy("col99") .bucketBy(2, "col1", "col2") + .clusterBy("col3") .parquet("my/test/path") val actualPlan = service.getAndClearLatestInputPlan() assert(actualPlan.equals(expectedPlan)) @@ -136,6 +138,7 @@ class ClientDatasetSuite extends ConnectFunSuite with BeforeAndAfterEach { .setTableName("t1") .addPartitioningColumns(col("col99").expr) .setProvider("json") + .addClusteringColumns("col3") .putTableProperties("key", "value") .putOptions("key2", "value2") .setMode(proto.WriteOperationV2.Mode.MODE_CREATE_OR_REPLACE) @@ -147,6 +150,7 @@ class ClientDatasetSuite extends ConnectFunSuite with BeforeAndAfterEach { df.writeTo("t1") .partitionedBy(col("col99")) + .clusterBy("col3") .using("json") .tableProperty("key", "value") .options(Map("key2" -> "value2")) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 8cf872e4dd0f..c126d12b1473 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -100,7 +100,9 @@ object MimaExcludes { ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SQLContext#implicits._sqlContext"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SQLImplicits._sqlContext"), ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.SQLImplicits.session"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SparkSession#implicits._sqlContext") + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SparkSession#implicits._sqlContext"), + // SPARK-48761: Add clusterBy() to CreateTableWriter. + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.CreateTableWriter.clusterBy") ) // Default exclude rules diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index c281b0df8a6d..dcd1d3137da3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -197,10 +197,22 @@ object ClusterBySpec { ret } + /** + * Converts the clustering column property to a ClusterBySpec. + */ def fromProperty(columns: String): ClusterBySpec = { ClusterBySpec(mapper.readValue[Seq[Seq[String]]](columns).map(FieldReference(_))) } + /** + * Converts a ClusterBySpec to a clustering column property map entry, with validation + * of the column names against the schema. + * + * @param schema the schema of the table. + * @param clusterBySpec the ClusterBySpec to be converted to a property. + * @param resolver the resolver used to match the column names. + * @return a map entry for the clustering column property. + */ def toProperty( schema: StructType, clusterBySpec: ClusterBySpec, @@ -209,10 +221,25 @@ object ClusterBySpec { normalizeClusterBySpec(schema, clusterBySpec, resolver).toJson } + /** + * Converts a ClusterBySpec to a clustering column property map entry, without validating + * the column names against the schema. + * + * @param clusterBySpec existing ClusterBySpec to be converted to properties. + * @return a map entry for the clustering column property. + */ + def toPropertyWithoutValidation(clusterBySpec: ClusterBySpec): (String, String) = { + (CatalogTable.PROP_CLUSTERING_COLUMNS -> clusterBySpec.toJson) + } + private def normalizeClusterBySpec( schema: StructType, clusterBySpec: ClusterBySpec, resolver: Resolver): ClusterBySpec = { + if (schema.isEmpty) { + return clusterBySpec + } + val normalizedColumns = clusterBySpec.columnNames.map { columnName => val position = SchemaUtils.findColumnPosition( columnName.fieldNames().toImmutableArraySeq, schema, resolver) @@ -239,6 +266,10 @@ object ClusterBySpec { val normalizedClusterBySpec = normalizeClusterBySpec(schema, clusterBySpec, resolver) ClusterByTransform(normalizedClusterBySpec.columnNames) } + + def fromColumnNames(names: Seq[String]): ClusterBySpec = { + ClusterBySpec(names.map(FieldReference(_))) + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index 73a98f9fe4be..75a9fdb1a6be 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -1866,6 +1866,18 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat "existingBucketString" -> existingBucketString)) } + def mismatchedTableClusteringError( + tableName: String, + specifiedClusteringString: String, + existingClusteringString: String): Throwable = { + new AnalysisException( + errorClass = "CLUSTERING_COLUMNS_MISMATCH", + messageParameters = Map( + "tableName" -> tableName, + "specifiedClusteringString" -> specifiedClusteringString, + "existingClusteringString" -> existingClusteringString)) + } + def specifyPartitionNotAllowedWhenTableSchemaNotDefinedError(): Throwable = { new AnalysisException( errorClass = "_LEGACY_ERROR_TEMP_1165", @@ -4108,4 +4120,22 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat messageParameters = Map("functionName" -> functionName) ) } + + def operationNotSupportClusteringError(operation: String): Throwable = { + new AnalysisException( + errorClass = "CLUSTERING_NOT_SUPPORTED", + messageParameters = Map("operation" -> operation)) + } + + def clusterByWithPartitionedBy(): Throwable = { + new AnalysisException( + errorClass = "SPECIFY_CLUSTER_BY_WITH_PARTITIONED_BY_IS_NOT_ALLOWED", + messageParameters = Map.empty) + } + + def clusterByWithBucketing(): Throwable = { + new AnalysisException( + errorClass = "SPECIFY_CLUSTER_BY_WITH_BUCKETING_IS_NOT_ALLOWED", + messageParameters = Map.empty) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala index 505a5a616920..852e39931626 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala @@ -194,6 +194,10 @@ abstract class InMemoryBaseTable( case (v, t) => throw new IllegalArgumentException(s"Match: unsupported argument(s) type - ($v, $t)") } + case ClusterByTransform(columnNames) => + columnNames.map { colName => + extractor(colName.fieldNames, cleanedSchema, row)._1 + } }.toImmutableArraySeq } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 2d6d5f0e8b2b..991487170f17 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{AppendData, CreateTableAsSel import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.connector.catalog.{CatalogPlugin, CatalogV2Implicits, CatalogV2Util, Identifier, SupportsCatalogOptions, Table, TableCatalog, TableProvider, V1Table} import org.apache.spark.sql.connector.catalog.TableCapability._ -import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform, Transform} +import org.apache.spark.sql.connector.expressions.{ClusterByTransform, FieldReference, IdentityTransform, Transform} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.command.DDLUtils @@ -193,6 +193,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { @scala.annotation.varargs def partitionBy(colNames: String*): DataFrameWriter[T] = { this.partitioningColumns = Option(colNames) + validatePartitioning() this } @@ -210,6 +211,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { def bucketBy(numBuckets: Int, colName: String, colNames: String*): DataFrameWriter[T] = { this.numBuckets = Option(numBuckets) this.bucketColumnNames = Option(colName +: colNames) + validatePartitioning() this } @@ -227,6 +229,23 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { this } + /** + * Clusters the output by the given columns on the storage. The rows with matching values in the + * specified clustering columns will be consolidated within the same group. + * + * For instance, if you cluster a dataset by date, the data sharing the same date will be stored + * together in a file. This arrangement improves query efficiency when you apply selective + * filters to these clustering columns, thanks to data skipping. + * + * @since 4.0 + */ + @scala.annotation.varargs + def clusterBy(colName: String, colNames: String*): DataFrameWriter[T] = { + this.clusteringColumns = Option(colName +: colNames) + validatePartitioning() + this + } + /** * Saves the content of the `DataFrame` at the specified path. * @@ -377,6 +396,11 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { DataSourceUtils.PARTITIONING_COLUMNS_KEY -> DataSourceUtils.encodePartitioningColumns(columns)) } + clusteringColumns.foreach { columns => + extraOptions = extraOptions + ( + DataSourceUtils.CLUSTERING_COLUMNS_KEY -> + DataSourceUtils.encodePartitioningColumns(columns)) + } val optionsWithPath = getOptionsWithPath(path) @@ -515,6 +539,12 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { } } + private def assertNotClustered(operation: String): Unit = { + if (clusteringColumns.isDefined) { + throw QueryCompilationErrors.operationNotSupportClusteringError(operation) + } + } + /** * Saves the content of the `DataFrame` as the specified table. * @@ -688,6 +718,13 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { CatalogTableType.MANAGED } + val properties = if (clusteringColumns.isEmpty) { + Map.empty[String, String] + } else { + Map(ClusterBySpec.toPropertyWithoutValidation( + ClusterBySpec.fromColumnNames(clusteringColumns.get))) + } + val tableDesc = CatalogTable( identifier = tableIdent, tableType = tableType, @@ -695,7 +732,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { schema = new StructType, provider = Some(source), partitionColumnNames = partitioningColumns.getOrElse(Nil), - bucketSpec = getBucketSpec) + bucketSpec = getBucketSpec, + properties = properties) runCommand(df.sparkSession)( CreateTable(tableDesc, mode, Some(df.logicalPlan))) @@ -708,7 +746,10 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { }.getOrElse(Seq.empty[Transform]) val bucketing = getBucketSpec.map(spec => CatalogV2Implicits.BucketSpecHelper(spec).asTransform).toSeq - partitioning ++ bucketing + val clustering = clusteringColumns.map { colNames => + ClusterByTransform(colNames.map(FieldReference(_))) + } + partitioning ++ bucketing ++ clustering } /** @@ -719,11 +760,25 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { val v2Partitions = partitioningAsV2 if (v2Partitions.isEmpty) return require(v2Partitions.sameElements(existingTable.partitioning()), - "The provided partitioning does not match of the table.\n" + + "The provided partitioning or clustering columns do not match the existing table's.\n" + s" - provided: ${v2Partitions.mkString(", ")}\n" + s" - table: ${existingTable.partitioning().mkString(", ")}") } + /** + * Validate that clusterBy is not used with partitionBy or bucketBy. + */ + private def validatePartitioning(): Unit = { + if (clusteringColumns.nonEmpty) { + if (partitioningColumns.nonEmpty) { + throw QueryCompilationErrors.clusterByWithPartitionedBy() + } + if (getBucketSpec.nonEmpty) { + throw QueryCompilationErrors.clusterByWithBucketing() + } + } + } + /** * Saves the content of the `DataFrame` to an external database table via JDBC. In the case the * table already exists in the external database, behavior of this function depends on the @@ -750,6 +805,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { def jdbc(url: String, table: String, connectionProperties: Properties): Unit = { assertNotPartitioned("jdbc") assertNotBucketed("jdbc") + assertNotClustered("jdbc") // connectionProperties should override settings in extraOptions. this.extraOptions ++= connectionProperties.asScala // explicit url and dbtable should override all @@ -917,4 +973,6 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { private var numBuckets: Option[Int] = None private var sortColumnNames: Option[Seq[String]] = None + + private var clusteringColumns: Option[Seq[String]] = None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala index b68a13ba2159..df1f8b5c6dfe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala @@ -24,7 +24,7 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NoSuchTableException, TableAlreadyExistsException, UnresolvedIdentifier, UnresolvedRelation} import org.apache.spark.sql.catalyst.expressions.{Attribute, Bucket, Days, Hours, Literal, Months, Years} import org.apache.spark.sql.catalyst.plans.logical.{AppendData, CreateTableAsSelect, LogicalPlan, OptionList, OverwriteByExpression, OverwritePartitionsDynamic, ReplaceTableAsSelect, UnresolvedTableSpec} -import org.apache.spark.sql.connector.expressions.{LogicalExpressions, NamedReference, Transform} +import org.apache.spark.sql.connector.expressions.{ClusterByTransform, FieldReference, LogicalExpressions, NamedReference, Transform} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.types.IntegerType @@ -54,6 +54,8 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T]) private var partitioning: Option[Seq[Transform]] = None + private var clustering: Option[ClusterByTransform] = None + override def using(provider: String): CreateTableWriter[T] = { this.provider = Some(provider) this @@ -104,9 +106,27 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T]) } this.partitioning = Some(asTransforms) + validatePartitioning() + this + } + + @scala.annotation.varargs + override def clusterBy(colName: String, colNames: String*): CreateTableWriter[T] = { + this.clustering = + Some(ClusterByTransform((colName +: colNames).map(col => FieldReference(col)))) + validatePartitioning() this } + /** + * Validate that clusterBy is not used with partitionBy. + */ + private def validatePartitioning(): Unit = { + if (partitioning.nonEmpty && clustering.nonEmpty) { + throw QueryCompilationErrors.clusterByWithPartitionedBy() + } + } + override def create(): Unit = { val tableSpec = UnresolvedTableSpec( properties = properties.toMap, @@ -119,7 +139,7 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T]) runCommand( CreateTableAsSelect( UnresolvedIdentifier(tableName), - partitioning.getOrElse(Seq.empty), + partitioning.getOrElse(Seq.empty) ++ clustering, logicalPlan, tableSpec, options.toMap, @@ -207,7 +227,7 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T]) external = false) runCommand(ReplaceTableAsSelect( UnresolvedIdentifier(tableName), - partitioning.getOrElse(Seq.empty), + partitioning.getOrElse(Seq.empty) ++ clustering, logicalPlan, tableSpec, writeOptions = options.toMap, @@ -328,8 +348,22 @@ trait CreateTableWriter[T] extends WriteConfigMethods[CreateTableWriter[T]] { * * @since 3.0.0 */ + @scala.annotation.varargs def partitionedBy(column: Column, columns: Column*): CreateTableWriter[T] + /** + * Clusters the output by the given columns on the storage. The rows with matching values in + * the specified clustering columns will be consolidated within the same group. + * + * For instance, if you cluster a dataset by date, the data sharing the same date will be stored + * together in a file. This arrangement improves query efficiency when you apply selective + * filters to these clustering columns, thanks to data skipping. + * + * @since 4.0.0 + */ + @scala.annotation.varargs + def clusterBy(colName: String, colNames: String*): CreateTableWriter[T] + /** * Specifies a provider for the underlying output data source. Spark's default catalog supports * "parquet", "json", etc. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala index c80dc8307967..81eadcc263c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala @@ -52,6 +52,11 @@ object DataSourceUtils extends PredicateHelper { */ val PARTITION_OVERWRITE_MODE = "partitionOverwriteMode" + /** + * The key to use for storing clusterBy columns as options. + */ + val CLUSTERING_COLUMNS_KEY = "__clustering_columns" + /** * Utility methods for converting partitionBy columns to options and back. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index e4c3cd20dedb..37ccf54d932b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -198,6 +198,18 @@ case class PreprocessTableCreation(catalog: SessionCatalog) extends Rule[Logical tableName, specifiedBucketString, existingBucketString) } + // Check if the specified clustering columns match the existing table. + val specifiedClusterBySpec = tableDesc.clusterBySpec + val existingClusterBySpec = existingTable.clusterBySpec + if (specifiedClusterBySpec != existingClusterBySpec) { + val specifiedClusteringString = + specifiedClusterBySpec.map(_.toString).getOrElse("") + val existingClusteringString = + existingClusterBySpec.map(_.toString).getOrElse("") + throw QueryCompilationErrors.mismatchedTableClusteringError( + tableName, specifiedClusteringString, existingClusteringString) + } + val newQuery = if (adjustedColumns != query.output) { Project(adjustedColumns, query) } else { @@ -319,7 +331,12 @@ case class PreprocessTableCreation(catalog: SessionCatalog) extends Rule[Logical } } - table.copy(partitionColumnNames = normalizedPartCols, bucketSpec = normalizedBucketSpec) + val normalizedProperties = table.properties ++ table.clusterBySpec.map { spec => + ClusterBySpec.toProperty(schema, spec, conf.resolver) + } + + table.copy(partitionColumnNames = normalizedPartCols, bucketSpec = normalizedBucketSpec, + properties = normalizedProperties) } private def normalizePartitionColumns(schema: StructType, table: CatalogTable): Seq[String] = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala index 44d47abc93fa..2275d8c21397 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, Ove import org.apache.spark.sql.connector.InMemoryV1Provider import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTable, InMemoryTableCatalog, TableCatalog} import org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAME -import org.apache.spark.sql.connector.expressions.{BucketTransform, DaysTransform, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, YearsTransform} +import org.apache.spark.sql.connector.expressions.{BucketTransform, ClusterByTransform, DaysTransform, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, YearsTransform} import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.execution.streaming.MemoryStream @@ -524,6 +524,18 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo Seq(BucketTransform(LiteralValue(4, IntegerType), Seq(FieldReference("id"))))) } + test("Create: cluster by") { + spark.table("source") + .writeTo("testcat.table_name") + .clusterBy("id") + .create() + + val table = catalog("testcat").loadTable(Identifier.of(Array(), "table_name")) + + assert(table.name === "testcat.table_name") + assert(table.partitioning === Seq(ClusterByTransform(Seq(FieldReference("id"))))) + } + test("Create: fail if table already exists") { spark.sql( "CREATE TABLE testcat.table_name (id bigint, data string) USING foo PARTITIONED BY (id)") @@ -634,6 +646,42 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo assert(replaced.properties === defaultOwnership.asJava) } + test("Replace: clustered table") { + spark.sql("CREATE TABLE testcat.table_name (id bigint, data string) USING foo") + spark.sql("INSERT INTO TABLE testcat.table_name SELECT * FROM source") + + checkAnswer( + spark.table("testcat.table_name"), + Seq(Row(1L, "a"), Row(2L, "b"), Row(3L, "c"))) + + val table = catalog("testcat").loadTable(Identifier.of(Array(), "table_name")) + + // validate the initial table + assert(table.name === "testcat.table_name") + assert(table.schema === new StructType().add("id", LongType).add("data", StringType)) + assert(table.partitioning.isEmpty) + assert(table.properties === (Map("provider" -> "foo") ++ defaultOwnership).asJava) + + spark.table("source2") + .withColumn("even_or_odd", when(($"id" % 2) === 0, "even").otherwise("odd")) + .writeTo("testcat.table_name").clusterBy("id").replace() + + checkAnswer( + spark.table("testcat.table_name"), + Seq(Row(4L, "d", "even"), Row(5L, "e", "odd"), Row(6L, "f", "even"))) + + val replaced = catalog("testcat").loadTable(Identifier.of(Array(), "table_name")) + + // validate the replacement table + assert(replaced.name === "testcat.table_name") + assert(replaced.schema === new StructType() + .add("id", LongType) + .add("data", StringType) + .add("even_or_odd", StringType)) + assert(replaced.partitioning === Seq(ClusterByTransform(Seq(FieldReference("id"))))) + assert(replaced.properties === defaultOwnership.asJava) + } + test("Replace: fail if table does not exist") { val exc = intercept[CannotReplaceMissingTableException] { spark.table("source").writeTo("testcat.table_name").replace() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DescribeTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DescribeTableSuiteBase.scala index 02e8a5e68999..c4e9ff93ef85 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DescribeTableSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DescribeTableSuiteBase.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.command import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.util.quoteIdentifier +import org.apache.spark.sql.functions.{col, struct} import org.apache.spark.sql.types.{BooleanType, MetadataBuilder, StringType, StructType} /** @@ -242,4 +243,54 @@ trait DescribeTableSuiteBase extends QueryTest with DDLCommandTestUtils { Row("# col_name", "data_type", "comment"))) } } + + test("describe a clustered table - dataframe writer v1") { + withNamespaceAndTable("ns", "tbl") { tbl => + val df = spark.range(10).select( + col("id").cast("string").as("col1"), + struct(col("id").cast("int").as("x"), col("id").cast("int").as("y")).as("col2")) + df.write.mode("append").clusterBy("col1", "col2.x").saveAsTable(tbl) + val descriptionDf = sql(s"DESC $tbl") + + descriptionDf.show(false) + assert(descriptionDf.schema.map(field => (field.name, field.dataType)) === Seq( + ("col_name", StringType), + ("data_type", StringType), + ("comment", StringType))) + QueryTest.checkAnswer( + descriptionDf, + Seq( + Row("col1", "string", null), + Row("col2", "struct", null), + Row("# Clustering Information", "", ""), + Row("# col_name", "data_type", "comment"), + Row("col2.x", "int", null), + Row("col1", "string", null))) + } + } + + test("describe a clustered table - dataframe writer v2") { + withNamespaceAndTable("ns", "tbl") { tbl => + val df = spark.range(10).select( + col("id").cast("string").as("col1"), + struct(col("id").cast("int").as("x"), col("id").cast("int").as("y")).as("col2")) + df.writeTo(tbl).clusterBy("col1", "col2.x").create() + val descriptionDf = sql(s"DESC $tbl") + + descriptionDf.show(false) + assert(descriptionDf.schema.map(field => (field.name, field.dataType)) === Seq( + ("col_name", StringType), + ("data_type", StringType), + ("comment", StringType))) + QueryTest.checkAnswer( + descriptionDf, + Seq( + Row("col1", "string", null), + Row("col2", "struct", null), + Row("# Clustering Information", "", ""), + Row("# col_name", "data_type", "comment"), + Row("col2.x", "int", null), + Row("col1", "string", null))) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index 603ee74ce333..377c422b22ac 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -286,6 +286,75 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSparkSession with assert(DataSourceUtils.decodePartitioningColumns(partColumns) === Seq("col1", "col2")) } + test("pass clusterBy as options") { + Seq(1).toDF().write + .format("org.apache.spark.sql.test") + .clusterBy("col1", "col2") + .save() + + val clusteringColumns = LastOptions.parameters(DataSourceUtils.CLUSTERING_COLUMNS_KEY) + assert(DataSourceUtils.decodePartitioningColumns(clusteringColumns) === Seq("col1", "col2")) + } + + test("Clustering columns should match when appending to existing data source tables") { + import testImplicits._ + val df = Seq((1, 2, 3)).toDF("a", "b", "c") + withTable("clusteredTable") { + df.write.mode("overwrite").clusterBy("a", "b").saveAsTable("clusteredTable") + // Misses some clustering columns + checkError( + exception = intercept[AnalysisException] { + df.write.mode("append").clusterBy("a").saveAsTable("clusteredTable") + }, + errorClass = "CLUSTERING_COLUMNS_MISMATCH", + parameters = Map( + "tableName" -> "spark_catalog.default.clusteredtable", + "specifiedClusteringString" -> """[["a"]]""", + "existingClusteringString" -> """[["a"],["b"]]""") + ) + // Wrong order + checkError( + exception = intercept[AnalysisException] { + df.write.mode("append").clusterBy("b", "a").saveAsTable("clusteredTable") + }, + errorClass = "CLUSTERING_COLUMNS_MISMATCH", + parameters = Map( + "tableName" -> "spark_catalog.default.clusteredtable", + "specifiedClusteringString" -> """[["b"],["a"]]""", + "existingClusteringString" -> """[["a"],["b"]]""") + ) + // Clustering columns not specified + checkError( + exception = intercept[AnalysisException] { + df.write.mode("append").saveAsTable("clusteredTable") + }, + errorClass = "CLUSTERING_COLUMNS_MISMATCH", + parameters = Map( + "tableName" -> "spark_catalog.default.clusteredtable", + "specifiedClusteringString" -> "", "existingClusteringString" -> """[["a"],["b"]]""") + ) + assert(sql("select * from clusteredTable").collect().length == 1) + // Inserts new data successfully when clustering columns are correctly specified in + // clusterBy(...). + Seq((4, 5, 6)).toDF("a", "b", "c") + .write + .mode("append") + .clusterBy("a", "b") + .saveAsTable("clusteredTable") + + Seq((7, 8, 9)).toDF("a", "b", "c") + .write + .mode("append") + .clusterBy("a", "b") + .saveAsTable("clusteredTable") + + checkAnswer( + sql("select a, b, c from clusteredTable"), + Row(1, 2, 3) :: Row(4, 5, 6) :: Row(7, 8, 9) :: Nil + ) + } + } + test ("SPARK-29537: throw exception when user defined a wrong base path") { withTempPath { p => val path = new Path(p.toURI).toString @@ -490,7 +559,7 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSparkSession with assert(LastOptions.parameters("doubleOpt") == "6.7") } - test("check jdbc() does not support partitioning, bucketBy or sortBy") { + test("check jdbc() does not support partitioning, bucketBy, clusterBy or sortBy") { val df = spark.read.text(Utils.createTempDir(namePrefix = "text").getCanonicalPath) var w = df.write.partitionBy("value") @@ -505,6 +574,12 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSparkSession with assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) } + w = df.write.clusterBy("value") + e = intercept[AnalysisException](w.jdbc(null, null, null)) + Seq("jdbc", "clustering").foreach { s => + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) + } + w = df.write.sortBy("value") e = intercept[AnalysisException](w.jdbc(null, null, null)) Seq("sortBy must be used together with bucketBy").foreach { s =>