Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,9 @@ message Expression {
// (Required) An identifier that will be parsed by Catalyst parser. This should follow the
// Spark SQL identifier syntax.
string unparsed_identifier = 1;

// (Optional) The id of corresponding connect plan.
optional int64 plan_id = 2;
}

// An unresolved function is not explicitly bound to one explicit function, but the function
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ message Unknown {}
message RelationCommon {
// (Required) Shared relation metadata.
string source_info = 1;

// (Optional) A per-client globally unique id for a given connect plan.
optional int64 plan_id = 2;
}

// Relation that uses a SQL query to generate the output.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class SparkConnectPlanner(val session: SparkSession) {

// The root of the query plan is a relation and we apply the transformations to it.
def transformRelation(rel: proto.Relation): LogicalPlan = {
rel.getRelTypeCase match {
val plan = rel.getRelTypeCase match {
// DataFrame API
case proto.Relation.RelTypeCase.SHOW_STRING => transformShowString(rel.getShowString)
case proto.Relation.RelTypeCase.READ => transformReadRel(rel.getRead)
Expand Down Expand Up @@ -124,6 +124,11 @@ class SparkConnectPlanner(val session: SparkSession) {
transformRelationPlugin(rel.getExtension)
case _ => throw InvalidPlanInput(s"${rel.getUnknown} not supported.")
}

if (rel.hasCommon && rel.getCommon.hasPlanId) {
plan.setTagValue(LogicalPlan.PLAN_ID_TAG, rel.getCommon.getPlanId)
}
plan
}

private def transformRelationPlugin(extension: ProtoAny): LogicalPlan = {
Expand Down Expand Up @@ -702,10 +707,6 @@ class SparkConnectPlanner(val session: SparkSession) {
logical.Project(projectList = projection, child = baseRel)
}

private def transformUnresolvedExpression(exp: proto.Expression): UnresolvedAttribute = {
UnresolvedAttribute.quotedString(exp.getUnresolvedAttribute.getUnparsedIdentifier)
}

/**
* Transforms an input protobuf expression into the Catalyst expression. This is usually not
* called directly. Typically the planner will traverse the expressions automatically, only
Expand All @@ -720,7 +721,7 @@ class SparkConnectPlanner(val session: SparkSession) {
exp.getExprTypeCase match {
case proto.Expression.ExprTypeCase.LITERAL => transformLiteral(exp.getLiteral)
case proto.Expression.ExprTypeCase.UNRESOLVED_ATTRIBUTE =>
transformUnresolvedExpression(exp)
transformUnresolvedAttribute(exp.getUnresolvedAttribute)
case proto.Expression.ExprTypeCase.UNRESOLVED_FUNCTION =>
transformUnregisteredFunction(exp.getUnresolvedFunction)
.getOrElse(transformUnresolvedFunction(exp.getUnresolvedFunction))
Expand Down Expand Up @@ -758,6 +759,15 @@ class SparkConnectPlanner(val session: SparkSession) {
case expr => UnresolvedAlias(expr)
}

private def transformUnresolvedAttribute(
attr: proto.Expression.UnresolvedAttribute): UnresolvedAttribute = {
val expr = UnresolvedAttribute.quotedString(attr.getUnparsedIdentifier)
if (attr.hasPlanId) {
expr.setTagValue(LogicalPlan.PLAN_ID_TAG, attr.getPlanId)
}
expr
}

private def transformExpressionPlugin(extension: ProtoAny): Expression = {
SparkConnectPluginRegistry.expressionRegistry
// Lazily traverse the collection.
Expand Down
5 changes: 2 additions & 3 deletions python/pyspark/sql/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,6 @@ def __ne__( # type: ignore[override]
__ge__ = _bin_op("geq")
__gt__ = _bin_op("gt")

# TODO(SPARK-41812): DataFrame.join: ambiguous column
_eqNullSafe_doc = """
Equality test that is safe for null values.

Expand Down Expand Up @@ -315,9 +314,9 @@ def __ne__( # type: ignore[override]
... Row(value = 'bar'),
... Row(value = None)
... ])
>>> df1.join(df2, df1["value"] == df2["value"]).count() # doctest: +SKIP
>>> df1.join(df2, df1["value"] == df2["value"]).count()
0
>>> df1.join(df2, df1["value"].eqNullSafe(df2["value"])).count() # doctest: +SKIP
>>> df1.join(df2, df1["value"].eqNullSafe(df2["value"])).count()
1
>>> df2 = spark.createDataFrame([
... Row(id=1, value=float('NaN')),
Expand Down
15 changes: 9 additions & 6 deletions python/pyspark/sql/connect/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,14 @@
)

from pyspark.errors import PySparkTypeError
from pyspark.errors.exceptions.connect import SparkConnectException
import pyspark.sql.connect.plan as plan
from pyspark.sql.connect.group import GroupedData
from pyspark.sql.connect.readwriter import DataFrameWriter, DataFrameWriterV2
from pyspark.sql.connect.column import Column
from pyspark.sql.connect.expressions import UnresolvedRegex
from pyspark.sql.connect.functions import (
_to_col_with_plan_id,
_to_col,
_invoke_function,
col,
Expand Down Expand Up @@ -1284,10 +1286,12 @@ def __getitem__(self, item: Union[int, str, Column, List, Tuple]) -> Union[Colum
if isinstance(item, str):
# Check for alias
alias = self._get_alias()
if alias is not None:
return col(alias)
else:
return col(item)
if self._plan is None:
raise SparkConnectException("Cannot analyze on empty plan.")
return _to_col_with_plan_id(
col=alias if alias is not None else item,
plan_id=self._plan._plan_id,
)
elif isinstance(item, Column):
return self.filter(item)
elif isinstance(item, (list, tuple)):
Expand Down Expand Up @@ -1694,9 +1698,8 @@ def _test() -> None:
del pyspark.sql.connect.dataframe.DataFrame.repartition.__doc__
del pyspark.sql.connect.dataframe.DataFrame.repartitionByRange.__doc__

# TODO(SPARK-41823): ambiguous column names
# TODO(SPARK-42367): DataFrame.drop should handle duplicated columns
del pyspark.sql.connect.dataframe.DataFrame.drop.__doc__
del pyspark.sql.connect.dataframe.DataFrame.join.__doc__

# TODO(SPARK-41625): Support Structured Streaming
del pyspark.sql.connect.dataframe.DataFrame.isStreaming.__doc__
Expand Down
7 changes: 6 additions & 1 deletion python/pyspark/sql/connect/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,11 +339,14 @@ class ColumnReference(Expression):
treat it as an unresolved attribute. Attributes that have the same fully
qualified name are identical"""

def __init__(self, unparsed_identifier: str) -> None:
def __init__(self, unparsed_identifier: str, plan_id: Optional[int] = None) -> None:
super().__init__()
assert isinstance(unparsed_identifier, str)
self._unparsed_identifier = unparsed_identifier

assert plan_id is None or isinstance(plan_id, int)
self._plan_id = plan_id

def name(self) -> str:
"""Returns the qualified name of the column reference."""
return self._unparsed_identifier
Expand All @@ -352,6 +355,8 @@ def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
"""Returns the Proto representation of the expression."""
expr = proto.Expression()
expr.unresolved_attribute.unparsed_identifier = self._unparsed_identifier
if self._plan_id is not None:
expr.unresolved_attribute.plan_id = self._plan_id
return expr

def __repr__(self) -> str:
Expand Down
19 changes: 10 additions & 9 deletions python/pyspark/sql/connect/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,15 @@
from pyspark.sql.connect.dataframe import DataFrame


def _to_col_with_plan_id(col: str, plan_id: Optional[int]) -> Column:
if col == "*":
return Column(UnresolvedStar(unparsed_target=None))
elif col.endswith(".*"):
return Column(UnresolvedStar(unparsed_target=col))
else:
return Column(ColumnReference(unparsed_identifier=col, plan_id=plan_id))


def _to_col(col: "ColumnOrName") -> Column:
assert isinstance(col, (Column, str))
return col if isinstance(col, Column) else column(col)
Expand Down Expand Up @@ -202,12 +211,7 @@ def _options_to_col(options: Dict[str, Any]) -> Column:


def col(col: str) -> Column:
if col == "*":
return Column(UnresolvedStar(unparsed_target=None))
elif col.endswith(".*"):
return Column(UnresolvedStar(unparsed_target=col))
else:
return Column(ColumnReference(unparsed_identifier=col))
return _to_col_with_plan_id(col=col, plan_id=None)


col.__doc__ = pysparkfuncs.col.__doc__
Expand Down Expand Up @@ -2470,9 +2474,6 @@ def _test() -> None:
del pyspark.sql.connect.functions.timestamp_seconds.__doc__
del pyspark.sql.connect.functions.unix_timestamp.__doc__

# TODO(SPARK-41812): Proper column names after join
del pyspark.sql.connect.functions.count_distinct.__doc__

# TODO(SPARK-41843): Implement SparkSession.udf
del pyspark.sql.connect.functions.call_udf.__doc__

Expand Down
Loading