From 9aefc275cc216fba652522ff974aa0ba98bb1f76 Mon Sep 17 00:00:00 2001 From: LuYi Date: Wed, 11 Feb 2026 04:33:32 -0500 Subject: [PATCH] Fix over_rollout (#500) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: 弈路 --- tests/explorer/scheduler_test.py | 15 ++++++++++++--- trinity/explorer/scheduler.py | 28 ++++++++++++++++++---------- 2 files changed, 30 insertions(+), 13 deletions(-) diff --git a/tests/explorer/scheduler_test.py b/tests/explorer/scheduler_test.py index 9eb996166a..7b475c8b01 100644 --- a/tests/explorer/scheduler_test.py +++ b/tests/explorer/scheduler_test.py @@ -390,7 +390,7 @@ async def test_get_results(self): statuses, exps = await scheduler.get_results(batch_id=0, min_num=4, timeout=3) end_time = time.time() - self.assertLessEqual(end_time - start_time, 5) + self.assertLessEqual(end_time - start_time, 15) # sync wait for runner restart self.assertEqual(len(statuses), 2) self.assertEqual(len(exps), 2) @@ -420,7 +420,7 @@ async def test_get_results(self): _, exps = await scheduler.get_results(batch_id=1, min_num=1, timeout=1) self.assertEqual(len(exps), 0) - # test clear_timeout_tasks + # test _cleanup_batch_and_restart_runners: part I, no clear tasks = generate_tasks(3, timeout_num=1, timeout_seconds=3) scheduler.schedule(tasks, batch_id=2) statuses, exps = await scheduler.get_results( @@ -433,7 +433,16 @@ async def test_get_results(self): ) self.assertEqual(len(statuses), 1) self.assertEqual(len(exps), 1) - _, exps = await scheduler.get_results(batch_id=2, min_num=1, timeout=1) + # test _cleanup_batch_and_restart_runners: part II, clear + tasks = generate_tasks(3, timeout_num=1, timeout_seconds=3) + scheduler.schedule(tasks, batch_id=3) + statuses, exps = await scheduler.get_results(batch_id=3, timeout=2) + self.assertEqual(len(statuses), 3) + self.assertEqual(len(exps), 3) + statuses, exps = await scheduler.get_results(batch_id=3, timeout=2) + self.assertEqual(len(statuses), 0) + self.assertEqual(len(exps), 0) + _, exps = await scheduler.get_results(batch_id=3, min_num=1, timeout=1) self.assertEqual(len(exps), 0) await scheduler.stop() diff --git a/trinity/explorer/scheduler.py b/trinity/explorer/scheduler.py index e2ac18e12a..e6d7bb46e4 100644 --- a/trinity/explorer/scheduler.py +++ b/trinity/explorer/scheduler.py @@ -518,6 +518,17 @@ def dynamic_timeout(self, timeout: Optional[float] = None) -> float: avg_time_per_task * self.config.explorer.dynamic_timeout.ratio, ) + async def _cleanup_batch_and_restart_runners(self, batch_id: Union[int, str]) -> None: + """Clear timeout tasks for a batch and restart associated runners.""" + self._clear_timeout_tasks(batch_id=batch_id) + runners_to_restart = [ + runner_id + for runner_id, task in list(self.busy_runners.items()) + if task.batch_id == batch_id + ] + if runners_to_restart: + await asyncio.gather(*[self._restart_runner(rid) for rid in runners_to_restart]) + async def get_results( self, batch_id: Union[int, str], @@ -550,11 +561,15 @@ async def get_results( completed_count = len(self.completed_tasks.get(batch_id, [])) if completed_count >= min_num: min_threshold_reached_time = min_threshold_reached_time or time.time() - if (completed_count >= scheduled_num) or ( + if completed_count >= scheduled_num: + break + if ( time.time() - min_threshold_reached_time >= self.config.explorer.over_rollout.wait_after_min ): - break + if clear_timeout_tasks: + await self._cleanup_batch_and_restart_runners(batch_id) + break await asyncio.sleep(0.1) if time.time() - start_time > timeout: @@ -562,14 +577,7 @@ async def get_results( f"Timed out waiting for tasks at batch {batch_id} to complete after {timeout} seconds" ) if clear_timeout_tasks: - self._clear_timeout_tasks(batch_id=batch_id) - runners_to_restart = [] - for runner_id, task in list(self.busy_runners.items()): - if task.batch_id == batch_id: - runners_to_restart.append(runner_id) - asyncio.gather( - *[self._restart_runner(runner_id) for runner_id in runners_to_restart] - ) + await self._cleanup_batch_and_restart_runners(batch_id) statuses = [] experiences = []