Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ message Relation {
Offset offset = 13;
Deduplicate deduplicate = 14;
Range range = 15;
SubqueryAlias subquery_alias = 16;

Unknown unknown = 999;
}
Expand All @@ -56,7 +57,6 @@ message Unknown {}
// Common metadata of all relations.
message RelationCommon {
string source_info = 1;
string alias = 2;
}

// Relation that uses a SQL query to generate the output.
Expand Down Expand Up @@ -223,6 +223,7 @@ message Sample {
message Range {
// Optional. Default value = 0
int32 start = 1;
// Required.
int32 end = 2;
// Optional. Default value = 1
Step step = 3;
Expand All @@ -238,3 +239,13 @@ message Range {
int32 num_partitions = 1;
}
}

// Relation alias.
message SubqueryAlias {
// Required. The input relation.
Relation input = 1;
// Required. The alias.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how can we check if it's present or not, given the default value is empty string?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given our discussion on the protocol, this is a required field so we ask clients to always set it. Server side only fetch the value in this field not matter what it is (either "", or not)

string alias = 2;
// Optional. Qualifier of the alias.
repeated string qualifier = 3;
}
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ package object dsl {
def as(alias: String): Relation = {
Relation
.newBuilder(logicalPlan)
.setCommon(RelationCommon.newBuilder().setAlias(alias))
.setSubqueryAlias(SubqueryAlias.newBuilder().setAlias(alias).setInput(logicalPlan))
.build()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import scala.collection.JavaConverters._

import org.apache.spark.connect.proto
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.AliasIdentifier
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Expression}
Expand Down Expand Up @@ -54,8 +55,8 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) {
}

rel.getRelTypeCase match {
case proto.Relation.RelTypeCase.READ => transformReadRel(rel.getRead, common)
case proto.Relation.RelTypeCase.PROJECT => transformProject(rel.getProject, common)
case proto.Relation.RelTypeCase.READ => transformReadRel(rel.getRead)
case proto.Relation.RelTypeCase.PROJECT => transformProject(rel.getProject)
case proto.Relation.RelTypeCase.FILTER => transformFilter(rel.getFilter)
case proto.Relation.RelTypeCase.LIMIT => transformLimit(rel.getLimit)
case proto.Relation.RelTypeCase.OFFSET => transformOffset(rel.getOffset)
Expand All @@ -66,9 +67,11 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) {
case proto.Relation.RelTypeCase.AGGREGATE => transformAggregate(rel.getAggregate)
case proto.Relation.RelTypeCase.SQL => transformSql(rel.getSql)
case proto.Relation.RelTypeCase.LOCAL_RELATION =>
transformLocalRelation(rel.getLocalRelation, common)
transformLocalRelation(rel.getLocalRelation)
case proto.Relation.RelTypeCase.SAMPLE => transformSample(rel.getSample)
case proto.Relation.RelTypeCase.RANGE => transformRange(rel.getRange)
case proto.Relation.RelTypeCase.SUBQUERY_ALIAS =>
transformSubqueryAlias(rel.getSubqueryAlias)
case proto.Relation.RelTypeCase.RELTYPE_NOT_SET =>
throw new IndexOutOfBoundsException("Expected Relation to be set, but is empty.")
case _ => throw InvalidPlanInput(s"${rel.getUnknown} not supported.")
Expand All @@ -79,6 +82,16 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) {
session.sessionState.sqlParser.parsePlan(sql.getQuery)
}

private def transformSubqueryAlias(alias: proto.SubqueryAlias): LogicalPlan = {
val aliasIdentifier =
if (alias.getQualifierCount > 0) {
AliasIdentifier.apply(alias.getAlias, alias.getQualifierList.asScala.toSeq)
} else {
AliasIdentifier.apply(alias.getAlias)
}
SubqueryAlias(aliasIdentifier, transformRelation(alias.getInput))
}

/**
* All fields of [[proto.Sample]] are optional. However, given those are proto primitive types,
* we cannot differentiate if the field is not or set when the field's value equals to the type
Expand Down Expand Up @@ -141,35 +154,21 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) {
}
}

private def transformLocalRelation(
rel: proto.LocalRelation,
common: Option[proto.RelationCommon]): LogicalPlan = {
private def transformLocalRelation(rel: proto.LocalRelation): LogicalPlan = {
val attributes = rel.getAttributesList.asScala.map(transformAttribute(_)).toSeq
val relation = new org.apache.spark.sql.catalyst.plans.logical.LocalRelation(attributes)
if (common.nonEmpty && common.get.getAlias.nonEmpty) {
logical.SubqueryAlias(identifier = common.get.getAlias, child = relation)
} else {
relation
}
new org.apache.spark.sql.catalyst.plans.logical.LocalRelation(attributes)
}

private def transformAttribute(exp: proto.Expression.QualifiedAttribute): Attribute = {
AttributeReference(exp.getName, DataTypeProtoConverter.toCatalystType(exp.getType))()
}

private def transformReadRel(
rel: proto.Read,
common: Option[proto.RelationCommon]): LogicalPlan = {
private def transformReadRel(rel: proto.Read): LogicalPlan = {
val baseRelation = rel.getReadTypeCase match {
case proto.Read.ReadTypeCase.NAMED_TABLE =>
val multipartIdentifier =
CatalystSqlParser.parseMultipartIdentifier(rel.getNamedTable.getUnparsedIdentifier)
val child = UnresolvedRelation(multipartIdentifier)
if (common.nonEmpty && common.get.getAlias.nonEmpty) {
SubqueryAlias(identifier = common.get.getAlias, child = child)
} else {
child
}
UnresolvedRelation(multipartIdentifier)
case proto.Read.ReadTypeCase.DATA_SOURCE =>
if (rel.getDataSource.getFormat == "") {
throw InvalidPlanInput("DataSource requires a format")
Expand All @@ -193,9 +192,7 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) {
logical.Filter(condition = transformExpression(rel.getCondition), child = baseRel)
}

private def transformProject(
rel: proto.Project,
common: Option[proto.RelationCommon]): LogicalPlan = {
private def transformProject(rel: proto.Project): LogicalPlan = {
val baseRel = transformRelation(rel.getInput)
// TODO: support the target field for *.
val projection =
Expand All @@ -204,12 +201,7 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) {
} else {
rel.getExpressionsList.asScala.map(transformExpression).map(UnresolvedAlias(_))
}
val project = logical.Project(projectList = projection.toSeq, child = baseRel)
if (common.nonEmpty && common.get.getAlias.nonEmpty) {
logical.SubqueryAlias(identifier = common.get.getAlias, child = project)
} else {
project
}
logical.Project(projectList = projection.toSeq, child = baseRel)
}

private def transformUnresolvedExpression(exp: proto.Expression): UnresolvedAttribute = {
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/sql/connect/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def agg(self, exprs: Optional[GroupingFrame.MeasuresType]) -> "DataFrame":
return self.groupBy().agg(exprs)

def alias(self, alias: str) -> "DataFrame":
return DataFrame.withPlan(plan.Project(self._plan).withAlias(alias), session=self._session)
return DataFrame.withPlan(plan.SubqueryAlias(self._plan, alias), session=self._session)

def approxQuantile(self, col: ColumnRef, probabilities: Any, relativeError: Any) -> "DataFrame":
...
Expand Down
36 changes: 28 additions & 8 deletions python/pyspark/sql/connect/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,10 +201,6 @@ def _verify_expressions(self) -> None:
f"Only Expressions or String can be used for projections: '{c}'."
)

def withAlias(self, alias: str) -> LogicalPlan:
self.alias = alias
return self

def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation:
assert self._child is not None
proj_exprs = []
Expand All @@ -217,14 +213,10 @@ def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation:
proj_exprs.append(exp)
else:
proj_exprs.append(self.unresolved_attr(c))
common = proto.RelationCommon()
if self.alias is not None:
common.alias = self.alias

plan = proto.Relation()
plan.project.input.CopyFrom(self._child.plan(session))
plan.project.expressions.extend(proj_exprs)
plan.common.CopyFrom(common)
return plan

def print(self, indent: int = 0) -> str:
Expand Down Expand Up @@ -648,6 +640,34 @@ def _repr_html_(self) -> str:
"""


class SubqueryAlias(LogicalPlan):
"""Alias for a relation."""

def __init__(self, child: Optional["LogicalPlan"], alias: str) -> None:
super().__init__(child)
self._alias = alias

def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation:
rel = proto.Relation()
rel.subquery_alias.alias = self._alias
return rel

def print(self, indent: int = 0) -> str:
c_buf = self._child.print(indent + LogicalPlan.INDENT) if self._child else ""
return f"{' ' * indent}<SubqueryAlias alias={self._alias}>\n{c_buf}"

def _repr_html_(self) -> str:
return f"""
<ul>
<li>
<b>SubqueryAlias</b><br />
Child: {self._child_repr_()}
Alias: {self._alias}
</li>
</ul>
"""


class SQL(LogicalPlan):
def __init__(self, query: str) -> None:
super().__init__(None)
Expand Down
Loading