Skip to content

Commit

Permalink
Merge pull request #2477 from harahu/run-types
Browse files Browse the repository at this point in the history
Add some low-effort type annotations
  • Loading branch information
harahu authored Jan 17, 2023
2 parents af3d7d8 + fc4ed29 commit d61b050
Showing 1 changed file with 86 additions and 57 deletions.
143 changes: 86 additions & 57 deletions trio/_core/_run.py
Original file line number Diff line number Diff line change
@@ -1,60 +1,61 @@
from __future__ import annotations

import enum
import functools
import gc
import itertools
import random
import select
import sys
import threading
import gc
import warnings
from collections import deque
from collections.abc import Callable
from contextlib import contextmanager
import warnings
import enum

from contextvars import copy_context
from heapq import heapify, heappop, heappush
from math import inf
from time import perf_counter
from typing import Callable, TYPE_CHECKING

from sniffio import current_async_library_cvar
from typing import TYPE_CHECKING, Any, NoReturn, TypeVar

import attr
from heapq import heapify, heappop, heappush
from sortedcontainers import SortedDict
from outcome import Error, Outcome, Value, capture
from sniffio import current_async_library_cvar
from sortedcontainers import SortedDict

# An unfortunate name collision here with trio._util.Final
from typing_extensions import Final as FinalT

from .. import _core
from .._util import Final, NoPublicConstructor, coroutine_or_error
from ._asyncgens import AsyncGenerators
from ._entry_queue import EntryQueue, TrioToken
from ._exceptions import TrioInternalError, RunFinishedError, Cancelled
from ._ki import (
LOCALS_KEY_KI_PROTECTION_ENABLED,
KIManager,
enable_ki_protection,
)
from ._exceptions import Cancelled, RunFinishedError, TrioInternalError
from ._instrumentation import Instruments
from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED, KIManager, enable_ki_protection
from ._multierror import MultiError, concat_tb
from ._thread_cache import start_thread_soon
from ._traps import (
Abort,
wait_task_rescheduled,
cancel_shielded_checkpoint,
CancelShieldedCheckpoint,
PermanentlyDetachCoroutineObject,
WaitTaskRescheduled,
cancel_shielded_checkpoint,
wait_task_rescheduled,
)
from ._asyncgens import AsyncGenerators
from ._thread_cache import start_thread_soon
from ._instrumentation import Instruments
from .. import _core
from .._util import Final, NoPublicConstructor, coroutine_or_error

if sys.version_info < (3, 11):
from exceptiongroup import BaseExceptionGroup

DEADLINE_HEAP_MIN_PRUNE_THRESHOLD = 1000
DEADLINE_HEAP_MIN_PRUNE_THRESHOLD: FinalT = 1000

_NO_SEND = object()
_NO_SEND: FinalT = object()

FnT = TypeVar("FnT", bound="Callable[..., Any]")

# Decorator to mark methods public. This does nothing by itself, but
# trio/_tools/gen_exports.py looks for it.
def _public(fn):
def _public(fn: FnT) -> FnT:
return fn


Expand All @@ -63,50 +64,71 @@ def _public(fn):
# variable to True, and registers the Random instance _r for Hypothesis
# to manage for each test case, which together should make Trio's task
# scheduling loop deterministic. We have a test for that, of course.
_ALLOW_DETERMINISTIC_SCHEDULING = False
_ALLOW_DETERMINISTIC_SCHEDULING: FinalT = False
_r = random.Random()


# On CPython, Context.run() is implemented in C and doesn't show up in
# tracebacks. On PyPy, it is implemented in Python and adds 1 frame to tracebacks.
def _count_context_run_tb_frames():
def function_with_unique_name_xyzzy():
1 / 0
def _count_context_run_tb_frames() -> int:
"""Count implementation dependent traceback frames from Context.run()
On CPython, Context.run() is implemented in C and doesn't show up in
tracebacks. On PyPy, it is implemented in Python and adds 1 frame to
tracebacks.
Returns:
int: Traceback frame count
"""

def function_with_unique_name_xyzzy() -> NoReturn:
try:
1 / 0
except ZeroDivisionError:
raise
else: # pragma: no cover
raise TrioInternalError(
"A ZeroDivisionError should have been raised, but it wasn't."
)

ctx = copy_context()
try:
ctx.run(function_with_unique_name_xyzzy)
except ZeroDivisionError as exc:
tb = exc.__traceback__
# Skip the frame where we caught it
tb = tb.tb_next
tb = tb.tb_next # type: ignore[union-attr]
count = 0
while tb.tb_frame.f_code.co_name != "function_with_unique_name_xyzzy":
tb = tb.tb_next
while tb.tb_frame.f_code.co_name != "function_with_unique_name_xyzzy": # type: ignore[union-attr]
tb = tb.tb_next # type: ignore[union-attr]
count += 1
return count
else: # pragma: no cover
raise TrioInternalError(
f"The purpose of {function_with_unique_name_xyzzy.__name__} is "
"to raise a ZeroDivisionError, but it didn't."
)


CONTEXT_RUN_TB_FRAMES = _count_context_run_tb_frames()
CONTEXT_RUN_TB_FRAMES: FinalT = _count_context_run_tb_frames()


@attr.s(frozen=True, slots=True)
class SystemClock:
# Add a large random offset to our clock to ensure that if people
# accidentally call time.perf_counter() directly or start comparing clocks
# between different runs, then they'll notice the bug quickly:
offset = attr.ib(factory=lambda: _r.uniform(10000, 200000))
offset: float = attr.ib(factory=lambda: _r.uniform(10000, 200000))

def start_clock(self):
def start_clock(self) -> None:
pass

