Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement request and response hooks in requests library instrumentation #1717

Merged
merged 2 commits into from
Mar 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fix httpx resource warnings
([#1695](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/1695))

### Changed

- `opentelemetry-instrumentation-requests` Replace `name_callback` and `span_callback` with standard `response_hook` and `request_hook` callbacks
([#670](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/670))

## Version 1.16.0/0.37b0 (2023-02-17)

### Added
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
from typing import Callable, Collection, Iterable, Optional
from urllib.parse import urlparse

from requests.models import Response
from requests.models import PreparedRequest, Response
from requests.sessions import Session
from requests.structures import CaseInsensitiveDict

Expand Down Expand Up @@ -85,14 +85,17 @@

_excluded_urls_from_env = get_excluded_urls("REQUESTS")

_RequestHookT = Optional[Callable[[Span, PreparedRequest], None]]
_ResponseHookT = Optional[Callable[[Span, PreparedRequest], None]]


# pylint: disable=unused-argument
# pylint: disable=R0915
def _instrument(
tracer: Tracer,
duration_histogram: Histogram,
span_callback: Optional[Callable[[Span, Response], str]] = None,
name_callback: Optional[Callable[[str, str], str]] = None,
request_hook: _RequestHookT = None,
response_hook: _ResponseHookT = None,
excluded_urls: Iterable[str] = None,
):
"""Enables tracing of all requests calls that go through
Expand All @@ -106,29 +109,9 @@ def _instrument(
# before v1.0.0, Dec 17, 2012, see
# https://github.com/psf/requests/commit/4e5c4a6ab7bb0195dececdd19bb8505b872fe120)

wrapped_request = Session.request
wrapped_send = Session.send

@functools.wraps(wrapped_request)
def instrumented_request(self, method, url, *args, **kwargs):
if excluded_urls and excluded_urls.url_disabled(url):
return wrapped_request(self, method, url, *args, **kwargs)

def get_or_create_headers():
headers = kwargs.get("headers")
if headers is None:
headers = {}
kwargs["headers"] = headers

return headers

def call_wrapped():
return wrapped_request(self, method, url, *args, **kwargs)

return _instrumented_requests_call(
method, url, call_wrapped, get_or_create_headers
)

# pylint: disable-msg=too-many-locals,too-many-branches
@functools.wraps(wrapped_send)
def instrumented_send(self, request, **kwargs):
if excluded_urls and excluded_urls.url_disabled(request.url):
Expand All @@ -142,32 +125,17 @@ def get_or_create_headers():
)
return request.headers

def call_wrapped():
return wrapped_send(self, request, **kwargs)

return _instrumented_requests_call(
request.method, request.url, call_wrapped, get_or_create_headers
)

# pylint: disable-msg=too-many-locals,too-many-branches
def _instrumented_requests_call(
method: str, url: str, call_wrapped, get_or_create_headers
):
if context.get_value(
_SUPPRESS_INSTRUMENTATION_KEY
) or context.get_value(_SUPPRESS_HTTP_INSTRUMENTATION_KEY):
return call_wrapped()
return wrapped_send(self, request, **kwargs)

# See
# https://github.com/open-telemetry/opentelemetry-specification/blob/main/specification/trace/semantic_conventions/http.md#http-client
method = method.upper()
span_name = ""
if name_callback is not None:
span_name = name_callback(method, url)
if not span_name or not isinstance(span_name, str):
span_name = get_default_span_name(method)
method = request.method.upper()
span_name = get_default_span_name(method)

url = remove_url_credentials(url)
url = remove_url_credentials(request.url)

span_attributes = {
SpanAttributes.HTTP_METHOD: method,
Expand Down Expand Up @@ -195,6 +163,8 @@ def _instrumented_requests_call(
span_name, kind=SpanKind.CLIENT, attributes=span_attributes
) as span, set_ip_on_next_http_connection(span):
exception = None
if callable(request_hook):
request_hook(span, request)

headers = get_or_create_headers()
inject(headers)
Expand All @@ -206,7 +176,7 @@ def _instrumented_requests_call(
start_time = default_timer()

try:
result = call_wrapped() # *** PROCEED
result = wrapped_send(self, request, **kwargs) # *** PROCEED
except Exception as exc: # pylint: disable=W0703
exception = exc
result = getattr(exc, "response", None)
Expand Down Expand Up @@ -236,8 +206,8 @@ def _instrumented_requests_call(
"1.1" if version == 11 else "1.0"
)

if span_callback is not None:
span_callback(span, result)
if callable(response_hook):
response_hook(span, request, result)

duration_histogram.record(elapsed_time, attributes=metric_labels)

Expand All @@ -246,9 +216,6 @@ def _instrumented_requests_call(

return result

instrumented_request.opentelemetry_instrumentation_requests_applied = True
Session.request = instrumented_request

instrumented_send.opentelemetry_instrumentation_requests_applied = True
Session.send = instrumented_send

Expand Down Expand Up @@ -295,10 +262,8 @@ def _instrument(self, **kwargs):
Args:
**kwargs: Optional arguments
``tracer_provider``: a TracerProvider, defaults to global
``span_callback``: An optional callback invoked before returning the http response. Invoked with Span and requests.Response
``name_callback``: Callback which calculates a generic span name for an
outgoing HTTP request based on the method and url.
Optional: Defaults to get_default_span_name.
``request_hook``: An optional callback that is invoked right after a span is created.
``response_hook``: An optional callback which is invoked right before the span is finished processing a response.
``excluded_urls``: A string containing a comma-delimited
list of regexes used to exclude URLs from tracking
"""
Expand All @@ -319,8 +284,8 @@ def _instrument(self, **kwargs):
_instrument(
tracer,
duration_histogram,
span_callback=kwargs.get("span_callback"),
name_callback=kwargs.get("name_callback"),
request_hook=kwargs.get("request_hook"),
response_hook=kwargs.get("response_hook"),
excluded_urls=_excluded_urls_from_env
if excluded_urls is None
else parse_excluded_urls(excluded_urls),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,17 +133,23 @@ def test_basic(self):
span, opentelemetry.instrumentation.requests
)

def test_name_callback(self):
def name_callback(method, url):
return "GET" + url
def test_hooks(self):
def request_hook(span, request_obj):
span.update_name("name set from hook")

def response_hook(span, request_obj, response):
span.set_attribute("response_hook_attr", "value")

RequestsInstrumentor().uninstrument()
RequestsInstrumentor().instrument(name_callback=name_callback)
RequestsInstrumentor().instrument(
request_hook=request_hook, response_hook=response_hook
)
result = self.perform_request(self.URL)
self.assertEqual(result.text, "Hello!")
span = self.assert_span()

self.assertEqual(span.name, "GET" + self.URL)
self.assertEqual(span.name, "name set from hook")
self.assertEqual(span.attributes["response_hook_attr"], "value")

def test_excluded_urls_explicit(self):
url_404 = "http://httpbin.org/status/404"
Expand Down Expand Up @@ -300,17 +306,21 @@ def test_distributed_context(self):
finally:
set_global_textmap(previous_propagator)

def test_span_callback(self):
def test_response_hook(self):
RequestsInstrumentor().uninstrument()

def span_callback(span, result: requests.Response):
def response_hook(
span,
request: requests.PreparedRequest,
response: requests.Response,
):
span.set_attribute(
"http.response.body", result.content.decode("utf-8")
"http.response.body", response.content.decode("utf-8")
)

RequestsInstrumentor().instrument(
tracer_provider=self.tracer_provider,
span_callback=span_callback,
response_hook=response_hook,
)

result = self.perform_request(self.URL)
Expand Down Expand Up @@ -449,21 +459,6 @@ def perform_request(url: str, session: requests.Session = None):
return requests.get(url)
return session.get(url)

def test_invalid_url(self):
url = "http://[::1/nope"

with self.assertRaises(ValueError):
requests.post(url)

span = self.assert_span()

self.assertEqual(span.name, "HTTP POST")
self.assertEqual(
span.attributes,
{SpanAttributes.HTTP_METHOD: "POST", SpanAttributes.HTTP_URL: url},
)
self.assertEqual(span.status.status_code, StatusCode.ERROR)

def test_credential_removal(self):
new_url = "http://username:password@httpbin.org/status/200"
self.perform_request(new_url)
Expand Down