Skip to content

Commit

Permalink
Generalize caching of connection in DbApiHook to improve performance (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
dabla authored Sep 4, 2024
1 parent 0a816c6 commit 2e813eb
Show file tree
Hide file tree
Showing 18 changed files with 59 additions and 78 deletions.
8 changes: 8 additions & 0 deletions airflow/providers/common/sql/CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,14 @@
Changelog
---------

1.17.0
......

Features
~~~~~~~~

* ``Connection in DB Hook is now cached to avoid multiple lookups when properties from extras have to be resolved``

1.16.0
......

Expand Down
26 changes: 23 additions & 3 deletions airflow/providers/common/sql/hooks/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from pandas import DataFrame
from sqlalchemy.engine import URL

from airflow.models import Connection
from airflow.providers.openlineage.extractors import OperatorLineage
from airflow.providers.openlineage.sqlparser import DatabaseInfo

Expand Down Expand Up @@ -183,14 +184,14 @@ def __init__(self, *args, schema: str | None = None, log_sql: bool = True, **kwa
self._replace_statement_format: str = kwargs.get(
"replace_statement_format", "REPLACE INTO {} {} VALUES ({})"
)
self._connection: Connection | None = kwargs.pop("connection", None)

def get_conn_id(self) -> str:
return getattr(self, self.conn_name_attr)

@cached_property
def placeholder(self):
conn = self.get_connection(self.get_conn_id())
placeholder = conn.extra_dejson.get("placeholder")
placeholder = self.connection_extra.get("placeholder")
if placeholder:
if placeholder in SQL_PLACEHOLDERS:
return placeholder
Expand All @@ -203,9 +204,28 @@ def placeholder(self):
)
return self._placeholder

@property
def connection(self) -> Connection:
if self._connection is None:
self._connection = self.get_connection(self.get_conn_id())
return self._connection

@property
def connection_extra(self) -> dict:
return self.connection.extra_dejson

@cached_property
def connection_extra_lower(self) -> dict:
"""
``connection.extra_dejson`` but where keys are converted to lower case.
This is used internally for case-insensitive access of extra params.
"""
return {k.lower(): v for k, v in self.connection_extra.items()}

def get_conn(self):
"""Return a connection object."""
db = self.get_connection(self.get_conn_id())
db = self.connection
return self.connector.connect(host=db.host, port=db.port, username=db.login, schema=db.schema)

def get_uri(self) -> str:
Expand Down
7 changes: 7 additions & 0 deletions airflow/providers/common/sql/hooks/sql.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ from airflow.exceptions import (
AirflowProviderDeprecationWarning as AirflowProviderDeprecationWarning,
)
from airflow.hooks.base import BaseHook as BaseHook
from airflow.models import Connection as Connection
from airflow.providers.openlineage.extractors import OperatorLineage as OperatorLineage
from airflow.providers.openlineage.sqlparser import DatabaseInfo as DatabaseInfo
from functools import cached_property as cached_property
Expand Down Expand Up @@ -67,6 +68,12 @@ class DbApiHook(BaseHook):
def get_conn_id(self) -> str: ...
@cached_property
def placeholder(self): ...
@property
def connection(self) -> Connection: ...
@property
def connection_extra(self) -> dict: ...
@cached_property
def connection_extra_lower(self) -> dict: ...
def get_conn(self): ...
def get_uri(self) -> str: ...
@property
Expand Down
1 change: 1 addition & 0 deletions airflow/providers/common/sql/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ state: ready
source-date-epoch: 1723970051
# note that those versions are maintained by release manager - do not update them manually
versions:
- 1.17.0
- 1.16.0
- 1.15.0
- 1.14.2
Expand Down
8 changes: 3 additions & 5 deletions airflow/providers/elasticsearch/hooks/elasticsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,11 @@ class ElasticsearchSQLHook(DbApiHook):
def __init__(self, schema: str = "http", connection: AirflowConnection | None = None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.schema = schema
self.connection = connection
self._connection = connection

def get_conn(self) -> ESConnection:
"""Return an elasticsearch connection object."""
conn_id = self.get_conn_id()
conn = self.connection or self.get_connection(conn_id)
conn = self.connection

conn_args = {
"host": conn.host,
Expand All @@ -117,8 +116,7 @@ def get_conn(self) -> ESConnection:
return connect(**conn_args)

def get_uri(self) -> str:
conn_id = self.get_conn_id()
conn = self.connection or self.get_connection(conn_id)
conn = self.connection

login = ""
if conn.login:
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/elasticsearch/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ versions:

dependencies:
- apache-airflow>=2.8.0
- apache-airflow-providers-common-sql>=1.14.1
- apache-airflow-providers-common-sql>=1.17.0
- elasticsearch>=8.10,<9

integrations:
Expand Down
10 changes: 0 additions & 10 deletions airflow/providers/jdbc/hooks/jdbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,16 +106,6 @@ def get_ui_field_behaviour(cls) -> dict[str, Any]:
"relabeling": {"host": "Connection URL"},
}

@property
def connection_extra_lower(self) -> dict:
"""
``connection.extra_dejson`` but where keys are converted to lower case.
This is used internally for case-insensitive access of jdbc params.
"""
conn = self.get_connection(self.get_conn_id())
return {k.lower(): v for k, v in conn.extra_dejson.items()}

@property
def driver_path(self) -> str | None:
from airflow.configuration import conf
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/jdbc/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ versions:

dependencies:
- apache-airflow>=2.8.0
- apache-airflow-providers-common-sql>=1.14.1
- apache-airflow-providers-common-sql>=1.17.0
- jaydebeapi>=1.1.1

integrations:
Expand Down
24 changes: 1 addition & 23 deletions airflow/providers/microsoft/mssql/hooks/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,14 @@

from __future__ import annotations

from functools import cached_property
from typing import TYPE_CHECKING, Any
from typing import Any

import pymssql
from methodtools import lru_cache
from pymssql import Connection as PymssqlConnection

from airflow.providers.common.sql.hooks.sql import DbApiHook, fetch_all_handler

if TYPE_CHECKING:
from airflow.models import Connection


class MsSqlHook(DbApiHook):
"""
Expand Down Expand Up @@ -59,24 +55,6 @@ def __init__(
self.schema = kwargs.pop("schema", None)
self._sqlalchemy_scheme = sqlalchemy_scheme

@cached_property
def connection(self) -> Connection:
"""
Get the airflow connection object.
:return: The connection object.
"""
return self.get_connection(self.get_conn_id())

@property
def connection_extra_lower(self) -> dict:
"""
``connection.extra_dejson`` but where keys are converted to lower case.
This is used internally for case-insensitive access of mssql params.
"""
return {k.lower(): v for k, v in self.connection.extra_dejson.items()}

@property
def sqlalchemy_scheme(self) -> str:
"""Sqlalchemy scheme either from constructor, connection extras or default."""
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/microsoft/mssql/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ versions:

dependencies:
- apache-airflow>=2.8.0
- apache-airflow-providers-common-sql>=1.14.1
- apache-airflow-providers-common-sql>=1.17.0
- pymssql>=2.3.0
# The methodtools dependency can be removed with min airflow version >=2.9.1
# as it was added in https://github.com/apache/airflow/pull/37757
Expand Down
1 change: 0 additions & 1 deletion airflow/providers/mysql/hooks/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ class MySqlHook(DbApiHook):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.schema = kwargs.pop("schema", None)
self.connection = kwargs.pop("connection", None)
self.local_infile = kwargs.pop("local_infile", False)
self.init_command = kwargs.pop("init_command", None)

Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/mysql/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ versions:

dependencies:
- apache-airflow>=2.8.0
- apache-airflow-providers-common-sql>=1.14.1
- apache-airflow-providers-common-sql>=1.17.0
- mysqlclient>=1.4.0
- mysql-connector-python>=8.0.29

Expand Down
20 changes: 1 addition & 19 deletions airflow/providers/odbc/hooks/odbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,6 @@ def __init__(
self._connection = None
self._connect_kwargs = connect_kwargs

@property
def connection(self):
"""The Connection object with ID ``odbc_conn_id``."""
if not self._connection:
self._connection = self.get_connection(self.get_conn_id())
return self._connection

@property
def database(self) -> str | None:
"""Database provided in init if exists; otherwise, ``schema`` from ``Connection`` object."""
Expand All @@ -99,15 +92,6 @@ def sqlalchemy_scheme(self) -> str:
raise RuntimeError("sqlalchemy_scheme in connection extra should not contain : or / characters")
return self._sqlalchemy_scheme or extra_scheme or self.DEFAULT_SQLALCHEMY_SCHEME

@property
def connection_extra_lower(self) -> dict:
"""
``connection.extra_dejson`` but where keys are converted to lower case.
This is used internally for case-insensitive access of odbc params.
"""
return {k.lower(): v for k, v in self.connection.extra_dejson.items()}

@property
def driver(self) -> str | None:
"""Driver from init param if given; else try to find one in connection extra."""
Expand Down Expand Up @@ -166,9 +150,7 @@ def odbc_connection_string(self):
conn_str += f"PORT={self.connection.port};"

extra_exclude = {"driver", "dsn", "connect_kwargs", "sqlalchemy_scheme", "placeholder"}
extra_params = {
k: v for k, v in self.connection.extra_dejson.items() if k.lower() not in extra_exclude
}
extra_params = {k: v for k, v in self.connection_extra.items() if k.lower() not in extra_exclude}
for k, v in extra_params.items():
conn_str += f"{k}={v};"

Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/odbc/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ versions:

dependencies:
- apache-airflow>=2.8.0
- apache-airflow-providers-common-sql>=1.14.1
- apache-airflow-providers-common-sql>=1.17.0
- pyodbc>=5.0.0

integrations:
Expand Down
4 changes: 1 addition & 3 deletions airflow/providers/postgres/hooks/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ def __init__(self, *args, options: str | None = None, **kwargs) -> None:
)
kwargs["database"] = kwargs["schema"]
super().__init__(*args, **kwargs)
self.connection: Connection | None = kwargs.pop("connection", None)
self.conn: connection = None
self.database: str | None = kwargs.pop("database", None)
self.options = options
Expand Down Expand Up @@ -142,8 +141,7 @@ def _get_cursor(self, raw_cursor: str) -> CursorType:

def get_conn(self) -> connection:
"""Establish a connection to a postgres database."""
conn_id = self.get_conn_id()
conn = deepcopy(self.connection or self.get_connection(conn_id))
conn = deepcopy(self.connection)

# check for authentication via AWS IAM
if conn.extra_dejson.get("iam", False):
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/postgres/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ versions:

dependencies:
- apache-airflow>=2.8.0
- apache-airflow-providers-common-sql>=1.14.1
- apache-airflow-providers-common-sql>=1.17.0
- psycopg2-binary>=2.9.4

additional-extras:
Expand Down
4 changes: 2 additions & 2 deletions dev/breeze/tests/test_packages.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def test_get_documentation_package_path():
"postgres",
"beta0",
"""
"apache-airflow-providers-common-sql>=1.14.1b0",
"apache-airflow-providers-common-sql>=1.17.0b0",
"apache-airflow>=2.8.0b0",
"psycopg2-binary>=2.9.4",
""",
Expand All @@ -214,7 +214,7 @@ def test_get_documentation_package_path():
"postgres",
"",
"""
"apache-airflow-providers-common-sql>=1.14.1",
"apache-airflow-providers-common-sql>=1.17.0",
"apache-airflow>=2.8.0",
"psycopg2-binary>=2.9.4",
""",
Expand Down
12 changes: 6 additions & 6 deletions generated/provider_dependencies.json
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@
},
"elasticsearch": {
"deps": [
"apache-airflow-providers-common-sql>=1.14.1",
"apache-airflow-providers-common-sql>=1.17.0",
"apache-airflow>=2.8.0",
"elasticsearch>=8.10,<9"
],
Expand Down Expand Up @@ -756,7 +756,7 @@
},
"jdbc": {
"deps": [
"apache-airflow-providers-common-sql>=1.14.1",
"apache-airflow-providers-common-sql>=1.17.0",
"apache-airflow>=2.8.0",
"jaydebeapi>=1.1.1"
],
Expand Down Expand Up @@ -820,7 +820,7 @@
},
"microsoft.mssql": {
"deps": [
"apache-airflow-providers-common-sql>=1.14.1",
"apache-airflow-providers-common-sql>=1.17.0",
"apache-airflow>=2.8.0",
"methodtools>=0.4.7",
"pymssql>=2.3.0"
Expand Down Expand Up @@ -871,7 +871,7 @@
},
"mysql": {
"deps": [
"apache-airflow-providers-common-sql>=1.14.1",
"apache-airflow-providers-common-sql>=1.17.0",
"apache-airflow>=2.8.0",
"mysql-connector-python>=8.0.29",
"mysqlclient>=1.4.0"
Expand Down Expand Up @@ -902,7 +902,7 @@
},
"odbc": {
"deps": [
"apache-airflow-providers-common-sql>=1.14.1",
"apache-airflow-providers-common-sql>=1.17.0",
"apache-airflow>=2.8.0",
"pyodbc>=5.0.0"
],
Expand Down Expand Up @@ -1047,7 +1047,7 @@
},
"postgres": {
"deps": [
"apache-airflow-providers-common-sql>=1.14.1",
"apache-airflow-providers-common-sql>=1.17.0",
"apache-airflow>=2.8.0",
"psycopg2-binary>=2.9.4"
],
Expand Down

0 comments on commit 2e813eb

Please sign in to comment.