Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -225,9 +225,12 @@ message Expression {

// UnresolvedStar is used to expand all the fields of a relation or struct.
message UnresolvedStar {
// (Optional) The target of the expansion, either be a table name or struct name, this
// is a list of identifiers that is the path of the expansion.
repeated string target = 1;

// (Optional) The target of the expansion.
//
// If set, it should end with '.*' and will be parsed by 'parseAttributeName'
// in the server side.
optional string unparsed_target = 1;
}

// Represents all of the input attributes to a given relational operator, for example in
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1002,11 +1002,19 @@ class SparkConnectPlanner(session: SparkSession) {
session.sessionState.sqlParser.parseExpression(expr.getExpression)
}

private def transformUnresolvedStar(regex: proto.Expression.UnresolvedStar): Expression = {
if (regex.getTargetList.isEmpty) {
UnresolvedStar(Option.empty)
private def transformUnresolvedStar(star: proto.Expression.UnresolvedStar): UnresolvedStar = {
if (star.hasUnparsedTarget) {
val target = star.getUnparsedTarget
if (!target.endsWith(".*")) {
throw InvalidPlanInput(
s"UnresolvedStar requires a unparsed target ending with '.*', " +
s"but got $target.")
}

UnresolvedStar(
Some(UnresolvedAttribute.parseAttributeName(target.substring(0, target.length - 2))))
} else {
UnresolvedStar(Some(regex.getTargetList.asScala.toSeq))
UnresolvedStar(None)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -552,7 +552,7 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest {
.addExpressions(
proto.Expression
.newBuilder()
.setUnresolvedStar(UnresolvedStar.newBuilder().addTarget("a").addTarget("b").build())
.setUnresolvedStar(UnresolvedStar.newBuilder().setUnparsedTarget("a.b.*").build())
.build())
.build()

Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/sql/connect/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def _convert_col(col: "ColumnOrName") -> "ColumnOrName":
else:
return Column(SortOrder(col._expr))
else:
return Column(SortOrder(ColumnReference(name=col)))
return Column(SortOrder(ColumnReference(col)))

if isinstance(numPartitions, int):
if not numPartitions > 0:
Expand Down Expand Up @@ -1176,7 +1176,7 @@ def sampleBy(
from pyspark.sql.connect.expressions import ColumnReference

if isinstance(col, str):
col = Column(ColumnReference(name=col))
col = Column(ColumnReference(col))
elif not isinstance(col, Column):
raise TypeError("col must be a string or a column, but got %r" % type(col))
if not isinstance(fractions, dict):
Expand Down
46 changes: 43 additions & 3 deletions python/pyspark/sql/connect/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,10 +336,10 @@ class ColumnReference(Expression):
treat it as an unresolved attribute. Attributes that have the same fully
qualified name are identical"""

def __init__(self, name: str) -> None:
def __init__(self, unparsed_identifier: str) -> None:
super().__init__()
assert isinstance(name, str)
self._unparsed_identifier = name
assert isinstance(unparsed_identifier, str)
self._unparsed_identifier = unparsed_identifier

def name(self) -> str:
"""Returns the qualified name of the column reference."""
Expand All @@ -354,6 +354,43 @@ def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
def __repr__(self) -> str:
return f"{self._unparsed_identifier}"

def __eq__(self, other: Any) -> bool:
return (
other is not None
and isinstance(other, ColumnReference)
and other._unparsed_identifier == self._unparsed_identifier
)


class UnresolvedStar(Expression):
def __init__(self, unparsed_target: Optional[str]):
super().__init__()

if unparsed_target is not None:
assert isinstance(unparsed_target, str) and unparsed_target.endswith(".*")

self._unparsed_target = unparsed_target

def to_plan(self, session: "SparkConnectClient") -> "proto.Expression":
expr = proto.Expression()
expr.unresolved_star.SetInParent()
if self._unparsed_target is not None:
expr.unresolved_star.unparsed_target = self._unparsed_target
return expr

def __repr__(self) -> str:
if self._unparsed_target is not None:
return f"unresolvedstar({self._unparsed_target})"
else:
return "unresolvedstar()"

def __eq__(self, other: Any) -> bool:
return (
other is not None
and isinstance(other, UnresolvedStar)
and other._unparsed_target == self._unparsed_target
)


class SQLExpression(Expression):
"""Returns Expression which contains a string which is a SQL expression
Expand All @@ -370,6 +407,9 @@ def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
expr.expression_string.expression = self._expr
return expr

def __eq__(self, other: Any) -> bool:
return other is not None and isinstance(other, SQLExpression) and other._expr == self._expr


class SortOrder(Expression):
def __init__(self, child: Expression, ascending: bool = True, nullsFirst: bool = True) -> None:
Expand Down
11 changes: 7 additions & 4 deletions python/pyspark/sql/connect/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
LiteralExpression,
ColumnReference,
UnresolvedFunction,
UnresolvedStar,
SQLExpression,
LambdaFunction,
UnresolvedNamedLambdaVariable,
Expand Down Expand Up @@ -186,7 +187,12 @@ def _options_to_col(options: Dict[str, Any]) -> Column:


def col(col: str) -> Column:
return Column(ColumnReference(col))
if col == "*":
return Column(UnresolvedStar(unparsed_target=None))
elif col.endswith(".*"):
return Column(UnresolvedStar(unparsed_target=col))
else:
return Column(ColumnReference(unparsed_identifier=col))


col.__doc__ = pysparkfuncs.col.__doc__
Expand Down Expand Up @@ -2389,9 +2395,6 @@ def _test() -> None:
# TODO(SPARK-41843): Implement SparkSession.udf
del pyspark.sql.connect.functions.call_udf.__doc__

# TODO(SPARK-41845): Fix count bug
del pyspark.sql.connect.functions.count.__doc__

globs["spark"] = (
PySparkSession.builder.appName("sql.connect.functions tests")
.remote("local[4]")
Expand Down
8 changes: 1 addition & 7 deletions python/pyspark/sql/connect/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,8 @@ def agg(self, *exprs: Union[Column, Dict[str, str]]) -> "DataFrame":

assert exprs, "exprs should not be empty"
if len(exprs) == 1 and isinstance(exprs[0], dict):
# There is a special case for count(*) which is rewritten into count(1).
# Convert the dict into key value pairs
aggregate_cols = [
_invoke_function(
exprs[0][k], lit(1) if exprs[0][k] == "count" and k == "*" else col(k)
)
for k in exprs[0]
]
aggregate_cols = [_invoke_function(exprs[0][k], col(k)) for k in exprs[0]]
else:
# Columns
assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column"
Expand Down
Loading