Skip to content

Commit

Permalink
Improving naming and event handling for process tasks.
Browse files Browse the repository at this point in the history
  • Loading branch information
rcschrg committed Mar 11, 2024
1 parent 9da486b commit f642e1c
Showing 1 changed file with 52 additions and 20 deletions.
72 changes: 52 additions & 20 deletions mango/util/scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,33 @@
import concurrent.futures
import datetime
from abc import abstractmethod
from multiprocessing import Manager
from multiprocessing import Manager, Event
from typing import Any, List, Tuple
from dataclasses import dataclass
from multiprocessing.synchronize import Event as MultiprocessingEvent

from dateutil.rrule import rrule

from mango.util.clock import AsyncioClock, Clock, ExternalClock
from asyncio import Future


@dataclass
class ScheduledProcessControl:
run_task_event: MultiprocessingEvent
kill_process_event: MultiprocessingEvent

def kill_process(self):
self.kill_process_event.set()

def init_process(self):
self.kill_process_event.clear()

def resume_task(self):
self.run_task_event.set()

def suspend_task(self):
self.run_task_event.clear()


class Suspendable:
Expand Down Expand Up @@ -437,7 +458,9 @@ def __init__(
Tuple[ScheduledTask, asyncio.Future, Suspendable, Any]
] = []
self.clock = clock if clock is not None else AsyncioClock()
self._scheduled_process_tasks = []
self._scheduled_process_tasks: List[
Tuple[ScheduledProcessTask, Future, ScheduledProcessControl, Any]
] = []
self._process_pool_exec = concurrent.futures.ProcessPoolExecutor(
max_workers=num_process_parallel, initializer=_create_asyncio_context
)
Expand All @@ -446,10 +469,14 @@ def __init__(
self._observable = observable

@staticmethod
def _run_task_in_p_context(task, suspend_event, kill_event):
def _run_task_in_p_context(
task, scheduled_process_control: ScheduledProcessControl
):
try:
coro = Suspendable(
task.run(), ext_contr_event=suspend_event, kill_event=kill_event
task.run(),
ext_contr_event=scheduled_process_control.run_task_event,
kill_event=scheduled_process_control.kill_process_event,
)

return asyncio.get_event_loop().run_until_complete(coro)
Expand Down Expand Up @@ -643,22 +670,27 @@ def schedule_process_task(self, task: ScheduledProcessTask, src=None):
loop = asyncio.get_running_loop()
if self._manager is None:
self._manager = Manager()
event = self._manager.Event()
kill_event = self._manager.Event()
kill_event.clear()
event.set()

scheduled_process_control = ScheduledProcessControl(
run_task_event=self._manager.Event(),
kill_process_event=self._manager.Event(),
)
scheduled_process_control.init_process()
scheduled_process_control.resume_task()

l_task = asyncio.ensure_future(
loop.run_in_executor(
self._process_pool_exec,
Scheduler._run_task_in_p_context,
task,
event,
kill_event,
scheduled_process_control,
)
)
l_task.add_done_callback(self._remove_process_task)
l_task.add_done_callback(task.on_stop)
self._scheduled_process_tasks.append((task, l_task, (event, kill_event), src))
self._scheduled_process_tasks.append(
(task, l_task, scheduled_process_control, src)
)
return l_task

def schedule_timestamp_process_task(
Expand Down Expand Up @@ -786,9 +818,9 @@ def suspend(self, given_src):
for _, _, coro, src in self._scheduled_tasks:
if src == given_src and coro is not None:
coro.suspend()
for _, _, event, src in self._scheduled_process_tasks:
for _, _, scheduled_process_control, src in self._scheduled_process_tasks:
if src == given_src:
event[0].clear()
scheduled_process_control.suspend_task()

def resume(self, given_src):
"""Resume a set of tasks triggered by the given src object.
Expand All @@ -802,16 +834,16 @@ def resume(self, given_src):
for _, _, coro, src in self._scheduled_tasks:
if src == given_src and coro is not None:
coro.resume()
for _, _, event, src in self._scheduled_process_tasks:
for _, _, scheduled_process_control, src in self._scheduled_process_tasks:
if src == given_src:
event[0].set()
scheduled_process_control.resume_task()

def _remove_process_task(self, fut=asyncio.Future):
for i in range(len(self._scheduled_process_tasks)):
_, task, event, _ = self._scheduled_process_tasks[i]
_, task, scheduled_process_control, _ = self._scheduled_process_tasks[i]
if task == fut:
event[0].set()
event[1].set()
scheduled_process_control.resume_task()
scheduled_process_control.kill_process()
del self._scheduled_process_tasks[i]
break

Expand Down Expand Up @@ -876,8 +908,8 @@ async def shutdown(self):
Shutdown internal process executor pool.
"""
# resume all process so they can get shutdown
for _, _, event, _ in self._scheduled_process_tasks:
event[1].set()
for _, _, scheduled_process_control, _ in self._scheduled_process_tasks:
scheduled_process_control.kill_process()
for task, _, _, _ in self._scheduled_tasks:
task.close()
await self.stop()
Expand Down

0 comments on commit f642e1c

Please sign in to comment.