diff --git a/connector/connect/src/main/protobuf/spark/connect/relations.proto b/connector/connect/src/main/protobuf/spark/connect/relations.proto index bf4fb1f031ac6..e519a564d5c2e 100644 --- a/connector/connect/src/main/protobuf/spark/connect/relations.proto +++ b/connector/connect/src/main/protobuf/spark/connect/relations.proto @@ -45,6 +45,7 @@ message Relation { Offset offset = 13; Deduplicate deduplicate = 14; Range range = 15; + SubqueryAlias subquery_alias = 16; Unknown unknown = 999; } @@ -56,7 +57,6 @@ message Unknown {} // Common metadata of all relations. message RelationCommon { string source_info = 1; - string alias = 2; } // Relation that uses a SQL query to generate the output. @@ -223,6 +223,7 @@ message Sample { message Range { // Optional. Default value = 0 int32 start = 1; + // Required. int32 end = 2; // Optional. Default value = 1 Step step = 3; @@ -238,3 +239,13 @@ message Range { int32 num_partitions = 1; } } + +// Relation alias. +message SubqueryAlias { + // Required. The input relation. + Relation input = 1; + // Required. The alias. + string alias = 2; + // Optional. Qualifier of the alias. + repeated string qualifier = 3; +} diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala index 4d3df49dc72d2..ae88486400609 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala @@ -308,7 +308,7 @@ package object dsl { def as(alias: String): Relation = { Relation .newBuilder(logicalPlan) - .setCommon(RelationCommon.newBuilder().setAlias(alias)) + .setSubqueryAlias(SubqueryAlias.newBuilder().setAlias(alias).setInput(logicalPlan)) .build() } diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 481dbff10ee90..cb04d6eaf2906 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -22,6 +22,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.connect.proto import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.AliasIdentifier import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Expression} @@ -54,8 +55,8 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { } rel.getRelTypeCase match { - case proto.Relation.RelTypeCase.READ => transformReadRel(rel.getRead, common) - case proto.Relation.RelTypeCase.PROJECT => transformProject(rel.getProject, common) + case proto.Relation.RelTypeCase.READ => transformReadRel(rel.getRead) + case proto.Relation.RelTypeCase.PROJECT => transformProject(rel.getProject) case proto.Relation.RelTypeCase.FILTER => transformFilter(rel.getFilter) case proto.Relation.RelTypeCase.LIMIT => transformLimit(rel.getLimit) case proto.Relation.RelTypeCase.OFFSET => transformOffset(rel.getOffset) @@ -66,9 +67,11 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { case proto.Relation.RelTypeCase.AGGREGATE => transformAggregate(rel.getAggregate) case proto.Relation.RelTypeCase.SQL => transformSql(rel.getSql) case proto.Relation.RelTypeCase.LOCAL_RELATION => - transformLocalRelation(rel.getLocalRelation, common) + transformLocalRelation(rel.getLocalRelation) case proto.Relation.RelTypeCase.SAMPLE => transformSample(rel.getSample) case proto.Relation.RelTypeCase.RANGE => transformRange(rel.getRange) + case proto.Relation.RelTypeCase.SUBQUERY_ALIAS => + transformSubqueryAlias(rel.getSubqueryAlias) case proto.Relation.RelTypeCase.RELTYPE_NOT_SET => throw new IndexOutOfBoundsException("Expected Relation to be set, but is empty.") case _ => throw InvalidPlanInput(s"${rel.getUnknown} not supported.") @@ -79,6 +82,16 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { session.sessionState.sqlParser.parsePlan(sql.getQuery) } + private def transformSubqueryAlias(alias: proto.SubqueryAlias): LogicalPlan = { + val aliasIdentifier = + if (alias.getQualifierCount > 0) { + AliasIdentifier.apply(alias.getAlias, alias.getQualifierList.asScala.toSeq) + } else { + AliasIdentifier.apply(alias.getAlias) + } + SubqueryAlias(aliasIdentifier, transformRelation(alias.getInput)) + } + /** * All fields of [[proto.Sample]] are optional. However, given those are proto primitive types, * we cannot differentiate if the field is not or set when the field's value equals to the type @@ -141,35 +154,21 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { } } - private def transformLocalRelation( - rel: proto.LocalRelation, - common: Option[proto.RelationCommon]): LogicalPlan = { + private def transformLocalRelation(rel: proto.LocalRelation): LogicalPlan = { val attributes = rel.getAttributesList.asScala.map(transformAttribute(_)).toSeq - val relation = new org.apache.spark.sql.catalyst.plans.logical.LocalRelation(attributes) - if (common.nonEmpty && common.get.getAlias.nonEmpty) { - logical.SubqueryAlias(identifier = common.get.getAlias, child = relation) - } else { - relation - } + new org.apache.spark.sql.catalyst.plans.logical.LocalRelation(attributes) } private def transformAttribute(exp: proto.Expression.QualifiedAttribute): Attribute = { AttributeReference(exp.getName, DataTypeProtoConverter.toCatalystType(exp.getType))() } - private def transformReadRel( - rel: proto.Read, - common: Option[proto.RelationCommon]): LogicalPlan = { + private def transformReadRel(rel: proto.Read): LogicalPlan = { val baseRelation = rel.getReadTypeCase match { case proto.Read.ReadTypeCase.NAMED_TABLE => val multipartIdentifier = CatalystSqlParser.parseMultipartIdentifier(rel.getNamedTable.getUnparsedIdentifier) - val child = UnresolvedRelation(multipartIdentifier) - if (common.nonEmpty && common.get.getAlias.nonEmpty) { - SubqueryAlias(identifier = common.get.getAlias, child = child) - } else { - child - } + UnresolvedRelation(multipartIdentifier) case proto.Read.ReadTypeCase.DATA_SOURCE => if (rel.getDataSource.getFormat == "") { throw InvalidPlanInput("DataSource requires a format") @@ -193,9 +192,7 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { logical.Filter(condition = transformExpression(rel.getCondition), child = baseRel) } - private def transformProject( - rel: proto.Project, - common: Option[proto.RelationCommon]): LogicalPlan = { + private def transformProject(rel: proto.Project): LogicalPlan = { val baseRel = transformRelation(rel.getInput) // TODO: support the target field for *. val projection = @@ -204,12 +201,7 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { } else { rel.getExpressionsList.asScala.map(transformExpression).map(UnresolvedAlias(_)) } - val project = logical.Project(projectList = projection.toSeq, child = baseRel) - if (common.nonEmpty && common.get.getAlias.nonEmpty) { - logical.SubqueryAlias(identifier = common.get.getAlias, child = project) - } else { - project - } + logical.Project(projectList = projection.toSeq, child = baseRel) } private def transformUnresolvedExpression(exp: proto.Expression): UnresolvedAttribute = { diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index 1ec105f5afd1f..ea355e150bd2d 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -121,7 +121,7 @@ def agg(self, exprs: Optional[GroupingFrame.MeasuresType]) -> "DataFrame": return self.groupBy().agg(exprs) def alias(self, alias: str) -> "DataFrame": - return DataFrame.withPlan(plan.Project(self._plan).withAlias(alias), session=self._session) + return DataFrame.withPlan(plan.SubqueryAlias(self._plan, alias), session=self._session) def approxQuantile(self, col: ColumnRef, probabilities: Any, relativeError: Any) -> "DataFrame": ... diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index 07fecfb47f34d..f919b5156cf0e 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -201,10 +201,6 @@ def _verify_expressions(self) -> None: f"Only Expressions or String can be used for projections: '{c}'." ) - def withAlias(self, alias: str) -> LogicalPlan: - self.alias = alias - return self - def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation: assert self._child is not None proj_exprs = [] @@ -217,14 +213,10 @@ def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation: proj_exprs.append(exp) else: proj_exprs.append(self.unresolved_attr(c)) - common = proto.RelationCommon() - if self.alias is not None: - common.alias = self.alias plan = proto.Relation() plan.project.input.CopyFrom(self._child.plan(session)) plan.project.expressions.extend(proj_exprs) - plan.common.CopyFrom(common) return plan def print(self, indent: int = 0) -> str: @@ -648,6 +640,34 @@ def _repr_html_(self) -> str: """ +class SubqueryAlias(LogicalPlan): + """Alias for a relation.""" + + def __init__(self, child: Optional["LogicalPlan"], alias: str) -> None: + super().__init__(child) + self._alias = alias + + def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation: + rel = proto.Relation() + rel.subquery_alias.alias = self._alias + return rel + + def print(self, indent: int = 0) -> str: + c_buf = self._child.print(indent + LogicalPlan.INDENT) if self._child else "" + return f"{' ' * indent}\n{c_buf}" + + def _repr_html_(self) -> str: + return f""" + + """ + + class SQL(LogicalPlan): def __init__(self, query: str) -> None: super().__init__(None) diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py b/python/pyspark/sql/connect/proto/relations_pb2.py index bf242a5d7cafc..17b4d6fa8b63d 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.py +++ b/python/pyspark/sql/connect/proto/relations_pb2.py @@ -32,7 +32,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\xc5\x06\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0b\x32\x13.spark.connect.JoinH\x00R\x04join\x12\x34\n\x06set_op\x18\x06 \x01(\x0b\x32\x1b.spark.connect.SetOperationH\x00R\x05setOp\x12)\n\x04sort\x18\x07 \x01(\x0b\x32\x13.spark.connect.SortH\x00R\x04sort\x12,\n\x05limit\x18\x08 \x01(\x0b\x32\x14.spark.connect.LimitH\x00R\x05limit\x12\x38\n\taggregate\x18\t \x01(\x0b\x32\x18.spark.connect.AggregateH\x00R\taggregate\x12&\n\x03sql\x18\n \x01(\x0b\x32\x12.spark.connect.SQLH\x00R\x03sql\x12\x45\n\x0elocal_relation\x18\x0b \x01(\x0b\x32\x1c.spark.connect.LocalRelationH\x00R\rlocalRelation\x12/\n\x06sample\x18\x0c \x01(\x0b\x32\x15.spark.connect.SampleH\x00R\x06sample\x12/\n\x06offset\x18\r \x01(\x0b\x32\x15.spark.connect.OffsetH\x00R\x06offset\x12>\n\x0b\x64\x65\x64uplicate\x18\x0e \x01(\x0b\x32\x1a.spark.connect.DeduplicateH\x00R\x0b\x64\x65\x64uplicate\x12,\n\x05range\x18\x0f \x01(\x0b\x32\x14.spark.connect.RangeH\x00R\x05range\x12\x33\n\x07unknown\x18\xe7\x07 \x01(\x0b\x32\x16.spark.connect.UnknownH\x00R\x07unknownB\n\n\x08rel_type"\t\n\x07Unknown"G\n\x0eRelationCommon\x12\x1f\n\x0bsource_info\x18\x01 \x01(\tR\nsourceInfo\x12\x14\n\x05\x61lias\x18\x02 \x01(\tR\x05\x61lias"\x1b\n\x03SQL\x12\x14\n\x05query\x18\x01 \x01(\tR\x05query"\x9a\x03\n\x04Read\x12\x41\n\x0bnamed_table\x18\x01 \x01(\x0b\x32\x1e.spark.connect.Read.NamedTableH\x00R\nnamedTable\x12\x41\n\x0b\x64\x61ta_source\x18\x02 \x01(\x0b\x32\x1e.spark.connect.Read.DataSourceH\x00R\ndataSource\x1a=\n\nNamedTable\x12/\n\x13unparsed_identifier\x18\x01 \x01(\tR\x12unparsedIdentifier\x1a\xbf\x01\n\nDataSource\x12\x16\n\x06\x66ormat\x18\x01 \x01(\tR\x06\x66ormat\x12\x16\n\x06schema\x18\x02 \x01(\tR\x06schema\x12\x45\n\x07options\x18\x03 \x03(\x0b\x32+.spark.connect.Read.DataSource.OptionsEntryR\x07options\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\x0b\n\tread_type"u\n\x07Project\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12;\n\x0b\x65xpressions\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x0b\x65xpressions"p\n\x06\x46ilter\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x37\n\tcondition\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\tcondition"\xc2\x03\n\x04Join\x12+\n\x04left\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x04left\x12-\n\x05right\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\x05right\x12@\n\x0ejoin_condition\x18\x03 \x01(\x0b\x32\x19.spark.connect.ExpressionR\rjoinCondition\x12\x39\n\tjoin_type\x18\x04 \x01(\x0e\x32\x1c.spark.connect.Join.JoinTypeR\x08joinType\x12#\n\rusing_columns\x18\x05 \x03(\tR\x0cusingColumns"\xbb\x01\n\x08JoinType\x12\x19\n\x15JOIN_TYPE_UNSPECIFIED\x10\x00\x12\x13\n\x0fJOIN_TYPE_INNER\x10\x01\x12\x18\n\x14JOIN_TYPE_FULL_OUTER\x10\x02\x12\x18\n\x14JOIN_TYPE_LEFT_OUTER\x10\x03\x12\x19\n\x15JOIN_TYPE_RIGHT_OUTER\x10\x04\x12\x17\n\x13JOIN_TYPE_LEFT_ANTI\x10\x05\x12\x17\n\x13JOIN_TYPE_LEFT_SEMI\x10\x06"\xeb\x02\n\x0cSetOperation\x12\x36\n\nleft_input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\tleftInput\x12\x38\n\x0bright_input\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\nrightInput\x12\x45\n\x0bset_op_type\x18\x03 \x01(\x0e\x32%.spark.connect.SetOperation.SetOpTypeR\tsetOpType\x12\x15\n\x06is_all\x18\x04 \x01(\x08R\x05isAll\x12\x17\n\x07\x62y_name\x18\x05 \x01(\x08R\x06\x62yName"r\n\tSetOpType\x12\x1b\n\x17SET_OP_TYPE_UNSPECIFIED\x10\x00\x12\x19\n\x15SET_OP_TYPE_INTERSECT\x10\x01\x12\x15\n\x11SET_OP_TYPE_UNION\x10\x02\x12\x16\n\x12SET_OP_TYPE_EXCEPT\x10\x03"L\n\x05Limit\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x14\n\x05limit\x18\x02 \x01(\x05R\x05limit"O\n\x06Offset\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x16\n\x06offset\x18\x02 \x01(\x05R\x06offset"\xc5\x02\n\tAggregate\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12L\n\x14grouping_expressions\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x13groupingExpressions\x12Y\n\x12result_expressions\x18\x03 \x03(\x0b\x32*.spark.connect.Aggregate.AggregateFunctionR\x11resultExpressions\x1a`\n\x11\x41ggregateFunction\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x37\n\targuments\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\targuments"\xf6\x03\n\x04Sort\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12>\n\x0bsort_fields\x18\x02 \x03(\x0b\x32\x1d.spark.connect.Sort.SortFieldR\nsortFields\x1a\xbc\x01\n\tSortField\x12\x39\n\nexpression\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\nexpression\x12?\n\tdirection\x18\x02 \x01(\x0e\x32!.spark.connect.Sort.SortDirectionR\tdirection\x12\x33\n\x05nulls\x18\x03 \x01(\x0e\x32\x1d.spark.connect.Sort.SortNullsR\x05nulls"l\n\rSortDirection\x12\x1e\n\x1aSORT_DIRECTION_UNSPECIFIED\x10\x00\x12\x1c\n\x18SORT_DIRECTION_ASCENDING\x10\x01\x12\x1d\n\x19SORT_DIRECTION_DESCENDING\x10\x02"R\n\tSortNulls\x12\x1a\n\x16SORT_NULLS_UNSPECIFIED\x10\x00\x12\x14\n\x10SORT_NULLS_FIRST\x10\x01\x12\x13\n\x0fSORT_NULLS_LAST\x10\x02"\x8e\x01\n\x0b\x44\x65\x64uplicate\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12!\n\x0c\x63olumn_names\x18\x02 \x03(\tR\x0b\x63olumnNames\x12-\n\x13\x61ll_columns_as_keys\x18\x03 \x01(\x08R\x10\x61llColumnsAsKeys"]\n\rLocalRelation\x12L\n\nattributes\x18\x01 \x03(\x0b\x32,.spark.connect.Expression.QualifiedAttributeR\nattributes"\xf0\x01\n\x06Sample\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1f\n\x0blower_bound\x18\x02 \x01(\x01R\nlowerBound\x12\x1f\n\x0bupper_bound\x18\x03 \x01(\x01R\nupperBound\x12)\n\x10with_replacement\x18\x04 \x01(\x08R\x0fwithReplacement\x12.\n\x04seed\x18\x05 \x01(\x0b\x32\x1a.spark.connect.Sample.SeedR\x04seed\x1a\x1a\n\x04Seed\x12\x12\n\x04seed\x18\x01 \x01(\x03R\x04seed"\xfd\x01\n\x05Range\x12\x14\n\x05start\x18\x01 \x01(\x05R\x05start\x12\x10\n\x03\x65nd\x18\x02 \x01(\x05R\x03\x65nd\x12-\n\x04step\x18\x03 \x01(\x0b\x32\x19.spark.connect.Range.StepR\x04step\x12I\n\x0enum_partitions\x18\x04 \x01(\x0b\x32".spark.connect.Range.NumPartitionsR\rnumPartitions\x1a\x1a\n\x04Step\x12\x12\n\x04step\x18\x01 \x01(\x05R\x04step\x1a\x36\n\rNumPartitions\x12%\n\x0enum_partitions\x18\x01 \x01(\x05R\rnumPartitionsB"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3' + b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\x8c\x07\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0b\x32\x13.spark.connect.JoinH\x00R\x04join\x12\x34\n\x06set_op\x18\x06 \x01(\x0b\x32\x1b.spark.connect.SetOperationH\x00R\x05setOp\x12)\n\x04sort\x18\x07 \x01(\x0b\x32\x13.spark.connect.SortH\x00R\x04sort\x12,\n\x05limit\x18\x08 \x01(\x0b\x32\x14.spark.connect.LimitH\x00R\x05limit\x12\x38\n\taggregate\x18\t \x01(\x0b\x32\x18.spark.connect.AggregateH\x00R\taggregate\x12&\n\x03sql\x18\n \x01(\x0b\x32\x12.spark.connect.SQLH\x00R\x03sql\x12\x45\n\x0elocal_relation\x18\x0b \x01(\x0b\x32\x1c.spark.connect.LocalRelationH\x00R\rlocalRelation\x12/\n\x06sample\x18\x0c \x01(\x0b\x32\x15.spark.connect.SampleH\x00R\x06sample\x12/\n\x06offset\x18\r \x01(\x0b\x32\x15.spark.connect.OffsetH\x00R\x06offset\x12>\n\x0b\x64\x65\x64uplicate\x18\x0e \x01(\x0b\x32\x1a.spark.connect.DeduplicateH\x00R\x0b\x64\x65\x64uplicate\x12,\n\x05range\x18\x0f \x01(\x0b\x32\x14.spark.connect.RangeH\x00R\x05range\x12\x45\n\x0esubquery_alias\x18\x10 \x01(\x0b\x32\x1c.spark.connect.SubqueryAliasH\x00R\rsubqueryAlias\x12\x33\n\x07unknown\x18\xe7\x07 \x01(\x0b\x32\x16.spark.connect.UnknownH\x00R\x07unknownB\n\n\x08rel_type"\t\n\x07Unknown"1\n\x0eRelationCommon\x12\x1f\n\x0bsource_info\x18\x01 \x01(\tR\nsourceInfo"\x1b\n\x03SQL\x12\x14\n\x05query\x18\x01 \x01(\tR\x05query"\x9a\x03\n\x04Read\x12\x41\n\x0bnamed_table\x18\x01 \x01(\x0b\x32\x1e.spark.connect.Read.NamedTableH\x00R\nnamedTable\x12\x41\n\x0b\x64\x61ta_source\x18\x02 \x01(\x0b\x32\x1e.spark.connect.Read.DataSourceH\x00R\ndataSource\x1a=\n\nNamedTable\x12/\n\x13unparsed_identifier\x18\x01 \x01(\tR\x12unparsedIdentifier\x1a\xbf\x01\n\nDataSource\x12\x16\n\x06\x66ormat\x18\x01 \x01(\tR\x06\x66ormat\x12\x16\n\x06schema\x18\x02 \x01(\tR\x06schema\x12\x45\n\x07options\x18\x03 \x03(\x0b\x32+.spark.connect.Read.DataSource.OptionsEntryR\x07options\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\x0b\n\tread_type"u\n\x07Project\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12;\n\x0b\x65xpressions\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x0b\x65xpressions"p\n\x06\x46ilter\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x37\n\tcondition\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\tcondition"\xc2\x03\n\x04Join\x12+\n\x04left\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x04left\x12-\n\x05right\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\x05right\x12@\n\x0ejoin_condition\x18\x03 \x01(\x0b\x32\x19.spark.connect.ExpressionR\rjoinCondition\x12\x39\n\tjoin_type\x18\x04 \x01(\x0e\x32\x1c.spark.connect.Join.JoinTypeR\x08joinType\x12#\n\rusing_columns\x18\x05 \x03(\tR\x0cusingColumns"\xbb\x01\n\x08JoinType\x12\x19\n\x15JOIN_TYPE_UNSPECIFIED\x10\x00\x12\x13\n\x0fJOIN_TYPE_INNER\x10\x01\x12\x18\n\x14JOIN_TYPE_FULL_OUTER\x10\x02\x12\x18\n\x14JOIN_TYPE_LEFT_OUTER\x10\x03\x12\x19\n\x15JOIN_TYPE_RIGHT_OUTER\x10\x04\x12\x17\n\x13JOIN_TYPE_LEFT_ANTI\x10\x05\x12\x17\n\x13JOIN_TYPE_LEFT_SEMI\x10\x06"\xeb\x02\n\x0cSetOperation\x12\x36\n\nleft_input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\tleftInput\x12\x38\n\x0bright_input\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\nrightInput\x12\x45\n\x0bset_op_type\x18\x03 \x01(\x0e\x32%.spark.connect.SetOperation.SetOpTypeR\tsetOpType\x12\x15\n\x06is_all\x18\x04 \x01(\x08R\x05isAll\x12\x17\n\x07\x62y_name\x18\x05 \x01(\x08R\x06\x62yName"r\n\tSetOpType\x12\x1b\n\x17SET_OP_TYPE_UNSPECIFIED\x10\x00\x12\x19\n\x15SET_OP_TYPE_INTERSECT\x10\x01\x12\x15\n\x11SET_OP_TYPE_UNION\x10\x02\x12\x16\n\x12SET_OP_TYPE_EXCEPT\x10\x03"L\n\x05Limit\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x14\n\x05limit\x18\x02 \x01(\x05R\x05limit"O\n\x06Offset\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x16\n\x06offset\x18\x02 \x01(\x05R\x06offset"\xc5\x02\n\tAggregate\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12L\n\x14grouping_expressions\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x13groupingExpressions\x12Y\n\x12result_expressions\x18\x03 \x03(\x0b\x32*.spark.connect.Aggregate.AggregateFunctionR\x11resultExpressions\x1a`\n\x11\x41ggregateFunction\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x37\n\targuments\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\targuments"\xf6\x03\n\x04Sort\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12>\n\x0bsort_fields\x18\x02 \x03(\x0b\x32\x1d.spark.connect.Sort.SortFieldR\nsortFields\x1a\xbc\x01\n\tSortField\x12\x39\n\nexpression\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\nexpression\x12?\n\tdirection\x18\x02 \x01(\x0e\x32!.spark.connect.Sort.SortDirectionR\tdirection\x12\x33\n\x05nulls\x18\x03 \x01(\x0e\x32\x1d.spark.connect.Sort.SortNullsR\x05nulls"l\n\rSortDirection\x12\x1e\n\x1aSORT_DIRECTION_UNSPECIFIED\x10\x00\x12\x1c\n\x18SORT_DIRECTION_ASCENDING\x10\x01\x12\x1d\n\x19SORT_DIRECTION_DESCENDING\x10\x02"R\n\tSortNulls\x12\x1a\n\x16SORT_NULLS_UNSPECIFIED\x10\x00\x12\x14\n\x10SORT_NULLS_FIRST\x10\x01\x12\x13\n\x0fSORT_NULLS_LAST\x10\x02"\x8e\x01\n\x0b\x44\x65\x64uplicate\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12!\n\x0c\x63olumn_names\x18\x02 \x03(\tR\x0b\x63olumnNames\x12-\n\x13\x61ll_columns_as_keys\x18\x03 \x01(\x08R\x10\x61llColumnsAsKeys"]\n\rLocalRelation\x12L\n\nattributes\x18\x01 \x03(\x0b\x32,.spark.connect.Expression.QualifiedAttributeR\nattributes"\xf0\x01\n\x06Sample\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1f\n\x0blower_bound\x18\x02 \x01(\x01R\nlowerBound\x12\x1f\n\x0bupper_bound\x18\x03 \x01(\x01R\nupperBound\x12)\n\x10with_replacement\x18\x04 \x01(\x08R\x0fwithReplacement\x12.\n\x04seed\x18\x05 \x01(\x0b\x32\x1a.spark.connect.Sample.SeedR\x04seed\x1a\x1a\n\x04Seed\x12\x12\n\x04seed\x18\x01 \x01(\x03R\x04seed"\xfd\x01\n\x05Range\x12\x14\n\x05start\x18\x01 \x01(\x05R\x05start\x12\x10\n\x03\x65nd\x18\x02 \x01(\x05R\x03\x65nd\x12-\n\x04step\x18\x03 \x01(\x0b\x32\x19.spark.connect.Range.StepR\x04step\x12I\n\x0enum_partitions\x18\x04 \x01(\x0b\x32".spark.connect.Range.NumPartitionsR\rnumPartitions\x1a\x1a\n\x04Step\x12\x12\n\x04step\x18\x01 \x01(\x05R\x04step\x1a\x36\n\rNumPartitions\x12%\n\x0enum_partitions\x18\x01 \x01(\x05R\rnumPartitions"r\n\rSubqueryAlias\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x14\n\x05\x61lias\x18\x02 \x01(\tR\x05\x61lias\x12\x1c\n\tqualifier\x18\x03 \x03(\tR\tqualifierB"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3' ) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) @@ -44,61 +44,63 @@ _READ_DATASOURCE_OPTIONSENTRY._options = None _READ_DATASOURCE_OPTIONSENTRY._serialized_options = b"8\001" _RELATION._serialized_start = 82 - _RELATION._serialized_end = 919 - _UNKNOWN._serialized_start = 921 - _UNKNOWN._serialized_end = 930 - _RELATIONCOMMON._serialized_start = 932 - _RELATIONCOMMON._serialized_end = 1003 - _SQL._serialized_start = 1005 - _SQL._serialized_end = 1032 - _READ._serialized_start = 1035 - _READ._serialized_end = 1445 - _READ_NAMEDTABLE._serialized_start = 1177 - _READ_NAMEDTABLE._serialized_end = 1238 - _READ_DATASOURCE._serialized_start = 1241 - _READ_DATASOURCE._serialized_end = 1432 - _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 1374 - _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 1432 - _PROJECT._serialized_start = 1447 - _PROJECT._serialized_end = 1564 - _FILTER._serialized_start = 1566 - _FILTER._serialized_end = 1678 - _JOIN._serialized_start = 1681 - _JOIN._serialized_end = 2131 - _JOIN_JOINTYPE._serialized_start = 1944 - _JOIN_JOINTYPE._serialized_end = 2131 - _SETOPERATION._serialized_start = 2134 - _SETOPERATION._serialized_end = 2497 - _SETOPERATION_SETOPTYPE._serialized_start = 2383 - _SETOPERATION_SETOPTYPE._serialized_end = 2497 - _LIMIT._serialized_start = 2499 - _LIMIT._serialized_end = 2575 - _OFFSET._serialized_start = 2577 - _OFFSET._serialized_end = 2656 - _AGGREGATE._serialized_start = 2659 - _AGGREGATE._serialized_end = 2984 - _AGGREGATE_AGGREGATEFUNCTION._serialized_start = 2888 - _AGGREGATE_AGGREGATEFUNCTION._serialized_end = 2984 - _SORT._serialized_start = 2987 - _SORT._serialized_end = 3489 - _SORT_SORTFIELD._serialized_start = 3107 - _SORT_SORTFIELD._serialized_end = 3295 - _SORT_SORTDIRECTION._serialized_start = 3297 - _SORT_SORTDIRECTION._serialized_end = 3405 - _SORT_SORTNULLS._serialized_start = 3407 - _SORT_SORTNULLS._serialized_end = 3489 - _DEDUPLICATE._serialized_start = 3492 - _DEDUPLICATE._serialized_end = 3634 - _LOCALRELATION._serialized_start = 3636 - _LOCALRELATION._serialized_end = 3729 - _SAMPLE._serialized_start = 3732 - _SAMPLE._serialized_end = 3972 - _SAMPLE_SEED._serialized_start = 3946 - _SAMPLE_SEED._serialized_end = 3972 - _RANGE._serialized_start = 3975 - _RANGE._serialized_end = 4228 - _RANGE_STEP._serialized_start = 4146 - _RANGE_STEP._serialized_end = 4172 - _RANGE_NUMPARTITIONS._serialized_start = 4174 - _RANGE_NUMPARTITIONS._serialized_end = 4228 + _RELATION._serialized_end = 990 + _UNKNOWN._serialized_start = 992 + _UNKNOWN._serialized_end = 1001 + _RELATIONCOMMON._serialized_start = 1003 + _RELATIONCOMMON._serialized_end = 1052 + _SQL._serialized_start = 1054 + _SQL._serialized_end = 1081 + _READ._serialized_start = 1084 + _READ._serialized_end = 1494 + _READ_NAMEDTABLE._serialized_start = 1226 + _READ_NAMEDTABLE._serialized_end = 1287 + _READ_DATASOURCE._serialized_start = 1290 + _READ_DATASOURCE._serialized_end = 1481 + _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 1423 + _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 1481 + _PROJECT._serialized_start = 1496 + _PROJECT._serialized_end = 1613 + _FILTER._serialized_start = 1615 + _FILTER._serialized_end = 1727 + _JOIN._serialized_start = 1730 + _JOIN._serialized_end = 2180 + _JOIN_JOINTYPE._serialized_start = 1993 + _JOIN_JOINTYPE._serialized_end = 2180 + _SETOPERATION._serialized_start = 2183 + _SETOPERATION._serialized_end = 2546 + _SETOPERATION_SETOPTYPE._serialized_start = 2432 + _SETOPERATION_SETOPTYPE._serialized_end = 2546 + _LIMIT._serialized_start = 2548 + _LIMIT._serialized_end = 2624 + _OFFSET._serialized_start = 2626 + _OFFSET._serialized_end = 2705 + _AGGREGATE._serialized_start = 2708 + _AGGREGATE._serialized_end = 3033 + _AGGREGATE_AGGREGATEFUNCTION._serialized_start = 2937 + _AGGREGATE_AGGREGATEFUNCTION._serialized_end = 3033 + _SORT._serialized_start = 3036 + _SORT._serialized_end = 3538 + _SORT_SORTFIELD._serialized_start = 3156 + _SORT_SORTFIELD._serialized_end = 3344 + _SORT_SORTDIRECTION._serialized_start = 3346 + _SORT_SORTDIRECTION._serialized_end = 3454 + _SORT_SORTNULLS._serialized_start = 3456 + _SORT_SORTNULLS._serialized_end = 3538 + _DEDUPLICATE._serialized_start = 3541 + _DEDUPLICATE._serialized_end = 3683 + _LOCALRELATION._serialized_start = 3685 + _LOCALRELATION._serialized_end = 3778 + _SAMPLE._serialized_start = 3781 + _SAMPLE._serialized_end = 4021 + _SAMPLE_SEED._serialized_start = 3995 + _SAMPLE_SEED._serialized_end = 4021 + _RANGE._serialized_start = 4024 + _RANGE._serialized_end = 4277 + _RANGE_STEP._serialized_start = 4195 + _RANGE_STEP._serialized_end = 4221 + _RANGE_NUMPARTITIONS._serialized_start = 4223 + _RANGE_NUMPARTITIONS._serialized_end = 4277 + _SUBQUERYALIAS._serialized_start = 4279 + _SUBQUERYALIAS._serialized_end = 4393 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi b/python/pyspark/sql/connect/proto/relations_pb2.pyi index 7618ed230e750..720a3bc99befb 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.pyi +++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi @@ -74,6 +74,7 @@ class Relation(google.protobuf.message.Message): OFFSET_FIELD_NUMBER: builtins.int DEDUPLICATE_FIELD_NUMBER: builtins.int RANGE_FIELD_NUMBER: builtins.int + SUBQUERY_ALIAS_FIELD_NUMBER: builtins.int UNKNOWN_FIELD_NUMBER: builtins.int @property def common(self) -> global___RelationCommon: ... @@ -106,6 +107,8 @@ class Relation(google.protobuf.message.Message): @property def range(self) -> global___Range: ... @property + def subquery_alias(self) -> global___SubqueryAlias: ... + @property def unknown(self) -> global___Unknown: ... def __init__( self, @@ -125,6 +128,7 @@ class Relation(google.protobuf.message.Message): offset: global___Offset | None = ..., deduplicate: global___Deduplicate | None = ..., range: global___Range | None = ..., + subquery_alias: global___SubqueryAlias | None = ..., unknown: global___Unknown | None = ..., ) -> None: ... def HasField( @@ -162,6 +166,8 @@ class Relation(google.protobuf.message.Message): b"sort", "sql", b"sql", + "subquery_alias", + b"subquery_alias", "unknown", b"unknown", ], @@ -201,6 +207,8 @@ class Relation(google.protobuf.message.Message): b"sort", "sql", b"sql", + "subquery_alias", + b"subquery_alias", "unknown", b"unknown", ], @@ -222,6 +230,7 @@ class Relation(google.protobuf.message.Message): "offset", "deduplicate", "range", + "subquery_alias", "unknown", ] | None: ... @@ -244,18 +253,14 @@ class RelationCommon(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor SOURCE_INFO_FIELD_NUMBER: builtins.int - ALIAS_FIELD_NUMBER: builtins.int source_info: builtins.str - alias: builtins.str def __init__( self, *, source_info: builtins.str = ..., - alias: builtins.str = ..., ) -> None: ... def ClearField( - self, - field_name: typing_extensions.Literal["alias", b"alias", "source_info", b"source_info"], + self, field_name: typing_extensions.Literal["source_info", b"source_info"] ) -> None: ... global___RelationCommon = RelationCommon @@ -983,6 +988,7 @@ class Range(google.protobuf.message.Message): start: builtins.int """Optional. Default value = 0""" end: builtins.int + """Required.""" @property def step(self) -> global___Range.Step: """Optional. Default value = 1""" @@ -1011,3 +1017,40 @@ class Range(google.protobuf.message.Message): ) -> None: ... global___Range = Range + +class SubqueryAlias(google.protobuf.message.Message): + """Relation alias.""" + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + INPUT_FIELD_NUMBER: builtins.int + ALIAS_FIELD_NUMBER: builtins.int + QUALIFIER_FIELD_NUMBER: builtins.int + @property + def input(self) -> global___Relation: + """Required. The input relation.""" + alias: builtins.str + """Required. The alias.""" + @property + def qualifier( + self, + ) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: + """Optional. Qualifier of the alias.""" + def __init__( + self, + *, + input: global___Relation | None = ..., + alias: builtins.str = ..., + qualifier: collections.abc.Iterable[builtins.str] | None = ..., + ) -> None: ... + def HasField( + self, field_name: typing_extensions.Literal["input", b"input"] + ) -> builtins.bool: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "alias", b"alias", "input", b"input", "qualifier", b"qualifier" + ], + ) -> None: ... + +global___SubqueryAlias = SubqueryAlias diff --git a/python/pyspark/sql/tests/connect/test_connect_plan_only.py b/python/pyspark/sql/tests/connect/test_connect_plan_only.py index 622340a3ef11b..05fcedd5c1417 100644 --- a/python/pyspark/sql/tests/connect/test_connect_plan_only.py +++ b/python/pyspark/sql/tests/connect/test_connect_plan_only.py @@ -116,7 +116,7 @@ def test_deduplicate(self): def test_relation_alias(self): df = self.connect.readTable(table_name=self.tbl_name) plan = df.alias("table_alias")._plan.to_proto(self.connect) - self.assertEqual(plan.root.common.alias, "table_alias") + self.assertEqual(plan.root.subquery_alias.alias, "table_alias") def test_datasource_read(self): reader = DataFrameReader(self.connect)