Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add name parameter to spawn #385

Merged
merged 13 commits into from
Jan 27, 2023
28 changes: 26 additions & 2 deletions aiojobs/_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,15 @@


class Job(Generic[_T]):
def __init__(self, coro: Coroutine[object, object, _T], scheduler: Scheduler):
def __init__(
self,
coro: Coroutine[object, object, _T],
scheduler: Scheduler,
name: Optional[str] = None,
):
self._coro = coro
self._scheduler: Optional[Scheduler] = scheduler
self._name = name
loop = asyncio.get_running_loop()
self._started = loop.create_future()

Expand Down Expand Up @@ -50,6 +56,21 @@ def pending(self) -> bool:
def closed(self) -> bool:
return self._closed

def get_name(self) -> Optional[str]:
"""Get the task name.

See https://docs.python.org/3/library/asyncio-task.html#asyncio.Task.get_name.
Returns None if no name was set on the Job object and job has not yet started.
"""
if sys.version_info >= (3, 8) and self._task:
return self._task.get_name()
return self._name

def set_name(self, name: str) -> None:
Fixed Show fixed Hide fixed
self._name = name
if sys.version_info >= (3, 8) and self._task is not None:
self._task.set_name(name)

async def _do_wait(self, timeout: Optional[float]) -> _T:
async with async_timeout.timeout(timeout):
# TODO: add a test for waiting for a pending coro
Expand Down Expand Up @@ -118,7 +139,10 @@ async def _close(self, timeout: Optional[float]) -> None:

def _start(self) -> None:
assert self._task is None
self._task = asyncio.create_task(self._coro)
if sys.version_info >= (3, 8):
self._task = asyncio.create_task(self._coro, name=self._name)
else:
self._task = asyncio.create_task(self._coro)
self._task.add_done_callback(self._done_callback)
self._started.set_result(None)

Expand Down
6 changes: 4 additions & 2 deletions aiojobs/_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,12 @@ def pending_count(self) -> int:
def closed(self) -> bool:
return self._closed

async def spawn(self, coro: Coroutine[object, object, _T]) -> Job[_T]:
async def spawn(
self, coro: Coroutine[object, object, _T], name: Optional[str] = None
) -> Job[_T]:
if self._closed:
raise RuntimeError("Scheduling a new job after closing")
job = Job(coro, self)
job = Job(coro, self, name=name)
should_start = self._limit is None or self.active_count < self._limit
if should_start:
job._start()
Expand Down
38 changes: 38 additions & 0 deletions tests/test_job.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import asyncio
import sys
from contextlib import suppress
from typing import Awaitable, Callable, NoReturn
from unittest import mock

import pytest

from aiojobs._job import Job
Dreamsorcerer marked this conversation as resolved.
Show resolved Hide resolved
from aiojobs._scheduler import Scheduler

_MakeScheduler = Callable[..., Awaitable[Scheduler]]
Expand Down Expand Up @@ -287,3 +289,39 @@ async def coro() -> NoReturn:
await scheduler.spawn(coro())
await scheduler.close()
handler.assert_called_once()


async def test_get_job_name(scheduler: Scheduler) -> None:
async def coro() -> None:
"""Dummy function."""

job = await scheduler.spawn(coro(), name="test_job_name")
assert job.get_name() == "test_job_name"
if sys.version_info >= (3, 8):
assert job._task is not None
assert job._task.get_name() == "test_job_name"


async def test_get_default_job_name(scheduler: Scheduler) -> None:
async def coro() -> None:
"""Dummy function."""

job = await scheduler.spawn(coro())
if sys.version_info >= (3, 8):
job_name = job.get_name()
assert job_name is not None
assert job_name.startswith("Task-")
else:
assert job.get_name() is None


async def test_set_job_name(scheduler: Scheduler) -> None:
async def coro() -> None:
"""Dummy function."""

job = await scheduler.spawn(coro(), name="original_name")
job.set_name("changed_name")
assert job.get_name() == "changed_name"
if sys.version_info >= (3, 8):
assert job._task is not None
assert job._task.get_name() == "changed_name"