Skip to content
/ psycopg Public
forked from psycopg/psycopg

Commit

Permalink
Rewrite waiting.wait_{,conn_}async() using anyio
Browse files Browse the repository at this point in the history
We need to build a socket object in these functions in order to use
anyio.wait_socket_{readable,writable}(). See discussions at
agronholm/anyio#386
Perhaps we should maintain a cache of these for performance?
  • Loading branch information
dlax committed Nov 10, 2021
1 parent 93257ef commit 99a5bf6
Showing 1 changed file with 65 additions and 50 deletions.
115 changes: 65 additions & 50 deletions psycopg/psycopg/waiting.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@

import select
import selectors
import socket
from enum import IntEnum
from typing import Optional
from asyncio import get_event_loop, wait_for, Event, TimeoutError
from selectors import DefaultSelector, EVENT_READ, EVENT_WRITE

import anyio

from . import errors as e
from .abc import PQGen, PQGenConn, RV

Expand Down Expand Up @@ -107,46 +109,49 @@ async def wait_async(gen: PQGen[RV], fileno: int) -> RV:
:param fileno: the file descriptor to wait on.
:return: whatever *gen* returns on completion.
Behave like in `wait()`, but exposing an `asyncio` interface.
Behave like in `wait()`, but exposing an `async` interface.
"""
# Use an event to block and restart after the fd state changes.
# Not sure this is the best implementation but it's a start.
ev = Event()
loop = get_event_loop()
ready: Ready
s: Wait
ready: Ready

def wakeup(state: Ready) -> None:
try:
sock = socket.fromfd(fileno, socket.AF_INET, socket.SOCK_STREAM)
except OSError:
# TODO: set a meaningful error message
raise e.OperationalError

async def readable(ev: anyio.Event) -> None:
await anyio.wait_socket_readable(sock)
nonlocal ready
ready = state
ready = Ready.R
ev.set()

async def writable(ev: anyio.Event) -> None:
await anyio.wait_socket_writable(sock)
nonlocal ready
ready = Ready.W
ev.set()

try:
s = next(gen)
while 1:
ev.clear()
if s == Wait.R:
loop.add_reader(fileno, wakeup, Ready.R)
await ev.wait()
loop.remove_reader(fileno)
elif s == Wait.W:
loop.add_writer(fileno, wakeup, Ready.W)
ev = anyio.Event()
async with anyio.create_task_group() as tg:
if s == Wait.R:
tg.start_soon(readable, ev)
if s == Wait.W:
tg.start_soon(writable, ev)
await ev.wait()
loop.remove_writer(fileno)
elif s == Wait.RW:
loop.add_reader(fileno, wakeup, Ready.R)
loop.add_writer(fileno, wakeup, Ready.W)
await ev.wait()
loop.remove_reader(fileno)
loop.remove_writer(fileno)
else:
raise e.InternalError("bad poll status: %s")

s = gen.send(ready)

except StopIteration as ex:
rv: RV = ex.args[0] if ex.args else None
return rv

finally:
sock.close()


async def wait_conn_async(
gen: PQGenConn[RV], timeout: Optional[float] = None
Expand All @@ -163,40 +168,50 @@ async def wait_conn_async(
Behave like in `wait()`, but take the fileno to wait from the generator
itself, which might change during processing.
"""
# Use an event to block and restart after the fd state changes.
# Not sure this is the best implementation but it's a start.
ev = Event()
loop = get_event_loop()
ready: Ready
s: Wait
ready: Ready

async def readable(sock: socket.socket, ev: anyio.Event) -> None:
await anyio.wait_socket_readable(sock)
nonlocal ready
ready = Ready.R
ev.set()

def wakeup(state: Ready) -> None:
async def writable(sock: socket.socket, ev: anyio.Event) -> None:
await anyio.wait_socket_writable(sock)
nonlocal ready
ready = state
ready = Ready.W
ev.set()

timeout = timeout or None
try:
fileno, s = next(gen)

while 1:
ev.clear()
if s == Wait.R:
loop.add_reader(fileno, wakeup, Ready.R)
await wait_for(ev.wait(), timeout)
loop.remove_reader(fileno)
elif s == Wait.W:
loop.add_writer(fileno, wakeup, Ready.W)
await wait_for(ev.wait(), timeout)
loop.remove_writer(fileno)
elif s == Wait.RW:
loop.add_reader(fileno, wakeup, Ready.R)
loop.add_writer(fileno, wakeup, Ready.W)
await wait_for(ev.wait(), timeout)
loop.remove_reader(fileno)
loop.remove_writer(fileno)
else:
raise e.InternalError("bad poll status: %s")
fileno, s = gen.send(ready)
try:
sock = socket.fromfd(
fileno, socket.AF_INET, socket.SOCK_STREAM
)
except OSError:
# TODO: set a meaningful error message
raise e.OperationalError
with sock:
ev = anyio.Event()
async with anyio.create_task_group() as tg:
if s == Wait.R:
tg.start_soon(readable, sock, ev)
elif s == Wait.W:
tg.start_soon(writable, sock, ev)
else:
raise e.InternalError(f"bad poll status: {s}")

try:
with anyio.fail_after(timeout):
await ev.wait()
except TimeoutError:
raise e.OperationalError("timeout expired")

fileno, s = gen.send(ready)

except TimeoutError:
raise e.OperationalError("timeout expired")
Expand Down

0 comments on commit 99a5bf6

Please sign in to comment.