diff --git a/airflow-core/tests/unit/cli/commands/test_asset_command.py b/airflow-core/tests/unit/cli/commands/test_asset_command.py index 88f2131575115..0543a828aad2e 100644 --- a/airflow-core/tests/unit/cli/commands/test_asset_command.py +++ b/airflow-core/tests/unit/cli/commands/test_asset_command.py @@ -18,8 +18,6 @@ from __future__ import annotations -import contextlib -import io import json import os import typing @@ -56,34 +54,34 @@ def parser() -> ArgumentParser: return cli_parser.get_parser() -def test_cli_assets_list(parser: ArgumentParser) -> None: +def test_cli_assets_list(parser: ArgumentParser, stdout_capture) -> None: args = parser.parse_args(["assets", "list", "--output=json"]) - with contextlib.redirect_stdout(io.StringIO()) as temp_stdout: + with stdout_capture as capture: asset_command.asset_list(args) - asset_list = json.loads(temp_stdout.getvalue()) + asset_list = json.loads(capture.getvalue()) assert len(asset_list) > 0 assert set(asset_list[0]) == {"name", "uri", "group", "extra"} assert any(asset["uri"] == "s3://dag1/output_1.txt" for asset in asset_list), asset_list -def test_cli_assets_alias_list(parser: ArgumentParser) -> None: +def test_cli_assets_alias_list(parser: ArgumentParser, stdout_capture) -> None: args = parser.parse_args(["assets", "list", "--alias", "--output=json"]) - with contextlib.redirect_stdout(io.StringIO()) as temp_stdout: + with stdout_capture as capture: asset_command.asset_list(args) - alias_list = json.loads(temp_stdout.getvalue()) + alias_list = json.loads(capture.getvalue()) assert len(alias_list) > 0 assert set(alias_list[0]) == {"name", "group"} assert any(alias["name"] == "example-alias" for alias in alias_list), alias_list -def test_cli_assets_details(parser: ArgumentParser) -> None: +def test_cli_assets_details(parser: ArgumentParser, stdout_capture) -> None: args = parser.parse_args(["assets", "details", "--name=asset1_producer", "--output=json"]) - with contextlib.redirect_stdout(io.StringIO()) as temp_stdout: + with stdout_capture as capture: asset_command.asset_details(args) - asset_detail_list = json.loads(temp_stdout.getvalue()) + asset_detail_list = json.loads(capture.getvalue()) assert len(asset_detail_list) == 1 # No good way to statically compare these. @@ -106,12 +104,12 @@ def test_cli_assets_details(parser: ArgumentParser) -> None: } -def test_cli_assets_alias_details(parser: ArgumentParser) -> None: +def test_cli_assets_alias_details(parser: ArgumentParser, stdout_capture) -> None: args = parser.parse_args(["assets", "details", "--alias", "--name=example-alias", "--output=json"]) - with contextlib.redirect_stdout(io.StringIO()) as temp_stdout: + with stdout_capture as capture: asset_command.asset_details(args) - alias_detail_list = json.loads(temp_stdout.getvalue()) + alias_detail_list = json.loads(capture.getvalue()) assert len(alias_detail_list) == 1 # No good way to statically compare these. @@ -121,15 +119,17 @@ def test_cli_assets_alias_details(parser: ArgumentParser) -> None: @mock.patch("airflow.api_fastapi.core_api.datamodels.dag_versions.hasattr") -def test_cli_assets_materialize(mock_hasattr, parser: ArgumentParser) -> None: +def test_cli_assets_materialize(mock_hasattr, parser: ArgumentParser, stdout_capture) -> None: mock_hasattr.return_value = False args = parser.parse_args(["assets", "materialize", "--name=asset1_producer", "--output=json"]) - with contextlib.redirect_stdout(io.StringIO()) as temp_stdout: + with stdout_capture as capture: asset_command.asset_materialize(args) - output = temp_stdout.getvalue() - # Skip the first line of `temp_stdout` since the current `DAGRunResponse` requires `DagBundlesManager`, which logs `INFO - DAG bundles loaded: dags-folder, example_dags`. - output = "\n".join(output.splitlines()[1:]) + output = capture.getvalue() + + # Check if output is empty first + assert output, "No output captured from asset_materialize command" + run_list = json.loads(output) assert len(run_list) == 1 @@ -162,12 +162,12 @@ def test_cli_assets_materialize(mock_hasattr, parser: ArgumentParser) -> None: } -def test_cli_assets_materialize_with_view_url_template(parser: ArgumentParser) -> None: +def test_cli_assets_materialize_with_view_url_template(parser: ArgumentParser, stdout_capture) -> None: args = parser.parse_args(["assets", "materialize", "--name=asset1_producer", "--output=json"]) - with contextlib.redirect_stdout(io.StringIO()) as temp_stdout: + with stdout_capture as capture: asset_command.asset_materialize(args) - output = temp_stdout.getvalue() + output = capture.getvalue() run_list = json.loads(output) assert len(run_list) == 1 diff --git a/airflow-core/tests/unit/cli/commands/test_cheat_sheet_command.py b/airflow-core/tests/unit/cli/commands/test_cheat_sheet_command.py index 34d5aa6fc0148..b00bb56059304 100644 --- a/airflow-core/tests/unit/cli/commands/test_cheat_sheet_command.py +++ b/airflow-core/tests/unit/cli/commands/test_cheat_sheet_command.py @@ -16,8 +16,6 @@ # under the License. from __future__ import annotations -import contextlib -from io import StringIO from unittest import mock from airflow.cli import cli_parser @@ -94,8 +92,8 @@ def setup_class(cls): cls.parser = cli_parser.get_parser() @mock.patch("airflow.cli.cli_parser.airflow_commands", MOCK_COMMANDS) - def test_should_display_index(self): - with contextlib.redirect_stdout(StringIO()) as temp_stdout: + def test_should_display_index(self, stdout_capture): + with stdout_capture as temp_stdout: args = self.parser.parse_args(["cheat-sheet"]) args.func(args) output = temp_stdout.getvalue() diff --git a/airflow-core/tests/unit/cli/commands/test_config_command.py b/airflow-core/tests/unit/cli/commands/test_config_command.py index bbbe233100f58..66a21c8b9244d 100644 --- a/airflow-core/tests/unit/cli/commands/test_config_command.py +++ b/airflow-core/tests/unit/cli/commands/test_config_command.py @@ -16,11 +16,9 @@ # under the License. from __future__ import annotations -import contextlib import os import re import shutil -from io import StringIO from unittest import mock import pytest @@ -75,22 +73,22 @@ def test_cli_show_config_should_write_data_specific_section(self, mock_conf, moc ) @conf_vars({("core", "testkey"): "test_value"}) - def test_cli_show_config_should_display_key(self): - with contextlib.redirect_stdout(StringIO()) as temp_stdout: + def test_cli_show_config_should_display_key(self, stdout_capture): + with stdout_capture as temp_stdout: config_command.show_config(self.parser.parse_args(["config", "list", "--color", "off"])) output = temp_stdout.getvalue() assert "[core]" in output assert "testkey = test_value" in temp_stdout.getvalue() - def test_cli_show_config_should_only_show_comments_when_no_defaults(self): - with contextlib.redirect_stdout(StringIO()) as temp_stdout: + def test_cli_show_config_should_only_show_comments_when_no_defaults(self, stdout_capture): + with stdout_capture as temp_stdout: config_command.show_config(self.parser.parse_args(["config", "list", "--color", "off"])) output = temp_stdout.getvalue() lines = output.splitlines() assert all(not line.startswith("#") or line.endswith("= ") for line in lines if line) - def test_cli_show_config_shows_descriptions(self): - with contextlib.redirect_stdout(StringIO()) as temp_stdout: + def test_cli_show_config_shows_descriptions(self, stdout_capture): + with stdout_capture as temp_stdout: config_command.show_config( self.parser.parse_args(["config", "list", "--color", "off", "--include-descriptions"]) ) @@ -102,8 +100,8 @@ def test_cli_show_config_shows_descriptions(self): assert all(not line.startswith("# Example:") for line in lines if line) assert all(not line.startswith("# Variable:") for line in lines if line) - def test_cli_show_config_shows_examples(self): - with contextlib.redirect_stdout(StringIO()) as temp_stdout: + def test_cli_show_config_shows_examples(self, stdout_capture): + with stdout_capture as temp_stdout: config_command.show_config( self.parser.parse_args(["config", "list", "--color", "off", "--include-examples"]) ) @@ -114,8 +112,8 @@ def test_cli_show_config_shows_examples(self): assert any(line.startswith("# Example:") for line in lines if line) assert all(not line.startswith("# Variable:") for line in lines if line) - def test_cli_show_config_shows_variables(self): - with contextlib.redirect_stdout(StringIO()) as temp_stdout: + def test_cli_show_config_shows_variables(self, stdout_capture): + with stdout_capture as temp_stdout: config_command.show_config( self.parser.parse_args(["config", "list", "--color", "off", "--include-env-vars"]) ) @@ -126,8 +124,8 @@ def test_cli_show_config_shows_variables(self): assert all(not line.startswith("# Example:") for line in lines if line) assert any(line.startswith("# Variable:") for line in lines if line) - def test_cli_show_config_shows_sources(self): - with contextlib.redirect_stdout(StringIO()) as temp_stdout: + def test_cli_show_config_shows_sources(self, stdout_capture): + with stdout_capture as temp_stdout: config_command.show_config( self.parser.parse_args(["config", "list", "--color", "off", "--include-sources"]) ) @@ -138,8 +136,8 @@ def test_cli_show_config_shows_sources(self): assert all(not line.startswith("# Example:") for line in lines if line) assert all(not line.startswith("# Variable:") for line in lines if line) - def test_cli_show_config_defaults(self): - with contextlib.redirect_stdout(StringIO()) as temp_stdout: + def test_cli_show_config_defaults(self, stdout_capture): + with stdout_capture as temp_stdout: config_command.show_config( self.parser.parse_args(["config", "list", "--color", "off", "--defaults"]) ) @@ -155,8 +153,8 @@ def test_cli_show_config_defaults(self): ) @conf_vars({("core", "hostname_callable"): "testfn"}) - def test_cli_show_config_defaults_not_show_conf_changes(self): - with contextlib.redirect_stdout(StringIO()) as temp_stdout: + def test_cli_show_config_defaults_not_show_conf_changes(self, stdout_capture): + with stdout_capture as temp_stdout: config_command.show_config( self.parser.parse_args(["config", "list", "--color", "off", "--defaults"]) ) @@ -167,8 +165,8 @@ def test_cli_show_config_defaults_not_show_conf_changes(self): ) @mock.patch("os.environ", {"AIRFLOW__CORE__HOSTNAME_CALLABLE": "test_env"}) - def test_cli_show_config_defaults_do_not_show_env_changes(self): - with contextlib.redirect_stdout(StringIO()) as temp_stdout: + def test_cli_show_config_defaults_do_not_show_env_changes(self, stdout_capture): + with stdout_capture as temp_stdout: config_command.show_config( self.parser.parse_args(["config", "list", "--color", "off", "--defaults"]) ) @@ -179,30 +177,30 @@ def test_cli_show_config_defaults_do_not_show_env_changes(self): ) @conf_vars({("core", "hostname_callable"): "testfn"}) - def test_cli_show_changed_defaults_when_overridden_in_conf(self): - with contextlib.redirect_stdout(StringIO()) as temp_stdout: + def test_cli_show_changed_defaults_when_overridden_in_conf(self, stdout_capture): + with stdout_capture as temp_stdout: config_command.show_config(self.parser.parse_args(["config", "list", "--color", "off"])) output = temp_stdout.getvalue() lines = output.splitlines() assert any(line.startswith("hostname_callable = testfn") for line in lines if line) @mock.patch("os.environ", {"AIRFLOW__CORE__HOSTNAME_CALLABLE": "test_env"}) - def test_cli_show_changed_defaults_when_overridden_in_env(self): - with contextlib.redirect_stdout(StringIO()) as temp_stdout: + def test_cli_show_changed_defaults_when_overridden_in_env(self, stdout_capture): + with stdout_capture as temp_stdout: config_command.show_config(self.parser.parse_args(["config", "list", "--color", "off"])) output = temp_stdout.getvalue() lines = output.splitlines() assert any(line.startswith("hostname_callable = test_env") for line in lines if line) - def test_cli_has_providers(self): - with contextlib.redirect_stdout(StringIO()) as temp_stdout: + def test_cli_has_providers(self, stdout_capture): + with stdout_capture as temp_stdout: config_command.show_config(self.parser.parse_args(["config", "list", "--color", "off"])) output = temp_stdout.getvalue() lines = output.splitlines() assert any(line.startswith("celery_config_options") for line in lines if line) - def test_cli_comment_out_everything(self): - with contextlib.redirect_stdout(StringIO()) as temp_stdout: + def test_cli_comment_out_everything(self, stdout_capture): + with stdout_capture as temp_stdout: config_command.show_config( self.parser.parse_args(["config", "list", "--color", "off", "--comment-out-everything"]) ) @@ -218,8 +216,8 @@ def setup_class(cls): cls.parser = cli_parser.get_parser() @conf_vars({("core", "test_key"): "test_value"}) - def test_should_display_value(self): - with contextlib.redirect_stdout(StringIO()) as temp_stdout: + def test_should_display_value(self, stdout_capture): + with stdout_capture as temp_stdout: config_command.get_value(self.parser.parse_args(["config", "get-value", "core", "test_key"])) assert temp_stdout.getvalue().strip() == "test_value" @@ -246,9 +244,9 @@ class TestConfigLint: @pytest.mark.parametrize( "removed_config", [config for config in config_command.CONFIGS_CHANGES if config.was_removed] ) - def test_lint_detects_removed_configs(self, removed_config): + def test_lint_detects_removed_configs(self, removed_config, stdout_capture): with mock.patch("airflow.configuration.conf.has_option", return_value=True): - with contextlib.redirect_stdout(StringIO()) as temp_stdout: + with stdout_capture as temp_stdout: config_command.lint_config(cli_parser.get_parser().parse_args(["config", "lint"])) output = temp_stdout.getvalue() @@ -264,9 +262,9 @@ def test_lint_detects_removed_configs(self, removed_config): "default_changed_config", [config for config in config_command.CONFIGS_CHANGES if config.default_change], ) - def test_lint_detects_default_changed_configs(self, default_changed_config): + def test_lint_detects_default_changed_configs(self, default_changed_config, stdout_capture): with mock.patch("airflow.configuration.conf.has_option", return_value=True): - with contextlib.redirect_stdout(StringIO()) as temp_stdout: + with stdout_capture as temp_stdout: config_command.lint_config(cli_parser.get_parser().parse_args(["config", "lint"])) output = temp_stdout.getvalue() @@ -297,9 +295,9 @@ def test_lint_detects_default_changed_configs(self, default_changed_config): ), ], ) - def test_lint_with_specific_removed_configs(self, section, option, suggestion): + def test_lint_with_specific_removed_configs(self, section, option, suggestion, stdout_capture): with mock.patch("airflow.configuration.conf.has_option", return_value=True): - with contextlib.redirect_stdout(StringIO()) as temp_stdout: + with stdout_capture as temp_stdout: config_command.lint_config(cli_parser.get_parser().parse_args(["config", "lint"])) output = temp_stdout.getvalue() @@ -311,9 +309,9 @@ def test_lint_with_specific_removed_configs(self, section, option, suggestion): assert suggestion in normalized_output - def test_lint_specific_section_option(self): + def test_lint_specific_section_option(self, stdout_capture): with mock.patch("airflow.configuration.conf.has_option", return_value=True): - with contextlib.redirect_stdout(StringIO()) as temp_stdout: + with stdout_capture as temp_stdout: config_command.lint_config( cli_parser.get_parser().parse_args( ["config", "lint", "--section", "core", "--option", "check_slas"] @@ -329,9 +327,9 @@ def test_lint_specific_section_option(self): in normalized_output ) - def test_lint_with_invalid_section_option(self): + def test_lint_with_invalid_section_option(self, stdout_capture): with mock.patch("airflow.configuration.conf.has_option", return_value=False): - with contextlib.redirect_stdout(StringIO()) as temp_stdout: + with stdout_capture as temp_stdout: config_command.lint_config( cli_parser.get_parser().parse_args( ["config", "lint", "--section", "invalid_section", "--option", "invalid_option"] @@ -344,13 +342,13 @@ def test_lint_with_invalid_section_option(self): assert "No issues found in your airflow.cfg." in normalized_output - def test_lint_detects_multiple_issues(self): + def test_lint_detects_multiple_issues(self, stdout_capture): with mock.patch( "airflow.configuration.conf.has_option", side_effect=lambda section, option, lookup_from_deprecated: option in ["check_slas", "strict_dataset_uri_validation"], ): - with contextlib.redirect_stdout(StringIO()) as temp_stdout: + with stdout_capture as temp_stdout: config_command.lint_config(cli_parser.get_parser().parse_args(["config", "lint"])) output = temp_stdout.getvalue() @@ -388,9 +386,9 @@ def test_lint_detects_multiple_issues(self): ], ], ) - def test_lint_detects_multiple_removed_configs(self, removed_configs): + def test_lint_detects_multiple_removed_configs(self, removed_configs, stdout_capture): with mock.patch("airflow.configuration.conf.has_option", return_value=True): - with contextlib.redirect_stdout(StringIO()) as temp_stdout: + with stdout_capture as temp_stdout: config_command.lint_config(cli_parser.get_parser().parse_args(["config", "lint"])) output = temp_stdout.getvalue() @@ -421,9 +419,9 @@ def test_lint_detects_multiple_removed_configs(self, removed_configs): ], ], ) - def test_lint_detects_renamed_configs(self, renamed_configs): + def test_lint_detects_renamed_configs(self, renamed_configs, stdout_capture): with mock.patch("airflow.configuration.conf.has_option", return_value=True): - with contextlib.redirect_stdout(StringIO()) as temp_stdout: + with stdout_capture as temp_stdout: config_command.lint_config(cli_parser.get_parser().parse_args(["config", "lint"])) output = temp_stdout.getvalue() @@ -458,10 +456,16 @@ def test_lint_detects_renamed_configs(self, renamed_configs): ), ], ) - def test_lint_detects_configs_with_env_vars(self, env_var, config_change, expected_message): + def test_lint_detects_configs_with_env_vars( + self, + env_var, + config_change, + expected_message, + stdout_capture, + ): with mock.patch.dict(os.environ, {env_var: "some_value"}): with mock.patch("airflow.configuration.conf.has_option", return_value=True): - with contextlib.redirect_stdout(StringIO()) as temp_stdout: + with stdout_capture as temp_stdout: config_command.lint_config(cli_parser.get_parser().parse_args(["config", "lint"])) output = temp_stdout.getvalue() @@ -471,9 +475,9 @@ def test_lint_detects_configs_with_env_vars(self, env_var, config_change, expect assert expected_message in normalized_output assert config_change.suggestion in normalized_output - def test_lint_detects_invalid_config(self): + def test_lint_detects_invalid_config(self, stdout_capture): with mock.patch.dict(os.environ, {"AIRFLOW__CORE__PARALLELISM": "0"}): - with contextlib.redirect_stdout(StringIO()) as temp_stdout: + with stdout_capture as temp_stdout: config_command.lint_config(cli_parser.get_parser().parse_args(["config", "lint"])) output = temp_stdout.getvalue() @@ -485,9 +489,9 @@ def test_lint_detects_invalid_config(self): in normalized_output ) - def test_lint_detects_invalid_config_negative(self): + def test_lint_detects_invalid_config_negative(self, stdout_capture): with mock.patch.dict(os.environ, {"AIRFLOW__CORE__PARALLELISM": "42"}): - with contextlib.redirect_stdout(StringIO()) as temp_stdout: + with stdout_capture as temp_stdout: config_command.lint_config(cli_parser.get_parser().parse_args(["config", "lint"])) output = temp_stdout.getvalue() diff --git a/airflow-core/tests/unit/cli/commands/test_connection_command.py b/airflow-core/tests/unit/cli/commands/test_connection_command.py index 1ebb415aa0f82..6db7f3881226b 100644 --- a/airflow-core/tests/unit/cli/commands/test_connection_command.py +++ b/airflow-core/tests/unit/cli/commands/test_connection_command.py @@ -16,7 +16,6 @@ # under the License. from __future__ import annotations -import io import json import os import re @@ -24,7 +23,6 @@ import warnings from contextlib import redirect_stdout from io import StringIO -from unittest import mock import pytest @@ -53,13 +51,12 @@ class TestCliGetConnection: def setup_method(self): clear_db_connections(add_default_connections_back=True) - def test_cli_connection_get(self): - with redirect_stdout(StringIO()) as stdout: + def test_cli_connection_get(self, stdout_capture): + with stdout_capture as capture: connection_command.connections_get( self.parser.parse_args(["connections", "get", "google_cloud_default", "--output", "json"]) ) - stdout = stdout.getvalue() - assert "google-cloud-platform:///default" in stdout + assert "google-cloud-platform:///default" in capture.getvalue() def test_cli_connection_get_invalid(self): with pytest.raises(SystemExit, match=re.escape("Connection not found.")): @@ -93,7 +90,6 @@ def test_cli_connections_list_as_json(self): connection_command.connections_list(args) print(stdout.getvalue()) stdout = stdout.getvalue() - for conn_id, conn_type in self.EXPECTED_CONS: assert conn_type in stdout assert conn_id in stdout @@ -102,9 +98,9 @@ def test_cli_connections_filter_conn_id(self): args = self.parser.parse_args( ["connections", "list", "--output", "json", "--conn-id", "http_default"] ) - with redirect_stdout(StringIO()) as stdout: + with redirect_stdout(StringIO()) as capture: connection_command.connections_list(args) - stdout = stdout.getvalue() + stdout = capture.getvalue() assert "http_default" in stdout @@ -149,39 +145,24 @@ def test_cli_connections_export_should_return_error_for_invalid_export_format(se with pytest.raises(SystemExit, match=r"Unsupported file format"): connection_command.connections_export(args) - @mock.patch.object(connection_command, "create_session") - def test_cli_connections_export_should_raise_error_if_create_session_fails( - self, mock_create_session, tmp_path - ): + def test_cli_connections_export_should_raise_error_if_create_session_fails(self, mocker, tmp_path): output_filepath = tmp_path / "connections.json" - - def my_side_effect(): - raise Exception("dummy exception") - - mock_create_session.side_effect = my_side_effect + mocker.patch.object(connection_command, "create_session", side_effect=Exception("dummy exception")) args = self.parser.parse_args(["connections", "export", output_filepath.as_posix()]) with pytest.raises(Exception, match=r"dummy exception"): connection_command.connections_export(args) - @mock.patch.object(connection_command, "create_session") - def test_cli_connections_export_should_raise_error_if_fetching_connections_fails( - self, mock_session, tmp_path - ): + def test_cli_connections_export_should_raise_error_if_fetching_connections_fails(self, mocker, tmp_path): output_filepath = tmp_path / "connections.json" - - def my_side_effect(_): - raise Exception("dummy exception") - - mock_session.return_value.__enter__.return_value.scalars.side_effect = my_side_effect + mock_session = mocker.patch.object(connection_command, "create_session") + mock_session.return_value.__enter__.return_value.scalars.side_effect = Exception("dummy exception") args = self.parser.parse_args(["connections", "export", output_filepath.as_posix()]) with pytest.raises(Exception, match=r"dummy exception"): connection_command.connections_export(args) - @mock.patch.object(connection_command, "create_session") - def test_cli_connections_export_should_not_raise_error_if_connections_is_empty( - self, mock_session, tmp_path - ): + def test_cli_connections_export_should_not_raise_error_if_connections_is_empty(self, mocker, tmp_path): output_filepath = tmp_path / "connections.json" + mock_session = mocker.patch.object(connection_command, "create_session") mock_session.return_value.__enter__.return_value.query.return_value.all.return_value = [] args = self.parser.parse_args(["connections", "export", output_filepath.as_posix()]) connection_command.connections_export(args) @@ -335,7 +316,7 @@ def test_cli_connections_export_should_force_export_as_specified_format(self, tm assert json.loads(output_filepath.read_text()) == expected_connections def test_cli_connections_export_should_work_when_stdout_is_not_a_real_fd(self, tmp_path): - class FakeFileStringIO(io.StringIO): + class FakeFileStringIO(StringIO): """ Buffer the contents of a StringIO to make them accessible after close @@ -598,7 +579,7 @@ def setup_method(self): ], ) @pytest.mark.execution_timeout(120) - def test_cli_connection_add(self, cmd, expected_output, expected_conn, session): + def test_cli_connection_add(self, cmd, expected_output, expected_conn, session, stdout_capture): if "invalid-uri-test" in cmd: with pytest.raises(SystemExit) as exc_info: connection_command.connections_add(self.parser.parse_args(cmd)) @@ -606,7 +587,7 @@ def test_cli_connection_add(self, cmd, expected_output, expected_conn, session): assert str(exc_info.value) == expected_output return - with redirect_stdout(StringIO()) as stdout: + with stdout_capture as stdout: connection_command.connections_add(self.parser.parse_args(cmd)) stdout = stdout.getvalue() @@ -711,7 +692,7 @@ class TestCliDeleteConnections: def setup_method(self): clear_db_connections(add_default_connections_back=False) - def test_cli_delete_connections(self, session): + def test_cli_delete_connections(self, session, stdout_capture): merge_conn( Connection( conn_id="new1", @@ -725,12 +706,11 @@ def test_cli_delete_connections(self, session): session=session, ) # Delete connections - with redirect_stdout(StringIO()) as stdout: + with stdout_capture as stdout: connection_command.connections_delete(self.parser.parse_args(["connections", "delete", "new1"])) - stdout = stdout.getvalue() # Check deletion stdout - assert "Successfully deleted connection with `conn_id`=new1" in stdout + assert "Successfully deleted connection with `conn_id`=new1" in stdout.getvalue() # Check deletions result = session.query(Connection).filter(Connection.conn_id == "new1").first() @@ -749,19 +729,15 @@ class TestCliImportConnections: def setup_method(self): clear_db_connections(add_default_connections_back=False) - @mock.patch("os.path.exists") - def test_cli_connections_import_should_return_error_if_file_does_not_exist(self, mock_exists): - mock_exists.return_value = False + def test_cli_connections_import_should_return_error_if_file_does_not_exist(self, mocker): + mocker.patch("os.path.exists", return_value=False) filepath = "/does/not/exist.json" with pytest.raises(SystemExit, match=r"Missing connections file."): connection_command.connections_import(self.parser.parse_args(["connections", "import", filepath])) @pytest.mark.parametrize("filepath", ["sample.jso", "sample.environ"]) - @mock.patch("os.path.exists") - def test_cli_connections_import_should_return_error_if_file_format_is_invalid( - self, mock_exists, filepath - ): - mock_exists.return_value = True + def test_cli_connections_import_should_return_error_if_file_format_is_invalid(self, filepath, mocker): + mocker.patch("os.path.exists", return_value=True) with pytest.raises( AirflowException, match=( @@ -771,11 +747,7 @@ def test_cli_connections_import_should_return_error_if_file_format_is_invalid( ): connection_command.connections_import(self.parser.parse_args(["connections", "import", filepath])) - @mock.patch("airflow.secrets.local_filesystem._parse_secret_file") - @mock.patch("os.path.exists") - def test_cli_connections_import_should_load_connections(self, mock_exists, mock_parse_secret_file): - mock_exists.return_value = True - + def test_cli_connections_import_should_load_connections(self, mocker): # Sample connections to import expected_connections = { "new0": { @@ -798,18 +770,39 @@ def test_cli_connections_import_should_load_connections(self, mock_exists, mock_ "schema": "airflow", "extra": '{"spam": "egg"}', }, + # Add new3 if the test expects an error about it + "new3": { + "conn_type": "sqlite", + "description": "new3 description", + "host": "host", + }, } - # We're not testing the behavior of _parse_secret_file, assume it successfully reads JSON, YAML or env - mock_parse_secret_file.return_value = expected_connections + # First, create new3 to trigger the "already exists" error + with create_session() as session: + session.add(Connection(conn_id="new3", conn_type="sqlite")) + session.commit() + + # We're not testing the behavior of _parse_secret_file + mocker.patch("airflow.secrets.local_filesystem._parse_secret_file", return_value=expected_connections) + mocker.patch("os.path.exists", return_value=True) + mock_print = mocker.patch("airflow.cli.commands.connection_command.print") connection_command.connections_import( self.parser.parse_args(["connections", "import", "sample.json"]) ) - # Verify that the imported connections match the expected, sample connections + # Check all print calls to find the error message + print_calls = [str(call) for call in mock_print.call_args_list] + assert any("Could not import connection new3" in call for call in print_calls), ( + f"Expected error message not found. Print calls: {print_calls}" + ) + + # Verify connections (exclude new3 since it should fail) + expected_imported = {k: v for k, v in expected_connections.items() if k != "new3"} + with create_session() as session: - current_conns = session.query(Connection).all() + current_conns = session.query(Connection).filter(Connection.conn_id.in_(["new0", "new1"])).all() comparable_attrs = [ "conn_id", @@ -827,15 +820,10 @@ def test_cli_connections_import_should_load_connections(self, mock_exists, mock_ current_conn.conn_id: {attr: getattr(current_conn, attr) for attr in comparable_attrs} for current_conn in current_conns } - assert expected_connections == current_conns_as_dicts - - @mock.patch("airflow.secrets.local_filesystem._parse_secret_file") - @mock.patch("os.path.exists") - def test_cli_connections_import_should_not_overwrite_existing_connections( - self, mock_exists, mock_parse_secret_file, session - ): - mock_exists.return_value = True + assert expected_imported == current_conns_as_dicts + def test_cli_connections_import_should_not_overwrite_existing_connections(self, session, mocker): + mocker.patch("os.path.exists", return_value=True) # Add a pre-existing connection "new3" merge_conn( Connection( @@ -875,14 +863,12 @@ def test_cli_connections_import_should_not_overwrite_existing_connections( } # We're not testing the behavior of _parse_secret_file, assume it successfully reads JSON, YAML or env - mock_parse_secret_file.return_value = expected_connections - - with redirect_stdout(StringIO()) as stdout: - connection_command.connections_import( - self.parser.parse_args(["connections", "import", "sample.json"]) - ) - - assert "Could not import connection new3: connection already exists." in stdout.getvalue() + mocker.patch("airflow.secrets.local_filesystem._parse_secret_file", return_value=expected_connections) + mock_print = mocker.patch("airflow.cli.commands.connection_command.print") + connection_command.connections_import( + self.parser.parse_args(["connections", "import", "sample.json"]) + ) + assert "Could not import connection new3: connection already exists." in mock_print.call_args[0][0] # Verify that the imported connections match the expected, sample connections current_conns = session.query(Connection).all() @@ -908,13 +894,33 @@ def test_cli_connections_import_should_not_overwrite_existing_connections( # The existing connection's description should not have changed assert current_conns_as_dicts["new3"]["description"] == "original description" - @mock.patch("airflow.secrets.local_filesystem._parse_secret_file") - @mock.patch("os.path.exists") - def test_cli_connections_import_should_overwrite_existing_connections( - self, mock_exists, mock_parse_secret_file, session - ): - mock_exists.return_value = True - + def test_cli_connections_import_should_overwrite_existing_connections(self, mocker, session): + mocker.patch("os.path.exists", return_value=True) + mocker.patch( + "airflow.secrets.local_filesystem._parse_secret_file", + return_value={ + "new2": { + "conn_type": "postgres", + "description": "new2 description", + "host": "host", + "login": "airflow", + "password": "password", + "port": 5432, + "schema": "airflow", + "extra": '{"foo": "bar"}', + }, + "new3": { + "conn_type": "mysql", + "description": "updated description", + "host": "host", + "login": "airflow", + "password": "new password", + "port": 3306, + "schema": "airflow", + "extra": '{"spam": "egg"}', + }, + }, + ) # Add a pre-existing connection "new3" merge_conn( Connection( @@ -928,44 +934,15 @@ def test_cli_connections_import_should_overwrite_existing_connections( ), session=session, ) - - # Sample connections to import, including a collision with "new3" - expected_connections = { - "new2": { - "conn_type": "postgres", - "description": "new2 description", - "host": "host", - "login": "airflow", - "password": "password", - "port": 5432, - "schema": "airflow", - "extra": '{"foo": "bar"}', - }, - "new3": { - "conn_type": "mysql", - "description": "updated description", - "host": "host", - "login": "airflow", - "password": "new password", - "port": 3306, - "schema": "airflow", - "extra": '{"spam": "egg"}', - }, - } - - # We're not testing the behavior of _parse_secret_file, assume it successfully reads JSON, YAML or env - mock_parse_secret_file.return_value = expected_connections - - with redirect_stdout(StringIO()) as stdout: - connection_command.connections_import( - self.parser.parse_args(["connections", "import", "sample.json", "--overwrite"]) - ) - - assert "Could not import connection new3: connection already exists." not in stdout.getvalue() - + mock_print = mocker.patch("airflow.cli.commands.connection_command.print") + connection_command.connections_import( + self.parser.parse_args(["connections", "import", "sample.json", "--overwrite"]) + ) + assert ( + "Could not import connection new3: connection already exists." not in mock_print.call_args[0][0] + ) # Verify that the imported connections match the expected, sample connections current_conns = session.query(Connection).all() - comparable_attrs = [ "conn_id", "conn_type", @@ -977,15 +954,33 @@ def test_cli_connections_import_should_overwrite_existing_connections( "schema", "extra", ] - current_conns_as_dicts = { current_conn.conn_id: {attr: getattr(current_conn, attr) for attr in comparable_attrs} for current_conn in current_conns } - assert current_conns_as_dicts["new2"] == expected_connections["new2"] - + assert current_conns_as_dicts["new2"] == { + "conn_type": "postgres", + "description": "new2 description", + "host": "host", + "login": "airflow", + "password": "password", + "port": 5432, + "schema": "airflow", + "extra": '{"foo": "bar"}', + "conn_id": "new2", + } # The existing connection should have been overwritten - assert current_conns_as_dicts["new3"] == expected_connections["new3"] + assert current_conns_as_dicts["new3"] == { + "conn_type": "mysql", + "description": "updated description", + "host": "host", + "login": "airflow", + "password": "new password", + "port": 3306, + "schema": "airflow", + "extra": '{"spam": "egg"}', + "conn_id": "new3", + } class TestCliTestConnections: @@ -994,38 +989,32 @@ class TestCliTestConnections: def setup_class(self): clear_db_connections() - @mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"}) - @mock.patch("airflow.providers.http.hooks.http.HttpHook.test_connection") - def test_cli_connections_test_success(self, mock_test_conn): - """Check that successful connection test result is displayed properly.""" + def test_cli_connections_test_success(self, mocker, stdout_capture): + mocker.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"}) + mock_test_conn = mocker.patch("airflow.providers.http.hooks.http.HttpHook.test_connection") conn_id = "http_default" mock_test_conn.return_value = True, None - with redirect_stdout(StringIO()) as stdout: + with stdout_capture as stdout: connection_command.connections_test(self.parser.parse_args(["connections", "test", conn_id])) - assert "Connection success!" in stdout.getvalue() - @mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"}) - @mock.patch("airflow.providers.http.hooks.http.HttpHook.test_connection") - def test_cli_connections_test_fail(self, mock_test_conn): - """Check that failed connection test result is displayed properly.""" + def test_cli_connections_test_fail(self, mocker, stdout_capture): + mocker.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"}) + mock_test_conn = mocker.patch("airflow.providers.http.hooks.http.HttpHook.test_connection") conn_id = "http_default" mock_test_conn.return_value = False, "Failed." - with redirect_stdout(StringIO()) as stdout: + with stdout_capture as stdout: connection_command.connections_test(self.parser.parse_args(["connections", "test", conn_id])) - assert "Connection failed!\nFailed.\n\n" in stdout.getvalue() - @mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"}) - def test_cli_connections_test_missing_conn(self): - """Check a connection test on a non-existent connection raises a "Connection not found" message.""" - with redirect_stdout(StringIO()) as stdout, pytest.raises(SystemExit): + def test_cli_connections_test_missing_conn(self, mocker, stdout_capture): + mocker.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"}) + with stdout_capture as stdout, pytest.raises(SystemExit): connection_command.connections_test(self.parser.parse_args(["connections", "test", "missing"])) assert "Connection not found.\n\n" in stdout.getvalue() - def test_cli_connections_test_disabled_by_default(self): - """Check that test connection functionality is disabled by default.""" - with redirect_stdout(StringIO()) as stdout, pytest.raises(SystemExit): + def test_cli_connections_test_disabled_by_default(self, stdout_capture): + with stdout_capture as stdout, pytest.raises(SystemExit): connection_command.connections_test(self.parser.parse_args(["connections", "test", "missing"])) assert ( "Testing connections is disabled in Airflow configuration. Contact your deployment admin to " @@ -1034,8 +1023,10 @@ def test_cli_connections_test_disabled_by_default(self): class TestCliCreateDefaultConnection: - @mock.patch("airflow.cli.commands.connection_command.db_create_default_connections") - def test_cli_create_default_connections(self, mock_db_create_default_connections): + def test_cli_create_default_connections(self, mocker): + mock_db_create_default_connections = mocker.patch( + "airflow.cli.commands.connection_command.db_create_default_connections" + ) create_default_connection_fnc = dict( (db_command.name, db_command.func) for db_command in cli_config.CONNECTIONS_COMMANDS )["create-default-connections"] diff --git a/airflow-core/tests/unit/cli/commands/test_dag_command.py b/airflow-core/tests/unit/cli/commands/test_dag_command.py index e13cc3a410e73..a3420cdf4af7b 100644 --- a/airflow-core/tests/unit/cli/commands/test_dag_command.py +++ b/airflow-core/tests/unit/cli/commands/test_dag_command.py @@ -18,12 +18,10 @@ from __future__ import annotations import argparse -import contextlib import json import logging import os from datetime import datetime, timedelta -from io import StringIO from unittest import mock from unittest.mock import MagicMock @@ -92,16 +90,16 @@ def setup_method(self): def teardown_method(self): clear_db_import_errors() - def test_show_dag_dependencies_print(self): - with contextlib.redirect_stdout(StringIO()) as temp_stdout: + def test_show_dag_dependencies_print(self, stdout_capture): + with stdout_capture as temp_stdout: dag_command.dag_dependencies_show(self.parser.parse_args(["dags", "show-dependencies"])) out = temp_stdout.getvalue() assert "digraph" in out assert "graph [rankdir=LR]" in out @mock.patch("airflow.cli.commands.dag_command.render_dag_dependencies") - def test_show_dag_dependencies_save(self, mock_render_dag_dependencies): - with contextlib.redirect_stdout(StringIO()) as temp_stdout: + def test_show_dag_dependencies_save(self, mock_render_dag_dependencies, stdout_capture): + with stdout_capture as temp_stdout: dag_command.dag_dependencies_show( self.parser.parse_args(["dags", "show-dependencies", "--save", "output.png"]) ) @@ -111,8 +109,8 @@ def test_show_dag_dependencies_save(self, mock_render_dag_dependencies): ) assert "File output.png saved" in out - def test_show_dag_print(self): - with contextlib.redirect_stdout(StringIO()) as temp_stdout: + def test_show_dag_print(self, stdout_capture): + with stdout_capture as temp_stdout: dag_command.dag_show(self.parser.parse_args(["dags", "show", "example_bash_operator"])) out = temp_stdout.getvalue() assert "label=example_bash_operator" in out @@ -120,8 +118,8 @@ def test_show_dag_print(self): assert "runme_2 -> run_after_loop" in out @mock.patch("airflow.cli.commands.dag_command.render_dag") - def test_show_dag_save(self, mock_render_dag): - with contextlib.redirect_stdout(StringIO()) as temp_stdout: + def test_show_dag_save(self, mock_render_dag, stdout_capture): + with stdout_capture as temp_stdout: dag_command.dag_show( self.parser.parse_args(["dags", "show", "example_bash_operator", "--save", "awesome.png"]) ) @@ -133,13 +131,13 @@ def test_show_dag_save(self, mock_render_dag): @mock.patch("airflow.cli.commands.dag_command.subprocess.Popen") @mock.patch("airflow.cli.commands.dag_command.render_dag") - def test_show_dag_imgcat(self, mock_render_dag, mock_popen): + def test_show_dag_imgcat(self, mock_render_dag, mock_popen, stdout_capture): mock_render_dag.return_value.pipe.return_value = b"DOT_DATA" mock_proc = mock.MagicMock() mock_proc.returncode = 0 mock_proc.communicate.return_value = (b"OUT", b"ERR") mock_popen.return_value.__enter__.return_value = mock_proc - with contextlib.redirect_stdout(StringIO()) as temp_stdout: + with stdout_capture as temp_stdout: dag_command.dag_show( self.parser.parse_args(["dags", "show", "example_bash_operator", "--imgcat"]) ) @@ -149,7 +147,7 @@ def test_show_dag_imgcat(self, mock_render_dag, mock_popen): assert "OUT" in out assert "ERR" in out - def test_next_execution(self, tmp_path): + def test_next_execution(self, tmp_path, stdout_capture): dag_test_list = [ ("future_schedule_daily", "timedelta(days=5)", "'0 0 * * *'", "True"), ("future_schedule_every_4_hours", "timedelta(days=5)", "timedelta(hours=4)", "True"), @@ -206,14 +204,14 @@ def test_next_execution(self, tmp_path): for dag_id in expected_output: # Test num-executions = 1 (default) args = self.parser.parse_args(["dags", "next-execution", dag_id]) - with contextlib.redirect_stdout(StringIO()) as temp_stdout: + with stdout_capture as temp_stdout: dag_command.dag_next_execution(args) out = temp_stdout.getvalue() assert expected_output[dag_id][0] in out # Test num-executions = 2 args = self.parser.parse_args(["dags", "next-execution", dag_id, "--num-executions", "2"]) - with contextlib.redirect_stdout(StringIO()) as temp_stdout: + with stdout_capture as temp_stdout: dag_command.dag_next_execution(args) out = temp_stdout.getvalue() assert expected_output[dag_id][1] in out @@ -223,9 +221,9 @@ def test_next_execution(self, tmp_path): parse_and_sync_to_db(os.devnull, include_examples=True) @conf_vars({("core", "load_examples"): "true"}) - def test_cli_report(self): + def test_cli_report(self, stdout_capture): args = self.parser.parse_args(["dags", "report", "--output", "json"]) - with contextlib.redirect_stdout(StringIO()) as temp_stdout: + with stdout_capture as temp_stdout: dag_command.dag_report(args) out = temp_stdout.getvalue() @@ -233,9 +231,9 @@ def test_cli_report(self): assert "example_complex" in out @conf_vars({("core", "load_examples"): "true"}) - def test_cli_get_dag_details(self): + def test_cli_get_dag_details(self, stdout_capture): args = self.parser.parse_args(["dags", "details", "example_complex", "--output", "yaml"]) - with contextlib.redirect_stdout(StringIO()) as temp_stdout: + with stdout_capture as temp_stdout: dag_command.dag_details(args) out = temp_stdout.getvalue() @@ -250,22 +248,22 @@ def test_cli_get_dag_details(self): assert value in out @conf_vars({("core", "load_examples"): "true"}) - def test_cli_list_dags(self): + def test_cli_list_dags(self, stdout_capture): args = self.parser.parse_args(["dags", "list", "--output", "json"]) - with contextlib.redirect_stdout(StringIO()) as temp_stdout: + with stdout_capture as temp_stdout: dag_command.dag_list_dags(args) out = temp_stdout.getvalue() dag_list = json.loads(out) - for key in ["dag_id", "fileloc", "owners", "is_paused"]: + for key in ["dag_id", "fileloc", "owners", "is_paused"]: # "bundle_name", "bundle_version"? assert key in dag_list[0] assert any("airflow/example_dags/example_complex.py" in d["fileloc"] for d in dag_list) @conf_vars({("core", "load_examples"): "true"}) - def test_cli_list_local_dags(self): + def test_cli_list_local_dags(self, stdout_capture): # Clear the database clear_db_dags() args = self.parser.parse_args(["dags", "list", "--output", "json", "--local"]) - with contextlib.redirect_stdout(StringIO()) as temp_stdout: + with stdout_capture as temp_stdout: dag_command.dag_list_dags(args) out = temp_stdout.getvalue() dag_list = json.loads(out) @@ -276,7 +274,7 @@ def test_cli_list_local_dags(self): parse_and_sync_to_db(os.devnull, include_examples=True) @conf_vars({("core", "load_examples"): "false"}) - def test_cli_list_local_dags_with_bundle_name(self, configure_testing_dag_bundle): + def test_cli_list_local_dags_with_bundle_name(self, configure_testing_dag_bundle, stdout_capture): # Clear the database clear_db_dags() path_to_parse = TEST_DAGS_FOLDER / "test_example_bash_operator.py" @@ -284,7 +282,7 @@ def test_cli_list_local_dags_with_bundle_name(self, configure_testing_dag_bundle ["dags", "list", "--output", "json", "--local", "--bundle-name", "testing"] ) with configure_testing_dag_bundle(path_to_parse): - with contextlib.redirect_stdout(StringIO()) as temp_stdout: + with stdout_capture as temp_stdout: dag_command.dag_list_dags(args) out = temp_stdout.getvalue() dag_list = json.loads(out) @@ -297,11 +295,11 @@ def test_cli_list_local_dags_with_bundle_name(self, configure_testing_dag_bundle parse_and_sync_to_db(os.devnull, include_examples=True) @conf_vars({("core", "load_examples"): "true"}) - def test_cli_list_dags_custom_cols(self): + def test_cli_list_dags_custom_cols(self, stdout_capture): args = self.parser.parse_args( ["dags", "list", "--output", "json", "--columns", "dag_id,last_parsed_time"] ) - with contextlib.redirect_stdout(StringIO()) as temp_stdout: + with stdout_capture as temp_stdout: dag_command.dag_list_dags(args) out = temp_stdout.getvalue() dag_list = json.loads(out) @@ -311,29 +309,33 @@ def test_cli_list_dags_custom_cols(self): assert key not in dag_list[0] @conf_vars({("core", "load_examples"): "true"}) - def test_cli_list_dags_invalid_cols(self): + def test_cli_list_dags_invalid_cols(self, stderr_capture): args = self.parser.parse_args(["dags", "list", "--output", "json", "--columns", "dag_id,invalid_col"]) - with contextlib.redirect_stderr(StringIO()) as temp_stderr: + with stderr_capture as temp_stderr: dag_command.dag_list_dags(args) out = temp_stderr.getvalue() assert "Ignoring the following invalid columns: ['invalid_col']" in out @conf_vars({("core", "load_examples"): "false"}) - def test_cli_list_dags_prints_import_errors(self, configure_testing_dag_bundle, get_test_dag): + def test_cli_list_dags_prints_import_errors( + self, configure_testing_dag_bundle, get_test_dag, stderr_capture + ): path_to_parse = TEST_DAGS_FOLDER / "test_invalid_cron.py" get_test_dag("test_invalid_cron") args = self.parser.parse_args(["dags", "list", "--output", "yaml", "--bundle-name", "testing"]) with configure_testing_dag_bundle(path_to_parse): - with contextlib.redirect_stderr(StringIO()) as temp_stderr: + with stderr_capture as temp_stderr: dag_command.dag_list_dags(args) out = temp_stderr.getvalue() assert "Failed to load all files." in out @conf_vars({("core", "load_examples"): "false"}) - def test_cli_list_dags_prints_local_import_errors(self, configure_testing_dag_bundle, get_test_dag): + def test_cli_list_dags_prints_local_import_errors( + self, configure_testing_dag_bundle, get_test_dag, stderr_capture + ): # Clear the database clear_db_dags() path_to_parse = TEST_DAGS_FOLDER / "test_invalid_cron.py" @@ -344,7 +346,7 @@ def test_cli_list_dags_prints_local_import_errors(self, configure_testing_dag_bu ) with configure_testing_dag_bundle(path_to_parse): - with contextlib.redirect_stderr(StringIO()) as temp_stderr: + with stderr_capture as temp_stderr: dag_command.dag_list_dags(args) out = temp_stderr.getvalue() @@ -354,10 +356,10 @@ def test_cli_list_dags_prints_local_import_errors(self, configure_testing_dag_bu @conf_vars({("core", "load_examples"): "true"}) @mock.patch("airflow.models.DagModel.get_dagmodel") - def test_list_dags_none_get_dagmodel(self, mock_get_dagmodel): + def test_list_dags_none_get_dagmodel(self, mock_get_dagmodel, stdout_capture): mock_get_dagmodel.return_value = None args = self.parser.parse_args(["dags", "list", "--output", "json"]) - with contextlib.redirect_stdout(StringIO()) as temp_stdout: + with stdout_capture as temp_stdout: dag_command.dag_list_dags(args) out = temp_stdout.getvalue() dag_list = json.loads(out) @@ -468,33 +470,33 @@ def test_pause_regex_yes(self, mock_yesno): mock_yesno.assert_not_called() dag_command.dag_unpause(args) - def test_pause_non_existing_dag_do_not_error(self): + def test_pause_non_existing_dag_do_not_error(self, stdout_capture): args = self.parser.parse_args(["dags", "pause", "non_existing_dag"]) - with contextlib.redirect_stdout(StringIO()) as temp_stdout: + with stdout_capture as temp_stdout: dag_command.dag_pause(args) - out = temp_stdout.getvalue().strip().splitlines()[-1] + out = temp_stdout.splitlines()[-1] assert out == "No unpaused DAGs were found" - def test_unpause_non_existing_dag_do_not_error(self): + def test_unpause_non_existing_dag_do_not_error(self, stdout_capture): args = self.parser.parse_args(["dags", "unpause", "non_existing_dag"]) - with contextlib.redirect_stdout(StringIO()) as temp_stdout: + with stdout_capture as temp_stdout: dag_command.dag_unpause(args) - out = temp_stdout.getvalue().strip().splitlines()[-1] + out = temp_stdout.splitlines()[-1] assert out == "No paused DAGs were found" - def test_unpause_already_unpaused_dag_do_not_error(self): + def test_unpause_already_unpaused_dag_do_not_error(self, stdout_capture): args = self.parser.parse_args(["dags", "unpause", "example_bash_operator", "--yes"]) - with contextlib.redirect_stdout(StringIO()) as temp_stdout: + with stdout_capture as temp_stdout: dag_command.dag_unpause(args) - out = temp_stdout.getvalue().strip().splitlines()[-1] + out = temp_stdout.splitlines()[-1] assert out == "No paused DAGs were found" - def test_pausing_already_paused_dag_do_not_error(self): + def test_pausing_already_paused_dag_do_not_error(self, stdout_capture): args = self.parser.parse_args(["dags", "pause", "example_bash_operator", "--yes"]) - with contextlib.redirect_stdout(StringIO()) as temp_stdout: + with stdout_capture as temp_stdout: dag_command.dag_pause(args) dag_command.dag_pause(args) - out = temp_stdout.getvalue().strip().splitlines()[-1] + out = temp_stdout.splitlines()[-1] assert out == "No unpaused DAGs were found" def test_trigger_dag(self): @@ -560,7 +562,7 @@ def test_trigger_dag_invalid_conf(self): ), ) - def test_trigger_dag_output_as_json(self): + def test_trigger_dag_output_as_json(self, stdout_capture): args = self.parser.parse_args( [ "dags", @@ -573,7 +575,7 @@ def test_trigger_dag_output_as_json(self): "--output=json", ] ) - with contextlib.redirect_stdout(StringIO()) as temp_stdout: + with stdout_capture as temp_stdout: dag_command.dag_trigger(args) # get the last line from the logs ignoring all logging lines out = temp_stdout.getvalue().strip().splitlines()[-1] @@ -731,13 +733,13 @@ def test_dag_test_conf(self, mock_get_dag): @mock.patch("airflow.cli.commands.dag_command.render_dag", return_value=MagicMock(source="SOURCE")) @mock.patch("airflow.cli.commands.dag_command.get_dag") - def test_dag_test_show_dag(self, mock_get_dag, mock_render_dag): + def test_dag_test_show_dag(self, mock_get_dag, mock_render_dag, stdout_capture): mock_get_dag.return_value.test.return_value.run_id = "__test_dag_test_show_dag_fake_dag_run_run_id__" cli_args = self.parser.parse_args( ["dags", "test", "example_bash_operator", DEFAULT_DATE.isoformat(), "--show-dagrun"] ) - with contextlib.redirect_stdout(StringIO()) as stdout: + with stdout_capture as stdout: dag_command.dag_test(cli_args) output = stdout.getvalue() diff --git a/airflow-core/tests/unit/cli/commands/test_info_command.py b/airflow-core/tests/unit/cli/commands/test_info_command.py index 1db61aa6346e2..6563b97979908 100644 --- a/airflow-core/tests/unit/cli/commands/test_info_command.py +++ b/airflow-core/tests/unit/cli/commands/test_info_command.py @@ -16,11 +16,9 @@ # under the License. from __future__ import annotations -import contextlib import importlib import logging import os -from io import StringIO from unittest import mock import httpx @@ -150,8 +148,8 @@ def test_tools_info(self): ("database", "sql_alchemy_conn"): "postgresql+psycopg2://postgres:airflow@postgres/airflow", } ) - def test_show_info(self): - with contextlib.redirect_stdout(StringIO()) as stdout: + def test_show_info(self, stdout_capture): + with stdout_capture as stdout: info_command.show_info(self.parser.parse_args(["info"])) output = stdout.getvalue() @@ -164,8 +162,8 @@ def test_show_info(self): ("database", "sql_alchemy_conn"): "postgresql+psycopg2://postgres:airflow@postgres/airflow", } ) - def test_show_info_anonymize(self): - with contextlib.redirect_stdout(StringIO()) as stdout: + def test_show_info_anonymize(self, stdout_capture): + with stdout_capture as stdout: info_command.show_info(self.parser.parse_args(["info", "--anonymize"])) output = stdout.getvalue() @@ -184,7 +182,7 @@ class TestInfoCommandMockHttpx: ("database", "sql_alchemy_conn"): "postgresql+psycopg2://postgres:airflow@postgres/airflow", } ) - def test_show_info_anonymize_fileio(self, setup_parser, cleanup_providers_manager): + def test_show_info_anonymize_fileio(self, setup_parser, cleanup_providers_manager, stdout_capture): with mock.patch("airflow.cli.commands.info_command.httpx.post") as post: post.return_value = httpx.Response( status_code=200, @@ -195,6 +193,6 @@ def test_show_info_anonymize_fileio(self, setup_parser, cleanup_providers_manage "expiry": "14 days", }, ) - with contextlib.redirect_stdout(StringIO()) as stdout: + with stdout_capture as stdout: info_command.show_info(setup_parser.parse_args(["info", "--file-io", "--anonymize"])) assert "https://file.io/TEST" in stdout.getvalue() diff --git a/airflow-core/tests/unit/cli/commands/test_jobs_command.py b/airflow-core/tests/unit/cli/commands/test_jobs_command.py index fa575e7917fd5..7219b8b45094d 100644 --- a/airflow-core/tests/unit/cli/commands/test_jobs_command.py +++ b/airflow-core/tests/unit/cli/commands/test_jobs_command.py @@ -16,9 +16,6 @@ # under the License. from __future__ import annotations -import contextlib -from io import StringIO - import pytest from airflow.cli import cli_parser @@ -45,7 +42,7 @@ def setup_method(self) -> None: def teardown_method(self) -> None: clear_db_jobs() - def test_should_report_success_for_one_working_scheduler(self): + def test_should_report_success_for_one_working_scheduler(self, stdout_capture): with create_session() as session: self.scheduler_job = Job() self.job_runner = SchedulerJobRunner(job=self.scheduler_job) @@ -54,11 +51,11 @@ def test_should_report_success_for_one_working_scheduler(self): session.commit() self.scheduler_job.heartbeat(heartbeat_callback=self.job_runner.heartbeat_callback) - with contextlib.redirect_stdout(StringIO()) as temp_stdout: + with stdout_capture as temp_stdout: jobs_command.check(self.parser.parse_args(["jobs", "check", "--job-type", "SchedulerJob"])) assert "Found one alive job." in temp_stdout.getvalue() - def test_should_report_success_for_one_working_scheduler_with_hostname(self): + def test_should_report_success_for_one_working_scheduler_with_hostname(self, stdout_capture): with create_session() as session: self.scheduler_job = Job() self.job_runner = SchedulerJobRunner(job=self.scheduler_job) @@ -68,7 +65,7 @@ def test_should_report_success_for_one_working_scheduler_with_hostname(self): session.commit() self.scheduler_job.heartbeat(heartbeat_callback=self.job_runner.heartbeat_callback) - with contextlib.redirect_stdout(StringIO()) as temp_stdout: + with stdout_capture as temp_stdout: jobs_command.check( self.parser.parse_args( ["jobs", "check", "--job-type", "SchedulerJob", "--hostname", "HOSTNAME"] @@ -76,7 +73,7 @@ def test_should_report_success_for_one_working_scheduler_with_hostname(self): ) assert "Found one alive job." in temp_stdout.getvalue() - def test_should_report_success_for_ha_schedulers(self): + def test_should_report_success_for_ha_schedulers(self, stdout_capture): scheduler_jobs = [] job_runners = [] with create_session() as session: @@ -89,7 +86,7 @@ def test_should_report_success_for_ha_schedulers(self): job_runners.append(job_runner) session.commit() scheduler_job.heartbeat(heartbeat_callback=job_runner.heartbeat_callback) - with contextlib.redirect_stdout(StringIO()) as temp_stdout: + with stdout_capture as temp_stdout: jobs_command.check( self.parser.parse_args( ["jobs", "check", "--job-type", "SchedulerJob", "--limit", "100", "--allow-multiple"] diff --git a/airflow-core/tests/unit/cli/commands/test_legacy_commands.py b/airflow-core/tests/unit/cli/commands/test_legacy_commands.py index e4c39b9c68e55..e945296363696 100644 --- a/airflow-core/tests/unit/cli/commands/test_legacy_commands.py +++ b/airflow-core/tests/unit/cli/commands/test_legacy_commands.py @@ -16,9 +16,7 @@ # under the License. from __future__ import annotations -import contextlib from argparse import ArgumentError -from io import StringIO from unittest.mock import MagicMock import pytest @@ -33,8 +31,8 @@ class TestCliDeprecatedCommandsValue: def setup_class(cls): cls.parser = cli_parser.get_parser() - def test_should_display_value(self): - with pytest.raises(SystemExit) as ctx, contextlib.redirect_stderr(StringIO()) as temp_stderr: + def test_should_display_value(self, stderr_capture): + with pytest.raises(SystemExit) as ctx, stderr_capture as temp_stderr: config_command.get_value(self.parser.parse_args(["webserver"])) assert ctx.value.code == 2 diff --git a/airflow-core/tests/unit/cli/commands/test_pool_command.py b/airflow-core/tests/unit/cli/commands/test_pool_command.py index 67cc39dec4cab..98248abd5a669 100644 --- a/airflow-core/tests/unit/cli/commands/test_pool_command.py +++ b/airflow-core/tests/unit/cli/commands/test_pool_command.py @@ -18,8 +18,6 @@ from __future__ import annotations import json -from contextlib import redirect_stdout -from io import StringIO import pytest @@ -54,9 +52,9 @@ def _cleanup(session=None): add_default_pool_if_not_exists() session.close() - def test_pool_list(self): + def test_pool_list(self, stdout_capture): pool_command.pool_set(self.parser.parse_args(["pools", "set", "foo", "1", "test"])) - with redirect_stdout(StringIO()) as stdout: + with stdout_capture as stdout: pool_command.pool_list(self.parser.parse_args(["pools", "list"])) assert "foo" in stdout.getvalue() diff --git a/airflow-core/tests/unit/cli/commands/test_variable_command.py b/airflow-core/tests/unit/cli/commands/test_variable_command.py index bfb459d99ec15..053ffa5b48f3a 100644 --- a/airflow-core/tests/unit/cli/commands/test_variable_command.py +++ b/airflow-core/tests/unit/cli/commands/test_variable_command.py @@ -19,8 +19,6 @@ import json import os -from contextlib import redirect_stdout -from io import StringIO import pytest from sqlalchemy import select @@ -71,15 +69,15 @@ def test_variables_set_with_description(self): with pytest.raises(KeyError): Variable.get("foo1") - def test_variables_get(self): + def test_variables_get(self, stdout_capture): Variable.set("foo", {"foo": "bar"}, serialize_json=True) - with redirect_stdout(StringIO()) as stdout: + with stdout_capture as stdout: variable_command.variables_get(self.parser.parse_args(["variables", "get", "foo"])) assert stdout.getvalue() == '{\n "foo": "bar"\n}\n' - def test_get_variable_default_value(self): - with redirect_stdout(StringIO()) as stdout: + def test_get_variable_default_value(self, stdout_capture): + with stdout_capture as stdout: variable_command.variables_get( self.parser.parse_args(["variables", "get", "baz", "--default", "bar"]) ) @@ -192,7 +190,7 @@ def test_variables_isolation(self, tmp_path): assert path1.read_text() == path2.read_text() def test_variables_import_and_export_with_description(self, tmp_path): - """Test variables_import with file-description parameted""" + """Test variables_import with file-description parameter""" variables_types_file = tmp_path / "variables_types.json" variable_command.variables_set( self.parser.parse_args(["variables", "set", "foo", "bar", "--description", "Foo var description"]) diff --git a/airflow-core/tests/unit/cli/commands/test_version_command.py b/airflow-core/tests/unit/cli/commands/test_version_command.py index bcf2e303c1fd3..8c3ece02175e8 100644 --- a/airflow-core/tests/unit/cli/commands/test_version_command.py +++ b/airflow-core/tests/unit/cli/commands/test_version_command.py @@ -16,9 +16,6 @@ # under the License. from __future__ import annotations -from contextlib import redirect_stdout -from io import StringIO - import airflow.cli.commands.version_command from airflow.cli import cli_parser from airflow.version import version @@ -29,7 +26,7 @@ class TestCliVersion: def setup_class(cls): cls.parser = cli_parser.get_parser() - def test_cli_version(self): - with redirect_stdout(StringIO()) as stdout: + def test_cli_version(self, stdout_capture): + with stdout_capture as stdout: airflow.cli.commands.version_command.version(self.parser.parse_args(["version"])) assert version in stdout.getvalue() diff --git a/airflow-core/tests/unit/cli/conftest.py b/airflow-core/tests/unit/cli/conftest.py index 5c46e0aa01bb0..c7230d00e534c 100644 --- a/airflow-core/tests/unit/cli/conftest.py +++ b/airflow-core/tests/unit/cli/conftest.py @@ -27,6 +27,12 @@ from airflow.providers.cncf.kubernetes.executors import kubernetes_executor from tests_common.test_utils.config import conf_vars +from tests_common.test_utils.stream_capture_manager import ( + CombinedCaptureManager, + StderrCaptureManager, + StdoutCaptureManager, + StreamCaptureManager, +) # Create custom executors here because conftest is imported first custom_executor_module = type(sys)("custom_executor") @@ -58,3 +64,31 @@ def parser(): from airflow.cli import cli_parser return cli_parser.get_parser() + + +@pytest.fixture +def stdout_capture(): + """Fixture that captures stdout only.""" + return StdoutCaptureManager() + + +@pytest.fixture +def stderr_capture(): + """Fixture that captures stderr only.""" + return StderrCaptureManager() + + +@pytest.fixture +def stream_capture(): + """Fixture that returns a configurable stream capture manager.""" + + def _capture(stdout=True, stderr=False): + return StreamCaptureManager(capture_stdout=stdout, capture_stderr=stderr) + + return _capture + + +@pytest.fixture +def combined_capture(): + """Fixture that captures both stdout and stderr.""" + return CombinedCaptureManager() diff --git a/devel-common/src/tests_common/test_utils/stream_capture_manager.py b/devel-common/src/tests_common/test_utils/stream_capture_manager.py new file mode 100644 index 0000000000000..74646ea2fa57a --- /dev/null +++ b/devel-common/src/tests_common/test_utils/stream_capture_manager.py @@ -0,0 +1,172 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + + +class StreamCaptureManager: + """Context manager class for capturing stdout and/or stderr while isolating from Logger output.""" + + def __init__(self, capture_stdout=True, capture_stderr=False): + from io import StringIO + + self.capture_stdout = capture_stdout + self.capture_stderr = capture_stderr + + self._stdout_buffer = StringIO() if capture_stdout else None + self._stderr_buffer = StringIO() if capture_stderr else None + + self.original_handlers = [] + self._stdout_final = "" + self._stderr_final = "" + self._in_context = False + + # Store original streams + self._original_stdout = None + self._original_stderr = None + + @property + def stdout(self) -> str: + """Get captured stdout content.""" + if self._in_context and self._stdout_buffer and not self._stdout_buffer.closed: + return self._stdout_buffer.getvalue() + return self._stdout_final + + @property + def stderr(self) -> str: + """Get captured stderr content.""" + if self._in_context and self._stderr_buffer and not self._stderr_buffer.closed: + return self._stderr_buffer.getvalue() + return self._stderr_final + + def getvalue(self) -> str: + """Get captured content. For backward compatibility, returns stdout by default.""" + return self.stdout if self.capture_stdout else self.stderr + + def get_combined(self) -> str: + """Get combined stdout and stderr content.""" + parts = [] + if self.capture_stdout: + parts.append(self.stdout) + if self.capture_stderr: + parts.append(self.stderr) + return "".join(parts) + + def splitlines(self) -> list[str]: + """Split captured content into lines.""" + content = self.getvalue() + if not content: + return [""] # Return list with empty string to avoid IndexError + return content.splitlines() + + def __enter__(self): + import logging + import sys + from contextlib import redirect_stderr, redirect_stdout + + self._in_context = True + + # Setup logging isolation + root_logger = logging.getLogger() + self.original_handlers = list(root_logger.handlers) + + # Remove stream handlers that would interfere with capture + handlers_to_remove = [] + for handler in self.original_handlers: + if isinstance(handler, logging.StreamHandler): + if self.capture_stdout and handler.stream == sys.stdout: + handlers_to_remove.append(handler) + elif self.capture_stderr and handler.stream == sys.stderr: + handlers_to_remove.append(handler) + + for handler in handlers_to_remove: + root_logger.removeHandler(handler) + + # Set up context managers for redirection + self._context_managers = [] + + if self.capture_stdout: + self._stdout_redirect = redirect_stdout(self._stdout_buffer) + self._context_managers.append(self._stdout_redirect) + + if self.capture_stderr: + self._stderr_redirect = redirect_stderr(self._stderr_buffer) + self._context_managers.append(self._stderr_redirect) + + # Enter all context managers + for cm in self._context_managers: + cm.__enter__() + + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + import logging + import sys + + self._in_context = False + + # Capture content BEFORE cleaning up + if self._stdout_buffer: + try: + self._stdout_final = self._stdout_buffer.getvalue() + except (ValueError, AttributeError): + self._stdout_final = "" + + if self._stderr_buffer: + try: + self._stderr_final = self._stderr_buffer.getvalue() + except (ValueError, AttributeError): + self._stderr_final = "" + + # Exit all context managers in reverse order + for cm in reversed(self._context_managers): + try: + cm.__exit__(exc_type, exc_val, exc_tb) + except Exception: + pass # Don't let cleanup failures mask the real error + + # Restore logging handlers + root_logger = logging.getLogger() + for handler in self.original_handlers: + if isinstance(handler, logging.StreamHandler): + if (self.capture_stdout and handler.stream == sys.stdout) or ( + self.capture_stderr and handler.stream == sys.stderr + ): + if handler not in root_logger.handlers: + root_logger.addHandler(handler) + + +# Convenience classes +class StdoutCaptureManager(StreamCaptureManager): + """Convenience class for stdout-only capture.""" + + def __init__(self): + super().__init__(capture_stdout=True, capture_stderr=False) + + +class StderrCaptureManager(StreamCaptureManager): + """Convenience class for stderr-only capture.""" + + def __init__(self): + super().__init__(capture_stdout=False, capture_stderr=True) + + +class CombinedCaptureManager(StreamCaptureManager): + """Convenience class for capturing both stdout and stderr.""" + + def __init__(self): + super().__init__(capture_stdout=True, capture_stderr=True) diff --git a/devel-common/tests/unit/tests_common/conftest.py b/devel-common/tests/unit/tests_common/conftest.py new file mode 100644 index 0000000000000..01fc944104d8e --- /dev/null +++ b/devel-common/tests/unit/tests_common/conftest.py @@ -0,0 +1,55 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pytest + +from tests_common.test_utils.stream_capture_manager import ( + CombinedCaptureManager, + StderrCaptureManager, + StdoutCaptureManager, + StreamCaptureManager, +) + + +@pytest.fixture +def stdout_capture(): + """Fixture that captures stdout only.""" + return StdoutCaptureManager() + + +@pytest.fixture +def stderr_capture(): + """Fixture that captures stderr only.""" + return StderrCaptureManager() + + +@pytest.fixture +def stream_capture(): + """Fixture that returns a configurable stream capture manager.""" + + def _capture(stdout=True, stderr=False): + return StreamCaptureManager(capture_stdout=stdout, capture_stderr=stderr) + + return _capture + + +@pytest.fixture +def combined_capture(): + """Fixture that captures both stdout and stderr.""" + return CombinedCaptureManager() diff --git a/devel-common/tests/unit/tests_common/test_utils/test_stream_capture_manager.py b/devel-common/tests/unit/tests_common/test_utils/test_stream_capture_manager.py new file mode 100644 index 0000000000000..880734727bb68 --- /dev/null +++ b/devel-common/tests/unit/tests_common/test_utils/test_stream_capture_manager.py @@ -0,0 +1,477 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Unit tests for the StreamCaptureManager class used in Airflow CLI tests.""" + +from __future__ import annotations + +import logging +import sys + +import pytest + + +def test_stdout_only(stdout_capture): + """Test capturing stdout only.""" + with stdout_capture as capture: + print("Hello stdout") + print("Error message", file=sys.stderr) + + # Access during context + assert "Hello stdout" in capture.getvalue() + assert "Error message" not in capture.getvalue() + + +def test_stderr_only(stderr_capture): + """Test capturing stderr only.""" + with stderr_capture as capture: + print("Hello stdout") + print("Error message", file=sys.stderr) + + assert "Error message" in capture.getvalue() + assert "Hello stdout" not in capture.getvalue() + + +def test_combined(combined_capture): + """Test capturing both streams.""" + with combined_capture as capture: + print("Hello stdout") + print("Error message", file=sys.stderr) + + assert "Hello stdout" in capture.stdout + assert "Error message" in capture.stderr + assert "Hello stdout" in capture.get_combined() + assert "Error message" in capture.get_combined() + + +def test_configurable(stream_capture): + """Test with configurable capture.""" + # Capture both + with stream_capture(stdout=True, stderr=True) as capture: + print("stdout message") + print("stderr message", file=sys.stderr) + + assert "stdout message" in capture.stdout + assert "stderr message" in capture.stderr + + # Capture stderr only + with stream_capture(stdout=False, stderr=True) as capture: + print("stdout message") + print("stderr message", file=sys.stderr) + + assert capture.stdout == "" + assert "stderr message" in capture.stderr + + +# ============== Tests for Logging Isolation ============== + + +def test_stdout_logging_isolation(stdout_capture): + """Test that logging to stdout is isolated from captured output.""" + # Set up a logger that writes to stdout + logger = logging.getLogger("test_stdout_logger") + logger.setLevel(logging.INFO) + + # Create handler that writes to stdout + stdout_handler = logging.StreamHandler(sys.stdout) + stdout_handler.setFormatter(logging.Formatter("%(levelname)s - %(message)s")) + logger.addHandler(stdout_handler) + + try: + with stdout_capture as capture: + # Regular print should be captured + print("Regular print to stdout") + + # Logging should NOT be captured (isolated) + logger.info("This is a log message to stdout") + logger.warning("This is a warning to stdout") + + # Another regular print + print("Another regular print") + + output = capture.getvalue() + + # Regular prints should be in output + assert "Regular print to stdout" in output + assert "Another regular print" in output + + # Log messages should NOT be in output + assert "This is a log message to stdout" not in output + assert "This is a warning to stdout" not in output + assert "INFO" not in output + assert "WARNING" not in output + finally: + # Clean up + logger.removeHandler(stdout_handler) + + +def test_stderr_logging_isolation(stderr_capture): + """Test that logging to stderr is isolated from captured output.""" + # Set up a logger that writes to stderr + logger = logging.getLogger("test_stderr_logger") + logger.setLevel(logging.INFO) + + # Create handler that writes to stderr + stderr_handler = logging.StreamHandler(sys.stderr) + stderr_handler.setFormatter(logging.Formatter("%(levelname)s - %(message)s")) + logger.addHandler(stderr_handler) + + try: + with stderr_capture as capture: + # Regular print to stderr should be captured + print("Regular print to stderr", file=sys.stderr) + + # Logging should NOT be captured (isolated) + logger.error("This is an error log to stderr") + logger.critical("This is a critical log to stderr") + + # Another regular print + print("Another stderr print", file=sys.stderr) + + output = capture.getvalue() + + # Regular prints should be in output + assert "Regular print to stderr" in output + assert "Another stderr print" in output + + # Log messages should NOT be in output + assert "This is an error log to stderr" not in output + assert "This is a critical log to stderr" not in output + assert "ERROR" not in output + assert "CRITICAL" not in output + finally: + # Clean up + logger.removeHandler(stderr_handler) + + +def test_combined_logging_isolation(combined_capture): + """Test that logging is isolated when capturing both stdout and stderr.""" + # Set up loggers for both streams + stdout_logger = logging.getLogger("test_combined_stdout") + stderr_logger = logging.getLogger("test_combined_stderr") + + stdout_logger.setLevel(logging.INFO) + stderr_logger.setLevel(logging.INFO) + + stdout_handler = logging.StreamHandler(sys.stdout) + stderr_handler = logging.StreamHandler(sys.stderr) + + stdout_handler.setFormatter(logging.Formatter("[STDOUT LOG] %(message)s")) + stderr_handler.setFormatter(logging.Formatter("[STDERR LOG] %(message)s")) + + stdout_logger.addHandler(stdout_handler) + stderr_logger.addHandler(stderr_handler) + + try: + with combined_capture as capture: + # Regular prints + print("Regular stdout") + print("Regular stderr", file=sys.stderr) + + # Logging (should be isolated) + stdout_logger.info("Log to stdout") + stderr_logger.error("Log to stderr") + + # Check stdout + assert "Regular stdout" in capture.stdout + assert "Log to stdout" not in capture.stdout + assert "[STDOUT LOG]" not in capture.stdout + + # Check stderr + assert "Regular stderr" in capture.stderr + assert "Log to stderr" not in capture.stderr + assert "[STDERR LOG]" not in capture.stderr + + # Combined should have regular prints but not logs + combined = capture.get_combined() + assert "Regular stdout" in combined + assert "Regular stderr" in combined + assert "Log to stdout" not in combined + assert "Log to stderr" not in combined + finally: + # Clean up + stdout_logger.removeHandler(stdout_handler) + stderr_logger.removeHandler(stderr_handler) + + +def test_root_logger_isolation(stdout_capture): + """Test that root logger messages are isolated from captured output.""" + # Configure root logger to output to stdout + root_logger = logging.getLogger() + original_level = root_logger.level + root_logger.setLevel(logging.DEBUG) + + # Add a stdout handler to root logger + handler = logging.StreamHandler(sys.stdout) + handler.setFormatter(logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")) + root_logger.addHandler(handler) + + try: + with stdout_capture as capture: + # Regular print + print("Before logging") + + # Various log levels using root logger explicitly + root_logger.debug("Debug message") + root_logger.info("Info message") + root_logger.warning("Warning message") + root_logger.error("Error message") + + # Regular print + print("After logging") + + output = capture.getvalue() + + # Only regular prints should be captured + assert "Before logging" in output + assert "After logging" in output + + # No log messages should be captured + assert "Debug message" not in output + assert "Info message" not in output + assert "Warning message" not in output + assert "Error message" not in output + + # No log formatting should be present + assert "DEBUG" not in output + assert "INFO" not in output + assert "WARNING" not in output + assert "ERROR" not in output + assert "%(asctime)s" not in output + finally: + # Clean up + root_logger.removeHandler(handler) + root_logger.setLevel(original_level) + + +def test_mixed_output_ordering(stdout_capture): + """Test that the order of regular prints is preserved when logging is mixed in.""" + logger = logging.getLogger("test_ordering") + logger.setLevel(logging.INFO) + + handler = logging.StreamHandler(sys.stdout) + handler.setFormatter(logging.Formatter("LOG: %(message)s")) + logger.addHandler(handler) + + try: + with stdout_capture as capture: + print("1. First print") + logger.info("Should not appear 1") + print("2. Second print") + logger.info("Should not appear 2") + print("3. Third print") + + output = capture.getvalue() + lines = output.strip().split("\n") + finally: + logger.removeHandler(handler) + + # Should have exactly 3 lines (only the prints) + assert len(lines) == 3 + assert lines[0] == "1. First print" + assert lines[1] == "2. Second print" + assert lines[2] == "3. Third print" + + # No log messages + assert "Should not appear" not in output + assert "LOG:" not in output + + +def test_handler_restoration(stdout_capture): + """Test that logging handlers are properly restored after capture.""" + root_logger = logging.getLogger() + + # Add a test handler to root logger to ensure we have something to test + test_root_handler = logging.StreamHandler(sys.stdout) + test_root_handler.setFormatter(logging.Formatter("ROOT: %(message)s")) + root_logger.addHandler(test_root_handler) + + # Also create a non-root logger with its own handler + logger = logging.getLogger("test_restoration") + logger.setLevel(logging.INFO) + + handler = logging.StreamHandler(sys.stdout) + handler.setFormatter(logging.Formatter("TEST: %(message)s")) + logger.addHandler(handler) + + try: + # Record initial state + initial_root_handlers = list(root_logger.handlers) + assert test_root_handler in initial_root_handlers + + # Use the capture context + with stdout_capture: + print("Inside capture") + logger.info("Log inside capture") + + # During capture, our test handler should be removed from root + current_root_handlers = list(root_logger.handlers) + assert test_root_handler not in current_root_handlers, ( + "Test handler should be removed during capture" + ) + + # The non-root logger's handler should still exist + assert handler in logger.handlers + + # After capture, root logger handlers should be restored + final_root_handlers = list(root_logger.handlers) + assert test_root_handler in final_root_handlers, "Test handler should be restored after capture" + assert len(final_root_handlers) == len(initial_root_handlers), ( + f"Handler count mismatch. Initial: {len(initial_root_handlers)}, Final: {len(final_root_handlers)}" + ) + + finally: + # Clean up + root_logger.removeHandler(test_root_handler) + logger.removeHandler(handler) + + +def test_multiple_loggers_isolation(stream_capture): + """Test isolation works with multiple loggers writing to different streams.""" + # Create multiple loggers + app_logger = logging.getLogger("app") + db_logger = logging.getLogger("database") + api_logger = logging.getLogger("api") + + # Set up handlers + app_handler = logging.StreamHandler(sys.stdout) + db_handler = logging.StreamHandler(sys.stderr) + api_handler = logging.StreamHandler(sys.stdout) + + app_handler.setFormatter(logging.Formatter("[APP] %(message)s")) + db_handler.setFormatter(logging.Formatter("[DB] %(message)s")) + api_handler.setFormatter(logging.Formatter("[API] %(message)s")) + + app_logger.addHandler(app_handler) + db_logger.addHandler(db_handler) + api_logger.addHandler(api_handler) + + app_logger.setLevel(logging.INFO) + db_logger.setLevel(logging.INFO) + api_logger.setLevel(logging.INFO) + + try: + with stream_capture(stdout=True, stderr=True) as capture: + # Regular prints + print("Starting application") + print("Database connection error", file=sys.stderr) + + # Logs (should be isolated) + app_logger.info("App initialized") + db_logger.error("Connection failed") + api_logger.info("API server started") + + # More regular prints + print("Application ready") + print("Check error log", file=sys.stderr) + + # Verify stdout + assert "Starting application" in capture.stdout + assert "Application ready" in capture.stdout + assert "[APP]" not in capture.stdout + assert "[API]" not in capture.stdout + assert "App initialized" not in capture.stdout + assert "API server started" not in capture.stdout + + # Verify stderr + assert "Database connection error" in capture.stderr + assert "Check error log" in capture.stderr + assert "[DB]" not in capture.stderr + assert "Connection failed" not in capture.stderr + finally: + # Clean up + app_logger.removeHandler(app_handler) + db_logger.removeHandler(db_handler) + api_logger.removeHandler(api_handler) + + +def test_exception_during_capture_preserves_isolation(stdout_capture): + """Test that logging isolation is maintained even when an exception occurs.""" + logger = logging.getLogger("test_exception") + logger.setLevel(logging.INFO) + + handler = logging.StreamHandler(sys.stdout) + handler.setFormatter(logging.Formatter("LOG: %(message)s")) + logger.addHandler(handler) + + # Test setup and isolation check before exception + captured_output = None + + try: + with stdout_capture as capture: + print("Before exception") + logger.info("Log before exception") + + # Check isolation before exception + captured_output = capture.getvalue() + assert "Before exception" in captured_output + assert "Log before exception" not in captured_output + + # This will raise an exception + raise ValueError("Test exception") + except ValueError: + # Expected exception + pass + + # Verify captured output was correct + assert captured_output is not None + assert "Before exception" in captured_output + assert "Log before exception" not in captured_output + + # Verify handlers are restored after exception + root_logger = logging.getLogger() + # Just verify no crash when accessing handlers + restored_handlers = list(root_logger.handlers) + assert isinstance(restored_handlers, list) + + # Clean up + logger.removeHandler(handler) + + +def test_exception_during_capture_with_pytest_raises(stdout_capture): + """Test exception handling with proper pytest.raises usage.""" + logger = logging.getLogger("test_exception_pytest") + logger.setLevel(logging.INFO) + + handler = logging.StreamHandler(sys.stdout) + handler.setFormatter(logging.Formatter("LOG: %(message)s")) + logger.addHandler(handler) + + try: + # Capture and verify before the exception + with stdout_capture as capture: + print("Before exception") + logger.info("Log before exception") + + # Verify isolation works + output = capture.getvalue() + assert "Before exception" in output + assert "Log before exception" not in output + + # Now test that exception in capture context still works + with pytest.raises(ValueError): + with stdout_capture: + raise ValueError("Test exception") + + # Verify we can still use capture after exception + with stdout_capture as capture: + print("After exception test") + output = capture.getvalue() + assert "After exception test" in output + + finally: + logger.removeHandler(handler) diff --git a/providers/standard/src/airflow/providers/standard/triggers/temporal.py b/providers/standard/src/airflow/providers/standard/triggers/temporal.py index 5f7451a0e84f9..599a219dbee39 100644 --- a/providers/standard/src/airflow/providers/standard/triggers/temporal.py +++ b/providers/standard/src/airflow/providers/standard/triggers/temporal.py @@ -24,7 +24,11 @@ import pendulum from airflow.triggers.base import BaseTrigger, TaskSuccessEvent, TriggerEvent -from airflow.utils import timezone + +try: + from airflow.sdk import timezone +except ImportError: + from airflow.utils import timezone # type: ignore[attr-defined,no-redef] class DateTimeTrigger(BaseTrigger): diff --git a/pyproject.toml b/pyproject.toml index 2a34bd7197111..5090965fa8937 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -748,6 +748,7 @@ testing = ["dev", "providers.tests", "tests_common", "tests", "system", "unit", "dev/perf/*" = ["TID253"] "dev/check_files.py" = ["S101"] "dev/breeze/tests/*" = ["TID253", "S101", "TRY002"] +"devel-common/tests/*" = ["S101"] "airflow-core/tests/*" = ["D", "TID253", "S101", "TRY002"] "docker-tests/*" = ["D", "TID253", "S101", "TRY002"] "task-sdk-tests/*" = ["D", "TID253", "S101", "TRY002"]