Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Add 'execute_sql' command on caches, add DuckDB WAL cleanup step #407

Merged
merged 5 commits into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion airbyte/_processors/sql/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from duckdb_engine import DuckDBEngineWarning
from overrides import overrides
from pydantic import Field
from sqlalchemy import text

from airbyte._writers.jsonl import JsonlWriter
from airbyte.secrets.base import SecretString
Expand All @@ -19,7 +20,7 @@


if TYPE_CHECKING:
from sqlalchemy.engine import Engine
from sqlalchemy.engine import Connection, Engine


# @dataclass
Expand Down Expand Up @@ -161,3 +162,23 @@ def _write_files_to_new_table(
)
self._execute_sql(insert_statement)
return temp_table_name

def _do_checkpoint(
self,
connection: Connection | None,
aaronsteers marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
"""Checkpoint the given connection.

We override this method to ensure that the DuckDB WAL is checkpointed explicitly.
Otherwise DuckDB will lazily flush the WAL to disk, which can cause issues for users
who want to manipulate the DB files after writing them.

For more info:
- https://duckdb.org/docs/sql/statements/checkpoint.html
"""
if connection is not None:
connection.execute(text("CHECKPOINT"))
return

with self.get_sql_connection() as new_conn:
new_conn.execute(text("CHECKPOINT"))
25 changes: 25 additions & 0 deletions airbyte/caches/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import pyarrow as pa
import pyarrow.dataset as ds
from pydantic import Field, PrivateAttr
from sqlalchemy import text

from airbyte_protocol.models import ConfiguredAirbyteCatalog

Expand Down Expand Up @@ -110,6 +111,30 @@ def config_hash(self) -> str | None:
"""
return super(SqlConfig, self).config_hash

def execute_sql(self, sql: str | list[str]) -> None:
"""Execute one or more SQL statements against the cache's SQL backend.

If multiple SQL statements are given, they are executed in order,
within the same transaction.

This method is useful for creating tables, indexes, and other
schema objects in the cache. It does not return any results and it
automatically closes the connection after executing all statements.

This method is not intended for querying data. For that, use the `get_records`
method - or for a low-level interface, use the `get_sql_engine` method.

If any of the statements fail, the transaction is canceled and an exception
is raised. Most databases will rollback the transaction in this case.
"""
if isinstance(sql, str):
# Coerce to a list if a single string is given
sql = [sql]

with self.processor.get_sql_connection() as connection:
for sql_statement in sql:
connection.execute(text(sql_statement))

@final
@property
def processor(self) -> SqlProcessorBase:
Expand Down
14 changes: 13 additions & 1 deletion airbyte/shared/sql_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,19 @@ def _setup(self) -> None: # noqa: B027 # Intentionally empty, not abstract
"""
pass

def _do_checkpoint( # noqa: B027 # Intentionally empty, not abstract
self,
connection: Connection | None,
) -> None:
"""Checkpoint the given connection.

If the WAL log needs to be, it will be flushed.

For most SQL databases, this is a no-op. However, it exists so that
subclasses can override this method to perform a checkpoint operation.
"""
pass

# Public interface:

@property
Expand Down Expand Up @@ -378,7 +391,6 @@ def get_sql_connection(self) -> Generator[sqlalchemy.engine.Connection, None, No
self._init_connection_settings(connection)
yield connection

connection.close()
del connection

def get_sql_table_name(
Expand Down
4 changes: 4 additions & 0 deletions airbyte/sources/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,6 +748,10 @@ def _read_to_cache( # noqa: PLR0913 # Too many arguments
state_writer=state_writer,
progress_tracker=progress_tracker,
)

# Flush the WAL, if applicable
cache.processor._do_checkpoint() # noqa: SLF001 # Non-public API

return ReadResult(
source_name=self.name,
progress_tracker=progress_tracker,
Expand Down
Loading