Skip to content
2 changes: 1 addition & 1 deletion ddtrace/contrib/internal/django/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,7 @@ def _gather_block_metadata(request, request_headers, ctx: core.ExecutionContext)
if user_agent:
metadata[http.USER_AGENT] = user_agent
except Exception as e:
log.warning("Could not gather some metadata on blocked request: %s", str(e)) # noqa: G200
log.warning("Could not gather some metadata on blocked request: %s", str(e))
core.dispatch("django.block_request_callback", (ctx, metadata, config_django, url, query))


Expand Down
20 changes: 13 additions & 7 deletions ddtrace/internal/endpoints.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import dataclasses
from time import monotonic
from typing import List
from typing import Set


@dataclasses.dataclass(frozen=True)
Expand All @@ -9,11 +9,17 @@ class HttpEndPoint:
path: str
resource_name: str = dataclasses.field(default="")
operation_name: str = dataclasses.field(default="http.request")
_hash: int = dataclasses.field(init=False, repr=False)

def __post_init__(self) -> None:
super().__setattr__("method", self.method.upper())
if not self.resource_name:
super().__setattr__("resource_name", f"{self.method} {self.path}")
# cache hash result
super().__setattr__("_hash", hash((self.method, self.path)))

def __hash__(self) -> int:
return self._hash


@dataclasses.dataclass()
Expand All @@ -24,7 +30,7 @@ class HttpEndPointsCollection:
It maintains a maximum size and drops endpoints after a certain time period in case of a hot reload of the server.
"""

endpoints: List[HttpEndPoint] = dataclasses.field(default_factory=list, init=False)
endpoints: Set[HttpEndPoint] = dataclasses.field(default_factory=set, init=False)
is_first: bool = dataclasses.field(default=True, init=False)
drop_time_seconds: float = dataclasses.field(default=90.0, init=False)
last_modification_time: float = dataclasses.field(default_factory=monotonic, init=False)
Expand All @@ -45,12 +51,12 @@ def add_endpoint(
current_time = monotonic()
if current_time - self.last_modification_time > self.drop_time_seconds:
self.reset()
self.endpoints.append(
self.endpoints.add(
HttpEndPoint(method=method, path=path, resource_name=resource_name, operation_name=operation_name)
)
elif len(self.endpoints) < self.max_size_length:
self.last_modification_time = current_time
self.endpoints.append(
self.endpoints.add(
HttpEndPoint(method=method, path=path, resource_name=resource_name, operation_name=operation_name)
)

Expand All @@ -61,16 +67,16 @@ def flush(self, max_length: int) -> dict:
if max_length >= len(self.endpoints):
res = {
"is_first": self.is_first,
"endpoints": [dataclasses.asdict(ep) for ep in self.endpoints],
"endpoints": list(map(dataclasses.asdict, self.endpoints)),
}
self.reset()
return res
else:
batch = [self.endpoints.pop() for _ in range(max_length)]
res = {
"is_first": self.is_first,
"endpoints": [dataclasses.asdict(ep) for ep in self.endpoints[:max_length]],
"endpoints": [dataclasses.asdict(ep) for ep in batch],
}
self.endpoints = self.endpoints[max_length:]
self.is_first = False
self.last_modification_time = monotonic()
return res
14 changes: 14 additions & 0 deletions tests/appsec/contrib_appsec/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,20 @@ def get(name):
yield get


@pytest.fixture
def find_resource(test_spans, root_span):
# checking both root spans and web spans for the tag
def find(resource_name):
for span in test_spans.spans:
if span.parent_id is None or span.span_type == "web":
res = span._resource[0]
if res == resource_name:
return True
return False

yield find


@pytest.fixture
def get_metric(root_span):
yield lambda name: root_span().get_metric(name)
Expand Down
4 changes: 2 additions & 2 deletions tests/appsec/contrib_appsec/django_app/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,13 +154,13 @@ def rasp(request, endpoint: str):
if param.startswith("cmda"):
cmd = query_params[param]
try:
res.append(f'cmd stdout: {subprocess.run([cmd, "-c", "3", "localhost"])}')
res.append(f'cmd stdout: {subprocess.run([cmd, "-c", "3", "localhost"], timeout=0.5)}')
except Exception as e:
res.append(f"Error: {e}")
elif param.startswith("cmds"):
cmd = query_params[param]
try:
res.append(f"cmd stdout: {subprocess.run(cmd)}")
res.append(f"cmd stdout: {subprocess.run(cmd, timeout=0.5)}")
except Exception as e:
res.append(f"Error: {e}")
tracer.current_span()._local_root.set_tag("rasp.request.done", endpoint)
Expand Down
45 changes: 45 additions & 0 deletions tests/appsec/contrib_appsec/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,51 @@ def test_simple_attack_timeout(self, interface: Interface, root_span, get_metric
assert len(args_list) == 1
assert ("waf_timeout", "true") in args_list[0][4]

def test_api_endpoint_discovery(self, interface: Interface, find_resource):
"""Check that API endpoint discovery works in the framework.

Also ensure the resource name is set correctly.
"""
if interface.name != "django":
pytest.skip("API endpoint discovery is only supported in Django")
from ddtrace.settings.asm import endpoint_collection

def parse(path: str) -> str:
import re

# django substitutions to make a url path from route
if re.match(r"^\^.*\$$", path):
path = path[1:-1]
path = re.sub(r"<int:param_int>", "123", path)
path = re.sub(r"<str:param_str>", "abc", path)
if path.endswith("/?"):
path = path[:-2]
return "/" + path

with override_global_config(dict(_asm_enabled=True)):
self.update_tracer(interface)
# required to load the endpoints
interface.client.get("/")
collection = endpoint_collection.endpoints
assert collection
for ep in collection:
assert ep.method
# path could be empty, but must be a string
assert isinstance(ep.path, str)
assert ep.resource_name
assert ep.operation_name
if ep.method not in ("GET", "*", "POST"):
continue
path = parse(ep.path)
response = (
interface.client.post(path, {"data": "content"})
if ep.method == "POST"
else interface.client.get(path)
)
assert self.status(response) in (200, 401), f"ep.path failed: {ep.path} -> {path}"
resource = "GET" + ep.resource_name[1:] if ep.resource_name.startswith("* ") else ep.resource_name
assert find_resource(resource)

@pytest.mark.parametrize("asm_enabled", [True, False])
@pytest.mark.parametrize(
("user_agent", "priority"),
Expand Down
Loading