Skip to content

Commit

Permalink
Adding retries to new database task sessions (#5448)
Browse files Browse the repository at this point in the history
  • Loading branch information
galvana authored Nov 19, 2024
1 parent 8c59e5d commit c7e645f
Show file tree
Hide file tree
Showing 7 changed files with 173 additions and 46 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ The types of changes are:
### Changed
- Allow hiding systems via a `hidden` parameter and add two flags on the `/system` api endpoint; `show_hidden` and `dnd_relevant`, to display only systems with integrations [#5484](https://github.com/ethyca/fides/pull/5484)
- Updated POST taxonomy endpoints to handle creating resources without specifying fides_key [#5468](https://github.com/ethyca/fides/pull/5468)
- Disabled connection pooling for task workers and added retries and keep-alive configurations for database connections [#5448](https://github.com/ethyca/fides/pull/5448)

### Developer Experience
- Fixing BigQuery integration tests [#5491](https://github.com/ethyca/fides/pull/5491)
Expand Down
3 changes: 3 additions & 0 deletions src/fides/api/api/deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ def get_api_session() -> Session:
config=CONFIG,
pool_size=CONFIG.database.api_engine_pool_size,
max_overflow=CONFIG.database.api_engine_max_overflow,
keepalives_idle=CONFIG.database.api_engine_keepalives_idle,
keepalives_interval=CONFIG.database.api_engine_keepalives_interval,
keepalives_count=CONFIG.database.api_engine_keepalives_count,
)
SessionLocal = get_db_session(CONFIG, engine=_engine)
db = SessionLocal()
Expand Down
42 changes: 34 additions & 8 deletions src/fides/api/db/session.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from __future__ import annotations

from typing import Any, Dict

from loguru import logger
from sqlalchemy import create_engine
from sqlalchemy.engine import Engine
from sqlalchemy.engine.url import URL
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.pool import NullPool

from fides.api.common_exceptions import MissingConfig
from fides.api.db.util import custom_json_deserializer, custom_json_serializer
Expand All @@ -17,6 +20,10 @@ def get_db_engine(
database_uri: str | URL | None = None,
pool_size: int = 50,
max_overflow: int = 50,
keepalives_idle: int | None = None,
keepalives_interval: int | None = None,
keepalives_count: int | None = None,
disable_pooling: bool = False,
) -> Engine:
"""Return a database engine.
Expand All @@ -32,14 +39,33 @@ def get_db_engine(
database_uri = config.database.sqlalchemy_test_database_uri
else:
database_uri = config.database.sqlalchemy_database_uri
return create_engine(
database_uri,
pool_pre_ping=True,
pool_size=pool_size,
max_overflow=max_overflow,
json_serializer=custom_json_serializer,
json_deserializer=custom_json_deserializer,
)

engine_args: Dict[str, Any] = {
"json_serializer": custom_json_serializer,
"json_deserializer": custom_json_deserializer,
}

# keepalives settings
connect_args = {}
if keepalives_idle:
connect_args["keepalives_idle"] = keepalives_idle
if keepalives_interval:
connect_args["keepalives_interval"] = keepalives_interval
if keepalives_count:
connect_args["keepalives_count"] = keepalives_count

if connect_args:
connect_args["keepalives"] = 1
engine_args["connect_args"] = connect_args

if disable_pooling:
engine_args["poolclass"] = NullPool
else:
engine_args["pool_pre_ping"] = True
engine_args["pool_size"] = pool_size
engine_args["max_overflow"] = max_overflow

return create_engine(database_uri, **engine_args)


def get_db_session(
Expand Down
34 changes: 32 additions & 2 deletions src/fides/api/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,15 @@

from celery import Celery, Task
from loguru import logger
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import Session
from tenacity import (
RetryCallState,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)

from fides.api.db.session import get_db_engine, get_db_session
from fides.api.util.logger import setup as setup_logging
Expand All @@ -11,6 +19,7 @@
MESSAGING_QUEUE_NAME = "fidesops.messaging"
PRIVACY_PREFERENCES_QUEUE_NAME = "fides.privacy_preferences" # This queue is used in Fidesplus for saving privacy preferences and notices served

NEW_SESSION_RETRIES = 5

autodiscover_task_locations: List[str] = [
"fides.api.tasks",
Expand All @@ -20,10 +29,29 @@
]


def log_retry_attempt(retry_state: RetryCallState) -> None:
"""Log database connection retry attempts."""

logger.warning(
"Database connection attempt {} failed. Retrying in {} seconds...",
retry_state.attempt_number,
retry_state.next_action.sleep, # type: ignore[union-attr]
)


class DatabaseTask(Task): # pylint: disable=W0223
_task_engine = None
_sessionmaker = None

# This retry will attempt to connect 5 times with an exponential backoff (2, 4, 8, 16 seconds between each attempt).
# The original error will be re-raised if the retries are successful. All attempts are shown in the logs.
@retry(
stop=stop_after_attempt(NEW_SESSION_RETRIES),
wait=wait_exponential(multiplier=1, min=1),
retry=retry_if_exception_type(OperationalError),
reraise=True,
before_sleep=log_retry_attempt,
)
def get_new_session(self) -> ContextManager[Session]:
"""
Creates a new Session to be used for each task invocation.
Expand All @@ -36,8 +64,10 @@ def get_new_session(self) -> ContextManager[Session]:
if self._task_engine is None:
self._task_engine = get_db_engine(
config=CONFIG,
pool_size=CONFIG.database.task_engine_pool_size,
max_overflow=CONFIG.database.task_engine_max_overflow,
keepalives_idle=CONFIG.database.task_engine_keepalives_idle,
keepalives_interval=CONFIG.database.task_engine_keepalives_interval,
keepalives_count=CONFIG.database.task_engine_keepalives_count,
disable_pooling=True,
)

# same for the sessionmaker
Expand Down
24 changes: 24 additions & 0 deletions src/fides/config/database_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,18 @@ class DatabaseSettings(FidesSettings):
default=50,
description="Number of additional 'overflow' concurrent database connections Fides will use for API requests if the pool reaches the limit. These overflow connections are discarded afterwards and not maintained.",
)
api_engine_keepalives_idle: int = Field(
default=30,
description="Number of seconds of inactivity before the client sends a TCP keepalive packet to verify the database connection is still alive.",
)
api_engine_keepalives_interval: int = Field(
default=10,
description="Number of seconds between TCP keepalive retries if the initial keepalive packet receives no response. These are client-side retries.",
)
api_engine_keepalives_count: int = Field(
default=5,
description="Maximum number of TCP keepalive retries before the client considers the connection dead and closes it.",
)
db: str = Field(
default="default_db", description="The name of the application database."
)
Expand Down Expand Up @@ -61,6 +73,18 @@ class DatabaseSettings(FidesSettings):
default=50,
description="Number of additional 'overflow' concurrent database connections Fides will use for executing privacy request tasks, either locally or on each worker, if the pool reaches the limit. These overflow connections are discarded afterwards and not maintained.",
)
task_engine_keepalives_idle: int = Field(
default=30,
description="Number of seconds of inactivity before the client sends a TCP keepalive packet to verify the database connection is still alive.",
)
task_engine_keepalives_interval: int = Field(
default=10,
description="Number of seconds between TCP keepalive retries if the initial keepalive packet receives no response. These are client-side retries.",
)
task_engine_keepalives_count: int = Field(
default=5,
description="Maximum number of TCP keepalive retries before the client considers the connection dead and closes it.",
)
test_db: str = Field(
default="default_test_db",
description="Used instead of the 'db' value when the FIDES_TEST_MODE environment variable is set to True. Avoids overwriting production data.",
Expand Down
37 changes: 1 addition & 36 deletions tests/ops/tasks/test_celery.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,7 @@
# pylint: disable=protected-access
import pytest
from sqlalchemy.engine import Engine
from sqlalchemy.orm import Session
from sqlalchemy.pool import QueuePool

from fides.api.tasks import DatabaseTask, _create_celery
from fides.api.tasks import _create_celery
from fides.config import CONFIG, CelerySettings, get_config


@pytest.fixture
def mock_config_changed_db_engine_settings():
pool_size = CONFIG.database.task_engine_pool_size
CONFIG.database.task_engine_pool_size = pool_size + 5
max_overflow = CONFIG.database.task_engine_max_overflow
CONFIG.database.task_engine_max_overflow = max_overflow + 5
yield
CONFIG.database.task_engine_pool_size = pool_size
CONFIG.database.task_engine_max_overflow = max_overflow


def test_create_task(celery_session_app, celery_session_worker):
@celery_session_app.task
def multiply(x, y):
Expand Down Expand Up @@ -70,21 +53,3 @@ def test_celery_config_override() -> None:
celery_app = _create_celery(config=config)
assert celery_app.conf["event_queue_prefix"] == "overridden_fides_worker"
assert celery_app.conf["task_default_queue"] == "overridden_fides"


@pytest.mark.parametrize(
"config_fixture", [None, "mock_config_changed_db_engine_settings"]
)
def test_get_task_session(config_fixture, request):
if config_fixture is not None:
request.getfixturevalue(
config_fixture
) # used to invoke config fixture if provided
pool_size = CONFIG.database.task_engine_pool_size
max_overflow = CONFIG.database.task_engine_max_overflow
t = DatabaseTask()
session: Session = t.get_new_session()
engine: Engine = session.get_bind()
pool: QueuePool = engine.pool
assert pool.size() == pool_size
assert pool._max_overflow == max_overflow
78 changes: 78 additions & 0 deletions tests/ops/tasks/test_database_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# pylint: disable=protected-access

from unittest import mock

import pytest
from sqlalchemy.engine import Engine
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import Session
from sqlalchemy.pool import NullPool

from fides.api.tasks import NEW_SESSION_RETRIES, DatabaseTask
from fides.config import CONFIG


class TestDatabaseTask:
@pytest.fixture
def mock_config_changed_db_engine_settings(self):
pool_size = CONFIG.database.task_engine_pool_size
CONFIG.database.task_engine_pool_size = pool_size + 5
max_overflow = CONFIG.database.task_engine_max_overflow
CONFIG.database.task_engine_max_overflow = max_overflow + 5
yield
CONFIG.database.task_engine_pool_size = pool_size
CONFIG.database.task_engine_max_overflow = max_overflow

@pytest.fixture
def recovering_session_maker(self):
"""Fixture that fails twice then succeeds"""
mock_session = mock.Mock()
mock_maker = mock.Mock()
mock_maker.side_effect = [
OperationalError("connection failed", None, None),
OperationalError("connection failed", None, None),
mock_session,
]
return mock_maker, mock_session

@pytest.fixture
def always_failing_session_maker(self):
"""Fixture that always fails with OperationalError"""
mock_maker = mock.Mock()
mock_maker.side_effect = OperationalError("connection failed", None, None)
return mock_maker

@pytest.mark.parametrize(
"config_fixture", [None, "mock_config_changed_db_engine_settings"]
)
def test_get_task_session(self, config_fixture, request):
if config_fixture is not None:
request.getfixturevalue(
config_fixture
) # used to invoke config fixture if provided
pool_size = CONFIG.database.task_engine_pool_size
max_overflow = CONFIG.database.task_engine_max_overflow
t = DatabaseTask()
session: Session = t.get_new_session()
engine: Engine = session.get_bind()
assert isinstance(engine.pool, NullPool)

def test_retry_on_operational_error(self, recovering_session_maker):
"""Test that session creation retries on OperationalError"""

mock_maker, mock_session = recovering_session_maker

task = DatabaseTask()
with mock.patch.object(task, "_sessionmaker", mock_maker):
session = task.get_new_session()
assert session == mock_session
assert mock_maker.call_count == 3

def test_max_retries_exceeded(mock_db_task, always_failing_session_maker):
"""Test that retries stop after max attempts"""
task = DatabaseTask()
with mock.patch.object(task, "_sessionmaker", always_failing_session_maker):
with pytest.raises(OperationalError):
with task.get_new_session():
pass
assert always_failing_session_maker.call_count == NEW_SESSION_RETRIES

0 comments on commit c7e645f

Please sign in to comment.