Skip to content

Commit

Permalink
feat: Support SQLAlchemy 2.0
Browse files Browse the repository at this point in the history
  • Loading branch information
edgarrmondragon committed Mar 6, 2023
1 parent 1450116 commit 73c65e3
Show file tree
Hide file tree
Showing 8 changed files with 26 additions and 53 deletions.
3 changes: 3 additions & 0 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ def tests(session: Session) -> None:
"-v",
"--durations=10",
*session.posargs,
env={
"SQLALCHEMY_WARN_20": "1",
},
)
finally:
if session.interactive:
Expand Down
17 changes: 0 additions & 17 deletions samples/sample_tap_sqlite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@

from typing import Any

import sqlalchemy

from singer_sdk import SQLConnector, SQLStream, SQLTap
from singer_sdk import typing as th

Expand All @@ -22,21 +20,6 @@ def get_sqlalchemy_url(self, config: dict[str, Any]) -> str:
"""Generates a SQLAlchemy URL for SQLite."""
return f"sqlite:///{config[DB_PATH_CONFIG]}"

def create_sqlalchemy_connection(self) -> sqlalchemy.engine.Connection:
"""Return a new SQLAlchemy connection using the provided config.
This override simply provides a more helpful error message on failure.
Returns:
A newly created SQLAlchemy engine object.
"""
try:
return super().create_sqlalchemy_connection()
except Exception as ex:
raise RuntimeError(
f"Error connecting to DB at '{self.config[DB_PATH_CONFIG]}': {ex}"
) from ex


class SQLiteStream(SQLStream):
"""The Stream class for SQLite.
Expand Down
17 changes: 0 additions & 17 deletions samples/sample_target_sqlite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@

from typing import Any

import sqlalchemy

from singer_sdk import SQLConnector, SQLSink, SQLTarget
from singer_sdk import typing as th

Expand All @@ -26,21 +24,6 @@ def get_sqlalchemy_url(self, config: dict[str, Any]) -> str:
"""Generates a SQLAlchemy URL for SQLite."""
return f"sqlite:///{config[DB_PATH_CONFIG]}"

def create_sqlalchemy_connection(self) -> sqlalchemy.engine.Connection:
"""Return a new SQLAlchemy connection using the provided config.
This override simply provides a more helpful error message on failure.
Returns:
A newly created SQLAlchemy engine object.
"""
try:
return super().create_sqlalchemy_connection()
except Exception as ex:
raise RuntimeError(
f"Error connecting to DB at '{self.config[DB_PATH_CONFIG]}'"
) from ex


class SQLiteSink(SQLSink):
"""The Sink class for SQLite.
Expand Down
2 changes: 1 addition & 1 deletion singer_sdk/connectors/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,7 +683,7 @@ def _create_empty_column(
column_add_ddl = self.get_column_add_ddl(
table_name=full_table_name, column_name=column_name, column_type=sql_type
)
with self._connect() as conn:
with self._connect() as conn, conn.begin():
conn.execute(column_add_ddl)

def prepare_schema(self, schema_name: str) -> None:
Expand Down
15 changes: 9 additions & 6 deletions singer_sdk/sinks/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,8 @@ def bulk_insert_records(
else (self.conform_record(record) for record in records)
)
self.logger.info("Inserting with SQL: %s", insert_sql)
self.connector.connection.execute(insert_sql, conformed_records)
with self.connector._connect() as conn, conn.begin():
conn.execute(insert_sql, conformed_records)
return len(conformed_records) if isinstance(conformed_records, list) else None

def merge_upsert_from_table(
Expand Down Expand Up @@ -371,10 +372,11 @@ def activate_version(self, new_version: int) -> None:
)

if self.config.get("hard_delete", True):
self.connection.execute(
f"DELETE FROM {self.full_table_name} "
f"WHERE {self.version_column_name} <= {new_version}"
)
with self.connector._connect() as conn, conn.begin():
conn.execute(
f"DELETE FROM {self.full_table_name} "
f"WHERE {self.version_column_name} <= {new_version}"
)
return

if not self.connector.column_exists(
Expand All @@ -397,7 +399,8 @@ def activate_version(self, new_version: int) -> None:
bindparam("deletedate", value=deleted_at, type_=sqlalchemy.types.DateTime),
bindparam("version", value=new_version, type_=sqlalchemy.types.Integer),
)
self.connector.connection.execute(query)
with self.connector._connect() as conn, conn.begin():
conn.execute(query)


__all__ = ["SQLSink", "SQLConnector"]
5 changes: 3 additions & 2 deletions singer_sdk/streams/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,9 @@ def get_records(self, context: dict | None) -> Iterable[dict[str, Any]]:
if self._MAX_RECORDS_LIMIT is not None:
query = query.limit(self._MAX_RECORDS_LIMIT)

for record in self.connector.connection.execute(query):
yield dict(record)
with self.connector._connect() as conn:
for record in conn.execute(query):
yield dict(record._mapping)


__all__ = ["SQLStream", "SQLConnector"]
2 changes: 1 addition & 1 deletion tests/core/test_connector_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def test_connect_calls_connect(self, connector):
def test_connect_raises_on_operational_failure(self, connector):
with pytest.raises(sqlalchemy.exc.OperationalError) as _:
with connector._connect() as conn:
conn.execute("SELECT * FROM fake_table")
conn.execute(sqlalchemy.text("SELECT * FROM fake_table"))

def test_rename_column_uses_connect_correctly(self, connector):
attached_engine = connector._engine
Expand Down
18 changes: 9 additions & 9 deletions tests/samples/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pathlib import Path

import pytest
from sqlalchemy import text

from samples.sample_tap_sqlite import SQLiteConnector, SQLiteTap
from singer_sdk._singerlib import Catalog
Expand All @@ -18,17 +19,16 @@ def csv_config(outdir: str) -> dict:


@pytest.fixture
def sqlite_sample_db(sqlite_connector):
def sqlite_sample_db(sqlite_connector: SQLiteConnector):
"""Return a path to a newly constructed sample DB."""
for t in range(3):
sqlite_connector.connection.execute(f"DROP TABLE IF EXISTS t{t}")
sqlite_connector.connection.execute(
f"CREATE TABLE t{t} (c1 int PRIMARY KEY, c2 varchar(10))"
)
for x in range(100):
sqlite_connector.connection.execute(
f"INSERT INTO t{t} VALUES ({x}, 'x={x}')"
with sqlite_connector._connect() as conn, conn.begin():
for t in range(3):
conn.execute(text(f"DROP TABLE IF EXISTS t{t}"))
conn.execute(
text(f"CREATE TABLE t{t} (c1 int PRIMARY KEY, c2 varchar(10))")
)
for x in range(100):
conn.execute(text(f"INSERT INTO t{t} VALUES ({x}, 'x={x}')"))


@pytest.fixture
Expand Down

0 comments on commit 73c65e3

Please sign in to comment.