Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX: DbApiHook.insert_rows unnecessarily restarting connections #40615

Merged
merged 13 commits into from
Jul 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions airflow/providers/common/sql/hooks/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import warnings
from contextlib import closing, contextmanager
from datetime import datetime
from functools import cached_property
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -54,6 +55,7 @@
from airflow.providers.openlineage.extractors import OperatorLineage
from airflow.providers.openlineage.sqlparser import DatabaseInfo


T = TypeVar("T")
SQL_PLACEHOLDERS = frozenset({"%s", "?"})

Expand Down Expand Up @@ -181,24 +183,28 @@ def __init__(self, *args, schema: str | None = None, log_sql: bool = True, **kwa
"replace_statement_format", "REPLACE INTO {} {} VALUES ({})"
)

@property
def get_conn_id(self) -> str:
return getattr(self, self.conn_name_attr)

@cached_property
def placeholder(self):
conn = self.get_connection(getattr(self, self.conn_name_attr))
conn = self.get_connection(self.get_conn_id())
placeholder = conn.extra_dejson.get("placeholder")
if placeholder:
if placeholder in SQL_PLACEHOLDERS:
return placeholder
self.log.warning(
"Placeholder defined in Connection '%s' is not listed in 'DEFAULT_SQL_PLACEHOLDERS' "
"Placeholder '%s' defined in Connection '%s' is not listed in 'DEFAULT_SQL_PLACEHOLDERS' "
"and got ignored. Falling back to the default placeholder '%s'.",
self.conn_name_attr,
placeholder,
self.get_conn_id(),
self._placeholder,
)
return self._placeholder

def get_conn(self):
"""Return a connection object."""
db = self.get_connection(getattr(self, cast(str, self.conn_name_attr)))
db = self.get_connection(self.get_conn_id())
return self.connector.connect(host=db.host, port=db.port, username=db.login, schema=db.schema)

def get_uri(self) -> str:
Expand All @@ -207,7 +213,7 @@ def get_uri(self) -> str:

:return: the extracted uri.
"""
conn = self.get_connection(getattr(self, self.conn_name_attr))
conn = self.get_connection(self.get_conn_id())
conn.schema = self.__schema or conn.schema
return conn.get_uri()

Expand Down Expand Up @@ -502,7 +508,7 @@ def set_autocommit(self, conn, autocommit):
if not self.supports_autocommit and autocommit:
self.log.warning(
"%s connection doesn't support autocommit but autocommit activated.",
getattr(self, self.conn_name_attr),
self.get_conn_id(),
)
conn.autocommit = autocommit

Expand Down
5 changes: 4 additions & 1 deletion airflow/providers/common/sql/hooks/sql.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ from airflow.exceptions import (
from airflow.hooks.base import BaseHook as BaseHook
from airflow.providers.openlineage.extractors import OperatorLineage as OperatorLineage
from airflow.providers.openlineage.sqlparser import DatabaseInfo as DatabaseInfo
from functools import cached_property as cached_property
from pandas import DataFrame as DataFrame
from sqlalchemy.engine import URL as URL
from typing import Any, Callable, Generator, Iterable, Mapping, Protocol, Sequence, TypeVar, overload
Expand All @@ -63,7 +64,9 @@ class DbApiHook(BaseHook):
log_sql: Incomplete
descriptions: Incomplete
def __init__(self, *args, schema: str | None = None, log_sql: bool = True, **kwargs) -> None: ...
@property

def get_conn_id(self) -> str: ...
@cached_property
def placeholder(self): ...
def get_conn(self): ...
def get_uri(self) -> str: ...
Expand Down
2 changes: 1 addition & 1 deletion tests/providers/common/sql/hooks/test_dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ def test_placeholder_with_invalid_placeholder_in_extra(self, caplog):
)

assert self.db_hook.placeholder == "%s"
assert any(
assert (
"Placeholder defined in Connection 'test_conn_id' is not listed in 'DEFAULT_SQL_PLACEHOLDERS' "
"and got ignored. Falling back to the default placeholder '%s'." in message
for message in caplog.messages
Expand Down
19 changes: 19 additions & 0 deletions tests/providers/common/sql/hooks/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#
from __future__ import annotations

import logging
import warnings
from unittest.mock import MagicMock

Expand Down Expand Up @@ -256,3 +257,21 @@ def test_make_common_data_structure_no_deprecated_method(self):
def test_placeholder_config_from_extra(self):
dbapi_hook = mock_hook(DbApiHook, conn_params={"extra": {"placeholder": "?"}})
assert dbapi_hook.placeholder == "?"

@pytest.mark.db_test
def test_placeholder_config_from_extra_when_not_in_default_sql_placeholders(self, caplog):
with caplog.at_level(logging.WARNING, logger="airflow.providers.common.sql.hooks.test_sql"):
dbapi_hook = mock_hook(DbApiHook, conn_params={"extra": {"placeholder": "!"}})
assert dbapi_hook.placeholder == "%s"
assert (
"Placeholder '!' defined in Connection 'default_conn_id' is not listed in 'DEFAULT_SQL_PLACEHOLDERS' "
f"and got ignored. Falling back to the default placeholder '{DbApiHook._placeholder}'."
in caplog.text
)

@pytest.mark.db_test
def test_placeholder_multiple_times_and_make_sure_connection_is_only_invoked_once(self):
dbapi_hook = mock_hook(DbApiHook)
for _ in range(10):
assert dbapi_hook.placeholder == "%s"
assert dbapi_hook.connection_invocations == 1
2 changes: 2 additions & 0 deletions tests/providers/common/sql/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,11 @@ def mock_hook(hook_class: type[BaseHook], hook_params=None, conn_params=None):

class MockedHook(hook_class): # type: ignore[misc, valid-type]
conn_name_attr = "test_conn_id"
connection_invocations = 0

@classmethod
def get_connection(cls, conn_id: str):
cls.connection_invocations += 1
return connection

def get_conn(self):
Expand Down