diff --git a/instrumentation/opentelemetry-instrumentation-grpc/setup.cfg b/instrumentation/opentelemetry-instrumentation-grpc/setup.cfg index 6fd95bbf09..c070e243ff 100644 --- a/instrumentation/opentelemetry-instrumentation-grpc/setup.cfg +++ b/instrumentation/opentelemetry-instrumentation-grpc/setup.cfg @@ -59,3 +59,4 @@ where = src opentelemetry_instrumentor = grpc_client = opentelemetry.instrumentation.grpc:GrpcInstrumentorClient grpc_server = opentelemetry.instrumentation.grpc:GrpcInstrumentorServer + grpc_aio_server = opentelemetry.instrumentation.grpc:GrpcAioInstrumentorServer diff --git a/instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/__init__.py b/instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/__init__.py index 177bfe67b5..edfd3f640e 100644 --- a/instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/__init__.py @@ -108,7 +108,7 @@ def serve(): logging.basicConfig() serve() -You can also add the instrumentor manually, rather than using +You can also add the interceptor manually, rather than using :py:class:`~opentelemetry.instrumentation.grpc.GrpcInstrumentorServer`: .. code-block:: python @@ -118,6 +118,64 @@ def serve(): server = grpc.server(futures.ThreadPoolExecutor(), interceptors = [server_interceptor()]) +Usage Aio Server +------------ +.. code-block:: python + + import logging + import asyncio + + import grpc + + from opentelemetry import trace + from opentelemetry.instrumentation.grpc import GrpcAioInstrumentorServer + from opentelemetry.sdk.trace import TracerProvider + from opentelemetry.sdk.trace.export import ( + ConsoleSpanExporter, + SimpleSpanProcessor, + ) + + try: + from .gen import helloworld_pb2, helloworld_pb2_grpc + except ImportError: + from gen import helloworld_pb2, helloworld_pb2_grpc + + trace.set_tracer_provider(TracerProvider()) + trace.get_tracer_provider().add_span_processor( + SimpleSpanProcessor(ConsoleSpanExporter()) + ) + + grpc_server_instrumentor = GrpcAioInstrumentorServer() + grpc_server_instrumentor.instrument() + + class Greeter(helloworld_pb2_grpc.GreeterServicer): + async def SayHello(self, request, context): + return helloworld_pb2.HelloReply(message="Hello, %s!" % request.name) + + + async def serve(): + + server = grpc.aio.server() + + helloworld_pb2_grpc.add_GreeterServicer_to_server(Greeter(), server) + server.add_insecure_port("[::]:50051") + await server.start() + await server.wait_for_termination() + + + if __name__ == "__main__": + logging.basicConfig() + asyncio.run(serve()) + +You can also add the interceptor manually, rather than using +:py:class:`~opentelemetry.instrumentation.grpc.GrpcAioInstrumentorServer`: + +.. code-block:: python + + from opentelemetry.instrumentation.grpc import aio_server_interceptor + + server = grpc.aio.server(interceptors = [aio_server_interceptor()]) + """ from typing import Collection @@ -174,6 +232,44 @@ def _uninstrument(self, **kwargs): grpc.server = self._original_func +class GrpcAioInstrumentorServer(BaseInstrumentor): + """ + Globally instrument the grpc.aio server. + + Usage:: + + grpc_aio_server_instrumentor = GrpcAioInstrumentorServer() + grpc_aio_server_instrumentor.instrument() + + """ + + # pylint:disable=attribute-defined-outside-init, redefined-outer-name + + def instrumentation_dependencies(self) -> Collection[str]: + return _instruments + + def _instrument(self, **kwargs): + self._original_func = grpc.aio.server + tracer_provider = kwargs.get("tracer_provider") + + def server(*args, **kwargs): + if "interceptors" in kwargs: + # add our interceptor as the first + kwargs["interceptors"].insert( + 0, aio_server_interceptor(tracer_provider=tracer_provider) + ) + else: + kwargs["interceptors"] = [ + aio_server_interceptor(tracer_provider=tracer_provider) + ] + return self._original_func(*args, **kwargs) + + grpc.aio.server = server + + def _uninstrument(self, **kwargs): + grpc.aio.server = self._original_func + + class GrpcInstrumentorClient(BaseInstrumentor): """ Globally instrument the grpc client @@ -255,3 +351,19 @@ def server_interceptor(tracer_provider=None): tracer = trace.get_tracer(__name__, __version__, tracer_provider) return _server.OpenTelemetryServerInterceptor(tracer) + + +def aio_server_interceptor(tracer_provider=None): + """Create a gRPC aio server interceptor. + + Args: + tracer: The tracer to use to create server-side spans. + + Returns: + A service-side interceptor object. + """ + from . import _aio_server + + tracer = trace.get_tracer(__name__, __version__, tracer_provider) + + return _aio_server.OpenTelemetryAioServerInterceptor(tracer) diff --git a/instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/_aio_server.py b/instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/_aio_server.py new file mode 100644 index 0000000000..0909d623db --- /dev/null +++ b/instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/_aio_server.py @@ -0,0 +1,105 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import grpc.aio + +from ._server import ( + OpenTelemetryServerInterceptor, + _wrap_rpc_behavior, + _OpenTelemetryServicerContext, +) + + +class OpenTelemetryAioServerInterceptor( + grpc.aio.ServerInterceptor, OpenTelemetryServerInterceptor +): + """ + An AsyncIO gRPC server interceptor, to add OpenTelemetry. + Usage:: + tracer = some OpenTelemetry tracer + interceptors = [ + AsyncOpenTelemetryServerInterceptor(tracer), + ] + server = aio.server( + futures.ThreadPoolExecutor(max_workers=concurrency), + interceptors = (interceptors,)) + """ + + async def intercept_service(self, continuation, handler_call_details): + def telemetry_wrapper(behavior, request_streaming, response_streaming): + # handle streaming responses specially + if response_streaming: + return self._intercept_server_stream( + behavior, + handler_call_details, + ) + else: + return self._intercept_server_unary( + behavior, + handler_call_details, + ) + + next_handler = await continuation(handler_call_details) + + return _wrap_rpc_behavior(next_handler, telemetry_wrapper) + + def _intercept_server_unary(self, behavior, handler_call_details): + async def _unary_interceptor(request_or_iterator, context): + with self._set_remote_context(context): + with self._start_span( + handler_call_details, + context, + set_status_on_exception=False, + ) as span: + # wrap the context + context = _OpenTelemetryServicerContext(context, span) + + # And now we run the actual RPC. + try: + return await behavior(request_or_iterator, context) + + except Exception as error: + # Bare exceptions are likely to be gRPC aborts, which + # we handle in our context wrapper. + # Here, we're interested in uncaught exceptions. + # pylint:disable=unidiomatic-typecheck + if type(error) != Exception: + span.record_exception(error) + raise error + + return _unary_interceptor + + def _intercept_server_stream(self, behavior, handler_call_details): + async def _stream_interceptor(request_or_iterator, context): + with self._set_remote_context(context): + with self._start_span( + handler_call_details, + context, + set_status_on_exception=False, + ) as span: + context = _OpenTelemetryServicerContext(context, span) + + try: + async for response in behavior( + request_or_iterator, context + ): + yield response + + except Exception as error: + # pylint:disable=unidiomatic-typecheck + if type(error) != Exception: + span.record_exception(error) + raise error + + return _stream_interceptor diff --git a/instrumentation/opentelemetry-instrumentation-grpc/tests/test_aio_server_interceptor.py b/instrumentation/opentelemetry-instrumentation-grpc/tests/test_aio_server_interceptor.py new file mode 100644 index 0000000000..0f6ecd8747 --- /dev/null +++ b/instrumentation/opentelemetry-instrumentation-grpc/tests/test_aio_server_interceptor.py @@ -0,0 +1,554 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint:disable=unused-argument +# pylint:disable=no-self-use +import asyncio +import grpc +import grpc.aio +from concurrent.futures.thread import ThreadPoolExecutor + +from time import sleep +from opentelemetry.test.test_base import TestBase +from opentelemetry import trace +import opentelemetry.instrumentation.grpc +from opentelemetry.semconv.trace import SpanAttributes +from opentelemetry.sdk import trace as trace_sdk +from opentelemetry.trace import StatusCode + +from .protobuf.test_server_pb2 import Request, Response +from .protobuf.test_server_pb2_grpc import ( + GRPCTestServerServicer, + add_GRPCTestServerServicer_to_server, +) +from opentelemetry.instrumentation.grpc import ( + GrpcAioInstrumentorServer, + aio_server_interceptor, +) + + +class Servicer(GRPCTestServerServicer): + """Our test servicer""" + + # pylint:disable=C0103 + async def SimpleMethod(self, request, context): + return Response( + server_id=request.client_id, + response_data=request.request_data, + ) + + # pylint:disable=C0103 + async def ServerStreamingMethod(self, request, context): + for data in ("one", "two", "three"): + yield Response( + server_id=request.client_id, + response_data=data, + ) + + +def run_with_test_server(runnable, servicer=Servicer(), add_interceptor=True): + if add_interceptor: + interceptors = [aio_server_interceptor()] + server = grpc.aio.server(interceptors=interceptors) + else: + server = grpc.aio.server() + + add_GRPCTestServerServicer_to_server(servicer, server) + + port = server.add_insecure_port("[::]:0") + channel = grpc.aio.insecure_channel(f"localhost:{port:d}") + + async def do_request(): + await server.start() + resp = await runnable(channel) + await server.stop(1000) + return resp + + loop = asyncio.get_event_loop_policy().get_event_loop() + return loop.run_until_complete(do_request()) + + +class TestOpenTelemetryAioServerInterceptor(TestBase): + def test_instrumentor(self): + """Check that automatic instrumentation configures the interceptor""" + rpc_call = "/GRPCTestServer/SimpleMethod" + + grpc_aio_server_instrumentor = GrpcAioInstrumentorServer() + grpc_aio_server_instrumentor.instrument() + + async def request(channel): + request = Request(client_id=1, request_data="test") + msg = request.SerializeToString() + return await channel.unary_unary(rpc_call)(msg) + + run_with_test_server(request, add_interceptor=False) + + spans_list = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans_list), 1) + span = spans_list[0] + + self.assertEqual(span.name, rpc_call) + self.assertIs(span.kind, trace.SpanKind.SERVER) + + # Check version and name in span's instrumentation info + self.assertEqualSpanInstrumentationInfo( + span, opentelemetry.instrumentation.grpc + ) + + # Check attributes + self.assertSpanHasAttributes( + span, + { + SpanAttributes.NET_PEER_IP: "[::1]", + SpanAttributes.NET_PEER_NAME: "localhost", + SpanAttributes.RPC_METHOD: "SimpleMethod", + SpanAttributes.RPC_SERVICE: "GRPCTestServer", + SpanAttributes.RPC_SYSTEM: "grpc", + SpanAttributes.RPC_GRPC_STATUS_CODE: grpc.StatusCode.OK.value[ + 0 + ], + }, + ) + + grpc_aio_server_instrumentor.uninstrument() + + + def test_uninstrument(self): + """Check that uninstrument removes the interceptor""" + rpc_call = "/GRPCTestServer/SimpleMethod" + + grpc_aio_server_instrumentor = GrpcAioInstrumentorServer() + grpc_aio_server_instrumentor.instrument() + grpc_aio_server_instrumentor.uninstrument() + + async def request(channel): + request = Request(client_id=1, request_data="test") + msg = request.SerializeToString() + return await channel.unary_unary(rpc_call)(msg) + + run_with_test_server(request, add_interceptor=False) + + spans_list = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans_list), 0) + + + def test_create_span(self): + """Check that the interceptor wraps calls with spans server-side.""" + rpc_call = "/GRPCTestServer/SimpleMethod" + + async def request(channel): + request = Request(client_id=1, request_data="test") + msg = request.SerializeToString() + return await channel.unary_unary(rpc_call)(msg) + + run_with_test_server(request) + + spans_list = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans_list), 1) + span = spans_list[0] + + self.assertEqual(span.name, rpc_call) + self.assertIs(span.kind, trace.SpanKind.SERVER) + + # Check version and name in span's instrumentation info + self.assertEqualSpanInstrumentationInfo( + span, opentelemetry.instrumentation.grpc + ) + + # Check attributes + self.assertSpanHasAttributes( + span, + { + SpanAttributes.NET_PEER_IP: "[::1]", + SpanAttributes.NET_PEER_NAME: "localhost", + SpanAttributes.RPC_METHOD: "SimpleMethod", + SpanAttributes.RPC_SERVICE: "GRPCTestServer", + SpanAttributes.RPC_SYSTEM: "grpc", + SpanAttributes.RPC_GRPC_STATUS_CODE: grpc.StatusCode.OK.value[ + 0 + ], + }, + ) + + def test_create_two_spans(self): + """Verify that the interceptor captures sub spans within the given + trace""" + rpc_call = "/GRPCTestServer/SimpleMethod" + + class TwoSpanServicer(GRPCTestServerServicer): + # pylint:disable=C0103 + async def SimpleMethod(self, request, context): + + # create another span + tracer = trace.get_tracer(__name__) + with tracer.start_as_current_span("child") as child: + child.add_event("child event") + + return Response( + server_id=request.client_id, + response_data=request.request_data, + ) + + async def request(channel): + request = Request(client_id=1, request_data="test") + msg = request.SerializeToString() + return await channel.unary_unary(rpc_call)(msg) + + run_with_test_server(request, servicer=TwoSpanServicer()) + + spans_list = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans_list), 2) + child_span = spans_list[0] + parent_span = spans_list[1] + + self.assertEqual(parent_span.name, rpc_call) + self.assertIs(parent_span.kind, trace.SpanKind.SERVER) + + # Check version and name in span's instrumentation info + self.assertEqualSpanInstrumentationInfo( + parent_span, opentelemetry.instrumentation.grpc + ) + + # Check attributes + self.assertSpanHasAttributes( + parent_span, + { + SpanAttributes.NET_PEER_IP: "[::1]", + SpanAttributes.NET_PEER_NAME: "localhost", + SpanAttributes.RPC_METHOD: "SimpleMethod", + SpanAttributes.RPC_SERVICE: "GRPCTestServer", + SpanAttributes.RPC_SYSTEM: "grpc", + SpanAttributes.RPC_GRPC_STATUS_CODE: grpc.StatusCode.OK.value[ + 0 + ], + }, + ) + + # Check the child span + self.assertEqual(child_span.name, "child") + self.assertEqual( + parent_span.context.trace_id, child_span.context.trace_id + ) + + def test_create_span_streaming(self): + """Check that the interceptor wraps calls with spans server-side, on a + streaming call.""" + rpc_call = "/GRPCTestServer/ServerStreamingMethod" + + async def request(channel): + request = Request(client_id=1, request_data="test") + msg = request.SerializeToString() + async for response in channel.unary_stream(rpc_call)(msg): + print(response) + + run_with_test_server(request) + + spans_list = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans_list), 1) + span = spans_list[0] + + self.assertEqual(span.name, rpc_call) + self.assertIs(span.kind, trace.SpanKind.SERVER) + + # Check version and name in span's instrumentation info + self.assertEqualSpanInstrumentationInfo( + span, opentelemetry.instrumentation.grpc + ) + + # Check attributes + self.assertSpanHasAttributes( + span, + { + SpanAttributes.NET_PEER_IP: "[::1]", + SpanAttributes.NET_PEER_NAME: "localhost", + SpanAttributes.RPC_METHOD: "ServerStreamingMethod", + SpanAttributes.RPC_SERVICE: "GRPCTestServer", + SpanAttributes.RPC_SYSTEM: "grpc", + SpanAttributes.RPC_GRPC_STATUS_CODE: grpc.StatusCode.OK.value[ + 0 + ], + }, + ) + + def test_create_two_spans_streaming(self): + """Verify that the interceptor captures sub spans within the given + trace""" + rpc_call = "/GRPCTestServer/ServerStreamingMethod" + + class TwoSpanServicer(GRPCTestServerServicer): + # pylint:disable=C0103 + async def ServerStreamingMethod(self, request, context): + # create another span + tracer = trace.get_tracer(__name__) + with tracer.start_as_current_span("child") as child: + child.add_event("child event") + + for data in ("one", "two", "three"): + yield Response( + server_id=request.client_id, + response_data=data, + ) + + async def request(channel): + request = Request(client_id=1, request_data="test") + msg = request.SerializeToString() + async for response in channel.unary_stream(rpc_call)(msg): + print(response) + + run_with_test_server(request, servicer=TwoSpanServicer()) + + spans_list = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans_list), 2) + child_span = spans_list[0] + parent_span = spans_list[1] + + self.assertEqual(parent_span.name, rpc_call) + self.assertIs(parent_span.kind, trace.SpanKind.SERVER) + + # Check version and name in span's instrumentation info + self.assertEqualSpanInstrumentationInfo( + parent_span, opentelemetry.instrumentation.grpc + ) + + # Check attributes + self.assertSpanHasAttributes( + parent_span, + { + SpanAttributes.NET_PEER_IP: "[::1]", + SpanAttributes.NET_PEER_NAME: "localhost", + SpanAttributes.RPC_METHOD: "ServerStreamingMethod", + SpanAttributes.RPC_SERVICE: "GRPCTestServer", + SpanAttributes.RPC_SYSTEM: "grpc", + SpanAttributes.RPC_GRPC_STATUS_CODE: grpc.StatusCode.OK.value[ + 0 + ], + }, + ) + + # Check the child span + self.assertEqual(child_span.name, "child") + self.assertEqual( + parent_span.context.trace_id, child_span.context.trace_id + ) + + def test_span_lifetime(self): + """Verify that the interceptor captures sub spans within the given + trace""" + rpc_call = "/GRPCTestServer/SimpleMethod" + + class SpanLifetimeServicer(GRPCTestServerServicer): + # pylint:disable=C0103 + async def SimpleMethod(self, request, context): + self.span = trace.get_current_span() + + return Response( + server_id=request.client_id, + response_data=request.request_data, + ) + + async def request(channel): + request = Request(client_id=1, request_data="test") + msg = request.SerializeToString() + return await channel.unary_unary(rpc_call)(msg) + + lifetime_servicer = SpanLifetimeServicer() + active_span_before_call = trace.get_current_span() + + run_with_test_server(request, servicer=lifetime_servicer) + + active_span_in_handler = lifetime_servicer.span + active_span_after_call = trace.get_current_span() + + self.assertEqual(active_span_before_call, trace.INVALID_SPAN) + self.assertEqual(active_span_after_call, trace.INVALID_SPAN) + self.assertIsInstance(active_span_in_handler, trace_sdk.Span) + self.assertIsNone(active_span_in_handler.parent) + + def test_sequential_server_spans(self): + """Check that sequential RPCs get separate server spans.""" + rpc_call = "/GRPCTestServer/SimpleMethod" + + async def request(channel): + request = Request(client_id=1, request_data="test") + msg = request.SerializeToString() + return await channel.unary_unary(rpc_call)(msg) + + async def sequential_requests(channel): + await request(channel) + await request(channel) + + run_with_test_server(sequential_requests) + + spans_list = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans_list), 2) + + span1, span2 = spans_list + # Spans should belong to separate traces + self.assertNotEqual(span1.context.span_id, span2.context.span_id) + self.assertNotEqual(span1.context.trace_id, span2.context.trace_id) + + for span in (span1, span2): + # each should be a root span + self.assertIsNone(span2.parent) + + # check attributes + self.assertSpanHasAttributes( + span, + { + SpanAttributes.NET_PEER_IP: "[::1]", + SpanAttributes.NET_PEER_NAME: "localhost", + SpanAttributes.RPC_METHOD: "SimpleMethod", + SpanAttributes.RPC_SERVICE: "GRPCTestServer", + SpanAttributes.RPC_SYSTEM: "grpc", + SpanAttributes.RPC_GRPC_STATUS_CODE: grpc.StatusCode.OK.value[ + 0 + ], + }, + ) + + def test_concurrent_server_spans(self): + """Check that concurrent RPC calls don't interfere with each other. + + This is the same check as test_sequential_server_spans except that the + RPCs are concurrent. Two handlers are invoked at the same time on two + separate threads. Each one should see a different active span and + context. + """ + rpc_call = "/GRPCTestServer/SimpleMethod" + latch = get_latch(2) + + class LatchedServicer(GRPCTestServerServicer): + # pylint:disable=C0103 + async def SimpleMethod(self, request, context): + await latch() + return Response( + server_id=request.client_id, + response_data=request.request_data, + ) + + async def request(channel): + request = Request(client_id=1, request_data="test") + msg = request.SerializeToString() + return await channel.unary_unary(rpc_call)(msg) + + async def concurrent_requests(channel): + await asyncio.gather(request(channel), request(channel)) + + run_with_test_server(concurrent_requests, servicer=LatchedServicer()) + + spans_list = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans_list), 2) + + span1, span2 = spans_list + # Spans should belong to separate traces + self.assertNotEqual(span1.context.span_id, span2.context.span_id) + self.assertNotEqual(span1.context.trace_id, span2.context.trace_id) + + for span in (span1, span2): + # each should be a root span + self.assertIsNone(span2.parent) + + # check attributes + self.assertSpanHasAttributes( + span, + { + SpanAttributes.NET_PEER_IP: "[::1]", + SpanAttributes.NET_PEER_NAME: "localhost", + SpanAttributes.RPC_METHOD: "SimpleMethod", + SpanAttributes.RPC_SERVICE: "GRPCTestServer", + SpanAttributes.RPC_SYSTEM: "grpc", + SpanAttributes.RPC_GRPC_STATUS_CODE: grpc.StatusCode.OK.value[ + 0 + ], + }, + ) + + def test_abort(self): + """Check that we can catch an abort properly""" + rpc_call = "/GRPCTestServer/SimpleMethod" + failure_message = "failure message" + + class AbortServicer(GRPCTestServerServicer): + # pylint:disable=C0103 + async def SimpleMethod(self, request, context): + await context.abort( + grpc.StatusCode.FAILED_PRECONDITION, failure_message + ) + + testcase = self + + async def request(channel): + request = Request(client_id=1, request_data=failure_message) + msg = request.SerializeToString() + + with testcase.assertRaises(Exception): + await channel.unary_unary(rpc_call)(msg) + + run_with_test_server(request, servicer=AbortServicer()) + + spans_list = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans_list), 1) + child_span = spans_list[0] + + self.assertEqual(len(spans_list), 1) + span = spans_list[0] + + self.assertEqual(span.name, rpc_call) + self.assertIs(span.kind, trace.SpanKind.SERVER) + + # Check version and name in span's instrumentation info + self.assertEqualSpanInstrumentationInfo( + span, opentelemetry.instrumentation.grpc + ) + + # make sure this span errored, with the right status and detail + self.assertEqual(span.status.status_code, StatusCode.ERROR) + self.assertEqual( + span.status.description, + f"{grpc.StatusCode.FAILED_PRECONDITION}:{failure_message}", + ) + + # Check attributes + self.assertSpanHasAttributes( + span, + { + SpanAttributes.NET_PEER_IP: "[::1]", + SpanAttributes.NET_PEER_NAME: "localhost", + SpanAttributes.RPC_METHOD: "SimpleMethod", + SpanAttributes.RPC_SERVICE: "GRPCTestServer", + SpanAttributes.RPC_SYSTEM: "grpc", + SpanAttributes.RPC_GRPC_STATUS_CODE: grpc.StatusCode.FAILED_PRECONDITION.value[ + 0 + ], + }, + ) + + +def get_latch(num): + """Get a countdown latch function for use in n threads.""" + cv = asyncio.Condition() + count = 0 + + async def countdown_latch(): + """Block until n-1 other threads have called.""" + nonlocal count + async with cv: + count += 1 + cv.notify() + + async with cv: + while count < num: + await cv.wait() + + return countdown_latch