Skip to content

Commit

Permalink
Add SSL related params to ClientSession.request
Browse files Browse the repository at this point in the history
  • Loading branch information
cecton committed Sep 13, 2017
1 parent 23f3348 commit 69bf0a5
Show file tree
Hide file tree
Showing 5 changed files with 235 additions and 13 deletions.
9 changes: 7 additions & 2 deletions aiohttp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,10 @@ def _request(self, method, url, *,
read_until_eof=True,
proxy=None,
proxy_auth=None,
timeout=sentinel):
timeout=sentinel,
verify_ssl=None,
fingerprint=None,
ssl_context=None):

# NOTE: timeout clamps existing connect and read timeouts. We cannot
# set the default to None because we need to detect if the user wants
Expand Down Expand Up @@ -225,7 +228,9 @@ def _request(self, method, url, *,
expect100=expect100, loop=self._loop,
response_class=self._response_class,
proxy=proxy, proxy_auth=proxy_auth, timer=timer,
session=self, auto_decompress=self._auto_decompress)
session=self, auto_decompress=self._auto_decompress,
verify_ssl=verify_ssl, fingerprint=fingerprint,
ssl_context=ssl_context)

# connection timeout
try:
Expand Down
52 changes: 51 additions & 1 deletion aiohttp/client_reqrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
import collections
import io
import json
import ssl
import sys
import traceback
import warnings
from hashlib import md5, sha1, sha256
from http.cookies import CookieError, Morsel
from urllib.request import getproxies

Expand Down Expand Up @@ -34,6 +36,16 @@
'RequestInfo', ('url', 'method', 'headers'))


HASHFUNC_BY_DIGESTLEN = {
16: md5,
20: sha1,
32: sha256,
}


_SSL_OP_NO_COMPRESSION = getattr(ssl, "OP_NO_COMPRESSION", 0)


class ClientRequest:

GET_METHODS = {hdrs.METH_GET, hdrs.METH_HEAD, hdrs.METH_OPTIONS}
Expand Down Expand Up @@ -66,7 +78,13 @@ def __init__(self, method, url, *,
chunked=None, expect100=False,
loop=None, response_class=None,
proxy=None, proxy_auth=None, proxy_from_env=False,
timer=None, session=None, auto_decompress=True):
timer=None, session=None, auto_decompress=True,
verify_ssl=None, fingerprint=None, ssl_context=None):

if verify_ssl is False and ssl_context is not None:
raise ValueError(
"Either disable ssl certificate validation by "
"verify_ssl=False or specify ssl_context, not both.")

