Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions python/pyspark/sql/connect/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -1655,6 +1655,7 @@ def __init__(self, child: "LogicalPlan") -> None:
self.mode: Optional[str] = None
self.sort_cols: List[str] = []
self.partitioning_cols: List[str] = []
self.clustering_cols: List[str] = []
self.options: Dict[str, Optional[str]] = {}
self.num_buckets: int = -1
self.bucket_cols: List[str] = []
Expand All @@ -1668,6 +1669,7 @@ def command(self, session: "SparkConnectClient") -> proto.Command:
plan.write_operation.source = self.source
plan.write_operation.sort_column_names.extend(self.sort_cols)
plan.write_operation.partitioning_columns.extend(self.partitioning_cols)
plan.write_operation.clustering_columns.extend(self.clustering_cols)

if self.num_buckets > 0:
plan.write_operation.bucket_by.bucket_column_names.extend(self.bucket_cols)
Expand Down Expand Up @@ -1727,6 +1729,7 @@ def print(self, indent: int = 0) -> str:
f"mode='{self.mode}' "
f"sort_cols='{self.sort_cols}' "
f"partitioning_cols='{self.partitioning_cols}' "
f"clustering_cols='{self.clustering_cols}' "
f"num_buckets='{self.num_buckets}' "
f"bucket_cols='{self.bucket_cols}' "
f"options='{self.options}'>"
Expand All @@ -1741,6 +1744,7 @@ def _repr_html_(self) -> str:
f"mode: '{self.mode}' <br />"
f"sort_cols: '{self.sort_cols}' <br />"
f"partitioning_cols: '{self.partitioning_cols}' <br />"
f"clustering_cols: '{self.clustering_cols}' <br />"
f"num_buckets: '{self.num_buckets}' <br />"
f"bucket_cols: '{self.bucket_cols}' <br />"
f"options: '{self.options}'<br />"
Expand All @@ -1754,6 +1758,7 @@ def __init__(self, child: "LogicalPlan", table_name: str) -> None:
self.table_name: Optional[str] = table_name
self.provider: Optional[str] = None
self.partitioning_columns: List[Column] = []
self.clustering_columns: List[str] = []
self.options: dict[str, Optional[str]] = {}
self.table_properties: dict[str, Optional[str]] = {}
self.mode: Optional[str] = None
Expand All @@ -1771,6 +1776,7 @@ def command(self, session: "SparkConnectClient") -> proto.Command:
plan.write_operation_v2.partitioning_columns.extend(
[c.to_plan(session) for c in self.partitioning_columns]
)
plan.write_operation_v2.clustering_columns.extend(self.clustering_columns)

