Skip to content

Commit ca78f69

Browse files
chore(asgi): refactor websocket code to use core API (#14345)
In asgi middleware: - move tracing logic to trace handlers - refactor out `websocket.receive` and `websocket.disconnect` messages to call respective trace handlers - refactor out `websocket.send` and `websocket.close` messages to call respective trace handlers ## Checklist - [x] PR author has checked that all the criteria below are met - The PR description includes an overview of the change - The PR description articulates the motivation for the change - The change includes tests OR the PR description describes a testing strategy - The PR description notes risks associated with the change, if any - Newly-added code is easy to change - The change follows the [library release note guidelines](https://ddtrace.readthedocs.io/en/stable/releasenotes.html) - The change includes or references documentation updates if necessary - Backport labels are set (if [applicable](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting)) ## Reviewer Checklist - [ ] Reviewer has checked that all the criteria below are met - Title is accurate - All changes are related to the pull request's stated goal - Avoids breaking [API](https://ddtrace.readthedocs.io/en/stable/versioning.html#interfaces) changes - Testing strategy adequately addresses listed risks - Newly-added code is easy to change - Release note makes sense to a user of the library - If necessary, author has acknowledged and discussed the performance implications of this PR as reported in the benchmarks PR comment - Backport labels are set in a manner that is consistent with the [release branch maintenance policy](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting) --------- Co-authored-by: Brett Langdon <brett.langdon@datadoghq.com>
1 parent b626393 commit ca78f69

8 files changed

+614
-423
lines changed

ddtrace/_trace/trace_handlers.py

Lines changed: 187 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Callable
66
from typing import Dict
77
from typing import List
8+
from typing import Mapping
89
from typing import Optional
910
from typing import Tuple
1011
from urllib import parse
@@ -24,14 +25,20 @@
2425
from ddtrace.constants import SPAN_KIND
2526
from ddtrace.contrib import trace_utils
2627
from ddtrace.contrib.internal.botocore.constants import BOTOCORE_STEPFUNCTIONS_INPUT_KEY
28+
29+
# from ddtrace.internal.utils import _copy_trace_level_tags
30+
from ddtrace.contrib.internal.trace_utils import _copy_trace_level_tags
2731
from ddtrace.contrib.internal.trace_utils import _set_url_tag
2832
from ddtrace.ext import SpanKind
33+
from ddtrace.ext import SpanLinkKind
2934
from ddtrace.ext import azure_servicebus as azure_servicebusx
3035
from ddtrace.ext import db
3136
from ddtrace.ext import http
3237
from ddtrace.ext import net
3338
from ddtrace.ext import redis as redisx
39+
from ddtrace.ext import websocket
3440
from ddtrace.internal import core
41+
from ddtrace.internal.compat import is_valid_ip
3542
from ddtrace.internal.compat import maybe_stringify
3643
from ddtrace.internal.constants import COMPONENT
3744
from ddtrace.internal.constants import FLASK_ENDPOINT
@@ -42,7 +49,9 @@
4249
from ddtrace.internal.constants import MESSAGING_MESSAGE_ID
4350
from ddtrace.internal.constants import MESSAGING_OPERATION
4451
from ddtrace.internal.constants import MESSAGING_SYSTEM
52+
from ddtrace.internal.constants import SPAN_LINK_KIND
4553
from ddtrace.internal.logger import get_logger
54+
from ddtrace.internal.sampling import _inherit_sampling_tags
4655
from ddtrace.internal.schema.span_attribute_schema import SpanDirection
4756
from ddtrace.propagation.http import HTTPPropagator
4857

@@ -131,7 +140,7 @@ def _start_span(ctx: core.ExecutionContext, call_trace: bool = True, **kwargs) -
131140

132141
if config._inferred_proxy_services_enabled:
133142
# dispatch event for checking headers and possibly making an inferred proxy span
134-
core.dispatch("inferred_proxy.start", (ctx, tracer, span_kwargs, call_trace, integration_config))
143+
core.dispatch("inferred_proxy.start", (ctx, tracer, span_kwargs, call_trace))
135144
# re-get span_kwargs in case an inferred span was created and we have a new span_kwargs.child_of field
136145
span_kwargs = ctx.get_item("span_kwargs", span_kwargs)
137146

@@ -237,14 +246,15 @@ def _set_inferred_proxy_tags(span, status_code):
237246
inferred_span.set_tag(ERROR_STACK, span.get_tag(ERROR_STACK))
238247

239248

240-
def _on_inferred_proxy_start(ctx, tracer, span_kwargs, call_trace, integration_config):
249+
def _on_inferred_proxy_start(ctx, tracer, span_kwargs, call_trace):
241250
# Skip creating another inferred span if one has already been created for this request
242251
if ctx.get_item("inferred_proxy_span"):
243252
return
244253

245254
# some integrations like Flask / WSGI store headers from environ in 'distributed_headers'
246255
# and normalized headers in 'headers'
247256
headers = ctx.get_item("headers", ctx.get_item("distributed_headers", None))
257+
integration_config = ctx.get_item("integration_config")
248258

249259
# Inferred Proxy Spans
250260
if integration_config and headers is not None:
@@ -956,6 +966,168 @@ def _on_router_match(route):
956966
req_span.set_tag_str(http.ROUTE, route.template)
957967

958968

969+
def _set_websocket_message_tags_on_span(websocket_span: Span, message: Mapping[str, Any]):
970+
if "text" in message:
971+
websocket_span.set_tag_str(websocket.MESSAGE_TYPE, "text")
972+
websocket_span.set_metric(websocket.MESSAGE_LENGTH, len(message["text"].encode("utf-8")))
973+
elif "binary" in message:
974+
websocket_span.set_tag_str(websocket.MESSAGE_TYPE, "binary")
975+
websocket_span.set_metric(websocket.MESSAGE_LENGTH, len(message["bytes"]))
976+
977+
978+
def _set_websocket_close_tags(span: Span, message: Mapping[str, Any]):
979+
code = message.get("code")
980+
reason = message.get("reason")
981+
if code is not None:
982+
span.set_metric(websocket.CLOSE_CODE, code)
983+
if reason:
984+
span.set_tag(websocket.CLOSE_REASON, reason)
985+
986+
987+
def _set_client_ip_tags(scope: Mapping[str, Any], span: Span):
988+
client = scope.get("client")
989+
if len(client) >= 1: # type: ignore[arg-type]
990+
client_ip = client[0] # type: ignore[index]
991+
span.set_tag_str(net.TARGET_HOST, client_ip)
992+
try:
993+
is_valid_ip(client_ip)
994+
span.set_tag_str("network.client.ip", client_ip)
995+
except ValueError as e:
996+
log.debug("Could not validate client IP address for websocket send message: %s", str(e))
997+
998+
999+
def _on_asgi_websocket_receive_message(ctx, scope, message):
1000+
"""
1001+
Handle websocket receive message events.
1002+
1003+
This handler is called when a websocket receive message event is dispatched.
1004+
It sets up the span with appropriate tags, metrics, and links.
1005+
"""
1006+
span = ctx.span
1007+
integration_config = ctx.get_item("integration_config")
1008+
1009+
span.set_tag_str(COMPONENT, integration_config.integration_name)
1010+
span.set_tag_str(SPAN_KIND, SpanKind.CONSUMER)
1011+
span.set_tag_str(websocket.RECEIVE_DURATION_TYPE, "blocking")
1012+
1013+
_set_websocket_message_tags_on_span(span, message)
1014+
1015+
span.set_metric(websocket.MESSAGE_FRAMES, 1)
1016+
1017+
if hasattr(ctx, "parent") and ctx.parent.span:
1018+
span.set_link(
1019+
trace_id=ctx.parent.span.trace_id,
1020+
span_id=ctx.parent.span.span_id,
1021+
attributes={SPAN_LINK_KIND: SpanLinkKind.EXECUTED},
1022+
)
1023+
1024+
if getattr(integration_config, "asgi_websocket_messages_inherit_sampling", True):
1025+
_inherit_sampling_tags(span, ctx.parent.span._local_root)
1026+
1027+
_copy_trace_level_tags(span, ctx.parent.span)
1028+
1029+
1030+
def _on_asgi_websocket_send_message(ctx, scope, message):
1031+
"""
1032+
Handle websocket send message events.
1033+
1034+
This handler is called when a websocket send message event is dispatched.
1035+
It sets up the span with appropriate tags, metrics, and links.
1036+
"""
1037+
span = ctx.span
1038+
integration_config = ctx.get_item("integration_config")
1039+
1040+
span.set_tag_str(COMPONENT, integration_config.integration_name)
1041+
span.set_tag_str(SPAN_KIND, SpanKind.PRODUCER)
1042+
_set_client_ip_tags(scope, span)
1043+
_set_websocket_message_tags_on_span(span, message)
1044+
1045+
span.set_metric(websocket.MESSAGE_FRAMES, 1)
1046+
1047+
if hasattr(ctx, "parent") and ctx.parent.span:
1048+
span.set_link(
1049+
trace_id=ctx.parent.span.trace_id,
1050+
span_id=ctx.parent.span.span_id,
1051+
attributes={SPAN_LINK_KIND: SpanLinkKind.RESUMING},
1052+
)
1053+
1054+
1055+
def _on_asgi_websocket_close_message(ctx, scope, message):
1056+
"""
1057+
Handle websocket close message events.
1058+
1059+
This handler is called when a websocket close message event is dispatched.
1060+
It sets up the span with appropriate tags, metrics, and links.
1061+
"""
1062+
span = ctx.span
1063+
integration_config = ctx.get_item("integration_config")
1064+
1065+
span.set_tag_str(COMPONENT, integration_config.integration_name)
1066+
span.set_tag_str(SPAN_KIND, SpanKind.PRODUCER)
1067+
1068+
_set_client_ip_tags(scope, span)
1069+
1070+
_set_websocket_message_tags_on_span(span, message)
1071+
1072+
_set_websocket_close_tags(span, message)
1073+
1074+
if hasattr(ctx, "parent") and ctx.parent.span:
1075+
span.set_link(
1076+
trace_id=ctx.parent.span.trace_id,
1077+
span_id=ctx.parent.span.span_id,
1078+
attributes={SPAN_LINK_KIND: SpanLinkKind.RESUMING},
1079+
)
1080+
1081+
_copy_trace_level_tags(span, ctx.parent.span)
1082+
1083+
1084+
def _on_asgi_websocket_disconnect_message(ctx, scope, message):
1085+
"""
1086+
Handle websocket disconnect message events.
1087+
1088+
This handler is called when a websocket disconnect message event is dispatched.
1089+
It sets up the span with appropriate tags, metrics, and links.
1090+
"""
1091+
span = ctx.span
1092+
integration_config = ctx.get_item("integration_config")
1093+
1094+
span.set_tag_str(COMPONENT, integration_config.integration_name)
1095+
span.set_tag_str(SPAN_KIND, SpanKind.CONSUMER)
1096+
1097+
_set_websocket_close_tags(span, message)
1098+
1099+
if hasattr(ctx, "parent") and ctx.parent.span:
1100+
span.set_link(
1101+
trace_id=ctx.parent_span.trace_id,
1102+
span_id=ctx.parent_span.span_id,
1103+
attributes={SPAN_LINK_KIND: SpanLinkKind.EXECUTED},
1104+
)
1105+
1106+
if getattr(integration_config, "asgi_websocket_messages_inherit_sampling", True):
1107+
_inherit_sampling_tags(span, ctx.parent.span._local_root)
1108+
1109+
_copy_trace_level_tags(span, ctx.parent.span)
1110+
1111+
1112+
def _on_asgi_request(ctx: core.ExecutionContext) -> None:
1113+
"""Handler for ASGI request context started event."""
1114+
scope = ctx.get_item("scope")
1115+
integration_config = ctx.get_item("integration_config")
1116+
1117+
ctx.set_item("tags", {COMPONENT: integration_config.integration_name, SPAN_KIND: SpanKind.SERVER})
1118+
1119+
span = _start_span(ctx)
1120+
ctx.set_item("req_span", span)
1121+
1122+
if scope["type"] == "websocket":
1123+
span.set_tag_str("http.upgraded", "websocket")
1124+
1125+
if "datadog" not in scope:
1126+
scope["datadog"] = {"request_spans": [span]}
1127+
else:
1128+
scope["datadog"]["request_spans"].append(span)
1129+
1130+
9591131
def listen():
9601132
core.on("wsgi.request.prepare", _on_request_prepare)
9611133
core.on("wsgi.request.prepared", _on_request_prepared)
@@ -1009,6 +1181,11 @@ def listen():
10091181
core.on("azure.functions.start_response", _on_azure_functions_start_response)
10101182
core.on("azure.functions.trigger_call_modifier", _on_azure_functions_trigger_span_modifier)
10111183
core.on("azure.functions.service_bus_trigger_modifier", _on_azure_functions_service_bus_trigger_span_modifier)
1184+
core.on("asgi.websocket.receive.message", _on_asgi_websocket_receive_message)
1185+
core.on("asgi.websocket.send.message", _on_asgi_websocket_send_message)
1186+
core.on("asgi.websocket.disconnect.message", _on_asgi_websocket_disconnect_message)
1187+
core.on("asgi.websocket.close.message", _on_asgi_websocket_close_message)
1188+
core.on("context.started.asgi.request", _on_asgi_request)
10121189
core.on("azure.servicebus.message_modifier", _on_azure_servicebus_message_modifier)
10131190

10141191
# web frameworks general handlers
@@ -1042,11 +1219,10 @@ def listen():
10421219
"flask.call",
10431220
"flask.jsonify",
10441221
"flask.render_template",
1045-
"asgi.__call__",
1046-
"asgi.websocket.close_message",
1047-
"asgi.websocket.disconnect_message",
1048-
"asgi.websocket.receive_message",
1049-
"asgi.websocket.send_message",
1222+
"asgi.websocket.close.message",
1223+
"asgi.websocket.disconnect.message",
1224+
"asgi.websocket.receive.message",
1225+
"asgi.websocket.send.message",
10501226
"wsgi.__call__",
10511227
"django.cache",
10521228
"django.middleware.__call__",
@@ -1086,6 +1262,10 @@ def listen():
10861262
core.on(f"context.started.{context_name}", _start_span)
10871263

10881264
for name in (
1265+
"asgi.websocket.close.message",
1266+
"asgi.websocket.disconnect.message",
1267+
"asgi.websocket.receive.message",
1268+
"asgi.websocket.send.message",
10891269
"django.middleware.__call__",
10901270
"django.middleware.func",
10911271
"django.middleware.process_exception",

0 commit comments

Comments
 (0)