Skip to content

Commit

Permalink
feat: add traceID and pass it to ES
Browse files Browse the repository at this point in the history
  • Loading branch information
SimonThordal committed Jul 31, 2024
1 parent 144f640 commit f272834
Show file tree
Hide file tree
Showing 6 changed files with 196 additions and 23 deletions.
2 changes: 0 additions & 2 deletions tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,12 @@ def test_healthz():
res = client.get("/healthz")
assert res.status_code == 200, res
assert res.json().get("status") == "ok", res
assert "x-trace-id" in res.headers


def test_readyz():
res = client.get("/readyz")
assert res.status_code == 200, res
assert res.json().get("status") == "ok", res
assert "x-trace-id" in res.headers


def test_manifest():
Expand Down
37 changes: 37 additions & 0 deletions tests/test_trace_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from .conftest import client
import re


def test_trace_context() -> None:
# Works when receiving a valid trace context
headers = {
"traceparent": "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01",
"tracestate": "rojo=00f067aa0ba902b7",
}
parent_pat = re.compile(r"00-0af7651916cd43dd8448eb211c80319c-[0-9a-f]{16}-01")
state_pat = re.compile(r"yente=[0-9a-f]{16},\s?rojo=00f067aa0ba902b7")
res = client.get("/search/default?q=vladimir putin", headers=headers)
assert "traceparent" in res.headers
assert "tracestate" in res.headers
assert parent_pat.match(res.headers["traceparent"])
assert state_pat.match(res.headers["tracestate"])
# Works when not receiving a trace context
res = client.get("/search/default?q=vladimir putin")
assert "traceparent" in res.headers
assert "tracestate" in res.headers
assert re.match(
r"00-[0-9a-f]{32}-[0-9a-f]{16}-[0-9a-f]{2}", res.headers["traceparent"]
)
assert re.match(r"yente=[0-9a-f]{16}", res.headers["tracestate"])
# Works with a broken trace context
headers = {
"traceparent": "ff-0af7651916cd43dd8448eb211c80319c-0000000000000000-01",
"tracestate": "rojo=00f067aa0ba902b7",
}
res = client.get("/search/default?q=vladimir putin")
assert "traceparent" in res.headers
assert "tracestate" in res.headers
assert re.match(
r"00-[0-9a-f]{32}-[0-9a-f]{16}-[0-9a-f]{2}", res.headers["traceparent"]
)
assert re.match(r"yente=[0-9a-f]{16}", res.headers["tracestate"])
7 changes: 2 additions & 5 deletions yente/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from yente.data import refresh_catalog
from yente.search.indexer import update_index_threaded
from yente.provider import close_provider
from yente.middleware import TraceContextMiddleware

