Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 187 additions & 7 deletions ddtrace/_trace/trace_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Callable
from typing import Dict
from typing import List
from typing import Mapping
from typing import Optional
from typing import Tuple
from urllib import parse
Expand All @@ -24,14 +25,20 @@
from ddtrace.constants import SPAN_KIND
from ddtrace.contrib import trace_utils
from ddtrace.contrib.internal.botocore.constants import BOTOCORE_STEPFUNCTIONS_INPUT_KEY

# from ddtrace.internal.utils import _copy_trace_level_tags
from ddtrace.contrib.internal.trace_utils import _copy_trace_level_tags
from ddtrace.contrib.internal.trace_utils import _set_url_tag
from ddtrace.ext import SpanKind
from ddtrace.ext import SpanLinkKind
from ddtrace.ext import azure_servicebus as azure_servicebusx
from ddtrace.ext import db
from ddtrace.ext import http
from ddtrace.ext import net
from ddtrace.ext import redis as redisx
from ddtrace.ext import websocket
from ddtrace.internal import core
from ddtrace.internal.compat import is_valid_ip
from ddtrace.internal.compat import maybe_stringify
from ddtrace.internal.constants import COMPONENT
from ddtrace.internal.constants import FLASK_ENDPOINT
Expand All @@ -42,7 +49,9 @@
from ddtrace.internal.constants import MESSAGING_MESSAGE_ID
from ddtrace.internal.constants import MESSAGING_OPERATION
from ddtrace.internal.constants import MESSAGING_SYSTEM
from ddtrace.internal.constants import SPAN_LINK_KIND
from ddtrace.internal.logger import get_logger
from ddtrace.internal.sampling import _inherit_sampling_tags
from ddtrace.internal.schema.span_attribute_schema import SpanDirection
from ddtrace.propagation.http import HTTPPropagator

Expand Down Expand Up @@ -131,7 +140,7 @@ def _start_span(ctx: core.ExecutionContext, call_trace: bool = True, **kwargs) -

if config._inferred_proxy_services_enabled:
# dispatch event for checking headers and possibly making an inferred proxy span
core.dispatch("inferred_proxy.start", (ctx, tracer, span_kwargs, call_trace, integration_config))
core.dispatch("inferred_proxy.start", (ctx, tracer, span_kwargs, call_trace))
# re-get span_kwargs in case an inferred span was created and we have a new span_kwargs.child_of field
span_kwargs = ctx.get_item("span_kwargs", span_kwargs)

Expand Down Expand Up @@ -237,14 +246,15 @@ def _set_inferred_proxy_tags(span, status_code):
inferred_span.set_tag(ERROR_STACK, span.get_tag(ERROR_STACK))


def _on_inferred_proxy_start(ctx, tracer, span_kwargs, call_trace, integration_config):
def _on_inferred_proxy_start(ctx, tracer, span_kwargs, call_trace):
# Skip creating another inferred span if one has already been created for this request
if ctx.get_item("inferred_proxy_span"):
return

# some integrations like Flask / WSGI store headers from environ in 'distributed_headers'
# and normalized headers in 'headers'
headers = ctx.get_item("headers", ctx.get_item("distributed_headers", None))
integration_config = ctx.get_item("integration_config")

# Inferred Proxy Spans
if integration_config and headers is not None:
Expand Down Expand Up @@ -956,6 +966,168 @@ def _on_router_match(route):
req_span.set_tag_str(http.ROUTE, route.template)


def _set_websocket_message_tags_on_span(websocket_span: Span, message: Mapping[str, Any]):
if "text" in message:
websocket_span.set_tag_str(websocket.MESSAGE_TYPE, "text")
websocket_span.set_metric(websocket.MESSAGE_LENGTH, len(message["text"].encode("utf-8")))
elif "binary" in message:
websocket_span.set_tag_str(websocket.MESSAGE_TYPE, "binary")
websocket_span.set_metric(websocket.MESSAGE_LENGTH, len(message["bytes"]))


def _set_websocket_close_tags(span: Span, message: Mapping[str, Any]):
code = message.get("code")
reason = message.get("reason")
if code is not None:
span.set_metric(websocket.CLOSE_CODE, code)
if reason:
span.set_tag(websocket.CLOSE_REASON, reason)


def _set_client_ip_tags(scope: Mapping[str, Any], span: Span):
client = scope.get("client")
if len(client) >= 1: # type: ignore[arg-type]
client_ip = client[0] # type: ignore[index]
span.set_tag_str(net.TARGET_HOST, client_ip)
try:
is_valid_ip(client_ip)
span.set_tag_str("network.client.ip", client_ip)
except ValueError as e:
log.debug("Could not validate client IP address for websocket send message: %s", str(e))


def _on_asgi_websocket_receive_message(ctx, scope, message):
"""
Handle websocket receive message events.

This handler is called when a websocket receive message event is dispatched.
It sets up the span with appropriate tags, metrics, and links.
"""
span = ctx.span
integration_config = ctx.get_item("integration_config")

span.set_tag_str(COMPONENT, integration_config.integration_name)
span.set_tag_str(SPAN_KIND, SpanKind.CONSUMER)
span.set_tag_str(websocket.RECEIVE_DURATION_TYPE, "blocking")

_set_websocket_message_tags_on_span(span, message)

span.set_metric(websocket.MESSAGE_FRAMES, 1)

if hasattr(ctx, "parent") and ctx.parent.span:
span.set_link(
trace_id=ctx.parent.span.trace_id,
span_id=ctx.parent.span.span_id,
attributes={SPAN_LINK_KIND: SpanLinkKind.EXECUTED},
)

