Skip to content

Commit

Permalink
update to auto grpc instrumentor
Browse files Browse the repository at this point in the history
  • Loading branch information
cnnradams committed Jul 27, 2020
1 parent 9679190 commit bdb5225
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 17 deletions.
38 changes: 31 additions & 7 deletions ext/opentelemetry-ext-grpc/src/opentelemetry/ext/grpc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
SimpleExportSpanProcessor,
)
from opentelemetry.sdk.metrics.export import ConsoleMetricsExporter
try:
from .gen import helloworld_pb2, helloworld_pb2_grpc
except ImportError:
Expand All @@ -42,7 +44,9 @@
trace.get_tracer_provider().add_span_processor(
SimpleExportSpanProcessor(ConsoleSpanExporter())
)
instrumentor = GrpcInstrumentorClient()
# Optional - export GRPC specific metrics (latency, bytes in/out, errors) by passing an exporter
instrumentor = GrpcInstrumentorClient(exporter=ConsoleMetricsExporter(), interval=10)
instrumentor.instrument()
def run():
Expand Down Expand Up @@ -109,6 +113,7 @@ def serve():
serve()
"""
from contextlib import contextmanager
from functools import partial

import grpc
from wrapt import wrap_function_wrapper as _wrap
Expand Down Expand Up @@ -139,11 +144,21 @@ def wrapper_fn(self, original_func, instance, args, kwargs):

class GrpcInstrumentorClient(BaseInstrumentor):
def _instrument(self, **kwargs):
exporter = kwargs.get("exporter", None)
interval = kwargs.get("interval", 30)
if kwargs.get("channel_type") == "secure":
_wrap("grpc", "secure_channel", self.wrapper_fn)
_wrap(
"grpc",
"secure_channel",
partial(self.wrapper_fn, exporter, interval),
)

else:
_wrap("grpc", "insecure_channel", self.wrapper_fn)
_wrap(
"grpc",
"insecure_channel",
partial(self.wrapper_fn, exporter, interval),
)

def _uninstrument(self, **kwargs):
if kwargs.get("channel_type") == "secure":
Expand All @@ -152,10 +167,19 @@ def _uninstrument(self, **kwargs):
else:
unwrap(grpc, "insecure_channel")

@contextmanager
def wrapper_fn(self, original_func, instance, args, kwargs):
with original_func(*args, **kwargs) as channel:
yield intercept_channel(channel, client_interceptor())
def wrapper_fn(
self, exporter, interval, original_func, instance, args, kwargs
):
channel = original_func(*args, **kwargs)
tracer_provider = kwargs.get("tracer_provider")
return intercept_channel(
channel,
client_interceptor(
tracer_provider=tracer_provider,
exporter=exporter,
interval=interval,
),
)


def client_interceptor(tracer_provider=None, exporter=None, interval=30):
Expand Down
20 changes: 10 additions & 10 deletions ext/opentelemetry-ext-grpc/tests/test_client_interceptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@

import opentelemetry.ext.grpc
from opentelemetry import trace
from opentelemetry.ext.grpc import client_interceptor
from opentelemetry.ext.grpc.grpcext import intercept_channel
from opentelemetry.ext.grpc import GrpcInstrumentorClient
from opentelemetry.sdk.metrics.export.aggregate import (
MinMaxSumCountAggregator,
SumAggregator,
Expand All @@ -37,23 +36,23 @@
class TestClientProto(TestBase):
def setUp(self):
super().setUp()
self.server = create_test_server(25565)
self.server.start()
self.interceptor = client_interceptor(
GrpcInstrumentorClient().instrument(
exporter=self.memory_metrics_exporter
)
self.channel = intercept_channel(
grpc.insecure_channel("localhost:25565"), self.interceptor
)
self.server = create_test_server(25565)
self.server.start()
self.channel = grpc.insecure_channel("localhost:25565")
self._stub = test_server_pb2_grpc.GRPCTestServerStub(self.channel)

def tearDown(self):
super().tearDown()
GrpcInstrumentorClient().uninstrument()
self.memory_metrics_exporter.clear()
self.server.stop(None)

def _verify_success_records(self, num_bytes_out, num_bytes_in, method):
self.interceptor.controller.tick()
# pylint: disable=protected-access,no-member
self.channel._interceptor.controller.tick()
records = self.memory_metrics_exporter.get_exported_metrics()
self.assertEqual(len(records), 3)

Expand Down Expand Up @@ -163,7 +162,8 @@ def test_stream_stream(self):
)

def _verify_error_records(self, method):
self.interceptor.controller.tick()
# pylint: disable=protected-access,no-member
self.channel._interceptor.controller.tick()
records = self.memory_metrics_exporter.get_exported_metrics()
self.assertEqual(len(records), 3)

Expand Down

0 comments on commit bdb5225

Please sign in to comment.