Skip to content

Commit

Permalink
Merge pull request #298 from duckdb/jwills_add_retries_for_some_errors
Browse files Browse the repository at this point in the history
Add support for retrying certain types of exceptions we see when running models with DuckDB
  • Loading branch information
jwills authored Dec 23, 2023
2 parents 36b4ec1 + 6c9dffe commit c951148
Show file tree
Hide file tree
Showing 5 changed files with 185 additions and 1 deletion.
20 changes: 20 additions & 0 deletions dbt/adapters/duckdb/credentials.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import time
from dataclasses import dataclass
from dataclasses import field
from functools import lru_cache
from typing import Any
from typing import Dict
Expand Down Expand Up @@ -61,6 +62,20 @@ class Remote(dbtClassMixin):
password: Optional[str] = None


@dataclass
class Retries(dbtClassMixin):
# The number of times to attempt the initial duckdb.connect call
# (to wait for another process to free the lock on the DB file)
connect_attempts: int = 1

# The number of times to attempt to execute a DuckDB query that throws
# one of the retryable exceptions
query_attempts: Optional[int] = None

# The list of exceptions that we are willing to retry on
retryable_exceptions: List[str] = field(default_factory=lambda: ["IOException"])


@dataclass
class DuckDBCredentials(Credentials):
database: str = "main"
Expand Down Expand Up @@ -126,6 +141,11 @@ class DuckDBCredentials(Credentials):
# provide helper functions for dbt Python models.
module_paths: Optional[List[str]] = None

# An optional strategy for allowing retries when certain types of
# exceptions occur on a model run (e.g., IOExceptions that were caused
# by networking issues)
retries: Optional[Retries] = None

@classmethod
def __pre_deserialize__(cls, data: Dict[Any, Any]) -> Dict[Any, Any]:
data = super().__pre_deserialize__(data)
Expand Down
72 changes: 71 additions & 1 deletion dbt/adapters/duckdb/environments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import os
import sys
import tempfile
import time
from typing import Dict
from typing import List
from typing import Optional

import duckdb
Expand Down Expand Up @@ -31,6 +33,44 @@ def _ensure_event_loop():
asyncio.set_event_loop(loop)


class RetryableCursor:
def __init__(self, cursor, retry_attempts: int, retryable_exceptions: List[str]):
self._cursor = cursor
self._retry_attempts = retry_attempts
self._retryable_exceptions = retryable_exceptions

def execute(self, sql: str, bindings=None):
attempt, success, exc = 0, False, None
while not success and attempt < self._retry_attempts:
try:
if bindings is None:
self._cursor.execute(sql)
else:
self._cursor.execute(sql, bindings)
success = True
except Exception as e:
exception_name = type(e).__name__
if exception_name in self._retryable_exceptions:
time.sleep(2**attempt)
exc = e
attempt += 1
else:
print(f"Did not retry exception named '{exception_name}'")
raise e
if not success:
if exc:
raise exc
else:
raise RuntimeError(
"execute call failed, but no exceptions raised- this should be impossible"
)
return self

# forward along all non-execute() methods/attribute look-ups
def __getattr__(self, name):
return getattr(self._cursor, name)


