Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add capture_(request|response)_headers to instrument_httpx #671

Merged
merged 16 commits into from
Dec 17, 2024
43 changes: 32 additions & 11 deletions logfire-api/logfire_api/_internal/integrations/httpx.pyi
Original file line number Diff line number Diff line change
@@ -1,22 +1,43 @@
import httpx
from _typeshed import Incomplete
from logfire import Logfire as Logfire
from typing import TypedDict, Unpack
from opentelemetry.instrumentation.httpx import AsyncRequestHook, AsyncResponseHook, RequestHook, RequestInfo, ResponseHook, ResponseInfo
from opentelemetry.trace import Span
from typing import Any, Callable, Literal, ParamSpec, TypeVar, TypedDict, Unpack, overload

RequestHook: Incomplete
ResponseHook: Incomplete
AsyncRequestHook: Incomplete
AsyncResponseHook: Incomplete
class AsyncClientKwargs(TypedDict, total=False):
request_hook: RequestHook | AsyncRequestHook
response_hook: ResponseHook | AsyncResponseHook
skip_dep_check: bool

class ClientKwargs(TypedDict, total=False):
request_hook: RequestHook
response_hook: ResponseHook
skip_dep_check: bool

class HTTPXInstrumentKwargs(TypedDict, total=False):
request_hook: RequestHook
response_hook: ResponseHook
async_request_hook: AsyncRequestHook
async_response_hook: AsyncResponseHook
skip_dep_check: bool
AnyRequestHook = TypeVar('AnyRequestHook', RequestHook, AsyncRequestHook)
AnyResponseHook = TypeVar('AnyResponseHook', ResponseHook, AsyncResponseHook)
Hook = TypeVar('Hook', RequestHook, ResponseHook)
AsyncHook = TypeVar('AsyncHook', AsyncRequestHook, AsyncResponseHook)
P = ParamSpec('P')

def instrument_httpx(logfire_instance: Logfire, client: httpx.Client | httpx.AsyncClient | None, **kwargs: Unpack[HTTPXInstrumentKwargs]) -> None:
"""Instrument the `httpx` module so that spans are automatically created for each request.

See the `Logfire.instrument_httpx` method for details.
"""
@overload
def instrument_httpx(logfire_instance: Logfire, client: httpx.Client, capture_request_headers: bool, capture_response_headers: bool, **kwargs: Unpack[ClientKwargs]) -> None: ...
@overload
def instrument_httpx(logfire_instance: Logfire, client: httpx.AsyncClient, capture_request_headers: bool, capture_response_headers: bool, **kwargs: Unpack[AsyncClientKwargs]) -> None: ...
@overload
def instrument_httpx(logfire_instance: Logfire, client: None, capture_request_headers: bool, capture_response_headers: bool, **kwargs: Unpack[HTTPXInstrumentKwargs]) -> None: ...
def make_capture_response_headers_hook(hook: ResponseHook | None) -> ResponseHook: ...
def make_capture_async_response_headers_hook(hook: AsyncResponseHook | None) -> AsyncResponseHook: ...
def make_capture_request_headers_hook(hook: RequestHook | None) -> RequestHook: ...
def make_capture_async_request_headers_hook(hook: AsyncRequestHook | None) -> AsyncRequestHook: ...
async def run_async_hook(hook: Callable[P, Any] | None, *args: P.args, **kwargs: P.kwargs) -> None: ...
def run_hook(hook: Callable[P, Any] | None, *args: P.args, **kwargs: P.kwargs) -> None: ...
def capture_response_headers(span: Span, response: ResponseInfo) -> None: ...
def capture_request_headers(span: Span, request: RequestInfo) -> None: ...
def capture_headers(span: Span, headers: httpx.Headers, request_or_response: Literal['request', 'response']) -> None: ...
17 changes: 7 additions & 10 deletions logfire-api/logfire_api/_internal/main.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ from .integrations.asyncpg import AsyncPGInstrumentKwargs as AsyncPGInstrumentKw
from .integrations.aws_lambda import AwsLambdaInstrumentKwargs as AwsLambdaInstrumentKwargs, LambdaHandler as LambdaHandler
from .integrations.celery import CeleryInstrumentKwargs as CeleryInstrumentKwargs
from .integrations.flask import FlaskInstrumentKwargs as FlaskInstrumentKwargs
from .integrations.httpx import HTTPXInstrumentKwargs as HTTPXInstrumentKwargs
from .integrations.httpx import AsyncClientKwargs as AsyncClientKwargs, ClientKwargs as ClientKwargs, HTTPXInstrumentKwargs as HTTPXInstrumentKwargs
from .integrations.mysql import MySQLConnection as MySQLConnection, MySQLInstrumentKwargs as MySQLInstrumentKwargs
from .integrations.psycopg import PsycopgInstrumentKwargs as PsycopgInstrumentKwargs
from .integrations.pymongo import PymongoInstrumentKwargs as PymongoInstrumentKwargs
Expand Down Expand Up @@ -550,15 +550,12 @@ class Logfire:
"""
def instrument_asyncpg(self, **kwargs: Unpack[AsyncPGInstrumentKwargs]) -> None:
"""Instrument the `asyncpg` module so that spans are automatically created for each query."""
def instrument_httpx(self, client: httpx.Client | httpx.AsyncClient | None = None, **kwargs: Unpack[HTTPXInstrumentKwargs]) -> None:
"""Instrument the `httpx` module so that spans are automatically created for each request.

