Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@
if TYPE_CHECKING:
from openlineage.client.event_v2 import Dataset

from airflow.models import Operator
from airflow.providers.common.compat.lineage.entities import Table
from airflow.providers.common.compat.sdk import BaseOperator


def _iter_extractor_types() -> Iterator[type[BaseExtractor]]:
Expand Down Expand Up @@ -161,7 +161,7 @@ def extract_metadata(

return OperatorLineage()

def get_extractor_class(self, task: Operator) -> type[BaseExtractor] | None:
def get_extractor_class(self, task: BaseOperator) -> type[BaseExtractor] | None:
if task.task_type in self.extractors:
return self.extractors[task.task_type]

Expand All @@ -172,7 +172,7 @@ def method_exists(method_name):
return self.default_extractor
return None

def _get_extractor(self, task: Operator) -> BaseExtractor | None:
def _get_extractor(self, task: BaseOperator) -> BaseExtractor | None:
# TODO: Re-enable in Extractor PR
# self.instantiate_abstract_extractors(task)
extractor = self.get_extractor_class(task)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@

from typing import TYPE_CHECKING

from airflow.providers.common.compat.sdk import BaseOperator
from airflow.providers.openlineage.extractors.base import OperatorLineage
from airflow.providers.openlineage.version_compat import BaseOperator

if TYPE_CHECKING:
from airflow.sdk.definitions.context import Context
from airflow.providers.common.compat.sdk import Context


class EmptyOperator(BaseOperator):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from airflow import settings
from airflow.listeners import hookimpl
from airflow.models import DagRun, TaskInstance
from airflow.providers.common.compat.sdk import timeout, timezone
from airflow.providers.openlineage import conf
from airflow.providers.openlineage.extractors import ExtractorManager, OperatorLineage
from airflow.providers.openlineage.plugins.adapter import OpenLineageAdapter, RunState
Expand All @@ -48,7 +49,6 @@
is_selective_lineage_enabled,
print_warning,
)
from airflow.providers.openlineage.version_compat import timeout, timezone
from airflow.settings import configure_orm
from airflow.stats import Stats
from airflow.utils.state import TaskInstanceState
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from airflow.providers.openlineage.version_compat import AIRFLOW_V_3_0_PLUS

if TYPE_CHECKING:
from airflow.models import TaskInstance
from airflow.providers.common.compat.sdk import TaskInstance


def lineage_job_namespace():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@
from openlineage.client.facet_v2 import JobFacet, RunFacet
from sqlalchemy.engine import Engine

from airflow.providers.common.compat.sdk import BaseHook
from airflow.providers.common.sql.hooks.sql import DbApiHook
from airflow.sdk import BaseHook

log = logging.getLogger(__name__)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,14 @@

from airflow.models import Param
from airflow.models.xcom_arg import XComArg
from airflow.providers.common.compat.sdk import DAG

if TYPE_CHECKING:
from airflow.providers.common.compat.sdk import BaseOperator, MappedOperator
from airflow.providers.openlineage.utils.utils import AnyOperator
from airflow.sdk import DAG, BaseOperator
from airflow.sdk.definitions.mappedoperator import MappedOperator
from airflow.serialization.serialized_objects import SerializedDAG

T = TypeVar("T", bound=DAG | BaseOperator | MappedOperator)
else:
try:
from airflow.sdk import DAG
except ImportError:
from airflow.models import DAG

ENABLE_OL_PARAM_NAME = "_selective_enable_ol"
ENABLE_OL_PARAM = Param(True, const=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
)

if TYPE_CHECKING:
from airflow.utils.context import Context
from airflow.providers.common.compat.sdk import Context

log = logging.getLogger(__name__)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from sqlalchemy.engine import Engine
from sqlalchemy.sql import ClauseElement

from airflow.sdk import BaseHook
from airflow.providers.common.compat.sdk import BaseHook


