Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 2 additions & 33 deletions airflow-core/tests/unit/listeners/class_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from airflow.listeners import hookimpl
from airflow.utils.state import DagRunState, TaskInstanceState

from tests_common.test_utils.version_compat import AIRFLOW_V_2_10_PLUS, AIRFLOW_V_3_0_PLUS
from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS

if AIRFLOW_V_3_0_PLUS:

Expand Down Expand Up @@ -64,8 +64,7 @@ def on_dag_run_success(self, dag_run, msg: str):
@hookimpl
def on_dag_run_failed(self, dag_run, msg: str):
self.state.append(DagRunState.FAILED)

elif AIRFLOW_V_2_10_PLUS:
else:

class ClassBasedListener: # type: ignore[no-redef]
def __init__(self):
Expand Down Expand Up @@ -95,36 +94,6 @@ def on_task_instance_success(self, previous_state, task_instance):
@hookimpl
def on_task_instance_failed(self, previous_state, task_instance, error: None | str | BaseException):
self.state.append(TaskInstanceState.FAILED)
else:

class ClassBasedListener: # type: ignore[no-redef]
def __init__(self):
self.started_component = None
self.stopped_component = None
self.state = []

@hookimpl
def on_starting(self, component):
self.started_component = component
self.state.append(DagRunState.RUNNING)

@hookimpl
def before_stopping(self, component):
global stopped_component
stopped_component = component
self.state.append(DagRunState.SUCCESS)

@hookimpl
def on_task_instance_running(self, previous_state, task_instance, session):
self.state.append(TaskInstanceState.RUNNING)

@hookimpl
def on_task_instance_success(self, previous_state, task_instance, session):
self.state.append(TaskInstanceState.SUCCESS)

@hookimpl
def on_task_instance_failed(self, previous_state, task_instance, session):
self.state.append(TaskInstanceState.FAILED)


def clear():
Expand Down
34 changes: 12 additions & 22 deletions devel-common/src/tests_common/test_utils/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
from airflow.exceptions import AirflowOptionalProviderFeatureException
from airflow.utils.helpers import prune_dict

from tests_common.test_utils.version_compat import AIRFLOW_V_2_10_PLUS

try:
# ImportError has been renamed to ParseImportError in airflow 2.10.0, and since our provider tests should
# run on all supported versions of Airflow, this compatibility shim falls back to the old ImportError so
Expand Down Expand Up @@ -86,35 +84,27 @@
except ModuleNotFoundError:
# dataset is renamed to asset since Airflow 3.0
from airflow.models.dataset import (
DagScheduleDatasetAliasReference as DagScheduleAssetAliasReference,
DagScheduleDatasetReference as DagScheduleAssetReference,
DatasetAliasModel as AssetAliasModel,
DatasetDagRunQueue as AssetDagRunQueue,
DatasetEvent as AssetEvent,
DatasetModel as AssetModel,
TaskOutletDatasetReference as TaskOutletAssetReference,
)

if AIRFLOW_V_2_10_PLUS:
from airflow.models.dataset import (
DagScheduleDatasetAliasReference as DagScheduleAssetAliasReference,
DatasetAliasModel as AssetAliasModel,
)


def deserialize_operator(serialized_operator: dict[str, Any]) -> Operator:
if AIRFLOW_V_2_10_PLUS:
# In airflow 2.10+ we can deserialize operator using regular deserialize method.
# We do not need to use deserialize_operator method explicitly but some tests are deserializing the
# operator and in the future they could use regular ``deserialize`` method. This method is a shim
# to make deserialization of operator works for tests run against older Airflow versions and tests
# should use that method instead of calling ``BaseSerialization.deserialize`` directly.
# We can remove this method and switch to the regular ``deserialize`` method as long as all providers
# are updated to airflow 2.10+.
from airflow.serialization.serialized_objects import BaseSerialization

return BaseSerialization.deserialize(serialized_operator)
from airflow.serialization.serialized_objects import SerializedBaseOperator

return SerializedBaseOperator.deserialize_operator(serialized_operator)
# In airflow 2.10+ we can deserialize operator using regular deserialize method.
# We do not need to use deserialize_operator method explicitly but some tests are deserializing the
# operator and in the future they could use regular ``deserialize`` method. This method is a shim
# to make deserialization of operator works for tests run against older Airflow versions and tests
# should use that method instead of calling ``BaseSerialization.deserialize`` directly.
# We can remove this method and switch to the regular ``deserialize`` method as long as all providers
# are updated to airflow 2.10+.
from airflow.serialization.serialized_objects import BaseSerialization

return BaseSerialization.deserialize(serialized_operator)


