Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
66 commits
Select commit Hold shift + click to select a range
021df83
Moving BaseHook to task SDK
amoghrajesh Jun 17, 2025
3053da9
changing occurences in providers
amoghrajesh Jun 18, 2025
5c88ae0
changing in public interface
amoghrajesh Jun 18, 2025
b60152e
updating occurences in provider tests
amoghrajesh Jun 18, 2025
d2d4ddf
moving hooks test to task sdk
amoghrajesh Jun 18, 2025
8b640cf
adding backcompat for basehook
amoghrajesh Jun 18, 2025
1225b42
Merge branch 'main' into move-basehook-to-task-sdk
amoghrajesh Jun 18, 2025
b544139
adding import in try catch block for backcompat in providers
amoghrajesh Jun 18, 2025
136f94c
fixing documentation in task sdk
amoghrajesh Jun 18, 2025
ebe6556
this is how to fix things
amoghrajesh Jun 18, 2025
d7a91a0
making it autouse
amoghrajesh Jun 19, 2025
023970c
Merge branch 'main' into move-basehook-to-task-sdk
amoghrajesh Jun 23, 2025
5fa4a22
Merge branch 'main' into move-basehook-to-task-sdk
amoghrajesh Jun 23, 2025
c7659d5
remove autouse
amoghrajesh Jun 23, 2025
e604ffb
🤦🏻
amoghrajesh Jun 23, 2025
a83f655
fixing core tests
amoghrajesh Jun 24, 2025
60696c5
Merge branch 'main' into move-basehook-to-task-sdk
amoghrajesh Jun 25, 2025
a26aab9
fixing those static checks for redef
amoghrajesh Jun 25, 2025
f9427b4
making the fixture as autouse
amoghrajesh Jun 25, 2025
9d41622
fixing core tests
amoghrajesh Jun 25, 2025
96cc643
do not love it, but compat
amoghrajesh Jun 27, 2025
11b649e
older provider test
amoghrajesh Jun 27, 2025
ba6d2c3
Merge branch 'main' into move-basehook-to-task-sdk
amoghrajesh Jun 27, 2025
06f8615
fixing import mypy
amoghrajesh Jun 27, 2025
9c37dc2
mypy wars: part 1
amoghrajesh Jun 27, 2025
c3a0f39
mypy wars: part 2
amoghrajesh Jun 27, 2025
0242e3b
mypy wars: part 3
amoghrajesh Jun 27, 2025
5c9939a
mypy wars: part 4 - all arg-type
amoghrajesh Jun 27, 2025
b7bebc5
mypy wars: part 4 - all arg-type
amoghrajesh Jun 27, 2025
c9f2ff0
mypy wars: part 5 - casting
amoghrajesh Jun 27, 2025
ff3ec38
mypy wars: part 6 - casting
amoghrajesh Jun 27, 2025
2dc98dd
mypy wars: part 7
amoghrajesh Jun 27, 2025
be57175
Merge branch 'main' into move-basehook-to-task-sdk
amoghrajesh Jun 27, 2025
0c7bad0
mypy wars: part 8
amoghrajesh Jun 27, 2025
df8f331
mypy wars: part 9
amoghrajesh Jun 27, 2025
a95c973
mypy wars: part 10
amoghrajesh Jun 27, 2025
7c1f64a
mypy wars: part 11
amoghrajesh Jun 27, 2025
4311513
fixing patch path for generic transfer
amoghrajesh Jun 27, 2025
e51d5ba
fixing failing provider tests
amoghrajesh Jun 27, 2025
f166ee4
Merge branch 'main' into move-basehook-to-task-sdk
amoghrajesh Jun 30, 2025
552389b
Merge branch 'main' into move-basehook-to-task-sdk
amoghrajesh Jun 30, 2025
b11bf9a
final mypy?
amoghrajesh Jun 30, 2025
722e764
Merge branch 'main' into move-basehook-to-task-sdk
amoghrajesh Jun 30, 2025
00567d1
remote a bad rebase file
amoghrajesh Jun 30, 2025
ff8fbde
fixing broken tests
amoghrajesh Jun 30, 2025
a0585e2
removing unneccessary str conversions
amoghrajesh Jun 30, 2025
5b55d78
another mypy fix
amoghrajesh Jun 30, 2025
94256b5
fixing compat tests
amoghrajesh Jun 30, 2025
86bac28
final mypy war!
amoghrajesh Jun 30, 2025
b46bc81
Merge branch 'main' into move-basehook-to-task-sdk
amoghrajesh Jun 30, 2025
4906f40
fixing compat
amoghrajesh Jun 30, 2025
3d8c452
compat fix again
amoghrajesh Jun 30, 2025
a1f28e1
fixing tests again
amoghrajesh Jun 30, 2025
0f20865
fixing docker tests
amoghrajesh Jun 30, 2025
34a6b3f
Merge branch 'main' into move-basehook-to-task-sdk
amoghrajesh Jul 1, 2025
b71c383
move import inside cli
amoghrajesh Jul 1, 2025
4d8cf85
reverting unwanted change
amoghrajesh Jul 1, 2025
8d70d44
stricter type checking
amoghrajesh Jul 1, 2025
c1db959
adding null check
amoghrajesh Jul 1, 2025
cee9f93
fixing connection type
amoghrajesh Jul 1, 2025
cdf6885
fixing type of connections
amoghrajesh Jul 1, 2025
644a5ea
mypy again
amoghrajesh Jul 1, 2025
ba9a57f
mypy again
amoghrajesh Jul 1, 2025
998f2a4
mypy again
amoghrajesh Jul 1, 2025
d30bc9d
mypy again
amoghrajesh Jul 1, 2025
02df165
Merge branch 'main' into move-basehook-to-task-sdk
amoghrajesh Jul 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion airflow-core/docs/public-airflow-interface.rst
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ Hooks

