-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implicit Context #2
base: master
Are you sure you want to change the base?
Changes from all commits
a448b85
df48823
8c991b0
02db4ca
2fe0a27
4345aed
1d403f8
8fe347b
01eb91b
73b625c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
# Implicit Context Propagation | ||
|
||
Much has been said about context propagation for the gRPC Python asyncio API. | ||
This is an example demonstrating my proposal for the shape of an API supporting | ||
implicit context propagation. | ||
|
||
Very little is actually expected from the gRPC library itself. The interceptor | ||
API is essentially all that's needed. Middleware-defined server interceptors | ||
install things into coroutine-local context using `contextvars` and | ||
corresponding client interceptors read from `contextvars` and add them to | ||
metadata. | ||
|
||
The context object gives middleware a mechanism to store arbitrary key-value | ||
pairs in a coroutine-local way without plumbing them through the application's | ||
stack from server to client. But without middleware installed, *nothing* is | ||
propagated by default, except timeout and cancellation, which are handled by the | ||
gRPC library itself using `ContextVar` objects managed by the `gRPC` library | ||
itself. | ||
|
||
## Interceptors | ||
|
||
Ancillary to the central point of context, but included in this example out of | ||
necessity is a possible shape for `asyncio` interceptors. These are included | ||
out of necessity, to demonstrate how tracing middleware will interface with | ||
gRPC. | ||
|
||
## Middleware | ||
|
||
The files in this example are supposed to stand in for code from three different | ||
authors. The first, is the application author. The application author owns the | ||
following files: | ||
- `user.proto` | ||
- `server.py` | ||
- `client.py` | ||
|
||
The "User" application makes use of a database with a gRPC interface. The | ||
database author also owns several files in this example: | ||
|
||
- `db.proto` | ||
- `db_client.py` | ||
- `db.py` | ||
|
||
Finally, a third-party has written a tracing library called "Oxton". They have | ||
written and made avaiable gRPC interceptors in the following files: | ||
|
||
- `client_interceptors.py` | ||
- `server_interceptors.py` | ||
- `interceptor_common.py` | ||
|
||
The only actual interaction with the context API happens in the previous two | ||
files. Middleware authors are the only ones expected to touch it. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import asyncio | ||
import datetime | ||
import grpc | ||
from proto import user_pb2, user_pb2_grpc | ||
import client_interceptors | ||
|
||
SERVER_TARGET = 'users.grpc.io:50051' | ||
|
||
async def create_user(): | ||
tracing_interceptor = tracing_interceptors.probabilistic_tracing_interceptor() | ||
async with grpc.intercept_channel( | ||
grpc.aio.insecure_channel(SERVER_TARGET), | ||
tracing_interceptor) as channel: | ||
stub = user_pb2_grpc.UserServiceStub(channel) | ||
new_user = user_pb2.CreateUserRequest( | ||
name="Smitty Werbenjagermanjensen") | ||
# The timeout will implicitly be propagated through the following files: | ||
# - client.py | ||
# - server.py | ||
# - db_client.py | ||
# - db.py | ||
# | ||
# If a timeout occurs at any link in that chain, an exception will be | ||
# raised from the call to stub.CreateUser and an on_done callback will | ||
# be executed in the service handlers in server.py and db.py, allowing | ||
# the user to exit their coroutines using, e.g. asyncio.Event. | ||
response = await stub.CreateUser(new_user, | ||
timeout=datetime.timedelta(seconds=3)) | ||
print(f"Successfully created user {response.user.name} with id {response.user.id}.") | ||
|
||
|
||
def main(): | ||
asyncio.run(create_user()) | ||
|
||
if __name__ == "__main__": | ||
main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import grpc | ||
import grpc.aio | ||
|
||
import random | ||
|
||
from typing import Any, AsyncGenerator, Tuple | ||
|
||
import interceptor_common | ||
|
||
_DEFAULT_TRACING_PROBABLIITY = 0.5 | ||
|
||
# NOTE: This file is meant as a stand-in for third party tracing middleware, | ||
# e.g. OpenCensus (https://opencensus.io/). Different libraries have different | ||
# opinions on the data that should be tracked by each request and what | ||
# percentage of requests should be traced. | ||
|
||
def tracing_interceptor() -> grpc.aio.GenericClientInterceptor: | ||
return probabilistic_tracing_interceptor(1.0) | ||
|
||
|
||
def probabilistic_tracing_interceptor( | ||
tracing_probability: float = _DEFAULT_TRACING_PROBABILITY) -> grpc.aio.GenericClientInterceptor: | ||
""" Create an interceptor that tracks a supplied percentage of requests. | ||
|
||
This interceptor should only be used upon ingress to the system. All | ||
intermediary services should use an unconditional interceptor. | ||
|
||
Args: | ||
tracing_probability: A value in the range [0.0, 1.0] representing the | ||
fraction of requests whose transit through the system will be traced. | ||
""" | ||
# NOTE: The return type here closely mirrors the original | ||
# GenericClientInterceptor interface. We may want to clean it up. | ||
def intercept_call(client_call_details: grpc.aio.ClientCallDetails, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
request_generator: AsyncGenerator[Any, Any, Any] | ||
request_streaming: bool, | ||
response_streaming: bool) -> Tuple[grpc.aio.ClientCallDetails, | ||
AsyncGenerator[Any, Any, Any], | ||
Callable[AsyncGenerator[Any, Any, Any]]] | ||
metadata = client_call_details.metadata or [] | ||
if random.random() < tracing_probability: | ||
# Tag this request. | ||
|
||
trace_id = interceptor_common.TRACE_ID_CONTEXTVAR.get() | ||
parent_id = None | ||
if trace_id is None: | ||
# This request has no parent. Generate a new trace ID. | ||
trace_id = interceptor_common.generate_trace_id() | ||
else: | ||
parent_id = interceptor_common.SPAN_ID_CONTEXTVAR.get() | ||
for key, value in ((interceptor_common.TRACE_ID_KEY, trace_id), | ||
(interceptor_common.PARENT_ID_KEY, parent_id)): | ||
if value is not None: | ||
metadata.append((key, value)) | ||
# TODO: We might want a better way to construct one of these from an | ||
# existing ClientCallDetails object. This current API very closely | ||
# mirrors the existing interceptor API. | ||
new_client_call_details = grpc.aio.ClientCallDetails(client_call_details.method, | ||
client_call_details.timeout, | ||
metadata, | ||
client_call_details.credentials) | ||
return new_client_call_details, request_generator, None | ||
|
||
return grpc.aio.generic_client_interceptor(intercept_call) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. just asking, |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
syntax = "proto3"; | ||
|
||
message Record { | ||
int64 id = 1; | ||
string value = 2; | ||
} | ||
|
||
message CreateRequest { | ||
string value = 1; | ||
} | ||
|
||
message CreateResponse { | ||
Record record = 1; | ||
} | ||
|
||
service RecordService { | ||
rpc Create (CreateRequest) returns (CreateResponse) {} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import asyncio | ||
import grpc | ||
from proto import db_pb2, db_pb2_grpc | ||
|
||
HOST = '[::]:7777' | ||
|
||
IN_MEMORY_RECORDS = { | ||
0: 'root', | ||
1: 'Kevin Flynn', | ||
2: 'Hiro Protagonist', | ||
} | ||
MAX_ID = max(IN_MEMORY_RECORDS) | ||
|
||
|
||
class RecordService(db_pb2_grpc.RecordServiceServicer): | ||
async def Create(self, | ||
request: db_pb2.CreateRequest, | ||
context: grpc.aio.ServicerContext): | ||
done_event = asyncio.Event() | ||
def on_done(): | ||
done_event.set() | ||
global MAX_ID | ||
new_id = None | ||
if value in _IN_MEMORY_RECORDS.values(): | ||
context.abort(grpc.StatusCode.ALREADY_EXISTS, | ||
f"Value '{request.value}' already exists.") | ||
MAX_ID += 1 | ||
new_id = MAX_ID | ||
IN_MEMORY_RECORDS[new_id] = request.value | ||
# Pretend we have a lot more work to do here. This just demonstrates | ||
# what the application author must do to cooperate with cancellation. | ||
if done_event.is_set(): | ||
return | ||
return db_pb2.Record(id=new_id, value=request.value) | ||
|
||
|
||
async def run_server(): | ||
server = grpc.aio.server() | ||
server.add_insecure_port(HOST) | ||
user_pb2_grpc.add_RecordServiceServicer_to_server(RecordService(), server) | ||
server.start() | ||
await server.wait_for_termination() | ||
|
||
|
||
def main(): | ||
asyncio.run(run_server()) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
import grpc | ||
from proto import db_pb2, db_pb2_grpc | ||
|
||
|
||
# NOTE: This file is meant to stand in for a third-party DB client library. | ||
# Importantly, assume that this file does not lie under the same domain of | ||
# control as client.py or server.py and, as such, changes to this file have | ||
# a substantially higher cost. | ||
async def create_record(value: Text, channel: grpc.Channel) -> db_pb2.Record: | ||
stub = db_pb2_grpc.RecordServiceStub(channel) | ||
response = await stub.Create(db_pb2.CreateRequest(value=value)) | ||
return response.record |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# NOTE: The library name is used in context keys to ensure we do not clash with | ||
# any other libraries that may be populating the context object. | ||
|
||
import contextvars | ||
|
||
TRACING_LIBRARY_NAME = 'oxton' | ||
TRACE_ID_KEY = f'{TRACING_LIBRARY_NAME}.trace_id' | ||
PARENT_ID_KEY = f'{TRACING_LIBRARY_NAME}.parent_id' | ||
SPAN_ID_KEY = f'{TRACING_LIBRARY_NAME}.span_id' | ||
|
||
TRACE_ID_CONTEXTVAR = contextvars.ContextVar(TRACE_ID_KEY) | ||
PARENT_ID_CONTEXTVAR = contextvars.ContextVar(PARENT_ID_KEY) | ||
SPAN_ID_KEY = contextvars.ContextVar(SPAN_ID_KEY) | ||
|
||
def generate_trace_id() -> Text: | ||
"""Generates a trace ID according to a standards-defined algorithm.""" | ||
|
||
def generate_span_id() -> Text: | ||
"""Generates a span ID according to a standards-defined algorithm.""" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
import grpc | ||
from proto import user_pb2, user_pb2_grpc | ||
import db_client | ||
import client_interceptors | ||
import server_interceptors | ||
|
||
HOST = '[::]:50051' | ||
DB_TARGET = 'db.grpc.io:7777' | ||
|
||
class UserService(user_pb2_grpc.UserServiceServicer): | ||
def __init__(self): | ||
self._db_channel = grpc.intercept_channel( | ||
grpc.aio.insecure_channel(DB_TARGET), | ||
tracing_interceptors.tracing_interceptor()) | ||
|
||
async def CreateUser(self, | ||
request: user_pb2.CreateUserRequest, | ||
context: grpc.aio.ServicerContext): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be nice to have also an example of how the timeout could be modified by the server evaluating the timeout provided by the context, for making a decision if the timeout needs to be modified or if it does not exist adding it. |
||
record = await db_client.create_record(request.name, self._db_channel) | ||
return user_pb2.CreateUserResponse( | ||
user=user_pb2.User(id=record.id, | ||
value=record.value)) | ||
|
||
|
||
async def run_server(): | ||
tracing_interceptor = server_interceptors.TracingInterceptor() | ||
server = grpc.aio.server(interceptors=(tracing_interceptor)) | ||
server.add_insecure_port(HOST) | ||
user_pb2_grpc.add_UserServiceServicer_to_server(UserService(), server) | ||
server.start() | ||
server_task = asyncio.create_task(server.wait_for_termination()) | ||
await asyncio.gather( | ||
server_task, | ||
tracing_interceptor.transmit_events(server_task), | ||
) | ||
|
||
|
||
def main(): | ||
asyncio.run(run_server()) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
import asyncio | ||
import grpc | ||
import grpc.aio | ||
import logging | ||
import time | ||
import aiohttp | ||
|
||
import interceptor_common | ||
|
||
_BATCH_SIZE = 256 | ||
_TRACING_SERVER = 'tracing.grpc.io:80' | ||
|
||
def _encode_batch(batch: Sequence[Tuple[Text, Text, Text, float]]) -> Text: | ||
"""Encode the events into json or something like that.""" | ||
gnossen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
class TracingInterceptor(grpc.aio.ServerInterceptor): | ||
|
||
def __init__(self): | ||
self._ingress_log_queue = asyncio.Queue() | ||
self._tracing_server_session = aiohttp.ClientSession() | ||
|
||
async def _log_service_ingress(trace_id: Text, parent_id: Text, span_id: Text): | ||
ingress_time = time.time() | ||
# These logs will be dequeued by another coroutine and transmitted across the | ||
# network to a central tracing server. | ||
await self._ingress_log_queue.put((trace_id, parent_id, span_id, ingress_time)) | ||
|
||
async def transmit_events(server_task: asyncio.Task): | ||
while not server_task.done(): | ||
event_batch = [] | ||
for _ in range(_BATCH_SIZE): | ||
event_batch.append(await self._ingress_log_queue.get()) | ||
encoded_batch = _encoded_batch(event_batch) | ||
async with self._tracing_server_session.put(_TRACING_SERVER, encoded_batch) as response: | ||
text = await response.text() | ||
if response.status != 200: | ||
logging.warning(f"Failed to send batch to server: {text}") | ||
|
||
|
||
async def intercept_service(self, | ||
continuation: Callable[[grpc.aio.HandlerCallDetails], grpc.aio.RpcMethodHandler], | ||
handler_call_details: grpc.aio.HandlerCallDetails) -> grpc.aio.GenericRpcHandler: | ||
trace_id = None | ||
parent_id = None | ||
span_id = interceptor_common.generate_span_id() | ||
# Pull the appropriate values out of the request metadata. | ||
for key, value in handler_call_details.invocation_metadata: | ||
if key == interceptor_common.TRACE_ID_KEY: | ||
trace_id = value | ||
elif key == interceptor_common.PARENT_ID_KEY: | ||
parent_id = value | ||
if trace_id is None or parent_id is None: | ||
logging.warning("Received malformed tracing metadata.") | ||
# Continue the RPC nonetheless. | ||
else: | ||
# Inject the tracing data into the coroutine-local context. | ||
interceptor_common.TRACE_ID_CONTEXTVAR.set(trace_id) | ||
interceptor_common.PARENT_ID_CONTEXTVAR.set(parent_id) | ||
interceptor_common.SPAN_ID_CONTEXTVAR.set(span_id) | ||
await _log_service_ingress(trace_id, parent_id, span_id) | ||
return await continuation(handler_call_details) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
syntax = "proto3"; | ||
|
||
message User { | ||
int64 id = 1; | ||
string name = 2; | ||
} | ||
|
||
message CreateUserRequest { | ||
string name = 1; | ||
} | ||
|
||
message CreateUserResponse { | ||
User user = 1; | ||
} | ||
|
||
service UserService { | ||
rpc CreateUser (CreateUserRequest) returns (CreateUserResponse) {} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we are on the same page!