def connection_to_dict(
Expand Down
9 changes: 4 additions & 5 deletions devel-common/src/tests_common/test_utils/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
ParseImportError,
TaskOutletAssetReference,
)
from tests_common.test_utils.version_compat import AIRFLOW_V_2_10_PLUS, AIRFLOW_V_3_0_PLUS
from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS

if TYPE_CHECKING:
from pathlib import Path
Expand Down Expand Up @@ -159,11 +159,10 @@ def clear_db_assets():
session.query(AssetDagRunQueue).delete()
session.query(DagScheduleAssetReference).delete()
session.query(TaskOutletAssetReference).delete()
if AIRFLOW_V_2_10_PLUS:
from tests_common.test_utils.compat import AssetAliasModel, DagScheduleAssetAliasReference
from tests_common.test_utils.compat import AssetAliasModel, DagScheduleAssetAliasReference

session.query(AssetAliasModel).delete()
session.query(DagScheduleAssetAliasReference).delete()
session.query(AssetAliasModel).delete()
session.query(DagScheduleAssetAliasReference).delete()
if AIRFLOW_V_3_0_PLUS:
from airflow.models.asset import (
AssetActive,
Expand Down
1 change: 0 additions & 1 deletion devel-common/src/tests_common/test_utils/version_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,5 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]:
return airflow_version.major, airflow_version.minor, airflow_version.micro


AIRFLOW_V_2_10_PLUS = get_base_airflow_version_tuple() >= (2, 10, 0)
AIRFLOW_V_3_0_PLUS = get_base_airflow_version_tuple() >= (3, 0, 0)
[].sort()
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.providers.amazon.version_compat import AIRFLOW_V_2_10_PLUS
from airflow.providers.common.sql.hooks.sql import DbApiHook

if TYPE_CHECKING:
Expand Down Expand Up @@ -260,6 +259,4 @@ def get_openlineage_database_dialect(self, connection: Connection) -> str:

def get_openlineage_default_schema(self) -> str | None:
"""Return current schema. This is usually changed with ``SEARCH_PATH`` parameter."""
if AIRFLOW_V_2_10_PLUS:
return self.get_first("SELECT CURRENT_SCHEMA();")[0]
return super().get_openlineage_default_schema()
return self.get_first("SELECT CURRENT_SCHEMA();")[0]
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,4 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]:
return airflow_version.major, airflow_version.minor, airflow_version.micro


AIRFLOW_V_2_10_PLUS = get_base_airflow_version_tuple() >= (2, 10, 0)
AIRFLOW_V_3_0_PLUS = get_base_airflow_version_tuple() >= (3, 0, 0)
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@

from datetime import datetime

import pytest

from airflow.decorators import task
from airflow.models.baseoperator import chain
from airflow.models.dag import DAG
Expand All @@ -28,7 +26,6 @@
)

from system.amazon.aws.utils import ENV_ID_KEY, SystemTestContextBuilder
from tests_common.test_utils.version_compat import AIRFLOW_V_2_10_PLUS

"""
Prerequisites: The account which runs this test must manually have the following:
Expand All @@ -42,8 +39,6 @@
Then, the SageMakerNotebookOperator will run a test notebook. This should spin up a SageMaker training job, run the notebook, and exit successfully.
"""

pytestmark = pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="Test requires Airflow 2.10+")

DAG_ID = "example_sagemaker_unified_studio"

# Externally fetched variables:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@

from tests_common import RUNNING_TESTS_AGAINST_AIRFLOW_PACKAGES
from tests_common.test_utils.config import conf_vars
from tests_common.test_utils.version_compat import AIRFLOW_V_2_10_PLUS, AIRFLOW_V_3_0_PLUS
from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS

pytestmark = pytest.mark.db_test

Expand Down Expand Up @@ -380,7 +380,6 @@ def test_stopped_tasks(self):
class TestAwsEcsExecutor:
"""Tests the AWS ECS Executor."""

@pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="Test requires Airflow 2.10+")
@mock.patch("airflow.providers.amazon.aws.executors.ecs.ecs_executor.AwsEcsExecutor.change_state")
def test_execute(self, change_state_mock, mock_airflow_key, mock_executor, mock_cmd):
"""Test execution from end-to-end."""
Expand Down
25 changes: 8 additions & 17 deletions providers/amazon/tests/unit/amazon/aws/hooks/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,6 @@
)
from airflow.utils.timezone import datetime

from tests_common.test_utils.version_compat import AIRFLOW_V_2_10_PLUS


@pytest.fixture
def mocked_s3_res():
Expand All @@ -59,19 +57,17 @@ def s3_bucket(mocked_s3_res):
return bucket


