diff --git a/providers/openlineage/src/airflow/providers/openlineage/extractors/manager.py b/providers/openlineage/src/airflow/providers/openlineage/extractors/manager.py index cc899b335e8c5..75a32d48bcf87 100644 --- a/providers/openlineage/src/airflow/providers/openlineage/extractors/manager.py +++ b/providers/openlineage/src/airflow/providers/openlineage/extractors/manager.py @@ -41,8 +41,8 @@ if TYPE_CHECKING: from openlineage.client.event_v2 import Dataset - from airflow.models import Operator from airflow.providers.common.compat.lineage.entities import Table + from airflow.providers.common.compat.sdk import BaseOperator def _iter_extractor_types() -> Iterator[type[BaseExtractor]]: @@ -161,7 +161,7 @@ def extract_metadata( return OperatorLineage() - def get_extractor_class(self, task: Operator) -> type[BaseExtractor] | None: + def get_extractor_class(self, task: BaseOperator) -> type[BaseExtractor] | None: if task.task_type in self.extractors: return self.extractors[task.task_type] @@ -172,7 +172,7 @@ def method_exists(method_name): return self.default_extractor return None - def _get_extractor(self, task: Operator) -> BaseExtractor | None: + def _get_extractor(self, task: BaseOperator) -> BaseExtractor | None: # TODO: Re-enable in Extractor PR # self.instantiate_abstract_extractors(task) extractor = self.get_extractor_class(task) diff --git a/providers/openlineage/src/airflow/providers/openlineage/operators/empty.py b/providers/openlineage/src/airflow/providers/openlineage/operators/empty.py index 6ac8754797f99..69e35542cb0cf 100644 --- a/providers/openlineage/src/airflow/providers/openlineage/operators/empty.py +++ b/providers/openlineage/src/airflow/providers/openlineage/operators/empty.py @@ -18,11 +18,11 @@ from typing import TYPE_CHECKING +from airflow.providers.common.compat.sdk import BaseOperator from airflow.providers.openlineage.extractors.base import OperatorLineage -from airflow.providers.openlineage.version_compat import BaseOperator if TYPE_CHECKING: - from airflow.sdk.definitions.context import Context + from airflow.providers.common.compat.sdk import Context class EmptyOperator(BaseOperator): diff --git a/providers/openlineage/src/airflow/providers/openlineage/plugins/listener.py b/providers/openlineage/src/airflow/providers/openlineage/plugins/listener.py index 734f4c5761aa1..78c3fffda897e 100644 --- a/providers/openlineage/src/airflow/providers/openlineage/plugins/listener.py +++ b/providers/openlineage/src/airflow/providers/openlineage/plugins/listener.py @@ -29,6 +29,7 @@ from airflow import settings from airflow.listeners import hookimpl from airflow.models import DagRun, TaskInstance +from airflow.providers.common.compat.sdk import timeout, timezone from airflow.providers.openlineage import conf from airflow.providers.openlineage.extractors import ExtractorManager, OperatorLineage from airflow.providers.openlineage.plugins.adapter import OpenLineageAdapter, RunState @@ -48,7 +49,6 @@ is_selective_lineage_enabled, print_warning, ) -from airflow.providers.openlineage.version_compat import timeout, timezone from airflow.settings import configure_orm from airflow.stats import Stats from airflow.utils.state import TaskInstanceState diff --git a/providers/openlineage/src/airflow/providers/openlineage/plugins/macros.py b/providers/openlineage/src/airflow/providers/openlineage/plugins/macros.py index 699800fb66eb5..ac9c75b4a941b 100644 --- a/providers/openlineage/src/airflow/providers/openlineage/plugins/macros.py +++ b/providers/openlineage/src/airflow/providers/openlineage/plugins/macros.py @@ -24,7 +24,7 @@ from airflow.providers.openlineage.version_compat import AIRFLOW_V_3_0_PLUS if TYPE_CHECKING: - from airflow.models import TaskInstance + from airflow.providers.common.compat.sdk import TaskInstance def lineage_job_namespace(): diff --git a/providers/openlineage/src/airflow/providers/openlineage/sqlparser.py b/providers/openlineage/src/airflow/providers/openlineage/sqlparser.py index d89852acbda14..0ac80fc9d7341 100644 --- a/providers/openlineage/src/airflow/providers/openlineage/sqlparser.py +++ b/providers/openlineage/src/airflow/providers/openlineage/sqlparser.py @@ -39,8 +39,8 @@ from openlineage.client.facet_v2 import JobFacet, RunFacet from sqlalchemy.engine import Engine + from airflow.providers.common.compat.sdk import BaseHook from airflow.providers.common.sql.hooks.sql import DbApiHook - from airflow.sdk import BaseHook log = logging.getLogger(__name__) diff --git a/providers/openlineage/src/airflow/providers/openlineage/utils/selective_enable.py b/providers/openlineage/src/airflow/providers/openlineage/utils/selective_enable.py index edf6f9c858f84..5f49231dc8f2d 100644 --- a/providers/openlineage/src/airflow/providers/openlineage/utils/selective_enable.py +++ b/providers/openlineage/src/airflow/providers/openlineage/utils/selective_enable.py @@ -22,19 +22,14 @@ from airflow.models import Param from airflow.models.xcom_arg import XComArg +from airflow.providers.common.compat.sdk import DAG if TYPE_CHECKING: + from airflow.providers.common.compat.sdk import BaseOperator, MappedOperator from airflow.providers.openlineage.utils.utils import AnyOperator - from airflow.sdk import DAG, BaseOperator - from airflow.sdk.definitions.mappedoperator import MappedOperator from airflow.serialization.serialized_objects import SerializedDAG T = TypeVar("T", bound=DAG | BaseOperator | MappedOperator) -else: - try: - from airflow.sdk import DAG - except ImportError: - from airflow.models import DAG ENABLE_OL_PARAM_NAME = "_selective_enable_ol" ENABLE_OL_PARAM = Param(True, const=True) diff --git a/providers/openlineage/src/airflow/providers/openlineage/utils/spark.py b/providers/openlineage/src/airflow/providers/openlineage/utils/spark.py index 4925cd00f3923..a92ac25eab274 100644 --- a/providers/openlineage/src/airflow/providers/openlineage/utils/spark.py +++ b/providers/openlineage/src/airflow/providers/openlineage/utils/spark.py @@ -30,7 +30,7 @@ ) if TYPE_CHECKING: - from airflow.utils.context import Context + from airflow.providers.common.compat.sdk import Context log = logging.getLogger(__name__) diff --git a/providers/openlineage/src/airflow/providers/openlineage/utils/sql.py b/providers/openlineage/src/airflow/providers/openlineage/utils/sql.py index 904206170e961..7f07d39003929 100644 --- a/providers/openlineage/src/airflow/providers/openlineage/utils/sql.py +++ b/providers/openlineage/src/airflow/providers/openlineage/utils/sql.py @@ -31,7 +31,7 @@ from sqlalchemy.engine import Engine from sqlalchemy.sql import ClauseElement - from airflow.sdk import BaseHook + from airflow.providers.common.compat.sdk import BaseHook log = logging.getLogger(__name__) diff --git a/providers/openlineage/src/airflow/providers/openlineage/utils/utils.py b/providers/openlineage/src/airflow/providers/openlineage/utils/utils.py index b2e483315fcb1..8ef5c23ced88c 100644 --- a/providers/openlineage/src/airflow/providers/openlineage/utils/utils.py +++ b/providers/openlineage/src/airflow/providers/openlineage/utils/utils.py @@ -35,6 +35,8 @@ # TODO: move this maybe to Airflow's logic? from airflow.models import DagRun, TaskReschedule from airflow.models.mappedoperator import MappedOperator as SerializedMappedOperator +from airflow.providers.common.compat.assets import Asset +from airflow.providers.common.compat.sdk import DAG, BaseOperator, BaseSensorOperator, MappedOperator from airflow.providers.openlineage import ( __version__ as OPENLINEAGE_PROVIDER_VERSION, conf, @@ -57,11 +59,6 @@ from airflow.serialization.serialized_objects import SerializedBaseOperator, SerializedDAG from airflow.utils.module_loading import import_string -if AIRFLOW_V_3_0_PLUS: - from airflow.sdk import BaseSensorOperator -else: - from airflow.sensors.base import BaseSensorOperator # type: ignore[no-redef] - if not AIRFLOW_V_3_0_PLUS: from airflow.utils.session import NEW_SESSION, provide_session @@ -72,9 +69,6 @@ from openlineage.client.facet_v2 import RunFacet, processing_engine_run from airflow.models import TaskInstance - from airflow.providers.common.compat.assets import Asset - from airflow.sdk import DAG, BaseOperator - from airflow.sdk.definitions.mappedoperator import MappedOperator from airflow.sdk.execution_time.secrets_masker import ( Redactable, Redacted, @@ -85,21 +79,6 @@ AnyOperator: TypeAlias = BaseOperator | MappedOperator | SerializedBaseOperator | SerializedMappedOperator else: - try: - from airflow.sdk import DAG, BaseOperator - from airflow.sdk.definitions.mappedoperator import MappedOperator - except ImportError: - from airflow.models import DAG, BaseOperator, MappedOperator - - try: - from airflow.providers.common.compat.assets import Asset - except ImportError: - if AIRFLOW_V_3_0_PLUS: - from airflow.sdk import Asset - else: - # dataset is renamed to asset since Airflow 3.0 - from airflow.datasets import Dataset as Asset - try: from airflow.sdk._shared.secrets_masker import ( Redactable, diff --git a/providers/openlineage/src/airflow/providers/openlineage/version_compat.py b/providers/openlineage/src/airflow/providers/openlineage/version_compat.py index ddf39a1898aaa..4a2c6ca5c6c7f 100644 --- a/providers/openlineage/src/airflow/providers/openlineage/version_compat.py +++ b/providers/openlineage/src/airflow/providers/openlineage/version_compat.py @@ -34,16 +34,5 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]: AIRFLOW_V_3_0_PLUS = get_base_airflow_version_tuple() >= (3, 0, 0) -if AIRFLOW_V_3_0_PLUS: - from airflow.sdk import BaseOperator -else: - from airflow.models import BaseOperator - -try: - from airflow.sdk import timezone - from airflow.sdk.execution_time.timeout import timeout -except ImportError: - from airflow.utils import timezone # type: ignore[attr-defined,no-redef] - from airflow.utils.timeout import timeout # type: ignore[assignment,attr-defined,no-redef] - -__all__ = ["AIRFLOW_V_3_0_PLUS", "BaseOperator", "timeout", "timezone"] + +__all__ = ["AIRFLOW_V_3_0_PLUS"] diff --git a/providers/openlineage/tests/system/openlineage/example_openlineage_base_complex_dag.py b/providers/openlineage/tests/system/openlineage/example_openlineage_base_complex_dag.py index c3cb654a5e35e..de1636e1b9108 100644 --- a/providers/openlineage/tests/system/openlineage/example_openlineage_base_complex_dag.py +++ b/providers/openlineage/tests/system/openlineage/example_openlineage_base_complex_dag.py @@ -35,8 +35,8 @@ from airflow import DAG from airflow.models import Variable -from airflow.models.baseoperator import BaseOperator from airflow.providers.common.compat.assets import Asset +from airflow.providers.common.compat.sdk import BaseOperator from airflow.providers.standard.operators.bash import BashOperator from airflow.providers.standard.operators.empty import EmptyOperator from airflow.providers.standard.operators.python import PythonOperator diff --git a/providers/openlineage/tests/system/openlineage/operator.py b/providers/openlineage/tests/system/openlineage/operator.py index 6cbfd3c0c3c05..000696a95b1f0 100644 --- a/providers/openlineage/tests/system/openlineage/operator.py +++ b/providers/openlineage/tests/system/openlineage/operator.py @@ -29,19 +29,10 @@ from dateutil.parser import parse from jinja2 import Environment -from airflow.models.operator import BaseOperator - -try: - from airflow.sdk import Variable -except ImportError: - from airflow.models.variable import Variable +from airflow.providers.common.compat.sdk import BaseOperator, Variable if TYPE_CHECKING: - try: - from airflow.sdk.definitions.context import Context - except ImportError: - # TODO: Remove once provider drops support for Airflow 2 - from airflow.utils.context import Context + from airflow.providers.common.compat.sdk import Context log = logging.getLogger(__name__) diff --git a/providers/openlineage/tests/unit/openlineage/dags/test_openlineage_execution.py b/providers/openlineage/tests/unit/openlineage/dags/test_openlineage_execution.py index f8db91611e848..cc710d91e335f 100644 --- a/providers/openlineage/tests/unit/openlineage/dags/test_openlineage_execution.py +++ b/providers/openlineage/tests/unit/openlineage/dags/test_openlineage_execution.py @@ -21,8 +21,8 @@ import time from airflow.models.dag import DAG -from airflow.models.operator import BaseOperator from airflow.providers.common.compat.openlineage.facet import Dataset +from airflow.providers.common.compat.sdk import BaseOperator from airflow.providers.openlineage.extractors import OperatorLineage diff --git a/providers/openlineage/tests/unit/openlineage/extractors/test_base.py b/providers/openlineage/tests/unit/openlineage/extractors/test_base.py index 129a86b00f7f2..d5647e3f3b0af 100644 --- a/providers/openlineage/tests/unit/openlineage/extractors/test_base.py +++ b/providers/openlineage/tests/unit/openlineage/extractors/test_base.py @@ -24,14 +24,8 @@ from openlineage.client.event_v2 import Dataset from openlineage.client.facet_v2 import BaseFacet, JobFacet, parent_run, sql_job -from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS - -if AIRFLOW_V_3_0_PLUS: - from airflow.sdk import BaseOperator -else: - from airflow.models.baseoperator import BaseOperator # type: ignore[no-redef] - from airflow.models.taskinstance import TaskInstanceState +from airflow.providers.common.compat.sdk import BaseOperator from airflow.providers.openlineage.extractors.base import ( BaseExtractor, DefaultExtractor, diff --git a/providers/openlineage/tests/unit/openlineage/extractors/test_manager.py b/providers/openlineage/tests/unit/openlineage/extractors/test_manager.py index f18f7e22a9ef6..a54830055fd76 100644 --- a/providers/openlineage/tests/unit/openlineage/extractors/test_manager.py +++ b/providers/openlineage/tests/unit/openlineage/extractors/test_manager.py @@ -31,6 +31,7 @@ from airflow.models.taskinstance import TaskInstance from airflow.providers.common.compat.lineage.entities import Column, File, Table, User +from airflow.providers.common.compat.sdk import BaseOperator, Context, ObjectStoragePath from airflow.providers.openlineage.extractors import OperatorLineage from airflow.providers.openlineage.extractors.manager import ExtractorManager from airflow.providers.openlineage.utils.utils import Asset @@ -43,13 +44,7 @@ if TYPE_CHECKING: try: from airflow.sdk.api.datamodels._generated import AssetEventDagRunReference, TIRunContext - from airflow.sdk.definitions.context import Context - except ImportError: - # TODO: Remove once provider drops support for Airflow 2 - # TIRunContext is only used in Airflow 3 tests - from airflow.utils.context import Context - AssetEventDagRunReference = TIRunContext = Any # type: ignore[misc, assignment] @@ -68,23 +63,6 @@ def hook_lineage_collector(): hook._hook_lineage_collector = None -if AIRFLOW_V_3_0_PLUS: - from airflow.sdk import BaseOperator, ObjectStoragePath - from airflow.sdk.api.datamodels._generated import TaskInstance as SDKTaskInstance - from airflow.sdk.execution_time import task_runner - from airflow.sdk.execution_time.comms import StartupDetails - from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance, parse -else: - from airflow.io.path import ObjectStoragePath # type: ignore[no-redef] - from airflow.models import BaseOperator - - SDKTaskInstance = ... # type: ignore - task_runner = ... # type: ignore - StartupDetails = ... # type: ignore - RuntimeTaskInstance = ... # type: ignore - parse = ... # type: ignore - - @pytest.mark.parametrize( ("uri", "dataset"), ( diff --git a/providers/openlineage/tests/unit/openlineage/plugins/test_listener.py b/providers/openlineage/tests/unit/openlineage/plugins/test_listener.py index 941666b8be42c..76b743f1a82f7 100644 --- a/providers/openlineage/tests/unit/openlineage/plugins/test_listener.py +++ b/providers/openlineage/tests/unit/openlineage/plugins/test_listener.py @@ -52,10 +52,7 @@ else: from airflow.utils import timezone # type: ignore[attr-defined,no-redef] -if AIRFLOW_V_3_0_PLUS: - from airflow.sdk import BaseOperator -else: - from airflow.models.baseoperator import BaseOperator # type: ignore[no-redef] +from airflow.providers.common.compat.sdk import BaseOperator EXPECTED_TRY_NUMBER_1 = 1 diff --git a/providers/openlineage/tests/unit/openlineage/plugins/test_macros.py b/providers/openlineage/tests/unit/openlineage/plugins/test_macros.py index a0b090ec53813..520421be76862 100644 --- a/providers/openlineage/tests/unit/openlineage/plugins/test_macros.py +++ b/providers/openlineage/tests/unit/openlineage/plugins/test_macros.py @@ -114,7 +114,7 @@ def test_lineage_parent_id(mock_run_id): @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test only for Airflow 3.0+") def test_lineage_root_run_id_with_runtime_task_instance(create_runtime_ti): """Test lineage_root_run_id with real RuntimeTaskInstance object doesn't throw AttributeError.""" - from airflow.sdk import BaseOperator + from airflow.providers.common.compat.sdk import BaseOperator task = BaseOperator(task_id="test_task") diff --git a/providers/openlineage/tests/unit/openlineage/plugins/test_utils.py b/providers/openlineage/tests/unit/openlineage/plugins/test_utils.py index 40c27f0b87852..b1cff03bcadaf 100644 --- a/providers/openlineage/tests/unit/openlineage/plugins/test_utils.py +++ b/providers/openlineage/tests/unit/openlineage/plugins/test_utils.py @@ -29,6 +29,7 @@ from pkg_resources import parse_version from airflow.providers.common.compat.assets import Asset +from airflow.providers.common.compat.sdk import timezone from airflow.providers.openlineage.plugins.facets import AirflowDebugRunFacet from airflow.providers.openlineage.utils.utils import ( DagInfo, @@ -53,9 +54,6 @@ if AIRFLOW_V_3_1_PLUS: from airflow.models.dag import get_next_data_interval - from airflow.sdk import timezone -else: - from airflow.utils import timezone # type: ignore[attr-defined,no-redef] if AIRFLOW_V_3_1_PLUS: from airflow.sdk._shared.secrets_masker import DEFAULT_SENSITIVE_FIELDS, SecretsMasker @@ -70,11 +68,10 @@ SecretsMasker, ) +from airflow.providers.common.compat.sdk import DAG + if AIRFLOW_V_3_0_PLUS: - from airflow.sdk import DAG from airflow.utils.types import DagRunTriggeredByType -else: - from airflow import DAG class SafeStrDict(dict): diff --git a/providers/openlineage/tests/unit/openlineage/utils/custom_facet_fixture.py b/providers/openlineage/tests/unit/openlineage/utils/custom_facet_fixture.py index 040c8c774c31f..e34af2ebb68ca 100644 --- a/providers/openlineage/tests/unit/openlineage/utils/custom_facet_fixture.py +++ b/providers/openlineage/tests/unit/openlineage/utils/custom_facet_fixture.py @@ -23,7 +23,7 @@ from airflow.providers.common.compat.openlineage.facet import RunFacet if TYPE_CHECKING: - from airflow.models.taskinstance import TaskInstance, TaskInstanceState + from airflow.providers.common.compat.sdk import TaskInstance, TaskInstanceState @attrs.define diff --git a/providers/openlineage/tests/unit/openlineage/utils/test_selective_enable.py b/providers/openlineage/tests/unit/openlineage/utils/test_selective_enable.py index 4a1bcb39ba78b..082fec0768799 100644 --- a/providers/openlineage/tests/unit/openlineage/utils/test_selective_enable.py +++ b/providers/openlineage/tests/unit/openlineage/utils/test_selective_enable.py @@ -19,13 +19,7 @@ from pendulum import now -from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS - -if AIRFLOW_V_3_0_PLUS: - from airflow.sdk import dag, task -else: - from airflow.decorators import dag, task # type: ignore[attr-defined,no-redef] -from airflow.models import DAG +from airflow.providers.common.compat.sdk import DAG, dag, task from airflow.providers.openlineage.utils.selective_enable import ( DISABLE_OL_PARAM, ENABLE_OL_PARAM, diff --git a/providers/openlineage/tests/unit/openlineage/utils/test_utils.py b/providers/openlineage/tests/unit/openlineage/utils/test_utils.py index 2d82b6b6ca5e6..3bbea4aafac4d 100644 --- a/providers/openlineage/tests/unit/openlineage/utils/test_utils.py +++ b/providers/openlineage/tests/unit/openlineage/utils/test_utils.py @@ -30,6 +30,7 @@ from airflow.models.dagrun import DagRun from airflow.models.taskinstance import TaskInstance, TaskInstanceState from airflow.providers.common.compat.assets import Asset +from airflow.providers.common.compat.sdk import BaseOperator, TaskGroup, task, timezone from airflow.providers.openlineage.plugins.facets import AirflowDagRunFacet, AirflowJobFacet from airflow.providers.openlineage.utils.utils import ( _MAX_DOC_BYTES, @@ -57,7 +58,6 @@ from airflow.serialization.serialized_objects import SerializedBaseOperator from airflow.timetables.events import EventsTimetable from airflow.timetables.trigger import CronTriggerTimetable -from airflow.utils import timezone from airflow.utils.session import create_session from airflow.utils.state import DagRunState from airflow.utils.types import DagRunType @@ -66,13 +66,6 @@ from tests_common.test_utils.mock_operators import MockOperator from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_3_PLUS, AIRFLOW_V_3_0_PLUS -if AIRFLOW_V_3_0_PLUS: - from airflow.sdk import BaseOperator, TaskGroup, task -else: - from airflow.decorators import task # type: ignore[attr-defined,no-redef] - from airflow.models.baseoperator import BaseOperator # type: ignore[no-redef] - from airflow.utils.task_group import TaskGroup # type: ignore[no-redef] - BASH_OPERATOR_PATH = "airflow.providers.standard.operators.bash" PYTHON_OPERATOR_PATH = "airflow.providers.standard.operators.python"