# In cPython 3, on every platform except Windows, perf_counter is
# exactly the same as time.monotonic; and on Windows, it uses
# QueryPerformanceCounter instead of GetTickCount64.
def current_time(self):
def current_time(self) -> float:
return self.offset + perf_counter()

def deadline_to_sleep_time(self, deadline):
def deadline_to_sleep_time(self, deadline: float) -> float:
return deadline - self.current_time()


Expand Down Expand Up @@ -1119,7 +1141,7 @@ class Task(metaclass=NoPublicConstructor):
name = attr.ib()
# PEP 567 contextvars context
context = attr.ib()
_counter = attr.ib(init=False, factory=itertools.count().__next__)
_counter: int = attr.ib(init=False, factory=itertools.count().__next__)

# Invariant:
# - for unscheduled tasks, _next_send_fn and _next_send are both None
Expand Down Expand Up @@ -1293,7 +1315,7 @@ class RunContext(threading.local):
task: Task


GLOBAL_RUN_CONTEXT = RunContext()
GLOBAL_RUN_CONTEXT: FinalT = RunContext()


@attr.s(frozen=True)
Expand Down Expand Up @@ -1380,7 +1402,7 @@ class Runner:
# Run-local values, see _local.py
_locals = attr.ib(factory=dict)

runq = attr.ib(factory=deque)
runq: deque[Task] = attr.ib(factory=deque)
tasks = attr.ib(factory=set)

deadlines = attr.ib(factory=Deadlines)
Expand Down Expand Up @@ -1957,8 +1979,8 @@ def run(
*args,
clock=None,
instruments=(),
restrict_keyboard_interrupt_to_checkpoints=False,
strict_exception_groups=False,
restrict_keyboard_interrupt_to_checkpoints: bool = False,
strict_exception_groups: bool = False,
):
"""Run a Trio-flavored async function, and return the result.
Expand Down Expand Up @@ -2063,11 +2085,11 @@ def start_guest_run(
run_sync_soon_threadsafe,
done_callback,
run_sync_soon_not_threadsafe=None,
host_uses_signal_set_wakeup_fd=False,
host_uses_signal_set_wakeup_fd: bool = False,
clock=None,
instruments=(),
restrict_keyboard_interrupt_to_checkpoints=False,
strict_exception_groups=False,
restrict_keyboard_interrupt_to_checkpoints: bool = False,
strict_exception_groups: bool = False,
):
"""Start a "guest" run of Trio on top of some other "host" event loop.
Expand Down Expand Up @@ -2147,14 +2169,19 @@ def my_done_callback(run_outcome):

# 24 hours is arbitrary, but it avoids issues like people setting timeouts of
# 10**20 and then getting integer overflows in the underlying system calls.
_MAX_TIMEOUT = 24 * 60 * 60
_MAX_TIMEOUT: FinalT = 24 * 60 * 60


# Weird quirk: this is written as a generator in order to support "guest
# mode", where our core event loop gets unrolled into a series of callbacks on
# the host loop. If you're doing a regular trio.run then this gets run
# straight through.
def unrolled_run(runner, async_fn, args, host_uses_signal_set_wakeup_fd=False):
def unrolled_run(
runner: Runner,
async_fn,
args,
host_uses_signal_set_wakeup_fd: bool = False,
):
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
__tracebackhide__ = True

Expand All @@ -2173,7 +2200,7 @@ def unrolled_run(runner, async_fn, args, host_uses_signal_set_wakeup_fd=False):
# here is our event loop:
while runner.tasks:
if runner.runq:
timeout = 0
timeout: float = 0
else:
deadline = runner.deadlines.next_deadline()
timeout = runner.clock.deadline_to_sleep_time(deadline)
Expand Down Expand Up @@ -2301,8 +2328,10 @@ def unrolled_run(runner, async_fn, args, host_uses_signal_set_wakeup_fd=False):
# frame we always remove, because it's this function
# catching it, and then in addition we remove however many
# more Context.run adds.
tb = task_exc.__traceback__.tb_next
for _ in range(CONTEXT_RUN_TB_FRAMES):
tb = task_exc.__traceback__
for _ in range(1 + CONTEXT_RUN_TB_FRAMES):
if tb is None:
break
tb = tb.tb_next
final_outcome = Error(task_exc.with_traceback(tb))
# Remove local refs so that e.g. cancelled coroutine locals
Expand Down Expand Up @@ -2397,7 +2426,7 @@ def started(self, value=None):
pass


TASK_STATUS_IGNORED = _TaskStatusIgnored()
TASK_STATUS_IGNORED: FinalT = _TaskStatusIgnored()


def current_task():
Expand Down Expand Up @@ -2493,16 +2522,16 @@ async def checkpoint_if_cancelled():


if sys.platform == "win32":
from ._io_windows import WindowsIOManager as TheIOManager
from ._generated_io_windows import *
from ._io_windows import WindowsIOManager as TheIOManager
elif sys.platform == "linux" or (not TYPE_CHECKING and hasattr(select, "epoll")):
from ._io_epoll import EpollIOManager as TheIOManager
from ._generated_io_epoll import *
from ._io_epoll import EpollIOManager as TheIOManager
elif TYPE_CHECKING or hasattr(select, "kqueue"):
from ._io_kqueue import KqueueIOManager as TheIOManager
from ._generated_io_kqueue import *
from ._io_kqueue import KqueueIOManager as TheIOManager
else: # pragma: no cover
raise NotImplementedError("unsupported platform")

from ._generated_run import *
from ._generated_instrumentation import *
from ._generated_run import *

0 comments on commit d61b050

Please sign in to comment.