Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for eager tasks #111425

Merged
merged 9 commits into from
Feb 26, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions homeassistant/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from __future__ import annotations

import asyncio
from collections.abc import Coroutine
import contextlib
from datetime import timedelta
import logging
Expand Down Expand Up @@ -51,6 +50,7 @@
async_set_domains_to_be_loaded,
async_setup_component,
)
from .util.async_ import create_eager_task
from .util.logging import async_activate_log_queue_handler
from .util.package import async_get_user_site, is_virtual_env

Expand Down Expand Up @@ -665,7 +665,7 @@ async def _async_resolve_domains_to_setup(
to_get = old_to_resolve

manifest_deps: set[str] = set()
resolve_dependencies_tasks: list[Coroutine[Any, Any, bool]] = []
resolve_dependencies_tasks: list[asyncio.Task[bool]] = []
integrations_to_process: list[loader.Integration] = []

for domain, itg in (await loader.async_get_integrations(hass, to_get)).items():
Expand All @@ -677,7 +677,13 @@ async def _async_resolve_domains_to_setup(
manifest_deps.update(itg.after_dependencies)
needed_requirements.update(itg.requirements)
if not itg.all_dependencies_resolved:
resolve_dependencies_tasks.append(itg.resolve_dependencies())
resolve_dependencies_tasks.append(
create_eager_task(
itg.resolve_dependencies(),
name=f"resolve dependencies {domain}",
loop=hass.loop,
)
)

if unseen_deps := manifest_deps - integration_cache.keys():
# If there are dependencies, try to preload all
Expand Down
26 changes: 24 additions & 2 deletions homeassistant/config_entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
from .loader import async_suggest_report_issue
from .setup import DATA_SETUP_DONE, async_process_deps_reqs, async_setup_component
from .util import uuid as uuid_util
from .util.async_ import create_eager_task
from .util.decorator import Registry

if TYPE_CHECKING:
Expand Down Expand Up @@ -929,6 +930,27 @@ def async_create_task(

return task

@callback
def async_create_eager_task(
self,
hass: HomeAssistant,
target: Coroutine[Any, Any, _R],
name: str | None = None,
) -> asyncio.Task[_R]:
"""Create an eager task from within the event loop.

This method must be run in the event loop.

target: target to call.
"""
task = hass.async_create_eager_task(
target, f"{name} {self.title} {self.domain} {self.entry_id}"
)
self._tasks.add(task)
task.add_done_callback(self._tasks.remove)

return task

@callback
def async_create_background_task(
self, hass: HomeAssistant, target: Coroutine[Any, Any, _R], name: str
Expand Down Expand Up @@ -1682,7 +1704,7 @@ async def async_forward_entry_setups(
"""Forward the setup of an entry to platforms."""
await asyncio.gather(
*(
asyncio.create_task(
create_eager_task(
self.async_forward_entry_setup(entry, platform),
name=f"config entry forward setup {entry.title} {entry.domain} {entry.entry_id} {platform}",
)
Expand Down Expand Up @@ -1718,7 +1740,7 @@ async def async_unload_platforms(
return all(
await asyncio.gather(
*(
asyncio.create_task(
create_eager_task(
self.async_forward_entry_unload(entry, platform),
name=f"config entry forward unload {entry.title} {entry.domain} {entry.entry_id} {platform}",
)
Expand Down
17 changes: 17 additions & 0 deletions homeassistant/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@
from .util import dt as dt_util, location
from .util.async_ import (
cancelling,
create_eager_task,
run_callback_threadsafe,
shutdown_run_callback_threadsafe,
)
Expand Down Expand Up @@ -636,6 +637,22 @@ def async_create_task(
task.add_done_callback(self._tasks.remove)
return task

@callback
def async_create_eager_task(
bdraco marked this conversation as resolved.
Show resolved Hide resolved
self, target: Coroutine[Any, Any, _R], name: str | None = None
) -> asyncio.Task[_R]:
"""Create an eager task from within the event loop.

This method must be run in the event loop. If you are using this in your
integration, use the create task methods on the config entry instead.

target: target to call.
"""
task = create_eager_task(target, name=name, loop=self.loop)
self._tasks.add(task)
task.add_done_callback(self._tasks.remove)
return task

@callback
def async_create_background_task(
self,
Expand Down
4 changes: 2 additions & 2 deletions homeassistant/helpers/entity_platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,7 @@ def _async_schedule_add_entities(
self, new_entities: Iterable[Entity], update_before_add: bool = False
) -> None:
"""Schedule adding entities for a single platform async."""
task = self.hass.async_create_task(
task = self.hass.async_create_eager_task(
self.async_add_entities(new_entities, update_before_add=update_before_add),
f"EntityPlatform async_add_entities {self.domain}.{self.platform_name}",
)
Expand All @@ -480,7 +480,7 @@ def _async_schedule_add_entities_for_entry(
) -> None:
"""Schedule adding entities for a single platform async and track the task."""
assert self.config_entry
task = self.config_entry.async_create_task(
task = self.config_entry.async_create_eager_task(
self.hass,
self.async_add_entities(new_entities, update_before_add=update_before_add),
f"EntityPlatform async_add_entities_for_entry {self.domain}.{self.platform_name}",
Expand Down
34 changes: 32 additions & 2 deletions homeassistant/util/async_.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
"""Asyncio utilities."""
from __future__ import annotations

from asyncio import Future, Semaphore, gather, get_running_loop
from asyncio.events import AbstractEventLoop
from asyncio import AbstractEventLoop, Future, Semaphore, Task, gather, get_running_loop
from collections.abc import Awaitable, Callable
import concurrent.futures
from contextlib import suppress
import functools
import logging
import sys
import threading
from traceback import extract_stack
from typing import Any, ParamSpec, TypeVar, TypeVarTuple
Expand All @@ -23,6 +23,36 @@
_P = ParamSpec("_P")
_Ts = TypeVarTuple("_Ts")

if sys.version_info >= (3, 12, 0):

def create_eager_task(
coro: Awaitable[_T],
*,
name: str | None = None,
loop: AbstractEventLoop | None = None,
) -> Task[_T]:
"""Create a task from a coroutine and schedule it to run immediately."""
return Task(
coro,
loop=loop or get_running_loop(),
name=name,
eager_start=True, # type: ignore[call-arg]
)
else:

def create_eager_task(
coro: Awaitable[_T],
*,
name: str | None = None,
loop: AbstractEventLoop | None = None,
) -> Task[_T]:
"""Create a task from a coroutine and schedule it to run immediately."""
return Task(
coro,
loop=loop or get_running_loop(),
name=name,
)


def cancelling(task: Future[Any]) -> bool:
"""Return True if task is cancelling."""
Expand Down
Loading