Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions connector/connect/src/main/protobuf/spark/connect/base.proto
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ message Request {
// The logical plan to be executed / analyzed.
Plan plan = 3;

// Provides optional information about the client sending the request. This field
// can be used for language or version specific information and is only intended for
// logging purposes and will not be interpreted by the server.
optional string client_type = 4;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this do the same thing as user_context.extensions ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, this is purely a string and not an extension type based on Any


// User Context is used to refer to one particular user session that is executing
// queries in the backend.
message UserContext {
Expand Down
16 changes: 8 additions & 8 deletions python/pyspark/sql/connect/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,9 +301,7 @@ def register_udf(
fun.parts.append(name)
fun.serialized_function = cloudpickle.dumps((function, return_type))

req = pb2.Request()
if self._user_id is not None:
req.user_context.user_id = self._user_id
req = self._request_with_metadata()
req.plan.command.create_function.CopyFrom(fun)

self._execute_and_fetch(req)
Expand Down Expand Up @@ -357,9 +355,7 @@ def range(
)

def _to_pandas(self, plan: pb2.Plan) -> Optional[pandas.DataFrame]:
req = pb2.Request()
if self._user_id is not None:
req.user_context.user_id = self._user_id
req = self._request_with_metadata()
req.plan.CopyFrom(plan)
return self._execute_and_fetch(req)

Expand Down Expand Up @@ -407,12 +403,16 @@ def execute_command(self, command: pb2.Command) -> None:
req.plan.command.CopyFrom(command)
self._execute_and_fetch(req)

def _analyze(self, plan: pb2.Plan) -> AnalyzeResult:
def _request_with_metadata(self) -> pb2.Request:
req = pb2.Request()
req.client_type = "_SPARK_CONNECT_PYTHON"
if self._user_id:
req.user_context.user_id = self._user_id
req.plan.CopyFrom(plan)
return req

def _analyze(self, plan: pb2.Plan) -> AnalyzeResult:
req = self._request_with_metadata()
req.plan.CopyFrom(plan)
resp = self._stub.AnalyzePlan(req, metadata=self._builder.metadata())
return AnalyzeResult.fromProto(resp)

Expand Down
44 changes: 22 additions & 22 deletions python/pyspark/sql/connect/proto/base_pb2.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@


DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"\x92\x02\n\x07Request\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12\x45\n\x0cuser_context\x18\x02 \x01(\x0b\x32".spark.connect.Request.UserContextR\x0buserContext\x12\'\n\x04plan\x18\x03 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1az\n\x0bUserContext\x12\x17\n\x07user_id\x18\x01 \x01(\tR\x06userId\x12\x1b\n\tuser_name\x18\x02 \x01(\tR\x08userName\x12\x35\n\nextensions\x18\xe7\x07 \x03(\x0b\x32\x14.google.protobuf.AnyR\nextensions"\xe0\x06\n\x08Response\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12\x45\n\x0b\x61rrow_batch\x18\x02 \x01(\x0b\x32".spark.connect.Response.ArrowBatchH\x00R\narrowBatch\x12\x42\n\njson_batch\x18\x03 \x01(\x0b\x32!.spark.connect.Response.JSONBatchH\x00R\tjsonBatch\x12\x39\n\x07metrics\x18\x04 \x01(\x0b\x32\x1f.spark.connect.Response.MetricsR\x07metrics\x1a=\n\nArrowBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12\x12\n\x04\x64\x61ta\x18\x02 \x01(\x0cR\x04\x64\x61ta\x1a<\n\tJSONBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12\x12\n\x04\x64\x61ta\x18\x02 \x01(\x0cR\x04\x64\x61ta\x1a\xe4\x03\n\x07Metrics\x12\x46\n\x07metrics\x18\x01 \x03(\x0b\x32,.spark.connect.Response.Metrics.MetricObjectR\x07metrics\x1a\xb6\x02\n\x0cMetricObject\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x17\n\x07plan_id\x18\x02 \x01(\x03R\x06planId\x12\x16\n\x06parent\x18\x03 \x01(\x03R\x06parent\x12o\n\x11\x65xecution_metrics\x18\x04 \x03(\x0b\x32\x42.spark.connect.Response.Metrics.MetricObject.ExecutionMetricsEntryR\x10\x65xecutionMetrics\x1ap\n\x15\x45xecutionMetricsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x41\n\x05value\x18\x02 \x01(\x0b\x32+.spark.connect.Response.Metrics.MetricValueR\x05value:\x02\x38\x01\x1aX\n\x0bMetricValue\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x14\n\x05value\x18\x02 \x01(\x03R\x05value\x12\x1f\n\x0bmetric_type\x18\x03 \x01(\tR\nmetricTypeB\r\n\x0bresult_type"\x86\x01\n\x0f\x41nalyzeResponse\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12/\n\x06schema\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema\x12%\n\x0e\x65xplain_string\x18\x03 \x01(\tR\rexplainString2\xa2\x01\n\x13SparkConnectService\x12\x42\n\x0b\x45xecutePlan\x12\x16.spark.connect.Request\x1a\x17.spark.connect.Response"\x00\x30\x01\x12G\n\x0b\x41nalyzePlan\x12\x16.spark.connect.Request\x1a\x1e.spark.connect.AnalyzeResponse"\x00\x42"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3'
b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"\xc8\x02\n\x07Request\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12\x45\n\x0cuser_context\x18\x02 \x01(\x0b\x32".spark.connect.Request.UserContextR\x0buserContext\x12\'\n\x04plan\x18\x03 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x00R\nclientType\x88\x01\x01\x1az\n\x0bUserContext\x12\x17\n\x07user_id\x18\x01 \x01(\tR\x06userId\x12\x1b\n\tuser_name\x18\x02 \x01(\tR\x08userName\x12\x35\n\nextensions\x18\xe7\x07 \x03(\x0b\x32\x14.google.protobuf.AnyR\nextensionsB\x0e\n\x0c_client_type"\xe0\x06\n\x08Response\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12\x45\n\x0b\x61rrow_batch\x18\x02 \x01(\x0b\x32".spark.connect.Response.ArrowBatchH\x00R\narrowBatch\x12\x42\n\njson_batch\x18\x03 \x01(\x0b\x32!.spark.connect.Response.JSONBatchH\x00R\tjsonBatch\x12\x39\n\x07metrics\x18\x04 \x01(\x0b\x32\x1f.spark.connect.Response.MetricsR\x07metrics\x1a=\n\nArrowBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12\x12\n\x04\x64\x61ta\x18\x02 \x01(\x0cR\x04\x64\x61ta\x1a<\n\tJSONBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12\x12\n\x04\x64\x61ta\x18\x02 \x01(\x0cR\x04\x64\x61ta\x1a\xe4\x03\n\x07Metrics\x12\x46\n\x07metrics\x18\x01 \x03(\x0b\x32,.spark.connect.Response.Metrics.MetricObjectR\x07metrics\x1a\xb6\x02\n\x0cMetricObject\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x17\n\x07plan_id\x18\x02 \x01(\x03R\x06planId\x12\x16\n\x06parent\x18\x03 \x01(\x03R\x06parent\x12o\n\x11\x65xecution_metrics\x18\x04 \x03(\x0b\x32\x42.spark.connect.Response.Metrics.MetricObject.ExecutionMetricsEntryR\x10\x65xecutionMetrics\x1ap\n\x15\x45xecutionMetricsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x41\n\x05value\x18\x02 \x01(\x0b\x32+.spark.connect.Response.Metrics.MetricValueR\x05value:\x02\x38\x01\x1aX\n\x0bMetricValue\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x14\n\x05value\x18\x02 \x01(\x03R\x05value\x12\x1f\n\x0bmetric_type\x18\x03 \x01(\tR\nmetricTypeB\r\n\x0bresult_type"\x86\x01\n\x0f\x41nalyzeResponse\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12/\n\x06schema\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema\x12%\n\x0e\x65xplain_string\x18\x03 \x01(\tR\rexplainString2\xa2\x01\n\x13SparkConnectService\x12\x42\n\x0b\x45xecutePlan\x12\x16.spark.connect.Request\x1a\x17.spark.connect.Response"\x00\x30\x01\x12G\n\x0b\x41nalyzePlan\x12\x16.spark.connect.Request\x1a\x1e.spark.connect.AnalyzeResponse"\x00\x42"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3'
)

_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
Expand All @@ -49,25 +49,25 @@
_PLAN._serialized_start = 158
_PLAN._serialized_end = 274
_REQUEST._serialized_start = 277
_REQUEST._serialized_end = 551
_REQUEST_USERCONTEXT._serialized_start = 429
_REQUEST_USERCONTEXT._serialized_end = 551
_RESPONSE._serialized_start = 554
_RESPONSE._serialized_end = 1418
_RESPONSE_ARROWBATCH._serialized_start = 793
_RESPONSE_ARROWBATCH._serialized_end = 854
_RESPONSE_JSONBATCH._serialized_start = 856
_RESPONSE_JSONBATCH._serialized_end = 916
_RESPONSE_METRICS._serialized_start = 919
_RESPONSE_METRICS._serialized_end = 1403
_RESPONSE_METRICS_METRICOBJECT._serialized_start = 1003
_RESPONSE_METRICS_METRICOBJECT._serialized_end = 1313
_RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_start = 1201
_RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end = 1313
_RESPONSE_METRICS_METRICVALUE._serialized_start = 1315
_RESPONSE_METRICS_METRICVALUE._serialized_end = 1403
_ANALYZERESPONSE._serialized_start = 1421
_ANALYZERESPONSE._serialized_end = 1555
_SPARKCONNECTSERVICE._serialized_start = 1558
_SPARKCONNECTSERVICE._serialized_end = 1720
_REQUEST._serialized_end = 605
_REQUEST_USERCONTEXT._serialized_start = 467
_REQUEST_USERCONTEXT._serialized_end = 589
_RESPONSE._serialized_start = 608
_RESPONSE._serialized_end = 1472
_RESPONSE_ARROWBATCH._serialized_start = 847
_RESPONSE_ARROWBATCH._serialized_end = 908
_RESPONSE_JSONBATCH._serialized_start = 910
_RESPONSE_JSONBATCH._serialized_end = 970
_RESPONSE_METRICS._serialized_start = 973
_RESPONSE_METRICS._serialized_end = 1457
_RESPONSE_METRICS_METRICOBJECT._serialized_start = 1057
_RESPONSE_METRICS_METRICOBJECT._serialized_end = 1367
_RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_start = 1255
_RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end = 1367
_RESPONSE_METRICS_METRICVALUE._serialized_start = 1369
_RESPONSE_METRICS_METRICVALUE._serialized_end = 1457
_ANALYZERESPONSE._serialized_start = 1475
_ANALYZERESPONSE._serialized_end = 1609
_SPARKCONNECTSERVICE._serialized_start = 1612
_SPARKCONNECTSERVICE._serialized_end = 1774
# @@protoc_insertion_point(module_scope)
32 changes: 30 additions & 2 deletions python/pyspark/sql/connect/proto/base_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ class Request(google.protobuf.message.Message):
CLIENT_ID_FIELD_NUMBER: builtins.int
USER_CONTEXT_FIELD_NUMBER: builtins.int
PLAN_FIELD_NUMBER: builtins.int
CLIENT_TYPE_FIELD_NUMBER: builtins.int
client_id: builtins.str
"""The client_id is set by the client to be able to collate streaming responses from
different queries.
Expand All @@ -145,23 +146,50 @@ class Request(google.protobuf.message.Message):
@property
def plan(self) -> global___Plan:
"""The logical plan to be executed / analyzed."""
client_type: builtins.str
"""Provides optional information about the client sending the request. This field
can be used for language or version specific information and is only intended for
logging purposes and will not be interpreted by the server.
"""
def __init__(
self,
*,
client_id: builtins.str = ...,
user_context: global___Request.UserContext | None = ...,
plan: global___Plan | None = ...,
client_type: builtins.str | None = ...,
) -> None: ...
def HasField(
self,
field_name: typing_extensions.Literal["plan", b"plan", "user_context", b"user_context"],
field_name: typing_extensions.Literal[
"_client_type",
b"_client_type",
"client_type",
b"client_type",
"plan",
b"plan",
"user_context",
b"user_context",
],
) -> builtins.bool: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
"client_id", b"client_id", "plan", b"plan", "user_context", b"user_context"
"_client_type",
b"_client_type",
"client_id",
b"client_id",
"client_type",
b"client_type",
"plan",
b"plan",
"user_context",
b"user_context",
],
) -> None: ...
def WhichOneof(
self, oneof_group: typing_extensions.Literal["_client_type", b"_client_type"]
) -> typing_extensions.Literal["client_type"] | None: ...

global___Request = Request

Expand Down