diff --git a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto index 349e2455be31a..f7feae0e2f012 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto @@ -225,9 +225,12 @@ message Expression { // UnresolvedStar is used to expand all the fields of a relation or struct. message UnresolvedStar { - // (Optional) The target of the expansion, either be a table name or struct name, this - // is a list of identifiers that is the path of the expansion. - repeated string target = 1; + + // (Optional) The target of the expansion. + // + // If set, it should end with '.*' and will be parsed by 'parseAttributeName' + // in the server side. + optional string unparsed_target = 1; } // Represents all of the input attributes to a given relational operator, for example in diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 56956e7fff19e..faebd438138fa 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -1002,11 +1002,19 @@ class SparkConnectPlanner(session: SparkSession) { session.sessionState.sqlParser.parseExpression(expr.getExpression) } - private def transformUnresolvedStar(regex: proto.Expression.UnresolvedStar): Expression = { - if (regex.getTargetList.isEmpty) { - UnresolvedStar(Option.empty) + private def transformUnresolvedStar(star: proto.Expression.UnresolvedStar): UnresolvedStar = { + if (star.hasUnparsedTarget) { + val target = star.getUnparsedTarget + if (!target.endsWith(".*")) { + throw InvalidPlanInput( + s"UnresolvedStar requires a unparsed target ending with '.*', " + + s"but got $target.") + } + + UnresolvedStar( + Some(UnresolvedAttribute.parseAttributeName(target.substring(0, target.length - 2)))) } else { - UnresolvedStar(Some(regex.getTargetList.asScala.toSeq)) + UnresolvedStar(None) } } diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala index 63e5415b44f9c..d8baa182e5ab5 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala @@ -552,7 +552,7 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest { .addExpressions( proto.Expression .newBuilder() - .setUnresolvedStar(UnresolvedStar.newBuilder().addTarget("a").addTarget("b").build()) + .setUnresolvedStar(UnresolvedStar.newBuilder().setUnparsedTarget("a.b.*").build()) .build()) .build() diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index 11c0ef6fc06f3..d82862a870b85 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -249,7 +249,7 @@ def _convert_col(col: "ColumnOrName") -> "ColumnOrName": else: return Column(SortOrder(col._expr)) else: - return Column(SortOrder(ColumnReference(name=col))) + return Column(SortOrder(ColumnReference(col))) if isinstance(numPartitions, int): if not numPartitions > 0: @@ -1176,7 +1176,7 @@ def sampleBy( from pyspark.sql.connect.expressions import ColumnReference if isinstance(col, str): - col = Column(ColumnReference(name=col)) + col = Column(ColumnReference(col)) elif not isinstance(col, Column): raise TypeError("col must be a string or a column, but got %r" % type(col)) if not isinstance(fractions, dict): diff --git a/python/pyspark/sql/connect/expressions.py b/python/pyspark/sql/connect/expressions.py index 6469c1917ec53..c8d361af2a554 100644 --- a/python/pyspark/sql/connect/expressions.py +++ b/python/pyspark/sql/connect/expressions.py @@ -336,10 +336,10 @@ class ColumnReference(Expression): treat it as an unresolved attribute. Attributes that have the same fully qualified name are identical""" - def __init__(self, name: str) -> None: + def __init__(self, unparsed_identifier: str) -> None: super().__init__() - assert isinstance(name, str) - self._unparsed_identifier = name + assert isinstance(unparsed_identifier, str) + self._unparsed_identifier = unparsed_identifier def name(self) -> str: """Returns the qualified name of the column reference.""" @@ -354,6 +354,43 @@ def to_plan(self, session: "SparkConnectClient") -> proto.Expression: def __repr__(self) -> str: return f"{self._unparsed_identifier}" + def __eq__(self, other: Any) -> bool: + return ( + other is not None + and isinstance(other, ColumnReference) + and other._unparsed_identifier == self._unparsed_identifier + ) + + +class UnresolvedStar(Expression): + def __init__(self, unparsed_target: Optional[str]): + super().__init__() + + if unparsed_target is not None: + assert isinstance(unparsed_target, str) and unparsed_target.endswith(".*") + + self._unparsed_target = unparsed_target + + def to_plan(self, session: "SparkConnectClient") -> "proto.Expression": + expr = proto.Expression() + expr.unresolved_star.SetInParent() + if self._unparsed_target is not None: + expr.unresolved_star.unparsed_target = self._unparsed_target + return expr + + def __repr__(self) -> str: + if self._unparsed_target is not None: + return f"unresolvedstar({self._unparsed_target})" + else: + return "unresolvedstar()" + + def __eq__(self, other: Any) -> bool: + return ( + other is not None + and isinstance(other, UnresolvedStar) + and other._unparsed_target == self._unparsed_target + ) + class SQLExpression(Expression): """Returns Expression which contains a string which is a SQL expression @@ -370,6 +407,9 @@ def to_plan(self, session: "SparkConnectClient") -> proto.Expression: expr.expression_string.expression = self._expr return expr + def __eq__(self, other: Any) -> bool: + return other is not None and isinstance(other, SQLExpression) and other._expr == self._expr + class SortOrder(Expression): def __init__(self, child: Expression, ascending: bool = True, nullsFirst: bool = True) -> None: diff --git a/python/pyspark/sql/connect/functions.py b/python/pyspark/sql/connect/functions.py index 5f1eb9c06d786..c73e6ec1ee4b5 100644 --- a/python/pyspark/sql/connect/functions.py +++ b/python/pyspark/sql/connect/functions.py @@ -40,6 +40,7 @@ LiteralExpression, ColumnReference, UnresolvedFunction, + UnresolvedStar, SQLExpression, LambdaFunction, UnresolvedNamedLambdaVariable, @@ -186,7 +187,12 @@ def _options_to_col(options: Dict[str, Any]) -> Column: def col(col: str) -> Column: - return Column(ColumnReference(col)) + if col == "*": + return Column(UnresolvedStar(unparsed_target=None)) + elif col.endswith(".*"): + return Column(UnresolvedStar(unparsed_target=col)) + else: + return Column(ColumnReference(unparsed_identifier=col)) col.__doc__ = pysparkfuncs.col.__doc__ @@ -2389,9 +2395,6 @@ def _test() -> None: # TODO(SPARK-41843): Implement SparkSession.udf del pyspark.sql.connect.functions.call_udf.__doc__ - # TODO(SPARK-41845): Fix count bug - del pyspark.sql.connect.functions.count.__doc__ - globs["spark"] = ( PySparkSession.builder.appName("sql.connect.functions tests") .remote("local[4]") diff --git a/python/pyspark/sql/connect/group.py b/python/pyspark/sql/connect/group.py index 3aa070ff8b6c9..cc728808d3a3b 100644 --- a/python/pyspark/sql/connect/group.py +++ b/python/pyspark/sql/connect/group.py @@ -80,14 +80,8 @@ def agg(self, *exprs: Union[Column, Dict[str, str]]) -> "DataFrame": assert exprs, "exprs should not be empty" if len(exprs) == 1 and isinstance(exprs[0], dict): - # There is a special case for count(*) which is rewritten into count(1). # Convert the dict into key value pairs - aggregate_cols = [ - _invoke_function( - exprs[0][k], lit(1) if exprs[0][k] == "count" and k == "*" else col(k) - ) - for k in exprs[0] - ] + aggregate_cols = [_invoke_function(exprs[0][k], col(k)) for k in exprs[0]] else: # Columns assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column" diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py b/python/pyspark/sql/connect/proto/expressions_pb2.py index 462384999bb18..87c169641029c 100644 --- a/python/pyspark/sql/connect/proto/expressions_pb2.py +++ b/python/pyspark/sql/connect/proto/expressions_pb2.py @@ -34,7 +34,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\xe8#\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunction\x12Y\n\x11\x65xpression_string\x18\x04 \x01(\x0b\x32*.spark.connect.Expression.ExpressionStringH\x00R\x10\x65xpressionString\x12S\n\x0funresolved_star\x18\x05 \x01(\x0b\x32(.spark.connect.Expression.UnresolvedStarH\x00R\x0eunresolvedStar\x12\x37\n\x05\x61lias\x18\x06 \x01(\x0b\x32\x1f.spark.connect.Expression.AliasH\x00R\x05\x61lias\x12\x34\n\x04\x63\x61st\x18\x07 \x01(\x0b\x32\x1e.spark.connect.Expression.CastH\x00R\x04\x63\x61st\x12V\n\x10unresolved_regex\x18\x08 \x01(\x0b\x32).spark.connect.Expression.UnresolvedRegexH\x00R\x0funresolvedRegex\x12\x44\n\nsort_order\x18\t \x01(\x0b\x32#.spark.connect.Expression.SortOrderH\x00R\tsortOrder\x12S\n\x0flambda_function\x18\n \x01(\x0b\x32(.spark.connect.Expression.LambdaFunctionH\x00R\x0elambdaFunction\x12:\n\x06window\x18\x0b \x01(\x0b\x32 .spark.connect.Expression.WindowH\x00R\x06window\x12l\n\x18unresolved_extract_value\x18\x0c \x01(\x0b\x32\x30.spark.connect.Expression.UnresolvedExtractValueH\x00R\x16unresolvedExtractValue\x12M\n\rupdate_fields\x18\r \x01(\x0b\x32&.spark.connect.Expression.UpdateFieldsH\x00R\x0cupdateFields\x12\x82\x01\n unresolved_named_lambda_variable\x18\x0e \x01(\x0b\x32\x37.spark.connect.Expression.UnresolvedNamedLambdaVariableH\x00R\x1dunresolvedNamedLambdaVariable\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textension\x1a\x8f\x06\n\x06Window\x12\x42\n\x0fwindow_function\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x0ewindowFunction\x12@\n\x0epartition_spec\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\rpartitionSpec\x12\x42\n\norder_spec\x18\x03 \x03(\x0b\x32#.spark.connect.Expression.SortOrderR\torderSpec\x12K\n\nframe_spec\x18\x04 \x01(\x0b\x32,.spark.connect.Expression.Window.WindowFrameR\tframeSpec\x1a\xed\x03\n\x0bWindowFrame\x12U\n\nframe_type\x18\x01 \x01(\x0e\x32\x36.spark.connect.Expression.Window.WindowFrame.FrameTypeR\tframeType\x12P\n\x05lower\x18\x02 \x01(\x0b\x32:.spark.connect.Expression.Window.WindowFrame.FrameBoundaryR\x05lower\x12P\n\x05upper\x18\x03 \x01(\x0b\x32:.spark.connect.Expression.Window.WindowFrame.FrameBoundaryR\x05upper\x1a\x91\x01\n\rFrameBoundary\x12!\n\x0b\x63urrent_row\x18\x01 \x01(\x08H\x00R\ncurrentRow\x12\x1e\n\tunbounded\x18\x02 \x01(\x08H\x00R\tunbounded\x12\x31\n\x05value\x18\x03 \x01(\x0b\x32\x19.spark.connect.ExpressionH\x00R\x05valueB\n\n\x08\x62oundary"O\n\tFrameType\x12\x18\n\x14\x46RAME_TYPE_UNDEFINED\x10\x00\x12\x12\n\x0e\x46RAME_TYPE_ROW\x10\x01\x12\x14\n\x10\x46RAME_TYPE_RANGE\x10\x02\x1a\xa9\x03\n\tSortOrder\x12/\n\x05\x63hild\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x05\x63hild\x12O\n\tdirection\x18\x02 \x01(\x0e\x32\x31.spark.connect.Expression.SortOrder.SortDirectionR\tdirection\x12U\n\rnull_ordering\x18\x03 \x01(\x0e\x32\x30.spark.connect.Expression.SortOrder.NullOrderingR\x0cnullOrdering"l\n\rSortDirection\x12\x1e\n\x1aSORT_DIRECTION_UNSPECIFIED\x10\x00\x12\x1c\n\x18SORT_DIRECTION_ASCENDING\x10\x01\x12\x1d\n\x19SORT_DIRECTION_DESCENDING\x10\x02"U\n\x0cNullOrdering\x12\x1a\n\x16SORT_NULLS_UNSPECIFIED\x10\x00\x12\x14\n\x10SORT_NULLS_FIRST\x10\x01\x12\x13\n\x0fSORT_NULLS_LAST\x10\x02\x1a\x91\x01\n\x04\x43\x61st\x12-\n\x04\x65xpr\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x04\x65xpr\x12-\n\x04type\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x00R\x04type\x12\x1b\n\x08type_str\x18\x03 \x01(\tH\x00R\x07typeStrB\x0e\n\x0c\x63\x61st_to_type\x1a\xec\x06\n\x07Literal\x12-\n\x04null\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x00R\x04null\x12\x18\n\x06\x62inary\x18\x02 \x01(\x0cH\x00R\x06\x62inary\x12\x1a\n\x07\x62oolean\x18\x03 \x01(\x08H\x00R\x07\x62oolean\x12\x14\n\x04\x62yte\x18\x04 \x01(\x05H\x00R\x04\x62yte\x12\x16\n\x05short\x18\x05 \x01(\x05H\x00R\x05short\x12\x1a\n\x07integer\x18\x06 \x01(\x05H\x00R\x07integer\x12\x14\n\x04long\x18\x07 \x01(\x03H\x00R\x04long\x12\x16\n\x05\x66loat\x18\n \x01(\x02H\x00R\x05\x66loat\x12\x18\n\x06\x64ouble\x18\x0b \x01(\x01H\x00R\x06\x64ouble\x12\x45\n\x07\x64\x65\x63imal\x18\x0c \x01(\x0b\x32).spark.connect.Expression.Literal.DecimalH\x00R\x07\x64\x65\x63imal\x12\x18\n\x06string\x18\r \x01(\tH\x00R\x06string\x12\x14\n\x04\x64\x61te\x18\x10 \x01(\x05H\x00R\x04\x64\x61te\x12\x1e\n\ttimestamp\x18\x11 \x01(\x03H\x00R\ttimestamp\x12%\n\rtimestamp_ntz\x18\x12 \x01(\x03H\x00R\x0ctimestampNtz\x12\x61\n\x11\x63\x61lendar_interval\x18\x13 \x01(\x0b\x32\x32.spark.connect.Expression.Literal.CalendarIntervalH\x00R\x10\x63\x61lendarInterval\x12\x30\n\x13year_month_interval\x18\x14 \x01(\x05H\x00R\x11yearMonthInterval\x12,\n\x11\x64\x61y_time_interval\x18\x15 \x01(\x03H\x00R\x0f\x64\x61yTimeInterval\x1au\n\x07\x44\x65\x63imal\x12\x14\n\x05value\x18\x01 \x01(\tR\x05value\x12!\n\tprecision\x18\x02 \x01(\x05H\x00R\tprecision\x88\x01\x01\x12\x19\n\x05scale\x18\x03 \x01(\x05H\x01R\x05scale\x88\x01\x01\x42\x0c\n\n_precisionB\x08\n\x06_scale\x1a\x62\n\x10\x43\x61lendarInterval\x12\x16\n\x06months\x18\x01 \x01(\x05R\x06months\x12\x12\n\x04\x64\x61ys\x18\x02 \x01(\x05R\x04\x64\x61ys\x12"\n\x0cmicroseconds\x18\x03 \x01(\x03R\x0cmicrosecondsB\x0e\n\x0cliteral_type\x1a\x46\n\x13UnresolvedAttribute\x12/\n\x13unparsed_identifier\x18\x01 \x01(\tR\x12unparsedIdentifier\x1a\xcc\x01\n\x12UnresolvedFunction\x12#\n\rfunction_name\x18\x01 \x01(\tR\x0c\x66unctionName\x12\x37\n\targuments\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\targuments\x12\x1f\n\x0bis_distinct\x18\x03 \x01(\x08R\nisDistinct\x12\x37\n\x18is_user_defined_function\x18\x04 \x01(\x08R\x15isUserDefinedFunction\x1a\x32\n\x10\x45xpressionString\x12\x1e\n\nexpression\x18\x01 \x01(\tR\nexpression\x1a(\n\x0eUnresolvedStar\x12\x16\n\x06target\x18\x01 \x03(\tR\x06target\x1a,\n\x0fUnresolvedRegex\x12\x19\n\x08\x63ol_name\x18\x01 \x01(\tR\x07\x63olName\x1a\x84\x01\n\x16UnresolvedExtractValue\x12/\n\x05\x63hild\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x05\x63hild\x12\x39\n\nextraction\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\nextraction\x1a\xbb\x01\n\x0cUpdateFields\x12\x46\n\x11struct_expression\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x10structExpression\x12\x1d\n\nfield_name\x18\x02 \x01(\tR\tfieldName\x12\x44\n\x10value_expression\x18\x03 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x0fvalueExpression\x1ax\n\x05\x41lias\x12-\n\x04\x65xpr\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x04\x65xpr\x12\x12\n\x04name\x18\x02 \x03(\tR\x04name\x12\x1f\n\x08metadata\x18\x03 \x01(\tH\x00R\x08metadata\x88\x01\x01\x42\x0b\n\t_metadata\x1a\x9e\x01\n\x0eLambdaFunction\x12\x35\n\x08\x66unction\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x08\x66unction\x12U\n\targuments\x18\x02 \x03(\x0b\x32\x37.spark.connect.Expression.UnresolvedNamedLambdaVariableR\targuments\x1a>\n\x1dUnresolvedNamedLambdaVariable\x12\x1d\n\nname_parts\x18\x01 \x03(\tR\tnamePartsB\x0b\n\texpr_typeB"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3' + b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\x92$\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunction\x12Y\n\x11\x65xpression_string\x18\x04 \x01(\x0b\x32*.spark.connect.Expression.ExpressionStringH\x00R\x10\x65xpressionString\x12S\n\x0funresolved_star\x18\x05 \x01(\x0b\x32(.spark.connect.Expression.UnresolvedStarH\x00R\x0eunresolvedStar\x12\x37\n\x05\x61lias\x18\x06 \x01(\x0b\x32\x1f.spark.connect.Expression.AliasH\x00R\x05\x61lias\x12\x34\n\x04\x63\x61st\x18\x07 \x01(\x0b\x32\x1e.spark.connect.Expression.CastH\x00R\x04\x63\x61st\x12V\n\x10unresolved_regex\x18\x08 \x01(\x0b\x32).spark.connect.Expression.UnresolvedRegexH\x00R\x0funresolvedRegex\x12\x44\n\nsort_order\x18\t \x01(\x0b\x32#.spark.connect.Expression.SortOrderH\x00R\tsortOrder\x12S\n\x0flambda_function\x18\n \x01(\x0b\x32(.spark.connect.Expression.LambdaFunctionH\x00R\x0elambdaFunction\x12:\n\x06window\x18\x0b \x01(\x0b\x32 .spark.connect.Expression.WindowH\x00R\x06window\x12l\n\x18unresolved_extract_value\x18\x0c \x01(\x0b\x32\x30.spark.connect.Expression.UnresolvedExtractValueH\x00R\x16unresolvedExtractValue\x12M\n\rupdate_fields\x18\r \x01(\x0b\x32&.spark.connect.Expression.UpdateFieldsH\x00R\x0cupdateFields\x12\x82\x01\n unresolved_named_lambda_variable\x18\x0e \x01(\x0b\x32\x37.spark.connect.Expression.UnresolvedNamedLambdaVariableH\x00R\x1dunresolvedNamedLambdaVariable\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textension\x1a\x8f\x06\n\x06Window\x12\x42\n\x0fwindow_function\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x0ewindowFunction\x12@\n\x0epartition_spec\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\rpartitionSpec\x12\x42\n\norder_spec\x18\x03 \x03(\x0b\x32#.spark.connect.Expression.SortOrderR\torderSpec\x12K\n\nframe_spec\x18\x04 \x01(\x0b\x32,.spark.connect.Expression.Window.WindowFrameR\tframeSpec\x1a\xed\x03\n\x0bWindowFrame\x12U\n\nframe_type\x18\x01 \x01(\x0e\x32\x36.spark.connect.Expression.Window.WindowFrame.FrameTypeR\tframeType\x12P\n\x05lower\x18\x02 \x01(\x0b\x32:.spark.connect.Expression.Window.WindowFrame.FrameBoundaryR\x05lower\x12P\n\x05upper\x18\x03 \x01(\x0b\x32:.spark.connect.Expression.Window.WindowFrame.FrameBoundaryR\x05upper\x1a\x91\x01\n\rFrameBoundary\x12!\n\x0b\x63urrent_row\x18\x01 \x01(\x08H\x00R\ncurrentRow\x12\x1e\n\tunbounded\x18\x02 \x01(\x08H\x00R\tunbounded\x12\x31\n\x05value\x18\x03 \x01(\x0b\x32\x19.spark.connect.ExpressionH\x00R\x05valueB\n\n\x08\x62oundary"O\n\tFrameType\x12\x18\n\x14\x46RAME_TYPE_UNDEFINED\x10\x00\x12\x12\n\x0e\x46RAME_TYPE_ROW\x10\x01\x12\x14\n\x10\x46RAME_TYPE_RANGE\x10\x02\x1a\xa9\x03\n\tSortOrder\x12/\n\x05\x63hild\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x05\x63hild\x12O\n\tdirection\x18\x02 \x01(\x0e\x32\x31.spark.connect.Expression.SortOrder.SortDirectionR\tdirection\x12U\n\rnull_ordering\x18\x03 \x01(\x0e\x32\x30.spark.connect.Expression.SortOrder.NullOrderingR\x0cnullOrdering"l\n\rSortDirection\x12\x1e\n\x1aSORT_DIRECTION_UNSPECIFIED\x10\x00\x12\x1c\n\x18SORT_DIRECTION_ASCENDING\x10\x01\x12\x1d\n\x19SORT_DIRECTION_DESCENDING\x10\x02"U\n\x0cNullOrdering\x12\x1a\n\x16SORT_NULLS_UNSPECIFIED\x10\x00\x12\x14\n\x10SORT_NULLS_FIRST\x10\x01\x12\x13\n\x0fSORT_NULLS_LAST\x10\x02\x1a\x91\x01\n\x04\x43\x61st\x12-\n\x04\x65xpr\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x04\x65xpr\x12-\n\x04type\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x00R\x04type\x12\x1b\n\x08type_str\x18\x03 \x01(\tH\x00R\x07typeStrB\x0e\n\x0c\x63\x61st_to_type\x1a\xec\x06\n\x07Literal\x12-\n\x04null\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x00R\x04null\x12\x18\n\x06\x62inary\x18\x02 \x01(\x0cH\x00R\x06\x62inary\x12\x1a\n\x07\x62oolean\x18\x03 \x01(\x08H\x00R\x07\x62oolean\x12\x14\n\x04\x62yte\x18\x04 \x01(\x05H\x00R\x04\x62yte\x12\x16\n\x05short\x18\x05 \x01(\x05H\x00R\x05short\x12\x1a\n\x07integer\x18\x06 \x01(\x05H\x00R\x07integer\x12\x14\n\x04long\x18\x07 \x01(\x03H\x00R\x04long\x12\x16\n\x05\x66loat\x18\n \x01(\x02H\x00R\x05\x66loat\x12\x18\n\x06\x64ouble\x18\x0b \x01(\x01H\x00R\x06\x64ouble\x12\x45\n\x07\x64\x65\x63imal\x18\x0c \x01(\x0b\x32).spark.connect.Expression.Literal.DecimalH\x00R\x07\x64\x65\x63imal\x12\x18\n\x06string\x18\r \x01(\tH\x00R\x06string\x12\x14\n\x04\x64\x61te\x18\x10 \x01(\x05H\x00R\x04\x64\x61te\x12\x1e\n\ttimestamp\x18\x11 \x01(\x03H\x00R\ttimestamp\x12%\n\rtimestamp_ntz\x18\x12 \x01(\x03H\x00R\x0ctimestampNtz\x12\x61\n\x11\x63\x61lendar_interval\x18\x13 \x01(\x0b\x32\x32.spark.connect.Expression.Literal.CalendarIntervalH\x00R\x10\x63\x61lendarInterval\x12\x30\n\x13year_month_interval\x18\x14 \x01(\x05H\x00R\x11yearMonthInterval\x12,\n\x11\x64\x61y_time_interval\x18\x15 \x01(\x03H\x00R\x0f\x64\x61yTimeInterval\x1au\n\x07\x44\x65\x63imal\x12\x14\n\x05value\x18\x01 \x01(\tR\x05value\x12!\n\tprecision\x18\x02 \x01(\x05H\x00R\tprecision\x88\x01\x01\x12\x19\n\x05scale\x18\x03 \x01(\x05H\x01R\x05scale\x88\x01\x01\x42\x0c\n\n_precisionB\x08\n\x06_scale\x1a\x62\n\x10\x43\x61lendarInterval\x12\x16\n\x06months\x18\x01 \x01(\x05R\x06months\x12\x12\n\x04\x64\x61ys\x18\x02 \x01(\x05R\x04\x64\x61ys\x12"\n\x0cmicroseconds\x18\x03 \x01(\x03R\x0cmicrosecondsB\x0e\n\x0cliteral_type\x1a\x46\n\x13UnresolvedAttribute\x12/\n\x13unparsed_identifier\x18\x01 \x01(\tR\x12unparsedIdentifier\x1a\xcc\x01\n\x12UnresolvedFunction\x12#\n\rfunction_name\x18\x01 \x01(\tR\x0c\x66unctionName\x12\x37\n\targuments\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\targuments\x12\x1f\n\x0bis_distinct\x18\x03 \x01(\x08R\nisDistinct\x12\x37\n\x18is_user_defined_function\x18\x04 \x01(\x08R\x15isUserDefinedFunction\x1a\x32\n\x10\x45xpressionString\x12\x1e\n\nexpression\x18\x01 \x01(\tR\nexpression\x1aR\n\x0eUnresolvedStar\x12,\n\x0funparsed_target\x18\x01 \x01(\tH\x00R\x0eunparsedTarget\x88\x01\x01\x42\x12\n\x10_unparsed_target\x1a,\n\x0fUnresolvedRegex\x12\x19\n\x08\x63ol_name\x18\x01 \x01(\tR\x07\x63olName\x1a\x84\x01\n\x16UnresolvedExtractValue\x12/\n\x05\x63hild\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x05\x63hild\x12\x39\n\nextraction\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\nextraction\x1a\xbb\x01\n\x0cUpdateFields\x12\x46\n\x11struct_expression\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x10structExpression\x12\x1d\n\nfield_name\x18\x02 \x01(\tR\tfieldName\x12\x44\n\x10value_expression\x18\x03 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x0fvalueExpression\x1ax\n\x05\x41lias\x12-\n\x04\x65xpr\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x04\x65xpr\x12\x12\n\x04name\x18\x02 \x03(\tR\x04name\x12\x1f\n\x08metadata\x18\x03 \x01(\tH\x00R\x08metadata\x88\x01\x01\x42\x0b\n\t_metadata\x1a\x9e\x01\n\x0eLambdaFunction\x12\x35\n\x08\x66unction\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x08\x66unction\x12U\n\targuments\x18\x02 \x03(\x0b\x32\x37.spark.connect.Expression.UnresolvedNamedLambdaVariableR\targuments\x1a>\n\x1dUnresolvedNamedLambdaVariable\x12\x1d\n\nname_parts\x18\x01 \x03(\tR\tnamePartsB\x0b\n\texpr_typeB"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3' ) @@ -262,7 +262,7 @@ DESCRIPTOR._options = None DESCRIPTOR._serialized_options = b"\n\036org.apache.spark.connect.protoP\001" _EXPRESSION._serialized_start = 105 - _EXPRESSION._serialized_end = 4689 + _EXPRESSION._serialized_end = 4731 _EXPRESSION_WINDOW._serialized_start = 1347 _EXPRESSION_WINDOW._serialized_end = 2130 _EXPRESSION_WINDOW_WINDOWFRAME._serialized_start = 1637 @@ -292,17 +292,17 @@ _EXPRESSION_EXPRESSIONSTRING._serialized_start = 3866 _EXPRESSION_EXPRESSIONSTRING._serialized_end = 3916 _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 3918 - _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 3958 - _EXPRESSION_UNRESOLVEDREGEX._serialized_start = 3960 - _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 4004 - _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_start = 4007 - _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_end = 4139 - _EXPRESSION_UPDATEFIELDS._serialized_start = 4142 - _EXPRESSION_UPDATEFIELDS._serialized_end = 4329 - _EXPRESSION_ALIAS._serialized_start = 4331 - _EXPRESSION_ALIAS._serialized_end = 4451 - _EXPRESSION_LAMBDAFUNCTION._serialized_start = 4454 - _EXPRESSION_LAMBDAFUNCTION._serialized_end = 4612 - _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_start = 4614 - _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_end = 4676 + _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 4000 + _EXPRESSION_UNRESOLVEDREGEX._serialized_start = 4002 + _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 4046 + _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_start = 4049 + _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_end = 4181 + _EXPRESSION_UPDATEFIELDS._serialized_start = 4184 + _EXPRESSION_UPDATEFIELDS._serialized_end = 4371 + _EXPRESSION_ALIAS._serialized_start = 4373 + _EXPRESSION_ALIAS._serialized_end = 4493 + _EXPRESSION_LAMBDAFUNCTION._serialized_start = 4496 + _EXPRESSION_LAMBDAFUNCTION._serialized_end = 4654 + _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_start = 4656 + _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_end = 4718 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi b/python/pyspark/sql/connect/proto/expressions_pb2.pyi index 5f64159b85409..45889c1518fd2 100644 --- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi +++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi @@ -699,22 +699,33 @@ class Expression(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor - TARGET_FIELD_NUMBER: builtins.int - @property - def target( - self, - ) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: - """(Optional) The target of the expansion, either be a table name or struct name, this - is a list of identifiers that is the path of the expansion. - """ + UNPARSED_TARGET_FIELD_NUMBER: builtins.int + unparsed_target: builtins.str + """(Optional) The target of the expansion. + + If set, it should end with '.*' and will be parsed by 'parseAttributeName' + in the server side. + """ def __init__( self, *, - target: collections.abc.Iterable[builtins.str] | None = ..., + unparsed_target: builtins.str | None = ..., ) -> None: ... + def HasField( + self, + field_name: typing_extensions.Literal[ + "_unparsed_target", b"_unparsed_target", "unparsed_target", b"unparsed_target" + ], + ) -> builtins.bool: ... def ClearField( - self, field_name: typing_extensions.Literal["target", b"target"] + self, + field_name: typing_extensions.Literal[ + "_unparsed_target", b"_unparsed_target", "unparsed_target", b"unparsed_target" + ], ) -> None: ... + def WhichOneof( + self, oneof_group: typing_extensions.Literal["_unparsed_target", b"_unparsed_target"] + ) -> typing_extensions.Literal["unparsed_target"] | None: ... class UnresolvedRegex(google.protobuf.message.Message): """Represents all of the input attributes to a given relational operator, for example in diff --git a/python/pyspark/sql/tests/connect/test_connect_function.py b/python/pyspark/sql/tests/connect/test_connect_function.py index 199fd6eb9a96c..e1792b03a44d1 100644 --- a/python/pyspark/sql/tests/connect/test_connect_function.py +++ b/python/pyspark/sql/tests/connect/test_connect_function.py @@ -71,6 +71,64 @@ def compare_by_show(self, df1, df2, n: int = 20, truncate: int = 20): self.assertEqual(str1, str2) + def test_count_star(self): + # SPARK-42099: test count(*), count(col(*)) and count(expr(*)) + + from pyspark.sql import functions as SF + from pyspark.sql.connect import functions as CF + + data = [(2, "Alice"), (3, "Alice"), (5, "Bob"), (10, "Bob")] + + cdf = self.connect.createDataFrame(data, schema=["age", "name"]) + sdf = self.spark.createDataFrame(data, schema=["age", "name"]) + + self.assertEqual( + cdf.select(CF.count(CF.expr("*")), CF.count(cdf.age)).collect(), + sdf.select(SF.count(SF.expr("*")), SF.count(sdf.age)).collect(), + ) + + self.assertEqual( + cdf.select(CF.count(CF.col("*")), CF.count(cdf.age)).collect(), + sdf.select(SF.count(SF.col("*")), SF.count(sdf.age)).collect(), + ) + + self.assertEqual( + cdf.select(CF.count("*"), CF.count(cdf.age)).collect(), + sdf.select(SF.count("*"), SF.count(sdf.age)).collect(), + ) + + self.assertEqual( + cdf.groupby("name").agg({"*": "count"}).sort("name").collect(), + sdf.groupby("name").agg({"*": "count"}).sort("name").collect(), + ) + + self.assertEqual( + cdf.groupby("name") + .agg(CF.count(CF.expr("*")), CF.count(cdf.age)) + .sort("name") + .collect(), + sdf.groupby("name") + .agg(SF.count(SF.expr("*")), SF.count(sdf.age)) + .sort("name") + .collect(), + ) + + self.assertEqual( + cdf.groupby("name") + .agg(CF.count(CF.col("*")), CF.count(cdf.age)) + .sort("name") + .collect(), + sdf.groupby("name") + .agg(SF.count(SF.col("*")), SF.count(sdf.age)) + .sort("name") + .collect(), + ) + + self.assertEqual( + cdf.groupby("name").agg(CF.count("*"), CF.count(cdf.age)).sort("name").collect(), + sdf.groupby("name").agg(SF.count("*"), SF.count(sdf.age)).sort("name").collect(), + ) + def test_broadcast(self): from pyspark.sql import functions as SF from pyspark.sql.connect import functions as CF