Optionally, pass an `httpx.Client` instance to instrument only that client.

Uses the
[OpenTelemetry HTTPX Instrumentation](https://opentelemetry-python-contrib.readthedocs.io/en/latest/instrumentation/httpx/httpx.html)
library, specifically `HTTPXClientInstrumentor().instrument()`, to which it passes `**kwargs`.
"""
@overload
def instrument_httpx(self, client: httpx.Client, capture_request_headers: bool = False, capture_response_headers: bool = False, **kwargs: Unpack[ClientKwargs]) -> None: ...
@overload
def instrument_httpx(self, client: httpx.AsyncClient, capture_request_headers: bool = False, capture_response_headers: bool = False, **kwargs: Unpack[AsyncClientKwargs]) -> None: ...
@overload
def instrument_httpx(self, client: None = None, capture_request_headers: bool = False, capture_response_headers: bool = False, **kwargs: Unpack[HTTPXInstrumentKwargs]) -> None: ...
def instrument_celery(self, **kwargs: Unpack[CeleryInstrumentKwargs]) -> None:
"""Instrument `celery` so that spans are automatically created for each task.

Expand Down
185 changes: 168 additions & 17 deletions logfire/_internal/integrations/httpx.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,21 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any
import functools
import inspect
from typing import TYPE_CHECKING, Any, Callable, Literal, cast, overload

import httpx

try:
from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor
from opentelemetry.instrumentation.httpx import (
AsyncRequestHook,
AsyncResponseHook,
HTTPXClientInstrumentor,
RequestHook,
RequestInfo,
ResponseHook,
ResponseInfo,
)
except ImportError:
raise RuntimeError(
'`logfire.instrument_httpx()` requires the `opentelemetry-instrumentation-httpx` package.\n'
Expand All @@ -14,15 +26,19 @@
from logfire import Logfire

if TYPE_CHECKING:
from typing import Awaitable, Callable, TypedDict, Unpack
from typing import ParamSpec, TypedDict, TypeVar, Unpack

import httpx
from opentelemetry.trace import Span

RequestHook = Callable[[Span, httpx.Request], None]
ResponseHook = Callable[[Span, httpx.Request, httpx.Response], None]
AsyncRequestHook = Callable[[Span, httpx.Request], Awaitable[None]]
AsyncResponseHook = Callable[[Span, httpx.Request, httpx.Response], Awaitable[None]]
class AsyncClientKwargs(TypedDict, total=False):
request_hook: RequestHook | AsyncRequestHook
response_hook: ResponseHook | AsyncResponseHook
skip_dep_check: bool

class ClientKwargs(TypedDict, total=False):
request_hook: RequestHook
response_hook: ResponseHook
skip_dep_check: bool

class HTTPXInstrumentKwargs(TypedDict, total=False):
request_hook: RequestHook
Expand All @@ -31,9 +47,47 @@ class HTTPXInstrumentKwargs(TypedDict, total=False):
async_response_hook: AsyncResponseHook
skip_dep_check: bool

AnyRequestHook = TypeVar('AnyRequestHook', RequestHook, AsyncRequestHook)
AnyResponseHook = TypeVar('AnyResponseHook', ResponseHook, AsyncResponseHook)
Hook = TypeVar('Hook', RequestHook, ResponseHook)
AsyncHook = TypeVar('AsyncHook', AsyncRequestHook, AsyncResponseHook)

P = ParamSpec('P')

@overload
def instrument_httpx(
logfire_instance: Logfire,
client: httpx.Client,
capture_request_headers: bool,
capture_response_headers: bool,
**kwargs: Unpack[ClientKwargs],
) -> None: ...

@overload
def instrument_httpx(
logfire_instance: Logfire,
client: httpx.AsyncClient,
capture_request_headers: bool,
capture_response_headers: bool,
**kwargs: Unpack[AsyncClientKwargs],
) -> None: ...

@overload
def instrument_httpx(
logfire_instance: Logfire,
client: None,
capture_request_headers: bool,
capture_response_headers: bool,
**kwargs: Unpack[HTTPXInstrumentKwargs],
) -> None: ...


def instrument_httpx(
logfire_instance: Logfire, client: httpx.Client | httpx.AsyncClient | None, **kwargs: Unpack[HTTPXInstrumentKwargs]
logfire_instance: Logfire,
client: httpx.Client | httpx.AsyncClient | None,
capture_request_headers: bool,
capture_response_headers: bool,
**kwargs: Any,
) -> None:
"""Instrument the `httpx` module so that spans are automatically created for each request.

Expand All @@ -45,13 +99,110 @@ def instrument_httpx(
**kwargs,
}
del kwargs # make sure only final_kwargs is used
Kludex marked this conversation as resolved.
Show resolved Hide resolved

instrumentor = HTTPXClientInstrumentor()
if client:
instrumentor.instrument_client(
client,
tracer_provider=final_kwargs['tracer_provider'],
request_hook=final_kwargs.get('request_hook'),
response_hook=final_kwargs.get('response_hook'),
)
else:

if client is None:
request_hook = cast('RequestHook | None', final_kwargs.get('request_hook'))
response_hook = cast('ResponseHook | None', final_kwargs.get('response_hook'))
async_request_hook = cast('AsyncRequestHook | None', final_kwargs.get('async_request_hook'))
async_response_hook = cast('AsyncResponseHook | None', final_kwargs.get('async_response_hook'))

if capture_request_headers: # pragma: no cover
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

non-blocking, but things here are getting complicated enough that this is worth testing. it could either make a real request or a mini local server could be set up in another thread.

final_kwargs['request_hook'] = make_capture_request_headers_hook(request_hook)
final_kwargs['async_request_hook'] = make_capture_async_request_headers_hook(async_request_hook)

if capture_response_headers: # pragma: no cover
final_kwargs['response_hook'] = make_capture_response_headers_hook(response_hook)
final_kwargs['async_response_hook'] = make_capture_async_response_headers_hook(async_response_hook)

instrumentor.instrument(**final_kwargs)
else:
request_hook = cast('RequestHook | AsyncRequestHook | None', final_kwargs.get('request_hook'))
response_hook = cast('ResponseHook | AsyncResponseHook | None', final_kwargs.get('response_hook'))

if capture_request_headers:
if isinstance(client, httpx.AsyncClient):
request_hook = cast('AsyncRequestHook | None', request_hook)
request_hook = make_capture_async_request_headers_hook(request_hook)
else:
request_hook = cast('RequestHook | None', request_hook)
request_hook = make_capture_request_headers_hook(request_hook)
else:
if isinstance(client, httpx.AsyncClient):
request_hook = functools.partial(run_async_hook, request_hook)

if capture_response_headers:
if isinstance(client, httpx.AsyncClient):
response_hook = cast('AsyncResponseHook | None', response_hook)
response_hook = make_capture_async_response_headers_hook(response_hook)
else:
response_hook = cast('ResponseHook | None', response_hook)
response_hook = make_capture_response_headers_hook(response_hook)
else:
if isinstance(client, httpx.AsyncClient):
response_hook = functools.partial(run_async_hook, response_hook)

tracer_provider = final_kwargs['tracer_provider']
instrumentor.instrument_client(client, tracer_provider, request_hook, response_hook)


def make_capture_response_headers_hook(hook: ResponseHook | None) -> ResponseHook:
def capture_response_headers_hook(span: Span, request: RequestInfo, response: ResponseInfo) -> None:
capture_response_headers(span, response)
run_hook(hook, span, request, response)

return capture_response_headers_hook


def make_capture_async_response_headers_hook(hook: AsyncResponseHook | None) -> AsyncResponseHook:
async def capture_response_headers_hook(span: Span, request: RequestInfo, response: ResponseInfo) -> None:
capture_response_headers(span, response)
await run_async_hook(hook, span, request, response)

return capture_response_headers_hook


def make_capture_request_headers_hook(hook: RequestHook | None) -> RequestHook:
def capture_request_headers_hook(span: Span, request: RequestInfo) -> None:
capture_request_headers(span, request)
run_hook(hook, span, request)

return capture_request_headers_hook


def make_capture_async_request_headers_hook(hook: AsyncRequestHook | None) -> AsyncRequestHook:
async def capture_request_headers_hook(span: Span, request: RequestInfo) -> None:
capture_request_headers(span, request)
await run_async_hook(hook, span, request)

return capture_request_headers_hook


async def run_async_hook(hook: Callable[P, Any] | None, *args: P.args, **kwargs: P.kwargs) -> None:
if hook:
result = hook(*args, **kwargs)
while inspect.isawaitable(result):
result = await result


def run_hook(hook: Callable[P, Any] | None, *args: P.args, **kwargs: P.kwargs) -> None:
if hook:
hook(*args, **kwargs)


def capture_response_headers(span: Span, response: ResponseInfo) -> None:
capture_headers(span, cast('httpx.Headers', response.headers), 'response')


def capture_request_headers(span: Span, request: RequestInfo) -> None:
capture_headers(span, cast('httpx.Headers', request.headers), 'request')


def capture_headers(span: Span, headers: httpx.Headers, request_or_response: Literal['request', 'response']) -> None:
span.set_attributes(
{
f'http.{request_or_response}.header.{header_name}': headers.get_list(header_name)
alexmojaki marked this conversation as resolved.
Show resolved Hide resolved
for header_name in headers.keys()
}
)
37 changes: 34 additions & 3 deletions logfire/_internal/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@
from .integrations.aws_lambda import AwsLambdaInstrumentKwargs, LambdaHandler
from .integrations.celery import CeleryInstrumentKwargs
from .integrations.flask import FlaskInstrumentKwargs
from .integrations.httpx import HTTPXInstrumentKwargs
from .integrations.httpx import AsyncClientKwargs, ClientKwargs, HTTPXInstrumentKwargs
from .integrations.mysql import MySQLConnection, MySQLInstrumentKwargs
from .integrations.psycopg import PsycopgInstrumentKwargs
from .integrations.pymongo import PymongoInstrumentKwargs
Expand Down Expand Up @@ -1158,8 +1158,39 @@ def instrument_asyncpg(self, **kwargs: Unpack[AsyncPGInstrumentKwargs]) -> None:
self._warn_if_not_initialized_for_instrumentation()
return instrument_asyncpg(self, **kwargs)

@overload
def instrument_httpx(
self,
client: httpx.Client,
capture_request_headers: bool = False,
capture_response_headers: bool = False,
**kwargs: Unpack[ClientKwargs],
) -> None: ...

@overload
def instrument_httpx(
self,
client: httpx.AsyncClient,
capture_request_headers: bool = False,
capture_response_headers: bool = False,
**kwargs: Unpack[AsyncClientKwargs],
) -> None: ...

@overload
def instrument_httpx(
self, client: httpx.Client | httpx.AsyncClient | None = None, **kwargs: Unpack[HTTPXInstrumentKwargs]
self,
client: None = None,
capture_request_headers: bool = False,
capture_response_headers: bool = False,
**kwargs: Unpack[HTTPXInstrumentKwargs],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still want to get rid of these Unpacks and kwargs eventually.

) -> None: ...

def instrument_httpx(
self,
client: httpx.Client | httpx.AsyncClient | None = None,
capture_request_headers: bool = False,
capture_response_headers: bool = False,
**kwargs: Any,
) -> None:
"""Instrument the `httpx` module so that spans are automatically created for each request.

Expand All @@ -1172,7 +1203,7 @@ def instrument_httpx(
from .integrations.httpx import instrument_httpx

self._warn_if_not_initialized_for_instrumentation()
return instrument_httpx(self, client, **kwargs)
return instrument_httpx(self, client, capture_request_headers, capture_response_headers, **kwargs)

def instrument_celery(self, **kwargs: Unpack[CeleryInstrumentKwargs]) -> None:
"""Instrument `celery` so that spans are automatically created for each task.
Expand Down
Loading
Loading