log = get_logger("yente")
ExceptionHandler = Callable[[Request, Any], Coroutine[Any, Any, Response]]
Expand Down Expand Up @@ -48,12 +49,8 @@ async def request_middleware(
request: Request, call_next: RequestResponseEndpoint
) -> Response:
start_time = time.time()
trace_id = request.headers.get("x-trace-id")
if trace_id is None:
trace_id = uuid4().hex
client_ip = request.client.host if request.client else "127.0.0.1"
bind_contextvars(
trace_id=trace_id,
client_ip=client_ip,
)
try:
Expand All @@ -62,7 +59,6 @@ async def request_middleware(
log.exception("Exception during request: %s" % type(exc))
response = JSONResponse(status_code=500, content={"status": "error"})
time_delta = time.time() - start_time
response.headers["x-trace-id"] = trace_id
log.info(
str(request.url.path),
action="request",
Expand Down Expand Up @@ -108,6 +104,7 @@ def create_app() -> FastAPI:
lifespan=lifespan,
)
app.middleware("http")(request_middleware)
app.add_middleware(TraceContextMiddleware)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
Expand Down
3 changes: 3 additions & 0 deletions yente/middleware/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .trace_context import TraceContextMiddleware

__all__ = ["TraceContextMiddleware", "get_trace_context"]
123 changes: 123 additions & 0 deletions yente/middleware/trace_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from typing import Any, Tuple, List
import secrets
from structlog.contextvars import get_contextvars, bind_contextvars

VENDOR_CODE = (
"yente" # It's available! https://w3c.github.io/tracestate-ids-registry/#registry
)


class TraceParent:
__slots__ = ["version", "trace_id", "parent_id", "trace_flags"]

def __init__(self, version: str, trace_id: str, parent_id: str, trace_flags: str):
self.version = version
self.trace_id = trace_id
self.parent_id = parent_id
self.trace_flags = trace_flags

def __str__(self) -> str:
return f"{self.version}-{self.trace_id}-{self.parent_id}-{self.trace_flags}"

@classmethod
def create(cls) -> "TraceParent":
return cls("00", secrets.token_hex(16), secrets.token_hex(8), "00")

@classmethod
def from_str(cls, traceparent: str | None) -> "TraceParent":
"""
Parse a traceparent header string into a TraceParent object created with a new parent_id.
"""
if traceparent is None:
return cls.create()
parts = traceparent.split("-")
try:
version, trace_id, parent_id, trace_flags = parts[:4]
except Exception:
raise ValueError(f"Invalid traceparent: {traceparent}")
if int(version, 16) == 255:
raise ValueError(f"Unsupported version: {version}")
for i in trace_id:
if i != "0":
break
else:
raise ValueError(f"Invalid trace_id: {trace_id}")
for i in parent_id:
if i != "0":
break
else:
raise ValueError(f"Invalid parent_id: {parent_id}")

return cls(version, trace_id, secrets.token_hex(8), trace_flags)


class TraceState:
__slots__ = ["tracestate"]

def __init__(self, tracestate: List[Tuple[str, str]] = []):
self.tracestate = tracestate

@classmethod
def create(cls, parent: TraceParent, prev_state: str = "") -> "TraceState":
spans_out: List[Tuple[str, str]] = []
for span in prev_state.split(","):
parts = span.split("=")
if len(parts) != 2:
# We are allowed to discard invalid states
continue
vendor, value = parts
if vendor == VENDOR_CODE:
continue
spans_out.append((vendor.lower().strip(), value.lower().strip()))
spans_out.insert(0, (VENDOR_CODE, f"{parent.parent_id}"))
return cls(spans_out)

def __str__(self) -> str:
return ",".join([f"{k}={v}" for k, v in self.tracestate])


class TraceContext:
__slots__ = ["traceparent", "tracestate"]

def __init__(self, traceparent: TraceParent, tracestate: TraceState):
self.traceparent = traceparent
self.tracestate = tracestate

def __repr__(self) -> str:
return str(
{
"traceparent": str(self.traceparent),
"tracestate": str(self.tracestate),
}
)


def get_trace_context() -> TraceContext | None:
vars = get_contextvars()
if "trace_context" in vars:
trace_context = vars["trace_context"]
if isinstance(trace_context, TraceContext):
return trace_context
return None


class TraceContextMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next: Any) -> Any:
parent_header = request.headers.get("traceparent")
try:
traceparent = TraceParent.from_str(parent_header)
except Exception:
traceparent = TraceParent.create()
state = request.headers.get("tracestate", "")
try:
tracestate = TraceState.create(traceparent, state)
except Exception:
tracestate = TraceState.create(traceparent, "")
context = TraceContext(traceparent, tracestate)
bind_contextvars(trace_context=context)
resp = await call_next(request)
resp.headers["traceparent"] = str(traceparent)
resp.headers["tracestate"] = str(tracestate)
return resp
47 changes: 31 additions & 16 deletions yente/provider/elastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from yente.logs import get_logger
from yente.search.mapping import make_entity_mapping, INDEX_SETTINGS
from yente.provider.base import SearchProvider, query_semaphore
from yente.middleware.trace_context import get_trace_context

log = get_logger(__name__)
warnings.filterwarnings("ignore", category=ElasticsearchWarning)
Expand Down Expand Up @@ -54,27 +55,40 @@ async def create(cls) -> "ElasticSearchProvider":
raise RuntimeError("Could not connect to ElasticSearch.")

def __init__(self, client: AsyncElasticsearch) -> None:
self.client = client
self._client = client

def client(self, **kwargs: Any) -> AsyncElasticsearch:
"""Get the client with the current context."""
if trace_context := get_trace_context():
arg_headers = kwargs.get("headers", {})
headers = arg_headers | (
dict(
traceparent=str(trace_context.traceparent),
tracestate=str(trace_context.tracestate),
)
)
kwargs.update(headers=headers)
return self._client.options(**kwargs)

