Skip to content

Commit

Permalink
Unpin SQLAlchemy & sqlalchemy-bigquery (#367)
Browse files Browse the repository at this point in the history
This commit unpins `SQLAlchemy` and `sqlalchemy-bigquery` dependency so we can use it with Airflow 2.3.

This PR also fixes the Sqlite issues that we were hacking around with quotes. I have also created a companion PR in Airflow: apache/airflow#23790 . Once this PR is merged and release, we can bump the Sqlite Provider and remove the logic of `get_uri` from this repo.

closes #351
  • Loading branch information
kaxil authored May 19, 2022
1 parent 153b886 commit ba4e6bd
Show file tree
Hide file tree
Showing 10 changed files with 49 additions and 31 deletions.
2 changes: 1 addition & 1 deletion .github/ci-test-connections.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ connections:
schema: null
- conn_id: sqlite_conn
conn_type: sqlite
host: ////tmp/sqlite.db
host: /tmp/sqlite.db
schema:
login:
password:
Expand Down
3 changes: 3 additions & 0 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ def test(session: nox.Session) -> None:
"""Run unit tests."""
session.install("-e", ".[all]")
session.install("-e", ".[tests]")
# Log all the installed dependencies
session.log("Installed Dependencies:")
session.run("pip3", "freeze")
session.run("airflow", "db", "init")
session.run("pytest", *session.posargs)

Expand Down
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ dependencies = [
"pyarrow",
"python-frontmatter",
"smart-open",
"SQLAlchemy>=1.3.18,<=1.3.24"
"SQLAlchemy>=1.3.18"
]

keywords = ["airflow", "provider", "astronomer", "sql", "decorator", "task flow", "elt", "etl", "dag"]
Expand Down Expand Up @@ -49,7 +49,7 @@ tests = [
]
google = [
"apache-airflow-providers-google",
"sqlalchemy-bigquery==1.3.0",
"sqlalchemy-bigquery>=1.3.0",
"smart-open[gcs]>=5.2.1",
]
snowflake = [
Expand All @@ -73,7 +73,7 @@ all = [
"smart-open[all]>=5.2.1",
"snowflake-connector-python[pandas]",
"snowflake-sqlalchemy>=1.2.0,<=1.2.4",
"sqlalchemy-bigquery==1.3.0",
"sqlalchemy-bigquery>=1.3.0",
"s3fs"
]
doc = [
Expand Down
10 changes: 5 additions & 5 deletions src/astro/databases/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,17 @@ def __init__(self, conn_id: str = DEFAULT_CONN_ID):
super().__init__(conn_id)

@property
def hook(self):
def hook(self) -> SqliteHook:
"""Retrieve Airflow hook to interface with the Sqlite database."""
return SqliteHook(sqlite_conn_id=self.conn_id)

@property
def sqlalchemy_engine(self) -> Engine:
"""Return SQAlchemy engine."""
uri = self.hook.get_uri()
if "////" not in uri:
uri = uri.replace("///", "////")
return create_engine(uri)
# Airflow uses sqlite3 library and not SqlAlchemy for SqliteHook
# and it only uses the hostname directly.
airflow_conn = self.hook.get_connection(self.conn_id)
return create_engine(f"sqlite:///{airflow_conn.host}")

@property
def default_metadata(self) -> Metadata:
Expand Down
9 changes: 8 additions & 1 deletion src/astro/sql/operators/agnostic_save_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from astro.databases import create_database
from astro.files import File
from astro.sql.table import Table
from astro.sqlite_utils import create_sqlalchemy_engine_with_sqlite
from astro.utils.database import create_database_from_conn_id
from astro.utils.dependencies import BigQueryHook, PostgresHook, SnowflakeHook
from astro.utils.task_id_helper import get_task_id
Expand Down Expand Up @@ -111,9 +112,15 @@ def convert_sql_table_to_dataframe(

db = create_database(input_table.conn_id)
table_name = db.get_table_qualified_name(input_table)

if database == Database.SQLITE:
con_engine = create_sqlalchemy_engine_with_sqlite(input_hook)
else:
con_engine = input_hook.get_sqlalchemy_engine()

return pd.read_sql(
f"SELECT * FROM {table_name}",
con=input_hook.get_sqlalchemy_engine(),
con=con_engine,
)


Expand Down
3 changes: 2 additions & 1 deletion src/astro/sql/operators/sql_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from astro.databases import create_database
from astro.settings import SCHEMA
from astro.sql.table import Table
from astro.sqlite_utils import create_sqlalchemy_engine_with_sqlite
from astro.utils import get_hook
from astro.utils.database import create_database_from_conn_id
from astro.utils.dependencies import (
Expand Down Expand Up @@ -147,7 +148,7 @@ def _get_dataframe(self, table: Table):
)
elif database == Database.SQLITE:
hook = SqliteHook(sqlite_conn_id=table.conn_id)
engine = hook.get_sqlalchemy_engine()
engine = create_sqlalchemy_engine_with_sqlite(hook)
df = pd.read_sql_table(table.name, engine)
elif database == Database.BIGQUERY:
db = create_database(table.conn_id)
Expand Down
11 changes: 11 additions & 0 deletions src/astro/sqlite_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import sqlalchemy
from airflow.providers.sqlite.hooks.sqlite import SqliteHook


# TODO: This function should be removed after the refactor as this is handled in the Database
def create_sqlalchemy_engine_with_sqlite(hook: SqliteHook) -> sqlalchemy.engine.Engine:
# Airflow uses sqlite3 library and not SqlAlchemy for SqliteHook
# and it only uses the hostname directly.
airflow_conn = hook.get_connection(getattr(hook, hook.conn_name_attr))
engine = sqlalchemy.create_engine(f"sqlite:///{airflow_conn.host}")
return engine
13 changes: 5 additions & 8 deletions src/astro/utils/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

from airflow.hooks.base import BaseHook
from airflow.providers.sqlite.hooks.sqlite import SqliteHook
from sqlalchemy import create_engine, text
from sqlalchemy.engine import Engine
from sqlalchemy.engine.result import ResultProxy
from sqlalchemy import text
from sqlalchemy.engine import Engine, ResultProxy

from astro.constants import CONN_TYPE_TO_DATABASE, Database
from astro.sqlite_utils import create_sqlalchemy_engine_with_sqlite
from astro.utils.dependencies import BigQueryHook, PostgresHook, SnowflakeHook


Expand Down Expand Up @@ -64,10 +64,7 @@ def get_sqlalchemy_engine(hook: Union[BaseHook, SqliteHook]) -> Engine:
database = get_database_name(hook)
engine = None
if database == Database.SQLITE:
uri = hook.get_uri()
if "////" not in uri:
uri = hook.get_uri().replace("///", "////")
engine = create_engine(uri)
engine = create_sqlalchemy_engine_with_sqlite(hook)
if engine is None:
engine = hook.get_sqlalchemy_engine()
return engine
Expand All @@ -88,7 +85,7 @@ def run_sql(
:param parameters: (optional) Parameters to be passed to the SQL statement
:type parameters: dict
:return: Result of running the statement
:rtype: sqlalchemy.engine.result.ResultProxy
:rtype: sqlalchemy.engine.ResultProxy
"""
if parameters is None:
parameters = {}
Expand Down
15 changes: 7 additions & 8 deletions tests/databases/test_sqlite.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
"""Tests specific to the Sqlite Database implementation."""
import os
import pathlib
from urllib.parse import urlparse

import pandas as pd
import pytest
import sqlalchemy
from airflow.hooks.base import BaseHook

from astro.constants import Database
from astro.databases import create_database
Expand Down Expand Up @@ -33,20 +33,19 @@ def test_create_database(conn_id):


@pytest.mark.parametrize(
"conn_id,expected_uri",
"conn_id,expected_db_path",
[
(DEFAULT_CONN_ID, "//tmp/sqlite_default.db"),
(CUSTOM_CONN_ID, "////tmp/sqlite.db"),
(DEFAULT_CONN_ID, BaseHook.get_connection(DEFAULT_CONN_ID).host),
(CUSTOM_CONN_ID, "/tmp/sqlite.db"),
],
ids=SUPPORTED_CONN_IDS,
)
def test_sqlite_sqlalchemy_engine(conn_id, expected_uri):
"""Confirm that the SQLALchemy is created successfully."""
def test_sqlite_sqlalchemy_engine(conn_id, expected_db_path):
"""Confirm that the SQLAlchemy is created successfully and verify DB path."""
database = SqliteDatabase(conn_id)
engine = database.sqlalchemy_engine
assert isinstance(engine, sqlalchemy.engine.base.Engine)
url = urlparse(str(engine.url))
assert url.path == expected_uri
assert engine.url.database == expected_db_path


@pytest.mark.integration
Expand Down
8 changes: 4 additions & 4 deletions tests/utils/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from sqlalchemy.engine.base import Engine

from astro.constants import Database
from astro.sqlite_utils import create_sqlalchemy_engine_with_sqlite
from astro.utils.database import (
create_database_from_conn_id,
get_database_name,
Expand Down Expand Up @@ -59,7 +60,7 @@ def with_sqlite_hook():
hook = SqliteHook()
db = get_database_name(hook)
assert db == Database.SQLITE
engine = hook.get_sqlalchemy_engine()
engine = create_sqlalchemy_engine_with_sqlite(hook)
db = get_database_name(engine)
assert db == Database.SQLITE

Expand All @@ -74,10 +75,9 @@ def with_unsupported_hook():
def describe_get_sqlalchemy_engine():
def with_sqlite():
hook = SqliteHook(sqlite_conn_id="sqlite_conn")
engine = get_sqlalchemy_engine(hook)
engine = create_sqlalchemy_engine_with_sqlite(hook)
assert isinstance(engine, Engine)
url = urlparse(str(engine.url))
assert url.path == "////tmp/sqlite.db"
assert engine.url.database == BaseHook.get_connection("sqlite_conn").host

def with_sqlite_default_conn():
hook = SqliteHook()
Expand Down

0 comments on commit ba4e6bd

Please sign in to comment.