Skip to content

Commit

Permalink
Create SQLAlchemy engine from connection in DB Hook and added autocom…
Browse files Browse the repository at this point in the history
…mit param to insert_rows method (#40669)

* refactor: Refactored get_sqlalchemy_engine method of DbApiHook to use the get_conn result to build the sqlalchemy engine

* refactor: Added autocommit parameter to insert_rows just like with the run method as this parameter will also be needed once whe have the SQLInsertRowsOperator

* refactor: Updated the docstring of the insert_rows method

* refactor: Updated sql.pyi

* refactor: Try to fix AttributeError: type object 'SkipDBTestsSession' has no attribute 'get_bind'

* refactor: Implemented the sqlalchemy_url property for JdbcHook

* refactor: Refactored get_sqlalchemy_engine in DbApiHook, if Hook implements the sqlalchemy_url property then use it, otherwise fallback to original implementation with get_uri

* refactor: Added SQLAlchemy Inspector property in DbApiHook

* refactor: Reformated test_sqlalchemy_url_with_sqlalchemy_scheme in TestJdbcHook

* refactor: Fixed static checks in DbApiHook

* refactor: Fixed some static checks

* docs: Updated docstring of JdbcHook and mentioned importance of sqlalchemy_scheme parameter

---------

Co-authored-by: David Blain <david.blain@infrabel.be>
  • Loading branch information
dabla and davidblain-infrabel authored Jul 26, 2024
1 parent 1def3f1 commit f6c7388
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 6 deletions.
23 changes: 20 additions & 3 deletions airflow/providers/common/sql/hooks/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import sqlparse
from more_itertools import chunked
from sqlalchemy import create_engine
from sqlalchemy.engine import Inspector

from airflow.exceptions import (
AirflowException,
Expand Down Expand Up @@ -242,7 +243,20 @@ def get_sqlalchemy_engine(self, engine_kwargs=None):
"""
if engine_kwargs is None:
engine_kwargs = {}
return create_engine(self.get_uri(), **engine_kwargs)
engine_kwargs["creator"] = self.get_conn

try:
url = self.sqlalchemy_url
except NotImplementedError:
url = self.get_uri()

self.log.debug("url: %s", url)
self.log.debug("engine_kwargs: %s", engine_kwargs)
return create_engine(url=url, **engine_kwargs)

@property
def inspector(self) -> Inspector:
return Inspector.from_engine(self.get_sqlalchemy_engine())

def get_pandas_df(
self,
Expand Down Expand Up @@ -571,6 +585,7 @@ def insert_rows(
replace=False,
*,
executemany=False,
autocommit=False,
**kwargs,
):
"""
Expand All @@ -585,12 +600,14 @@ def insert_rows(
:param commit_every: The maximum number of rows to insert in one
transaction. Set to 0 to insert all rows in one transaction.
:param replace: Whether to replace instead of insert
:param executemany: (Deprecated) If True, all rows are inserted at once in
:param executemany: If True, all rows are inserted at once in
chunks defined by the commit_every parameter. This only works if all rows
have same number of column names, but leads to better performance.
:param autocommit: What to set the connection's autocommit setting to
before executing the query.
"""
nb_rows = 0
with self._create_autocommit_connection() as conn:
with self._create_autocommit_connection(autocommit) as conn:
conn.commit()
with closing(conn.cursor()) as cur:
if self.supports_executemany or executemany:
Expand Down
6 changes: 4 additions & 2 deletions airflow/providers/common/sql/hooks/sql.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ from airflow.providers.openlineage.extractors import OperatorLineage as Operator
from airflow.providers.openlineage.sqlparser import DatabaseInfo as DatabaseInfo
from functools import cached_property as cached_property
from pandas import DataFrame as DataFrame
from sqlalchemy.engine import URL as URL
from sqlalchemy.engine import Inspector, URL as URL
from typing import Any, Callable, Generator, Iterable, Mapping, Protocol, Sequence, TypeVar, overload

T = TypeVar("T")
Expand All @@ -64,7 +64,6 @@ class DbApiHook(BaseHook):
log_sql: Incomplete
descriptions: Incomplete
def __init__(self, *args, schema: str | None = None, log_sql: bool = True, **kwargs) -> None: ...

def get_conn_id(self) -> str: ...
@cached_property
def placeholder(self): ...
Expand All @@ -73,6 +72,8 @@ class DbApiHook(BaseHook):
@property
def sqlalchemy_url(self) -> URL: ...
def get_sqlalchemy_engine(self, engine_kwargs: Incomplete | None = None): ...
@property
def inspector(self) -> Inspector: ...
def get_pandas_df(
self, sql, parameters: list | tuple | Mapping[str, Any] | None = None, **kwargs
) -> DataFrame: ...
Expand Down Expand Up @@ -123,6 +124,7 @@ class DbApiHook(BaseHook):
replace: bool = False,
*,
executemany: bool = False,
autocommit: bool = False,
**kwargs,
): ...
def bulk_dump(self, table, tmp_file) -> None: ...
Expand Down
26 changes: 25 additions & 1 deletion airflow/providers/jdbc/hooks/jdbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
from typing import TYPE_CHECKING, Any

import jaydebeapi
from sqlalchemy.engine import URL

from airflow.exceptions import AirflowException
from airflow.providers.common.sql.hooks.sql import DbApiHook

if TYPE_CHECKING:
Expand Down Expand Up @@ -60,7 +62,12 @@ class JdbcHook(DbApiHook):
"providers.jdbc" section of the Airflow configuration. If you're enabling these options in Airflow
configuration, you should make sure that you trust the users who can edit connections in the UI
to not use it maliciously.
4. Patch the ``JdbcHook.default_driver_path`` and/or ``JdbcHook.default_driver_class`` values in the
4. Define the "sqlalchemy_scheme" property in the extra of the connection if you want to use the
SQLAlchemy engine from the JdbcHook. When using the JdbcHook, the "sqlalchemy_scheme" will by
default have the "jdbc" value, which is a protocol, not a database scheme or dialect. So in order
to be able to use SQLAlchemy with the JdbcHook, you need to define the "sqlalchemy_scheme"
property in the extra of the connection.
5. Patch the ``JdbcHook.default_driver_path`` and/or ``JdbcHook.default_driver_class`` values in the
``local_settings.py`` file.
See :doc:`/connections/jdbc` for full documentation.
Expand Down Expand Up @@ -149,6 +156,23 @@ def driver_class(self) -> str | None:
self._driver_class = self.default_driver_class
return self._driver_class

@property
def sqlalchemy_url(self) -> URL:
conn = self.get_connection(getattr(self, self.conn_name_attr))
sqlalchemy_scheme = conn.extra_dejson.get("sqlalchemy_scheme")
if sqlalchemy_scheme is None:
raise AirflowException(
"The parameter 'sqlalchemy_scheme' must be defined in extra for JDBC connections!"
)
return URL.create(
drivername=sqlalchemy_scheme,
username=conn.login,
password=conn.password,
host=conn.host,
port=conn.port,
database=conn.schema,
)

def get_conn(self) -> jaydebeapi.Connection:
conn: Connection = self.get_connection(self.get_conn_id())
host: str = conn.host
Expand Down
10 changes: 10 additions & 0 deletions airflow/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,16 @@ def __init__(self):
def remove(*args, **kwargs):
pass

def get_bind(
self,
mapper=None,
clause=None,
bind=None,
_sa_skip_events=None,
_sa_skip_for_implicit_returning=False,
):
pass


class TracebackSession:
"""
Expand Down
15 changes: 15 additions & 0 deletions tests/providers/jdbc/hooks/test_jdbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import jaydebeapi
import pytest

from airflow.exceptions import AirflowException
from airflow.models import Connection
from airflow.providers.jdbc.hooks.jdbc import JdbcHook, suppress_and_warn
from airflow.utils import db
Expand Down Expand Up @@ -186,3 +187,17 @@ def test_suppress_and_warn_when_raised_exception_is_not_suppressed(self):
with pytest.raises(RuntimeError, match="Spam Egg"):
with suppress_and_warn(KeyError):
raise RuntimeError("Spam Egg")

def test_sqlalchemy_url_without_sqlalchemy_scheme(self):
hook_params = {"driver_path": "ParamDriverPath", "driver_class": "ParamDriverClass"}
hook = get_hook(hook_params=hook_params)

with pytest.raises(AirflowException):
hook.sqlalchemy_url

def test_sqlalchemy_url_with_sqlalchemy_scheme(self):
conn_params = dict(extra=json.dumps(dict(sqlalchemy_scheme="mssql")))
hook_params = {"driver_path": "ParamDriverPath", "driver_class": "ParamDriverClass"}
hook = get_hook(conn_params=conn_params, hook_params=hook_params)

assert str(hook.sqlalchemy_url) == "mssql://login:password@host:1234/schema"

0 comments on commit f6c7388

Please sign in to comment.