diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index fefaecec8b4d2..15b5a3605d5ca 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -427,7 +427,7 @@ repos: types_or: [python, pyi] args: [--fix] require_serial: true - additional_dependencies: ['ruff==0.11.13'] + additional_dependencies: ['ruff==0.12.1'] exclude: ^airflow-core/tests/unit/dags/test_imports\.py$|^performance/tests/test_.*\.py$ - id: ruff-format name: Run 'ruff format' @@ -437,7 +437,7 @@ repos: types_or: [python, pyi] args: [] require_serial: true - additional_dependencies: ['ruff==0.11.13'] + additional_dependencies: ['ruff==0.12.1'] exclude: ^airflow-core/tests/unit/dags/test_imports\.py$ - id: replace-bad-characters name: Replace bad characters @@ -1590,7 +1590,7 @@ repos: name: Check imports in providers entry: ./scripts/ci/pre_commit/check_imports_in_providers.py language: python - additional_dependencies: ['rich>=12.4.4', 'ruff==0.11.13'] + additional_dependencies: ['rich>=12.4.4', 'ruff==0.12.1'] files: ^providers/.*/src/airflow/providers/.*version_compat.*\.py$ require_serial: true ## ONLY ADD PRE-COMMITS HERE THAT REQUIRE CI IMAGE diff --git a/airflow-core/src/airflow/utils/log/logging_mixin.py b/airflow-core/src/airflow/utils/log/logging_mixin.py index 2e22c6999cb0b..1b3e24442dfa1 100644 --- a/airflow-core/src/airflow/utils/log/logging_mixin.py +++ b/airflow-core/src/airflow/utils/log/logging_mixin.py @@ -24,7 +24,7 @@ import sys from io import TextIOBase, UnsupportedOperation from logging import Handler, StreamHandler -from typing import IO, TYPE_CHECKING, Any, Optional, TypeVar, cast +from typing import IO, TYPE_CHECKING, Any, TypeVar, cast if TYPE_CHECKING: from logging import Logger @@ -72,9 +72,9 @@ class LoggingMixin: # Parent logger used by this class. It should match one of the loggers defined in the # `logging_config_class`. By default, this attribute is used to create the final name of the logger, and # will prefix the `_logger_name` with a separating dot. - _log_config_logger_name: Optional[str] = None # noqa: UP007 + _log_config_logger_name: str | None = None - _logger_name: Optional[str] = None # noqa: UP007 + _logger_name: str | None = None def __init__(self, context=None): self._set_context(context) diff --git a/airflow-core/tests/unit/always/test_providers_manager.py b/airflow-core/tests/unit/always/test_providers_manager.py index 4f2de48b70bac..d0a2037f9bdce 100644 --- a/airflow-core/tests/unit/always/test_providers_manager.py +++ b/airflow-core/tests/unit/always/test_providers_manager.py @@ -72,12 +72,12 @@ def test_providers_are_loaded(self): assert self._caplog.records == [] def test_hooks_deprecation_warnings_generated(self): + providers_manager = ProvidersManager() + providers_manager._provider_dict["test-package"] = ProviderInfo( + version="0.0.1", + data={"hook-class-names": ["airflow.providers.sftp.hooks.sftp.SFTPHook"]}, + ) with pytest.warns(expected_warning=DeprecationWarning, match="hook-class-names") as warning_records: - providers_manager = ProvidersManager() - providers_manager._provider_dict["test-package"] = ProviderInfo( - version="0.0.1", - data={"hook-class-names": ["airflow.providers.sftp.hooks.sftp.SFTPHook"]}, - ) providers_manager._discover_hooks() assert warning_records diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py index 3c107f61863f2..d9cdf9542aae3 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py @@ -1909,7 +1909,7 @@ def test_get_task_group_states_with_multiple_task(self, client, session, dag_mak }, } - def test_get_task_group_states_with_logical_dates(self, client, session, dag_maker, serialized=True): + def test_get_task_group_states_with_logical_dates(self, client, session, dag_maker): with dag_maker("test_get_task_group_states_with_logical_dates", serialized=True): with TaskGroup("group1"): EmptyOperator(task_id="task1") diff --git a/airflow-core/tests/unit/core/test_configuration.py b/airflow-core/tests/unit/core/test_configuration.py index 8b64640f93532..2de31a5dd1063 100644 --- a/airflow-core/tests/unit/core/test_configuration.py +++ b/airflow-core/tests/unit/core/test_configuration.py @@ -1040,11 +1040,14 @@ def test_deprecated_options(self): # Remove it so we are sure we use the right setting conf.remove_option("celery", "worker_concurrency") - with pytest.warns(DeprecationWarning): + with pytest.warns(DeprecationWarning, match="celeryd_concurrency"): with mock.patch.dict("os.environ", AIRFLOW__CELERY__CELERYD_CONCURRENCY="99"): assert conf.getint("celery", "worker_concurrency") == 99 - with pytest.warns(DeprecationWarning), conf_vars({("celery", "celeryd_concurrency"): "99"}): + with ( + pytest.warns(DeprecationWarning, match="celeryd_concurrency"), + conf_vars({("celery", "celeryd_concurrency"): "99"}), + ): assert conf.getint("celery", "worker_concurrency") == 99 @pytest.mark.parametrize( @@ -1109,13 +1112,13 @@ def test_deprecated_options_cmd(self): ): conf.remove_option("celery", "result_backend") with conf_vars({("celery", "celery_result_backend_cmd"): "/bin/echo 99"}): - with pytest.warns(DeprecationWarning): - tmp = None - if "AIRFLOW__CELERY__RESULT_BACKEND" in os.environ: - tmp = os.environ.pop("AIRFLOW__CELERY__RESULT_BACKEND") + tmp = None + if "AIRFLOW__CELERY__RESULT_BACKEND" in os.environ: + tmp = os.environ.pop("AIRFLOW__CELERY__RESULT_BACKEND") + with pytest.warns(DeprecationWarning, match="result_backend"): assert conf.getint("celery", "result_backend") == 99 - if tmp: - os.environ["AIRFLOW__CELERY__RESULT_BACKEND"] = tmp + if tmp: + os.environ["AIRFLOW__CELERY__RESULT_BACKEND"] = tmp def test_deprecated_values_from_conf(self): test_conf = AirflowConfigParser( @@ -1135,7 +1138,7 @@ def test_deprecated_values_from_conf(self): with pytest.warns(FutureWarning): test_conf.validate() - assert test_conf.get("core", "hostname_callable") == "airflow.utils.net.getfqdn" + assert test_conf.get("core", "hostname_callable") == "airflow.utils.net.getfqdn" @pytest.mark.parametrize( "old, new", @@ -1160,19 +1163,19 @@ def test_deprecated_env_vars_upgraded_and_removed(self, old, new): old_env_var = test_conf._env_var_name(old_section, old_key) new_env_var = test_conf._env_var_name(new_section, new_key) - with pytest.warns(FutureWarning): - with mock.patch.dict("os.environ", **{old_env_var: old_value}): - # Can't start with the new env var existing... - os.environ.pop(new_env_var, None) + with mock.patch.dict("os.environ", **{old_env_var: old_value}): + # Can't start with the new env var existing... + os.environ.pop(new_env_var, None) + with pytest.warns(FutureWarning): test_conf.validate() - assert test_conf.get(new_section, new_key) == new_value - # We also need to make sure the deprecated env var is removed - # so that any subprocesses don't use it in place of our updated - # value. - assert old_env_var not in os.environ - # and make sure we track the old value as well, under the new section/key - assert test_conf.upgraded_values[(new_section, new_key)] == old_value + assert test_conf.get(new_section, new_key) == new_value + # We also need to make sure the deprecated env var is removed + # so that any subprocesses don't use it in place of our updated + # value. + assert old_env_var not in os.environ + # and make sure we track the old value as well, under the new section/key + assert test_conf.upgraded_values[(new_section, new_key)] == old_value @pytest.mark.parametrize( "conf_dict", @@ -1200,10 +1203,10 @@ def make_config(): test_conf.validate() return test_conf - with pytest.warns(FutureWarning): - with mock.patch.dict("os.environ", AIRFLOW__CORE__HOSTNAME_CALLABLE="airflow.utils.net:getfqdn"): + with mock.patch.dict("os.environ", AIRFLOW__CORE__HOSTNAME_CALLABLE="airflow.utils.net:getfqdn"): + with pytest.warns(FutureWarning): test_conf = make_config() - assert test_conf.get("core", "hostname_callable") == "airflow.utils.net.getfqdn" + assert test_conf.get("core", "hostname_callable") == "airflow.utils.net.getfqdn" with reset_warning_registry(): with warnings.catch_warnings(record=True) as warning: diff --git a/airflow-core/tests/unit/datasets/test_dataset.py b/airflow-core/tests/unit/datasets/test_dataset.py index 31211ad0dd247..4ad0946549937 100644 --- a/airflow-core/tests/unit/datasets/test_dataset.py +++ b/airflow-core/tests/unit/datasets/test_dataset.py @@ -73,12 +73,14 @@ ), ) def test_backward_compat_import_before_airflow_3_2(module_path, attr_name, expected_value, warning_message): - with pytest.warns() as record: - import importlib + import importlib + with pytest.warns() as record: mod = importlib.import_module(module_path, __name__) attr = getattr(mod, attr_name) - assert f"{attr.__module__}.{attr.__name__}" == expected_value - + assert f"{attr.__module__}.{attr.__name__}" == expected_value assert record[0].category is DeprecationWarning assert str(record[0].message) == warning_message + + +# ruff: noqa: PT031 diff --git a/airflow-core/tests/unit/listeners/test_listeners.py b/airflow-core/tests/unit/listeners/test_listeners.py index 3fceaaf0843cc..9a4c85cb504a1 100644 --- a/airflow-core/tests/unit/listeners/test_listeners.py +++ b/airflow-core/tests/unit/listeners/test_listeners.py @@ -69,7 +69,7 @@ def clean_listener_manager(): @provide_session -def test_listener_gets_calls(create_task_instance, session=None): +def test_listener_gets_calls(create_task_instance, session): lm = get_listener_manager() lm.add_listener(full_listener) @@ -84,7 +84,7 @@ def test_listener_gets_calls(create_task_instance, session=None): @provide_session -def test_multiple_listeners(create_task_instance, session=None): +def test_multiple_listeners(create_task_instance, session): lm = get_listener_manager() lm.add_listener(full_listener) lm.add_listener(lifecycle_listener) @@ -105,7 +105,7 @@ def test_multiple_listeners(create_task_instance, session=None): @provide_session -def test_listener_gets_only_subscribed_calls(create_task_instance, session=None): +def test_listener_gets_only_subscribed_calls(create_task_instance, session): lm = get_listener_manager() lm.add_listener(partial_listener) @@ -130,7 +130,7 @@ def test_listener_suppresses_exceptions(create_task_instance, session, cap_struc @provide_session -def test_listener_captures_failed_taskinstances(create_task_instance_of_operator, session=None): +def test_listener_captures_failed_taskinstances(create_task_instance_of_operator, session): lm = get_listener_manager() lm.add_listener(full_listener) @@ -145,7 +145,7 @@ def test_listener_captures_failed_taskinstances(create_task_instance_of_operator @provide_session -def test_listener_captures_longrunning_taskinstances(create_task_instance_of_operator, session=None): +def test_listener_captures_longrunning_taskinstances(create_task_instance_of_operator, session): lm = get_listener_manager() lm.add_listener(full_listener) @@ -159,7 +159,7 @@ def test_listener_captures_longrunning_taskinstances(create_task_instance_of_ope @provide_session -def test_class_based_listener(create_task_instance, session=None): +def test_class_based_listener(create_task_instance, session): lm = get_listener_manager() listener = class_listener.ClassBasedListener() lm.add_listener(listener) diff --git a/airflow-core/tests/unit/models/test_dagbag.py b/airflow-core/tests/unit/models/test_dagbag.py index c3a0fcd929f8e..0e863bd17a58c 100644 --- a/airflow-core/tests/unit/models/test_dagbag.py +++ b/airflow-core/tests/unit/models/test_dagbag.py @@ -1032,7 +1032,7 @@ def test_capture_warnings(self): with pytest.warns(UserWarning, match="(Foo|Bar|Baz)") as ctx: with _capture_with_reraise() as cw: self.raise_warnings() - assert len(cw) == 3 + assert len(cw) == 3 assert len(ctx.list) == 3 def test_capture_warnings_with_parent_error_filter(self): diff --git a/airflow-core/tests/unit/models/test_taskinstance.py b/airflow-core/tests/unit/models/test_taskinstance.py index 8f0de3d07c95e..a25924e4ab17f 100644 --- a/airflow-core/tests/unit/models/test_taskinstance.py +++ b/airflow-core/tests/unit/models/test_taskinstance.py @@ -404,7 +404,7 @@ def test_pool_slots_property(self): ) @provide_session - def test_ti_updates_with_task(self, create_task_instance, session=None): + def test_ti_updates_with_task(self, create_task_instance, session): """ test that updating the executor_config propagates to the TaskInstance DB """ @@ -1269,7 +1269,7 @@ def test_respects_prev_dagrun_dep(self, create_task_instance): ) @provide_session def test_are_dependents_done( - self, downstream_ti_state, expected_are_dependents_done, create_task_instance, session=None + self, downstream_ti_state, expected_are_dependents_done, create_task_instance, session ): ti = create_task_instance(session=session) dag = ti.task.dag @@ -2288,7 +2288,7 @@ def test_template_with_json_variable_missing(self, create_task_instance, session ti.task.render_template('{{ var.json.get("missing_variable") }}', context) @provide_session - def test_handle_failure(self, dag_maker, session=None): + def test_handle_failure(self, dag_maker, session): class CustomOp(BaseOperator): def execute(self, context): ... diff --git a/airflow-core/tests/unit/models/test_timestamp.py b/airflow-core/tests/unit/models/test_timestamp.py index 912347c602e77..5aaa956829d72 100644 --- a/airflow-core/tests/unit/models/test_timestamp.py +++ b/airflow-core/tests/unit/models/test_timestamp.py @@ -54,7 +54,7 @@ def add_log(execdate, session, dag_maker, timezone_override=None): @provide_session -def test_timestamp_behaviour(dag_maker, session=None): +def test_timestamp_behaviour(dag_maker, session): execdate = timezone.utcnow() with time_machine.travel(execdate, tick=False): current_time = timezone.utcnow() @@ -66,7 +66,7 @@ def test_timestamp_behaviour(dag_maker, session=None): @provide_session -def test_timestamp_behaviour_with_timezone(dag_maker, session=None): +def test_timestamp_behaviour_with_timezone(dag_maker, session): execdate = timezone.utcnow() with time_machine.travel(execdate, tick=False): current_time = timezone.utcnow() diff --git a/devel-common/pyproject.toml b/devel-common/pyproject.toml index 2af4184c0e2e5..1a37f7cafcfac 100644 --- a/devel-common/pyproject.toml +++ b/devel-common/pyproject.toml @@ -35,7 +35,7 @@ dependencies = [ "kgb>=7.2.0", "requests_mock>=1.11.0", "rich>=13.6.0", - "ruff==0.11.13", + "ruff==0.12.1", "semver>=3.0.2", "time-machine>=2.13.0", "wheel>=0.42.0", diff --git a/helm-tests/tests/helm_tests/airflow_aux/test_container_lifecycle.py b/helm-tests/tests/helm_tests/airflow_aux/test_container_lifecycle.py index 77e8cb09a33ff..440dc68b4da7e 100644 --- a/helm-tests/tests/helm_tests/airflow_aux/test_container_lifecycle.py +++ b/helm-tests/tests/helm_tests/airflow_aux/test_container_lifecycle.py @@ -223,3 +223,6 @@ def test_log_groomer_sidecar_container_setting(self, hook_type="preStop"): assert lifecycle_hook_params["lifecycle_parsed"] == jmespath.search( f"spec.template.spec.containers[1].lifecycle.{hook_type}", doc ) + + +# ruff: noqa: PT028 diff --git a/providers/amazon/tests/unit/amazon/aws/hooks/test_eks.py b/providers/amazon/tests/unit/amazon/aws/hooks/test_eks.py index ed95a4f75e5d0..ebd45661cd315 100644 --- a/providers/amazon/tests/unit/amazon/aws/hooks/test_eks.py +++ b/providers/amazon/tests/unit/amazon/aws/hooks/test_eks.py @@ -1360,3 +1360,6 @@ def assert_is_valid_uri(value: str) -> None: assert all([result.scheme, result.netloc, result.path]) assert REGION in value + + +# ruff: noqa: PT028 diff --git a/providers/amazon/tests/unit/amazon/aws/utils/test_connection_wrapper.py b/providers/amazon/tests/unit/amazon/aws/utils/test_connection_wrapper.py index 02275cccaea59..796d105724159 100644 --- a/providers/amazon/tests/unit/amazon/aws/utils/test_connection_wrapper.py +++ b/providers/amazon/tests/unit/amazon/aws/utils/test_connection_wrapper.py @@ -120,7 +120,7 @@ def test_unexpected_aws_connection_type(self, conn_type): warning_message = f"expected connection type 'aws', got '{conn_type}'" with pytest.warns(UserWarning, match=warning_message): wrap_conn = AwsConnectionWrapper(conn=mock_connection_factory(conn_type=conn_type)) - assert wrap_conn.conn_type == conn_type + assert wrap_conn.conn_type == conn_type @pytest.mark.parametrize("aws_session_token", [None, "mock-aws-session-token"]) @pytest.mark.parametrize("aws_secret_access_key", ["mock-aws-secret-access-key"]) diff --git a/providers/celery/tests/unit/celery/executors/test_celery_executor.py b/providers/celery/tests/unit/celery/executors/test_celery_executor.py index 974dc2b06f72e..d71fa9a64bb56 100644 --- a/providers/celery/tests/unit/celery/executors/test_celery_executor.py +++ b/providers/celery/tests/unit/celery/executors/test_celery_executor.py @@ -260,7 +260,7 @@ def test_cleanup_stuck_queued_tasks(self, mock_fail): executor.running = {ti.key} executor.tasks = {ti.key: AsyncResult("231")} assert executor.has_task(ti) - with pytest.warns(DeprecationWarning): + with pytest.warns(DeprecationWarning, match="cleanup_stuck_queued_tasks"): executor.cleanup_stuck_queued_tasks(tis=tis) executor.sync() assert executor.tasks == {} diff --git a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/executors/test_kubernetes_executor.py b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/executors/test_kubernetes_executor.py index 062c889089e1a..31899080ceaf1 100644 --- a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/executors/test_kubernetes_executor.py +++ b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/executors/test_kubernetes_executor.py @@ -1250,7 +1250,7 @@ def test_cleanup_stuck_queued_tasks(self, mock_kube_dynamic_client, dag_maker, c executor.kube_scheduler = mock.MagicMock() ti.refresh_from_db() tis = [ti] - with pytest.warns(DeprecationWarning): + with pytest.warns(DeprecationWarning, match="cleanup_stuck_queued_tasks"): executor.cleanup_stuck_queued_tasks(tis=tis) executor.kube_scheduler.delete_pod.assert_called_once() assert executor.running == set() diff --git a/providers/databricks/tests/unit/databricks/hooks/test_databricks_sql.py b/providers/databricks/tests/unit/databricks/hooks/test_databricks_sql.py index 20e2c081f740d..f7dfc1d73ac69 100644 --- a/providers/databricks/tests/unit/databricks/hooks/test_databricks_sql.py +++ b/providers/databricks/tests/unit/databricks/hooks/test_databricks_sql.py @@ -573,3 +573,6 @@ def test_get_df(df_type, df_class, description): assert df.row(1)[0] == result_sets[1][0] assert isinstance(df, df_class) + + +# ruff: noqa: PT028 diff --git a/providers/google/tests/unit/google/cloud/hooks/test_bigquery.py b/providers/google/tests/unit/google/cloud/hooks/test_bigquery.py index 0b0e8d94ce2a6..e46142e88690c 100644 --- a/providers/google/tests/unit/google/cloud/hooks/test_bigquery.py +++ b/providers/google/tests/unit/google/cloud/hooks/test_bigquery.py @@ -788,7 +788,7 @@ def test_create_empty_table_view(self, mock_bq_client, mock_table): view=view, retry=DEFAULT_RETRY, ) - assert_warning("create_empty_table", warnings) + assert_warning("create_empty_table", warnings) body = {"tableReference": TABLE_REFERENCE_REPR, "view": view} mock_table.from_api_repr.assert_called_once_with(body) @@ -827,7 +827,7 @@ def test_create_table_view(self, mock_bq_client, mock_table): def test_create_empty_table_succeed(self, mock_bq_client, mock_table): with pytest.warns(AirflowProviderDeprecationWarning) as warnings: self.hook.create_empty_table(project_id=PROJECT_ID, dataset_id=DATASET_ID, table_id=TABLE_ID) - assert_warning("create_empty_table", warnings) + assert_warning("create_empty_table", warnings) body = { "tableReference": { @@ -880,22 +880,22 @@ def test_create_empty_table_with_extras_succeed(self, mock_bq_client, mock_table time_partitioning=time_partitioning, cluster_fields=cluster_fields, ) - assert_warning("create_empty_table", warnings) + assert_warning("create_empty_table", warnings) - body = { - "tableReference": { - "tableId": TABLE_ID, - "projectId": PROJECT_ID, - "datasetId": DATASET_ID, - }, - "schema": {"fields": schema_fields}, - "timePartitioning": time_partitioning, - "clustering": {"fields": cluster_fields}, - } - mock_table.from_api_repr.assert_called_once_with(body) - mock_bq_client.return_value.create_table.assert_called_once_with( - table=mock_table.from_api_repr.return_value, exists_ok=True, retry=DEFAULT_RETRY - ) + body = { + "tableReference": { + "tableId": TABLE_ID, + "projectId": PROJECT_ID, + "datasetId": DATASET_ID, + }, + "schema": {"fields": schema_fields}, + "timePartitioning": time_partitioning, + "clustering": {"fields": cluster_fields}, + } + mock_table.from_api_repr.assert_called_once_with(body) + mock_bq_client.return_value.create_table.assert_called_once_with( + table=mock_table.from_api_repr.return_value, exists_ok=True, retry=DEFAULT_RETRY + ) @mock.patch("airflow.providers.google.cloud.hooks.bigquery.Table.from_api_repr") @mock.patch("airflow.providers.google.cloud.hooks.bigquery.Client") diff --git a/providers/google/tests/unit/google/cloud/operators/test_automl.py b/providers/google/tests/unit/google/cloud/operators/test_automl.py index 34c0397a43906..18f4de4af709d 100644 --- a/providers/google/tests/unit/google/cloud/operators/test_automl.py +++ b/providers/google/tests/unit/google/cloud/operators/test_automl.py @@ -91,7 +91,7 @@ def test_execute(self, mock_hook): project_id=GCP_PROJECT_ID, task_id=TASK_ID, ) - op.execute(context=mock.MagicMock()) + op.execute(context=mock.MagicMock()) mock_hook.return_value.create_model.assert_called_once_with( model=MODEL, @@ -194,7 +194,9 @@ def test_hook_type(self): task_id=TASK_ID, operation_params={"TEST_KEY": "TEST_VALUE"}, ) + with pytest.warns(AirflowProviderDeprecationWarning): assert isinstance(op.hook, CloudAutoMLHook) + with pytest.warns(AirflowProviderDeprecationWarning): op = AutoMLPredictOperator( endpoint_id="endpoint_id", location=GCP_LOCATION, @@ -203,7 +205,7 @@ def test_hook_type(self): task_id=TASK_ID, operation_params={"TEST_KEY": "TEST_VALUE"}, ) - assert isinstance(op.hook, PredictionServiceHook) + assert isinstance(op.hook, PredictionServiceHook) class TestAutoMLCreateImportOperator: @@ -218,7 +220,7 @@ def test_execute(self, mock_hook): project_id=GCP_PROJECT_ID, task_id=TASK_ID, ) - op.execute(context=mock.MagicMock()) + op.execute(context=mock.MagicMock()) mock_hook.return_value.create_dataset.assert_called_once_with( dataset=DATASET, diff --git a/providers/google/tests/unit/google/cloud/operators/test_bigquery.py b/providers/google/tests/unit/google/cloud/operators/test_bigquery.py index 114c6f899a595..dc4ef3135d907 100644 --- a/providers/google/tests/unit/google/cloud/operators/test_bigquery.py +++ b/providers/google/tests/unit/google/cloud/operators/test_bigquery.py @@ -348,7 +348,7 @@ def test_deprecation_warning(self): project_id=TEST_GCP_PROJECT_ID, table_id=TEST_TABLE_ID, ) - assert_warning("BigQueryCreateEmptyTableOperator", warnings) + assert_warning("BigQueryCreateEmptyTableOperator", warnings) @mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook") def test_execute(self, mock_hook): @@ -360,22 +360,22 @@ def test_execute(self, mock_hook): table_id=TEST_TABLE_ID, ) - operator.execute(context=MagicMock()) + operator.execute(context=MagicMock()) - mock_hook.return_value.create_empty_table.assert_called_once_with( - dataset_id=TEST_DATASET, - project_id=TEST_GCP_PROJECT_ID, - table_id=TEST_TABLE_ID, - schema_fields=None, - time_partitioning={}, - cluster_fields=None, - labels=None, - view=None, - materialized_view=None, - encryption_configuration=None, - table_resource=None, - exists_ok=False, - ) + mock_hook.return_value.create_empty_table.assert_called_once_with( + dataset_id=TEST_DATASET, + project_id=TEST_GCP_PROJECT_ID, + table_id=TEST_TABLE_ID, + schema_fields=None, + time_partitioning={}, + cluster_fields=None, + labels=None, + view=None, + materialized_view=None, + encryption_configuration=None, + table_resource=None, + exists_ok=False, + ) @mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook") def test_create_view(self, mock_hook): @@ -388,22 +388,22 @@ def test_create_view(self, mock_hook): view=VIEW_DEFINITION, ) - operator.execute(context=MagicMock()) + operator.execute(context=MagicMock()) - mock_hook.return_value.create_empty_table.assert_called_once_with( - dataset_id=TEST_DATASET, - project_id=TEST_GCP_PROJECT_ID, - table_id=TEST_TABLE_ID, - schema_fields=None, - time_partitioning={}, - cluster_fields=None, - labels=None, - view=VIEW_DEFINITION, - materialized_view=None, - encryption_configuration=None, - table_resource=None, - exists_ok=False, - ) + mock_hook.return_value.create_empty_table.assert_called_once_with( + dataset_id=TEST_DATASET, + project_id=TEST_GCP_PROJECT_ID, + table_id=TEST_TABLE_ID, + schema_fields=None, + time_partitioning={}, + cluster_fields=None, + labels=None, + view=VIEW_DEFINITION, + materialized_view=None, + encryption_configuration=None, + table_resource=None, + exists_ok=False, + ) @mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook") def test_create_materialized_view(self, mock_hook): @@ -416,22 +416,22 @@ def test_create_materialized_view(self, mock_hook): materialized_view=MATERIALIZED_VIEW_DEFINITION, ) - operator.execute(context=MagicMock()) + operator.execute(context=MagicMock()) - mock_hook.return_value.create_empty_table.assert_called_once_with( - dataset_id=TEST_DATASET, - project_id=TEST_GCP_PROJECT_ID, - table_id=TEST_TABLE_ID, - schema_fields=None, - time_partitioning={}, - cluster_fields=None, - labels=None, - view=None, - materialized_view=MATERIALIZED_VIEW_DEFINITION, - encryption_configuration=None, - table_resource=None, - exists_ok=False, - ) + mock_hook.return_value.create_empty_table.assert_called_once_with( + dataset_id=TEST_DATASET, + project_id=TEST_GCP_PROJECT_ID, + table_id=TEST_TABLE_ID, + schema_fields=None, + time_partitioning={}, + cluster_fields=None, + labels=None, + view=None, + materialized_view=MATERIALIZED_VIEW_DEFINITION, + encryption_configuration=None, + table_resource=None, + exists_ok=False, + ) @mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook") def test_create_clustered_empty_table(self, mock_hook): @@ -453,21 +453,21 @@ def test_create_clustered_empty_table(self, mock_hook): cluster_fields=cluster_fields, ) - operator.execute(context=MagicMock()) - mock_hook.return_value.create_empty_table.assert_called_once_with( - dataset_id=TEST_DATASET, - project_id=TEST_GCP_PROJECT_ID, - table_id=TEST_TABLE_ID, - schema_fields=schema_fields, - time_partitioning=time_partitioning, - cluster_fields=cluster_fields, - labels=None, - view=None, - materialized_view=None, - encryption_configuration=None, - table_resource=None, - exists_ok=False, - ) + operator.execute(context=MagicMock()) + mock_hook.return_value.create_empty_table.assert_called_once_with( + dataset_id=TEST_DATASET, + project_id=TEST_GCP_PROJECT_ID, + table_id=TEST_TABLE_ID, + schema_fields=schema_fields, + time_partitioning=time_partitioning, + cluster_fields=cluster_fields, + labels=None, + view=None, + materialized_view=None, + encryption_configuration=None, + table_resource=None, + exists_ok=False, + ) @pytest.mark.parametrize( "if_exists, is_conflict, expected_error, log_msg", @@ -492,17 +492,17 @@ def test_create_existing_table(self, mock_hook, caplog, if_exists, is_conflict, view=VIEW_DEFINITION, if_exists=if_exists, ) - if is_conflict: - mock_hook.return_value.create_empty_table.side_effect = Conflict("any") - else: - mock_hook.return_value.create_empty_table.side_effect = None - if expected_error is not None: - with pytest.raises(expected_error): - operator.execute(context=MagicMock()) - else: + if is_conflict: + mock_hook.return_value.create_empty_table.side_effect = Conflict("any") + else: + mock_hook.return_value.create_empty_table.side_effect = None + if expected_error is not None: + with pytest.raises(expected_error): operator.execute(context=MagicMock()) - if log_msg is not None: - assert log_msg in caplog.text + else: + operator.execute(context=MagicMock()) + if log_msg is not None: + assert log_msg in caplog.text @mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook") def test_get_openlineage_facets_on_complete(self, mock_hook): @@ -528,22 +528,22 @@ def test_get_openlineage_facets_on_complete(self, mock_hook): table_id=TEST_TABLE_ID, schema_fields=schema_fields, ) - operator.execute(context=MagicMock()) + operator.execute(context=MagicMock()) - mock_hook.return_value.create_empty_table.assert_called_once_with( - dataset_id=TEST_DATASET, - project_id=TEST_GCP_PROJECT_ID, - table_id=TEST_TABLE_ID, - schema_fields=schema_fields, - time_partitioning={}, - cluster_fields=None, - labels=None, - view=None, - materialized_view=None, - encryption_configuration=None, - table_resource=None, - exists_ok=False, - ) + mock_hook.return_value.create_empty_table.assert_called_once_with( + dataset_id=TEST_DATASET, + project_id=TEST_GCP_PROJECT_ID, + table_id=TEST_TABLE_ID, + schema_fields=schema_fields, + time_partitioning={}, + cluster_fields=None, + labels=None, + view=None, + materialized_view=None, + encryption_configuration=None, + table_resource=None, + exists_ok=False, + ) result = operator.get_openlineage_facets_on_complete(None) assert not result.run_facets @@ -576,7 +576,7 @@ def test_deprecation_warning(self): }, }, ) - assert_warning("BigQueryCreateExternalTableOperator", warnings) + assert_warning("BigQueryCreateExternalTableOperator", warnings) @mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook") def test_execute_with_csv_format(self, mock_hook): @@ -615,14 +615,14 @@ def test_execute_with_csv_format(self, mock_hook): table_resource=table_resource, ) - mock_hook.return_value.split_tablename.return_value = ( - TEST_GCP_PROJECT_ID, - TEST_DATASET, - TEST_TABLE_ID, - ) + mock_hook.return_value.split_tablename.return_value = ( + TEST_GCP_PROJECT_ID, + TEST_DATASET, + TEST_TABLE_ID, + ) - operator.execute(context=MagicMock()) - mock_hook.return_value.create_empty_table.assert_called_once_with(table_resource=table_resource) + operator.execute(context=MagicMock()) + mock_hook.return_value.create_empty_table.assert_called_once_with(table_resource=table_resource) @mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook") def test_execute_with_parquet_format(self, mock_hook): @@ -654,14 +654,14 @@ def test_execute_with_parquet_format(self, mock_hook): table_resource=table_resource, ) - mock_hook.return_value.split_tablename.return_value = ( - TEST_GCP_PROJECT_ID, - TEST_DATASET, - TEST_TABLE_ID, - ) + mock_hook.return_value.split_tablename.return_value = ( + TEST_GCP_PROJECT_ID, + TEST_DATASET, + TEST_TABLE_ID, + ) - operator.execute(context=MagicMock()) - mock_hook.return_value.create_empty_table.assert_called_once_with(table_resource=table_resource) + operator.execute(context=MagicMock()) + mock_hook.return_value.create_empty_table.assert_called_once_with(table_resource=table_resource) @mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook") def test_get_openlineage_facets_on_complete(self, mock_hook): @@ -694,14 +694,14 @@ def test_get_openlineage_facets_on_complete(self, mock_hook): table_resource=table_resource, ) - mock_hook.return_value.split_tablename.return_value = ( - TEST_GCP_PROJECT_ID, - TEST_DATASET, - TEST_TABLE_ID, - ) + mock_hook.return_value.split_tablename.return_value = ( + TEST_GCP_PROJECT_ID, + TEST_DATASET, + TEST_TABLE_ID, + ) - operator.execute(context=MagicMock()) - mock_hook.return_value.create_empty_table.assert_called_once_with(table_resource=table_resource) + operator.execute(context=MagicMock()) + mock_hook.return_value.create_empty_table.assert_called_once_with(table_resource=table_resource) result = operator.get_openlineage_facets_on_complete(None) assert not result.run_facets diff --git a/providers/google/tests/unit/google/cloud/operators/test_life_sciences.py b/providers/google/tests/unit/google/cloud/operators/test_life_sciences.py index e49c428d1a44d..d125bdf10dcee 100644 --- a/providers/google/tests/unit/google/cloud/operators/test_life_sciences.py +++ b/providers/google/tests/unit/google/cloud/operators/test_life_sciences.py @@ -47,8 +47,8 @@ def test_executes(self, mock_hook): operator = LifeSciencesRunPipelineOperator( task_id="task-id", body=TEST_BODY, location=TEST_LOCATION, project_id=TEST_PROJECT_ID ) - context = mock.MagicMock() - result = operator.execute(context=context) + context = mock.MagicMock() + result = operator.execute(context=context) assert result == TEST_OPERATION @@ -62,6 +62,6 @@ def test_executes_without_project_id(self, mock_hook): body=TEST_BODY, location=TEST_LOCATION, ) - context = mock.MagicMock() - result = operator.execute(context=context) + context = mock.MagicMock() + result = operator.execute(context=context) assert result == TEST_OPERATION diff --git a/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_msgraph.py b/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_msgraph.py index d6a9fae4bdcaf..2168744edcade 100644 --- a/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_msgraph.py +++ b/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_msgraph.py @@ -63,17 +63,17 @@ def test_execute_with_old_result_processor_signature(self): ): results, events = execute_operator(operator) - assert len(results) == 30 - assert results == users.get("value") + next_users.get("value") - assert len(events) == 2 - assert isinstance(events[0], TriggerEvent) - assert events[0].payload["status"] == "success" - assert events[0].payload["type"] == "builtins.dict" - assert events[0].payload["response"] == json.dumps(users) - assert isinstance(events[1], TriggerEvent) - assert events[1].payload["status"] == "success" - assert events[1].payload["type"] == "builtins.dict" - assert events[1].payload["response"] == json.dumps(next_users) + assert len(results) == 30 + assert results == users.get("value") + next_users.get("value") + assert len(events) == 2 + assert isinstance(events[0], TriggerEvent) + assert events[0].payload["status"] == "success" + assert events[0].payload["type"] == "builtins.dict" + assert events[0].payload["response"] == json.dumps(users) + assert isinstance(events[1], TriggerEvent) + assert events[1].payload["status"] == "success" + assert events[1].payload["type"] == "builtins.dict" + assert events[1].payload["response"] == json.dumps(next_users) def test_execute_with_new_result_processor_signature(self): users = load_json_from_resources(dirname(__file__), "..", "resources", "users.json") @@ -124,17 +124,17 @@ def test_execute_with_old_paginate_function_signature(self): ): results, events = execute_operator(operator) - assert len(results) == 30 - assert results == users.get("value") + next_users.get("value") - assert len(events) == 2 - assert isinstance(events[0], TriggerEvent) - assert events[0].payload["status"] == "success" - assert events[0].payload["type"] == "builtins.dict" - assert events[0].payload["response"] == json.dumps(users) - assert isinstance(events[1], TriggerEvent) - assert events[1].payload["status"] == "success" - assert events[1].payload["type"] == "builtins.dict" - assert events[1].payload["response"] == json.dumps(next_users) + assert len(results) == 30 + assert results == users.get("value") + next_users.get("value") + assert len(events) == 2 + assert isinstance(events[0], TriggerEvent) + assert events[0].payload["status"] == "success" + assert events[0].payload["type"] == "builtins.dict" + assert events[0].payload["response"] == json.dumps(users) + assert isinstance(events[1], TriggerEvent) + assert events[1].payload["status"] == "success" + assert events[1].payload["type"] == "builtins.dict" + assert events[1].payload["response"] == json.dumps(next_users) def test_execute_when_do_xcom_push_is_false(self): users = load_json_from_resources(dirname(__file__), "..", "resources", "users.json") @@ -193,11 +193,11 @@ def custom_event_handler(context: Context, event: dict[Any, Any] | None = None): ): results, events = execute_operator(operator) - assert not results - assert len(events) == 1 - assert isinstance(events[0], TriggerEvent) - assert events[0].payload["status"] == "failure" - assert events[0].payload["message"] == "An error occurred" + assert not results + assert len(events) == 1 + assert isinstance(events[0], TriggerEvent) + assert events[0].payload["status"] == "failure" + assert events[0].payload["message"] == "An error occurred" def test_execute_when_an_exception_occurs_on_custom_event_handler_with_new_signature(self): with self.patch_hook_and_request_adapter(AirflowException("An error occurred")): diff --git a/providers/microsoft/azure/tests/unit/microsoft/azure/sensors/test_msgraph.py b/providers/microsoft/azure/tests/unit/microsoft/azure/sensors/test_msgraph.py index 055165d4fc70a..7361697b009d6 100644 --- a/providers/microsoft/azure/tests/unit/microsoft/azure/sensors/test_msgraph.py +++ b/providers/microsoft/azure/tests/unit/microsoft/azure/sensors/test_msgraph.py @@ -54,20 +54,20 @@ def test_execute_with_result_processor_with_old_signature(self): ): results, events = execute_operator(sensor) - assert sensor.path_parameters == {"scanId": "0a1b1bf3-37de-48f7-9863-ed4cda97a9ef"} - assert isinstance(results, str) - assert results == "0a1b1bf3-37de-48f7-9863-ed4cda97a9ef" - assert len(events) == 3 - assert isinstance(events[0], TriggerEvent) - assert events[0].payload["status"] == "success" - assert events[0].payload["type"] == "builtins.dict" - assert events[0].payload["response"] == json.dumps(status[0]) - assert isinstance(events[1], TriggerEvent) - assert isinstance(events[1].payload, datetime) - assert isinstance(events[2], TriggerEvent) - assert events[2].payload["status"] == "success" - assert events[2].payload["type"] == "builtins.dict" - assert events[2].payload["response"] == json.dumps(status[1]) + assert sensor.path_parameters == {"scanId": "0a1b1bf3-37de-48f7-9863-ed4cda97a9ef"} + assert isinstance(results, str) + assert results == "0a1b1bf3-37de-48f7-9863-ed4cda97a9ef" + assert len(events) == 3 + assert isinstance(events[0], TriggerEvent) + assert events[0].payload["status"] == "success" + assert events[0].payload["type"] == "builtins.dict" + assert events[0].payload["response"] == json.dumps(status[0]) + assert isinstance(events[1], TriggerEvent) + assert isinstance(events[1].payload, datetime) + assert isinstance(events[2], TriggerEvent) + assert events[2].payload["status"] == "success" + assert events[2].payload["type"] == "builtins.dict" + assert events[2].payload["response"] == json.dumps(status[1]) def test_execute_with_result_processor_with_new_signature(self): status = load_json_from_resources(dirname(__file__), "..", "resources", "status.json") diff --git a/providers/slack/tests/unit/slack/hooks/test_slack.py b/providers/slack/tests/unit/slack/hooks/test_slack.py index 904ea02bf329f..cd7e3321191c4 100644 --- a/providers/slack/tests/unit/slack/hooks/test_slack.py +++ b/providers/slack/tests/unit/slack/hooks/test_slack.py @@ -111,8 +111,8 @@ def test_resolve_token(self): """Test that we only use token from Slack API Connection ID.""" with pytest.warns(UserWarning, match="Provide `token` as part of .* parameters is disallowed"): hook = SlackHook(slack_conn_id=SLACK_API_DEFAULT_CONN_ID, token="foo-bar") - assert "token" not in hook.extra_client_args - assert hook._get_conn_params()["token"] == MOCK_SLACK_API_TOKEN + assert "token" not in hook.extra_client_args + assert hook._get_conn_params()["token"] == MOCK_SLACK_API_TOKEN def test_empty_password(self): """Test password field defined in the connection.""" @@ -330,7 +330,7 @@ def test_backcompat_prefix_both_causes_warning(self, monkeypatch): hook = SlackHook(slack_conn_id="my_conn") with pytest.warns(Warning, match="Using value for `timeout`"): params = hook._get_conn_params() - assert params["timeout"] == 222 + assert params["timeout"] == 222 def test_empty_string_ignored_prefixed(self, monkeypatch): monkeypatch.setenv( diff --git a/providers/slack/tests/unit/slack/hooks/test_slack_webhook.py b/providers/slack/tests/unit/slack/hooks/test_slack_webhook.py index cf2d406533df5..8dfa021b9113f 100644 --- a/providers/slack/tests/unit/slack/hooks/test_slack_webhook.py +++ b/providers/slack/tests/unit/slack/hooks/test_slack_webhook.py @@ -199,8 +199,8 @@ def test_ignore_webhook_token(self): UserWarning, match="Provide `webhook_token` as part of .* parameters is disallowed" ): hook = SlackWebhookHook(slack_webhook_conn_id=TEST_CONN_ID, webhook_token="foo-bar") - assert "webhook_token" not in hook.extra_client_args - assert hook._get_conn_params()["url"] == TEST_WEBHOOK_URL + assert "webhook_token" not in hook.extra_client_args + assert hook._get_conn_params()["url"] == TEST_WEBHOOK_URL @pytest.mark.parametrize("conn_id", ["conn_token_in_host_1", "conn_token_in_host_2"]) def test_wrong_connections(self, conn_id): @@ -479,7 +479,7 @@ def test_backcompat_prefix_both_causes_warning(self): hook = SlackWebhookHook(slack_webhook_conn_id="my_conn") with pytest.warns(Warning, match="Using value for `timeout`"): params = hook._get_conn_params() - assert params["timeout"] == 222 + assert params["timeout"] == 222 def test_empty_string_ignored_prefixed(self): with patch.dict( diff --git a/providers/standard/tests/unit/standard/decorators/test_bash.py b/providers/standard/tests/unit/standard/decorators/test_bash.py index c5ba14e70daaa..f4303b345ebc4 100644 --- a/providers/standard/tests/unit/standard/decorators/test_bash.py +++ b/providers/standard/tests/unit/standard/decorators/test_bash.py @@ -412,9 +412,9 @@ def bash(): ): bash_task = bash() - assert bash_task.operator.bash_command == SET_DURING_EXECUTION + assert bash_task.operator.bash_command == SET_DURING_EXECUTION - ti, _ = self.execute_task(bash_task) + ti, _ = self.execute_task(bash_task) assert bash_task.operator.multiple_outputs is False self.validate_bash_command_rtif(ti, "echo") diff --git a/providers/standard/tests/unit/standard/decorators/test_python.py b/providers/standard/tests/unit/standard/decorators/test_python.py index 6e74a4e16a3c6..221aaf9b06fe8 100644 --- a/providers/standard/tests/unit/standard/decorators/test_python.py +++ b/providers/standard/tests/unit/standard/decorators/test_python.py @@ -154,18 +154,18 @@ def t2(x: "FakeTypeCheckingOnlyClass", y: int) -> "dict[int, int]": # type: ign assert t2(5, 5).operator.multiple_outputs is True - with pytest.warns(UserWarning, match="Cannot infer multiple_outputs.*t3") as recwarn: - - @task_decorator - def t3( # type: ignore[empty-body] - x: "FakeTypeCheckingOnlyClass", - y: int, - ) -> "UnresolveableName[int, int]": ... + @task_decorator + def t3( # type: ignore[empty-body] + x: "FakeTypeCheckingOnlyClass", + y: int, + ) -> "UnresolveableName[int, int]": ... + with pytest.warns(UserWarning, match="Cannot infer multiple_outputs.*t3") as recwarn: line = sys._getframe().f_lineno - 5 if PY38 else sys._getframe().f_lineno - 2 - if PY311: - # extra line explaining the error location in Py311 - line = line - 1 + + if PY311: + # extra line explaining the error location in Py311 + line = line - 1 warn = recwarn[0] assert warn.filename == __file__ diff --git a/providers/standard/tests/unit/standard/operators/test_python.py b/providers/standard/tests/unit/standard/operators/test_python.py index e3ae3672718a2..a04575e8bff31 100644 --- a/providers/standard/tests/unit/standard/operators/test_python.py +++ b/providers/standard/tests/unit/standard/operators/test_python.py @@ -1940,8 +1940,8 @@ def test_context_removed_after_exit(self): with set_current_context(example_context): pass if AIRFLOW_V_3_0_PLUS: - with pytest.warns(AirflowProviderDeprecationWarning): - with pytest.raises(RuntimeError): + with pytest.raises(RuntimeError): + with pytest.warns(AirflowProviderDeprecationWarning): get_current_context() else: with pytest.raises(RuntimeError): @@ -1963,13 +1963,13 @@ def test_nested_context(self): ctx_obj.__enter__() ctx_list.append(ctx_obj) if AIRFLOW_V_3_0_PLUS: - with pytest.warns(AirflowProviderDeprecationWarning): - for i in reversed(range(max_stack_depth)): - # Iterate over contexts in reverse order - stack is LIFO + for i in reversed(range(max_stack_depth)): + # Iterate over contexts in reverse order - stack is LIFO + with pytest.warns(AirflowProviderDeprecationWarning): ctx = get_current_context() - assert ctx["ContextId"] == i - # End of with statement - ctx_list[i].__exit__(None, None, None) + assert ctx["ContextId"] == i + # End of with statement + ctx_list[i].__exit__(None, None, None) else: for i in reversed(range(max_stack_depth)): # Iterate over contexts in reverse order - stack is LIFO diff --git a/providers/standard/tests/unit/standard/sensors/test_time_delta.py b/providers/standard/tests/unit/standard/sensors/test_time_delta.py index 2014e9b1275e2..bc70aca8a9b2b 100644 --- a/providers/standard/tests/unit/standard/sensors/test_time_delta.py +++ b/providers/standard/tests/unit/standard/sensors/test_time_delta.py @@ -170,20 +170,18 @@ def setup_method(self): ) @mock.patch(DEFER_PATH) def test_timedelta_sensor(self, defer_mock, should_defer): + delta = timedelta(hours=1) with pytest.warns(AirflowProviderDeprecationWarning): - delta = timedelta(hours=1) op = TimeDeltaSensorAsync(task_id="timedelta_sensor_check", delta=delta, dag=self.dag) - if should_defer: - data_interval_end = pendulum.now("UTC").add(hours=1) - else: - data_interval_end = ( - pendulum.now("UTC").replace(microsecond=0, second=0, minute=0).add(hours=-1) - ) - op.execute({"data_interval_end": data_interval_end}) - if should_defer: - defer_mock.assert_called_once() - else: - defer_mock.assert_not_called() + if should_defer: + data_interval_end = pendulum.now("UTC").add(hours=1) + else: + data_interval_end = pendulum.now("UTC").replace(microsecond=0, second=0, minute=0).add(hours=-1) + op.execute({"data_interval_end": data_interval_end}) + if should_defer: + defer_mock.assert_called_once() + else: + defer_mock.assert_not_called() @pytest.mark.parametrize( "should_defer", @@ -213,32 +211,32 @@ def test_wait_sensor(self, sleep_mock, defer_mock, should_defer): ) def test_timedelta_sensor_async_run_after_vs_interval(self, run_after, interval_end, dag_maker): """Interval end should be used as base time when present else run_after""" - with pytest.warns(AirflowProviderDeprecationWarning): - if not AIRFLOW_V_3_0_PLUS and not interval_end: - pytest.skip("not applicable") - - context = {} - if interval_end: - context["data_interval_end"] = interval_end - with dag_maker() as dag: - kwargs = {} - if AIRFLOW_V_3_0_PLUS: - from airflow.utils.types import DagRunTriggeredByType - - kwargs.update(triggered_by=DagRunTriggeredByType.TEST, run_after=run_after) - - dr = dag.create_dagrun( - run_id="abcrhroceuh", - run_type=DagRunType.MANUAL, - state=None, - **kwargs, - ) - context.update(dag_run=dr) - delta = timedelta(seconds=1) + if not AIRFLOW_V_3_0_PLUS and not interval_end: + pytest.skip("not applicable") + + context = {} + if interval_end: + context["data_interval_end"] = interval_end + with dag_maker() as dag: + kwargs = {} + if AIRFLOW_V_3_0_PLUS: + from airflow.utils.types import DagRunTriggeredByType + + kwargs.update(triggered_by=DagRunTriggeredByType.TEST, run_after=run_after) + + dr = dag.create_dagrun( + run_id="abcrhroceuh", + run_type=DagRunType.MANUAL, + state=None, + **kwargs, + ) + context.update(dag_run=dr) + delta = timedelta(seconds=1) + with pytest.warns(AirflowProviderDeprecationWarning): op = TimeDeltaSensorAsync(task_id="wait_sensor_check", delta=delta, dag=dag) - base_time = interval_end or run_after - expected_time = base_time + delta - with pytest.raises(TaskDeferred) as caught: - op.execute(context) + base_time = interval_end or run_after + expected_time = base_time + delta + with pytest.raises(TaskDeferred) as caught: + op.execute(context) - assert caught.value.trigger.moment == expected_time + assert caught.value.trigger.moment == expected_time diff --git a/task-sdk/pyproject.toml b/task-sdk/pyproject.toml index fb2c514349556..0c3cee49059f0 100644 --- a/task-sdk/pyproject.toml +++ b/task-sdk/pyproject.toml @@ -138,13 +138,15 @@ codegen = [ "datamodel-code-generator[http]==0.28.2", "openapi-spec-validator>=0.7.1", "svcs>=25.1.0", - "ruff==0.11.13", + "ruff==0.12.1", "rich>=12.4.4", ] dev = [ "apache-airflow-providers-common-sql", "apache-airflow-providers-standard", "apache-airflow-devel-common", + "pandas>=2.1.2; python_version <\"3.13\"", + "pandas>=2.2.3; python_version >=\"3.13\"" ] docs = [ "apache-airflow-devel-common[docs]", diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index 2bae008a0a6ff..690411979865e 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -1015,11 +1015,11 @@ def _send_heartbeat_if_needed(self): self._terminal_state = SERVER_TERMINATED else: # If we get any other error, we'll just log it and try again next time - self._handle_heartbeat_failures() - except Exception: - self._handle_heartbeat_failures() + self._handle_heartbeat_failures(e) + except Exception as e: + self._handle_heartbeat_failures(e) - def _handle_heartbeat_failures(self): + def _handle_heartbeat_failures(self, exc: Exception | None): """Increment the failed heartbeats counter and kill the process if too many failures.""" self.failed_heartbeats += 1 log.warning( @@ -1027,7 +1027,7 @@ def _handle_heartbeat_failures(self): failed_heartbeats=self.failed_heartbeats, ti_id=self.id, max_retries=MAX_FAILED_HEARTBEATS, - exc_info=True, + exception=exc, ) # If we've failed to heartbeat too many times, kill the process if self.failed_heartbeats >= MAX_FAILED_HEARTBEATS: diff --git a/task-sdk/tests/task_sdk/bases/test_operator.py b/task-sdk/tests/task_sdk/bases/test_operator.py index dafce2400c8e0..b9a4c14e31750 100644 --- a/task-sdk/tests/task_sdk/bases/test_operator.py +++ b/task-sdk/tests/task_sdk/bases/test_operator.py @@ -250,13 +250,13 @@ def get_weight(self, ti): ) def test_warnings_are_properly_propagated(self): - with pytest.warns(DeprecationWarning) as warnings: + with pytest.warns(DeprecationWarning, match="deprecated") as warnings: DeprecatedOperator(task_id="test") - assert len(warnings) == 1 - warning = warnings[0] - # Here we check that the trace points to the place - # where the deprecated class was used - assert warning.filename == __file__ + assert len(warnings) == 1 + warning = warnings[0] + # Here we check that the trace points to the place + # where the deprecated class was used + assert warning.filename == __file__ def test_setattr_performs_no_custom_action_at_execute_time(self, spy_agency): op = MockOperator(task_id="test_task") @@ -680,8 +680,8 @@ class StringTemplateFieldsOperator(BaseOperator): with pytest.warns(UserWarning, match=warning_message) as warnings: task = StringTemplateFieldsOperator(task_id="op1") - assert len(warnings) == 1 - assert isinstance(task.template_fields, list) + assert len(warnings) == 1 + assert isinstance(task.template_fields, list) def test_jinja_invalid_expression_is_just_propagated(self): """Test render_template propagates Jinja invalid expression errors.""" diff --git a/task-sdk/tests/task_sdk/definitions/test_asset.py b/task-sdk/tests/task_sdk/definitions/test_asset.py index fd70882e96ada..34992c73f8673 100644 --- a/task-sdk/tests/task_sdk/definitions/test_asset.py +++ b/task-sdk/tests/task_sdk/definitions/test_asset.py @@ -117,7 +117,7 @@ def test_uri_with_scheme(uri: str, normalized: str) -> None: def test_uri_with_auth() -> None: - with pytest.warns(UserWarning) as record: + with pytest.warns(UserWarning, match="username") as record: asset = Asset("ftp://user@localhost/foo.txt") assert len(record) == 1 assert str(record[0].message) == (