From 185375672a60fd70a385eb238a271c8102942785 Mon Sep 17 00:00:00 2001 From: cdeler Date: Mon, 24 Aug 2020 13:42:04 +0300 Subject: [PATCH 01/10] Implemented curio backend (#94) --- httpcore/_backends/auto.py | 4 + httpcore/_backends/curio.py | 189 ++++++++++++++++++++++++++++++++++++ requirements.txt | 1 + setup.cfg | 1 + tests/conftest.py | 53 +++++++++- 5 files changed, 247 insertions(+), 1 deletion(-) create mode 100644 httpcore/_backends/curio.py diff --git a/httpcore/_backends/auto.py b/httpcore/_backends/auto.py index 41ee82eb..a11a3293 100644 --- a/httpcore/_backends/auto.py +++ b/httpcore/_backends/auto.py @@ -24,6 +24,10 @@ def backend(self) -> AsyncBackend: from .trio import TrioBackend self._backend_implementation = TrioBackend() + elif backend == "curio": + from .curio import CurioBackend + + self._backend_implementation = CurioBackend() else: # pragma: nocover raise RuntimeError(f"Unsupported concurrency backend {backend!r}") return self._backend_implementation diff --git a/httpcore/_backends/curio.py b/httpcore/_backends/curio.py new file mode 100644 index 00000000..a8e333d3 --- /dev/null +++ b/httpcore/_backends/curio.py @@ -0,0 +1,189 @@ +from ssl import SSLContext +from typing import Optional + +import curio +import curio.io +from curio.network import _wrap_ssl_client + +from .._exceptions import ( + ConnectError, + ConnectTimeout, + ReadError, + ReadTimeout, + WriteError, + WriteTimeout, + map_exceptions, +) +from .._types import TimeoutDict +from .._utils import get_logger +from .base import AsyncBackend, AsyncLock, AsyncSemaphore, AsyncSocketStream + +logger = get_logger("curio_backend") + +one_day_in_seconds = 60 * 60 * 24 + + +def convert_timeout(value: Optional[float]) -> int: + return int(value) if value is not None else one_day_in_seconds + + +class Lock(AsyncLock): + def __init__(self) -> None: + self._lock = curio.Lock() + + async def acquire(self) -> None: + await self._lock.acquire() + + async def release(self) -> None: + await self._lock.release() + + +class Semaphore(AsyncSemaphore): + def __init__(self, max_value: int, exc_class: type) -> None: + self.max_value = max_value + self.exc_class = exc_class + + @property + def semaphore(self) -> curio.Semaphore: + if not hasattr(self, "_semaphore"): + self._semaphore = curio.Semaphore(value=self.max_value) + return self._semaphore + + async def acquire(self, timeout: float = None) -> None: + await self.semaphore.acquire() + + async def release(self) -> None: + await self.semaphore.release() + + +class SocketStream(AsyncSocketStream): + def __init__(self, socket: curio.io.Socket) -> None: + self.read_lock = curio.Lock() + self.write_lock = curio.Lock() + self.socket = socket + + def get_http_version(self) -> str: + if hasattr(self.socket._socket, "_sslobj"): + ident = self.socket._socket._sslobj.selected_alpn_protocol() + else: + ident = "http/1.1" + return "HTTP/2" if ident == "h2" else "HTTP/1.1" + + async def start_tls( + self, hostname: bytes, ssl_context: SSLContext, timeout: TimeoutDict + ) -> "AsyncSocketStream": + connect_timeout = convert_timeout(timeout.get("connect")) + exc_map = { + curio.TaskTimeout: ConnectTimeout, + curio.CurioError: ConnectError, + OSError: ConnectError, + } + + with map_exceptions(exc_map): + wrapped_sock = await curio.timeout_after( + connect_timeout, + _wrap_ssl_client( + self.socket, + ssl=ssl_context, + server_hostname=hostname, + alpn_protocols=["h2", "http/1.1"], + ), + ) + + return SocketStream(wrapped_sock) + + async def read(self, n: int, timeout: TimeoutDict) -> bytes: + read_timeout = convert_timeout(timeout.get("read")) + exc_map = { + curio.TaskTimeout: ReadTimeout, + curio.CurioError: ReadError, + OSError: ReadError, + } + + with map_exceptions(exc_map): + async with self.read_lock: + socket_stream = self.socket.as_stream() + + return await curio.timeout_after(read_timeout, socket_stream.read(n)) + + async def write(self, data: bytes, timeout: TimeoutDict) -> None: + write_timeout = convert_timeout(timeout.get("write")) + exc_map = { + curio.TaskTimeout: WriteTimeout, + curio.CurioError: WriteError, + OSError: WriteError, + } + + with map_exceptions(exc_map): + async with self.write_lock: + socket_stream = self.socket.as_stream() + await curio.timeout_after(write_timeout, socket_stream.write(data)) + + async def aclose(self) -> None: + await self.socket.close() + + def is_connection_dropped(self) -> bool: + return self.socket._closed + + +class CurioBackend(AsyncBackend): + async def open_tcp_stream( + self, + hostname: bytes, + port: int, + ssl_context: Optional[SSLContext], + timeout: TimeoutDict, + *, + local_address: Optional[str], + ) -> AsyncSocketStream: + connect_timeout = convert_timeout(timeout.get("connect")) + exc_map = { + curio.TaskTimeout: ConnectTimeout, + curio.CurioError: ConnectError, + OSError: ConnectError, + } + host = hostname.decode("ascii") + kwargs = ( + {} if not ssl_context else {"ssl": ssl_context, "server_hostname": host} + ) + + with map_exceptions(exc_map): + sock: curio.io.Socket = await curio.timeout_after( + connect_timeout, curio.open_connection(hostname, port, **kwargs) + ) + + return SocketStream(sock) + + async def open_uds_stream( + self, + path: str, + hostname: bytes, + ssl_context: Optional[SSLContext], + timeout: TimeoutDict, + ) -> AsyncSocketStream: + connect_timeout = convert_timeout(timeout.get("connect")) + exc_map = { + curio.TaskTimeout: ConnectTimeout, + curio.CurioError: ConnectError, + OSError: ConnectError, + } + host = hostname.decode("ascii") + kwargs = ( + {} if not ssl_context else {"ssl": ssl_context, "server_hostname": host} + ) + + with map_exceptions(exc_map): + sock: curio.io.Socket = await curio.timeout_after( + connect_timeout, curio.open_unix_connection(path, **kwargs) + ) + + return SocketStream(sock) + + def create_lock(self) -> AsyncLock: + return Lock() + + def create_semaphore(self, max_value: int, exc_class: type) -> AsyncSemaphore: + return Semaphore(max_value, exc_class) + + async def time(self) -> float: + return float(await curio.clock()) diff --git a/requirements.txt b/requirements.txt index b1176f52..7be6603c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,7 @@ # Optionals trio trio-typing +curio # Docs mkautodoc diff --git a/setup.cfg b/setup.cfg index b3d6a535..fdc2387f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -23,3 +23,4 @@ skip = httpcore/_sync/,tests/sync_tests/ addopts = --cov-report= --cov=httpcore --cov=tests -rxXs markers = copied_from(source, changes=None): mark test as copied from somewhere else, along with a description of changes made to accodomate e.g. our test setup + curio: mark the test as a coroutine, it will be run using a Curio kernel. diff --git a/tests/conftest.py b/tests/conftest.py index 32f87635..703a26ab 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,11 +1,16 @@ import asyncio import contextlib +import functools +import inspect import os import ssl import threading import time import typing +import curio +import curio.debug +import curio.meta import pytest import trustme import uvicorn @@ -18,15 +23,61 @@ PROXY_PORT = 8080 +def _is_coroutine(obj): + """Check to see if an object is really a coroutine.""" + return curio.meta.iscoroutinefunction(obj) or inspect.isgeneratorfunction(obj) + + +@pytest.mark.tryfirst +def pytest_pycollect_makeitem(collector, name, obj): + """A pytest hook to collect coroutines in a test module.""" + if collector.funcnamefilter(name) and _is_coroutine(obj): + item = pytest.Function.from_parent(collector, name=name) + if "curio" in item.keywords: + return list(collector._genfunctions(name, obj)) + + +@pytest.hookimpl(tryfirst=True, hookwrapper=True) +def pytest_pyfunc_call(pyfuncitem): + """Run curio marked test functions in a Curio kernel instead of a normal function call. + """ + if pyfuncitem.get_closest_marker("curio"): + pyfuncitem.obj = wrap_in_sync(pyfuncitem.obj) + yield + + +def wrap_in_sync(func): + """Return a sync wrapper around an async function executing it in a Kernel.""" + + @functools.wraps(func) + def inner(**kwargs): + coro = func(**kwargs) + curio.Kernel().run(coro, shutdown=True) + + return inner + + +# Fixture for explicitly running in Kernel instance. +@pytest.fixture(scope="session") +def kernel(request): + """Provide a Curio Kernel object for running co-routines.""" + k = curio.Kernel(debug=[curio.debug.longblock, curio.debug.logcrash]) + m = curio.monitor.Monitor(k) + request.addfinalizer(lambda: k.run(shutdown=True)) + request.addfinalizer(m.close) + return k + + @pytest.fixture( params=[ pytest.param("asyncio", marks=pytest.mark.asyncio), pytest.param("trio", marks=pytest.mark.trio), + pytest.param("curio", marks=pytest.mark.curio), ] ) def async_environment(request: typing.Any) -> str: """ - Mark a test function to be run on both asyncio and trio. + Mark a test function to be run on asyncio, trio and curio. Equivalent to having a pair of tests, each respectively marked with '@pytest.mark.asyncio' and '@pytest.mark.trio'. From 8738aeec841ac4e7969d20fca0acf5a75a22b1d6 Mon Sep 17 00:00:00 2001 From: cdeler Date: Tue, 25 Aug 2020 13:04:36 +0300 Subject: [PATCH 02/10] Fixing PR remarks (#94) --- httpcore/_backends/curio.py | 13 +++++---- setup.cfg | 1 - tests/conftest.py | 54 +++---------------------------------- tests/marks/__init__.py | 0 tests/marks/curio.py | 53 ++++++++++++++++++++++++++++++++++++ 5 files changed, 63 insertions(+), 58 deletions(-) create mode 100644 tests/marks/__init__.py create mode 100644 tests/marks/curio.py diff --git a/httpcore/_backends/curio.py b/httpcore/_backends/curio.py index a8e333d3..c959ddbb 100644 --- a/httpcore/_backends/curio.py +++ b/httpcore/_backends/curio.py @@ -61,9 +61,10 @@ def __init__(self, socket: curio.io.Socket) -> None: self.read_lock = curio.Lock() self.write_lock = curio.Lock() self.socket = socket + self.stream = socket.as_stream() def get_http_version(self) -> str: - if hasattr(self.socket._socket, "_sslobj"): + if hasattr(self.socket, "_socket") and hasattr(self.socket._socket, "_sslobj"): ident = self.socket._socket._sslobj.selected_alpn_protocol() else: ident = "http/1.1" @@ -102,9 +103,7 @@ async def read(self, n: int, timeout: TimeoutDict) -> bytes: with map_exceptions(exc_map): async with self.read_lock: - socket_stream = self.socket.as_stream() - - return await curio.timeout_after(read_timeout, socket_stream.read(n)) + return await curio.timeout_after(read_timeout, self.stream.read(n)) async def write(self, data: bytes, timeout: TimeoutDict) -> None: write_timeout = convert_timeout(timeout.get("write")) @@ -116,11 +115,11 @@ async def write(self, data: bytes, timeout: TimeoutDict) -> None: with map_exceptions(exc_map): async with self.write_lock: - socket_stream = self.socket.as_stream() - await curio.timeout_after(write_timeout, socket_stream.write(data)) + await curio.timeout_after(write_timeout, self.stream.write(data)) async def aclose(self) -> None: - await self.socket.close() + # we dont need to close the self.socket, since it's closed by stream closing + await self.stream.close() def is_connection_dropped(self) -> bool: return self.socket._closed diff --git a/setup.cfg b/setup.cfg index fdc2387f..b3d6a535 100644 --- a/setup.cfg +++ b/setup.cfg @@ -23,4 +23,3 @@ skip = httpcore/_sync/,tests/sync_tests/ addopts = --cov-report= --cov=httpcore --cov=tests -rxXs markers = copied_from(source, changes=None): mark test as copied from somewhere else, along with a description of changes made to accodomate e.g. our test setup - curio: mark the test as a coroutine, it will be run using a Curio kernel. diff --git a/tests/conftest.py b/tests/conftest.py index 703a26ab..f21f4a7a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,16 +1,11 @@ import asyncio import contextlib -import functools -import inspect import os import ssl import threading import time import typing -import curio -import curio.debug -import curio.meta import pytest import trustme import uvicorn @@ -19,55 +14,14 @@ from httpcore._types import URL +from .marks.curio import kernel # noqa: F401 +from .marks.curio import pytest_pycollect_makeitem # noqa: F401 +from .marks.curio import pytest_pyfunc_call # noqa: F401 + PROXY_HOST = "127.0.0.1" PROXY_PORT = 8080 -def _is_coroutine(obj): - """Check to see if an object is really a coroutine.""" - return curio.meta.iscoroutinefunction(obj) or inspect.isgeneratorfunction(obj) - - -@pytest.mark.tryfirst -def pytest_pycollect_makeitem(collector, name, obj): - """A pytest hook to collect coroutines in a test module.""" - if collector.funcnamefilter(name) and _is_coroutine(obj): - item = pytest.Function.from_parent(collector, name=name) - if "curio" in item.keywords: - return list(collector._genfunctions(name, obj)) - - -@pytest.hookimpl(tryfirst=True, hookwrapper=True) -def pytest_pyfunc_call(pyfuncitem): - """Run curio marked test functions in a Curio kernel instead of a normal function call. - """ - if pyfuncitem.get_closest_marker("curio"): - pyfuncitem.obj = wrap_in_sync(pyfuncitem.obj) - yield - - -def wrap_in_sync(func): - """Return a sync wrapper around an async function executing it in a Kernel.""" - - @functools.wraps(func) - def inner(**kwargs): - coro = func(**kwargs) - curio.Kernel().run(coro, shutdown=True) - - return inner - - -# Fixture for explicitly running in Kernel instance. -@pytest.fixture(scope="session") -def kernel(request): - """Provide a Curio Kernel object for running co-routines.""" - k = curio.Kernel(debug=[curio.debug.longblock, curio.debug.logcrash]) - m = curio.monitor.Monitor(k) - request.addfinalizer(lambda: k.run(shutdown=True)) - request.addfinalizer(m.close) - return k - - @pytest.fixture( params=[ pytest.param("asyncio", marks=pytest.mark.asyncio), diff --git a/tests/marks/__init__.py b/tests/marks/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/marks/curio.py b/tests/marks/curio.py new file mode 100644 index 00000000..33befac7 --- /dev/null +++ b/tests/marks/curio.py @@ -0,0 +1,53 @@ +import functools +import inspect + +import curio +import curio.debug +import curio.meta +import curio.monitor +import pytest + + +def _is_coroutine(obj): + """Check to see if an object is really a coroutine.""" + return curio.meta.iscoroutinefunction(obj) or inspect.isgeneratorfunction(obj) + + +@pytest.mark.tryfirst +def pytest_pycollect_makeitem(collector, name, obj): + """A pytest hook to collect coroutines in a test module.""" + if collector.funcnamefilter(name) and _is_coroutine(obj): + item = pytest.Function.from_parent(collector, name=name) + if "curio" in item.keywords: + return list(collector._genfunctions(name, obj)) + + +@pytest.hookimpl(tryfirst=True, hookwrapper=True) +def pytest_pyfunc_call(pyfuncitem): + """Run curio marked test functions in a Curio kernel instead of a normal function call. + """ + if pyfuncitem.get_closest_marker("curio"): + pyfuncitem.obj = wrap_in_sync(pyfuncitem.obj) + yield + + +def wrap_in_sync(func): + """Return a sync wrapper around an async function executing it in a Kernel.""" + + @functools.wraps(func) + def inner(**kwargs): + coro = func(**kwargs) + curio.Kernel().run(coro, shutdown=True) + + return inner + + +# Fixture for explicitly running in Kernel instance. +@pytest.fixture(scope="session") +def kernel(request): + """Provide a Curio Kernel object for running co-routines.""" + k = curio.Kernel(debug=[curio.debug.longblock, curio.debug.logcrash]) + m = curio.monitor.Monitor(k) + request.addfinalizer(lambda: k.run(shutdown=True)) + request.addfinalizer(m.close) + return k From bcfb95084fa3e652046f9e8484087445e43592fc Mon Sep 17 00:00:00 2001 From: cdeler Date: Tue, 25 Aug 2020 15:59:22 +0300 Subject: [PATCH 03/10] Fixing PR remarks. Mention curio in the same context when the pair of asyncio and trio are mentioned (#94) --- setup.py | 1 + tests/conftest.py | 4 ++-- unasync.py | 1 + 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index f91c4d95..b0b4dfd8 100644 --- a/setup.py +++ b/setup.py @@ -66,6 +66,7 @@ def get_packages(package): "Topic :: Internet :: WWW/HTTP", "Framework :: AsyncIO", "Framework :: Trio", + "Framework :: Curio", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.7", diff --git a/tests/conftest.py b/tests/conftest.py index f21f4a7a..84837b76 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -33,8 +33,8 @@ def async_environment(request: typing.Any) -> str: """ Mark a test function to be run on asyncio, trio and curio. - Equivalent to having a pair of tests, each respectively marked with - '@pytest.mark.asyncio' and '@pytest.mark.trio'. + Equivalent to having three tests, each respectively marked with + '@pytest.mark.asyncio', '@pytest.mark.trio' and '@pytest.mark.curio'. Intended usage: diff --git a/unasync.py b/unasync.py index 84d4d367..b2ad647a 100755 --- a/unasync.py +++ b/unasync.py @@ -20,6 +20,7 @@ ('__aiter__', '__iter__'), ('@pytest.mark.asyncio', ''), ('@pytest.mark.trio', ''), + ('@pytest.mark.curio', ''), ('@pytest.mark.usefixtures.*', ''), ] COMPILED_SUBS = [ From 46bdc252c357c429cc8e7f1df0f2d5c9ffc49d7f Mon Sep 17 00:00:00 2001 From: cdeler Date: Tue, 25 Aug 2020 16:23:34 +0300 Subject: [PATCH 04/10] Fixing PR remarks. Made pytest.mark.curio completely isolated from conftest.py (for now it's easier to add new backend with custom marks) (#94) --- tests/conftest.py | 15 ++++++++++++--- tests/marks/curio.py | 6 +++--- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 84837b76..d4a7b727 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,14 +14,23 @@ from httpcore._types import URL -from .marks.curio import kernel # noqa: F401 -from .marks.curio import pytest_pycollect_makeitem # noqa: F401 -from .marks.curio import pytest_pyfunc_call # noqa: F401 +from .marks.curio import curio_kernel_fixture # noqa: F401 +from .marks.curio import curio_pytest_pycollect_makeitem, curio_pytest_pyfunc_call PROXY_HOST = "127.0.0.1" PROXY_PORT = 8080 +@pytest.mark.tryfirst +def pytest_pycollect_makeitem(collector, name, obj): + curio_pytest_pycollect_makeitem(collector, name, obj) + + +@pytest.hookimpl(tryfirst=True, hookwrapper=True) +def pytest_pyfunc_call(pyfuncitem): + yield from curio_pytest_pyfunc_call(pyfuncitem) + + @pytest.fixture( params=[ pytest.param("asyncio", marks=pytest.mark.asyncio), diff --git a/tests/marks/curio.py b/tests/marks/curio.py index 33befac7..9fba6b5e 100644 --- a/tests/marks/curio.py +++ b/tests/marks/curio.py @@ -14,7 +14,7 @@ def _is_coroutine(obj): @pytest.mark.tryfirst -def pytest_pycollect_makeitem(collector, name, obj): +def curio_pytest_pycollect_makeitem(collector, name, obj): """A pytest hook to collect coroutines in a test module.""" if collector.funcnamefilter(name) and _is_coroutine(obj): item = pytest.Function.from_parent(collector, name=name) @@ -23,7 +23,7 @@ def pytest_pycollect_makeitem(collector, name, obj): @pytest.hookimpl(tryfirst=True, hookwrapper=True) -def pytest_pyfunc_call(pyfuncitem): +def curio_pytest_pyfunc_call(pyfuncitem): """Run curio marked test functions in a Curio kernel instead of a normal function call. """ if pyfuncitem.get_closest_marker("curio"): @@ -44,7 +44,7 @@ def inner(**kwargs): # Fixture for explicitly running in Kernel instance. @pytest.fixture(scope="session") -def kernel(request): +def curio_kernel_fixture(request): """Provide a Curio Kernel object for running co-routines.""" k = curio.Kernel(debug=[curio.debug.longblock, curio.debug.logcrash]) m = curio.monitor.Monitor(k) From 5604bf860059b32cd77790a5c844592b7378e8f0 Mon Sep 17 00:00:00 2001 From: cdeler Date: Wed, 26 Aug 2020 12:10:28 +0300 Subject: [PATCH 05/10] PR review. Updated tests/marks/curio.py (removed unnecessary fixture) Co-authored-by: Florimond Manca --- tests/conftest.py | 1 - tests/marks/curio.py | 15 ++------------- 2 files changed, 2 insertions(+), 14 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index d4a7b727..d9869268 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,7 +14,6 @@ from httpcore._types import URL -from .marks.curio import curio_kernel_fixture # noqa: F401 from .marks.curio import curio_pytest_pycollect_makeitem, curio_pytest_pyfunc_call PROXY_HOST = "127.0.0.1" diff --git a/tests/marks/curio.py b/tests/marks/curio.py index 9fba6b5e..7504ee8f 100644 --- a/tests/marks/curio.py +++ b/tests/marks/curio.py @@ -24,8 +24,8 @@ def curio_pytest_pycollect_makeitem(collector, name, obj): @pytest.hookimpl(tryfirst=True, hookwrapper=True) def curio_pytest_pyfunc_call(pyfuncitem): - """Run curio marked test functions in a Curio kernel instead of a normal function call. - """ + """Run curio marked test functions in a Curio kernel + instead of a normal function call.""" if pyfuncitem.get_closest_marker("curio"): pyfuncitem.obj = wrap_in_sync(pyfuncitem.obj) yield @@ -40,14 +40,3 @@ def inner(**kwargs): curio.Kernel().run(coro, shutdown=True) return inner - - -# Fixture for explicitly running in Kernel instance. -@pytest.fixture(scope="session") -def curio_kernel_fixture(request): - """Provide a Curio Kernel object for running co-routines.""" - k = curio.Kernel(debug=[curio.debug.longblock, curio.debug.logcrash]) - m = curio.monitor.Monitor(k) - request.addfinalizer(lambda: k.run(shutdown=True)) - request.addfinalizer(m.close) - return k From 659bd94e51392d0015e963c2d6596a84422f3161 Mon Sep 17 00:00:00 2001 From: cdeler Date: Wed, 26 Aug 2020 14:06:40 +0300 Subject: [PATCH 06/10] Added "curio" test mark programmatically (#94) --- tests/conftest.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/conftest.py b/tests/conftest.py index d9869268..0cef05a4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -20,6 +20,14 @@ PROXY_PORT = 8080 +def pytest_configure(config): + # register an additional marker + config.addinivalue_line( + "markers", + "curio: mark the test as a coroutine, it will be run using a Curio kernel.", + ) + + @pytest.mark.tryfirst def pytest_pycollect_makeitem(collector, name, obj): curio_pytest_pycollect_makeitem(collector, name, obj) From dc306d00e8235616ac0117179328937a79806a4b Mon Sep 17 00:00:00 2001 From: cdeler Date: Mon, 31 Aug 2020 19:55:55 +0300 Subject: [PATCH 07/10] Fixed PR remarks (#94) Added timeout handling to Semaphore::acquire and tried to avoid private API usage in SocketStream::get_http_version, also changed is_connection_dropped behaviour --- httpcore/_backends/curio.py | 33 ++++++++++++++++++++++++--------- tests/marks/curio.py | 2 +- 2 files changed, 25 insertions(+), 10 deletions(-) diff --git a/httpcore/_backends/curio.py b/httpcore/_backends/curio.py index c959ddbb..ef9a11c1 100644 --- a/httpcore/_backends/curio.py +++ b/httpcore/_backends/curio.py @@ -1,5 +1,7 @@ -from ssl import SSLContext -from typing import Optional +import select +import socket +from ssl import SSLContext, SSLSocket +from typing import Dict, Optional, Type, Union import curio import curio.io @@ -10,6 +12,7 @@ ConnectTimeout, ReadError, ReadTimeout, + TimeoutException, WriteError, WriteTimeout, map_exceptions, @@ -50,7 +53,13 @@ def semaphore(self) -> curio.Semaphore: return self._semaphore async def acquire(self, timeout: float = None) -> None: - await self.semaphore.acquire() + exc_map: Dict[Type[Exception], Type[Exception]] = { + curio.TaskTimeout: TimeoutException, + } + acquire_timeout: int = convert_timeout(timeout) + + with map_exceptions(exc_map): + return await curio.timeout_after(acquire_timeout, self.semaphore.acquire()) async def release(self) -> None: await self.semaphore.release() @@ -64,10 +73,14 @@ def __init__(self, socket: curio.io.Socket) -> None: self.stream = socket.as_stream() def get_http_version(self) -> str: - if hasattr(self.socket, "_socket") and hasattr(self.socket._socket, "_sslobj"): - ident = self.socket._socket._sslobj.selected_alpn_protocol() - else: - ident = "http/1.1" + ident: Optional[str] = "http/1.1" + + if hasattr(self.socket, "_socket"): + raw_socket: Union[SSLSocket, socket.socket] = self.socket._socket + + if isinstance(raw_socket, SSLSocket): + ident = raw_socket.selected_alpn_protocol() + return "HTTP/2" if ident == "h2" else "HTTP/1.1" async def start_tls( @@ -118,11 +131,13 @@ async def write(self, data: bytes, timeout: TimeoutDict) -> None: await curio.timeout_after(write_timeout, self.stream.write(data)) async def aclose(self) -> None: - # we dont need to close the self.socket, since it's closed by stream closing await self.stream.close() + await self.socket.close() def is_connection_dropped(self) -> bool: - return self.socket._closed + rready, _, _ = select.select([self.socket.fileno()], [], [], 0) + + return bool(rready) class CurioBackend(AsyncBackend): diff --git a/tests/marks/curio.py b/tests/marks/curio.py index 7504ee8f..616b2766 100644 --- a/tests/marks/curio.py +++ b/tests/marks/curio.py @@ -19,7 +19,7 @@ def curio_pytest_pycollect_makeitem(collector, name, obj): if collector.funcnamefilter(name) and _is_coroutine(obj): item = pytest.Function.from_parent(collector, name=name) if "curio" in item.keywords: - return list(collector._genfunctions(name, obj)) + return list(collector._genfunctions(name, obj)) # pragma: nocover @pytest.hookimpl(tryfirst=True, hookwrapper=True) From 1f009dac83d4a935eeb3a10ff4e9c38e5d0827e2 Mon Sep 17 00:00:00 2001 From: cdeler Date: Tue, 1 Sep 2020 19:03:03 +0300 Subject: [PATCH 08/10] Fixed PR remarks (#94) Rewrote _wrap_ssl_client using ssl.SSLContext::wrap_socket --- httpcore/_backends/curio.py | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/httpcore/_backends/curio.py b/httpcore/_backends/curio.py index ef9a11c1..1aa91b24 100644 --- a/httpcore/_backends/curio.py +++ b/httpcore/_backends/curio.py @@ -5,7 +5,6 @@ import curio import curio.io -from curio.network import _wrap_ssl_client from .._exceptions import ( ConnectError, @@ -26,6 +25,22 @@ one_day_in_seconds = 60 * 60 * 24 +async def wrap_ssl_client( + sock: curio.io.Socket, + ssl_context: SSLContext, + server_hostname: bytes, +) -> curio.io.Socket: + kwargs = { + "server_hostname": server_hostname, + "do_handshake_on_connect": sock._socket.gettimeout() != 0.0, + } + + socket = curio.io.Socket(ssl_context.wrap_socket(sock._socket, **kwargs)) + await socket.do_handshake() + + return socket + + def convert_timeout(value: Optional[float]) -> int: return int(value) if value is not None else one_day_in_seconds @@ -96,12 +111,7 @@ async def start_tls( with map_exceptions(exc_map): wrapped_sock = await curio.timeout_after( connect_timeout, - _wrap_ssl_client( - self.socket, - ssl=ssl_context, - server_hostname=hostname, - alpn_protocols=["h2", "http/1.1"], - ), + wrap_ssl_client(self.socket, ssl_context, hostname), ) return SocketStream(wrapped_sock) @@ -163,7 +173,8 @@ async def open_tcp_stream( with map_exceptions(exc_map): sock: curio.io.Socket = await curio.timeout_after( - connect_timeout, curio.open_connection(hostname, port, **kwargs) + connect_timeout, + curio.open_connection(hostname, port, **kwargs), ) return SocketStream(sock) From 0895314a57269c12db1a4a40ecdacffccd450e1b Mon Sep 17 00:00:00 2001 From: cdeler Date: Wed, 2 Sep 2020 12:49:50 +0300 Subject: [PATCH 09/10] PR review (#94) Co-authored-by: Florimond Manca --- httpcore/_backends/curio.py | 60 +++++++++++++++---------------------- tests/conftest.py | 1 - 2 files changed, 24 insertions(+), 37 deletions(-) diff --git a/httpcore/_backends/curio.py b/httpcore/_backends/curio.py index 1aa91b24..fcf2f763 100644 --- a/httpcore/_backends/curio.py +++ b/httpcore/_backends/curio.py @@ -1,7 +1,6 @@ import select -import socket from ssl import SSLContext, SSLSocket -from typing import Dict, Optional, Type, Union +from typing import Optional import curio import curio.io @@ -11,7 +10,6 @@ ConnectTimeout, ReadError, ReadTimeout, - TimeoutException, WriteError, WriteTimeout, map_exceptions, @@ -22,27 +20,11 @@ logger = get_logger("curio_backend") -one_day_in_seconds = 60 * 60 * 24 +ONE_DAY_IN_SECONDS = float(60 * 60 * 24) -async def wrap_ssl_client( - sock: curio.io.Socket, - ssl_context: SSLContext, - server_hostname: bytes, -) -> curio.io.Socket: - kwargs = { - "server_hostname": server_hostname, - "do_handshake_on_connect": sock._socket.gettimeout() != 0.0, - } - - socket = curio.io.Socket(ssl_context.wrap_socket(sock._socket, **kwargs)) - await socket.do_handshake() - - return socket - - -def convert_timeout(value: Optional[float]) -> int: - return int(value) if value is not None else one_day_in_seconds +def convert_timeout(value: Optional[float]) -> float: + return value if value is not None else ONE_DAY_IN_SECONDS class Lock(AsyncLock): @@ -68,13 +50,12 @@ def semaphore(self) -> curio.Semaphore: return self._semaphore async def acquire(self, timeout: float = None) -> None: - exc_map: Dict[Type[Exception], Type[Exception]] = { - curio.TaskTimeout: TimeoutException, - } - acquire_timeout: int = convert_timeout(timeout) + timeout = convert_timeout(timeout) - with map_exceptions(exc_map): - return await curio.timeout_after(acquire_timeout, self.semaphore.acquire()) + try: + return await curio.timeout_after(timeout, self.semaphore.acquire()) + except curio.TaskTimeout: + raise self.exc_class() async def release(self) -> None: await self.semaphore.release() @@ -88,15 +69,14 @@ def __init__(self, socket: curio.io.Socket) -> None: self.stream = socket.as_stream() def get_http_version(self) -> str: - ident: Optional[str] = "http/1.1" - if hasattr(self.socket, "_socket"): - raw_socket: Union[SSLSocket, socket.socket] = self.socket._socket + raw_socket = self.socket._socket if isinstance(raw_socket, SSLSocket): ident = raw_socket.selected_alpn_protocol() + return "HTTP/2" if ident == "h2" else "HTTP/1.1" - return "HTTP/2" if ident == "h2" else "HTTP/1.1" + return "HTTP/1.1" async def start_tls( self, hostname: bytes, ssl_context: SSLContext, timeout: TimeoutDict @@ -109,9 +89,17 @@ async def start_tls( } with map_exceptions(exc_map): - wrapped_sock = await curio.timeout_after( + wrapped_sock = curio.io.Socket( + ssl_context.wrap_socket( + self.socket._socket, + do_handshake_on_connect=False, + server_hostname=hostname.decode("ascii"), + ) + ) + + await curio.timeout_after( connect_timeout, - wrap_ssl_client(self.socket, ssl_context, hostname), + wrapped_sock.do_handshake(), ) return SocketStream(wrapped_sock) @@ -168,7 +156,7 @@ async def open_tcp_stream( } host = hostname.decode("ascii") kwargs = ( - {} if not ssl_context else {"ssl": ssl_context, "server_hostname": host} + {} if ssl_context is None else {"ssl": ssl_context, "server_hostname": host} ) with map_exceptions(exc_map): @@ -194,7 +182,7 @@ async def open_uds_stream( } host = hostname.decode("ascii") kwargs = ( - {} if not ssl_context else {"ssl": ssl_context, "server_hostname": host} + {} if ssl_context is None else {"ssl": ssl_context, "server_hostname": host} ) with map_exceptions(exc_map): diff --git a/tests/conftest.py b/tests/conftest.py index 0cef05a4..35aaf9ea 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -21,7 +21,6 @@ def pytest_configure(config): - # register an additional marker config.addinivalue_line( "markers", "curio: mark the test as a coroutine, it will be run using a Curio kernel.", From 38c24c178848f16aa241a367f38b9744f079df3f Mon Sep 17 00:00:00 2001 From: cdeler Date: Sat, 5 Sep 2020 09:47:14 +0300 Subject: [PATCH 10/10] PR review (#94) Co-authored-by: Florimond Manca --- httpcore/_backends/curio.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/httpcore/_backends/curio.py b/httpcore/_backends/curio.py index fcf2f763..8cae3be8 100644 --- a/httpcore/_backends/curio.py +++ b/httpcore/_backends/curio.py @@ -18,7 +18,7 @@ from .._utils import get_logger from .base import AsyncBackend, AsyncLock, AsyncSemaphore, AsyncSocketStream -logger = get_logger("curio_backend") +logger = get_logger(__name__) ONE_DAY_IN_SECONDS = float(60 * 60 * 24) @@ -199,4 +199,4 @@ def create_semaphore(self, max_value: int, exc_class: type) -> AsyncSemaphore: return Semaphore(max_value, exc_class) async def time(self) -> float: - return float(await curio.clock()) + return await curio.clock()