Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement curio backend #168

Merged
merged 10 commits into from
Sep 5, 2020
4 changes: 4 additions & 0 deletions httpcore/_backends/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ def backend(self) -> AsyncBackend:
from .trio import TrioBackend

self._backend_implementation = TrioBackend()
elif backend == "curio":
from .curio import CurioBackend

self._backend_implementation = CurioBackend()
else: # pragma: nocover
raise RuntimeError(f"Unsupported concurrency backend {backend!r}")
return self._backend_implementation
Expand Down
188 changes: 188 additions & 0 deletions httpcore/_backends/curio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
from ssl import SSLContext
from typing import Optional

import curio
import curio.io
from curio.network import _wrap_ssl_client

from .._exceptions import (
ConnectError,
ConnectTimeout,
ReadError,
ReadTimeout,
WriteError,
WriteTimeout,
map_exceptions,
)
from .._types import TimeoutDict
from .._utils import get_logger
from .base import AsyncBackend, AsyncLock, AsyncSemaphore, AsyncSocketStream

logger = get_logger("curio_backend")
cdeler marked this conversation as resolved.
Show resolved Hide resolved

one_day_in_seconds = 60 * 60 * 24
cdeler marked this conversation as resolved.
Show resolved Hide resolved


def convert_timeout(value: Optional[float]) -> int:
return int(value) if value is not None else one_day_in_seconds
florimondmanca marked this conversation as resolved.
Show resolved Hide resolved


class Lock(AsyncLock):
def __init__(self) -> None:
self._lock = curio.Lock()

async def acquire(self) -> None:
await self._lock.acquire()

async def release(self) -> None:
await self._lock.release()


class Semaphore(AsyncSemaphore):
def __init__(self, max_value: int, exc_class: type) -> None:
self.max_value = max_value
self.exc_class = exc_class

@property
def semaphore(self) -> curio.Semaphore:
if not hasattr(self, "_semaphore"):
self._semaphore = curio.Semaphore(value=self.max_value)
return self._semaphore

async def acquire(self, timeout: float = None) -> None:
await self.semaphore.acquire()
florimondmanca marked this conversation as resolved.
Show resolved Hide resolved

async def release(self) -> None:
await self.semaphore.release()


class SocketStream(AsyncSocketStream):
def __init__(self, socket: curio.io.Socket) -> None:
self.read_lock = curio.Lock()
self.write_lock = curio.Lock()
self.socket = socket
self.stream = socket.as_stream()

def get_http_version(self) -> str:
if hasattr(self.socket, "_socket") and hasattr(self.socket._socket, "_sslobj"):
ident = self.socket._socket._sslobj.selected_alpn_protocol()
florimondmanca marked this conversation as resolved.
Show resolved Hide resolved
else:
ident = "http/1.1"
return "HTTP/2" if ident == "h2" else "HTTP/1.1"

async def start_tls(
self, hostname: bytes, ssl_context: SSLContext, timeout: TimeoutDict
) -> "AsyncSocketStream":
connect_timeout = convert_timeout(timeout.get("connect"))
exc_map = {
curio.TaskTimeout: ConnectTimeout,
curio.CurioError: ConnectError,
OSError: ConnectError,
}

with map_exceptions(exc_map):
wrapped_sock = await curio.timeout_after(
connect_timeout,
_wrap_ssl_client(
florimondmanca marked this conversation as resolved.
Show resolved Hide resolved
self.socket,
ssl=ssl_context,
server_hostname=hostname,
alpn_protocols=["h2", "http/1.1"],
),
)

return SocketStream(wrapped_sock)

async def read(self, n: int, timeout: TimeoutDict) -> bytes:
read_timeout = convert_timeout(timeout.get("read"))
exc_map = {
curio.TaskTimeout: ReadTimeout,
curio.CurioError: ReadError,
OSError: ReadError,
}

with map_exceptions(exc_map):
async with self.read_lock:
return await curio.timeout_after(read_timeout, self.stream.read(n))

async def write(self, data: bytes, timeout: TimeoutDict) -> None:
write_timeout = convert_timeout(timeout.get("write"))
exc_map = {
curio.TaskTimeout: WriteTimeout,
curio.CurioError: WriteError,
OSError: WriteError,
}

with map_exceptions(exc_map):
async with self.write_lock:
await curio.timeout_after(write_timeout, self.stream.write(data))

async def aclose(self) -> None:
# we dont need to close the self.socket, since it's closed by stream closing
await self.stream.close()

def is_connection_dropped(self) -> bool:
return self.socket._closed
florimondmanca marked this conversation as resolved.
Show resolved Hide resolved


class CurioBackend(AsyncBackend):
async def open_tcp_stream(
self,
hostname: bytes,
port: int,
ssl_context: Optional[SSLContext],
timeout: TimeoutDict,
*,
local_address: Optional[str],
) -> AsyncSocketStream:
connect_timeout = convert_timeout(timeout.get("connect"))
exc_map = {
curio.TaskTimeout: ConnectTimeout,
curio.CurioError: ConnectError,
OSError: ConnectError,
}
host = hostname.decode("ascii")
kwargs = (
{} if not ssl_context else {"ssl": ssl_context, "server_hostname": host}
cdeler marked this conversation as resolved.
Show resolved Hide resolved
)