for k in self.options:
if self.options[k] is None:
Expand Down
24 changes: 24 additions & 0 deletions python/pyspark/sql/connect/readwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,6 +643,23 @@ def sortBy(

sortBy.__doc__ = PySparkDataFrameWriter.sortBy.__doc__

@overload
def clusterBy(self, *cols: str) -> "DataFrameWriter":
...

@overload
def clusterBy(self, *cols: List[str]) -> "DataFrameWriter":
...

def clusterBy(self, *cols: Union[str, List[str]]) -> "DataFrameWriter":
if len(cols) == 1 and isinstance(cols[0], (list, tuple)):
cols = cols[0] # type: ignore[assignment]
assert len(cols) > 0, "clusterBy needs one or more clustering columns."
self._write.clustering_cols = cast(List[str], cols)
return self

clusterBy.__doc__ = PySparkDataFrameWriter.clusterBy.__doc__

def save(
self,
path: Optional[str] = None,
Expand Down Expand Up @@ -900,6 +917,13 @@ def partitionedBy(self, col: "ColumnOrName", *cols: "ColumnOrName") -> "DataFram

partitionedBy.__doc__ = PySparkDataFrameWriterV2.partitionedBy.__doc__

def clusterBy(self, col: str, *cols: str) -> "DataFrameWriterV2":
self._write.clustering_columns = [col]
self._write.clustering_columns.extend(cols)
return self

clusterBy.__doc__ = PySparkDataFrameWriterV2.clusterBy.__doc__

def create(self) -> None:
self._write.mode = "create"
_, _, ei = self._spark.client.execute_command(
Expand Down
45 changes: 45 additions & 0 deletions python/pyspark/sql/readwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1638,6 +1638,42 @@ def sortBy(
)
return self

@overload
def clusterBy(self, *cols: str) -> "DataFrameWriter":
...

@overload
def clusterBy(self, *cols: List[str]) -> "DataFrameWriter":
...

def clusterBy(self, *cols: Union[str, List[str]]) -> "DataFrameWriter":
"""Clusters the data by the given columns to optimize query performance.

.. versionadded:: 4.0.0

Parameters
----------
cols : str or list
name of columns

Examples
--------
Write a DataFrame into a Parquet file with clustering.

>>> import tempfile
>>> with tempfile.TemporaryDirectory(prefix="clusterBy") as d:
... spark.createDataFrame(
... [{"age": 100, "name": "Hyukjin Kwon"}, {"age": 120, "name": "Ruifeng Zheng"}]
... ).write.clusterBy("name").mode("overwrite").format("parquet").save(d)
"""
from pyspark.sql.classic.column import _to_seq

if len(cols) == 1 and isinstance(cols[0], (list, tuple)):
cols = cols[0] # type: ignore[assignment]
assert len(cols) > 0, "clusterBy needs one or more clustering columns."
self._jwrite = self._jwrite.clusterBy(cols[0], _to_seq(self._spark._sc, cols[1:]))
return self

def save(
self,
path: Optional[str] = None,
Expand Down Expand Up @@ -2397,6 +2433,15 @@ def partitionedBy(self, col: Column, *cols: Column) -> "DataFrameWriterV2":
self._jwriter.partitionedBy(col, cols)
return self

def clusterBy(self, col: str, *cols: str) -> "DataFrameWriterV2":
"""
Clusters the data by the given columns to optimize query performance.
"""
from pyspark.sql.classic.column import _to_seq

self._jwriter.clusterBy(col, _to_seq(self._spark._sc, cols))
return self

def create(self) -> None:
"""
Create a new table from the contents of the data frame.
Expand Down
63 changes: 63 additions & 0 deletions python/pyspark/sql/tests/test_readwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,47 @@ def count_bucketed_cols(names, table="pyspark_bucket"):
)
self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect()))

def test_cluster_by(self):
data = [
(1, "foo", 3.0),
(2, "foo", 5.0),
(3, "bar", -1.0),
(4, "bar", 6.0),
]
df = self.spark.createDataFrame(data, ["x", "y", "z"])

def get_cluster_by_cols(table="pyspark_cluster_by"):
cols = self.spark.catalog.listColumns(table)
return [c.name for c in cols if c.isCluster]

table_name = "pyspark_cluster_by"
with self.table(table_name):
# Test write with one clustering column
df.write.clusterBy("x").mode("overwrite").saveAsTable(table_name)
self.assertEqual(get_cluster_by_cols(), ["x"])
self.assertSetEqual(set(data), set(self.spark.table(table_name).collect()))

# Test write with two clustering columns
df.write.clusterBy("x", "y").mode("overwrite").option(
"overwriteSchema", "true"
).saveAsTable(table_name)
self.assertEqual(get_cluster_by_cols(), ["x", "y"])
self.assertSetEqual(set(data), set(self.spark.table(table_name).collect()))

# Test write with a list of columns
df.write.clusterBy(["y", "z"]).mode("overwrite").option(
"overwriteSchema", "true"
).saveAsTable(table_name)
self.assertEqual(get_cluster_by_cols(), ["y", "z"])
self.assertSetEqual(set(data), set(self.spark.table(table_name).collect()))

# Test write with a tuple of columns
df.write.clusterBy(("x", "z")).mode("overwrite").option(
"overwriteSchema", "true"
).saveAsTable(table_name)
self.assertEqual(get_cluster_by_cols(), ["x", "z"])
self.assertSetEqual(set(data), set(self.spark.table(table_name).collect()))

def test_insert_into(self):
df = self.spark.createDataFrame([("a", 1), ("b", 2)], ["C1", "C2"])
with self.table("test_table"):
Expand Down Expand Up @@ -250,6 +291,28 @@ def test_table_overwrite(self):
with self.assertRaisesRegex(AnalysisException, "TABLE_OR_VIEW_NOT_FOUND"):
df.writeTo("test_table").overwrite(lit(True))

def test_cluster_by(self):
data = [
(1, "foo", 3.0),
(2, "foo", 5.0),
(3, "bar", -1.0),
(4, "bar", 6.0),
]
df = self.spark.createDataFrame(data, ["x", "y", "z"])

def get_cluster_by_cols(table="pyspark_cluster_by"):
# Note that listColumns only returns top-level clustering columns and doesn't consider
# nested clustering columns as isCluster. This is fine for this test.
cols = self.spark.catalog.listColumns(table)
return [c.name for c in cols if c.isCluster]

table_name = "pyspark_cluster_by"
with self.table(table_name):
# Test write with one clustering column
df.writeTo(table_name).using("parquet").clusterBy("x").create()
self.assertEqual(get_cluster_by_cols(), ["x"])
self.assertSetEqual(set(data), set(self.spark.table(table_name).collect()))


class ReadwriterTests(ReadwriterTestsMixin, ReusedSQLTestCase):
pass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ private[sql] case class V1Table(v1Table: CatalogTable) extends Table {
partitions += spec.asTransform
}

v1Table.clusterBySpec.foreach { spec =>
partitions += spec.asTransform
}

partitions.toArray
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,10 +229,7 @@ abstract class ExternalCatalogSuite extends SparkFunSuite {
catalog.alterTable(tbl1.copy(properties = Map("toh" -> "frem")))
val newTbl1 = catalog.getTable("db2", "tbl1")
assert(!tbl1.properties.contains("toh"))
// clusteringColumns property is injected during newTable, so we need
// to filter it out before comparing the properties.
assert(newTbl1.properties.size ==
tbl1.properties.filter { case (key, _) => key != "clusteringColumns" }.size + 1)
assert(newTbl1.properties.size == tbl1.properties.size + 1)
assert(newTbl1.properties.get("toh") == Some("frem"))
}

Expand Down Expand Up @@ -1076,7 +1073,8 @@ abstract class CatalogTestUtils {
def newTable(
name: String,
database: Option[String] = None,
defaultColumns: Boolean = false): CatalogTable = {
defaultColumns: Boolean = false,
clusterBy: Boolean = false): CatalogTable = {
CatalogTable(
identifier = TableIdentifier(name, database),
tableType = CatalogTableType.EXTERNAL,
Expand Down Expand Up @@ -1113,10 +1111,14 @@ abstract class CatalogTestUtils {
.add("b", "string")
},
provider = Some(defaultProvider),
partitionColumnNames = Seq("a", "b"),
bucketSpec = Some(BucketSpec(4, Seq("col1"), Nil)),
properties = Map(
ClusterBySpec.toPropertyWithoutValidation(ClusterBySpec.fromColumnNames(Seq("c1", "c2")))))
partitionColumnNames = if (clusterBy) Seq.empty else Seq("a", "b"),
bucketSpec = if (clusterBy) None else Some(BucketSpec(4, Seq("col1"), Nil)),
properties = if (clusterBy) {
Map(
ClusterBySpec.toPropertyWithoutValidation(ClusterBySpec.fromColumnNames(Seq("c1", "c2"))))
} else {
Map.empty
})
}

def newView(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -529,10 +529,7 @@ abstract class SessionCatalogSuite extends AnalysisTest with Eventually {
catalog.alterTable(tbl1.copy(properties = Map("toh" -> "frem")))
val newTbl1 = catalog.getTableRawMetadata(TableIdentifier("tbl1", Some("db2")))
assert(!tbl1.properties.contains("toh"))
// clusteringColumns property is injected during newTable, so we need
// to filter it out before comparing the properties.
assert(newTbl1.properties.size ==
tbl1.properties.filter { case (key, _) => key != "clusteringColumns" }.size + 1)
assert(newTbl1.properties.size == tbl1.properties.size + 1)
assert(newTbl1.properties.get("toh") == Some("frem"))
// Alter table without explicitly specifying database
catalog.setCurrentDatabase("db2")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ class CatalogSuite extends SharedSparkSession with AnalysisTest with BeforeAndAf
sessionCatalog.createTable(utils.newTable(name, db), ignoreIfExists = false)
}

private def createClusteredTable(name: String, db: Option[String] = None): Unit = {
sessionCatalog.createTable(utils.newTable(name, db, clusterBy = true), ignoreIfExists = false)
}

private def createTable(name: String, db: String, catalog: String, source: String,
schema: StructType, option: Map[String, String], description: String): DataFrame = {
spark.catalog.createTable(Array(catalog, db, name).mkString("."), source,
Expand Down Expand Up @@ -106,9 +110,10 @@ class CatalogSuite extends SharedSparkSession with AnalysisTest with BeforeAndAf
.map { db => spark.catalog.listColumns(db, tableName) }
.getOrElse { spark.catalog.listColumns(tableName) }
assert(tableMetadata.schema.nonEmpty, "bad test")
assert(tableMetadata.partitionColumnNames.nonEmpty, "bad test")
assert(tableMetadata.bucketSpec.isDefined, "bad test")
assert(tableMetadata.clusterBySpec.isDefined, "bad test")
if (tableMetadata.clusterBySpec.isEmpty) {
assert(tableMetadata.partitionColumnNames.nonEmpty, "bad test")
assert(tableMetadata.bucketSpec.isDefined, "bad test")
}
assert(columns.collect().map(_.name).toSet == tableMetadata.schema.map(_.name).toSet)
val bucketColumnNames = tableMetadata.bucketSpec.map(_.bucketColumnNames).getOrElse(Nil).toSet
val clusteringColumnNames = tableMetadata.clusterBySpec.map { clusterBySpec =>
Expand Down Expand Up @@ -409,6 +414,12 @@ class CatalogSuite extends SharedSparkSession with AnalysisTest with BeforeAndAf
testListColumns("tab1", dbName = Some("db1"))
}

test("list columns in clustered table") {
createDatabase("db1")
createClusteredTable("tab1", Some("db1"))
testListColumns("tab1", dbName = Some("db1"))
}

test("SPARK-39615: qualified name with catalog - listColumns") {
val answers = Map(
"col1" -> ("int", true, false, true, false),
Expand Down