if getattr(integration_config, "asgi_websocket_messages_inherit_sampling", True):
_inherit_sampling_tags(span, ctx.parent.span._local_root)

_copy_trace_level_tags(span, ctx.parent.span)


def _on_asgi_websocket_send_message(ctx, scope, message):
"""
Handle websocket send message events.

This handler is called when a websocket send message event is dispatched.
It sets up the span with appropriate tags, metrics, and links.
"""
span = ctx.span
integration_config = ctx.get_item("integration_config")

span.set_tag_str(COMPONENT, integration_config.integration_name)
span.set_tag_str(SPAN_KIND, SpanKind.PRODUCER)
_set_client_ip_tags(scope, span)
_set_websocket_message_tags_on_span(span, message)

span.set_metric(websocket.MESSAGE_FRAMES, 1)

if hasattr(ctx, "parent") and ctx.parent.span:
span.set_link(
trace_id=ctx.parent.span.trace_id,
span_id=ctx.parent.span.span_id,
attributes={SPAN_LINK_KIND: SpanLinkKind.RESUMING},
)


def _on_asgi_websocket_close_message(ctx, scope, message):
"""
Handle websocket close message events.

This handler is called when a websocket close message event is dispatched.
It sets up the span with appropriate tags, metrics, and links.
"""
span = ctx.span
integration_config = ctx.get_item("integration_config")

span.set_tag_str(COMPONENT, integration_config.integration_name)
span.set_tag_str(SPAN_KIND, SpanKind.PRODUCER)

_set_client_ip_tags(scope, span)

_set_websocket_message_tags_on_span(span, message)

_set_websocket_close_tags(span, message)

if hasattr(ctx, "parent") and ctx.parent.span:
span.set_link(
trace_id=ctx.parent.span.trace_id,
span_id=ctx.parent.span.span_id,
attributes={SPAN_LINK_KIND: SpanLinkKind.RESUMING},
)

_copy_trace_level_tags(span, ctx.parent.span)


def _on_asgi_websocket_disconnect_message(ctx, scope, message):
"""
Handle websocket disconnect message events.

This handler is called when a websocket disconnect message event is dispatched.
It sets up the span with appropriate tags, metrics, and links.
"""
span = ctx.span
integration_config = ctx.get_item("integration_config")

span.set_tag_str(COMPONENT, integration_config.integration_name)
span.set_tag_str(SPAN_KIND, SpanKind.CONSUMER)

_set_websocket_close_tags(span, message)

if hasattr(ctx, "parent") and ctx.parent.span:
span.set_link(
trace_id=ctx.parent_span.trace_id,
span_id=ctx.parent_span.span_id,
attributes={SPAN_LINK_KIND: SpanLinkKind.EXECUTED},
)

if getattr(integration_config, "asgi_websocket_messages_inherit_sampling", True):
_inherit_sampling_tags(span, ctx.parent.span._local_root)

_copy_trace_level_tags(span, ctx.parent.span)


def _on_asgi_request(ctx: core.ExecutionContext) -> None:
"""Handler for ASGI request context started event."""
scope = ctx.get_item("scope")
integration_config = ctx.get_item("integration_config")

ctx.set_item("tags", {COMPONENT: integration_config.integration_name, SPAN_KIND: SpanKind.SERVER})

span = _start_span(ctx)
ctx.set_item("req_span", span)

if scope["type"] == "websocket":
span.set_tag_str("http.upgraded", "websocket")

if "datadog" not in scope:
scope["datadog"] = {"request_spans": [span]}
else:
scope["datadog"]["request_spans"].append(span)


def listen():
core.on("wsgi.request.prepare", _on_request_prepare)
core.on("wsgi.request.prepared", _on_request_prepared)
Expand Down Expand Up @@ -1009,6 +1181,11 @@ def listen():
core.on("azure.functions.start_response", _on_azure_functions_start_response)
core.on("azure.functions.trigger_call_modifier", _on_azure_functions_trigger_span_modifier)
core.on("azure.functions.service_bus_trigger_modifier", _on_azure_functions_service_bus_trigger_span_modifier)
core.on("asgi.websocket.receive.message", _on_asgi_websocket_receive_message)
core.on("asgi.websocket.send.message", _on_asgi_websocket_send_message)
core.on("asgi.websocket.disconnect.message", _on_asgi_websocket_disconnect_message)
core.on("asgi.websocket.close.message", _on_asgi_websocket_close_message)
core.on("context.started.asgi.request", _on_asgi_request)
core.on("azure.servicebus.message_modifier", _on_azure_servicebus_message_modifier)

# web frameworks general handlers
Expand Down Expand Up @@ -1042,11 +1219,10 @@ def listen():
"flask.call",
"flask.jsonify",
"flask.render_template",
"asgi.__call__",
"asgi.websocket.close_message",
"asgi.websocket.disconnect_message",
"asgi.websocket.receive_message",
"asgi.websocket.send_message",
"asgi.websocket.close.message",
"asgi.websocket.disconnect.message",
"asgi.websocket.receive.message",
"asgi.websocket.send.message",
"wsgi.__call__",
"django.cache",
"django.middleware.__call__",
Expand Down Expand Up @@ -1086,6 +1262,10 @@ def listen():
core.on(f"context.started.{context_name}", _start_span)

for name in (
"asgi.websocket.close.message",
"asgi.websocket.disconnect.message",
"asgi.websocket.receive.message",
"asgi.websocket.send.message",
"django.middleware.__call__",
"django.middleware.func",
"django.middleware.process_exception",
Expand Down
Loading
Loading