log = logging.getLogger(__name__)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
# TODO: move this maybe to Airflow's logic?
from airflow.models import DagRun, TaskReschedule
from airflow.models.mappedoperator import MappedOperator as SerializedMappedOperator
from airflow.providers.common.compat.assets import Asset
from airflow.providers.common.compat.sdk import DAG, BaseOperator, BaseSensorOperator, MappedOperator
from airflow.providers.openlineage import (
__version__ as OPENLINEAGE_PROVIDER_VERSION,
conf,
Expand All @@ -57,11 +59,6 @@
from airflow.serialization.serialized_objects import SerializedBaseOperator, SerializedDAG
from airflow.utils.module_loading import import_string

if AIRFLOW_V_3_0_PLUS:
from airflow.sdk import BaseSensorOperator
else:
from airflow.sensors.base import BaseSensorOperator # type: ignore[no-redef]

if not AIRFLOW_V_3_0_PLUS:
from airflow.utils.session import NEW_SESSION, provide_session

Expand All @@ -72,9 +69,6 @@
from openlineage.client.facet_v2 import RunFacet, processing_engine_run

from airflow.models import TaskInstance
from airflow.providers.common.compat.assets import Asset
from airflow.sdk import DAG, BaseOperator
from airflow.sdk.definitions.mappedoperator import MappedOperator
from airflow.sdk.execution_time.secrets_masker import (
Redactable,
Redacted,
Expand All @@ -85,21 +79,6 @@

AnyOperator: TypeAlias = BaseOperator | MappedOperator | SerializedBaseOperator | SerializedMappedOperator
else:
try:
from airflow.sdk import DAG, BaseOperator
from airflow.sdk.definitions.mappedoperator import MappedOperator
except ImportError:
from airflow.models import DAG, BaseOperator, MappedOperator

try:
from airflow.providers.common.compat.assets import Asset
except ImportError:
if AIRFLOW_V_3_0_PLUS:
from airflow.sdk import Asset
else:
# dataset is renamed to asset since Airflow 3.0
from airflow.datasets import Dataset as Asset

try:
from airflow.sdk._shared.secrets_masker import (
Redactable,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,5 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]:

AIRFLOW_V_3_0_PLUS = get_base_airflow_version_tuple() >= (3, 0, 0)

if AIRFLOW_V_3_0_PLUS:
from airflow.sdk import BaseOperator
else:
from airflow.models import BaseOperator

try:
from airflow.sdk import timezone
from airflow.sdk.execution_time.timeout import timeout
except ImportError:
from airflow.utils import timezone # type: ignore[attr-defined,no-redef]
from airflow.utils.timeout import timeout # type: ignore[assignment,attr-defined,no-redef]

__all__ = ["AIRFLOW_V_3_0_PLUS", "BaseOperator", "timeout", "timezone"]

__all__ = ["AIRFLOW_V_3_0_PLUS"]
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@

from airflow import DAG
from airflow.models import Variable
from airflow.models.baseoperator import BaseOperator
from airflow.providers.common.compat.assets import Asset
from airflow.providers.common.compat.sdk import BaseOperator
from airflow.providers.standard.operators.bash import BashOperator
from airflow.providers.standard.operators.empty import EmptyOperator
from airflow.providers.standard.operators.python import PythonOperator
Expand Down
13 changes: 2 additions & 11 deletions providers/openlineage/tests/system/openlineage/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,10 @@
from dateutil.parser import parse
from jinja2 import Environment

from airflow.models.operator import BaseOperator

try:
from airflow.sdk import Variable
except ImportError:
from airflow.models.variable import Variable
from airflow.providers.common.compat.sdk import BaseOperator, Variable

if TYPE_CHECKING:
try:
from airflow.sdk.definitions.context import Context
except ImportError:
# TODO: Remove once provider drops support for Airflow 2
from airflow.utils.context import Context
from airflow.providers.common.compat.sdk import Context

log = logging.getLogger(__name__)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
import time

from airflow.models.dag import DAG
from airflow.models.operator import BaseOperator
from airflow.providers.common.compat.openlineage.facet import Dataset
from airflow.providers.common.compat.sdk import BaseOperator
from airflow.providers.openlineage.extractors import OperatorLineage


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,8 @@
from openlineage.client.event_v2 import Dataset
from openlineage.client.facet_v2 import BaseFacet, JobFacet, parent_run, sql_job

from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS

if AIRFLOW_V_3_0_PLUS:
from airflow.sdk import BaseOperator
else:
from airflow.models.baseoperator import BaseOperator # type: ignore[no-redef]

from airflow.models.taskinstance import TaskInstanceState
from airflow.providers.common.compat.sdk import BaseOperator
from airflow.providers.openlineage.extractors.base import (
BaseExtractor,
DefaultExtractor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

from airflow.models.taskinstance import TaskInstance
from airflow.providers.common.compat.lineage.entities import Column, File, Table, User
from airflow.providers.common.compat.sdk import BaseOperator, Context, ObjectStoragePath
from airflow.providers.openlineage.extractors import OperatorLineage
from airflow.providers.openlineage.extractors.manager import ExtractorManager
from airflow.providers.openlineage.utils.utils import Asset
Expand All @@ -43,13 +44,7 @@
if TYPE_CHECKING:
try:
from airflow.sdk.api.datamodels._generated import AssetEventDagRunReference, TIRunContext
from airflow.sdk.definitions.context import Context

except ImportError:
# TODO: Remove once provider drops support for Airflow 2
# TIRunContext is only used in Airflow 3 tests
from airflow.utils.context import Context

AssetEventDagRunReference = TIRunContext = Any # type: ignore[misc, assignment]


Expand All @@ -68,23 +63,6 @@ def hook_lineage_collector():
hook._hook_lineage_collector = None


if AIRFLOW_V_3_0_PLUS:
from airflow.sdk import BaseOperator, ObjectStoragePath
from airflow.sdk.api.datamodels._generated import TaskInstance as SDKTaskInstance
from airflow.sdk.execution_time import task_runner
from airflow.sdk.execution_time.comms import StartupDetails
from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance, parse
else:
from airflow.io.path import ObjectStoragePath # type: ignore[no-redef]
from airflow.models import BaseOperator

SDKTaskInstance = ... # type: ignore
task_runner = ... # type: ignore
StartupDetails = ... # type: ignore
RuntimeTaskInstance = ... # type: ignore
parse = ... # type: ignore


@pytest.mark.parametrize(
("uri", "dataset"),
(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,7 @@
else:
from airflow.utils import timezone # type: ignore[attr-defined,no-redef]

if AIRFLOW_V_3_0_PLUS:
from airflow.sdk import BaseOperator
else:
from airflow.models.baseoperator import BaseOperator # type: ignore[no-redef]
from airflow.providers.common.compat.sdk import BaseOperator

EXPECTED_TRY_NUMBER_1 = 1

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def test_lineage_parent_id(mock_run_id):
@pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test only for Airflow 3.0+")
def test_lineage_root_run_id_with_runtime_task_instance(create_runtime_ti):
"""Test lineage_root_run_id with real RuntimeTaskInstance object doesn't throw AttributeError."""
from airflow.sdk import BaseOperator
from airflow.providers.common.compat.sdk import BaseOperator

task = BaseOperator(task_id="test_task")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from pkg_resources import parse_version

from airflow.providers.common.compat.assets import Asset
from airflow.providers.common.compat.sdk import timezone
from airflow.providers.openlineage.plugins.facets import AirflowDebugRunFacet
from airflow.providers.openlineage.utils.utils import (
DagInfo,
Expand All @@ -53,9 +54,6 @@

if AIRFLOW_V_3_1_PLUS:
from airflow.models.dag import get_next_data_interval
from airflow.sdk import timezone
else:
from airflow.utils import timezone # type: ignore[attr-defined,no-redef]

if AIRFLOW_V_3_1_PLUS:
from airflow.sdk._shared.secrets_masker import DEFAULT_SENSITIVE_FIELDS, SecretsMasker
Expand All @@ -70,11 +68,10 @@
SecretsMasker,
)

from airflow.providers.common.compat.sdk import DAG

if AIRFLOW_V_3_0_PLUS:
from airflow.sdk import DAG
from airflow.utils.types import DagRunTriggeredByType
else:
from airflow import DAG


class SafeStrDict(dict):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from airflow.providers.common.compat.openlineage.facet import RunFacet

if TYPE_CHECKING:
from airflow.models.taskinstance import TaskInstance, TaskInstanceState
from airflow.providers.common.compat.sdk import TaskInstance, TaskInstanceState


@attrs.define
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,7 @@

from pendulum import now

from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS

if AIRFLOW_V_3_0_PLUS:
from airflow.sdk import dag, task
else:
from airflow.decorators import dag, task # type: ignore[attr-defined,no-redef]
from airflow.models import DAG
from airflow.providers.common.compat.sdk import DAG, dag, task
from airflow.providers.openlineage.utils.selective_enable import (
DISABLE_OL_PARAM,
ENABLE_OL_PARAM,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from airflow.models.dagrun import DagRun
from airflow.models.taskinstance import TaskInstance, TaskInstanceState
from airflow.providers.common.compat.assets import Asset
from airflow.providers.common.compat.sdk import BaseOperator, TaskGroup, task, timezone
from airflow.providers.openlineage.plugins.facets import AirflowDagRunFacet, AirflowJobFacet
from airflow.providers.openlineage.utils.utils import (
_MAX_DOC_BYTES,
Expand Down Expand Up @@ -57,7 +58,6 @@
from airflow.serialization.serialized_objects import SerializedBaseOperator
from airflow.timetables.events import EventsTimetable
from airflow.timetables.trigger import CronTriggerTimetable
from airflow.utils import timezone
from airflow.utils.session import create_session
from airflow.utils.state import DagRunState
from airflow.utils.types import DagRunType
Expand All @@ -66,13 +66,6 @@
from tests_common.test_utils.mock_operators import MockOperator
from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_3_PLUS, AIRFLOW_V_3_0_PLUS

if AIRFLOW_V_3_0_PLUS:
from airflow.sdk import BaseOperator, TaskGroup, task
else:
from airflow.decorators import task # type: ignore[attr-defined,no-redef]
from airflow.models.baseoperator import BaseOperator # type: ignore[no-redef]
from airflow.utils.task_group import TaskGroup # type: ignore[no-redef]

BASH_OPERATOR_PATH = "airflow.providers.standard.operators.bash"
PYTHON_OPERATOR_PATH = "airflow.providers.standard.operators.python"

Expand Down