diff --git a/tests/cli/commands/test_celery_command.py b/tests/cli/commands/test_celery_command.py index 461f25884ac0d..b3d0365c087f7 100644 --- a/tests/cli/commands/test_celery_command.py +++ b/tests/cli/commands/test_celery_command.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import os import unittest from argparse import Namespace from tempfile import NamedTemporaryFile @@ -143,15 +142,22 @@ def test_same_pid_file_is_used_in_start_and_stop( celery_command.stop_worker(stop_args) mock_read_pid_from_pidfile.assert_called_once_with(pid_file) + @mock.patch("airflow.cli.commands.celery_command.remove_existing_pidfile") @mock.patch("airflow.cli.commands.celery_command.read_pid_from_pidfile") @mock.patch("airflow.cli.commands.celery_command.worker_bin.worker") @mock.patch("airflow.cli.commands.celery_command.psutil.Process") + @mock.patch("airflow.cli.commands.celery_command.setup_locations") @conf_vars({("core", "executor"): "CeleryExecutor"}) def test_custom_pid_file_is_used_in_start_and_stop( - self, mock_celery_worker, mock_read_pid_from_pidfile, mock_process + self, + mock_setup_locations, + mock_process, + mock_celery_worker, + mock_read_pid_from_pidfile, + mock_remove_existing_pidfile, ): pid_file = "custom_test_pid_file" - + mock_setup_locations.return_value = (pid_file, None, None, None) # Call worker worker_args = self.parser.parse_args(['celery', 'worker', '--skip-serve-logs', '--pid', pid_file]) celery_command.worker(worker_args) @@ -161,25 +167,18 @@ def test_custom_pid_file_is_used_in_start_and_stop( assert 'pidfile' in kwargs assert kwargs['pidfile'] == pid_file assert not args - assert os.path.exists(pid_file) - - with open(pid_file) as pid_fd: - pid = "".join(pid_fd.readlines()) - - # Call stop - stop_args = self.parser.parse_args(['celery', 'stop', '--pid', pid_file]) - celery_command.stop_worker(stop_args) - run_mock = mock_celery_worker.return_value.run - assert run_mock.call_args - args, kwargs = run_mock.call_args - assert 'pidfile' in kwargs - assert kwargs['pidfile'] == pid_file - assert not args - - mock_read_pid_from_pidfile.assert_called_once_with(pid_file) - mock_process.assert_called_once_with(int(pid)) - mock_process.return_value.terminate.assert_called_once_with() - assert not os.path.exists(pid_file) + stop_args = self.parser.parse_args(['celery', 'stop', '--pid', pid_file]) + celery_command.stop_worker(stop_args) + run_mock = mock_celery_worker.return_value.run + assert run_mock.call_args + args, kwargs = run_mock.call_args + assert 'pidfile' in kwargs + assert kwargs['pidfile'] == pid_file + assert not args + + mock_read_pid_from_pidfile.assert_called_once_with(pid_file) + mock_process.return_value.terminate.assert_called() + mock_remove_existing_pidfile.assert_called_once_with(pid_file) @pytest.mark.backend("mysql", "postgres")