Skip to content
/ psycopg Public
forked from psycopg/psycopg

Commit

Permalink
Add waiting functions using anyio
Browse files Browse the repository at this point in the history
We need to build a socket object in these functions in order to use
anyio.wait_socket_{readable,writable}(). See discussions at
agronholm/anyio#386
Perhaps we should maintain a cache of these for performance?
  • Loading branch information
dlax committed Nov 18, 2021
1 parent 9dcfd07 commit 77663fe
Show file tree
Hide file tree
Showing 2 changed files with 153 additions and 10 deletions.
134 changes: 134 additions & 0 deletions psycopg/psycopg/waiting.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,140 @@ def wakeup(state: Ready) -> None:
return rv


try:
import anyio
except ImportError:
pass
else:
import socket

async def wait_anyio(gen: PQGen[RV], fileno: int) -> RV:
"""
Coroutine waiting for a generator to complete.
:param gen: a generator performing database operations and yielding
`Ready` values when it would block.
:param fileno: the file descriptor to wait on.
:return: whatever *gen* returns on completion.
Behave like in `wait()`, but exposing an `anyio` interface.
"""
s: Wait
ready: Ready

try:
sock = socket.fromfd(fileno, socket.AF_INET, socket.SOCK_STREAM)
except OSError:
# TODO: set a meaningful error message
raise e.OperationalError

async def readable(ev: anyio.Event) -> None:
await anyio.wait_socket_readable(sock)
nonlocal ready
ready |= Ready.R # type: ignore[assignment]
ev.set()

async def writable(ev: anyio.Event) -> None:
await anyio.wait_socket_writable(sock)
nonlocal ready
ready |= Ready.W # type: ignore[assignment]
ev.set()

try:
s = next(gen)
while 1:
reader = s & Wait.R
writer = s & Wait.W
if not reader and not writer:
raise e.InternalError(f"bad poll status: {s}")
ev = anyio.Event()
ready = 0 # type: ignore[assignment]
async with anyio.create_task_group() as tg:
if reader:
tg.start_soon(readable, ev)
if writer:
tg.start_soon(writable, ev)
await ev.wait()

s = gen.send(ready)

except StopIteration as ex:
rv: RV = ex.args[0] if ex.args else None
return rv

finally:
sock.close()

async def wait_conn_anyio(
gen: PQGenConn[RV], timeout: Optional[float] = None
) -> RV:
"""
Coroutine waiting for a connection generator to complete.
:param gen: a generator performing database operations and yielding
(fd, `Ready`) pairs when it would block.
:param timeout: timeout (in seconds) to check for other interrupt, e.g.
to allow Ctrl-C. If zero or None, wait indefinitely.
:return: whatever *gen* returns on completion.
Behave like in `wait()`, but take the fileno to wait from the generator
itself, which might change during processing.
"""
s: Wait
ready: Ready

async def readable(sock: socket.socket, ev: anyio.Event) -> None:
await anyio.wait_socket_readable(sock)
nonlocal ready
ready |= Ready.R # type: ignore[assignment]
ev.set()

async def writable(sock: socket.socket, ev: anyio.Event) -> None:
await anyio.wait_socket_writable(sock)
nonlocal ready
ready |= Ready.W # type: ignore[assignment]
ev.set()

timeout = timeout or None
try:
fileno, s = next(gen)

while 1:
try:
sock = socket.fromfd(
fileno, socket.AF_INET, socket.SOCK_STREAM
)
except OSError:
# TODO: set a meaningful error message
raise e.OperationalError
reader = s & Wait.R
writer = s & Wait.W
if not reader and not writer:
raise e.InternalError(f"bad poll status: {s}")
ev = anyio.Event()
ready = 0 # type: ignore[assignment]
with sock:
async with anyio.create_task_group() as tg:
if reader:
tg.start_soon(readable, sock, ev)
if writer:
tg.start_soon(writable, sock, ev)
try:
with anyio.fail_after(timeout):
await ev.wait()
except TimeoutError:
raise e.OperationalError("timeout expired")

fileno, s = gen.send(ready)

except TimeoutError:
raise e.OperationalError("timeout expired")

except StopIteration as ex:
rv: RV = ex.args[0] if ex.args else None
return rv


def wait_epoll(
gen: PQGen[RV], fileno: int, timeout: Optional[float] = None
) -> RV:
Expand Down
29 changes: 19 additions & 10 deletions tests/test_waiting.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,46 +107,55 @@ def test_wait_epoll_bad(pgconn):


@pytest.mark.parametrize("timeout", timeouts)
@pytest.mark.parametrize(
"waitfn", [waiting.wait_conn_asyncio, waiting.wait_conn_anyio]
)
@pytest.mark.asyncio
async def test_wait_conn_async(dsn, timeout):
async def test_wait_conn_async(dsn, timeout, waitfn):
gen = generators.connect(dsn)
conn = await waiting.wait_conn_asyncio(gen, **timeout)
conn = await waitfn(gen, **timeout)
assert conn.status == ConnStatus.OK


@pytest.mark.parametrize(
"waitfn", [waiting.wait_conn_asyncio, waiting.wait_conn_anyio]
)
@pytest.mark.asyncio
async def test_wait_conn_async_bad(dsn):
async def test_wait_conn_async_bad(dsn, waitfn):
gen = generators.connect("dbname=nosuchdb")
with pytest.raises(psycopg.OperationalError):
await waiting.wait_conn_asyncio(gen)
await waitfn(gen)


@pytest.mark.parametrize("waitfn", [waiting.wait_asyncio, waiting.wait_anyio])
@pytest.mark.asyncio
async def test_wait_async(pgconn):
async def test_wait_async(pgconn, waitfn):
pgconn.send_query(b"select 1")
gen = generators.execute(pgconn)
(res,) = await waiting.wait_asyncio(gen, pgconn.socket)
(res,) = await waitfn(gen, pgconn.socket)
assert res.status == ExecStatus.TUPLES_OK


@pytest.mark.asyncio
@pytest.mark.parametrize("wait, ready", zip(waiting.Wait, waiting.Ready))
@pytest.mark.parametrize("waitfn", [waiting.wait_asyncio, waiting.wait_anyio])
@skip_if_not_linux
async def test_wait_ready_async(wait, ready):
async def test_wait_ready_async(wait, ready, waitfn):
def gen():
r = yield wait
return r

with socket.socket() as s:
r = await waiting.wait_asyncio(gen(), s.fileno())
r = await waitfn(gen(), s.fileno())
assert r & ready


@pytest.mark.parametrize("waitfn", [waiting.wait_asyncio, waiting.wait_anyio])
@pytest.mark.asyncio
async def test_wait_async_bad(pgconn):
async def test_wait_async_bad(pgconn, waitfn):
pgconn.send_query(b"select 1")
gen = generators.execute(pgconn)
socket = pgconn.socket
pgconn.finish()
with pytest.raises(psycopg.OperationalError):
await waiting.wait_asyncio(gen, socket)
await waitfn(gen, socket)

0 comments on commit 77663fe

Please sign in to comment.