Skip to content

Commit

Permalink
Support HTTP trailers
Browse files Browse the repository at this point in the history
  • Loading branch information
siminn-arnorgj committed Nov 8, 2023
1 parent 597d091 commit 816e8a3
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -655,8 +655,11 @@ async def otel_receive():
def _get_otel_send(
self, server_span, server_span_name, scope, send, duration_attrs
):
expecting_trailers = False

@wraps(send)
async def otel_send(message):
nonlocal expecting_trailers
with self.tracer.start_as_current_span(
" ".join((server_span_name, scope["type"], "send"))
) as send_span:
Expand All @@ -670,6 +673,8 @@ async def otel_send(message):
] = status_code
set_status_code(server_span, status_code)
set_status_code(send_span, status_code)

expecting_trailers = message.get("trailers", False)
elif message["type"] == "websocket.send":
set_status_code(server_span, 200)
set_status_code(send_span, 200)
Expand Down Expand Up @@ -705,8 +710,15 @@ async def otel_send(message):
pass

await send(message)
if message["type"] == "http.response.body":
if not message.get("more_body", False):
server_span.end()
if (
not expecting_trailers
and message["type"] == "http.response.body"
and not message.get("more_body", False)
) or (
expecting_trailers
and message["type"] == "http.response.trailers"
and not message.get("more_trailers", False)
):
server_span.end()

return otel_send
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,51 @@ async def background_execution_asgi(scope, receive, send):
time.sleep(_SIMULATED_BACKGROUND_TASK_EXECUTION_TIME_S)


async def background_execution_trailers_asgi(scope, receive, send):
assert isinstance(scope, dict)
assert scope["type"] == "http"
message = await receive()
scope["headers"] = [(b"content-length", b"128")]
assert scope["type"] == "http"
if message.get("type") == "http.request":
await send(
{
"type": "http.response.start",
"status": 200,
"headers": [
[b"Content-Type", b"text/plain"],
[b"content-length", b"1024"],
],
"trailers": True,
}
)
await send(
{"type": "http.response.body", "body": b"*", "more_body": True}
)
await send(
{"type": "http.response.body", "body": b"*", "more_body": False}
)
await send(
{
"type": "http.response.trailers",
"headers": [
[b"trailer", b"test-trailer"],
],
"more_trailers": True,
}
)
await send(
{
"type": "http.response.trailers",
"headers": [
[b"trailer", b"second-test-trailer"],
],
"more_trailers": False,
}
)
time.sleep(_SIMULATED_BACKGROUND_TASK_EXECUTION_TIME_S)


async def error_asgi(scope, receive, send):
assert isinstance(scope, dict)
assert scope["type"] == "http"
Expand Down Expand Up @@ -188,7 +233,12 @@ def validate_outputs(self, outputs, error=None, modifiers=None):
modifiers = modifiers or []
# Check for expected outputs
response_start = outputs[0]
response_final_body = outputs[-1]
response_final_body = [
output
for output in outputs
if output["type"] == "http.response.body"
][-1]

self.assertEqual(response_start["type"], "http.response.start")
self.assertEqual(response_final_body["type"], "http.response.body")
self.assertEqual(response_final_body.get("more_body", False), False)
Expand Down Expand Up @@ -331,6 +381,41 @@ def test_background_execution(self):
_SIMULATED_BACKGROUND_TASK_EXECUTION_TIME_S * 10**9,
)

def test_trailers(self):
"""Test that trailers are emitted as expected and that the server span is ended
BEFORE the background task is finished."""
app = otel_asgi.OpenTelemetryMiddleware(
background_execution_trailers_asgi
)
self.seed_app(app)
self.send_default_request()
outputs = self.get_all_output()

def add_body_and_trailer_span(expected: list):
body_span = {
"name": "GET / http send",
"kind": trace_api.SpanKind.INTERNAL,
"attributes": {"type": "http.response.body"},
}
trailer_span = {
"name": "GET / http send",
"kind": trace_api.SpanKind.INTERNAL,
"attributes": {"type": "http.response.trailers"},
}
expected[2:2] = [body_span]
expected[4:4] = [trailer_span] * 2
return expected

self.validate_outputs(outputs, modifiers=[add_body_and_trailer_span])
span_list = self.memory_exporter.get_finished_spans()
server_span = span_list[-1]
assert server_span.kind == SpanKind.SERVER
span_duration_nanos = server_span.end_time - server_span.start_time
self.assertLessEqual(
span_duration_nanos,
_SIMULATED_BACKGROUND_TASK_EXECUTION_TIME_S * 10**9,
)

def test_override_span_name(self):
"""Test that default span_names can be overwritten by our callback function."""
span_name = "Dymaxion"
Expand Down

0 comments on commit 816e8a3

Please sign in to comment.