Skip to content

Commit

Permalink
Merge pull request #214 from njsmith/fix-ssl-test-threading
Browse files Browse the repository at this point in the history
The echo server thread in test_ssl had unnoticed errors
  • Loading branch information
njsmith authored Jun 15, 2017
2 parents a7c3e5f + e976598 commit 190ebd4
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 37 deletions.
9 changes: 6 additions & 3 deletions trio/_core/tests/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import pytest
import attr

from .tutil import check_sequence_matches
from .tutil import check_sequence_matches, gc_collect_harder
from ...testing import (
wait_all_tasks_blocked, Sequencer, assert_yields,
)
Expand Down Expand Up @@ -41,7 +41,8 @@ def ignore_coroutine_never_awaited_warnings():
finally:
# Make sure to trigger any coroutine __del__ methods now, before
# we leave the context manager.
gc.collect()
gc_collect_harder()


def test_basic():
async def trivial(x):
Expand Down Expand Up @@ -893,7 +894,7 @@ async def main():
# Because this crashes, various __del__ methods print complaints on
# stderr. Make sure that they get run now, so the output is attached to
# this test.
gc.collect()
gc_collect_harder()


def test_error_in_run_loop():
Expand Down Expand Up @@ -1500,6 +1501,8 @@ async def f(): # pragma: no cover
bad_call(len, [1, 2, 3])
assert "appears to be synchronous" in str(excinfo.value)

# Make sure no references are kept around to keep anything alive
del excinfo

def test_calling_asyncio_function_gives_nice_error():
async def misguided():
Expand Down
15 changes: 14 additions & 1 deletion trio/_core/tests/tutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,26 @@

import pytest

import gc

# See trio/tests/conftest.py for the other half of this
slow = pytest.mark.skipif(
not pytest.config.getoption("--run-slow", True),
reason="use --run-slow to run slow tests",
)

from ... import _core
def gc_collect_harder():
# In the test suite we sometimes want to call gc.collect() to make sure
# that any objects with noisy __del__ methods (e.g. unawaited coroutines)
# get collected before we continue, so their noise doesn't leak into
# unrelated tests.
#
# On PyPy, coroutine objects (for example) can survive at least 1 round of
# garbage collection, because executing their __del__ method to print the
# warning can cause them to be resurrected. So we call collect a few times
# to make sure.
for _ in range(4):
gc.collect()

# template is like:
# [1, {2.1, 2.2}, 3] -> matches [1, 2.1, 3] or [1, 2.2, 3]
Expand Down
90 changes: 57 additions & 33 deletions trio/tests/test_ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,18 @@
import socket as stdlib_socket
import ssl as stdlib_ssl
from contextlib import contextmanager
from functools import partial

from OpenSSL import SSL
from async_generator import async_generator, yield_

import trio
from .. import _core
from .. import _network
from .._streams import BrokenStreamError, ClosedStreamError
from .. import ssl as tssl
from .. import socket as tsocket
from .._util import UnLock
from .._util import UnLock, acontextmanager

from .._core.tests.tutil import slow

Expand Down Expand Up @@ -55,16 +57,37 @@

CLIENT_CTX = stdlib_ssl.create_default_context(cafile=CA)

# workaround for
# https://bitbucket.org/pypy/pypy/issues/2578/
# (fortunately only affects our test suite, not the actual ssl.py)
# bug is in 5.8.0-beta and at least some of the 5.9.0-alpha nightlies, but
# will hopefully be fixed soon
import sys
WORKAROUND_PYPY_BUG = False
if (hasattr(sys, "pypy_version_info")
and (sys.pypy_version_info < (5, 9)
or sys.pypy_version_info[:4] == (5, 9, 0, "alpha"))):
WORKAROUND_PYPY_BUG = True

# The blocking socket server.
def ssl_echo_serve_sync(sock, *, expect_fail=False):
try:
wrapped = SERVER_CTX.wrap_socket(sock, server_side=True)
wrapped = SERVER_CTX.wrap_socket(
sock, server_side=True, suppress_ragged_eofs=False)
wrapped.do_handshake()
while True:
data = wrapped.recv(4096)
if not data:
# graceful shutdown
wrapped.unwrap()
# other side has initiated a graceful shutdown; we try to
# respond in kind but it's legal for them to have already gone
# away.
exceptions = (BrokenPipeError,)
if WORKAROUND_PYPY_BUG:
exceptions += (stdlib_ssl.SSLEOFError,)
try:
wrapped.unwrap()
except exceptions:
pass
return
wrapped.sendall(data)
except Exception as exc:
Expand All @@ -74,37 +97,38 @@ def ssl_echo_serve_sync(sock, *, expect_fail=False):
raise
else:
if expect_fail: # pragma: no cover
print("failed to fail?!")
raise RuntimeError("failed to fail?")


