diff --git a/connector/connect/src/main/protobuf/spark/connect/base.proto b/connector/connect/src/main/protobuf/spark/connect/base.proto index 277da6b2431d0..d6dac4854ef55 100644 --- a/connector/connect/src/main/protobuf/spark/connect/base.proto +++ b/connector/connect/src/main/protobuf/spark/connect/base.proto @@ -174,6 +174,8 @@ message ExecutePlanResponse { service SparkConnectService { // Executes a request that contains the query and returns a stream of [[Response]]. + // + // It is guaranteed that there is at least one ARROW batch returned even if the result set is empty. rpc ExecutePlan(ExecutePlanRequest) returns (stream ExecutePlanResponse) {} // Analyzes a query and returns a [[AnalyzeResponse]] containing metadata about the query. diff --git a/python/pyspark/sql/connect/client.py b/python/pyspark/sql/connect/client.py index deb9ef6f3be6f..24d104a0418ec 100644 --- a/python/pyspark/sql/connect/client.py +++ b/python/pyspark/sql/connect/client.py @@ -300,7 +300,7 @@ def register_udf( req = self._execute_plan_request_with_metadata() req.plan.command.create_function.CopyFrom(fun) - self._execute_and_fetch(req) + self._execute(req) return name def _build_metrics(self, metrics: "pb2.ExecutePlanResponse.Metrics") -> List[PlanMetrics]: @@ -350,7 +350,7 @@ def range( Range(start=start, end=end, step=step, num_partitions=numPartitions), self ) - def _to_pandas(self, plan: pb2.Plan) -> Optional[pandas.DataFrame]: + def _to_pandas(self, plan: pb2.Plan) -> "pandas.DataFrame": req = self._execute_plan_request_with_metadata() req.plan.CopyFrom(plan) return self._execute_and_fetch(req) @@ -398,7 +398,7 @@ def execute_command(self, command: pb2.Command) -> None: if self._user_id: req.user_context.user_id = self._user_id req.plan.command.CopyFrom(command) - self._execute_and_fetch(req) + self._execute(req) return def _execute_plan_request_with_metadata(self) -> pb2.ExecutePlanRequest: @@ -439,13 +439,16 @@ def _analyze(self, plan: pb2.Plan, explain_mode: str = "extended") -> AnalyzeRes resp = self._stub.AnalyzePlan(req, metadata=self._builder.metadata()) return AnalyzeResult.fromProto(resp) - def _process_batch(self, b: pb2.ExecutePlanResponse) -> Optional[pandas.DataFrame]: - if b.arrow_batch is not None and len(b.arrow_batch.data) > 0: - with pa.ipc.open_stream(b.arrow_batch.data) as rd: - return rd.read_pandas() - return None + def _process_batch(self, arrow_batch: pb2.ExecutePlanResponse.ArrowBatch) -> "pandas.DataFrame": + with pa.ipc.open_stream(arrow_batch.data) as rd: + return rd.read_pandas() - def _execute_and_fetch(self, req: pb2.ExecutePlanRequest) -> Optional[pandas.DataFrame]: + def _execute(self, req: pb2.ExecutePlanRequest) -> None: + for b in self._stub.ExecutePlan(req, metadata=self._builder.metadata()): + continue + return + + def _execute_and_fetch(self, req: pb2.ExecutePlanRequest) -> "pandas.DataFrame": import pandas as pd m: Optional[pb2.ExecutePlanResponse.Metrics] = None @@ -454,23 +457,21 @@ def _execute_and_fetch(self, req: pb2.ExecutePlanRequest) -> Optional[pandas.Dat for b in self._stub.ExecutePlan(req, metadata=self._builder.metadata()): if b.metrics is not None: m = b.metrics - - pb = self._process_batch(b) - if pb is not None: + if b.HasField("arrow_batch"): + pb = self._process_batch(b.arrow_batch) result_dfs.append(pb) - if len(result_dfs) > 0: - df = pd.concat(result_dfs) + assert len(result_dfs) > 0 - # pd.concat generates non-consecutive index like: - # Int64Index([0, 1, 0, 1, 2, 0, 1, 0, 1, 2], dtype='int64') - # set it to RangeIndex to be consistent with pyspark - n = len(df) - df.set_index(pd.RangeIndex(start=0, stop=n, step=1), inplace=True) + df = pd.concat(result_dfs) - # Attach the metrics to the DataFrame attributes. - if m is not None: - df.attrs["metrics"] = self._build_metrics(m) - return df - else: - return None + # pd.concat generates non-consecutive index like: + # Int64Index([0, 1, 0, 1, 2, 0, 1, 0, 1, 2], dtype='int64') + # set it to RangeIndex to be consistent with pyspark + n = len(df) + df.set_index(pd.RangeIndex(start=0, stop=n, step=1), inplace=True) + + # Attach the metrics to the DataFrame attributes. + if m is not None: + df.attrs["metrics"] = self._build_metrics(m) + return df diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index ff14945db0f3d..bd374dcf814e6 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -789,7 +789,7 @@ def collect(self) -> List[Row]: else: return [] - def toPandas(self) -> Optional["pandas.DataFrame"]: + def toPandas(self) -> "pandas.DataFrame": if self._plan is None: raise Exception("Cannot collect on empty plan.") if self._session is None: diff --git a/python/pyspark/sql/connect/proto/base_pb2_grpc.py b/python/pyspark/sql/connect/proto/base_pb2_grpc.py index 139727e283007..aff5897f520f8 100644 --- a/python/pyspark/sql/connect/proto/base_pb2_grpc.py +++ b/python/pyspark/sql/connect/proto/base_pb2_grpc.py @@ -46,7 +46,10 @@ class SparkConnectServiceServicer(object): """Main interface for the SparkConnect service.""" def ExecutePlan(self, request, context): - """Executes a request that contains the query and returns a stream of [[Response]].""" + """Executes a request that contains the query and returns a stream of [[Response]]. + + It is guaranteed that there is at least one ARROW batch returned even if the result set is empty. + """ context.set_code(grpc.StatusCode.UNIMPLEMENTED) context.set_details("Method not implemented!") raise NotImplementedError("Method not implemented!")