Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ message Expression {
CommonInlineUserDefinedFunction common_inline_user_defined_function = 15;
CallFunction call_function = 16;
NamedArgumentExpression named_argument_expression = 17;
GetColumnByOrdinal get_column_by_ordinal = 18;

// This field is used to mark extensions to the protocol. When plugins generate arbitrary
// relations they can add them here. During the planning the correct resolution is done.
Expand Down Expand Up @@ -228,6 +229,15 @@ message Expression {
optional bool is_metadata_column = 3;
}

// An unresolved attribute that is represented by its column index.
message GetColumnByOrdinal {
// (Required) 0-based column index.
int32 ordinal = 1;

// (Optional) The id of corresponding connect plan.
optional int64 plan_id = 2;
}

// An unresolved function is not explicitly bound to one explicit function, but the function
// is resolved during analysis following Sparks name resolution rules.
message UnresolvedFunction {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ import org.apache.spark.ml.{functions => MLFunctions}
import org.apache.spark.sql.{Column, Dataset, Encoders, ForeachWriter, RelationalGroupedDataset, SparkSession}
import org.apache.spark.sql.avro.{AvroDataToCatalyst, CatalystDataToAvro}
import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier, FunctionIdentifier}
import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, MultiAlias, NameParameterizedQuery, PosParameterizedQuery, UnresolvedAlias, UnresolvedAttribute, UnresolvedDeserializer, UnresolvedExtractValue, UnresolvedFunction, UnresolvedRegex, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, GlobalTempView, LocalTempView, MultiAlias, NameParameterizedQuery, PosParameterizedQuery, UnresolvedAlias, UnresolvedAttribute, UnresolvedDeserializer, UnresolvedExtractValue, UnresolvedFunction, UnresolvedRegex, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, ExpressionEncoder, RowEncoder}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.UnboundRowEncoder
import org.apache.spark.sql.catalyst.expressions._
Expand Down Expand Up @@ -1344,6 +1344,8 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging {
case proto.Expression.ExprTypeCase.LITERAL => transformLiteral(exp.getLiteral)
case proto.Expression.ExprTypeCase.UNRESOLVED_ATTRIBUTE =>
transformUnresolvedAttribute(exp.getUnresolvedAttribute)
case proto.Expression.ExprTypeCase.GET_COLUMN_BY_ORDINAL =>
transformGetColumnByOrdinal(exp.getGetColumnByOrdinal)
case proto.Expression.ExprTypeCase.UNRESOLVED_FUNCTION =>
transformUnregisteredFunction(exp.getUnresolvedFunction)
.getOrElse(transformUnresolvedFunction(exp.getUnresolvedFunction))
Expand Down Expand Up @@ -1397,6 +1399,16 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging {
expr
}

private def transformGetColumnByOrdinal(
attr: proto.Expression.GetColumnByOrdinal): GetColumnByOrdinal = {
// always set dataType field null, since it is not used in Analyzer
val expr = GetColumnByOrdinal(attr.getOrdinal, null)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also cc @cloud-fan for the usage of GetColumnByOrdinal

if (attr.hasPlanId) {
expr.setTagValue(LogicalPlan.PLAN_ID_TAG, attr.getPlanId)
}
expr
}

private def transformExpressionPlugin(extension: ProtoAny): Expression = {
SparkConnectPluginRegistry.expressionRegistry
// Lazily traverse the collection.
Expand Down
21 changes: 16 additions & 5 deletions python/pyspark/sql/connect/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,10 @@
from pyspark.sql.connect.readwriter import DataFrameWriter, DataFrameWriterV2
from pyspark.sql.connect.streaming.readwriter import DataStreamWriter
from pyspark.sql.connect.column import Column
from pyspark.sql.connect.expressions import UnresolvedRegex
from pyspark.sql.connect.expressions import (
UnresolvedRegex,
GetColumnByOrdinal,
)
from pyspark.sql.connect.functions import (
_to_col_with_plan_id,
_to_col,
Expand Down Expand Up @@ -1654,10 +1657,10 @@ def __getitem__(self, item: Union[Column, List, Tuple]) -> "DataFrame":
...

def __getitem__(self, item: Union[int, str, Column, List, Tuple]) -> Union[Column, "DataFrame"]:
if isinstance(item, str):
if self._plan is None:
raise SparkConnectException("Cannot analyze on empty plan.")
if self._plan is None:
raise SparkConnectException("Cannot analyze on empty plan.")

if isinstance(item, str):
# validate the column name
if not hasattr(self._session, "is_mock_session"):
self.select(item).isLocal()
Expand All @@ -1671,7 +1674,15 @@ def __getitem__(self, item: Union[int, str, Column, List, Tuple]) -> Union[Colum
elif isinstance(item, (list, tuple)):
return self.select(*item)
elif isinstance(item, int):
return col(self.columns[item])
n = len(self.columns)
# 1, convert bool; 2, covert negative index; 3, validate index
item = range(0, n)[int(item)]
return Column(
GetColumnByOrdinal(
ordinal=item,
plan_id=self._plan._plan_id,
)
)
else:
raise PySparkTypeError(
error_class="NOT_COLUMN_OR_INT_OR_LIST_OR_STR_OR_TUPLE",
Expand Down
34 changes: 34 additions & 0 deletions python/pyspark/sql/connect/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,40 @@ def __eq__(self, other: Any) -> bool:
)


class GetColumnByOrdinal(Expression):
"""Represents a column index (0-based). There is no guarantee that this column
actually exists. In the context of this project, we refer by its index and
treat it as an unresolved GetColumnByOrdinal"""

def __init__(self, ordinal: int, plan_id: Optional[int] = None):
super().__init__()

assert isinstance(ordinal, int) and ordinal >= 0
self._ordinal = ordinal

assert plan_id is None or isinstance(plan_id, int)
self._plan_id = plan_id

def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
"""Returns the Proto representation of the expression."""
expr = proto.Expression()
expr.get_column_by_ordinal.ordinal = self._ordinal
if self._plan_id is not None:
expr.get_column_by_ordinal.plan_id = self._plan_id
return expr

def __repr__(self) -> str:
return f"getcolumnbyordinal({self._ordinal})"

def __eq__(self, other: Any) -> bool:
return (
other is not None
and isinstance(other, GetColumnByOrdinal)
and other._ordinal == self._ordinal
and other._plan_id == self._plan_id
)


class UnresolvedStar(Expression):
def __init__(self, unparsed_target: Optional[str]):
super().__init__()
Expand Down
126 changes: 64 additions & 62 deletions python/pyspark/sql/connect/proto/expressions_pb2.py

Large diffs are not rendered by default.

40 changes: 40 additions & 0 deletions python/pyspark/sql/connect/proto/expressions_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -804,6 +804,37 @@ class Expression(google.protobuf.message.Message):
self, oneof_group: typing_extensions.Literal["_plan_id", b"_plan_id"]
) -> typing_extensions.Literal["plan_id"] | None: ...

class GetColumnByOrdinal(google.protobuf.message.Message):
"""An unresolved attribute that is represented by its column index."""

DESCRIPTOR: google.protobuf.descriptor.Descriptor

ORDINAL_FIELD_NUMBER: builtins.int
PLAN_ID_FIELD_NUMBER: builtins.int
ordinal: builtins.int
"""(Required) 0-based column index."""
plan_id: builtins.int
"""(Optional) The id of corresponding connect plan."""
def __init__(
self,
*,
ordinal: builtins.int = ...,
plan_id: builtins.int | None = ...,
) -> None: ...
def HasField(
self,
field_name: typing_extensions.Literal["_plan_id", b"_plan_id", "plan_id", b"plan_id"],
) -> builtins.bool: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
"_plan_id", b"_plan_id", "ordinal", b"ordinal", "plan_id", b"plan_id"
],
) -> None: ...
def WhichOneof(
self, oneof_group: typing_extensions.Literal["_plan_id", b"_plan_id"]
) -> typing_extensions.Literal["plan_id"] | None: ...

class UnresolvedFunction(google.protobuf.message.Message):
"""An unresolved function is not explicitly bound to one explicit function, but the function
is resolved during analysis following Sparks name resolution rules.
Expand Down Expand Up @@ -1126,6 +1157,7 @@ class Expression(google.protobuf.message.Message):
COMMON_INLINE_USER_DEFINED_FUNCTION_FIELD_NUMBER: builtins.int
CALL_FUNCTION_FIELD_NUMBER: builtins.int
NAMED_ARGUMENT_EXPRESSION_FIELD_NUMBER: builtins.int
GET_COLUMN_BY_ORDINAL_FIELD_NUMBER: builtins.int
EXTENSION_FIELD_NUMBER: builtins.int
@property
def literal(self) -> global___Expression.Literal: ...
Expand Down Expand Up @@ -1164,6 +1196,8 @@ class Expression(google.protobuf.message.Message):
@property
def named_argument_expression(self) -> global___NamedArgumentExpression: ...
@property
def get_column_by_ordinal(self) -> global___Expression.GetColumnByOrdinal: ...
@property
def extension(self) -> google.protobuf.any_pb2.Any:
"""This field is used to mark extensions to the protocol. When plugins generate arbitrary
relations they can add them here. During the planning the correct resolution is done.
Expand All @@ -1189,6 +1223,7 @@ class Expression(google.protobuf.message.Message):
common_inline_user_defined_function: global___CommonInlineUserDefinedFunction | None = ...,
call_function: global___CallFunction | None = ...,
named_argument_expression: global___NamedArgumentExpression | None = ...,
get_column_by_ordinal: global___Expression.GetColumnByOrdinal | None = ...,
extension: google.protobuf.any_pb2.Any | None = ...,
) -> None: ...
def HasField(
Expand All @@ -1208,6 +1243,8 @@ class Expression(google.protobuf.message.Message):
b"expression_string",
"extension",
b"extension",
"get_column_by_ordinal",
b"get_column_by_ordinal",
"lambda_function",
b"lambda_function",
"literal",
Expand Down Expand Up @@ -1251,6 +1288,8 @@ class Expression(google.protobuf.message.Message):
b"expression_string",
"extension",
b"extension",
"get_column_by_ordinal",
b"get_column_by_ordinal",
"lambda_function",
b"lambda_function",
"literal",
Expand Down Expand Up @@ -1297,6 +1336,7 @@ class Expression(google.protobuf.message.Message):
"common_inline_user_defined_function",
"call_function",
"named_argument_expression",
"get_column_by_ordinal",
"extension",
] | None: ...

