Skip to content

Commit

Permalink
Merge pull request #90 from Miksus/fix/max_process_count
Browse files Browse the repository at this point in the history
FIX: Async/threaded tasks limit max_process_count
  • Loading branch information
Miksus authored Sep 3, 2022
2 parents 35de0a7 + b2b0ff2 commit 724ba4b
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 2 deletions.
5 changes: 4 additions & 1 deletion rocketry/core/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,10 @@ async def startup(self):
def has_free_processors(self) -> bool:
"""Whether the Scheduler has free processors to
allocate more tasks."""
return self.n_alive <= self.session.config.max_process_count
return self.count_process_tasks_alive() < self.session.config.max_process_count

def count_process_tasks_alive(self):
return sum(task.is_alive_as_process() for task in self.tasks)

@property
def n_alive(self) -> int:
Expand Down
10 changes: 9 additions & 1 deletion rocketry/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ class Config:
_thread_error: Exception = PrivateAttr(default=None)
_lock: Optional[threading.Lock] = PrivateAttr(default_factory=threading.Lock)
_async_task: Optional[asyncio.Task] = PrivateAttr(default=None)
_main_alive: bool = PrivateAttr(default=False)

_mark_running = False

Expand Down Expand Up @@ -374,6 +375,7 @@ async def start_async(self, params:Union[dict, Parameters]=None, **kwargs):
self.log_running()
async_task = asyncio.create_task(self._run_as_async(params=params, direct_params=direct_params, execution=execution, **kwargs))
if execution == "main":
self._main_alive = True
await async_task
else:
self._async_task = async_task
Expand Down Expand Up @@ -410,6 +412,9 @@ async def start_async(self, params:Union[dict, Parameters]=None, **kwargs):
self.log_running()
self.log_failure()
raise TaskSetupError("Task failed before logging") from exc
finally:
# Clean up
self._main_alive = False

def __bool__(self):
return self.is_runnable()
Expand Down Expand Up @@ -808,7 +813,10 @@ def get_default_name(self, **kwargs):

def is_alive(self) -> bool:
"""Whether the task is alive: check if the task has a live process or thread."""
return self.is_alive_as_async() or self.is_alive_as_thread() or self.is_alive_as_process()
return self.is_alive_as_main() or self.is_alive_as_async() or self.is_alive_as_thread() or self.is_alive_as_process()

def is_alive_as_main(self) -> bool:
return self._main_alive

def is_alive_as_async(self) -> bool:
return self._async_task is not None and not self._async_task.done()
Expand Down
54 changes: 54 additions & 0 deletions rocketry/test/schedule/process/test_core.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@

import asyncio
import multiprocessing
import time
from rocketry.args.builtin import TerminationFlag
from rocketry.conditions.scheduler import SchedulerCycles

from rocketry.core import Scheduler
from rocketry.tasks import FuncTask
from rocketry.time import TimeDelta
from rocketry.conds import true
from rocketry.conditions import SchedulerStarted, TaskStarted, AlwaysTrue

def run_succeeding():
pass

def run_succeeding_slow():
time.sleep(20)

def run_creating_child():

proc = multiprocessing.Process(target=run_succeeding, daemon=True)
Expand All @@ -27,3 +35,49 @@ def test_creating_child(session):
assert 1 == logger.filter_by(action="run").count()
assert 1 == logger.filter_by(action="success").count()
assert 0 == logger.filter_by(action="fail").count()

def test_limited_processes(session):

def run_thread(flag=TerminationFlag()):
while not flag.is_set():
...

async def run_async():
while True:
await asyncio.sleep(0)

def do_post_check():
sched = session.scheduler
assert sched.n_alive == 5 # 2 processes, 1 thread, 1 async and this
assert not sched.has_free_processors()

assert task_threaded.is_alive()
assert task_threaded.is_running
assert task_async.is_alive()
assert task_async.is_running

assert task1.is_alive()
assert task2.is_alive()
assert not task3.is_alive()

assert task1.is_running
assert task2.is_running
assert not task3.is_running

task_threaded = FuncTask(run_thread, name="threaded", priority=4, start_cond=true, execution="thread", permanent=True, session=session)
task_async = FuncTask(run_async, name="async", priority=4, start_cond=true, execution="async", permanent=True, session=session)
post_check = FuncTask(do_post_check, name="post_check", on_shutdown=True, execution="main", session=session)

task1 = FuncTask(run_succeeding_slow, name="task_1", priority=3, start_cond=true, execution="process", session=session)
task2 = FuncTask(run_succeeding_slow, name="task_2", priority=2, start_cond=true, execution="process", session=session)
task3 = FuncTask(run_succeeding_slow, name="task_3", priority=1, start_cond=true, execution="process", session=session)

session.config.max_process_count = 2
session.config.instant_shutdown = True
session.config.shut_cond = SchedulerCycles() >= 3

session.start()

outcome = post_check.logger.filter_by().all()[-1]
assert outcome.action == "success", outcome.exc_text

0 comments on commit 724ba4b

Please sign in to comment.