From 72ff6c575025b2196458e90d69aa8681baf5ec33 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 22 Mar 2021 09:51:12 +0000 Subject: [PATCH] Fix race when using HTTP proxy. (#181) After `start_tls` we need to manually call `connection_made` on the protocol to tell it about the new underlying transport. This gives a brief window where the protocol can receive data without having a transport to write to, causing issues with the APNS connections where it assumes it can write once it starts reading data. We fix this by wrapping the protocol in a buffer that simply buffers incoming data until `connection_made` is called. Co-authored-by: Patrick Cloke --- changelog.d/181.bugfix | 1 + sygnal/helper/proxy/proxy_asyncio.py | 50 ++++++++- tests/asyncio_test_helpers.py | 38 +++++++ tests/test_httpproxy_asyncio.py | 150 +++++++++++++++++++++++++++ 4 files changed, 237 insertions(+), 2 deletions(-) create mode 100644 changelog.d/181.bugfix diff --git a/changelog.d/181.bugfix b/changelog.d/181.bugfix new file mode 100644 index 00000000..d1595022 --- /dev/null +++ b/changelog.d/181.bugfix @@ -0,0 +1 @@ +Fix bug when using a HTTP proxy where connections would sometimes fail to establish. diff --git a/sygnal/helper/proxy/proxy_asyncio.py b/sygnal/helper/proxy/proxy_asyncio.py index 73f43377..39f3049b 100644 --- a/sygnal/helper/proxy/proxy_asyncio.py +++ b/sygnal/helper/proxy/proxy_asyncio.py @@ -22,6 +22,8 @@ from ssl import Purpose, SSLContext, create_default_context from typing import Callable, Optional, Tuple, Union +import attr + from sygnal.exceptions import ProxyConnectError from sygnal.helper.proxy import decompose_http_proxy_url @@ -146,19 +148,25 @@ async def switch_over_when_ready(self) -> Tuple[BaseTransport, Protocol]: # unreachable raise RuntimeError("Left over bytes should not occur with TLS") + # There is a race where the `new_protocol` may get given data before + # we manage to call `connection_made` on it, which can lead to + # exceptions if the protocol then tries to write to the transport + # that is has been given yet. + buffered_protocol = _BufferedWrapperProtocol(new_protocol) + # be careful not to use the `transport` ever again after passing it # to start_tls — we overwrite our variable with the TLS-wrapped # transport to avoid that! transport = await self._event_loop.start_tls( self._transport, - new_protocol, + buffered_protocol, self._sslcontext, server_hostname=self._target_hostport[0], ) # start_tls does NOT call connection_made on new_protocol, so we # must do it ourselves - new_protocol.connection_made(transport) + buffered_protocol.connection_made(transport) else: # no wrapping required for non-TLS transport = self._transport @@ -171,6 +179,8 @@ async def switch_over_when_ready(self) -> Tuple[BaseTransport, Protocol]: # pass over dangling bytes if applicable new_protocol.data_received(left_over_bytes) + logger.debug("Finished switching protocol") + return transport, new_protocol def data_received(self, data: bytes) -> None: @@ -332,3 +342,39 @@ def __getattr__(self, item): We use this to delegate other method calls to the real EventLoop. """ return getattr(self._wrapped_loop, item) + + +@attr.s(slots=True, auto_attribs=True) +class _BufferedWrapperProtocol(Protocol): + """Wraps a protocol to buffer any incoming data received before + `connection_made` is called. + """ + + _protocol: Protocol + _connected: bool = False + _buffer: bytearray = attr.Factory(bytearray) + + def connection_made(self, transport: BaseTransport): + self._connected = True + self._protocol.connection_made(transport) + if self._buffer: + self._protocol.data_received(self._buffer) + self._buffer = bytearray() + + def connection_lost(self, exc: Optional[Exception]): + self._protocol.connection_lost(exc) + + def pause_writing(self): + self._protocol.pause_writing() + + def resume_writing(self): + self._protocol.resume_writing() + + def data_received(self, data: bytes): + if self._connected: + self._protocol.data_received(data) + else: + self._buffer.extend(data) + + def eof_received(self): + return self._protocol.eof_received() diff --git a/tests/asyncio_test_helpers.py b/tests/asyncio_test_helpers.py index 9096f34d..20a7fdf2 100644 --- a/tests/asyncio_test_helpers.py +++ b/tests/asyncio_test_helpers.py @@ -70,6 +70,14 @@ def call_later( ): self.call_at(self._time + delay, callback, *args, context=context) + # We're meant to return a canceller, but can cheat and return a no-op one + # instead. + class _Canceller: + def cancel(self): + pass + + return _Canceller() + def call_at( self, when: float, @@ -114,6 +122,10 @@ def __init__(self): # Whether this transport was closed self.closed = False + # We need to explicitly mark that this connection allows start tls, + # otherwise `loop.start_tls` will raise an exception. + self._start_tls_compatible = True + def reset_mock(self) -> None: self.buffer = b"" self.eofed = False @@ -189,3 +201,29 @@ def write(self, data: bytes) -> None: self.transport.write(data) else: self._to_transmit += data + + +class EchoProtocol(Protocol): + """A protocol that immediately echoes all data it receives""" + + def __init__(self): + self._to_transmit = b"" + self.received_bytes = b"" + self.transport = None + + def data_received(self, data: bytes) -> None: + self.received_bytes += data + assert self.transport + self.transport.write(data) + + def connection_made(self, transport: transports.BaseTransport) -> None: + assert isinstance(transport, Transport) + self.transport = transport + if self._to_transmit: + transport.write(self._to_transmit) + + def write(self, data: bytes) -> None: + if self.transport: + self.transport.write(data) + else: + self._to_transmit += data diff --git a/tests/test_httpproxy_asyncio.py b/tests/test_httpproxy_asyncio.py index e88189ae..d173ddfb 100644 --- a/tests/test_httpproxy_asyncio.py +++ b/tests/test_httpproxy_asyncio.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import asyncio +import ssl from asyncio import AbstractEventLoop, BaseTransport, Protocol, Task from typing import Optional, Tuple, cast @@ -21,10 +22,16 @@ from tests import testutils from tests.asyncio_test_helpers import ( + EchoProtocol, MockProtocol, MockTransport, TimelessEventLoopWrapper, ) +from tests.twisted_test_helpers import ( + create_test_cert_file, + get_test_ca_cert_file, + get_test_key_file, +) class AsyncioHttpProxyTest(testutils.TestCase): @@ -191,3 +198,146 @@ def test_connect_failure(self): # check our protocol did not receive anything, because it was an HTTP- # level error, not actually a connection to our target. self.assertEqual(fake_protocol.received_bytes, b"") + + +class AsyncioHttpProxyTLSTest(testutils.TestCase): + """Test that using a HTTPS proxy works. + + This is a bit convoluted to try and test that we don't hit a race where the + new client protocol can receive data before `connection_made` is called, + which can cause problems if it tries to write to the connection that it + hasn't been given yet. + """ + + def config_setup(self, config): + super().config_setup(config) + config["apps"]["com.example.spqr"] = { + "type": "tests.test_pushgateway_api_v1.TestPushkin" + } + self.base_loop = asyncio.new_event_loop() + augmented_loop = TimelessEventLoopWrapper(self.base_loop) # type: ignore + asyncio.set_event_loop(cast(AbstractEventLoop, augmented_loop)) + + self.loop = augmented_loop + + self.proxy_context = ssl.create_default_context() + self.proxy_context.load_verify_locations(get_test_ca_cert_file()) + self.proxy_context.set_ciphers("DEFAULT") + + def make_fake_proxy( + self, host: str, port: int, proxy_credentials: Optional[Tuple[str, str]], + ) -> Tuple[EchoProtocol, MockTransport, "Task[Tuple[BaseTransport, Protocol]]"]: + # Task[Tuple[MockTransport, MockProtocol]] + + # make a fake proxy + fake_proxy = MockTransport() + + # We connect with an echo protocol to test that we can always write when + # we receive data. + fake_protocol = EchoProtocol() + + # create a HTTP CONNECT proxy client protocol + http_connect_protocol = HttpConnectProtocol( + target_hostport=(host, port), + proxy_credentials=proxy_credentials, + protocol_factory=lambda: fake_protocol, + sslcontext=self.proxy_context, + loop=None, + ) + switch_over_task = self.loop.create_task( + http_connect_protocol.switch_over_when_ready() + ) + # check the task is not somehow already marked as done before we even + # receive anything. + self.assertFalse(switch_over_task.done()) + # connect the proxy client to the proxy + fake_proxy.set_protocol(http_connect_protocol) + http_connect_protocol.connection_made(fake_proxy) + return fake_protocol, fake_proxy, switch_over_task + + def test_connect_no_credentials(self): + """ + Tests the proxy connection procedure when there is no basic auth. + """ + host = "example.org" + port = 443 + proxy_credentials = None + fake_protocol, fake_proxy, switch_over_task = self.make_fake_proxy( + host, port, proxy_credentials + ) + + # Check that the proxy got the proper CONNECT request. + self.assertEqual(fake_proxy.buffer, b"CONNECT example.org:443 HTTP/1.0\r\n\r\n") + # Reset the proxy mock + fake_proxy.reset_mock() + + # pretend we got a happy response + fake_proxy.pretend_to_receive(b"HTTP/1.0 200 Connection Established\r\n\r\n") + + # Since we're talking TLS we need to create a server TLS connection that + # we can use to talk to each other. + context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + context.load_cert_chain( + create_test_cert_file([b"DNS:example.org"]), keyfile=get_test_key_file() + ) + context.set_ciphers("DEFAULT") + + # Note that we have to use a different event loop wrapper here as we + # want that server side setup to finish before the client side setup, so + # that we can trigger any races. + server_loop = TimelessEventLoopWrapper(self.base_loop) # type: ignore + server_transport = MockTransport() + proxy_ft = server_loop.create_task( + server_loop.start_tls( + server_transport, + MockProtocol(), + context, + server_hostname=host, + server_side=True, + ) + ) + + # Advance event loop because we have to let coroutines be executed + self.loop.advance(1.0) + server_loop.advance(1.0) + + # We manually copy the bytes between the fake_proxy transport and our + # created TLS transport. We do this for each step in the TLS handshake. + + # Client -> Server + server_transport.pretend_to_receive(fake_proxy.buffer) + fake_proxy.buffer = b"" + + # Server -> Client + fake_proxy.pretend_to_receive(server_transport.buffer) + server_transport.buffer = b"" + + # Client -> Server + server_transport.pretend_to_receive(fake_proxy.buffer) + fake_proxy.buffer = b"" + + # We *only* advance the server side loop so that we can send data before + # the client has called `connection_made` on the new protocol. + server_loop.advance(0.1) + + # Server -> Client application data. + server_plain_transport = proxy_ft.result() + server_plain_transport.write(b"begin beep boop\r\n\r\n~~ :) ~~") + fake_proxy.pretend_to_receive(server_transport.buffer) + server_transport.buffer = b"" + + self.loop.advance(1.0) + + # *now* we should have switched over from the HTTP CONNECT protocol + # to the user protocol (in our case, a MockProtocol). + self.assertTrue(switch_over_task.done()) + + transport, protocol = switch_over_task.result() + + # check it was our protocol that was returned + self.assertIs(protocol, fake_protocol) + + # check our protocol received exactly the bytes meant for it + self.assertEqual( + fake_protocol.received_bytes, b"begin beep boop\r\n\r\n~~ :) ~~" + )