Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature/retry graph commits #216

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mex/backend/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class ErrorResponse(BaseModel):

def handle_detailed_error(request: Request, exc: Exception) -> Response:
"""Handle detailed errors and provide debugging info."""
logger.exception("%s %s", type(exc), exc)
logger.exception("%s %s", type(exc).__name__, exc)
return Response(
content=ErrorResponse(
message=str(exc).strip(" "),
Expand Down
61 changes: 43 additions & 18 deletions mex/backend/graph/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@
from string import Template
from typing import Annotated, Any, Literal, cast

import backoff
from backoff.types import Details as BackoffDetails
from neo4j import Driver, GraphDatabase, NotificationDisabledCategory
from neo4j.exceptions import DriverError
from pydantic import Field

from mex.backend.fields import SEARCHABLE_CLASSES, SEARCHABLE_FIELDS
from mex.backend.graph.models import Result
from mex.backend.graph.query import QueryBuilder
from mex.backend.graph.query import Query, QueryBuilder
from mex.backend.graph.transform import expand_references_in_search_result
from mex.backend.settings import BackendSettings
from mex.common.connector import BaseConnector
Expand Down Expand Up @@ -141,25 +144,47 @@ def close(self) -> None:
"""Close the connector's underlying requests session."""
self.driver.close()

def commit(self, query: str, **parameters: Any) -> Result:
"""Send and commit a single graph transaction."""
message = Template(query).safe_substitute(
{
k: json.dumps(v, ensure_ascii=False)
for k, v in (parameters or {}).items()
}
)
@staticmethod
def _should_giveup_commit(error: Exception) -> bool:
"""When to give up on committing."""
return not cast(DriverError, error).is_retryable()

@staticmethod
def _on_commit_backoff(event: BackoffDetails) -> None:
"""Re-connect to the graph database."""
self = cast(GraphConnector, event["args"][0])
try:
with self.driver.session() as session:
result = Result(session.run(query, parameters))
except Exception as error:
logger.error("\n%s\n%s", message, error)
raise
if counters := result.get_update_counters():
logger.debug("\n%s\n%s", message, json.dumps(counters, indent=4))
self.close()
except DriverError as error:
logger.error("error closing before reconnect %s", error)
self.driver = self._init_driver()
self._check_connectivity_and_authentication()

@staticmethod
def _on_commit_giveup(event: BackoffDetails) -> None:
"""Log the query when giving up on committing."""
query = cast(Query, event["args"][1])
kwargs = event["kwargs"]
settings = BackendSettings.get()
if settings.debug:
params = {k: json.dumps(v, ensure_ascii=False) for k, v in kwargs.items()}
message = f"\n{Template(str(query)).safe_substitute(params)}"
else:
logger.debug("\n%s", message)
return result
message = f": {query!r}"
logger.error("error committing query: %s%s", message)

@backoff.on_exception(
backoff.fibo,
DriverError,
giveup=_should_giveup_commit,
on_backoff=_on_commit_backoff,
on_giveup=_on_commit_giveup,
max_time=1000,
)
def commit(self, query: Query, **parameters: Any) -> Result:
"""Send and commit a single graph transaction."""
with self.driver.session() as session:
return Result(session.run(str(query), parameters))

def _fetch_extracted_or_rule_items(
self,
Expand Down
38 changes: 34 additions & 4 deletions mex/backend/graph/query.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
from collections.abc import Callable
from typing import Annotated

from jinja2 import Environment, PackageLoader, StrictUndefined, select_autoescape
from black import Mode, format_str
from jinja2 import (
Environment,
PackageLoader,
StrictUndefined,
Template,
select_autoescape,
)
from pydantic import StringConstraints, validate_call

from mex.backend.settings import BackendSettings
Expand Down Expand Up @@ -30,6 +37,29 @@ def render_constraints(
return ", ".join(f"{f}: ${f}" for f in fields)


class Query:
"""Factory for rendering queries."""

REPR_MODE = Mode(line_length=1024)

def __init__(
self, name: str, template: Template, kwargs: dict[str, object]
) -> None:
"""Create a new query instance."""
self.name = name
self.template = template
self.kwargs = kwargs

def __str__(self) -> str:
"""Render the query for database execution."""
return self.template.render(**self.kwargs)

def __repr__(self) -> str:
"""Render the call to the query builder for logging and testing."""
kwargs_repr = ",".join(f"{k}={v!r}" for k, v in self.kwargs.items())
return format_str(f"{self.name}({kwargs_repr})", mode=self.REPR_MODE).strip()


class QueryBuilder(BaseConnector):
"""Wrapper around jinja template loading and rendering."""

Expand Down Expand Up @@ -64,10 +94,10 @@ def __init__(self) -> None:
rule_labels=list(RULE_MODEL_CLASSES_BY_NAME),
)

def __getattr__(self, name: str) -> Callable[..., str]:
"""Load the template with the given `name` and return its `render` method."""
def __getattr__(self, name: str) -> Callable[..., Query]:
"""Load the template with the given `name` and return a query factory."""
template = self._env.get_template(f"{name}.cql")
return template.render
return lambda **kwargs: Query(name, template, kwargs)

def close(self) -> None:
"""Clean up the connector."""
Expand Down
44 changes: 22 additions & 22 deletions pdm.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ license = { file = "LICENSE" }
urls = { Repository = "https://github.com/robert-koch-institut/mex-backend" }
requires-python = ">=3.11,<3.13"
dependencies = [
"backoff>=2,<3",
"black>=24,<25",
"fastapi>=0.115,<1",
"httpx>=0.27,<1",
"jinja2>=3,<4",
Expand All @@ -18,7 +20,6 @@ dependencies = [
"uvicorn[standard]>=0.30,<1",
]
optional-dependencies.dev = [
"black>=24,<25",
"ipdb>=0.13,<1",
"mypy>=1,<2",
"pytest-cov>=6,<7",
Expand Down
18 changes: 6 additions & 12 deletions tests/graph/test_connector.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from collections.abc import Callable
from unittest.mock import Mock

import pytest
from black import Mode, format_str
from black import DEFAULT_LINE_LENGTH
from pytest import MonkeyPatch

from mex.backend.graph import connector as connector_module
from mex.backend.graph.connector import MEX_EXTRACTED_PRIMARY_SOURCE, GraphConnector
from mex.backend.graph.query import QueryBuilder
from mex.backend.graph.query import Query
from mex.common.exceptions import MExError
from mex.common.models import (
MEX_PRIMARY_SOURCE_IDENTIFIER,
Expand All @@ -22,17 +21,12 @@


@pytest.fixture
def mocked_query_builder(monkeypatch: MonkeyPatch) -> None:
def __getattr__(_: QueryBuilder, query: str) -> Callable[..., str]:
return lambda **parameters: format_str(
f"{query}({','.join(f'{k}={v!r}' for k, v in parameters.items())})",
mode=Mode(line_length=78),
).strip()
def mocked_query_class(monkeypatch: MonkeyPatch) -> None:
monkeypatch.setattr(Query, "__str__", Query.__repr__)
monkeypatch.setattr(Query.REPR_MODE, "line_length", DEFAULT_LINE_LENGTH)

monkeypatch.setattr(QueryBuilder, "__getattr__", __getattr__)


@pytest.mark.usefixtures("mocked_query_builder")
@pytest.mark.usefixtures("mocked_query_class")
def test_check_connectivity_and_authentication(mocked_graph: MockedGraph) -> None:
mocked_graph.return_value = [{"currentStatus": "online"}]
graph = GraphConnector.get()
Expand Down
Loading