if AIRFLOW_V_2_10_PLUS:

@pytest.fixture
def hook_lineage_collector():
from airflow.lineage import hook
from airflow.providers.common.compat.lineage.hook import get_hook_lineage_collector
@pytest.fixture
def hook_lineage_collector():
from airflow.lineage import hook
from airflow.providers.common.compat.lineage.hook import get_hook_lineage_collector

hook._hook_lineage_collector = None
hook._hook_lineage_collector = hook.HookLineageCollector()
hook._hook_lineage_collector = None
hook._hook_lineage_collector = hook.HookLineageCollector()

yield get_hook_lineage_collector()
yield get_hook_lineage_collector()

hook._hook_lineage_collector = None
hook._hook_lineage_collector = None


class TestAwsS3Hook:
Expand Down Expand Up @@ -448,7 +444,6 @@ def test_load_string(self, s3_bucket):
resource = boto3.resource("s3").Object(s3_bucket, "my_key")
assert resource.get()["Body"].read() == b"Cont\xc3\xa9nt"

@pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="Hook lineage works in Airflow >= 2.10.0")
def test_load_string_exposes_lineage(self, s3_bucket, hook_lineage_collector):
hook = S3Hook()

Expand Down Expand Up @@ -1023,7 +1018,6 @@ def test_load_file_gzip(self, s3_bucket, tmp_path):
resource = boto3.resource("s3").Object(s3_bucket, "my_key")
assert gz.decompress(resource.get()["Body"].read()) == b"Content"

@pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="Hook lineage works in Airflow >= 2.10.0")
def test_load_file_exposes_lineage(self, s3_bucket, tmp_path, hook_lineage_collector):
hook = S3Hook()
path = tmp_path / "testfile"
Expand Down Expand Up @@ -1091,7 +1085,6 @@ def test_copy_object_no_acl(
ACL="private",
)

@pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="Hook lineage works in Airflow >= 2.10.0")
@mock_aws
def test_copy_object_ol_instrumentation(self, s3_bucket, hook_lineage_collector):
mock_hook = S3Hook()
Expand Down Expand Up @@ -1230,7 +1223,6 @@ def test_download_file(self, mock_temp_file, tmp_path):

assert path.name == output_file

@pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="Hook lineage works in Airflow >= 2.10.0")
@mock.patch("airflow.providers.amazon.aws.hooks.s3.NamedTemporaryFile")
def test_download_file_exposes_lineage(self, mock_temp_file, tmp_path, hook_lineage_collector):
path = tmp_path / "airflow_tmp_test_s3_hook"
Expand Down Expand Up @@ -1273,7 +1265,6 @@ def test_download_file_with_preserve_name(self, mock_open, tmp_path):

mock_open.assert_called_once_with(path, "wb")

@pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="Hook lineage works in Airflow >= 2.10.0")
@mock.patch("airflow.providers.amazon.aws.hooks.s3.open")
def test_download_file_with_preserve_name_exposes_lineage(
self, mock_open, tmp_path, hook_lineage_collector
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from airflow.utils.timezone import datetime

from tests_common.test_utils.config import conf_vars
from tests_common.test_utils.version_compat import AIRFLOW_V_2_10_PLUS, AIRFLOW_V_3_0_PLUS
from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS


def get_time_str(time_in_milliseconds):
Expand Down Expand Up @@ -270,19 +270,13 @@ def test_read(self, monkeypatch):
{"timestamp": current_time, "message": "Third"},
],
)
if AIRFLOW_V_2_10_PLUS:
monkeypatch.setattr(self.cloudwatch_task_handler, "_read_from_logs_server", lambda a, b: ([], []))
msg_template = textwrap.dedent("""
INFO - ::group::Log message source details
*** Reading remote log from Cloudwatch log_group: {} log_stream: {}
INFO - ::endgroup::
{}
""")[1:][:-1] # Strip off leading and trailing new lines, but not spaces
else:
msg_template = textwrap.dedent("""
*** Reading remote log from Cloudwatch log_group: {} log_stream: {}
{}
""").strip()
monkeypatch.setattr(self.cloudwatch_task_handler, "_read_from_logs_server", lambda a, b: ([], []))
msg_template = textwrap.dedent("""
INFO - ::group::Log message source details
*** Reading remote log from Cloudwatch log_group: {} log_stream: {}
INFO - ::endgroup::
{}
""")[1:][:-1] # Strip off leading and trailing new lines, but not spaces

logs, metadata = self.cloudwatch_task_handler.read(self.ti)
if AIRFLOW_V_3_0_PLUS:
Expand Down
Loading