Hooks are interfaces to external platforms and databases, implementing a common
interface when possible and acting as building blocks for operators. All hooks
are derived from :class:`~airflow.hooks.base.BaseHook`.
are derived from :class:`~airflow.sdk.bases.hook.BaseHook`.

Airflow has a set of Hooks that are considered public. You are free to extend their functionality
by extending them:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
from airflow.cli.utils import is_stdout, print_export_output
from airflow.configuration import conf
from airflow.exceptions import AirflowNotFoundException
from airflow.hooks.base import BaseHook
from airflow.models import Connection
from airflow.providers_manager import ProvidersManager
from airflow.secrets.local_filesystem import load_connections_dict
Expand Down Expand Up @@ -67,6 +66,8 @@ def _connection_mapper(conn: Connection) -> dict[str, Any]:
def connections_get(args):
"""Get a connection."""
try:
from airflow.sdk import BaseHook

conn = BaseHook.get_connection(args.conn_id)
except AirflowNotFoundException:
raise SystemExit("Connection not found.")
Expand Down
3 changes: 3 additions & 0 deletions airflow-core/src/airflow/hooks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,8 @@
"subprocess": {
"SubprocessHook": "airflow.providers.standard.hooks.subprocess.SubprocessHook",
},
"base": {
"BaseHook": "airflow.sdk.bases.hook.BaseHook",
},
}
add_deprecated_classes(__deprecated_classes, __name__)
70 changes: 1 addition & 69 deletions airflow-core/src/airflow/hooks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,75 +19,7 @@

from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Any, Protocol

from airflow.utils.log.logging_mixin import LoggingMixin

if TYPE_CHECKING:
from airflow.models.connection import Connection # Avoid circular imports.

log = logging.getLogger(__name__)


class BaseHook(LoggingMixin):
"""
Abstract base class for hooks.

Hooks are meant as an interface to
interact with external systems. MySqlHook, HiveHook, PigHook return
object that can handle the connection and interaction to specific
instances of these systems, and expose consistent methods to interact
with them.

:param logger_name: Name of the logger used by the Hook to emit logs.
If set to `None` (default), the logger name will fall back to
`airflow.task.hooks.{class.__module__}.{class.__name__}` (e.g. DbApiHook will have
*airflow.task.hooks.airflow.providers.common.sql.hooks.sql.DbApiHook* as logger).
"""

def __init__(self, logger_name: str | None = None):
super().__init__()
self._log_config_logger_name = "airflow.task.hooks"
self._logger_name = logger_name

@classmethod
def get_connection(cls, conn_id: str) -> Connection:
"""
Get connection, given connection id.

:param conn_id: connection id
:return: connection
"""
from airflow.models.connection import Connection

conn = Connection.get_connection_from_secrets(conn_id)
log.info("Connection Retrieved '%s'", conn.conn_id)
return conn

@classmethod
def get_hook(cls, conn_id: str, hook_params: dict | None = None) -> BaseHook:
"""
Return default hook for this connection id.

:param conn_id: connection id
:param hook_params: hook parameters
:return: default hook for this connection
"""
connection = cls.get_connection(conn_id)
return connection.get_hook(hook_params=hook_params)

def get_conn(self) -> Any:
"""Return connection for the hook."""
raise NotImplementedError()

@classmethod
def get_connection_form_widgets(cls) -> dict[str, Any]:
return {}

@classmethod
def get_ui_field_behaviour(cls) -> dict[str, Any]:
return {}
from typing import Any, Protocol


class DiscoverableHook(Protocol):
Expand Down
3 changes: 1 addition & 2 deletions airflow-core/src/airflow/lineage/hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@
from airflow.utils.log.logging_mixin import LoggingMixin

