Skip to content

Commit

Permalink
Implement curio backend (#168)
Browse files Browse the repository at this point in the history
* Implemented curio backend (#94)

* Fixing PR remarks (#94)

* Fixing PR remarks. Mention curio in the same context when the pair of asyncio and trio are mentioned (#94)

* Fixing PR remarks. Made pytest.mark.curio completely isolated from conftest.py (for now it's easier to add new backend with custom marks) (#94)

* PR review. Updated tests/marks/curio.py (removed unnecessary fixture)

Co-authored-by: Florimond Manca <florimond.manca@gmail.com>

* Added "curio" test mark programmatically (#94)

* Fixed PR remarks (#94)
Added timeout handling to Semaphore::acquire and tried to avoid private API usage in SocketStream::get_http_version, also changed is_connection_dropped behaviour

* Fixed PR remarks (#94)
Rewrote _wrap_ssl_client using ssl.SSLContext::wrap_socket

* PR review (#94)

Co-authored-by: Florimond Manca <florimond.manca@gmail.com>

* PR review (#94)

Co-authored-by: Florimond Manca <florimond.manca@gmail.com>

Co-authored-by: Florimond Manca <florimond.manca@gmail.com>
  • Loading branch information
cdeler and florimondmanca authored Sep 5, 2020
1 parent ccea853 commit 2a1c6ef
Show file tree
Hide file tree
Showing 8 changed files with 274 additions and 3 deletions.
4 changes: 4 additions & 0 deletions httpcore/_backends/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ def backend(self) -> AsyncBackend:
from .trio import TrioBackend

self._backend_implementation = TrioBackend()
elif backend == "curio":
from .curio import CurioBackend

self._backend_implementation = CurioBackend()
else: # pragma: nocover
raise RuntimeError(f"Unsupported concurrency backend {backend!r}")
return self._backend_implementation
Expand Down
202 changes: 202 additions & 0 deletions httpcore/_backends/curio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
import select
from ssl import SSLContext, SSLSocket
from typing import Optional

import curio
import curio.io

from .._exceptions import (
ConnectError,
ConnectTimeout,
ReadError,
ReadTimeout,
WriteError,
WriteTimeout,
map_exceptions,
)
from .._types import TimeoutDict
from .._utils import get_logger
from .base import AsyncBackend, AsyncLock, AsyncSemaphore, AsyncSocketStream

logger = get_logger(__name__)

ONE_DAY_IN_SECONDS = float(60 * 60 * 24)


def convert_timeout(value: Optional[float]) -> float:
return value if value is not None else ONE_DAY_IN_SECONDS


class Lock(AsyncLock):
def __init__(self) -> None:
self._lock = curio.Lock()

async def acquire(self) -> None:
await self._lock.acquire()

async def release(self) -> None:
await self._lock.release()


class Semaphore(AsyncSemaphore):
def __init__(self, max_value: int, exc_class: type) -> None:
self.max_value = max_value
self.exc_class = exc_class

@property
def semaphore(self) -> curio.Semaphore:
if not hasattr(self, "_semaphore"):
self._semaphore = curio.Semaphore(value=self.max_value)
return self._semaphore

async def acquire(self, timeout: float = None) -> None:
timeout = convert_timeout(timeout)

try:
return await curio.timeout_after(timeout, self.semaphore.acquire())
except curio.TaskTimeout:
raise self.exc_class()

async def release(self) -> None:
await self.semaphore.release()


class SocketStream(AsyncSocketStream):
def __init__(self, socket: curio.io.Socket) -> None:
self.read_lock = curio.Lock()
self.write_lock = curio.Lock()
self.socket = socket
self.stream = socket.as_stream()

def get_http_version(self) -> str:
if hasattr(self.socket, "_socket"):
raw_socket = self.socket._socket

if isinstance(raw_socket, SSLSocket):
ident = raw_socket.selected_alpn_protocol()
return "HTTP/2" if ident == "h2" else "HTTP/1.1"

return "HTTP/1.1"

async def start_tls(
self, hostname: bytes, ssl_context: SSLContext, timeout: TimeoutDict
) -> "AsyncSocketStream":
connect_timeout = convert_timeout(timeout.get("connect"))
exc_map = {
curio.TaskTimeout: ConnectTimeout,
curio.CurioError: ConnectError,
OSError: ConnectError,
}

with map_exceptions(exc_map):
wrapped_sock = curio.io.Socket(
ssl_context.wrap_socket(
self.socket._socket,
do_handshake_on_connect=False,
server_hostname=hostname.decode("ascii"),
)
)

await curio.timeout_after(
connect_timeout,
wrapped_sock.do_handshake(),
)

return SocketStream(wrapped_sock)

async def read(self, n: int, timeout: TimeoutDict) -> bytes:
read_timeout = convert_timeout(timeout.get("read"))
exc_map = {
curio.TaskTimeout: ReadTimeout,
curio.CurioError: ReadError,
OSError: ReadError,
}

with map_exceptions(exc_map):
async with self.read_lock:
return await curio.timeout_after(read_timeout, self.stream.read(n))

async def write(self, data: bytes, timeout: TimeoutDict) -> None:
write_timeout = convert_timeout(timeout.get("write"))
exc_map = {
curio.TaskTimeout: WriteTimeout,
curio.CurioError: WriteError,
OSError: WriteError,
}

with map_exceptions(exc_map):
async with self.write_lock:
await curio.timeout_after(write_timeout, self.stream.write(data))

async def aclose(self) -> None:
await self.stream.close()
await self.socket.close()

def is_connection_dropped(self) -> bool:
rready, _, _ = select.select([self.socket.fileno()], [], [], 0)

return bool(rready)


class CurioBackend(AsyncBackend):
async def open_tcp_stream(
self,
hostname: bytes,
port: int,
ssl_context: Optional[SSLContext],
timeout: TimeoutDict,
*,
local_address: Optional[str],
) -> AsyncSocketStream:
connect_timeout = convert_timeout(timeout.get("connect"))
exc_map = {
curio.TaskTimeout: ConnectTimeout,
curio.CurioError: ConnectError,
OSError: ConnectError,
}
host = hostname.decode("ascii")
kwargs = (
{} if ssl_context is None else {"ssl": ssl_context, "server_hostname": host}
)

with map_exceptions(exc_map):
sock: curio.io.Socket = await curio.timeout_after(
connect_timeout,
curio.open_connection(hostname, port, **kwargs),
)

return SocketStream(sock)

async def open_uds_stream(
self,
path: str,
hostname: bytes,
ssl_context: Optional[SSLContext],
timeout: TimeoutDict,
) -> AsyncSocketStream:
connect_timeout = convert_timeout(timeout.get("connect"))
exc_map = {
curio.TaskTimeout: ConnectTimeout,
curio.CurioError: ConnectError,
OSError: ConnectError,
}
host = hostname.decode("ascii")
kwargs = (
{} if ssl_context is None else {"ssl": ssl_context, "server_hostname": host}
)

with map_exceptions(exc_map):
sock: curio.io.Socket = await curio.timeout_after(
connect_timeout, curio.open_unix_connection(path, **kwargs)
)

return SocketStream(sock)

def create_lock(self) -> AsyncLock:
return Lock()

def create_semaphore(self, max_value: int, exc_class: type) -> AsyncSemaphore:
return Semaphore(max_value, exc_class)

async def time(self) -> float:
return await curio.clock()
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Optionals
trio
trio-typing
curio

# Docs
mkautodoc
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def get_packages(package):
"Topic :: Internet :: WWW/HTTP",
"Framework :: AsyncIO",
"Framework :: Trio",
"Framework :: Curio",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
Expand Down
26 changes: 23 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,42 @@

from httpcore._types import URL

from .marks.curio import curio_pytest_pycollect_makeitem, curio_pytest_pyfunc_call

PROXY_HOST = "127.0.0.1"
PROXY_PORT = 8080


def pytest_configure(config):
config.addinivalue_line(
"markers",
"curio: mark the test as a coroutine, it will be run using a Curio kernel.",
)


@pytest.mark.tryfirst
def pytest_pycollect_makeitem(collector, name, obj):
curio_pytest_pycollect_makeitem(collector, name, obj)


@pytest.hookimpl(tryfirst=True, hookwrapper=True)
def pytest_pyfunc_call(pyfuncitem):
yield from curio_pytest_pyfunc_call(pyfuncitem)


@pytest.fixture(
params=[
pytest.param("asyncio", marks=pytest.mark.asyncio),
pytest.param("trio", marks=pytest.mark.trio),
pytest.param("curio", marks=pytest.mark.curio),
]
)
def async_environment(request: typing.Any) -> str:
"""
Mark a test function to be run on both asyncio and trio.
Mark a test function to be run on asyncio, trio and curio.
Equivalent to having a pair of tests, each respectively marked with
'@pytest.mark.asyncio' and '@pytest.mark.trio'.
Equivalent to having three tests, each respectively marked with
'@pytest.mark.asyncio', '@pytest.mark.trio' and '@pytest.mark.curio'.
Intended usage:
Expand Down
Empty file added tests/marks/__init__.py
Empty file.
42 changes: 42 additions & 0 deletions tests/marks/curio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import functools
import inspect

import curio
import curio.debug
import curio.meta
import curio.monitor
import pytest


def _is_coroutine(obj):
"""Check to see if an object is really a coroutine."""
return curio.meta.iscoroutinefunction(obj) or inspect.isgeneratorfunction(obj)


@pytest.mark.tryfirst
def curio_pytest_pycollect_makeitem(collector, name, obj):
"""A pytest hook to collect coroutines in a test module."""
if collector.funcnamefilter(name) and _is_coroutine(obj):
item = pytest.Function.from_parent(collector, name=name)
if "curio" in item.keywords:
return list(collector._genfunctions(name, obj)) # pragma: nocover


@pytest.hookimpl(tryfirst=True, hookwrapper=True)
def curio_pytest_pyfunc_call(pyfuncitem):
"""Run curio marked test functions in a Curio kernel
instead of a normal function call."""
if pyfuncitem.get_closest_marker("curio"):
pyfuncitem.obj = wrap_in_sync(pyfuncitem.obj)
yield


def wrap_in_sync(func):
"""Return a sync wrapper around an async function executing it in a Kernel."""

@functools.wraps(func)
def inner(**kwargs):
coro = func(**kwargs)
curio.Kernel().run(coro, shutdown=True)

return inner
1 change: 1 addition & 0 deletions unasync.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
('__aiter__', '__iter__'),
('@pytest.mark.asyncio', ''),
('@pytest.mark.trio', ''),
('@pytest.mark.curio', ''),
('@pytest.mark.usefixtures.*', ''),
]
COMPILED_SUBS = [
Expand Down

0 comments on commit 2a1c6ef

Please sign in to comment.