diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py
index b64b1d0ab08a..25c6076c14fc 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -1655,6 +1655,7 @@ def __init__(self, child: "LogicalPlan") -> None:
self.mode: Optional[str] = None
self.sort_cols: List[str] = []
self.partitioning_cols: List[str] = []
+ self.clustering_cols: List[str] = []
self.options: Dict[str, Optional[str]] = {}
self.num_buckets: int = -1
self.bucket_cols: List[str] = []
@@ -1668,6 +1669,7 @@ def command(self, session: "SparkConnectClient") -> proto.Command:
plan.write_operation.source = self.source
plan.write_operation.sort_column_names.extend(self.sort_cols)
plan.write_operation.partitioning_columns.extend(self.partitioning_cols)
+ plan.write_operation.clustering_columns.extend(self.clustering_cols)
if self.num_buckets > 0:
plan.write_operation.bucket_by.bucket_column_names.extend(self.bucket_cols)
@@ -1727,6 +1729,7 @@ def print(self, indent: int = 0) -> str:
f"mode='{self.mode}' "
f"sort_cols='{self.sort_cols}' "
f"partitioning_cols='{self.partitioning_cols}' "
+ f"clustering_cols='{self.clustering_cols}' "
f"num_buckets='{self.num_buckets}' "
f"bucket_cols='{self.bucket_cols}' "
f"options='{self.options}'>"
@@ -1741,6 +1744,7 @@ def _repr_html_(self) -> str:
f"mode: '{self.mode}'
"
f"sort_cols: '{self.sort_cols}'
"
f"partitioning_cols: '{self.partitioning_cols}'
"
+ f"clustering_cols: '{self.clustering_cols}'
"
f"num_buckets: '{self.num_buckets}'
"
f"bucket_cols: '{self.bucket_cols}'
"
f"options: '{self.options}'
"
@@ -1754,6 +1758,7 @@ def __init__(self, child: "LogicalPlan", table_name: str) -> None:
self.table_name: Optional[str] = table_name
self.provider: Optional[str] = None
self.partitioning_columns: List[Column] = []
+ self.clustering_columns: List[str] = []
self.options: dict[str, Optional[str]] = {}
self.table_properties: dict[str, Optional[str]] = {}
self.mode: Optional[str] = None
@@ -1771,6 +1776,7 @@ def command(self, session: "SparkConnectClient") -> proto.Command:
plan.write_operation_v2.partitioning_columns.extend(
[c.to_plan(session) for c in self.partitioning_columns]
)
+ plan.write_operation_v2.clustering_columns.extend(self.clustering_columns)
for k in self.options:
if self.options[k] is None:
diff --git a/python/pyspark/sql/connect/readwriter.py b/python/pyspark/sql/connect/readwriter.py
index 0a86dd22b9de..626b27a9d2ad 100644
--- a/python/pyspark/sql/connect/readwriter.py
+++ b/python/pyspark/sql/connect/readwriter.py
@@ -643,6 +643,23 @@ def sortBy(
sortBy.__doc__ = PySparkDataFrameWriter.sortBy.__doc__
+ @overload
+ def clusterBy(self, *cols: str) -> "DataFrameWriter":
+ ...
+
+ @overload
+ def clusterBy(self, *cols: List[str]) -> "DataFrameWriter":
+ ...
+
+ def clusterBy(self, *cols: Union[str, List[str]]) -> "DataFrameWriter":
+ if len(cols) == 1 and isinstance(cols[0], (list, tuple)):
+ cols = cols[0] # type: ignore[assignment]
+ assert len(cols) > 0, "clusterBy needs one or more clustering columns."
+ self._write.clustering_cols = cast(List[str], cols)
+ return self
+
+ clusterBy.__doc__ = PySparkDataFrameWriter.clusterBy.__doc__
+
def save(
self,
path: Optional[str] = None,
@@ -900,6 +917,13 @@ def partitionedBy(self, col: "ColumnOrName", *cols: "ColumnOrName") -> "DataFram
partitionedBy.__doc__ = PySparkDataFrameWriterV2.partitionedBy.__doc__
+ def clusterBy(self, col: str, *cols: str) -> "DataFrameWriterV2":
+ self._write.clustering_columns = [col]
+ self._write.clustering_columns.extend(cols)
+ return self
+
+ clusterBy.__doc__ = PySparkDataFrameWriterV2.clusterBy.__doc__
+
def create(self) -> None:
self._write.mode = "create"
_, _, ei = self._spark.client.execute_command(
diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
index 32284bd45a7a..6052db1f2905 100644
--- a/python/pyspark/sql/readwriter.py
+++ b/python/pyspark/sql/readwriter.py
@@ -1638,6 +1638,42 @@ def sortBy(
)
return self
+ @overload
+ def clusterBy(self, *cols: str) -> "DataFrameWriter":
+ ...
+
+ @overload
+ def clusterBy(self, *cols: List[str]) -> "DataFrameWriter":
+ ...
+
+ def clusterBy(self, *cols: Union[str, List[str]]) -> "DataFrameWriter":
+ """Clusters the data by the given columns to optimize query performance.
+
+ .. versionadded:: 4.0.0
+
+ Parameters
+ ----------
+ cols : str or list
+ name of columns
+
+ Examples
+ --------
+ Write a DataFrame into a Parquet file with clustering.
+
+ >>> import tempfile
+ >>> with tempfile.TemporaryDirectory(prefix="clusterBy") as d:
+ ... spark.createDataFrame(
+ ... [{"age": 100, "name": "Hyukjin Kwon"}, {"age": 120, "name": "Ruifeng Zheng"}]
+ ... ).write.clusterBy("name").mode("overwrite").format("parquet").save(d)
+ """
+ from pyspark.sql.classic.column import _to_seq
+
+ if len(cols) == 1 and isinstance(cols[0], (list, tuple)):
+ cols = cols[0] # type: ignore[assignment]
+ assert len(cols) > 0, "clusterBy needs one or more clustering columns."
+ self._jwrite = self._jwrite.clusterBy(cols[0], _to_seq(self._spark._sc, cols[1:]))
+ return self
+
def save(
self,
path: Optional[str] = None,
@@ -2397,6 +2433,15 @@ def partitionedBy(self, col: Column, *cols: Column) -> "DataFrameWriterV2":
self._jwriter.partitionedBy(col, cols)
return self
+ def clusterBy(self, col: str, *cols: str) -> "DataFrameWriterV2":
+ """
+ Clusters the data by the given columns to optimize query performance.
+ """
+ from pyspark.sql.classic.column import _to_seq
+
+ self._jwriter.clusterBy(col, _to_seq(self._spark._sc, cols))
+ return self
+
def create(self) -> None:
"""
Create a new table from the contents of the data frame.
diff --git a/python/pyspark/sql/tests/test_readwriter.py b/python/pyspark/sql/tests/test_readwriter.py
index 8060a9ae8bc7..f4f32dea9060 100644
--- a/python/pyspark/sql/tests/test_readwriter.py
+++ b/python/pyspark/sql/tests/test_readwriter.py
@@ -154,6 +154,47 @@ def count_bucketed_cols(names, table="pyspark_bucket"):
)
self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect()))
+ def test_cluster_by(self):
+ data = [
+ (1, "foo", 3.0),
+ (2, "foo", 5.0),
+ (3, "bar", -1.0),
+ (4, "bar", 6.0),
+ ]
+ df = self.spark.createDataFrame(data, ["x", "y", "z"])
+
+ def get_cluster_by_cols(table="pyspark_cluster_by"):
+ cols = self.spark.catalog.listColumns(table)
+ return [c.name for c in cols if c.isCluster]
+
+ table_name = "pyspark_cluster_by"
+ with self.table(table_name):
+ # Test write with one clustering column
+ df.write.clusterBy("x").mode("overwrite").saveAsTable(table_name)
+ self.assertEqual(get_cluster_by_cols(), ["x"])
+ self.assertSetEqual(set(data), set(self.spark.table(table_name).collect()))
+
+ # Test write with two clustering columns
+ df.write.clusterBy("x", "y").mode("overwrite").option(
+ "overwriteSchema", "true"
+ ).saveAsTable(table_name)
+ self.assertEqual(get_cluster_by_cols(), ["x", "y"])
+ self.assertSetEqual(set(data), set(self.spark.table(table_name).collect()))
+
+ # Test write with a list of columns
+ df.write.clusterBy(["y", "z"]).mode("overwrite").option(
+ "overwriteSchema", "true"
+ ).saveAsTable(table_name)
+ self.assertEqual(get_cluster_by_cols(), ["y", "z"])
+ self.assertSetEqual(set(data), set(self.spark.table(table_name).collect()))
+
+ # Test write with a tuple of columns
+ df.write.clusterBy(("x", "z")).mode("overwrite").option(
+ "overwriteSchema", "true"
+ ).saveAsTable(table_name)
+ self.assertEqual(get_cluster_by_cols(), ["x", "z"])
+ self.assertSetEqual(set(data), set(self.spark.table(table_name).collect()))
+
def test_insert_into(self):
df = self.spark.createDataFrame([("a", 1), ("b", 2)], ["C1", "C2"])
with self.table("test_table"):
@@ -250,6 +291,28 @@ def test_table_overwrite(self):
with self.assertRaisesRegex(AnalysisException, "TABLE_OR_VIEW_NOT_FOUND"):
df.writeTo("test_table").overwrite(lit(True))
+ def test_cluster_by(self):
+ data = [
+ (1, "foo", 3.0),
+ (2, "foo", 5.0),
+ (3, "bar", -1.0),
+ (4, "bar", 6.0),
+ ]
+ df = self.spark.createDataFrame(data, ["x", "y", "z"])
+
+ def get_cluster_by_cols(table="pyspark_cluster_by"):
+ # Note that listColumns only returns top-level clustering columns and doesn't consider
+ # nested clustering columns as isCluster. This is fine for this test.
+ cols = self.spark.catalog.listColumns(table)
+ return [c.name for c in cols if c.isCluster]
+
+ table_name = "pyspark_cluster_by"
+ with self.table(table_name):
+ # Test write with one clustering column
+ df.writeTo(table_name).using("parquet").clusterBy("x").create()
+ self.assertEqual(get_cluster_by_cols(), ["x"])
+ self.assertSetEqual(set(data), set(self.spark.table(table_name).collect()))
+
class ReadwriterTests(ReadwriterTestsMixin, ReusedSQLTestCase):
pass
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/V1Table.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/V1Table.scala
index d96b06789c40..9b9c569ad250 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/V1Table.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/V1Table.scala
@@ -60,6 +60,10 @@ private[sql] case class V1Table(v1Table: CatalogTable) extends Table {
partitions += spec.asTransform
}
+ v1Table.clusterBySpec.foreach { spec =>
+ partitions += spec.asTransform
+ }
+
partitions.toArray
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala
index 30374b847ea5..f4b0c232c25f 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala
@@ -229,10 +229,7 @@ abstract class ExternalCatalogSuite extends SparkFunSuite {
catalog.alterTable(tbl1.copy(properties = Map("toh" -> "frem")))
val newTbl1 = catalog.getTable("db2", "tbl1")
assert(!tbl1.properties.contains("toh"))
- // clusteringColumns property is injected during newTable, so we need
- // to filter it out before comparing the properties.
- assert(newTbl1.properties.size ==
- tbl1.properties.filter { case (key, _) => key != "clusteringColumns" }.size + 1)
+ assert(newTbl1.properties.size == tbl1.properties.size + 1)
assert(newTbl1.properties.get("toh") == Some("frem"))
}
@@ -1076,7 +1073,8 @@ abstract class CatalogTestUtils {
def newTable(
name: String,
database: Option[String] = None,
- defaultColumns: Boolean = false): CatalogTable = {
+ defaultColumns: Boolean = false,
+ clusterBy: Boolean = false): CatalogTable = {
CatalogTable(
identifier = TableIdentifier(name, database),
tableType = CatalogTableType.EXTERNAL,
@@ -1113,10 +1111,14 @@ abstract class CatalogTestUtils {
.add("b", "string")
},
provider = Some(defaultProvider),
- partitionColumnNames = Seq("a", "b"),
- bucketSpec = Some(BucketSpec(4, Seq("col1"), Nil)),
- properties = Map(
- ClusterBySpec.toPropertyWithoutValidation(ClusterBySpec.fromColumnNames(Seq("c1", "c2")))))
+ partitionColumnNames = if (clusterBy) Seq.empty else Seq("a", "b"),
+ bucketSpec = if (clusterBy) None else Some(BucketSpec(4, Seq("col1"), Nil)),
+ properties = if (clusterBy) {
+ Map(
+ ClusterBySpec.toPropertyWithoutValidation(ClusterBySpec.fromColumnNames(Seq("c1", "c2"))))
+ } else {
+ Map.empty
+ })
}
def newView(
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala
index b4a50736b0c7..f5f6fac96872 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala
@@ -529,10 +529,7 @@ abstract class SessionCatalogSuite extends AnalysisTest with Eventually {
catalog.alterTable(tbl1.copy(properties = Map("toh" -> "frem")))
val newTbl1 = catalog.getTableRawMetadata(TableIdentifier("tbl1", Some("db2")))
assert(!tbl1.properties.contains("toh"))
- // clusteringColumns property is injected during newTable, so we need
- // to filter it out before comparing the properties.
- assert(newTbl1.properties.size ==
- tbl1.properties.filter { case (key, _) => key != "clusteringColumns" }.size + 1)
+ assert(newTbl1.properties.size == tbl1.properties.size + 1)
assert(newTbl1.properties.get("toh") == Some("frem"))
// Alter table without explicitly specifying database
catalog.setCurrentDatabase("db2")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala
index 973303f84acc..7c929b5da872 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala
@@ -67,6 +67,10 @@ class CatalogSuite extends SharedSparkSession with AnalysisTest with BeforeAndAf
sessionCatalog.createTable(utils.newTable(name, db), ignoreIfExists = false)
}
+ private def createClusteredTable(name: String, db: Option[String] = None): Unit = {
+ sessionCatalog.createTable(utils.newTable(name, db, clusterBy = true), ignoreIfExists = false)
+ }
+
private def createTable(name: String, db: String, catalog: String, source: String,
schema: StructType, option: Map[String, String], description: String): DataFrame = {
spark.catalog.createTable(Array(catalog, db, name).mkString("."), source,
@@ -106,9 +110,10 @@ class CatalogSuite extends SharedSparkSession with AnalysisTest with BeforeAndAf
.map { db => spark.catalog.listColumns(db, tableName) }
.getOrElse { spark.catalog.listColumns(tableName) }
assert(tableMetadata.schema.nonEmpty, "bad test")
- assert(tableMetadata.partitionColumnNames.nonEmpty, "bad test")
- assert(tableMetadata.bucketSpec.isDefined, "bad test")
- assert(tableMetadata.clusterBySpec.isDefined, "bad test")
+ if (tableMetadata.clusterBySpec.isEmpty) {
+ assert(tableMetadata.partitionColumnNames.nonEmpty, "bad test")
+ assert(tableMetadata.bucketSpec.isDefined, "bad test")
+ }
assert(columns.collect().map(_.name).toSet == tableMetadata.schema.map(_.name).toSet)
val bucketColumnNames = tableMetadata.bucketSpec.map(_.bucketColumnNames).getOrElse(Nil).toSet
val clusteringColumnNames = tableMetadata.clusterBySpec.map { clusterBySpec =>
@@ -409,6 +414,12 @@ class CatalogSuite extends SharedSparkSession with AnalysisTest with BeforeAndAf
testListColumns("tab1", dbName = Some("db1"))
}
+ test("list columns in clustered table") {
+ createDatabase("db1")
+ createClusteredTable("tab1", Some("db1"))
+ testListColumns("tab1", dbName = Some("db1"))
+ }
+
test("SPARK-39615: qualified name with catalog - listColumns") {
val answers = Map(
"col1" -> ("int", true, false, true, false),