Skip to content

Commit

Permalink
add 'execute_sql' command on caches
Browse files Browse the repository at this point in the history
  • Loading branch information
aaronsteers committed Oct 1, 2024
1 parent 4cd8167 commit 0ec04f7
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 2 deletions.
17 changes: 16 additions & 1 deletion airbyte/_processors/sql/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from duckdb_engine import DuckDBEngineWarning
from overrides import overrides
from pydantic import Field
from sqlalchemy import text

from airbyte._writers.jsonl import JsonlWriter
from airbyte.secrets.base import SecretString
Expand All @@ -19,7 +20,7 @@


if TYPE_CHECKING:
from sqlalchemy.engine import Engine
from sqlalchemy.engine import Connection, Engine


# @dataclass
Expand Down Expand Up @@ -161,3 +162,17 @@ def _write_files_to_new_table(
)
self._execute_sql(insert_statement)
return temp_table_name

def _close_connection(
self,
connection: Connection,
) -> None:
"""Close the given connection.
We override this method to ensure that the DuckDB WAL is checkpointed before closing.
For more info:
- https://duckdb.org/docs/sql/statements/checkpoint.html
"""
connection.execute(text("CHECKPOINT"))
super()._close_connection(connection)
25 changes: 25 additions & 0 deletions airbyte/caches/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import pyarrow as pa
import pyarrow.dataset as ds
from pydantic import Field, PrivateAttr
from sqlalchemy import text

from airbyte_protocol.models import ConfiguredAirbyteCatalog

Expand Down Expand Up @@ -110,6 +111,30 @@ def config_hash(self) -> str | None:
"""
return super(SqlConfig, self).config_hash

def execute_sql(self, sql: str | list[str]) -> None:
"""Execute one or more SQL statements against the cache's SQL backend.
If multiple SQL statements are given, they are executed in order,
within the same transaction.
This method is useful for creating tables, indexes, and other
schema objects in the cache. It does not return any results and it
automatically closes the connection after executing all statements.
This method is not intended for querying data. For that, use the `get_records`
method - or for a low-level interface, use the `get_sql_engine` method.
If any of the statements fail, the transaction is canceled and an exception
is raised. Most databases will rollback the transaction in this case.
"""
if isinstance(sql, str):
# Coerce to a list if a single string is given
sql = [sql]

with self.processor.get_sql_connection() as connection:
for sql_statement in sql:
connection.execute(text(sql_statement))

@final
@property
def processor(self) -> SqlProcessorBase:
Expand Down
13 changes: 12 additions & 1 deletion airbyte/shared/sql_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,17 @@ def get_sql_engine(self) -> Engine:
"""Return a new SQL engine to use."""
return self.sql_config.get_sql_engine()

def _close_connection(
self,
connection: Connection,
) -> None:
"""Close the given connection.
Subclasses can override this method to perform additional cleanup, such
as WAL checkpointing.
"""
connection.close()

@contextmanager
def get_sql_connection(self) -> Generator[sqlalchemy.engine.Connection, None, None]:
"""A context manager which returns a new SQL connection for running queries.
Expand All @@ -378,7 +389,7 @@ def get_sql_connection(self) -> Generator[sqlalchemy.engine.Connection, None, No
self._init_connection_settings(connection)
yield connection

connection.close()
self._close_connection(connection)
del connection

def get_sql_table_name(
Expand Down

0 comments on commit 0ec04f7

Please sign in to comment.