if TYPE_CHECKING:
from airflow.hooks.base import BaseHook
from airflow.sdk import ObjectStoragePath
from airflow.sdk import BaseHook, ObjectStoragePath

# Store context what sent lineage.
LineageContext: TypeAlias = BaseHook | ObjectStoragePath
Expand Down
2 changes: 1 addition & 1 deletion airflow-core/src/airflow/providers_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def ensure_prefix(field):
if TYPE_CHECKING:
from urllib.parse import SplitResult

from airflow.hooks.base import BaseHook
from airflow.sdk import BaseHook
from airflow.sdk.bases.decorator import TaskDecorator
from airflow.sdk.definitions.asset import Asset

Expand Down
2 changes: 1 addition & 1 deletion airflow-core/src/airflow/utils/email.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def send_mime_email(

if conn_id is not None:
try:
from airflow.hooks.base import BaseHook
from airflow.sdk import BaseHook

airflow_conn = BaseHook.get_connection(conn_id)
smtp_user = airflow_conn.login
Expand Down
14 changes: 12 additions & 2 deletions airflow-core/tests/unit/always/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
from cryptography.fernet import Fernet

from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook
from airflow.models import Connection, crypto
from airflow.sdk import BaseHook

from tests_common.test_utils.version_compat import SQLALCHEMY_V_1_4

Expand Down Expand Up @@ -640,8 +640,18 @@ def test_param_setup(self):
assert conn.port is None

@pytest.mark.db_test
def test_env_var_priority(self):
def test_env_var_priority(self, mock_supervisor_comms):
from airflow.providers.sqlite.hooks.sqlite import SqliteHook
from airflow.sdk.execution_time.comms import ConnectionResult

conn = ConnectionResult(
conn_id="airflow_db",
conn_type="mysql",
host="mysql",
login="root",
)

mock_supervisor_comms.send.return_value = conn

conn = SqliteHook.get_connection(conn_id="airflow_db")
assert conn.host != "ec2.compute.com"
Expand Down
2 changes: 1 addition & 1 deletion airflow-core/tests/unit/always/test_example_dags.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
from packaging.specifiers import SpecifierSet
from packaging.version import Version

from airflow.hooks.base import BaseHook
from airflow.models import Connection, DagBag
from airflow.sdk import BaseHook
from airflow.utils import yaml

from tests_common.test_utils.asserts import assert_queries_count
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

from airflow.cli import cli_config, cli_parser
from airflow.cli.commands import connection_command
from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowException, AirflowNotFoundException
from airflow.models import Connection
from airflow.utils.db import merge_conn
from airflow.utils.session import create_session
Expand Down Expand Up @@ -61,7 +61,8 @@ def test_cli_connection_get(self):
stdout = stdout.getvalue()
assert "google-cloud-platform:///default" in stdout

def test_cli_connection_get_invalid(self):
def test_cli_connection_get_invalid(self, mock_supervisor_comms):
mock_supervisor_comms.send.side_effect = AirflowNotFoundException
with pytest.raises(SystemExit, match=re.escape("Connection not found.")):
connection_command.connections_get(self.parser.parse_args(["connections", "get", "INVALID"]))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@

from airflow.cli import cli_parser
from airflow.cli.commands import rotate_fernet_key_command
from airflow.hooks.base import BaseHook
from airflow.models import Connection, Variable
from airflow.sdk import BaseHook
from airflow.utils.session import provide_session

from tests_common.test_utils.config import conf_vars
Expand Down Expand Up @@ -84,7 +84,7 @@ def test_should_rotate_variable(self, session):
assert Variable.get(key=var2_key) == "value"

@provide_session
def test_should_rotate_connection(self, session):
def test_should_rotate_connection(self, session, mock_supervisor_comms):
fernet_key1 = Fernet.generate_key()
fernet_key2 = Fernet.generate_key()
var1_key = f"{__file__}_var1"
Expand All @@ -111,6 +111,26 @@ def test_should_rotate_connection(self, session):
args = self.parser.parse_args(["rotate-fernet-key"])
rotate_fernet_key_command.rotate_fernet_key(args)

def mock_get_connection(conn_id):
conn = session.query(Connection).filter(Connection.conn_id == conn_id).first()
if conn:
from airflow.sdk.execution_time.comms import ConnectionResult

return ConnectionResult(
conn_id=conn.conn_id,
conn_type=conn.conn_type or "mysql", # Provide a default conn_type
host=conn.host,
login=conn.login,
password=conn.password,
schema_=conn.schema,
port=conn.port,
extra=conn.extra,
)
raise Exception(f"Connection {conn_id} not found")

# Mock the send method to return our connection data
mock_supervisor_comms.send.return_value = mock_get_connection(var1_key)

# Assert correctness using a new fernet key
with (
conf_vars({("core", "fernet_key"): fernet_key2.decode()}),
Expand All @@ -119,5 +139,7 @@ def test_should_rotate_connection(self, session):
# Unencrypted variable should be unchanged
conn1: Connection = BaseHook.get_connection(var1_key)
assert conn1.password == "pass"
assert conn1._password == "pass"

# Mock for the second connection
mock_supervisor_comms.send.return_value = mock_get_connection(var2_key)
assert BaseHook.get_connection(var2_key).password == "pass"
6 changes: 2 additions & 4 deletions airflow-core/tests/unit/dag_processing/test_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,8 +282,7 @@ def test_top_level_connection_access(
logger_filehandle = MagicMock()

def dag_in_a_fn():
from airflow.hooks.base import BaseHook
from airflow.sdk import DAG
from airflow.sdk import DAG, BaseHook

with DAG(f"test_{BaseHook.get_connection(conn_id='my_conn').conn_id}"):
...
Expand Down Expand Up @@ -312,8 +311,7 @@ def test_top_level_connection_access_not_found(self, tmp_path: pathlib.Path, inp
logger_filehandle = MagicMock()

def dag_in_a_fn():
from airflow.hooks.base import BaseHook
from airflow.sdk import DAG
from airflow.sdk import DAG, BaseHook

with DAG(f"test_{BaseHook.get_connection(conn_id='my_conn').conn_id}"):
...
Expand Down
2 changes: 1 addition & 1 deletion airflow-core/tests/unit/jobs/test_triggerer_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from asgiref.sync import sync_to_async

from airflow.executors import workloads
from airflow.hooks.base import BaseHook
from airflow.jobs.job import Job
from airflow.jobs.triggerer_job_runner import (
TriggerCommsDecoder,
Expand All @@ -49,6 +48,7 @@
from airflow.providers.standard.operators.empty import EmptyOperator
from airflow.providers.standard.operators.python import PythonOperator
from airflow.providers.standard.triggers.temporal import DateTimeTrigger, TimeDeltaTrigger
from airflow.sdk import BaseHook
from airflow.triggers.base import BaseTrigger, TriggerEvent
from airflow.triggers.testing import FailureTrigger, SuccessTrigger
from airflow.utils import timezone
Expand Down
2 changes: 1 addition & 1 deletion airflow-core/tests/unit/lineage/test_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import pytest

from airflow import plugins_manager
from airflow.hooks.base import BaseHook
from airflow.lineage import hook
from airflow.lineage.hook import (
AssetLineageInfo,
Expand All @@ -32,6 +31,7 @@
NoOpCollector,
get_hook_lineage_collector,
)
from airflow.sdk import BaseHook
from airflow.sdk.definitions.asset import Asset

from tests_common.test_utils.mock_plugins import mock_plugin_manager
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@
ParamValidationError,
SerializationError,
)
from airflow.hooks.base import BaseHook
from airflow.models.asset import AssetModel
from airflow.models.baseoperator import BaseOperator
from airflow.models.connection import Connection
Expand All @@ -64,7 +63,7 @@
from airflow.providers.cncf.kubernetes.pod_generator import PodGenerator
from airflow.providers.standard.operators.bash import BashOperator
from airflow.providers.standard.sensors.bash import BashSensor
from airflow.sdk import AssetAlias, teardown
from airflow.sdk import AssetAlias, BaseHook, teardown
from airflow.sdk.bases.decorator import DecoratedOperator
from airflow.sdk.definitions._internal.expandinput import EXPAND_INPUT_EMPTY
from airflow.sdk.definitions.asset import Asset, AssetUniqueKey
Expand Down
2 changes: 1 addition & 1 deletion devel-common/src/tests_common/test_utils/common_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from airflow.models import Connection

if TYPE_CHECKING:
from airflow.hooks.base import BaseHook
from airflow.sdk import BaseHook


def mock_db_hook(hook_class: type[BaseHook], hook_params=None, conn_params=None):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@
from requests import Session

from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook

try:
from airflow.sdk import BaseHook
except ImportError:
from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef]

T = TypeVar("T", bound=Any)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,11 @@
from alibabacloud_tea_openapi.models import Config

from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook

try:
from airflow.sdk import BaseHook
except ImportError:
from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef]
from airflow.utils.log.logging_mixin import LoggingMixin


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@

from typing import Any, NamedTuple

from airflow.hooks.base import BaseHook
try:
from airflow.sdk import BaseHook
except ImportError:
from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef]


class AccessKeyCredentials(NamedTuple):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,17 @@
from oss2.exceptions import ClientError

from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook

try:
from airflow.sdk import BaseHook
except ImportError:
from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef]

if TYPE_CHECKING:
from airflow.models.connection import Connection
try:
from airflow.sdk import Connection
except ImportError:
from airflow.models.connection import Connection # type: ignore[assignment]

T = TypeVar("T", bound=Callable)

Expand Down
Loading