Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implicit Context #2

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions context/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Implicit Context Propagation

Much has been said about context propagation for the gRPC Python asyncio API.
This is an example demonstrating my proposal for the shape of an API supporting
implicit context propagation.

Very little is actually expected from the gRPC library itself. The interceptor
API is essentially all that's needed. Middleware-defined server interceptors
install things into coroutine-local context using `contextvars` and
corresponding client interceptors read from `contextvars` and add them to
metadata.

The context object gives middleware a mechanism to store arbitrary key-value
pairs in a coroutine-local way without plumbing them through the application's
stack from server to client. But without middleware installed, *nothing* is
propagated by default, except timeout and cancellation, which are handled by the
gRPC library itself using `ContextVar` objects managed by the `gRPC` library
itself.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we are on the same page!


## Interceptors

Ancillary to the central point of context, but included in this example out of
necessity is a possible shape for `asyncio` interceptors. These are included
out of necessity, to demonstrate how tracing middleware will interface with
gRPC.

## Middleware

The files in this example are supposed to stand in for code from three different
authors. The first, is the application author. The application author owns the
following files:
- `user.proto`
- `server.py`
- `client.py`

The "User" application makes use of a database with a gRPC interface. The
database author also owns several files in this example:

- `db.proto`
- `db_client.py`
- `db.py`

Finally, a third-party has written a tracing library called "Oxton". They have
written and made avaiable gRPC interceptors in the following files:

- `client_interceptors.py`
- `server_interceptors.py`
- `interceptor_common.py`

The only actual interaction with the context API happens in the previous two
files. Middleware authors are the only ones expected to touch it.
36 changes: 36 additions & 0 deletions context/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import asyncio
import datetime
import grpc
from proto import user_pb2, user_pb2_grpc
import client_interceptors

SERVER_TARGET = 'users.grpc.io:50051'

async def create_user():
tracing_interceptor = tracing_interceptors.probabilistic_tracing_interceptor()
async with grpc.intercept_channel(
grpc.aio.insecure_channel(SERVER_TARGET),
tracing_interceptor) as channel:
stub = user_pb2_grpc.UserServiceStub(channel)
new_user = user_pb2.CreateUserRequest(
name="Smitty Werbenjagermanjensen")
# The timeout will implicitly be propagated through the following files:
# - client.py
# - server.py
# - db_client.py
# - db.py
#
# If a timeout occurs at any link in that chain, an exception will be
# raised from the call to stub.CreateUser and an on_done callback will
# be executed in the service handlers in server.py and db.py, allowing
# the user to exit their coroutines using, e.g. asyncio.Event.
response = await stub.CreateUser(new_user,
timeout=datetime.timedelta(seconds=3))
print(f"Successfully created user {response.user.name} with id {response.user.id}.")


def main():
asyncio.run(create_user())

if __name__ == "__main__":
main()
64 changes: 64 additions & 0 deletions context/client_interceptors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import grpc
import grpc.aio

import random

from typing import Any, AsyncGenerator, Tuple

import interceptor_common

_DEFAULT_TRACING_PROBABLIITY = 0.5

# NOTE: This file is meant as a stand-in for third party tracing middleware,
# e.g. OpenCensus (https://opencensus.io/). Different libraries have different
# opinions on the data that should be tracked by each request and what
# percentage of requests should be traced.

def tracing_interceptor() -> grpc.aio.GenericClientInterceptor:
return probabilistic_tracing_interceptor(1.0)


def probabilistic_tracing_interceptor(
tracing_probability: float = _DEFAULT_TRACING_PROBABILITY) -> grpc.aio.GenericClientInterceptor:
""" Create an interceptor that tracks a supplied percentage of requests.

This interceptor should only be used upon ingress to the system. All
intermediary services should use an unconditional interceptor.

Args:
tracing_probability: A value in the range [0.0, 1.0] representing the
fraction of requests whose transit through the system will be traced.
"""
# NOTE: The return type here closely mirrors the original
# GenericClientInterceptor interface. We may want to clean it up.
def intercept_call(client_call_details: grpc.aio.ClientCallDetails,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

async?

request_generator: AsyncGenerator[Any, Any, Any]
request_streaming: bool,
response_streaming: bool) -> Tuple[grpc.aio.ClientCallDetails,
AsyncGenerator[Any, Any, Any],
Callable[AsyncGenerator[Any, Any, Any]]]
metadata = client_call_details.metadata or []
if random.random() < tracing_probability:
# Tag this request.

