diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index 1f7e789818fb2..5ca747fdd6a7c 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -199,6 +199,9 @@ def join(self, other: "DataFrame", on: Any, how: Optional[str] = None) -> "DataF def limit(self, n: int) -> "DataFrame": return DataFrame.withPlan(plan.Limit(child=self._plan, limit=n), session=self._session) + def offset(self, n: int) -> "DataFrame": + return DataFrame.withPlan(plan.Offset(child=self._plan, offset=n), session=self._session) + def sort(self, *cols: "ColumnOrString") -> "DataFrame": """Sort by a specific column""" return DataFrame.withPlan(plan.Sort(self._plan, *cols), session=self._session) diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index c564b71cdba6c..5b8b7c71866ed 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -272,10 +272,9 @@ def _repr_html_(self) -> str: class Limit(LogicalPlan): - def __init__(self, child: Optional["LogicalPlan"], limit: int, offset: int = 0) -> None: + def __init__(self, child: Optional["LogicalPlan"], limit: int) -> None: super().__init__(child) self.limit = limit - self.offset = offset def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation: assert self._child is not None @@ -286,7 +285,7 @@ def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation: def print(self, indent: int = 0) -> str: c_buf = self._child.print(indent + LogicalPlan.INDENT) if self._child else "" - return f"{' ' * indent}\n{c_buf}" + return f"{' ' * indent}\n{c_buf}" def _repr_html_(self) -> str: return f""" @@ -294,6 +293,33 @@ def _repr_html_(self) -> str:
  • Limit
    Limit: {self.limit}
    + {self._child_repr_()} +
  • + + """ + + +class Offset(LogicalPlan): + def __init__(self, child: Optional["LogicalPlan"], offset: int = 0) -> None: + super().__init__(child) + self.offset = offset + + def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation: + assert self._child is not None + plan = proto.Relation() + plan.offset.input.CopyFrom(self._child.plan(session)) + plan.offset.offset = self.offset + return plan + + def print(self, indent: int = 0) -> str: + c_buf = self._child.print(indent + LogicalPlan.INDENT) if self._child else "" + return f"{' ' * indent}\n{c_buf}" + + def _repr_html_(self) -> str: + return f""" +
      +
    • + Limit
      Offset: {self.offset}
      {self._child_repr_()}
    • diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index de300946932f1..f6988a1d1200d 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -106,6 +106,13 @@ def test_simple_binary_expressions(self): res = pandas.DataFrame(data={"id": [0, 30, 60, 90]}) self.assert_(pd.equals(res), f"{pd.to_string()} != {res.to_string()}") + def test_limit_offset(self): + df = self.connect.read.table(self.tbl_name) + pd = df.limit(10).offset(1).toPandas() + self.assertEqual(9, len(pd.index)) + pd2 = df.offset(98).limit(10).toPandas() + self.assertEqual(2, len(pd2.index)) + def test_simple_datasource_read(self) -> None: writeDf = self.df_text tmpPath = tempfile.mkdtemp() diff --git a/python/pyspark/sql/tests/connect/test_connect_plan_only.py b/python/pyspark/sql/tests/connect/test_connect_plan_only.py index 96bbb8aa83472..739c24ca96ea4 100644 --- a/python/pyspark/sql/tests/connect/test_connect_plan_only.py +++ b/python/pyspark/sql/tests/connect/test_connect_plan_only.py @@ -44,6 +44,16 @@ def test_filter(self): self.assertEqual(plan.root.filter.condition.unresolved_function.parts, [">"]) self.assertEqual(len(plan.root.filter.condition.unresolved_function.arguments), 2) + def test_limit(self): + df = self.connect.readTable(table_name=self.tbl_name) + limit_plan = df.limit(10)._plan.to_proto(self.connect) + self.assertEqual(limit_plan.root.limit.limit, 10) + + def test_offset(self): + df = self.connect.readTable(table_name=self.tbl_name) + offset_plan = df.offset(10)._plan.to_proto(self.connect) + self.assertEqual(offset_plan.root.offset.offset, 10) + def test_relation_alias(self): df = self.connect.readTable(table_name=self.tbl_name) plan = df.alias("table_alias")._plan.to_proto(self.connect)