Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 11 additions & 27 deletions python/pyspark/sql/connect/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,19 +72,12 @@
from pyspark.sql.connect.streaming.readwriter import DataStreamWriter
from pyspark.sql.connect.column import Column
from pyspark.sql.connect.expressions import (
SortOrder,
ColumnReference,
UnresolvedRegex,
UnresolvedStar,
)
from pyspark.sql.connect.functions.builtin import (
_to_col,
_invoke_function,
col,
lit,
udf,
struct,
expr as sql_expression,
)
from pyspark.sql.connect.functions import builtin as F
from pyspark.sql.pandas.types import from_arrow_schema


Expand Down Expand Up @@ -199,9 +192,9 @@ def selectExpr(self, *expr: Union[str, List[str]]) -> "DataFrame":
expr = expr[0] # type: ignore[assignment]
for element in expr:
if isinstance(element, str):
sql_expr.append(sql_expression(element))
sql_expr.append(F.expr(element))
else:
sql_expr.extend([sql_expression(e) for e in element])
sql_expr.extend([F.expr(e) for e in element])

return DataFrame(plan.Project(self._plan, *sql_expr), session=self._session)

Expand All @@ -215,7 +208,7 @@ def agg(self, *exprs: Union[Column, Dict[str, str]]) -> "DataFrame":
)

if len(exprs) == 1 and isinstance(exprs[0], dict):
measures = [_invoke_function(f, col(e)) for e, f in exprs[0].items()]
measures = [F._invoke_function(f, F.col(e)) for e, f in exprs[0].items()]
return self.groupBy().agg(*measures)
else:
# other expressions
Expand Down Expand Up @@ -259,7 +252,7 @@ def sparkSession(self) -> "SparkSession":
sparkSession.__doc__ = PySparkDataFrame.sparkSession.__doc__

def count(self) -> int:
table, _ = self.agg(_invoke_function("count", lit(1)))._to_table()
table, _ = self.agg(F._invoke_function("count", F.lit(1)))._to_table()
return table[0][0].as_py()

count.__doc__ = PySparkDataFrame.count.__doc__
Expand Down Expand Up @@ -352,8 +345,6 @@ def repartitionByRange( # type: ignore[misc]
self, numPartitions: Union[int, "ColumnOrName"], *cols: "ColumnOrName"
) -> "DataFrame":
def _convert_col(col: "ColumnOrName") -> "ColumnOrName":
from pyspark.sql.connect.expressions import SortOrder, ColumnReference

if isinstance(col, Column):
if isinstance(col._expr, SortOrder):
return col
Expand Down Expand Up @@ -471,7 +462,7 @@ def drop(self, *cols: "ColumnOrName") -> "DataFrame":

def filter(self, condition: Union[Column, str]) -> "DataFrame":
if isinstance(condition, str):
expr = sql_expression(condition)
expr = F.expr(condition)
else:
expr = condition
return DataFrame(plan.Filter(child=self._plan, filter=expr), session=self._session)
Expand Down Expand Up @@ -713,7 +704,7 @@ def _sort_cols(
)
else:
_c = c # type: ignore[assignment]
_cols.append(_to_col(cast("ColumnOrName", _c)))
_cols.append(F._to_col(cast("ColumnOrName", _c)))