trace_id = interceptor_common.TRACE_ID_CONTEXTVAR.get()
parent_id = None
if trace_id is None:
# This request has no parent. Generate a new trace ID.
trace_id = interceptor_common.generate_trace_id()
else:
parent_id = interceptor_common.SPAN_ID_CONTEXTVAR.get()
for key, value in ((interceptor_common.TRACE_ID_KEY, trace_id),
(interceptor_common.PARENT_ID_KEY, parent_id)):
if value is not None:
metadata.append((key, value))
# TODO: We might want a better way to construct one of these from an
# existing ClientCallDetails object. This current API very closely
# mirrors the existing interceptor API.
new_client_call_details = grpc.aio.ClientCallDetails(client_call_details.method,
client_call_details.timeout,
metadata,
client_call_details.credentials)
return new_client_call_details, request_generator, None

return grpc.aio.generic_client_interceptor(intercept_call)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just asking, generic_client_interceptor would provide a way for creating an interceptor without needing to implement a derivated interceptor class? does it exist in the current API?

18 changes: 18 additions & 0 deletions context/db.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
syntax = "proto3";

message Record {
int64 id = 1;
string value = 2;
}

message CreateRequest {
string value = 1;
}

message CreateResponse {
Record record = 1;
}

service RecordService {
rpc Create (CreateRequest) returns (CreateResponse) {}
}
50 changes: 50 additions & 0 deletions context/db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import asyncio
import grpc
from proto import db_pb2, db_pb2_grpc

HOST = '[::]:7777'

IN_MEMORY_RECORDS = {
0: 'root',
1: 'Kevin Flynn',
2: 'Hiro Protagonist',
}
MAX_ID = max(IN_MEMORY_RECORDS)


class RecordService(db_pb2_grpc.RecordServiceServicer):
async def Create(self,
request: db_pb2.CreateRequest,
context: grpc.aio.ServicerContext):
done_event = asyncio.Event()
def on_done():
done_event.set()
global MAX_ID
new_id = None
if value in _IN_MEMORY_RECORDS.values():
context.abort(grpc.StatusCode.ALREADY_EXISTS,
f"Value '{request.value}' already exists.")
MAX_ID += 1
new_id = MAX_ID
IN_MEMORY_RECORDS[new_id] = request.value
# Pretend we have a lot more work to do here. This just demonstrates
# what the application author must do to cooperate with cancellation.
if done_event.is_set():
return
return db_pb2.Record(id=new_id, value=request.value)


async def run_server():
server = grpc.aio.server()
server.add_insecure_port(HOST)
user_pb2_grpc.add_RecordServiceServicer_to_server(RecordService(), server)
server.start()
await server.wait_for_termination()


def main():
asyncio.run(run_server())


if __name__ == "__main__":
main()
12 changes: 12 additions & 0 deletions context/db_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import grpc
from proto import db_pb2, db_pb2_grpc


# NOTE: This file is meant to stand in for a third-party DB client library.
# Importantly, assume that this file does not lie under the same domain of
# control as client.py or server.py and, as such, changes to this file have
# a substantially higher cost.
async def create_record(value: Text, channel: grpc.Channel) -> db_pb2.Record:
stub = db_pb2_grpc.RecordServiceStub(channel)
response = await stub.Create(db_pb2.CreateRequest(value=value))
return response.record
19 changes: 19 additions & 0 deletions context/interceptor_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# NOTE: The library name is used in context keys to ensure we do not clash with
# any other libraries that may be populating the context object.

import contextvars

TRACING_LIBRARY_NAME = 'oxton'
TRACE_ID_KEY = f'{TRACING_LIBRARY_NAME}.trace_id'
PARENT_ID_KEY = f'{TRACING_LIBRARY_NAME}.parent_id'
SPAN_ID_KEY = f'{TRACING_LIBRARY_NAME}.span_id'

TRACE_ID_CONTEXTVAR = contextvars.ContextVar(TRACE_ID_KEY)
PARENT_ID_CONTEXTVAR = contextvars.ContextVar(PARENT_ID_KEY)
SPAN_ID_KEY = contextvars.ContextVar(SPAN_ID_KEY)

def generate_trace_id() -> Text:
"""Generates a trace ID according to a standards-defined algorithm."""

