diff --git a/src/clusterfuzz/_internal/base/tasks/__init__.py b/src/clusterfuzz/_internal/base/tasks/__init__.py index 72cff5ea41..ed3a3b988c 100644 --- a/src/clusterfuzz/_internal/base/tasks/__init__.py +++ b/src/clusterfuzz/_internal/base/tasks/__init__.py @@ -585,11 +585,52 @@ def lease(self, _event=None): # pylint: disable=arguments-differ track_task_end() +def _filter_task_for_os_mismatch(message, queue) -> bool: + """Filters a Pub/Sub message if its OS version does not match the bot's OS. + + This function checks the `base_os_version` attribute in the incoming message + against the bot's `BASE_OS_VERSION` environment variable. This handles cases + where a message is misrouted or received from a legacy subscription without + OS-specific filters. + + If an OS version mismatch is detected, the function logs a warning and + acknowledges (`ack()`) the message. Acknowledging the message permanently + removes it from the current subscription, effectively skipping it for this + bot. This assumes the message was also correctly delivered to another, + properly filtered subscription for processing. + + Args: + message: The `pubsub.Message` object to check. + queue: The name of the queue from which the message was pulled. + + Returns: + True if the message had a mismatch and was acknowledged; False otherwise. + """ + base_os_version = environment.get_value('BASE_OS_VERSION') + message_base_os_version = message.attributes.get('base_os_version') + + if not (message_base_os_version and base_os_version and + message_base_os_version != base_os_version): + return False + + logs.warning( + 'Skipping task for different OS.', + queue=queue, + message_os_version=message_base_os_version, + base_os_version=base_os_version) + message.ack() + return True + + def get_task_from_message(message, queue=None, can_defer=True, task_cls=None) -> Optional[PubSubTask]: """Returns a task constructed from the first of |messages| if possible.""" if message is None: return None + + if _filter_task_for_os_mismatch(message, queue): + return None + try: task = initialize_task(message, task_cls=task_cls) if task is None: diff --git a/src/clusterfuzz/_internal/tests/core/base/tasks/tasks_test.py b/src/clusterfuzz/_internal/tests/core/base/tasks/tasks_test.py index 59c95271db..85427ec19b 100644 --- a/src/clusterfuzz/_internal/tests/core/base/tasks/tasks_test.py +++ b/src/clusterfuzz/_internal/tests/core/base/tasks/tasks_test.py @@ -255,48 +255,96 @@ def test_get_machine_template_for_high_end_linux_queue(self): class GetTaskFromMessageTest(unittest.TestCase): """Tests for get_task_from_message.""" + def setUp(self): + self.mock_message = mock.MagicMock() + self.mock_task = mock.Mock(defer=mock.Mock(return_value=False)) + self.mock_task.set_queue.return_value = self.mock_task + + self.initialize_task_patcher = mock.patch( + 'clusterfuzz._internal.base.tasks.initialize_task', + return_value=self.mock_task) + self.mock_initialize_task = self.initialize_task_patcher.start() + + self.env_patcher = mock.patch( + 'clusterfuzz._internal.system.environment.get_value') + self.mock_env_get = self.env_patcher.start() + self.mock_env_get.return_value = None + + def tearDown(self): + self.initialize_task_patcher.stop() + self.env_patcher.stop() + def test_no_message(self): - self.assertEqual(tasks.get_task_from_message(None), None) + """Test that no task is returned when the message is None.""" + self.assertIsNone(tasks.get_task_from_message(None)) def test_success(self): - mock_task = mock.Mock(defer=mock.Mock(return_value=False)) - mock_task.set_queue.return_value = mock_task - with mock.patch( - 'clusterfuzz._internal.base.tasks.initialize_task', - return_value=mock_task): - self.assertEqual(tasks.get_task_from_message(mock.Mock()), mock_task) + """Test successful task creation from a message.""" + self.assertEqual( + tasks.get_task_from_message(self.mock_message), self.mock_task) def test_key_error(self): - mock_message = mock.Mock() - with mock.patch( - 'clusterfuzz._internal.base.tasks.initialize_task', - side_effect=KeyError): - self.assertEqual(tasks.get_task_from_message(mock_message), None) - mock_message.ack.assert_called_with() + """Test that a message is acked and skipped on a KeyError.""" + self.mock_initialize_task.side_effect = KeyError + self.assertIsNone(tasks.get_task_from_message(self.mock_message)) + self.mock_message.ack.assert_called_with() def test_defer(self): - mock_task = mock.Mock(defer=mock.Mock(return_value=True)) - with mock.patch( - 'clusterfuzz._internal.base.tasks.initialize_task', - return_value=mock_task): - self.assertEqual(tasks.get_task_from_message(mock.Mock()), None) + """Test that a task is deferred if its ETA is in the future.""" + self.mock_task.defer.return_value = True + self.assertIsNone(tasks.get_task_from_message(self.mock_message)) def test_set_queue(self): """Tests the set_queue method of a task.""" mock_queue = mock.Mock() - mock_task = mock.Mock() + task = tasks.get_task_from_message(self.mock_message, queue=mock_queue) + task.set_queue.assert_called_with(mock_queue) - mock_task.configure_mock( - queue=mock_queue, - set_queue=mock.Mock(return_value=mock_task), - defer=mock.Mock(return_value=False)) + @mock.patch('clusterfuzz._internal.metrics.logs.warning') + def test_os_mismatch(self, mock_log_warning): + """Test that a message is skipped and acked if OS versions mismatch.""" + self.mock_env_get.return_value = 'ubuntu-24-04' + self.mock_message.attributes = {'base_os_version': 'ubuntu-22-04'} - with mock.patch( - 'clusterfuzz._internal.base.tasks.initialize_task', - return_value=mock_task): - task = tasks.get_task_from_message(mock.Mock()) + result = tasks.get_task_from_message(self.mock_message) + + self.assertIsNone(result) + self.mock_message.ack.assert_called_once() + mock_log_warning.assert_called_with( + 'Skipping task for different OS.', + queue=None, + message_os_version='ubuntu-22-04', + base_os_version='ubuntu-24-04') + + def test_os_match(self): + """Test that a message is processed if OS versions match.""" + self.mock_env_get.return_value = 'ubuntu-24-04' + self.mock_message.attributes = {'base_os_version': 'ubuntu-24-04'} + + result = tasks.get_task_from_message(self.mock_message) + + self.assertEqual(result, self.mock_task) + self.mock_message.ack.assert_not_called() + + def test_bot_has_os_message_does_not(self): + """Test that a message is processed if the bot has an OS but the message does not.""" + self.mock_env_get.return_value = 'ubuntu-24-04' + self.mock_message.attributes = {} + + result = tasks.get_task_from_message(self.mock_message) + + self.assertEqual(result, self.mock_task) + self.mock_message.ack.assert_not_called() + + def test_bot_has_no_os_message_does(self): + """Test that a message is processed if the message has an OS but the bot does not.""" + self.mock_env_get.return_value = None + self.mock_message.attributes = {'base_os_version': 'ubuntu-24-04'} + + result = tasks.get_task_from_message(self.mock_message) - self.assertEqual(task.queue, mock_queue) + self.assertEqual(result, self.mock_task) + self.mock_message.ack.assert_not_called() @test_utils.with_cloud_emulators('datastore')