Skip to content

Commit

Permalink
Add support for eager tasks
Browse files Browse the repository at this point in the history
python 3.12 supports eager tasks

reading:
https://docs.python.org/3/library/asyncio-task.html#eager-task-factory
python/cpython#97696

There are lots of places were we are unlikely to suspend, but we might
suspend so creating a task makes sense
  • Loading branch information
bdraco committed Feb 25, 2024
1 parent b3e1019 commit 364babc
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 9 deletions.
12 changes: 9 additions & 3 deletions homeassistant/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from __future__ import annotations

import asyncio
from collections.abc import Coroutine
import contextlib
from datetime import timedelta
import logging
Expand Down Expand Up @@ -51,6 +50,7 @@
async_set_domains_to_be_loaded,
async_setup_component,
)
from .util.async_ import create_eager_task
from .util.logging import async_activate_log_queue_handler
from .util.package import async_get_user_site, is_virtual_env

Expand Down Expand Up @@ -665,7 +665,7 @@ async def _async_resolve_domains_to_setup(
to_get = old_to_resolve

manifest_deps: set[str] = set()
resolve_dependencies_tasks: list[Coroutine[Any, Any, bool]] = []
resolve_dependencies_tasks: list[asyncio.Task[bool]] = []
integrations_to_process: list[loader.Integration] = []

for domain, itg in (await loader.async_get_integrations(hass, to_get)).items():
Expand All @@ -677,7 +677,13 @@ async def _async_resolve_domains_to_setup(
manifest_deps.update(itg.after_dependencies)
needed_requirements.update(itg.requirements)
if not itg.all_dependencies_resolved:
resolve_dependencies_tasks.append(itg.resolve_dependencies())
resolve_dependencies_tasks.append(
create_eager_task(
itg.resolve_dependencies(),
name=f"resolve dependencies {domain}",
loop=hass.loop,
)
)

if unseen_deps := manifest_deps - integration_cache.keys():
# If there are dependencies, try to preload all
Expand Down
26 changes: 24 additions & 2 deletions homeassistant/config_entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
from .loader import async_suggest_report_issue
from .setup import DATA_SETUP_DONE, async_process_deps_reqs, async_setup_component
from .util import uuid as uuid_util
from .util.async_ import create_eager_task
from .util.decorator import Registry

if TYPE_CHECKING:
Expand Down Expand Up @@ -929,6 +930,27 @@ def async_create_task(

return task

@callback
def async_create_eager_task(
self,
hass: HomeAssistant,
target: Coroutine[Any, Any, _R],
name: str | None = None,
) -> asyncio.Task[_R]:
"""Create an eager task from within the event loop.
This method must be run in the event loop.
target: target to call.
"""
task = hass.async_create_eager_task(
target, f"{name} {self.title} {self.domain} {self.entry_id}"
)
self._tasks.add(task)
task.add_done_callback(self._tasks.remove)

return task

@callback
def async_create_background_task(
self, hass: HomeAssistant, target: Coroutine[Any, Any, _R], name: str
Expand Down Expand Up @@ -1682,7 +1704,7 @@ async def async_forward_entry_setups(
"""Forward the setup of an entry to platforms."""
await asyncio.gather(
*(
asyncio.create_task(
create_eager_task(
self.async_forward_entry_setup(entry, platform),
name=f"config entry forward setup {entry.title} {entry.domain} {entry.entry_id} {platform}",
)
Expand Down Expand Up @@ -1718,7 +1740,7 @@ async def async_unload_platforms(
return all(
await asyncio.gather(
*(
asyncio.create_task(
create_eager_task(
self.async_forward_entry_unload(entry, platform),
name=f"config entry forward unload {entry.title} {entry.domain} {entry.entry_id} {platform}",
)
Expand Down
17 changes: 17 additions & 0 deletions homeassistant/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@
from .util import dt as dt_util, location
from .util.async_ import (
cancelling,
create_eager_task,
run_callback_threadsafe,
shutdown_run_callback_threadsafe,
)
Expand Down Expand Up @@ -636,6 +637,22 @@ def async_create_task(
task.add_done_callback(self._tasks.remove)
return task

@callback
def async_create_eager_task(
self, target: Coroutine[Any, Any, _R], name: str | None = None
) -> asyncio.Task[_R]:
"""Create an eager task from within the event loop.
This method must be run in the event loop. If you are using this in your
integration, use the create task methods on the config entry instead.
target: target to call.
"""
task = create_eager_task(target, name=name, loop=self.loop)
self._tasks.add(task)
task.add_done_callback(self._tasks.remove)
return task

@callback
def async_create_background_task(
self,
Expand Down
4 changes: 2 additions & 2 deletions homeassistant/helpers/entity_platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,7 @@ def _async_schedule_add_entities(
self, new_entities: Iterable[Entity], update_before_add: bool = False
) -> None:
"""Schedule adding entities for a single platform async."""
task = self.hass.async_create_task(
task = self.hass.async_create_eager_task(
self.async_add_entities(new_entities, update_before_add=update_before_add),
f"EntityPlatform async_add_entities {self.domain}.{self.platform_name}",
)
Expand All @@ -480,7 +480,7 @@ def _async_schedule_add_entities_for_entry(
) -> None:
"""Schedule adding entities for a single platform async and track the task."""
assert self.config_entry
task = self.config_entry.async_create_task(
task = self.config_entry.async_create_eager_task(
self.hass,
self.async_add_entities(new_entities, update_before_add=update_before_add),
f"EntityPlatform async_add_entities_for_entry {self.domain}.{self.platform_name}",
Expand Down
34 changes: 32 additions & 2 deletions homeassistant/util/async_.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
"""Asyncio utilities."""
from __future__ import annotations

from asyncio import Future, Semaphore, gather, get_running_loop
from asyncio.events import AbstractEventLoop
from asyncio import AbstractEventLoop, Future, Semaphore, Task, gather, get_running_loop
from collections.abc import Awaitable, Callable
import concurrent.futures
from contextlib import suppress
import functools
import logging
import sys
import threading
from traceback import extract_stack
from typing import Any, ParamSpec, TypeVar, TypeVarTuple
Expand All @@ -23,6 +23,36 @@
_P = ParamSpec("_P")
_Ts = TypeVarTuple("_Ts")

if sys.version_info >= (3, 12, 0):

def create_eager_task(
coro: Awaitable[_T],
*,
name: str | None = None,
loop: AbstractEventLoop | None = None,
) -> Task[_T]:
"""Create a task from a coroutine and schedule it to run immediately."""
return Task(
coro,
loop=loop or get_running_loop(),
name=name,
eager_start=True, # type: ignore[call-arg]
)
else:

def create_eager_task(
coro: Awaitable[_T],
*,
name: str | None = None,
loop: AbstractEventLoop | None = None,
) -> Task[_T]:
"""Create a task from a coroutine and schedule it to run immediately."""
return Task(
coro,
loop=loop or get_running_loop(),
name=name,
)


def cancelling(task: Future[Any]) -> bool:
"""Return True if task is cancelling."""
Expand Down

0 comments on commit 364babc

Please sign in to comment.