class Environment(abc.ABC):
"""An Environment is an abstraction to describe *where* the code you execute in your dbt-duckdb project
actually runs. This could be the local Python process that runs dbt (which is the default),
Expand Down Expand Up @@ -74,7 +114,32 @@ def initialize_db(
cls, creds: DuckDBCredentials, plugins: Optional[Dict[str, BasePlugin]] = None
):
config = creds.config_options or {}
conn = duckdb.connect(creds.path, read_only=False, config=config)

if creds.retries:
success, attempt, exc = False, 0, None
while not success and attempt < creds.retries.connect_attempts:
try:
conn = duckdb.connect(creds.path, read_only=False, config=config)
success = True
except Exception as e:
exception_name = type(e).__name__
if exception_name in creds.retries.retryable_exceptions:
time.sleep(2**attempt)
exc = e
attempt += 1
else:
print(f"Did not retry exception named '{exception_name}'")
raise e
if not success:
if exc:
raise exc
else:
raise RuntimeError(
"connect call failed, but no exceptions raised- this should be impossible"
)

else:
conn = duckdb.connect(creds.path, read_only=False, config=config)

# install any extensions on the connection
if creds.extensions is not None:
Expand Down Expand Up @@ -127,6 +192,11 @@ def initialize_cursor(
for df_name, df in registered_df.items():
cursor.register(df_name, df)

if creds.retries and creds.retries.query_attempts:
cursor = RetryableCursor(
cursor, creds.retries.query_attempts, creds.retries.retryable_exceptions
)

return cursor

@classmethod
Expand Down
1 change: 1 addition & 0 deletions tests/functional/plugins/test_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def profiles_config_update(self, dbt_profile_target, sqlite_test_db):
"type": "duckdb",
"path": dbt_profile_target.get("path", ":memory:"),
"plugins": plugins,
"retries": {"query_attempts": 2},
}
},
"target": "dev",
Expand Down
36 changes: 36 additions & 0 deletions tests/unit/test_retries_connect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import pytest
from unittest.mock import patch

from duckdb.duckdb import IOException

from dbt.adapters.duckdb.credentials import DuckDBCredentials
from dbt.adapters.duckdb.credentials import Retries
from dbt.adapters.duckdb.environments import Environment

class TestConnectRetries:

@pytest.fixture
def creds(self):
# Create a mock credentials object
return DuckDBCredentials(
path="foo.db",
retries=Retries(connect_attempts=2, retryable_exceptions=["IOException", "ArithmeticError"])
)

@pytest.mark.parametrize("exception", [None, IOException, ArithmeticError, ValueError])
def test_initialize_db(self, creds, exception):
# Mocking the duckdb.connect method
with patch('duckdb.connect') as mock_connect:
if exception:
mock_connect.side_effect = [exception, None]

if exception == ValueError:
with pytest.raises(ValueError) as excinfo:
Environment.initialize_db(creds)
else:
# Call the initialize_db method
Environment.initialize_db(creds)
if exception in {IOException, ArithmeticError}:
assert mock_connect.call_count == creds.retries.connect_attempts
else:
mock_connect.assert_called_once_with(creds.path, read_only=False, config={})
57 changes: 57 additions & 0 deletions tests/unit/test_retries_query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import pytest
from unittest.mock import MagicMock
from unittest.mock import patch

import duckdb

from dbt.adapters.duckdb.credentials import Retries
from dbt.adapters.duckdb.environments import RetryableCursor

class TestRetryableCursor:

@pytest.fixture
def mock_cursor(self):
return MagicMock()

@pytest.fixture
def mock_retries(self):
return Retries(query_attempts=3)

@pytest.fixture
def retry_cursor(self, mock_cursor, mock_retries):
return RetryableCursor(
mock_cursor,
mock_retries.query_attempts,
mock_retries.retryable_exceptions)

def test_successful_execute(self, mock_cursor, retry_cursor):
""" Test that execute successfully runs the SQL query. """
sql_query = "SELECT * FROM table"
retry_cursor.execute(sql_query)
mock_cursor.execute.assert_called_once_with(sql_query)

def test_retry_on_failure(self, mock_cursor, retry_cursor):
""" Test that execute retries the SQL query on failure. """
mock_cursor.execute.side_effect = [duckdb.duckdb.IOException, None]
sql_query = "SELECT * FROM table"
retry_cursor.execute(sql_query)
assert mock_cursor.execute.call_count == 2

def test_no_retry_on_non_retryable_exception(self, mock_cursor, retry_cursor):
""" Test that a non-retryable exception is not retried. """
mock_cursor.execute.side_effect = ValueError
sql_query = "SELECT * FROM table"
with pytest.raises(ValueError):
retry_cursor.execute(sql_query)
mock_cursor.execute.assert_called_once_with(sql_query)

def test_exponential_backoff(self, mock_cursor, retry_cursor):
""" Test that exponential backoff is applied between retries. """
mock_cursor.execute.side_effect = [duckdb.duckdb.IOException, duckdb.duckdb.IOException, None]
sql_query = "SELECT * FROM table"

with patch("time.sleep") as mock_sleep:
retry_cursor.execute(sql_query)
assert mock_sleep.call_count == 2
mock_sleep.assert_any_call(1)
mock_sleep.assert_any_call(2)

0 comments on commit c951148

Please sign in to comment.