diff --git a/locust/dispatch.py b/locust/dispatch.py index afde008c8a..e1b8a16da5 100644 --- a/locust/dispatch.py +++ b/locust/dispatch.py @@ -311,16 +311,28 @@ def _add_users_on_workers(self) -> dict[str, dict[str, int]]: return self._users_on_workers def _remove_users_from_workers(self) -> dict[str, dict[str, int]]: - """Remove users from the workers until the target number of users is reached for the current dispatch iteration + """Remove users from the workers until the target number of users is reached for the current dispatch iteration. :return: The users that we want to run on the workers """ current_user_count_target = max( self._current_user_count - self._user_count_per_dispatch_iteration, self._target_user_count ) + + # These are the user classes that are valid for removal. + user_class_names = [user_class.__name__ for user_class in self._user_classes] + while True: + # Iterate right - left over _active users and pop user that matches class we want to remove + index_to_pop = None + for i in range(len(self._active_users) - 1, -1, -1): + if self._active_users[i][1] in user_class_names: + index_to_pop = i + break + if index_to_pop is None: + return self._users_on_workers try: - worker_node, user = self._active_users.pop() + worker_node, user = self._active_users.pop(index_to_pop) except IndexError: return self._users_on_workers self._users_on_workers[worker_node.id][user] -= 1 diff --git a/locust/test/test_dispatch.py b/locust/test/test_dispatch.py index 81fb948e02..2d85598111 100644 --- a/locust/test/test_dispatch.py +++ b/locust/test/test_dispatch.py @@ -4168,3 +4168,48 @@ def _user_count(d: dict[str, dict[str, int]]) -> int: def _user_count_on_worker(d: dict[str, dict[str, int]], worker_node_id: str) -> int: return sum(d[worker_node_id].values()) + + +class TestSpawnDespwSpecificUserClasses(unittest.TestCase): + def test_add_then_remove_specific_user_classes(self): + """ + Adds multiple users classes and then individualy removes them, checking only specified user classes are removed. + """ + + class User1(User): + weight = 1 + + class User2(User): + weight = 1 + + class User3(User): + weight = 1 + + user_classes = [User1, User2, User3] + worker_node1 = WorkerNode("1") + + sleep_time = 0.2 # Speed-up test + + users_dispatcher = UsersDispatcher(worker_nodes=[worker_node1], user_classes=user_classes) + + # Add equal spread of Users 1, 2, 3 + users_dispatcher.new_dispatch(target_user_count=9, spawn_rate=9) + users_dispatcher._wait_between_dispatch = sleep_time + + self.assertDictEqual( + next(users_dispatcher), + { + "1": {"User1": 3, "User2": 3, "User3": 3}, + }, + ) + + # Now remove All instances of User2 + users_dispatcher.new_dispatch(target_user_count=6, spawn_rate=3, user_classes=[User2]) + + self.assertDictEqual( + next(users_dispatcher), + { + "1": {"User1": 3, "User2": 0, "User3": 3}, + }, + ) + self.assertRaises(StopIteration, lambda: next(users_dispatcher))