with map_exceptions(exc_map):
sock: curio.io.Socket = await curio.timeout_after(
connect_timeout, curio.open_connection(hostname, port, **kwargs)
)

return SocketStream(sock)

async def open_uds_stream(
self,
path: str,
hostname: bytes,
ssl_context: Optional[SSLContext],
timeout: TimeoutDict,
) -> AsyncSocketStream:
connect_timeout = convert_timeout(timeout.get("connect"))
exc_map = {
curio.TaskTimeout: ConnectTimeout,
curio.CurioError: ConnectError,
OSError: ConnectError,
}
host = hostname.decode("ascii")
kwargs = (
{} if not ssl_context else {"ssl": ssl_context, "server_hostname": host}
cdeler marked this conversation as resolved.
Show resolved Hide resolved
)

with map_exceptions(exc_map):
sock: curio.io.Socket = await curio.timeout_after(
connect_timeout, curio.open_unix_connection(path, **kwargs)
)

return SocketStream(sock)

def create_lock(self) -> AsyncLock:
return Lock()

def create_semaphore(self, max_value: int, exc_class: type) -> AsyncSemaphore:
return Semaphore(max_value, exc_class)

async def time(self) -> float:
return float(await curio.clock())
cdeler marked this conversation as resolved.
Show resolved Hide resolved
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Optionals
trio
trio-typing
curio

# Docs
mkautodoc
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def get_packages(package):
"Topic :: Internet :: WWW/HTTP",
"Framework :: AsyncIO",
"Framework :: Trio",
"Framework :: Curio",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
Expand Down
27 changes: 24 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,43 @@

from httpcore._types import URL

from .marks.curio import curio_pytest_pycollect_makeitem, curio_pytest_pyfunc_call

PROXY_HOST = "127.0.0.1"
PROXY_PORT = 8080


def pytest_configure(config):
# register an additional marker
cdeler marked this conversation as resolved.
Show resolved Hide resolved
config.addinivalue_line(
"markers",
"curio: mark the test as a coroutine, it will be run using a Curio kernel.",
)


@pytest.mark.tryfirst
def pytest_pycollect_makeitem(collector, name, obj):
curio_pytest_pycollect_makeitem(collector, name, obj)


@pytest.hookimpl(tryfirst=True, hookwrapper=True)
def pytest_pyfunc_call(pyfuncitem):
yield from curio_pytest_pyfunc_call(pyfuncitem)
florimondmanca marked this conversation as resolved.
Show resolved Hide resolved


@pytest.fixture(
params=[
pytest.param("asyncio", marks=pytest.mark.asyncio),
pytest.param("trio", marks=pytest.mark.trio),
pytest.param("curio", marks=pytest.mark.curio),
]
)
def async_environment(request: typing.Any) -> str:
"""
Mark a test function to be run on both asyncio and trio.
Mark a test function to be run on asyncio, trio and curio.
florimondmanca marked this conversation as resolved.
Show resolved Hide resolved

Equivalent to having a pair of tests, each respectively marked with
'@pytest.mark.asyncio' and '@pytest.mark.trio'.
Equivalent to having three tests, each respectively marked with
'@pytest.mark.asyncio', '@pytest.mark.trio' and '@pytest.mark.curio'.

Intended usage:

Expand Down
Empty file added tests/marks/__init__.py
Empty file.
42 changes: 42 additions & 0 deletions tests/marks/curio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import functools
import inspect

import curio
import curio.debug
import curio.meta
import curio.monitor
import pytest


def _is_coroutine(obj):
"""Check to see if an object is really a coroutine."""
return curio.meta.iscoroutinefunction(obj) or inspect.isgeneratorfunction(obj)


@pytest.mark.tryfirst
def curio_pytest_pycollect_makeitem(collector, name, obj):
"""A pytest hook to collect coroutines in a test module."""
if collector.funcnamefilter(name) and _is_coroutine(obj):
item = pytest.Function.from_parent(collector, name=name)
if "curio" in item.keywords:
return list(collector._genfunctions(name, obj))


@pytest.hookimpl(tryfirst=True, hookwrapper=True)
def curio_pytest_pyfunc_call(pyfuncitem):
"""Run curio marked test functions in a Curio kernel
instead of a normal function call."""
if pyfuncitem.get_closest_marker("curio"):
pyfuncitem.obj = wrap_in_sync(pyfuncitem.obj)
yield


def wrap_in_sync(func):
"""Return a sync wrapper around an async function executing it in a Kernel."""

@functools.wraps(func)
def inner(**kwargs):
coro = func(**kwargs)
curio.Kernel().run(coro, shutdown=True)

return inner
1 change: 1 addition & 0 deletions unasync.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
('__aiter__', '__iter__'),
('@pytest.mark.asyncio', ''),
('@pytest.mark.trio', ''),
('@pytest.mark.curio', ''),
('@pytest.mark.usefixtures.*', ''),
]
COMPILED_SUBS = [
Expand Down