if loop is None:
loop = asyncio.get_event_loop()
Expand All @@ -89,6 +107,8 @@ def __init__(self, method, url, *,
self.response_class = response_class or ClientResponse
self._timer = timer if timer is not None else TimerNoop()
self._auto_decompress = auto_decompress
self._verify_ssl = verify_ssl
self._ssl_context = ssl_context

if loop.get_debug():
self._source_traceback = traceback.extract_stack(sys._getframe(1))
Expand All @@ -101,6 +121,7 @@ def __init__(self, method, url, *,
self.update_content_encoding(data)
self.update_auth(auth)
self.update_proxy(proxy, proxy_auth, proxy_from_env)
self.update_fingerprint(fingerprint)

self.update_body_from_data(data)
self.update_transfer_encoding()
Expand Down Expand Up @@ -307,6 +328,35 @@ def update_proxy(self, proxy, proxy_auth, proxy_from_env):
self.proxy = proxy
self.proxy_auth = proxy_auth

def update_fingerprint(self, fingerprint):
if fingerprint:
digestlen = len(fingerprint)
hashfunc = HASHFUNC_BY_DIGESTLEN.get(digestlen)
if not hashfunc:
raise ValueError('fingerprint has invalid length')
elif hashfunc is md5 or hashfunc is sha1:
warnings.simplefilter('always')
warnings.warn('md5 and sha1 are insecure and deprecated. '
'Use sha256.',
DeprecationWarning, stacklevel=2)
self._hashfunc = hashfunc
self._fingerprint = fingerprint

@property
def verify_ssl(self):
"""Do check for ssl certifications?"""
return self._verify_ssl

@property
def fingerprint(self):
"""Expected ssl certificate fingerprint."""
return self._fingerprint

@property
def ssl_context(self):
"""SSLContext instance for https requests."""
return self._ssl_context

def keep_alive(self):
if self.version < HttpVersion10:
# keep alive not supported at all
Expand Down
57 changes: 49 additions & 8 deletions aiohttp/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,12 +723,49 @@ def _create_connection(self, req):

return proto

@asyncio.coroutine
def _create_direct_connection(self, req):
def _get_ssl_context(self, req):
"""Logic to get the correct SSL context
0. if req.ssl is false, return None
1. if ssl_context is specified in req, use it
2. if _ssl_context is specified in self, use it
3. otherwise:
1. if verify_ssl is not specified in req, use self.ssl_context
(will generate a default context according to self.verify_ssl)
2. if verify_ssl is True in req, generate a default SSL context
3. if verify_ssl is False in req, generate a SSL context that
won't verify
"""
if req.ssl:
sslcontext = self.ssl_context
sslcontext = req.ssl_context or self._ssl_context
if not sslcontext:
if req.verify_ssl is None:
sslcontext = self.ssl_context
elif req.verify_ssl:
sslcontext = ssl.create_default_context()
else:
sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
sslcontext.options |= ssl.OP_NO_SSLv2
sslcontext.options |= ssl.OP_NO_SSLv3
sslcontext.options |= _SSL_OP_NO_COMPRESSION
sslcontext.set_default_verify_paths()
else:
sslcontext = None
return sslcontext

def _get_fingerprint_and_hashfunc(self, req):
if req.fingerprint:
return (req.fingerprint, req._hashfunc)
elif self.fingerprint:
return (self.fingerprint, self._hashfunc)
else:
return (None, None)

@asyncio.coroutine
def _create_direct_connection(self, req):
sslcontext = self._get_ssl_context(req)
fingerprint, hashfunc = self._get_fingerprint_and_hashfunc(req)

hosts = yield from self._resolve_host(req.url.raw_host, req.port)
exc = None
Expand All @@ -744,7 +781,7 @@ def _create_direct_connection(self, req):
server_hostname=hinfo['hostname'] if sslcontext else None,
local_addr=self._local_addr)
has_cert = transp.get_extra_info('sslcontext')
if has_cert and self._fingerprint:
if has_cert and fingerprint:
sock = transp.get_extra_info('socket')
if not hasattr(sock, 'getpeercert'):
# Workaround for asyncio 3.5.0
Expand All @@ -754,8 +791,8 @@ def _create_direct_connection(self, req):
# gives DER-encoded cert as a sequence of bytes (or None)
cert = sock.getpeercert(binary_form=True)
assert cert
got = self._hashfunc(cert).digest()
expected = self._fingerprint
got = hashfunc(cert).digest()
expected = fingerprint
if got != expected:
transp.close()
if not self._cleanup_closed_disabled:
Expand All @@ -774,7 +811,10 @@ def _create_proxy_connection(self, req):
hdrs.METH_GET, req.proxy,
headers={hdrs.HOST: req.headers[hdrs.HOST]},
auth=req.proxy_auth,
loop=self._loop)
loop=self._loop,
verify_ssl=req.verify_ssl,
fingerprint=req.fingerprint,
ssl_context=req.ssl_context)
try:
# create connection to proxy server
transport, proto = yield from self._create_direct_connection(
Expand All @@ -790,6 +830,7 @@ def _create_proxy_connection(self, req):
proxy_req.headers[hdrs.PROXY_AUTHORIZATION] = auth

if req.ssl:
sslcontext = self._get_ssl_context(req)
# For HTTPS requests over HTTP proxy
# we must notify proxy to tunnel connection
# so we send CONNECT command:
Expand Down Expand Up @@ -831,7 +872,7 @@ def _create_proxy_connection(self, req):
transport.close()

transport, proto = yield from self._loop.create_connection(
self._factory, ssl=self.ssl_context, sock=rawsock,
self._factory, ssl=sslcontext, sock=rawsock,
server_hostname=req.host)
finally:
proxy_resp.close()
Expand Down
24 changes: 24 additions & 0 deletions tests/test_client_request.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# coding: utf-8

import asyncio
import hashlib
import io
import os.path
import urllib.parse
Expand Down Expand Up @@ -1130,3 +1131,26 @@ def create_connection(req):
resp.close()
session.close()
conn.close()


def test_verify_ssl_false_with_ssl_context(loop):
with pytest.raises(ValueError):
ClientRequest('get', URL('http://python.org'), verify_ssl=False,
ssl_context=mock.Mock(), loop=loop)


def test_bad_fingerprint(loop):
with pytest.raises(ValueError):
req = ClientRequest('get', URL('http://python.org'),
fingerprint=b'invalid', loop=loop)


def test_insecure_fingerprint(loop):
with pytest.warns(DeprecationWarning):
req = ClientRequest('get', URL('http://python.org'),
fingerprint=hashlib.md5(b"foo").digest(),
loop=loop)
with pytest.warns(DeprecationWarning):
req = ClientRequest('get', URL('http://python.org'),
fingerprint=hashlib.sha1(b"foo").digest(),
loop=loop)
106 changes: 104 additions & 2 deletions tests/test_proxy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import gc
import hashlib
import os
import socket
import unittest
Expand All @@ -14,6 +15,17 @@


class TestProxy(unittest.TestCase):
fingerprint = hashlib.sha256(b"foo").digest()
response_mock_attrs = {
'status': 200,
}
mocked_response = mock.Mock(**response_mock_attrs)
clientrequest_mock_attrs = {
'return_value._hashfunc.return_value.digest.return_value': fingerprint,
'return_value.fingerprint': fingerprint,
'return_value.send.return_value.start':
make_mocked_coro(mocked_response),
}

def setUp(self):
self.loop = asyncio.new_event_loop()
Expand Down Expand Up @@ -57,6 +69,94 @@ def test_connect(self):
self.assertEqual(proxy_req.headers['Host'], 'www.python.org')
self.assertIs(proxy_req.loop, self.loop)

@mock.patch('aiohttp.connector.ClientRequest', **clientrequest_mock_attrs)
def test_connect_req_verify_ssl_true(self, ClientRequestMock):
req = ClientRequest(
'GET', URL('https://www.python.org'),
proxy=URL('http://proxy.example.com'),
loop=self.loop,
verify_ssl=True,
)

proto = mock.Mock()
connector = aiohttp.TCPConnector(loop=self.loop)
connector._create_proxy_connection = mock.MagicMock(
side_effect=connector._create_proxy_connection)
connector._create_direct_connection = mock.MagicMock(
side_effect=connector._create_direct_connection)
connector._resolve_host = make_mocked_coro([mock.MagicMock()])

self.loop.create_connection = make_mocked_coro(
(proto.transport, proto))
self.loop.run_until_complete(connector.connect(req))

connector._create_proxy_connection.assert_called_with(req)
((proxy_req,), _) = connector._create_direct_connection.call_args
proxy_req.send.assert_called_with(mock.ANY)

@mock.patch('aiohttp.connector.ClientRequest', **clientrequest_mock_attrs)
def test_connect_req_verify_ssl_false(self, ClientRequestMock):
req = ClientRequest(
'GET', URL('https://www.python.org'),
proxy=URL('http://proxy.example.com'),
loop=self.loop,
verify_ssl=False,
)

proto = mock.Mock()
connector = aiohttp.TCPConnector(loop=self.loop)
connector._create_proxy_connection = mock.MagicMock(
side_effect=connector._create_proxy_connection)
connector._create_direct_connection = mock.MagicMock(
side_effect=connector._create_direct_connection)
connector._resolve_host = make_mocked_coro([mock.MagicMock()])

self.loop.create_connection = make_mocked_coro(
(proto.transport, proto))
self.loop.run_until_complete(connector.connect(req))

connector._create_proxy_connection.assert_called_with(req)
((proxy_req,), _) = connector._create_direct_connection.call_args
proxy_req.send.assert_called_with(mock.ANY)

@mock.patch('aiohttp.connector.ClientRequest', **clientrequest_mock_attrs)
def test_connect_req_fingerprint_ssl_context(self, ClientRequestMock):
ssl_context = mock.Mock()
attrs = {
'return_value.ssl_context': ssl_context,
}
ClientRequestMock.configure_mock(**attrs)
req = ClientRequest(
'GET', URL('https://www.python.org'),
proxy=URL('http://proxy.example.com'),
loop=self.loop,
verify_ssl=True,
fingerprint=self.fingerprint,
ssl_context=ssl_context,
)

proto = mock.Mock()
connector = aiohttp.TCPConnector(loop=self.loop)
connector._create_proxy_connection = mock.MagicMock(
side_effect=connector._create_proxy_connection)
connector._create_direct_connection = mock.MagicMock(
side_effect=connector._create_direct_connection)
connector._resolve_host = make_mocked_coro([mock.MagicMock()])

transport_attrs = {
'get_extra_info.return_value.getpeercert.return_value': b"foo"
}
transport = mock.Mock(**transport_attrs)
self.loop.create_connection = make_mocked_coro(
(transport, proto))
self.loop.run_until_complete(connector.connect(req))

connector._create_proxy_connection.assert_called_with(req)
((proxy_req,), _) = connector._create_direct_connection.call_args
self.assertTrue(proxy_req.verify_ssl)
self.assertEqual(proxy_req.fingerprint, req.fingerprint)
self.assertIs(proxy_req.ssl_context, req.ssl_context)

def test_proxy_auth(self):
with self.assertRaises(ValueError) as ctx:
ClientRequest(
Expand Down Expand Up @@ -136,7 +236,8 @@ def test_auth(self, ClientRequestMock):
ClientRequestMock.assert_called_with(
'GET', URL('http://proxy.example.com'),
auth=aiohttp.helpers.BasicAuth('user', 'pass'),
loop=mock.ANY, headers=mock.ANY)
loop=mock.ANY, headers=mock.ANY, fingerprint=None,
ssl_context=None, verify_ssl=None)
conn.close()

def test_auth_utf8(self):
Expand Down Expand Up @@ -178,7 +279,8 @@ def test_auth_from_url(self, ClientRequestMock):

ClientRequestMock.assert_called_with(
'GET', URL('http://user:pass@proxy.example.com'),
auth=None, loop=mock.ANY, headers=mock.ANY)
auth=None, loop=mock.ANY, headers=mock.ANY, fingerprint=None,
ssl_context=None, verify_ssl=None)
conn.close()

@mock.patch('aiohttp.connector.ClientRequest')
Expand Down

0 comments on commit 69bf0a5

Please sign in to comment.