Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for HTTP(S) proxies to connect() #422

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,11 @@ Client

.. automodule:: websockets.client

.. autofunction:: connect(uri, *, create_protocol=None, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origin=None, extensions=None, subprotocols=None, extra_headers=None, compression='deflate', **kwds)
.. autofunction:: connect(uri, *, create_protocol=None, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origin=None, extensions=None, subprotocols=None, extra_headers=None, compression='deflate', proxy_uri=USE_SYSTEM_PROXY, proxy_ssl=None, **kwds)

.. autoclass:: WebSocketClientProtocol(*, host=None, port=None, secure=None, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origin=None, extensions=None, subprotocols=None, extra_headers=None)

.. automethod:: handshake(wsuri, origin=None, available_extensions=None, available_subprotocols=None, extra_headers=None)
.. automethod:: handshake(uri, origin=None, available_extensions=None, available_subprotocols=None, extra_headers=None)

Shared
......
Expand Down
157 changes: 133 additions & 24 deletions websockets/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import asyncio
import collections.abc
import sys
import urllib.request

from .exceptions import (
InvalidHandshake, InvalidMessage, InvalidStatusCode, NegotiationError
Expand All @@ -18,11 +19,13 @@
)
from .http import USER_AGENT, basic_auth_header, build_headers, read_response
from .protocol import WebSocketCommonProtocol
from .uri import parse_uri
from .uri import parse_proxy_uri, parse_uri


__all__ = ['connect', 'WebSocketClientProtocol']

USE_SYSTEM_PROXY = object()


class WebSocketClientProtocol(WebSocketCommonProtocol):
"""
Expand Down Expand Up @@ -196,7 +199,67 @@ def process_subprotocol(headers, available_subprotocols):
return subprotocol

@asyncio.coroutine
def handshake(self, wsuri, origin=None, available_extensions=None,
def proxy_connect(self, proxy_uri, uri, ssl=None, server_hostname=None):
request = ['CONNECT {uri.host}:{uri.port} HTTP/1.1'.format(uri=uri)]

headers = []

if uri.port == (443 if uri.secure else 80): # pragma: no cover
headers.append(('Host', uri.host))
else:
headers.append(('Host', '{uri.host}:{uri.port}'.format(uri=uri)))

if proxy_uri.user_info:
headers.append((
'Proxy-Authorization',
basic_auth_header(*proxy_uri.user_info),
))

request.extend('{}: {}'.format(k, v) for k, v in headers)
request.append('\r\n')
request = '\r\n'.join(request).encode()

self.writer.write(request)

status_code, headers = yield from read_response(self.reader)

if not 200 <= status_code < 300:
# TODO improve error handling
raise ValueError("proxy error: HTTP {}".format(status_code))

if ssl is not None:
# Wrap socket with TLS. This ugly hack will be necessary until
# https://bugs.python.org/issue23749 is resolved and websockets
# drops support for all early Python versions.
if not asyncio.sslproto._is_sslproto_available():
raise ValueError(
"connecting to a wss:// server through a proxy isn't "
"supported on Python < 3.5")
old_protocol = self
old_transport = self.writer.transport
ssl_connected = asyncio.Future()
new_protocol = asyncio.sslproto.SSLProtocol(
loop=self.loop,
app_protocol=old_protocol,
# taken from _create_connection_transport
sslcontext=None if isinstance(ssl, bool) else ssl,
waiter=ssl_connected,
server_side=False,
server_hostname=server_hostname,
call_connection_made=False,
)
new_transport = new_protocol._app_transport

# Surgery without anesthesia.
old_transport._protocol = new_protocol
self.reader._transport = new_transport
self.writer._transport = new_transport

new_protocol.connection_made(old_transport)
yield from ssl_connected

@asyncio.coroutine
def handshake(self, uri, origin=None, available_extensions=None,
available_subprotocols=None, extra_headers=None):
"""
Perform the client side of the opening handshake.
Expand All @@ -220,13 +283,13 @@ def handshake(self, wsuri, origin=None, available_extensions=None,
set_header = lambda k, v: request_headers.append((k, v))
is_header_set = lambda k: k in dict(request_headers).keys()

if wsuri.port == (443 if wsuri.secure else 80): # pragma: no cover
set_header('Host', wsuri.host)
if uri.port == (443 if uri.secure else 80): # pragma: no cover
set_header('Host', uri.host)
else:
set_header('Host', '{}:{}'.format(wsuri.host, wsuri.port))
set_header('Host', '{uri.host}:{uri.port}'.format(uri=uri))

if wsuri.user_info:
set_header(*basic_auth_header(*wsuri.user_info))
if uri.user_info:
set_header('Authorization', basic_auth_header(*uri.user_info))

if origin is not None:
set_header('Origin', origin)
Expand Down Expand Up @@ -257,7 +320,7 @@ def handshake(self, wsuri, origin=None, available_extensions=None,
key = build_request(set_header)

yield from self.write_http_request(
wsuri.resource_name, request_headers)
uri.resource_name, request_headers)

status_code, response_headers = yield from self.read_http_response()
get_header = lambda k: response_headers.get(k, '')
Expand Down Expand Up @@ -318,6 +381,12 @@ class Connect:
* ``compression`` is a shortcut to configure compression extensions;
by default it enables the "permessage-deflate" extension; set it to
``None`` to disable compression
* ``proxy`` defines the HTTP proxy for establishing the connection; by
default, :func:`connect` uses proxies configured in the environment or
the system (see :func:`~urllib.request.getproxies` for details); set
``proxy`` to ``None`` to disable this behavior
* ``proxy_ssl`` may be set to a :class:`~ssl.SSLContext` to enforce TLS
settings for connecting to a ``https://`` proxy; it defaults to ``True``

:func:`connect` raises :exc:`~websockets.uri.InvalidURI` if ``uri`` is
invalid and :exc:`~websockets.handshake.InvalidHandshake` if the opening
Expand All @@ -331,7 +400,9 @@ def __init__(self, uri, *,
read_limit=2 ** 16, write_limit=2 ** 16,
loop=None, legacy_recv=False, klass=None,
origin=None, extensions=None, subprotocols=None,
extra_headers=None, compression='deflate', **kwds):
extra_headers=None, compression='deflate',
proxy_uri=USE_SYSTEM_PROXY, proxy_ssl=None,
ssl=None, sock=None, **kwds):
if loop is None:
loop = asyncio.get_event_loop()

