Skip to content

Commit

Permalink
Tests for the HTTP Proxy support (asyncio and common only) (#151)
Browse files Browse the repository at this point in the history
Co-authored-by: Patrick Cloke <clokep@users.noreply.github.com>
  • Loading branch information
reivilibre and clokep committed Sep 1, 2020
1 parent b2a74db commit 53015c7
Show file tree
Hide file tree
Showing 5 changed files with 440 additions and 1 deletion.
1 change: 1 addition & 0 deletions changelog.d/151.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add tests for HTTP Proxy support.
191 changes: 191 additions & 0 deletions tests/asyncio_test_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
import logging
import types
from asyncio import AbstractEventLoop, transports
from asyncio.protocols import BaseProtocol, Protocol
from asyncio.transports import Transport
from contextvars import Context
from typing import Any, Callable, List, Optional, Tuple

logger = logging.getLogger(__name__)


class TimelessEventLoopWrapper:
@property # type: ignore
def __class__(self):
"""
Fakes isinstance(this, AbstractEventLoop) so we can set_event_loop
without fail.
"""
return self._wrapped_loop.__class__

def __init__(self, wrapped_loop: AbstractEventLoop):
self._wrapped_loop = wrapped_loop
self._time = 0.0
self._to_be_called: List[Tuple[float, Any, Any, Any]] = []

def advance(self, time_delta: float):
target_time = self._time + time_delta
logger.debug(
"advancing from %f by %f (%d in queue)",
self._time,
time_delta,
len(self._to_be_called),
)
while self._time < target_time and self._to_be_called:
# pop off the next callback from the queue
next_time, next_callback, args, _context = self._to_be_called[0]
if next_time > target_time:
# this isn't allowed to run yet
break
logger.debug("callback at %f on %r", next_time, next_callback)
self._to_be_called = self._to_be_called[1:]
self._time = next_time
next_callback(*args)

# no more tasks can run now but advance to the time anyway
self._time = target_time

def __getattr__(self, item: str):
"""
We use this to delegate other method calls to the real EventLoop.
"""
value = getattr(self._wrapped_loop, item)
if isinstance(value, types.MethodType):
# rebind this method to be called on us
# this makes the wrapped class use our overridden methods when
# available.
# we have to do this because methods are bound to the underlying
# event loop, which will call `self.call_later` or something
# which won't normally hit us because we are not an actual subtype.
return types.MethodType(value.__func__, self)
else:
return value

def call_later(
self,
delay: float,
callback: Callable,
*args: Any,
context: Optional[Context] = None,
):
self.call_at(self._time + delay, callback, *args, context=context)

def call_at(
self,
when: float,
callback: Callable,
*args: Any,
context: Optional[Context] = None,
):
logger.debug(f"Calling {callback} at %f...", when)
self._to_be_called.append((when, callback, args, context))

# re-sort list in ascending time order
self._to_be_called.sort(key=lambda x: x[0])

def call_soon(
self, callback: Callable, *args: Any, context: Optional[Context] = None
):
return self.call_later(0, callback, *args, context=context)

def time(self) -> float:
return self._time


class MockTransport(Transport):
"""
A transport intended to be driven by tests.
Stores received data into a buffer.
"""

def __init__(self):
# Holds bytes received
self.buffer = b""

# Whether we reached the end of file/stream
self.eofed = False

# Whether the connection was aborted
self.aborted = False

# The protocol attached to this transport
self.protocol = None

# Whether this transport was closed
self.closed = False

def reset_mock(self) -> None:
self.buffer = b""
self.eofed = False
self.aborted = False
self.closed = False

def is_reading(self) -> bool:
return True

def pause_reading(self) -> None:
pass # NOP

def resume_reading(self) -> None:
pass # NOP

def set_write_buffer_limits(self, high: int = None, low: int = None) -> None:
pass # NOP

def get_write_buffer_size(self) -> int:
"""Return the current size of the write buffer."""
raise NotImplementedError

def write(self, data: bytes) -> None:
self.buffer += data

def write_eof(self) -> None:
self.eofed = True

def can_write_eof(self) -> bool:
return True

def abort(self) -> None:
self.aborted = True

def pretend_to_receive(self, data: bytes) -> None:
proto = self.get_protocol()
assert isinstance(proto, Protocol)
proto.data_received(data)

def set_protocol(self, protocol: BaseProtocol) -> None:
self.protocol = protocol

def get_protocol(self) -> BaseProtocol:
assert isinstance(self.protocol, BaseProtocol)
return self.protocol

def close(self) -> None:
self.closed = True


class MockProtocol(Protocol):
"""
A protocol intended to be driven by tests.
Stores received data into a buffer.
"""

def __init__(self):
self._to_transmit = b""
self.received_bytes = b""
self.transport = None

def data_received(self, data: bytes) -> None:
self.received_bytes += data

def connection_made(self, transport: transports.BaseTransport) -> None:
assert isinstance(transport, Transport)
self.transport = transport
if self._to_transmit:
transport.write(self._to_transmit)

def write(self, data: bytes) -> None:
if self.transport:
self.transport.write(data)
else:
self._to_transmit += data
193 changes: 193 additions & 0 deletions tests/test_httpproxy_asyncio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
from asyncio import AbstractEventLoop, BaseTransport, Protocol, Task
from typing import Optional, Tuple, cast

from sygnal.exceptions import ProxyConnectError
from sygnal.helper.proxy.proxy_asyncio import HttpConnectProtocol

from tests import testutils
from tests.asyncio_test_helpers import (
MockProtocol,
MockTransport,
TimelessEventLoopWrapper,
)


class AsyncioHttpProxyTest(testutils.TestCase):
def config_setup(self, config):
super().config_setup(config)
config["apps"]["com.example.spqr"] = {
"type": "tests.test_pushgateway_api_v1.TestPushkin"
}
base_loop = asyncio.new_event_loop()
augmented_loop = TimelessEventLoopWrapper(base_loop) # type: ignore
asyncio.set_event_loop(cast(AbstractEventLoop, augmented_loop))

self.loop = augmented_loop

def make_fake_proxy(
self, host: str, port: int, proxy_credentials: Optional[Tuple[str, str]]
) -> Tuple[MockProtocol, MockTransport, "Task[Tuple[BaseTransport, Protocol]]"]:
# Task[Tuple[MockTransport, MockProtocol]]
# make a fake proxy
fake_proxy = MockTransport()
# make a fake protocol that we fancy using through the proxy
fake_protocol = MockProtocol()
# create a HTTP CONNECT proxy client protocol
http_connect_protocol = HttpConnectProtocol(
target_hostport=(host, port),
proxy_credentials=proxy_credentials,
protocol_factory=lambda: fake_protocol,
sslcontext=None,
loop=None,
)
switch_over_task = asyncio.get_event_loop().create_task(
http_connect_protocol.switch_over_when_ready()
)
# check the task is not somehow already marked as done before we even
# receive anything.
self.assertFalse(switch_over_task.done())
# connect the proxy client to the proxy
fake_proxy.set_protocol(http_connect_protocol)
http_connect_protocol.connection_made(fake_proxy)
return fake_protocol, fake_proxy, switch_over_task

def test_connect_no_credentials(self):
"""
Tests the proxy connection procedure when there is no basic auth.
"""
host = "example.org"
port = 443
proxy_credentials = None
fake_protocol, fake_proxy, switch_over_task = self.make_fake_proxy(
host, port, proxy_credentials
)

# Check that the proxy got the proper CONNECT request.
self.assertEqual(fake_proxy.buffer, b"CONNECT example.org:443 HTTP/1.0\r\n\r\n")
# Reset the proxy mock
fake_proxy.reset_mock()

# pretend we got a happy response with some dangling bytes from the
# target protocol
fake_proxy.pretend_to_receive(
b"HTTP/1.0 200 Connection Established\r\n\r\n"
b"begin beep boop\r\n\r\n~~ :) ~~"
)

# advance event loop because we have to let coroutines be executed
self.loop.advance(1.0)

# *now* we should have switched over from the HTTP CONNECT protocol
# to the user protocol (in our case, a MockProtocol).
self.assertTrue(switch_over_task.done())

transport, protocol = switch_over_task.result()

# check it was our protocol that was returned
self.assertIs(protocol, fake_protocol)

# check our protocol received exactly the bytes meant for it
self.assertEqual(
fake_protocol.received_bytes, b"begin beep boop\r\n\r\n~~ :) ~~"
)

def test_connect_correct_credentials(self):
"""
Tests the proxy connection procedure when there is basic auth.
"""
host = "example.org"
port = 443
proxy_credentials = ("user", "secret")
fake_protocol, fake_proxy, switch_over_task = self.make_fake_proxy(
host, port, proxy_credentials
)

# Check that the proxy got the proper CONNECT request with the
# correctly-encoded credentials
self.assertEqual(
fake_proxy.buffer,
b"CONNECT example.org:443 HTTP/1.0\r\n"
b"Proxy-Authorization: basic dXNlcjpzZWNyZXQ=\r\n\r\n",
)
# Reset the proxy mock
fake_proxy.reset_mock()

# pretend we got a happy response with some dangling bytes from the
# target protocol
fake_proxy.pretend_to_receive(
b"HTTP/1.0 200 Connection Established\r\n\r\n"
b"begin beep boop\r\n\r\n~~ :) ~~"
)

# advance event loop because we have to let coroutines be executed
self.loop.advance(1.0)

# *now* we should have switched over from the HTTP CONNECT protocol
# to the user protocol (in our case, a MockProtocol).
self.assertTrue(switch_over_task.done())

transport, protocol = switch_over_task.result()

# check it was our protocol that was returned
self.assertIs(protocol, fake_protocol)

# check our protocol received exactly the bytes meant for it
self.assertEqual(
fake_protocol.received_bytes, b"begin beep boop\r\n\r\n~~ :) ~~"
)

def test_connect_failure(self):
"""
Test that our task fails properly when we cannot make a connection through
the proxy.
"""
host = "example.org"
port = 443
proxy_credentials = ("user", "secret")
fake_protocol, fake_proxy, switch_over_task = self.make_fake_proxy(
host, port, proxy_credentials
)

# Check that the proxy got the proper CONNECT request with the
# correctly-encoded credentials.
self.assertEqual(
fake_proxy.buffer,
b"CONNECT example.org:443 HTTP/1.0\r\n"
b"Proxy-Authorization: basic dXNlcjpzZWNyZXQ=\r\n\r\n",
)
# Reset the proxy mock
fake_proxy.reset_mock()

# For the sake of this test, pretend the credentials are incorrect so
# send a sad response with a HTML error page
fake_proxy.pretend_to_receive(
b"HTTP/1.0 401 Unauthorised\r\n\r\n<HTML>... some error here ...</HTML>"
)

# advance event loop because we have to let coroutines be executed
self.loop.advance(1.0)

# *now* this future should have completed
self.assertTrue(switch_over_task.done())

# but we should have failed
self.assertIsInstance(switch_over_task.exception(), ProxyConnectError)

# check our protocol did not receive anything, because it was an HTTP-
# level error, not actually a connection to our target.
self.assertEqual(fake_protocol.received_bytes, b"")
Loading

0 comments on commit 53015c7

Please sign in to comment.