Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions connector/connect/src/main/protobuf/spark/connect/base.proto
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,8 @@ message ExecutePlanResponse {
service SparkConnectService {

// Executes a request that contains the query and returns a stream of [[Response]].
//
// It is guaranteed that there is at least one ARROW batch returned even if the result set is empty.
rpc ExecutePlan(ExecutePlanRequest) returns (stream ExecutePlanResponse) {}

// Analyzes a query and returns a [[AnalyzeResponse]] containing metadata about the query.
Expand Down
51 changes: 26 additions & 25 deletions python/pyspark/sql/connect/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def register_udf(
req = self._execute_plan_request_with_metadata()
req.plan.command.create_function.CopyFrom(fun)

self._execute_and_fetch(req)
self._execute(req)
return name

def _build_metrics(self, metrics: "pb2.ExecutePlanResponse.Metrics") -> List[PlanMetrics]:
Expand Down Expand Up @@ -350,7 +350,7 @@ def range(
Range(start=start, end=end, step=step, num_partitions=numPartitions), self
)

def _to_pandas(self, plan: pb2.Plan) -> Optional[pandas.DataFrame]:
def _to_pandas(self, plan: pb2.Plan) -> "pandas.DataFrame":
req = self._execute_plan_request_with_metadata()
req.plan.CopyFrom(plan)
return self._execute_and_fetch(req)
Expand Down Expand Up @@ -398,7 +398,7 @@ def execute_command(self, command: pb2.Command) -> None:
if self._user_id:
req.user_context.user_id = self._user_id
req.plan.command.CopyFrom(command)
self._execute_and_fetch(req)
self._execute(req)
return

def _execute_plan_request_with_metadata(self) -> pb2.ExecutePlanRequest:
Expand Down Expand Up @@ -439,13 +439,16 @@ def _analyze(self, plan: pb2.Plan, explain_mode: str = "extended") -> AnalyzeRes
resp = self._stub.AnalyzePlan(req, metadata=self._builder.metadata())
return AnalyzeResult.fromProto(resp)

def _process_batch(self, b: pb2.ExecutePlanResponse) -> Optional[pandas.DataFrame]:
if b.arrow_batch is not None and len(b.arrow_batch.data) > 0:
with pa.ipc.open_stream(b.arrow_batch.data) as rd:
return rd.read_pandas()
return None
def _process_batch(self, arrow_batch: pb2.ExecutePlanResponse.ArrowBatch) -> "pandas.DataFrame":
with pa.ipc.open_stream(arrow_batch.data) as rd:
return rd.read_pandas()

def _execute_and_fetch(self, req: pb2.ExecutePlanRequest) -> Optional[pandas.DataFrame]:
def _execute(self, req: pb2.ExecutePlanRequest) -> None:
for b in self._stub.ExecutePlan(req, metadata=self._builder.metadata()):
continue
return

def _execute_and_fetch(self, req: pb2.ExecutePlanRequest) -> "pandas.DataFrame":
import pandas as pd

m: Optional[pb2.ExecutePlanResponse.Metrics] = None
Expand All @@ -454,23 +457,21 @@ def _execute_and_fetch(self, req: pb2.ExecutePlanRequest) -> Optional[pandas.Dat
for b in self._stub.ExecutePlan(req, metadata=self._builder.metadata()):
if b.metrics is not None:
m = b.metrics

pb = self._process_batch(b)
if pb is not None:
if b.HasField("arrow_batch"):
pb = self._process_batch(b.arrow_batch)
result_dfs.append(pb)

if len(result_dfs) > 0:
df = pd.concat(result_dfs)
assert len(result_dfs) > 0

# pd.concat generates non-consecutive index like:
# Int64Index([0, 1, 0, 1, 2, 0, 1, 0, 1, 2], dtype='int64')
# set it to RangeIndex to be consistent with pyspark
n = len(df)
df.set_index(pd.RangeIndex(start=0, stop=n, step=1), inplace=True)
df = pd.concat(result_dfs)

# Attach the metrics to the DataFrame attributes.
if m is not None:
df.attrs["metrics"] = self._build_metrics(m)
return df
else:
return None
# pd.concat generates non-consecutive index like:
# Int64Index([0, 1, 0, 1, 2, 0, 1, 0, 1, 2], dtype='int64')
# set it to RangeIndex to be consistent with pyspark
n = len(df)
df.set_index(pd.RangeIndex(start=0, stop=n, step=1), inplace=True)

# Attach the metrics to the DataFrame attributes.
if m is not None:
df.attrs["metrics"] = self._build_metrics(m)
return df
2 changes: 1 addition & 1 deletion python/pyspark/sql/connect/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,7 +789,7 @@ def collect(self) -> List[Row]:
else:
return []

def toPandas(self) -> Optional["pandas.DataFrame"]:
def toPandas(self) -> "pandas.DataFrame":
if self._plan is None:
raise Exception("Cannot collect on empty plan.")
if self._session is None:
Expand Down
5 changes: 4 additions & 1 deletion python/pyspark/sql/connect/proto/base_pb2_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,10 @@ class SparkConnectServiceServicer(object):
"""Main interface for the SparkConnect service."""

def ExecutePlan(self, request, context):
"""Executes a request that contains the query and returns a stream of [[Response]]."""
"""Executes a request that contains the query and returns a stream of [[Response]].

It is guaranteed that there is at least one ARROW batch returned even if the result set is empty.
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details("Method not implemented!")
raise NotImplementedError("Method not implemented!")
Expand Down