diff --git a/aiohttp/web_urldispatcher.py b/aiohttp/web_urldispatcher.py index 0359fd885b2..62851b6edca 100644 --- a/aiohttp/web_urldispatcher.py +++ b/aiohttp/web_urldispatcher.py @@ -207,10 +207,16 @@ def _sendfile_system(self, req, resp, fobj, count): `count` should be an integer > 0. """ + transport = req.transport + + if transport.get_extra_info("sslcontext"): + yield from self._sendfile_fallback(req, resp, fobj, count) + return + yield from resp.drain() loop = req.app.loop - out_fd = req.transport.get_extra_info("socket").fileno() + out_fd = transport.get_extra_info("socket").fileno() in_fd = fobj.fileno() fut = asyncio.Future(loop=loop) diff --git a/tests/test_web_functional.py b/tests/test_web_functional.py index 4bdfbda2200..596587d587f 100644 --- a/tests/test_web_functional.py +++ b/tests/test_web_functional.py @@ -5,11 +5,16 @@ import os.path import socket import unittest -from aiohttp import log, web, request, FormData, ClientSession +from aiohttp import log, web, request, FormData, ClientSession, TCPConnector from aiohttp.multidict import MultiDict from aiohttp.protocol import HttpVersion, HttpVersion10, HttpVersion11 from aiohttp.streams import EOF_MARKER +try: + import ssl +except: + ssl = False + class WebFunctionalSetupMixin: @@ -34,7 +39,7 @@ def find_unused_port(self): return port @asyncio.coroutine - def create_server(self, method, path, handler=None): + def create_server(self, method, path, handler=None, ssl_ctx=None): app = web.Application(loop=self.loop) if handler: app.router.add_route(method, path, handler) @@ -44,8 +49,9 @@ def create_server(self, method, path, handler=None): debug=True, keep_alive_on=False, access_log=log.access_logger) srv = yield from self.loop.create_server( - self.handler, '127.0.0.1', port) - url = "http://127.0.0.1:{}".format(port) + path + self.handler, '127.0.0.1', port, ssl=ssl_ctx) + protocol = "https" if ssl_ctx else "http" + url = "{}://127.0.0.1:{}".format(protocol, port) + path self.addCleanup(srv.close) return app, srv, url @@ -732,8 +738,10 @@ def go(): class StaticFileMixin(WebFunctionalSetupMixin): @asyncio.coroutine - def create_server(self, method, path): - app, srv, url = yield from super().create_server(method, path) + def create_server(self, method, path, ssl_ctx=None): + app, srv, url = yield from super().create_server( + method, path, ssl_ctx=ssl_ctx + ) app.router.add_static = self.patch_sendfile(app.router.add_static) return app, srv, url @@ -768,6 +776,45 @@ def go(dirname, filename): filename = 'data.unknown_mime_type' self.loop.run_until_complete(go(here, filename)) + @unittest.skipUnless(ssl, "ssl not supported") + def test_static_file_ssl(self): + + @asyncio.coroutine + def go(dirname, filename): + ssl_ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + ssl_ctx.load_cert_chain( + os.path.join(dirname, 'sample.crt'), + os.path.join(dirname, 'sample.key') + ) + app, _, url = yield from self.create_server( + 'GET', '/static/' + filename, ssl_ctx=ssl_ctx + ) + app.router.add_static('/static', dirname) + + conn = TCPConnector(verify_ssl=False, loop=self.loop) + session = ClientSession(connector=conn) + + resp = yield from session.request('GET', url) + self.assertEqual(200, resp.status) + txt = yield from resp.text() + self.assertEqual('file content', txt.rstrip()) + ct = resp.headers['CONTENT-TYPE'] + self.assertEqual('application/octet-stream', ct) + self.assertEqual(resp.headers.get('CONTENT-ENCODING'), None) + resp.close() + + resp = yield from session.request('GET', url + 'fake') + self.assertEqual(404, resp.status) + resp.close() + + resp = yield from session.request('GET', url + '/../../') + self.assertEqual(404, resp.status) + resp.close() + + here = os.path.dirname(__file__) + filename = 'data.unknown_mime_type' + self.loop.run_until_complete(go(here, filename)) + def test_static_file_with_content_type(self): @asyncio.coroutine