async def close(self) -> None:
await self.client.close()
await self._client.close()

async def refresh(self, index: str) -> None:
"""Refresh the index to make changes visible."""
try:
await self.client.indices.refresh(index=index)
await self.client().indices.refresh(index=index)
except NotFoundError as nfe:
raise YenteNotFoundError(f"Index {index} does not exist.") from nfe

async def get_all_indices(self) -> List[str]:
"""Get a list of all indices in the ElasticSearch cluster."""
indices: Any = await self.client.cat.indices(format="json")
indices: Any = await self.client().cat.indices(format="json")
return [index.get("index") for index in indices]

async def get_alias_indices(self, alias: str) -> List[str]:
"""Get a list of indices that are aliased to the entity query alias."""
try:
resp = await self.client.indices.get_alias(name=alias)
resp = await self.client().indices.get_alias(name=alias)
return list(resp.keys())
except NotFoundError:
return []
Expand All @@ -88,7 +102,7 @@ async def rollover_index(self, alias: str, next_index: str, prefix: str) -> None
actions = []
actions.append({"remove": {"index": f"{prefix}*", "alias": alias}})
actions.append({"add": {"index": next_index, "alias": alias}})
await self.client.indices.update_aliases(actions=actions)
await self.client().indices.update_aliases(actions=actions)
except (ApiError, TransportError) as te:
raise YenteIndexError(f"Could not rollover index: {te}") from te

Expand All @@ -97,19 +111,19 @@ async def clone_index(self, base_version: str, target_version: str) -> None:
if base_version == target_version:
raise ValueError("Cannot clone an index to itself.")
try:
await self.client.indices.put_settings(
await self.client().indices.put_settings(
index=base_version,
settings={"index.blocks.read_only": True},
)
await self.delete_index(target_version)
await self.client.indices.clone(
await self.client().indices.clone(
index=base_version,
target=target_version,
body={
"settings": {"index": {"blocks": {"read_only": False}}},
},
)
await self.client.indices.put_settings(
await self.client().indices.put_settings(
index=base_version,
settings={"index.blocks.read_only": False},
)
Expand All @@ -122,7 +136,7 @@ async def create_index(self, index: str) -> None:
"""Create a new index with the given name."""
log.info("Create index", index=index)
try:
await self.client.indices.create(
await self.client().indices.create(
index=index,
mappings=make_entity_mapping(),
settings=INDEX_SETTINGS,
Expand All @@ -135,7 +149,7 @@ async def create_index(self, index: str) -> None:
async def delete_index(self, index: str) -> None:
"""Delete a given index if it exists."""
try:
await self.client.indices.delete(index=index)
await self.client().indices.delete(index=index)
except NotFoundError:
pass
except (ApiError, TransportError) as te:
Expand All @@ -144,7 +158,7 @@ async def delete_index(self, index: str) -> None:
async def exists_index_alias(self, alias: str, index: str) -> bool:
"""Check if an index exists and is linked into the given alias."""
try:
exists = await self.client.indices.exists_alias(name=alias, index=index)
exists = await self.client().indices.exists_alias(name=alias, index=index)
return True if exists.body else False
except NotFoundError:
return False
Expand All @@ -153,8 +167,9 @@ async def exists_index_alias(self, alias: str, index: str) -> bool:

async def check_health(self, index: str) -> bool:
try:
client = self.client.options(request_timeout=5)
health = await client.cluster.health(index=index, timeout=0)
health = await self.client(request_timeout=5).cluster.health(
index=index, timeout=0
)
return health.get("status") in ("yellow", "green")
except NotFoundError as nfe:
raise YenteNotFoundError(f"Index {index} does not exist.") from nfe
Expand Down Expand Up @@ -182,7 +197,7 @@ async def search(

try:
async with query_semaphore:
response = await self.client.search(
response = await self.client().search(
index=index,
query=query,
size=size,
Expand Down Expand Up @@ -218,7 +233,7 @@ async def bulk_index(self, entities: AsyncIterator[Dict[str, Any]]) -> None:
"""Index a list of entities into the search index."""
try:
await async_bulk(
self.client,
self.client(),
entities,
chunk_size=1000,
yield_ok=False,
Expand Down

0 comments on commit f272834

Please sign in to comment.