# Fixture that gives a raw socket connected to a trio-test-1 echo server
# (running in a thread). Useful for testing making connections with different
# SSLContexts.
@contextmanager
def ssl_echo_server_raw(**kwargs):
#
# This way of writing it is pretty janky, with the nursery hidden inside the
# fixture and no proper parental supervision. Don't copy this code; it was
# written this way before we knew better.
@acontextmanager
@async_generator
async def ssl_echo_server_raw(**kwargs):
a, b = stdlib_socket.socketpair()
with a, b:
t = threading.Thread(
target=ssl_echo_serve_sync,
args=(b,),
kwargs=kwargs,
)
t.start()

yield _network.SocketStream(tsocket.from_stdlib_socket(a))

# exiting the context manager closes the sockets, which should force the
# thread to shut down (possibly with an error)
t.join()
async with trio.open_nursery() as nursery:
with a, b:
nursery.spawn(
trio.run_in_worker_thread,
partial(ssl_echo_serve_sync, b, **kwargs))

await yield_(_network.SocketStream(tsocket.from_stdlib_socket(a)))
# exiting the 'with a, b' context manager closes the sockets, which
# should force the thread to shut down (possibly with an error)

# Fixture that gives a properly set up SSLStream connected to a trio-test-1
# echo server (running in a thread)
@contextmanager
def ssl_echo_server(**kwargs):
with ssl_echo_server_raw(**kwargs) as sock:
yield tssl.SSLStream(
sock, CLIENT_CTX, server_hostname="trio-test-1.example.org")
@acontextmanager
@async_generator
async def ssl_echo_server(**kwargs):
async with ssl_echo_server_raw(**kwargs) as sock:
await yield_(tssl.SSLStream(
sock, CLIENT_CTX, server_hostname="trio-test-1.example.org"))


# The weird in-memory server ... thing.
Expand Down Expand Up @@ -326,14 +350,14 @@ def test_exports():
# certificate checking (even though this is really Python's responsibility)
async def test_ssl_client_basics():
# Everything OK
with ssl_echo_server() as s:
async with ssl_echo_server() as s:
assert not s.server_side
await s.send_all(b"x")
assert await s.receive_some(1) == b"x"
await s.graceful_close()

# Didn't configure the CA file, should fail
with ssl_echo_server_raw(expect_fail=True) as sock:
async with ssl_echo_server_raw(expect_fail=True) as sock:
client_ctx = stdlib_ssl.create_default_context()
s = tssl.SSLStream(
sock, client_ctx, server_hostname="trio-test-1.example.org")
Expand All @@ -343,7 +367,7 @@ async def test_ssl_client_basics():
assert isinstance(excinfo.value.__cause__, tssl.SSLError)

# Trusted CA, but wrong host name
with ssl_echo_server_raw(expect_fail=True) as sock:
async with ssl_echo_server_raw(expect_fail=True) as sock:
s = tssl.SSLStream(
sock, CLIENT_CTX, server_hostname="trio-test-2.example.org")
assert not s.server_side
Expand Down Expand Up @@ -380,7 +404,7 @@ def client():


async def test_attributes():
with ssl_echo_server_raw(expect_fail=True) as sock:
async with ssl_echo_server_raw(expect_fail=True) as sock:
good_ctx = CLIENT_CTX
bad_ctx = stdlib_ssl.create_default_context()
s = tssl.SSLStream(
Expand Down Expand Up @@ -470,7 +494,7 @@ async def receiver(s):
chunk = await s.receive_some(CHUNK_SIZE // 2)
received += chunk

with ssl_echo_server() as s:
async with ssl_echo_server() as s:
async with _core.open_nursery() as nursery:
nursery.spawn(sender, s)
nursery.spawn(receiver, s)
Expand Down Expand Up @@ -672,7 +696,7 @@ async def wait_send_all_might_not_block(self):


async def test_checkpoints():
with ssl_echo_server() as s:
async with ssl_echo_server() as s:
with assert_yields():
await s.do_handshake()
with assert_yields():
Expand All @@ -694,14 +718,14 @@ async def test_checkpoints():
with assert_yields():
await s.unwrap()

with ssl_echo_server() as s:
async with ssl_echo_server() as s:
await s.do_handshake()
with assert_yields():
await s.graceful_close()


async def test_send_all_empty_string():
with ssl_echo_server() as s:
async with ssl_echo_server() as s:
await s.do_handshake()

# underlying SSLObject interprets writing b"" as indicating an EOF,
Expand Down

0 comments on commit 190ebd4

Please sign in to comment.