ascending = kwargs.get("ascending", True)
if isinstance(ascending, (bool, int)):
Expand Down Expand Up @@ -1652,8 +1643,6 @@ def freqItems(
def sampleBy(
self, col: "ColumnOrName", fractions: Dict[Any, float], seed: Optional[int] = None
) -> "DataFrame":
from pyspark.sql.connect.expressions import ColumnReference

if isinstance(col, str):
col = Column(ColumnReference(col))
elif not isinstance(col, Column):
Expand Down Expand Up @@ -1754,7 +1743,7 @@ def __getitem__(self, item: Union[int, str, Column, List, Tuple]) -> Union[Colum
elif isinstance(item, (list, tuple)):
return self.select(*item)
elif isinstance(item, int):
return col(self.columns[item])
return F.col(self.columns[item])
else:
raise PySparkTypeError(
error_class="NOT_COLUMN_OR_INT_OR_LIST_OR_STR_OR_TUPLE",
Expand All @@ -1768,11 +1757,6 @@ def __dir__(self) -> List[str]:

__dir__.__doc__ = PySparkDataFrame.__dir__.__doc__

def _print_plan(self) -> str:
if self._plan:
return self._plan.print()
return ""

def collect(self) -> List[Row]:
table, schema = self._to_table()

Expand Down Expand Up @@ -2084,8 +2068,8 @@ def foreach(self, f: Callable[[Row], None]) -> None:
def foreach_func(row: Any) -> None:
f(row)

self.select(struct(*self.schema.fieldNames()).alias("row")).select(
udf(foreach_func, StructType())("row") # type: ignore[arg-type]
self.select(F.struct(*self.schema.fieldNames()).alias("row")).select(
F.udf(foreach_func, StructType())("row") # type: ignore[arg-type]
).collect()

foreach.__doc__ = PySparkDataFrame.foreach.__doc__
Expand Down
15 changes: 4 additions & 11 deletions python/pyspark/sql/connect/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@

import pyspark.sql.connect.plan as plan
from pyspark.sql.connect.column import Column
from pyspark.sql.connect.functions.builtin import _invoke_function, col, lit
from pyspark.sql.connect.functions import builtin as F
from pyspark.errors import PySparkNotImplementedError, PySparkTypeError

if TYPE_CHECKING:
Expand Down Expand Up @@ -132,7 +132,7 @@ def agg(self, *exprs: Union[Column, Dict[str, str]]) -> "DataFrame":
assert exprs, "exprs should not be empty"
if len(exprs) == 1 and isinstance(exprs[0], dict):
# Convert the dict into key value pairs
aggregate_cols = [_invoke_function(exprs[0][k], col(k)) for k in exprs[0]]
aggregate_cols = [F._invoke_function(exprs[0][k], F.col(k)) for k in exprs[0]]
else:
# Columns
assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column"
Expand Down Expand Up @@ -166,8 +166,6 @@ def _numeric_agg(self, function: str, cols: Sequence[str]) -> "DataFrame":
field.name for field in schema.fields if isinstance(field.dataType, NumericType)
]

agg_cols: List[str] = []

if len(cols) > 0:
invalid_cols = [c for c in cols if c not in numerical_cols]
if len(invalid_cols) > 0:
Expand All @@ -185,7 +183,7 @@ def _numeric_agg(self, function: str, cols: Sequence[str]) -> "DataFrame":
child=self._df._plan,
group_type=self._group_type,
grouping_cols=self._grouping_cols,
aggregate_cols=[_invoke_function(function, col(c)) for c in agg_cols],
aggregate_cols=[F._invoke_function(function, F.col(c)) for c in agg_cols],
pivot_col=self._pivot_col,
pivot_values=self._pivot_values,
grouping_sets=self._grouping_sets,
Expand Down Expand Up @@ -216,7 +214,7 @@ def avg(self, *cols: str) -> "DataFrame":
mean = avg

def count(self) -> "DataFrame":
return self.agg(_invoke_function("count", lit(1)).alias("count"))
return self.agg(F._invoke_function("count", F.lit(1)).alias("count"))

count.__doc__ = PySparkGroupedData.count.__doc__

Expand Down Expand Up @@ -444,11 +442,6 @@ def applyInArrow(

applyInArrow.__doc__ = PySparkPandasCogroupedOps.applyInArrow.__doc__

@staticmethod
def _extract_cols(gd: "GroupedData") -> List[Column]:
df = gd._df
return [df[col] for col in df.columns]


PandasCogroupedOps.__doc__ = PySparkPandasCogroupedOps.__doc__

Expand Down