Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add hooks for aiohttp, asgi, starlette, fastAPI, urllib, urllib3 #576

Merged
merged 19 commits into from
Jul 26, 2021
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
([#567](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/567))
- `opentelemetry-instrumentation-grpc` Fixed asynchonous unary call traces
([#536](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/536))
- `opentelemetry-instrumentation-asgi`, `opentelemetry-instrumentation-aiohttp-client`, `openetelemetry-instrumentation-fastapi`, `opentelemetry-instrumentation-starlette`, `opentelemetry-instrumentation-urllib`, `opentelemetry-instrumentation-urllib3` Added `request_hook` and `response_hook` callbacks ([#576](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/576))

### Added
- `opentelemetry-instrumentation-httpx` Add `httpx` instrumentation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,25 @@ def strip_query_params(url: yarl.URL) -> str:
)
from opentelemetry.propagate import inject
from opentelemetry.semconv.trace import SpanAttributes
from opentelemetry.trace import SpanKind, TracerProvider, get_tracer
from opentelemetry.trace import Span, SpanKind, TracerProvider, get_tracer
from opentelemetry.trace.status import Status, StatusCode
from opentelemetry.util.http import remove_url_credentials

_UrlFilterT = typing.Optional[typing.Callable[[str], str]]
_SpanNameT = typing.Optional[
typing.Union[typing.Callable[[aiohttp.TraceRequestStartParams], str], str]
_RequestHookT = typing.Optional[
typing.Callable[[Span, aiohttp.TraceRequestStartParams], None]
]
_ResponseHookT = typing.Optional[
typing.Callable[
[
Span,
typing.Union[
aiohttp.TraceRequestEndParams,
aiohttp.TraceRequestExceptionParams,
],
],
None,
]
]


Expand All @@ -108,7 +120,8 @@ def url_path_span_name(params: aiohttp.TraceRequestStartParams) -> str:

def create_trace_config(
url_filter: _UrlFilterT = None,
span_name: _SpanNameT = None,
request_hook: _RequestHookT = None,
response_hook: _ResponseHookT = None,
tracer_provider: TracerProvider = None,
) -> aiohttp.TraceConfig:
"""Create an aiohttp-compatible trace configuration.
Expand All @@ -134,8 +147,8 @@ def create_trace_config(
it as a span attribute. This can be useful to remove sensitive data
such as API keys or user personal information.

:param str span_name: Override the default span name.
:param tracer_provider: optional TracerProvider from which to get a Tracer
ryokather marked this conversation as resolved.
Show resolved Hide resolved
:param request_hook: Optional callback that can modify span name and request params.
:param response_hook: Optional callback that can modify span name and response params.

:return: An object suitable for use with :py:class:`aiohttp.ClientSession`.
:rtype: :py:class:`aiohttp.TraceConfig`
Expand All @@ -161,17 +174,15 @@ async def on_request_start(
return

http_method = params.method.upper()
if trace_config_ctx.span_name is None:
request_span_name = "HTTP {}".format(http_method)
elif callable(trace_config_ctx.span_name):
request_span_name = str(trace_config_ctx.span_name(params))
else:
request_span_name = str(trace_config_ctx.span_name)
request_span_name = "HTTP {}".format(http_method)

trace_config_ctx.span = trace_config_ctx.tracer.start_span(
request_span_name, kind=SpanKind.CLIENT,
)

if callable(request_hook):
request_hook(trace_config_ctx.span, params)

if trace_config_ctx.span.is_recording():
attributes = {
SpanAttributes.HTTP_METHOD: http_method,
Expand All @@ -198,6 +209,9 @@ async def on_request_end(
if trace_config_ctx.span is None:
return

if callable(response_hook):
response_hook(trace_config_ctx.span, params)

if trace_config_ctx.span.is_recording():
trace_config_ctx.span.set_status(
Status(http_status_to_status_code(int(params.response.status)))
Expand All @@ -215,6 +229,9 @@ async def on_request_exception(
if trace_config_ctx.span is None:
return

if callable(response_hook):
response_hook(trace_config_ctx.span, params)

if trace_config_ctx.span.is_recording() and params.exception:
trace_config_ctx.span.set_status(Status(StatusCode.ERROR))
trace_config_ctx.span.record_exception(params.exception)
Expand All @@ -223,7 +240,7 @@ async def on_request_exception(
def _trace_config_ctx_factory(**kwargs):
kwargs.setdefault("trace_request_ctx", {})
return types.SimpleNamespace(
span_name=span_name, tracer=tracer, url_filter=url_filter, **kwargs
tracer=tracer, url_filter=url_filter, **kwargs
)

trace_config = aiohttp.TraceConfig(
Expand All @@ -240,7 +257,8 @@ def _trace_config_ctx_factory(**kwargs):
def _instrument(
tracer_provider: TracerProvider = None,
url_filter: _UrlFilterT = None,
span_name: _SpanNameT = None,
request_hook: _RequestHookT = None,
response_hook: _ResponseHookT = None,
):
"""Enables tracing of all ClientSessions

Expand All @@ -256,7 +274,8 @@ def instrumented_init(wrapped, instance, args, kwargs):

trace_config = create_trace_config(
url_filter=url_filter,
span_name=span_name,
request_hook=request_hook,
response_hook=response_hook,
tracer_provider=tracer_provider,
)
trace_config._is_instrumented_by_opentelemetry = True
Expand Down Expand Up @@ -304,12 +323,14 @@ def _instrument(self, **kwargs):
``url_filter``: A callback to process the requested URL prior to adding
it as a span attribute. This can be useful to remove sensitive data
such as API keys or user personal information.
``span_name``: Override the default span name.
``request_hook``: An optional callback that is invoked right after a span is created.
``response_hook``: An optional callback which is invoked right before the span is finished processing a response.
"""
_instrument(
tracer_provider=kwargs.get("tracer_provider"),
url_filter=kwargs.get("url_filter"),
span_name=kwargs.get("span_name"),
request_hook=kwargs.get("request_hook"),
response_hook=kwargs.get("response_hook"),
)

def _uninstrument(self, **kwargs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from opentelemetry.instrumentation.utils import _SUPPRESS_INSTRUMENTATION_KEY
from opentelemetry.semconv.trace import SpanAttributes
from opentelemetry.test.test_base import TestBase
from opentelemetry.trace import StatusCode
from opentelemetry.trace import Span, StatusCode


def run_with_test_server(
Expand Down Expand Up @@ -161,46 +161,51 @@ def test_not_recording(self):
self.assertFalse(mock_span.set_attribute.called)
self.assertFalse(mock_span.set_status.called)

def test_span_name_option(self):
for span_name, method, path, expected in (
("static", "POST", "/static-span-name", "static"),
(
lambda params: "{} - {}".format(
params.method, params.url.path
),
"PATCH",
"/some/path",
"PATCH - /some/path",
),
def test_hooks(self):
method = "PATCH"
path = "/some/path"
expected = "PATCH - /some/path"

def request_hook(span: Span, params: aiohttp.TraceRequestStartParams):
span.update_name("{} - {}".format(params.method, params.url.path))

def response_hook(
span: Span,
params: typing.Union[
aiohttp.TraceRequestEndParams,
aiohttp.TraceRequestExceptionParams,
],
):
with self.subTest(span_name=span_name, method=method, path=path):
host, port = self._http_request(
trace_config=aiohttp_client.create_trace_config(
span_name=span_name
),
method=method,
url=path,
status_code=HTTPStatus.OK,
)
span.set_attribute("response_hook_attr", "value")

self.assert_spans(
[
(
expected,
(StatusCode.UNSET, None),
{
SpanAttributes.HTTP_METHOD: method,
SpanAttributes.HTTP_URL: "http://{}:{}{}".format(
host, port, path
),
SpanAttributes.HTTP_STATUS_CODE: int(
HTTPStatus.OK
),
},
)
]
)
self.memory_exporter.clear()
host, port = self._http_request(
trace_config=aiohttp_client.create_trace_config(
request_hook=request_hook, response_hook=response_hook,
),
method=method,
url=path,
status_code=HTTPStatus.OK,
)

for span in self.memory_exporter.get_finished_spans():
self.assertEqual(span.name, expected)
self.assertEqual(
(span.status.status_code, span.status.description),
(StatusCode.UNSET, None),
)
self.assertEqual(
span.attributes[SpanAttributes.HTTP_METHOD], method
)
self.assertEqual(
span.attributes[SpanAttributes.HTTP_URL],
"http://{}:{}{}".format(host, port, path),
)
self.assertEqual(
span.attributes[SpanAttributes.HTTP_STATUS_CODE], HTTPStatus.OK
)
self.assertIn("response_hook_attr", span.attributes)
self.assertEqual(span.attributes["response_hook_attr"], "value")
self.memory_exporter.clear()

def test_url_filter_option(self):
# Strips all query params from URL before adding as a span attribute.
Expand Down Expand Up @@ -501,19 +506,32 @@ def strip_query_params(url: yarl.URL) -> str:
span.attributes[SpanAttributes.HTTP_URL],
)

def test_span_name(self):
def span_name_callback(params: aiohttp.TraceRequestStartParams) -> str:
return "{} - {}".format(params.method, params.url.path)
def test_hooks(self):
def request_hook(span: Span, params: aiohttp.TraceRequestStartParams):
span.update_name("{} - {}".format(params.method, params.url.path))

def response_hook(
span: Span,
params: typing.Union[
aiohttp.TraceRequestEndParams,
aiohttp.TraceRequestExceptionParams,
],
):
span.set_attribute("response_hook_attr", "value")

AioHttpClientInstrumentor().uninstrument()
AioHttpClientInstrumentor().instrument(span_name=span_name_callback)
AioHttpClientInstrumentor().instrument(
request_hook=request_hook, response_hook=response_hook
)

url = "/test-path"
run_with_test_server(
self.get_default_request(url), url, self.default_handler
)
span = self.assert_spans(1)
self.assertEqual("GET - /test-path", span.name)
self.assertIn("response_hook_attr", span.attributes)
self.assertEqual(span.attributes["response_hook_attr"], "value")


class TestLoadingAioHttpInstrumentor(unittest.TestCase):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,9 @@ def set_status_code(span, status_code):


def get_default_span_details(scope: dict) -> Tuple[str, dict]:
"""Default implementation for span_details_callback

"""Default implementation for get_default_span_details
Args:
scope: the asgi scope dictionary

Returns:
a tuple of the span name, and any attributes to attach to the span.
"""
Expand All @@ -164,10 +162,15 @@ class OpenTelemetryMiddleware:

Args:
app: The ASGI application callable to forward requests to.
span_details_callback: Callback which should return a string
and a tuple, representing the desired span name and a
dictionary with any additional span attributes to set.
Optional: Defaults to get_default_span_details.
default_span_details: Callback which should return a string and a tuple, representing the desired default span name and a
lzchen marked this conversation as resolved.
Show resolved Hide resolved
dictionary with any additional span attributes to set.
Optional: Defaults to get_default_span_details.
server_request_hook: Optional callback which is called with the server span and ASGI
ryokather marked this conversation as resolved.
Show resolved Hide resolved
scope object for every incoming request.
client_request_hook: Optional callback which is called with the internal span and an ASGI
scope which is sent as a dictionary for when the method recieve is called.
client_response_hook: Optional callback which is called with the internal span and an ASGI
event which is sent as a dictionary for when the method send is called.
tracer_provider: The optional tracer provider to use. If omitted
the current globally configured one is used.
"""
Expand All @@ -176,15 +179,21 @@ def __init__(
self,
app,
excluded_urls=None,
span_details_callback=None,
default_span_details=None,
server_request_hook=None,
client_request_hook=None,
client_response_hook=None,
tracer_provider=None,
):
self.app = guarantee_single_callable(app)
self.tracer = trace.get_tracer(__name__, __version__, tracer_provider)
self.span_details_callback = (
span_details_callback or get_default_span_details
)
self.excluded_urls = excluded_urls
self.default_span_details = (
default_span_details or get_default_span_details
)
self.server_request_hook = server_request_hook
self.client_request_hook = client_request_hook
self.client_response_hook = client_response_hook

async def __call__(self, scope, receive, send):
"""The ASGI application
Expand All @@ -202,7 +211,7 @@ async def __call__(self, scope, receive, send):
return await self.app(scope, receive, send)

token = context.attach(extract(scope, getter=asgi_getter))
span_name, additional_attributes = self.span_details_callback(scope)
span_name, additional_attributes = self.default_span_details(scope)

try:
with self.tracer.start_as_current_span(
Expand All @@ -214,11 +223,16 @@ async def __call__(self, scope, receive, send):
for key, value in attributes.items():
span.set_attribute(key, value)

if callable(self.server_request_hook):
self.server_request_hook(span, scope)

@wraps(receive)
async def wrapped_receive():
with self.tracer.start_as_current_span(
" ".join((span_name, scope["type"], "receive"))
) as receive_span:
if callable(self.client_request_hook):
self.client_request_hook(receive_span, scope)
message = await receive()
if receive_span.is_recording():
if message["type"] == "websocket.receive":
Expand All @@ -231,6 +245,8 @@ async def wrapped_send(message):
with self.tracer.start_as_current_span(
" ".join((span_name, scope["type"], "send"))
) as send_span:
if callable(self.client_response_hook):
self.client_response_hook(send_span, message)
if send_span.is_recording():
if message["type"] == "http.response.start":
status_code = message["status"]
Expand Down
Loading