Expand Down
5 changes: 4 additions & 1 deletion python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3455,7 +3455,10 @@ def __getitem__(self, item: Union[int, str, Column, List, Tuple]) -> Union[Colum
elif isinstance(item, (list, tuple)):
return self.select(*item)
elif isinstance(item, int):
jc = self._jdf.apply(self.columns[item])
n = len(self.columns)
# 1, convert bool; 2, covert negative index; 3, validate index
item = range(0, n)[int(item)]
jc = self._jdf.apply(item)
return Column(jc)
else:
raise PySparkTypeError(
Expand Down
47 changes: 45 additions & 2 deletions python/pyspark/sql/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import io
from contextlib import redirect_stdout

from pyspark import StorageLevel
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a question. Is this a relevant change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is not related, it is a not-used import. since we are touching this file, what about also removing it btw?

from pyspark.sql import SparkSession, Row, functions
from pyspark.sql.functions import col, lit, count, sum, mean, struct
from pyspark.sql.pandas.utils import pyarrow_version_less_than_minimum
Expand Down Expand Up @@ -63,6 +62,51 @@


class DataFrameTestsMixin:
def test_getitem_invalid_indices(self):
df = self.spark.sql(
"SELECT * FROM VALUES "
"(1, 1.1, 'a'), "
"(2, 2.2, 'b'), "
"(4, 4.4, 'c') "
"AS TAB(a, b, c)"
)

# accepted type and values
for index in [False, True, 0, 1, 2, -1, -2, -3]:
df[index]
Copy link
Contributor

@cloud-fan cloud-fan Sep 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really a bad API. df.col can be ambiguous as people may use the column reference far away from the dataframe, e.g. df1.join(df2).select...filter...select(df1.col). We recommend users use qualified unresolved column instead, like col("t1.col"). Now df[index] is even worse as it only makes sense to use it immediately in current df's transformation.

Why do we add such an API? To support order by ordinal, we can just order by integer literals. The SQL parser also parses ORDER BY 1, 2 as ordering by integer literal 1 and 2, and analyzer will properly resolve it.

cc @HyukjinKwon @zhengruifeng

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If df[index] is already in pyspark for a while, I think it's fine to treat it as a shortcut of df.i_th_col. We shouldn't use GetColumnByOrdinal in this case, as it was added for Dataset Tuple encoding and it's guaranteed that we want to get the column from the direct child of the current plan node. But here, we can't guarantee this, as people can do df1.select..filter...groupBy...select(df1[index])

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

6183b5e

df[index] has been supported since spark 2.0.0.

To support df.groupBy(1, 2, 3) and df.orderBy(1, 2, 3), right now GetColumnByOrdinal is only used in the direct child internally.

The SQL parser also parses ORDER BY 1, 2 as ordering by integer literal 1 and 2, and analyzer will properly resolve it.

Do you mean use should directly SortOrder(UnresolvedOrdinal(index)) ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

have offline discussion with wenchen, will fix it by switching to SortOrder(Literal(index)). Will fix it next week.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scala> val df = Seq((2, 1), (1, 2)).toDF("a", "b")
val df: org.apache.spark.sql.DataFrame = [a: int, b: int]

scala> df.show()
+---+---+
|  a|  b|
+---+---+
|  2|  1|
|  1|  2|
+---+---+


scala> df.orderBy(lit(1)).show()
+---+---+
|  a|  b|
+---+---+
|  1|  2|
|  2|  1|
+---+---+


scala> df.groupBy(lit(1)).agg(first(col("a")), max(col("b"))).show()
+---+--------+------+
|  1|first(a)|max(b)|
+---+--------+------+
|  1|       2|     2|
+---+--------+------+

it seems orderBy(lit(1)) directly works, while groupBy(lit(1)) needs some investigation.

Let me revert this PR first


# negative cases: ordinal out of range
for index in [-10, -4, 3, 10, 100]:
with self.assertRaises(IndexError):
df[index]

# negative cases: unsupported types
for index in [None, 1.0, Decimal(1)]:
with self.assertRaises(PySparkTypeError):
df[index]

def test_getitem_duplicated_column(self):
df = self.spark.sql(
"SELECT * FROM VALUES "
"(1, 1.1, 'a'), "
"(2, 2.2, 'b'), "
"(4, 4.4, 'c') "
"AS TAB(a, a, a)"
)

self.assertEqual(
df.select(df[0]).schema.simpleString(),
"struct<a:int>",
)
self.assertEqual(
df.select(df[1]).schema.simpleString(),
"struct<a:decimal(2,1)>",
)
self.assertEqual(
df.select(df[2]).schema.simpleString(),
"struct<a:string>",
)

def test_range(self):
self.assertEqual(self.spark.range(1, 1).count(), 0)
self.assertEqual(self.spark.range(1, 0, -1).count(), 1)
Expand All @@ -77,7 +121,6 @@ def test_duplicated_column_names(self):
self.assertEqual(2, row[1])
self.assertEqual("Row(c=1, c=2)", str(row))
# Cannot access columns
self.assertRaises(AnalysisException, lambda: df.select(df[0]).first())
self.assertRaises(AnalysisException, lambda: df.select(df.c).first())
self.assertRaises(AnalysisException, lambda: df.select(df["c"]).first())

Expand Down
27 changes: 27 additions & 0 deletions python/pyspark/sql/tests/test_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,33 @@ def test_order_by_ordinal(self):
with self.assertRaises(IndexError):
df.orderBy(-3)

def test_order_by_ordinal_duplicated_column(self):
spark = self.spark
df = spark.createDataFrame(
[
(1, 1),
(1, 2),
(2, 1),
(2, 2),
(3, 1),
(3, 2),
],
["a", "a"],
)

with self.tempView("v"):
df.createOrReplaceTempView("v")

df1 = spark.sql("select * from v order by 2, 1;")
df2 = df.orderBy(2, 1)
assertSchemaEqual(df1.schema, df2.schema)
assertDataFrameEqual(df1, df2)

df1 = spark.sql("select * from v order by 1 desc, 2;")
df2 = df.orderBy(-1, 2)
assertSchemaEqual(df1.schema, df2.schema)
assertDataFrameEqual(df1, df2)


class GroupTests(GroupTestsMixin, ReusedSQLTestCase):
pass
Expand Down
12 changes: 12 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1393,6 +1393,11 @@ class Dataset[T] private[sql](
*/
def apply(colName: String): Column = col(colName)

/**
* Selects column based on the column index (0-based) and returns it as a [[Column]].
*/
private[sql] def apply(index: Int): Column = col(index)

/**
* Specifies some hint on the current Dataset. As an example, the following code specifies
* that one of the plan can be broadcasted:
Expand Down Expand Up @@ -1445,6 +1450,13 @@ class Dataset[T] private[sql](
}
}

/**
* Selects column based on the column index (0-based) and returns it as a [[Column]].
*/
private[sql] def col(index: Int): Column = {
Column(addDataFrameIdToCol(queryExecution.analyzed.output(index)))
}

/**
* Selects a metadata column based on its logical column name, and returns it as a [[Column]].
*
Expand Down