Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ message Aggregate {
message Sort {
Relation input = 1;
repeated SortField sort_fields = 2;
bool is_global = 3;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unfortunately, global is a Python keyword, so name it as is_global


message SortField {
Expression expression = 1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,45 @@ package object dsl {
.build()
}

def createDefaultSortField(col: String): Sort.SortField = {
Sort.SortField
.newBuilder()
.setNulls(Sort.SortNulls.SORT_NULLS_FIRST)
.setDirection(Sort.SortDirection.SORT_DIRECTION_ASCENDING)
.setExpression(
Expression.newBuilder
.setUnresolvedAttribute(
Expression.UnresolvedAttribute.newBuilder.setUnparsedIdentifier(col).build())
.build())
.build()
}

def sort(columns: String*): Relation = {
Relation
.newBuilder()
.setSort(
Sort
.newBuilder()
.setInput(logicalPlan)
.addAllSortFields(columns.map(createDefaultSortField).asJava)
.setIsGlobal(true)
.build())
.build()
}

def sortWithinPartitions(columns: String*): Relation = {
Relation
.newBuilder()
.setSort(
Sort
.newBuilder()
.setInput(logicalPlan)
.addAllSortFields(columns.map(createDefaultSortField).asJava)
.setIsGlobal(false)
.build())
.build()
}

def groupBy(groupingExprs: Expression*)(aggregateExprs: Expression*): Relation = {
val agg = Aggregate.newBuilder()
agg.setInput(logicalPlan)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) {
assert(rel.getSortFieldsCount > 0, "'sort_fields' must be present and contain elements.")
logical.Sort(
child = transformRelation(rel.getInput),
global = true,
global = rel.getIsGlobal,
order = rel.getSortFieldsList.asScala.map(transformSortOrderExpression).toSeq)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.Expression.UnresolvedStar
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.test.SharedSparkSession

/**
Expand All @@ -32,7 +32,7 @@ import org.apache.spark.sql.test.SharedSparkSession
*/
trait SparkConnectPlanTest extends SharedSparkSession {

def transform(rel: proto.Relation): LogicalPlan = {
def transform(rel: proto.Relation): logical.LogicalPlan = {
new SparkConnectPlanner(rel, spark).transform()
}

Expand Down Expand Up @@ -149,9 +149,25 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest {

val res = transform(
proto.Relation.newBuilder
.setSort(proto.Sort.newBuilder.addAllSortFields(Seq(f).asJava).setInput(readRel))
.setSort(
proto.Sort.newBuilder
.addAllSortFields(Seq(f).asJava)
.setInput(readRel)
.setIsGlobal(true))
.build())
assert(res.nodeName == "Sort")
assert(res.asInstanceOf[logical.Sort].global)

val res2 = transform(
proto.Relation.newBuilder
.setSort(
proto.Sort.newBuilder
.addAllSortFields(Seq(f).asJava)
.setInput(readRel)
.setIsGlobal(false))
.build())
assert(res2.nodeName == "Sort")
assert(!res2.asInstanceOf[logical.Sort].global)
}

test("Simple Union") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,16 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
comparePlans(connectPlan, sparkPlan)
}

test("Test sort") {
val connectPlan = connectTestRelation.sort("id", "name")
val sparkPlan = sparkTestRelation.sort("id", "name")
comparePlans(connectPlan, sparkPlan)

val connectPlan2 = connectTestRelation.sortWithinPartitions("id", "name")
val sparkPlan2 = sparkTestRelation.sortWithinPartitions("id", "name")
comparePlans(connectPlan2, sparkPlan2)
}

test("column alias") {
val connectPlan = connectTestRelation.select("id".protoAttr.as("id2"))
val sparkPlan = sparkTestRelation.select(Column("id").alias("id2"))
Expand Down
10 changes: 9 additions & 1 deletion python/pyspark/sql/connect/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,15 @@ def offset(self, n: int) -> "DataFrame":

def sort(self, *cols: "ColumnOrString") -> "DataFrame":
"""Sort by a specific column"""
return DataFrame.withPlan(plan.Sort(self._plan, *cols), session=self._session)
return DataFrame.withPlan(
plan.Sort(self._plan, columns=list(cols), is_global=True), session=self._session
)

def sortWithinPartitions(self, *cols: "ColumnOrString") -> "DataFrame":
"""Sort within each partition by a specific column"""
return DataFrame.withPlan(
plan.Sort(self._plan, columns=list(cols), is_global=False), session=self._session
)

def sample(
self,
Expand Down
18 changes: 12 additions & 6 deletions python/pyspark/sql/connect/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,15 +360,19 @@ def _repr_html_(self) -> str:

class Sort(LogicalPlan):
def __init__(
self, child: Optional["LogicalPlan"], *columns: Union[SortOrder, ColumnRef, str]
self,
child: Optional["LogicalPlan"],
columns: List[Union[SortOrder, ColumnRef, str]],
is_global: bool,
) -> None:
super().__init__(child)
self.columns = list(columns)
self.columns = columns
self.is_global = is_global

def col_to_sort_field(
self, col: Union[SortOrder, ColumnRef, str], session: Optional["RemoteSparkSession"]
) -> proto.Sort.SortField:
if type(col) is SortOrder:
if isinstance(col, SortOrder):
sf = proto.Sort.SortField()
sf.expression.CopyFrom(col.ref.to_plan(session))
sf.direction = (
Expand All @@ -385,10 +389,10 @@ def col_to_sort_field(
else:
sf = proto.Sort.SortField()
# Check string
if type(col) is ColumnRef:
if isinstance(col, ColumnRef):
sf.expression.CopyFrom(col.to_plan(session))
else:
sf.expression.CopyFrom(self.unresolved_attr(cast(str, col)))
sf.expression.CopyFrom(self.unresolved_attr(col))
sf.direction = proto.Sort.SortDirection.SORT_DIRECTION_ASCENDING
sf.nulls = proto.Sort.SortNulls.SORT_NULLS_LAST
return sf
Expand All @@ -398,18 +402,20 @@ def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation:
plan = proto.Relation()
plan.sort.input.CopyFrom(self._child.plan(session))
plan.sort.sort_fields.extend([self.col_to_sort_field(x, session) for x in self.columns])
plan.sort.is_global = self.is_global
return plan

def print(self, indent: int = 0) -> str:
c_buf = self._child.print(indent + LogicalPlan.INDENT) if self._child else ""
return f"{' ' * indent}<Sort columns={self.columns}>\n{c_buf}"
return f"{' ' * indent}<Sort columns={self.columns}, global={self.is_global}>\n{c_buf}"

def _repr_html_(self) -> str:
return f"""
<ul>
<li>
<b>Sort</b><br />
{", ".join([str(c) for c in self.columns])}
global: {self.is_global} <br />
{self._child_repr_()}
</li>
</uL>
Expand Down
Loading