diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index b64b1d0ab08a..25c6076c14fc 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -1655,6 +1655,7 @@ def __init__(self, child: "LogicalPlan") -> None: self.mode: Optional[str] = None self.sort_cols: List[str] = [] self.partitioning_cols: List[str] = [] + self.clustering_cols: List[str] = [] self.options: Dict[str, Optional[str]] = {} self.num_buckets: int = -1 self.bucket_cols: List[str] = [] @@ -1668,6 +1669,7 @@ def command(self, session: "SparkConnectClient") -> proto.Command: plan.write_operation.source = self.source plan.write_operation.sort_column_names.extend(self.sort_cols) plan.write_operation.partitioning_columns.extend(self.partitioning_cols) + plan.write_operation.clustering_columns.extend(self.clustering_cols) if self.num_buckets > 0: plan.write_operation.bucket_by.bucket_column_names.extend(self.bucket_cols) @@ -1727,6 +1729,7 @@ def print(self, indent: int = 0) -> str: f"mode='{self.mode}' " f"sort_cols='{self.sort_cols}' " f"partitioning_cols='{self.partitioning_cols}' " + f"clustering_cols='{self.clustering_cols}' " f"num_buckets='{self.num_buckets}' " f"bucket_cols='{self.bucket_cols}' " f"options='{self.options}'>" @@ -1741,6 +1744,7 @@ def _repr_html_(self) -> str: f"mode: '{self.mode}'
" f"sort_cols: '{self.sort_cols}'
" f"partitioning_cols: '{self.partitioning_cols}'
" + f"clustering_cols: '{self.clustering_cols}'
" f"num_buckets: '{self.num_buckets}'
" f"bucket_cols: '{self.bucket_cols}'
" f"options: '{self.options}'
" @@ -1754,6 +1758,7 @@ def __init__(self, child: "LogicalPlan", table_name: str) -> None: self.table_name: Optional[str] = table_name self.provider: Optional[str] = None self.partitioning_columns: List[Column] = [] + self.clustering_columns: List[str] = [] self.options: dict[str, Optional[str]] = {} self.table_properties: dict[str, Optional[str]] = {} self.mode: Optional[str] = None @@ -1771,6 +1776,7 @@ def command(self, session: "SparkConnectClient") -> proto.Command: plan.write_operation_v2.partitioning_columns.extend( [c.to_plan(session) for c in self.partitioning_columns] ) + plan.write_operation_v2.clustering_columns.extend(self.clustering_columns) for k in self.options: if self.options[k] is None: diff --git a/python/pyspark/sql/connect/readwriter.py b/python/pyspark/sql/connect/readwriter.py index 0a86dd22b9de..626b27a9d2ad 100644 --- a/python/pyspark/sql/connect/readwriter.py +++ b/python/pyspark/sql/connect/readwriter.py @@ -643,6 +643,23 @@ def sortBy( sortBy.__doc__ = PySparkDataFrameWriter.sortBy.__doc__ + @overload + def clusterBy(self, *cols: str) -> "DataFrameWriter": + ... + + @overload + def clusterBy(self, *cols: List[str]) -> "DataFrameWriter": + ... + + def clusterBy(self, *cols: Union[str, List[str]]) -> "DataFrameWriter": + if len(cols) == 1 and isinstance(cols[0], (list, tuple)): + cols = cols[0] # type: ignore[assignment] + assert len(cols) > 0, "clusterBy needs one or more clustering columns." + self._write.clustering_cols = cast(List[str], cols) + return self + + clusterBy.__doc__ = PySparkDataFrameWriter.clusterBy.__doc__ + def save( self, path: Optional[str] = None, @@ -900,6 +917,13 @@ def partitionedBy(self, col: "ColumnOrName", *cols: "ColumnOrName") -> "DataFram partitionedBy.__doc__ = PySparkDataFrameWriterV2.partitionedBy.__doc__ + def clusterBy(self, col: str, *cols: str) -> "DataFrameWriterV2": + self._write.clustering_columns = [col] + self._write.clustering_columns.extend(cols) + return self + + clusterBy.__doc__ = PySparkDataFrameWriterV2.clusterBy.__doc__ + def create(self) -> None: self._write.mode = "create" _, _, ei = self._spark.client.execute_command( diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 32284bd45a7a..6052db1f2905 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -1638,6 +1638,42 @@ def sortBy( ) return self + @overload + def clusterBy(self, *cols: str) -> "DataFrameWriter": + ... + + @overload + def clusterBy(self, *cols: List[str]) -> "DataFrameWriter": + ... + + def clusterBy(self, *cols: Union[str, List[str]]) -> "DataFrameWriter": + """Clusters the data by the given columns to optimize query performance. + + .. versionadded:: 4.0.0 + + Parameters + ---------- + cols : str or list + name of columns + + Examples + -------- + Write a DataFrame into a Parquet file with clustering. + + >>> import tempfile + >>> with tempfile.TemporaryDirectory(prefix="clusterBy") as d: + ... spark.createDataFrame( + ... [{"age": 100, "name": "Hyukjin Kwon"}, {"age": 120, "name": "Ruifeng Zheng"}] + ... ).write.clusterBy("name").mode("overwrite").format("parquet").save(d) + """ + from pyspark.sql.classic.column import _to_seq + + if len(cols) == 1 and isinstance(cols[0], (list, tuple)): + cols = cols[0] # type: ignore[assignment] + assert len(cols) > 0, "clusterBy needs one or more clustering columns." + self._jwrite = self._jwrite.clusterBy(cols[0], _to_seq(self._spark._sc, cols[1:])) + return self + def save( self, path: Optional[str] = None, @@ -2397,6 +2433,15 @@ def partitionedBy(self, col: Column, *cols: Column) -> "DataFrameWriterV2": self._jwriter.partitionedBy(col, cols) return self + def clusterBy(self, col: str, *cols: str) -> "DataFrameWriterV2": + """ + Clusters the data by the given columns to optimize query performance. + """ + from pyspark.sql.classic.column import _to_seq + + self._jwriter.clusterBy(col, _to_seq(self._spark._sc, cols)) + return self + def create(self) -> None: """ Create a new table from the contents of the data frame. diff --git a/python/pyspark/sql/tests/test_readwriter.py b/python/pyspark/sql/tests/test_readwriter.py index 8060a9ae8bc7..f4f32dea9060 100644 --- a/python/pyspark/sql/tests/test_readwriter.py +++ b/python/pyspark/sql/tests/test_readwriter.py @@ -154,6 +154,47 @@ def count_bucketed_cols(names, table="pyspark_bucket"): ) self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect())) + def test_cluster_by(self): + data = [ + (1, "foo", 3.0), + (2, "foo", 5.0), + (3, "bar", -1.0), + (4, "bar", 6.0), + ] + df = self.spark.createDataFrame(data, ["x", "y", "z"]) + + def get_cluster_by_cols(table="pyspark_cluster_by"): + cols = self.spark.catalog.listColumns(table) + return [c.name for c in cols if c.isCluster] + + table_name = "pyspark_cluster_by" + with self.table(table_name): + # Test write with one clustering column + df.write.clusterBy("x").mode("overwrite").saveAsTable(table_name) + self.assertEqual(get_cluster_by_cols(), ["x"]) + self.assertSetEqual(set(data), set(self.spark.table(table_name).collect())) + + # Test write with two clustering columns + df.write.clusterBy("x", "y").mode("overwrite").option( + "overwriteSchema", "true" + ).saveAsTable(table_name) + self.assertEqual(get_cluster_by_cols(), ["x", "y"]) + self.assertSetEqual(set(data), set(self.spark.table(table_name).collect())) + + # Test write with a list of columns + df.write.clusterBy(["y", "z"]).mode("overwrite").option( + "overwriteSchema", "true" + ).saveAsTable(table_name) + self.assertEqual(get_cluster_by_cols(), ["y", "z"]) + self.assertSetEqual(set(data), set(self.spark.table(table_name).collect())) + + # Test write with a tuple of columns + df.write.clusterBy(("x", "z")).mode("overwrite").option( + "overwriteSchema", "true" + ).saveAsTable(table_name) + self.assertEqual(get_cluster_by_cols(), ["x", "z"]) + self.assertSetEqual(set(data), set(self.spark.table(table_name).collect())) + def test_insert_into(self): df = self.spark.createDataFrame([("a", 1), ("b", 2)], ["C1", "C2"]) with self.table("test_table"): @@ -250,6 +291,28 @@ def test_table_overwrite(self): with self.assertRaisesRegex(AnalysisException, "TABLE_OR_VIEW_NOT_FOUND"): df.writeTo("test_table").overwrite(lit(True)) + def test_cluster_by(self): + data = [ + (1, "foo", 3.0), + (2, "foo", 5.0), + (3, "bar", -1.0), + (4, "bar", 6.0), + ] + df = self.spark.createDataFrame(data, ["x", "y", "z"]) + + def get_cluster_by_cols(table="pyspark_cluster_by"): + # Note that listColumns only returns top-level clustering columns and doesn't consider + # nested clustering columns as isCluster. This is fine for this test. + cols = self.spark.catalog.listColumns(table) + return [c.name for c in cols if c.isCluster] + + table_name = "pyspark_cluster_by" + with self.table(table_name): + # Test write with one clustering column + df.writeTo(table_name).using("parquet").clusterBy("x").create() + self.assertEqual(get_cluster_by_cols(), ["x"]) + self.assertSetEqual(set(data), set(self.spark.table(table_name).collect())) + class ReadwriterTests(ReadwriterTestsMixin, ReusedSQLTestCase): pass diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/V1Table.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/V1Table.scala index d96b06789c40..9b9c569ad250 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/V1Table.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/V1Table.scala @@ -60,6 +60,10 @@ private[sql] case class V1Table(v1Table: CatalogTable) extends Table { partitions += spec.asTransform } + v1Table.clusterBySpec.foreach { spec => + partitions += spec.asTransform + } + partitions.toArray } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala index 30374b847ea5..f4b0c232c25f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala @@ -229,10 +229,7 @@ abstract class ExternalCatalogSuite extends SparkFunSuite { catalog.alterTable(tbl1.copy(properties = Map("toh" -> "frem"))) val newTbl1 = catalog.getTable("db2", "tbl1") assert(!tbl1.properties.contains("toh")) - // clusteringColumns property is injected during newTable, so we need - // to filter it out before comparing the properties. - assert(newTbl1.properties.size == - tbl1.properties.filter { case (key, _) => key != "clusteringColumns" }.size + 1) + assert(newTbl1.properties.size == tbl1.properties.size + 1) assert(newTbl1.properties.get("toh") == Some("frem")) } @@ -1076,7 +1073,8 @@ abstract class CatalogTestUtils { def newTable( name: String, database: Option[String] = None, - defaultColumns: Boolean = false): CatalogTable = { + defaultColumns: Boolean = false, + clusterBy: Boolean = false): CatalogTable = { CatalogTable( identifier = TableIdentifier(name, database), tableType = CatalogTableType.EXTERNAL, @@ -1113,10 +1111,14 @@ abstract class CatalogTestUtils { .add("b", "string") }, provider = Some(defaultProvider), - partitionColumnNames = Seq("a", "b"), - bucketSpec = Some(BucketSpec(4, Seq("col1"), Nil)), - properties = Map( - ClusterBySpec.toPropertyWithoutValidation(ClusterBySpec.fromColumnNames(Seq("c1", "c2"))))) + partitionColumnNames = if (clusterBy) Seq.empty else Seq("a", "b"), + bucketSpec = if (clusterBy) None else Some(BucketSpec(4, Seq("col1"), Nil)), + properties = if (clusterBy) { + Map( + ClusterBySpec.toPropertyWithoutValidation(ClusterBySpec.fromColumnNames(Seq("c1", "c2")))) + } else { + Map.empty + }) } def newView( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index b4a50736b0c7..f5f6fac96872 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -529,10 +529,7 @@ abstract class SessionCatalogSuite extends AnalysisTest with Eventually { catalog.alterTable(tbl1.copy(properties = Map("toh" -> "frem"))) val newTbl1 = catalog.getTableRawMetadata(TableIdentifier("tbl1", Some("db2"))) assert(!tbl1.properties.contains("toh")) - // clusteringColumns property is injected during newTable, so we need - // to filter it out before comparing the properties. - assert(newTbl1.properties.size == - tbl1.properties.filter { case (key, _) => key != "clusteringColumns" }.size + 1) + assert(newTbl1.properties.size == tbl1.properties.size + 1) assert(newTbl1.properties.get("toh") == Some("frem")) // Alter table without explicitly specifying database catalog.setCurrentDatabase("db2") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala index 973303f84acc..7c929b5da872 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala @@ -67,6 +67,10 @@ class CatalogSuite extends SharedSparkSession with AnalysisTest with BeforeAndAf sessionCatalog.createTable(utils.newTable(name, db), ignoreIfExists = false) } + private def createClusteredTable(name: String, db: Option[String] = None): Unit = { + sessionCatalog.createTable(utils.newTable(name, db, clusterBy = true), ignoreIfExists = false) + } + private def createTable(name: String, db: String, catalog: String, source: String, schema: StructType, option: Map[String, String], description: String): DataFrame = { spark.catalog.createTable(Array(catalog, db, name).mkString("."), source, @@ -106,9 +110,10 @@ class CatalogSuite extends SharedSparkSession with AnalysisTest with BeforeAndAf .map { db => spark.catalog.listColumns(db, tableName) } .getOrElse { spark.catalog.listColumns(tableName) } assert(tableMetadata.schema.nonEmpty, "bad test") - assert(tableMetadata.partitionColumnNames.nonEmpty, "bad test") - assert(tableMetadata.bucketSpec.isDefined, "bad test") - assert(tableMetadata.clusterBySpec.isDefined, "bad test") + if (tableMetadata.clusterBySpec.isEmpty) { + assert(tableMetadata.partitionColumnNames.nonEmpty, "bad test") + assert(tableMetadata.bucketSpec.isDefined, "bad test") + } assert(columns.collect().map(_.name).toSet == tableMetadata.schema.map(_.name).toSet) val bucketColumnNames = tableMetadata.bucketSpec.map(_.bucketColumnNames).getOrElse(Nil).toSet val clusteringColumnNames = tableMetadata.clusterBySpec.map { clusterBySpec => @@ -409,6 +414,12 @@ class CatalogSuite extends SharedSparkSession with AnalysisTest with BeforeAndAf testListColumns("tab1", dbName = Some("db1")) } + test("list columns in clustered table") { + createDatabase("db1") + createClusteredTable("tab1", Some("db1")) + testListColumns("tab1", dbName = Some("db1")) + } + test("SPARK-39615: qualified name with catalog - listColumns") { val answers = Map( "col1" -> ("int", true, false, true, false),