diff --git a/httpcore/_backends/auto.py b/httpcore/_backends/auto.py index 41ee82eb..a11a3293 100644 --- a/httpcore/_backends/auto.py +++ b/httpcore/_backends/auto.py @@ -24,6 +24,10 @@ def backend(self) -> AsyncBackend: from .trio import TrioBackend self._backend_implementation = TrioBackend() + elif backend == "curio": + from .curio import CurioBackend + + self._backend_implementation = CurioBackend() else: # pragma: nocover raise RuntimeError(f"Unsupported concurrency backend {backend!r}") return self._backend_implementation diff --git a/httpcore/_backends/curio.py b/httpcore/_backends/curio.py new file mode 100644 index 00000000..8cae3be8 --- /dev/null +++ b/httpcore/_backends/curio.py @@ -0,0 +1,202 @@ +import select +from ssl import SSLContext, SSLSocket +from typing import Optional + +import curio +import curio.io + +from .._exceptions import ( + ConnectError, + ConnectTimeout, + ReadError, + ReadTimeout, + WriteError, + WriteTimeout, + map_exceptions, +) +from .._types import TimeoutDict +from .._utils import get_logger +from .base import AsyncBackend, AsyncLock, AsyncSemaphore, AsyncSocketStream + +logger = get_logger(__name__) + +ONE_DAY_IN_SECONDS = float(60 * 60 * 24) + + +def convert_timeout(value: Optional[float]) -> float: + return value if value is not None else ONE_DAY_IN_SECONDS + + +class Lock(AsyncLock): + def __init__(self) -> None: + self._lock = curio.Lock() + + async def acquire(self) -> None: + await self._lock.acquire() + + async def release(self) -> None: + await self._lock.release() + + +class Semaphore(AsyncSemaphore): + def __init__(self, max_value: int, exc_class: type) -> None: + self.max_value = max_value + self.exc_class = exc_class + + @property + def semaphore(self) -> curio.Semaphore: + if not hasattr(self, "_semaphore"): + self._semaphore = curio.Semaphore(value=self.max_value) + return self._semaphore + + async def acquire(self, timeout: float = None) -> None: + timeout = convert_timeout(timeout) + + try: + return await curio.timeout_after(timeout, self.semaphore.acquire()) + except curio.TaskTimeout: + raise self.exc_class() + + async def release(self) -> None: + await self.semaphore.release() + + +class SocketStream(AsyncSocketStream): + def __init__(self, socket: curio.io.Socket) -> None: + self.read_lock = curio.Lock() + self.write_lock = curio.Lock() + self.socket = socket + self.stream = socket.as_stream() + + def get_http_version(self) -> str: + if hasattr(self.socket, "_socket"): + raw_socket = self.socket._socket + + if isinstance(raw_socket, SSLSocket): + ident = raw_socket.selected_alpn_protocol() + return "HTTP/2" if ident == "h2" else "HTTP/1.1" + + return "HTTP/1.1" + + async def start_tls( + self, hostname: bytes, ssl_context: SSLContext, timeout: TimeoutDict + ) -> "AsyncSocketStream": + connect_timeout = convert_timeout(timeout.get("connect")) + exc_map = { + curio.TaskTimeout: ConnectTimeout, + curio.CurioError: ConnectError, + OSError: ConnectError, + } + + with map_exceptions(exc_map): + wrapped_sock = curio.io.Socket( + ssl_context.wrap_socket( + self.socket._socket, + do_handshake_on_connect=False, + server_hostname=hostname.decode("ascii"), + ) + ) + + await curio.timeout_after( + connect_timeout, + wrapped_sock.do_handshake(), + ) + + return SocketStream(wrapped_sock) + + async def read(self, n: int, timeout: TimeoutDict) -> bytes: + read_timeout = convert_timeout(timeout.get("read")) + exc_map = { + curio.TaskTimeout: ReadTimeout, + curio.CurioError: ReadError, + OSError: ReadError, + } + + with map_exceptions(exc_map): + async with self.read_lock: + return await curio.timeout_after(read_timeout, self.stream.read(n)) + + async def write(self, data: bytes, timeout: TimeoutDict) -> None: + write_timeout = convert_timeout(timeout.get("write")) + exc_map = { + curio.TaskTimeout: WriteTimeout, + curio.CurioError: WriteError, + OSError: WriteError, + } + + with map_exceptions(exc_map): + async with self.write_lock: + await curio.timeout_after(write_timeout, self.stream.write(data)) + + async def aclose(self) -> None: + await self.stream.close() + await self.socket.close() + + def is_connection_dropped(self) -> bool: + rready, _, _ = select.select([self.socket.fileno()], [], [], 0) + + return bool(rready) + + +class CurioBackend(AsyncBackend): + async def open_tcp_stream( + self, + hostname: bytes, + port: int, + ssl_context: Optional[SSLContext], + timeout: TimeoutDict, + *, + local_address: Optional[str], + ) -> AsyncSocketStream: + connect_timeout = convert_timeout(timeout.get("connect")) + exc_map = { + curio.TaskTimeout: ConnectTimeout, + curio.CurioError: ConnectError, + OSError: ConnectError, + } + host = hostname.decode("ascii") + kwargs = ( + {} if ssl_context is None else {"ssl": ssl_context, "server_hostname": host} + ) + + with map_exceptions(exc_map): + sock: curio.io.Socket = await curio.timeout_after( + connect_timeout, + curio.open_connection(hostname, port, **kwargs), + ) + + return SocketStream(sock) + + async def open_uds_stream( + self, + path: str, + hostname: bytes, + ssl_context: Optional[SSLContext], + timeout: TimeoutDict, + ) -> AsyncSocketStream: + connect_timeout = convert_timeout(timeout.get("connect")) + exc_map = { + curio.TaskTimeout: ConnectTimeout, + curio.CurioError: ConnectError, + OSError: ConnectError, + } + host = hostname.decode("ascii") + kwargs = ( + {} if ssl_context is None else {"ssl": ssl_context, "server_hostname": host} + ) + + with map_exceptions(exc_map): + sock: curio.io.Socket = await curio.timeout_after( + connect_timeout, curio.open_unix_connection(path, **kwargs) + ) + + return SocketStream(sock) + + def create_lock(self) -> AsyncLock: + return Lock() + + def create_semaphore(self, max_value: int, exc_class: type) -> AsyncSemaphore: + return Semaphore(max_value, exc_class) + + async def time(self) -> float: + return await curio.clock() diff --git a/requirements.txt b/requirements.txt index b1176f52..7be6603c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,7 @@ # Optionals trio trio-typing +curio # Docs mkautodoc diff --git a/setup.py b/setup.py index f91c4d95..b0b4dfd8 100644 --- a/setup.py +++ b/setup.py @@ -66,6 +66,7 @@ def get_packages(package): "Topic :: Internet :: WWW/HTTP", "Framework :: AsyncIO", "Framework :: Trio", + "Framework :: Curio", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.7", diff --git a/tests/conftest.py b/tests/conftest.py index 32f87635..35aaf9ea 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,22 +14,42 @@ from httpcore._types import URL +from .marks.curio import curio_pytest_pycollect_makeitem, curio_pytest_pyfunc_call + PROXY_HOST = "127.0.0.1" PROXY_PORT = 8080 +def pytest_configure(config): + config.addinivalue_line( + "markers", + "curio: mark the test as a coroutine, it will be run using a Curio kernel.", + ) + + +@pytest.mark.tryfirst +def pytest_pycollect_makeitem(collector, name, obj): + curio_pytest_pycollect_makeitem(collector, name, obj) + + +@pytest.hookimpl(tryfirst=True, hookwrapper=True) +def pytest_pyfunc_call(pyfuncitem): + yield from curio_pytest_pyfunc_call(pyfuncitem) + + @pytest.fixture( params=[ pytest.param("asyncio", marks=pytest.mark.asyncio), pytest.param("trio", marks=pytest.mark.trio), + pytest.param("curio", marks=pytest.mark.curio), ] ) def async_environment(request: typing.Any) -> str: """ - Mark a test function to be run on both asyncio and trio. + Mark a test function to be run on asyncio, trio and curio. - Equivalent to having a pair of tests, each respectively marked with - '@pytest.mark.asyncio' and '@pytest.mark.trio'. + Equivalent to having three tests, each respectively marked with + '@pytest.mark.asyncio', '@pytest.mark.trio' and '@pytest.mark.curio'. Intended usage: diff --git a/tests/marks/__init__.py b/tests/marks/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/marks/curio.py b/tests/marks/curio.py new file mode 100644 index 00000000..616b2766 --- /dev/null +++ b/tests/marks/curio.py @@ -0,0 +1,42 @@ +import functools +import inspect + +import curio +import curio.debug +import curio.meta +import curio.monitor +import pytest + + +def _is_coroutine(obj): + """Check to see if an object is really a coroutine.""" + return curio.meta.iscoroutinefunction(obj) or inspect.isgeneratorfunction(obj) + + +@pytest.mark.tryfirst +def curio_pytest_pycollect_makeitem(collector, name, obj): + """A pytest hook to collect coroutines in a test module.""" + if collector.funcnamefilter(name) and _is_coroutine(obj): + item = pytest.Function.from_parent(collector, name=name) + if "curio" in item.keywords: + return list(collector._genfunctions(name, obj)) # pragma: nocover + + +@pytest.hookimpl(tryfirst=True, hookwrapper=True) +def curio_pytest_pyfunc_call(pyfuncitem): + """Run curio marked test functions in a Curio kernel + instead of a normal function call.""" + if pyfuncitem.get_closest_marker("curio"): + pyfuncitem.obj = wrap_in_sync(pyfuncitem.obj) + yield + + +def wrap_in_sync(func): + """Return a sync wrapper around an async function executing it in a Kernel.""" + + @functools.wraps(func) + def inner(**kwargs): + coro = func(**kwargs) + curio.Kernel().run(coro, shutdown=True) + + return inner diff --git a/unasync.py b/unasync.py index 84d4d367..b2ad647a 100755 --- a/unasync.py +++ b/unasync.py @@ -20,6 +20,7 @@ ('__aiter__', '__iter__'), ('@pytest.mark.asyncio', ''), ('@pytest.mark.trio', ''), + ('@pytest.mark.curio', ''), ('@pytest.mark.usefixtures.*', ''), ] COMPILED_SUBS = [