Expand All @@ -343,12 +414,15 @@ def __init__(self, uri, *,
if create_protocol is None:
create_protocol = WebSocketClientProtocol

wsuri = parse_uri(uri)
if wsuri.secure:
kwds.setdefault('ssl', True)
elif kwds.get('ssl') is not None:
raise ValueError("connect() received a SSL context for a ws:// "
"URI, use a wss:// URI to enable TLS")
uri = parse_uri(uri)
if uri.secure:
if ssl is None:
ssl = True
elif ssl is not None:
raise ValueError(
"connect() received a TLS/SSL context for a ws:// URI;"
"use a wss:// URI to enable TLS",
)

if compression == 'deflate':
if extensions is None:
Expand All @@ -364,26 +438,55 @@ def __init__(self, uri, *,
raise ValueError("Unsupported compression: {}".format(compression))

factory = lambda: create_protocol(
host=wsuri.host, port=wsuri.port, secure=wsuri.secure,
host=uri.host, port=uri.port, secure=uri.secure,
timeout=timeout, max_size=max_size, max_queue=max_queue,
read_limit=read_limit, write_limit=write_limit,
loop=loop, legacy_recv=legacy_recv,
origin=origin, extensions=extensions, subprotocols=subprotocols,
extra_headers=extra_headers,
)

if kwds.get('sock') is None:
host, port = wsuri.host, wsuri.port
else:
if proxy_uri is USE_SYSTEM_PROXY:
proxies = urllib.request.getproxies()
if urllib.request.proxy_bypass(
'{uri.host}:{uri.port}'.format(uri=uri)):
proxy_uri = None
else:
# RFC 6455 recommends to prefer the proxy configured for HTTPS
# connections over the proxy configured for HTTP connections.
proxy_uri = proxies.get('https')
if proxy_uri is None and not uri.secure:
proxy_uri = proxies.get('http')

if proxy_uri is not None:
proxy_uri = parse_proxy_uri(proxy_uri)
if proxy_uri.secure:
if proxy_ssl is None:
proxy_ssl = True
elif proxy_ssl is not None:
raise ValueError(
"connect() received a TLS/SSL context for a HTTP proxy; "
"use a HTTPS proxy to enable TLS",
)

if sock is not None:
# If sock is given, host and port mustn't be specified.
host, port = None, None
conn_host, conn_port, conn_ssl = None, None, ssl
elif proxy_uri is not None:
conn_host, conn_port, conn_ssl = (
proxy_uri.host, proxy_uri.port, proxy_ssl)
else:
conn_host, conn_port, conn_ssl = uri.host, uri.port, ssl

self._wsuri = wsuri
self._origin = origin
self._proxy_uri = proxy_uri
self._uri = uri
if proxy_uri is not None:
self._ssl = ssl
self._server_hostname = kwds.pop('server_hostname', None)

# This is a coroutine object.
self._creating_connection = loop.create_connection(
factory, host, port, **kwds)
factory, conn_host, conn_port, ssl=conn_ssl, sock=sock, **kwds)

@asyncio.coroutine
def __aenter__(self):
Expand All @@ -397,8 +500,14 @@ def __await__(self):
transport, protocol = yield from self._creating_connection

try:
if self._proxy_uri is not None:
yield from protocol.proxy_connect(
self._proxy_uri, self._uri,
self._ssl, self._server_hostname,
)
yield from protocol.handshake(
self._wsuri, origin=self._origin,
self._uri,
origin=protocol.origin,
available_extensions=protocol.available_extensions,
available_subprotocols=protocol.available_subprotocols,
extra_headers=protocol.extra_headers,
Expand Down
2 changes: 1 addition & 1 deletion websockets/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,4 +218,4 @@ def basic_auth_header(username, password):
assert ':' not in username
user_pass = '{}:{}'.format(username, password)
basic_credentials = base64.b64encode(user_pass.encode()).decode()
return ('Authorization', 'Basic ' + basic_credentials)
return 'Basic ' + basic_credentials
2 changes: 1 addition & 1 deletion websockets/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,5 +133,5 @@ def test_basic_auth_header(self):
# Test vector from RFC 7617.
self.assertEqual(
basic_auth_header("Aladdin", "open sesame"),
('Authorization', 'Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ=='),
'Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==',
)
36 changes: 34 additions & 2 deletions websockets/test_uri.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,48 @@
'ws://localhost/path#fragment',
]

VALID_PROXY_URIS = [
(
'http://localhost',
(False, 'localhost', 80, None),
),
(
'https://localhost',
(True, 'localhost', 443, None),
),
(
'http://user:pass@localhost',
(False, 'localhost', 80, ('user', 'pass')),
),
]

INVALID_PROXY_URIS = [
'http://localhost/path',
'ws://localhost/',
'wss://localhost/',
]


class URITests(unittest.TestCase):

def test_success(self):
def test_parse_uri_success(self):
for uri, parsed in VALID_URIS:
with self.subTest(uri=uri):
self.assertEqual(parse_uri(uri), parsed)

def test_error(self):
def test_parse_uri_error(self):
for uri in INVALID_URIS:
with self.subTest(uri=uri):
with self.assertRaises(InvalidURI):
parse_uri(uri)

def test_parse_proxy_uri_success(self):
for uri, parsed in VALID_PROXY_URIS:
with self.subTest(uri=uri):
self.assertEqual(parse_proxy_uri(uri), parsed)

def test_parse_proxy_uri_error(self):
for uri in INVALID_PROXY_URIS:
with self.subTest(uri=uri):
with self.assertRaises(InvalidURI):
parse_proxy_uri(uri)
Loading