Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions tests/explorer/scheduler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ async def test_get_results(self):
statuses, exps = await scheduler.get_results(batch_id=0, min_num=4, timeout=3)
end_time = time.time()

self.assertLessEqual(end_time - start_time, 5)
self.assertLessEqual(end_time - start_time, 15) # sync wait for runner restart
self.assertEqual(len(statuses), 2)
self.assertEqual(len(exps), 2)

Expand Down Expand Up @@ -420,7 +420,7 @@ async def test_get_results(self):
_, exps = await scheduler.get_results(batch_id=1, min_num=1, timeout=1)
self.assertEqual(len(exps), 0)

# test clear_timeout_tasks
# test _cleanup_batch_and_restart_runners: part I, no clear
tasks = generate_tasks(3, timeout_num=1, timeout_seconds=3)
scheduler.schedule(tasks, batch_id=2)
statuses, exps = await scheduler.get_results(
Expand All @@ -433,7 +433,16 @@ async def test_get_results(self):
)
self.assertEqual(len(statuses), 1)
self.assertEqual(len(exps), 1)
_, exps = await scheduler.get_results(batch_id=2, min_num=1, timeout=1)
# test _cleanup_batch_and_restart_runners: part II, clear
tasks = generate_tasks(3, timeout_num=1, timeout_seconds=3)
scheduler.schedule(tasks, batch_id=3)
statuses, exps = await scheduler.get_results(batch_id=3, timeout=2)
self.assertEqual(len(statuses), 3)
self.assertEqual(len(exps), 3)
statuses, exps = await scheduler.get_results(batch_id=3, timeout=2)
self.assertEqual(len(statuses), 0)
self.assertEqual(len(exps), 0)
_, exps = await scheduler.get_results(batch_id=3, min_num=1, timeout=1)
self.assertEqual(len(exps), 0)

await scheduler.stop()
Expand Down
28 changes: 18 additions & 10 deletions trinity/explorer/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,17 @@ def dynamic_timeout(self, timeout: Optional[float] = None) -> float:
avg_time_per_task * self.config.explorer.dynamic_timeout.ratio,
)

async def _cleanup_batch_and_restart_runners(self, batch_id: Union[int, str]) -> None:
"""Clear timeout tasks for a batch and restart associated runners."""
self._clear_timeout_tasks(batch_id=batch_id)
runners_to_restart = [
runner_id
for runner_id, task in list(self.busy_runners.items())
if task.batch_id == batch_id
]
if runners_to_restart:
await asyncio.gather(*[self._restart_runner(rid) for rid in runners_to_restart])

async def get_results(
self,
batch_id: Union[int, str],
Expand Down Expand Up @@ -550,26 +561,23 @@ async def get_results(
completed_count = len(self.completed_tasks.get(batch_id, []))
if completed_count >= min_num:
min_threshold_reached_time = min_threshold_reached_time or time.time()
if (completed_count >= scheduled_num) or (
if completed_count >= scheduled_num:
break
if (
time.time() - min_threshold_reached_time
>= self.config.explorer.over_rollout.wait_after_min
):
break
if clear_timeout_tasks:
await self._cleanup_batch_and_restart_runners(batch_id)
break
await asyncio.sleep(0.1)

if time.time() - start_time > timeout:
self.logger.error(
f"Timed out waiting for tasks at batch {batch_id} to complete after {timeout} seconds"
)
if clear_timeout_tasks:
self._clear_timeout_tasks(batch_id=batch_id)
runners_to_restart = []
for runner_id, task in list(self.busy_runners.items()):
if task.batch_id == batch_id:
runners_to_restart.append(runner_id)
asyncio.gather(
*[self._restart_runner(runner_id) for runner_id in runners_to_restart]
)
await self._cleanup_batch_and_restart_runners(batch_id)

statuses = []
experiences = []
Expand Down