diff --git a/ddtrace/contrib/internal/django/patch.py b/ddtrace/contrib/internal/django/patch.py index 087186cb18a..0886c3d9bbf 100644 --- a/ddtrace/contrib/internal/django/patch.py +++ b/ddtrace/contrib/internal/django/patch.py @@ -456,7 +456,7 @@ def _gather_block_metadata(request, request_headers, ctx: core.ExecutionContext) if user_agent: metadata[http.USER_AGENT] = user_agent except Exception as e: - log.warning("Could not gather some metadata on blocked request: %s", str(e)) # noqa: G200 + log.warning("Could not gather some metadata on blocked request: %s", str(e)) core.dispatch("django.block_request_callback", (ctx, metadata, config_django, url, query)) diff --git a/ddtrace/internal/endpoints.py b/ddtrace/internal/endpoints.py index f21236eec5f..90725956988 100644 --- a/ddtrace/internal/endpoints.py +++ b/ddtrace/internal/endpoints.py @@ -1,6 +1,6 @@ import dataclasses from time import monotonic -from typing import List +from typing import Set @dataclasses.dataclass(frozen=True) @@ -9,11 +9,17 @@ class HttpEndPoint: path: str resource_name: str = dataclasses.field(default="") operation_name: str = dataclasses.field(default="http.request") + _hash: int = dataclasses.field(init=False, repr=False) def __post_init__(self) -> None: super().__setattr__("method", self.method.upper()) if not self.resource_name: super().__setattr__("resource_name", f"{self.method} {self.path}") + # cache hash result + super().__setattr__("_hash", hash((self.method, self.path))) + + def __hash__(self) -> int: + return self._hash @dataclasses.dataclass() @@ -24,7 +30,7 @@ class HttpEndPointsCollection: It maintains a maximum size and drops endpoints after a certain time period in case of a hot reload of the server. """ - endpoints: List[HttpEndPoint] = dataclasses.field(default_factory=list, init=False) + endpoints: Set[HttpEndPoint] = dataclasses.field(default_factory=set, init=False) is_first: bool = dataclasses.field(default=True, init=False) drop_time_seconds: float = dataclasses.field(default=90.0, init=False) last_modification_time: float = dataclasses.field(default_factory=monotonic, init=False) @@ -45,12 +51,12 @@ def add_endpoint( current_time = monotonic() if current_time - self.last_modification_time > self.drop_time_seconds: self.reset() - self.endpoints.append( + self.endpoints.add( HttpEndPoint(method=method, path=path, resource_name=resource_name, operation_name=operation_name) ) elif len(self.endpoints) < self.max_size_length: self.last_modification_time = current_time - self.endpoints.append( + self.endpoints.add( HttpEndPoint(method=method, path=path, resource_name=resource_name, operation_name=operation_name) ) @@ -61,16 +67,16 @@ def flush(self, max_length: int) -> dict: if max_length >= len(self.endpoints): res = { "is_first": self.is_first, - "endpoints": [dataclasses.asdict(ep) for ep in self.endpoints], + "endpoints": list(map(dataclasses.asdict, self.endpoints)), } self.reset() return res else: + batch = [self.endpoints.pop() for _ in range(max_length)] res = { "is_first": self.is_first, - "endpoints": [dataclasses.asdict(ep) for ep in self.endpoints[:max_length]], + "endpoints": [dataclasses.asdict(ep) for ep in batch], } - self.endpoints = self.endpoints[max_length:] self.is_first = False self.last_modification_time = monotonic() return res diff --git a/tests/appsec/contrib_appsec/conftest.py b/tests/appsec/contrib_appsec/conftest.py index 2df68072b9d..4fe2c8f5e62 100644 --- a/tests/appsec/contrib_appsec/conftest.py +++ b/tests/appsec/contrib_appsec/conftest.py @@ -64,6 +64,20 @@ def get(name): yield get +@pytest.fixture +def find_resource(test_spans, root_span): + # checking both root spans and web spans for the tag + def find(resource_name): + for span in test_spans.spans: + if span.parent_id is None or span.span_type == "web": + res = span._resource[0] + if res == resource_name: + return True + return False + + yield find + + @pytest.fixture def get_metric(root_span): yield lambda name: root_span().get_metric(name) diff --git a/tests/appsec/contrib_appsec/django_app/urls.py b/tests/appsec/contrib_appsec/django_app/urls.py index 1ac7f0e03fa..2e4de06b7a0 100644 --- a/tests/appsec/contrib_appsec/django_app/urls.py +++ b/tests/appsec/contrib_appsec/django_app/urls.py @@ -154,13 +154,13 @@ def rasp(request, endpoint: str): if param.startswith("cmda"): cmd = query_params[param] try: - res.append(f'cmd stdout: {subprocess.run([cmd, "-c", "3", "localhost"])}') + res.append(f'cmd stdout: {subprocess.run([cmd, "-c", "3", "localhost"], timeout=0.5)}') except Exception as e: res.append(f"Error: {e}") elif param.startswith("cmds"): cmd = query_params[param] try: - res.append(f"cmd stdout: {subprocess.run(cmd)}") + res.append(f"cmd stdout: {subprocess.run(cmd, timeout=0.5)}") except Exception as e: res.append(f"Error: {e}") tracer.current_span()._local_root.set_tag("rasp.request.done", endpoint) diff --git a/tests/appsec/contrib_appsec/utils.py b/tests/appsec/contrib_appsec/utils.py index a254d34fcea..ff95d4fde73 100644 --- a/tests/appsec/contrib_appsec/utils.py +++ b/tests/appsec/contrib_appsec/utils.py @@ -181,6 +181,51 @@ def test_simple_attack_timeout(self, interface: Interface, root_span, get_metric assert len(args_list) == 1 assert ("waf_timeout", "true") in args_list[0][4] + def test_api_endpoint_discovery(self, interface: Interface, find_resource): + """Check that API endpoint discovery works in the framework. + + Also ensure the resource name is set correctly. + """ + if interface.name != "django": + pytest.skip("API endpoint discovery is only supported in Django") + from ddtrace.settings.asm import endpoint_collection + + def parse(path: str) -> str: + import re + + # django substitutions to make a url path from route + if re.match(r"^\^.*\$$", path): + path = path[1:-1] + path = re.sub(r"", "123", path) + path = re.sub(r"", "abc", path) + if path.endswith("/?"): + path = path[:-2] + return "/" + path + + with override_global_config(dict(_asm_enabled=True)): + self.update_tracer(interface) + # required to load the endpoints + interface.client.get("/") + collection = endpoint_collection.endpoints + assert collection + for ep in collection: + assert ep.method + # path could be empty, but must be a string + assert isinstance(ep.path, str) + assert ep.resource_name + assert ep.operation_name + if ep.method not in ("GET", "*", "POST"): + continue + path = parse(ep.path) + response = ( + interface.client.post(path, {"data": "content"}) + if ep.method == "POST" + else interface.client.get(path) + ) + assert self.status(response) in (200, 401), f"ep.path failed: {ep.path} -> {path}" + resource = "GET" + ep.resource_name[1:] if ep.resource_name.startswith("* ") else ep.resource_name + assert find_resource(resource) + @pytest.mark.parametrize("asm_enabled", [True, False]) @pytest.mark.parametrize( ("user_agent", "priority"),