Skip to content

Commit

Permalink
Unpin SQLAlchemy & sqlalchemy-bigquery (#367)
Browse files Browse the repository at this point in the history
This commit unpins `SQLAlchemy` and `sqlalchemy-bigquery` dependency so we can use it with Airflow 2.3.

This PR also fixes the Sqlite issues that we were hacking around with quotes. I have also created a companion PR in Airflow: apache/airflow#23790 . Once this PR is merged and release, we can bump the Sqlite Provider and remove the logic of `get_uri` from this repo.

closes #351

(cherry picked from commit ba4e6bd)
  • Loading branch information
kaxil committed May 19, 2022
1 parent e077c60 commit ca13ec4
Show file tree
Hide file tree
Showing 8 changed files with 40 additions and 21 deletions.
2 changes: 1 addition & 1 deletion .github/ci-test-connections.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ connections:
schema: null
- conn_id: sqlite_conn
conn_type: sqlite
host: ////tmp/sqlite.db
host: /tmp/sqlite.db
schema:
login:
password:
Expand Down
3 changes: 3 additions & 0 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ def test(session: nox.Session) -> None:
"""Run unit tests."""
session.install("-e", ".[all]")
session.install("-e", ".[tests]")
# Log all the installed dependencies
session.log("Installed Dependencies:")
session.run("pip3", "freeze")
session.run("airflow", "db", "init")
session.run("pytest", *session.posargs)

Expand Down
13 changes: 7 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ dependencies = [
"SQLAlchemy>=1.3.18,<=1.3.24",
"markupsafe>=1.1.1,<2.1.0",
"smart-open",
"pyarrow"
"pyarrow",
"SQLAlchemy>=1.3.18"
]

keywords = ["airflow", "provider", "astronomer", "sql", "decorator", "task flow", "elt", "etl", "dag"]
Expand Down Expand Up @@ -49,7 +50,7 @@ tests = [
]
google = [
"apache-airflow-providers-google",
"sqlalchemy-bigquery==1.3.0",
"sqlalchemy-bigquery>=1.3.0",
"smart-open[gcs]>=5.2.1",
]
snowflake = [
Expand All @@ -68,13 +69,13 @@ amazon = [
all = [
"apache-airflow-providers-amazon",
"apache-airflow-providers-google>=6.4.0",
"sqlalchemy-bigquery==1.3.0",
"smart-open[all]>=5.2.1",
"s3fs",
"apache-airflow-providers-postgres",
"apache-airflow-providers-snowflake",
"smart-open[all]>=5.2.1",
"snowflake-sqlalchemy>=1.2.0,<=1.2.4",
"snowflake-connector-python[pandas]",
"apache-airflow-providers-postgres"
"sqlalchemy-bigquery>=1.3.0",
"s3fs"
]
doc = [
"sphinx==4.4.0"
Expand Down
8 changes: 7 additions & 1 deletion src/astro/sql/operators/agnostic_save_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from astro.constants import Database
from astro.sql.table import Table
from astro.sqlite_utils import create_sqlalchemy_engine_with_sqlite
from astro.utils.cloud_storage_creds import gcs_client, s3fs_creds
from astro.utils.database import get_database_from_conn_id
from astro.utils.dependencies import BigQueryHook, PostgresHook, SnowflakeHook
Expand Down Expand Up @@ -130,9 +131,14 @@ def convert_sql_table_to_dataframe(
f"Support types: {list(hook_class.keys())}"
)

if database == Database.SQLITE:
con_engine = create_sqlalchemy_engine_with_sqlite(input_hook)
else:
con_engine = input_hook.get_sqlalchemy_engine()

return pd.read_sql(
f"SELECT * FROM {input_table.qualified_name()}",
con=input_hook.get_sqlalchemy_engine(),
con=con_engine,
)

def agnostic_write_file(self, df: pd.DataFrame, output_file_path: str) -> None:
Expand Down
3 changes: 2 additions & 1 deletion src/astro/sql/operators/sql_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from astro.constants import Database
from astro.settings import SCHEMA
from astro.sql.table import Table, TempTable, create_table_name
from astro.sqlite_utils import create_sqlalchemy_engine_with_sqlite
from astro.utils import get_hook
from astro.utils.database import get_database_from_conn_id
from astro.utils.dependencies import (
Expand Down Expand Up @@ -150,7 +151,7 @@ def _get_dataframe(self, table: Table):
)
elif database == Database.SQLITE:
hook = SqliteHook(sqlite_conn_id=table.conn_id, database=table.database)
engine = hook.get_sqlalchemy_engine()
engine = create_sqlalchemy_engine_with_sqlite(hook)
df = pd.read_sql_table(table.table_name, engine)
elif database == Database.BIGQUERY:
hook = BigQueryHook(gcp_conn_id=table.conn_id)
Expand Down
11 changes: 11 additions & 0 deletions src/astro/sqlite_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import sqlalchemy
from airflow.providers.sqlite.hooks.sqlite import SqliteHook


# TODO: This function should be removed after the refactor as this is handled in the Database
def create_sqlalchemy_engine_with_sqlite(hook: SqliteHook) -> sqlalchemy.engine.Engine:
# Airflow uses sqlite3 library and not SqlAlchemy for SqliteHook
# and it only uses the hostname directly.
airflow_conn = hook.get_connection(getattr(hook, hook.conn_name_attr))
engine = sqlalchemy.create_engine(f"sqlite:///{airflow_conn.host}")
return engine
13 changes: 5 additions & 8 deletions src/astro/utils/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

from airflow.hooks.base import BaseHook
from airflow.providers.sqlite.hooks.sqlite import SqliteHook
from sqlalchemy import create_engine, text
from sqlalchemy.engine import Engine
from sqlalchemy.engine.result import ResultProxy
from sqlalchemy import text
from sqlalchemy.engine import Engine, ResultProxy

from astro.constants import CONN_TYPE_TO_DATABASE, Database
from astro.sqlite_utils import create_sqlalchemy_engine_with_sqlite
from astro.utils.dependencies import BigQueryHook, PostgresHook, SnowflakeHook


Expand Down Expand Up @@ -64,10 +64,7 @@ def get_sqlalchemy_engine(hook: Union[BaseHook, SqliteHook]) -> Engine:
database = get_database_name(hook)
engine = None
if database == Database.SQLITE:
uri = hook.get_uri()
if "////" not in uri:
uri = hook.get_uri().replace("///", "////")
engine = create_engine(uri)
engine = create_sqlalchemy_engine_with_sqlite(hook)
if engine is None:
engine = hook.get_sqlalchemy_engine()
return engine
Expand All @@ -88,7 +85,7 @@ def run_sql(
:param parameters: (optional) Parameters to be passed to the SQL statement
:type parameters: dict
:return: Result of running the statement
:rtype: sqlalchemy.engine.result.ResultProxy
:rtype: sqlalchemy.engine.ResultProxy
"""
if parameters is None:
parameters = {}
Expand Down
8 changes: 4 additions & 4 deletions tests/utils/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from sqlalchemy.engine.base import Engine

from astro.constants import Database
from astro.sqlite_utils import create_sqlalchemy_engine_with_sqlite
from astro.utils.database import (
get_database_from_conn_id,
get_database_name,
Expand Down Expand Up @@ -59,7 +60,7 @@ def with_sqlite_hook():
hook = SqliteHook()
db = get_database_name(hook)
assert db == Database.SQLITE
engine = hook.get_sqlalchemy_engine()
engine = create_sqlalchemy_engine_with_sqlite(hook)
db = get_database_name(engine)
assert db == Database.SQLITE

Expand All @@ -74,10 +75,9 @@ def with_unsupported_hook():
def describe_get_sqlalchemy_engine():
def with_sqlite():
hook = SqliteHook(sqlite_conn_id="sqlite_conn")
engine = get_sqlalchemy_engine(hook)
engine = create_sqlalchemy_engine_with_sqlite(hook)
assert isinstance(engine, Engine)
url = urlparse(str(engine.url))
assert url.path == "////tmp/sqlite.db"
assert engine.url.database == BaseHook.get_connection("sqlite_conn").host

def with_sqlite_default_conn():
hook = SqliteHook()
Expand Down

0 comments on commit ca13ec4

Please sign in to comment.