def generate_span_id() -> Text:
"""Generates a span ID according to a standards-defined algorithm."""
43 changes: 43 additions & 0 deletions context/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import grpc
from proto import user_pb2, user_pb2_grpc
import db_client
import client_interceptors
import server_interceptors

HOST = '[::]:50051'
DB_TARGET = 'db.grpc.io:7777'

class UserService(user_pb2_grpc.UserServiceServicer):
def __init__(self):
self._db_channel = grpc.intercept_channel(
grpc.aio.insecure_channel(DB_TARGET),
tracing_interceptors.tracing_interceptor())

async def CreateUser(self,
request: user_pb2.CreateUserRequest,
context: grpc.aio.ServicerContext):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to have also an example of how the timeout could be modified by the server evaluating the timeout provided by the context, for making a decision if the timeout needs to be modified or if it does not exist adding it.

record = await db_client.create_record(request.name, self._db_channel)
return user_pb2.CreateUserResponse(
user=user_pb2.User(id=record.id,
value=record.value))


async def run_server():
tracing_interceptor = server_interceptors.TracingInterceptor()
server = grpc.aio.server(interceptors=(tracing_interceptor))
server.add_insecure_port(HOST)
user_pb2_grpc.add_UserServiceServicer_to_server(UserService(), server)
server.start()
server_task = asyncio.create_task(server.wait_for_termination())
await asyncio.gather(
server_task,
tracing_interceptor.transmit_events(server_task),
)


def main():
asyncio.run(run_server())


if __name__ == "__main__":
main()
62 changes: 62 additions & 0 deletions context/server_interceptors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import asyncio
import grpc
import grpc.aio
import logging
import time
import aiohttp

import interceptor_common

_BATCH_SIZE = 256
_TRACING_SERVER = 'tracing.grpc.io:80'

def _encode_batch(batch: Sequence[Tuple[Text, Text, Text, float]]) -> Text:
"""Encode the events into json or something like that."""

class TracingInterceptor(grpc.aio.ServerInterceptor):

def __init__(self):
self._ingress_log_queue = asyncio.Queue()
self._tracing_server_session = aiohttp.ClientSession()

async def _log_service_ingress(trace_id: Text, parent_id: Text, span_id: Text):
ingress_time = time.time()
# These logs will be dequeued by another coroutine and transmitted across the
# network to a central tracing server.
await self._ingress_log_queue.put((trace_id, parent_id, span_id, ingress_time))

async def transmit_events(server_task: asyncio.Task):
while not server_task.done():
event_batch = []
for _ in range(_BATCH_SIZE):
event_batch.append(await self._ingress_log_queue.get())
encoded_batch = _encoded_batch(event_batch)
async with self._tracing_server_session.put(_TRACING_SERVER, encoded_batch) as response:
text = await response.text()
if response.status != 200:
logging.warning(f"Failed to send batch to server: {text}")


async def intercept_service(self,
continuation: Callable[[grpc.aio.HandlerCallDetails], grpc.aio.RpcMethodHandler],
handler_call_details: grpc.aio.HandlerCallDetails) -> grpc.aio.GenericRpcHandler:
trace_id = None
parent_id = None
span_id = interceptor_common.generate_span_id()
# Pull the appropriate values out of the request metadata.
for key, value in handler_call_details.invocation_metadata:
if key == interceptor_common.TRACE_ID_KEY:
trace_id = value
elif key == interceptor_common.PARENT_ID_KEY:
parent_id = value
if trace_id is None or parent_id is None:
logging.warning("Received malformed tracing metadata.")
# Continue the RPC nonetheless.
else:
# Inject the tracing data into the coroutine-local context.
interceptor_common.TRACE_ID_CONTEXTVAR.set(trace_id)
interceptor_common.PARENT_ID_CONTEXTVAR.set(parent_id)
interceptor_common.SPAN_ID_CONTEXTVAR.set(span_id)
await _log_service_ingress(trace_id, parent_id, span_id)
return await continuation(handler_call_details)

18 changes: 18 additions & 0 deletions context/user.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
syntax = "proto3";

message User {
int64 id = 1;
string name = 2;
}

message CreateUserRequest {
string name = 1;
}

message CreateUserResponse {
User user = 1;
}

service UserService {
rpc CreateUser (CreateUserRequest) returns (CreateUserResponse) {}
}