diff --git a/aiohttp/protocol.py b/aiohttp/protocol.py index ae2baa32ca0..08d6dbd89cd 100644 --- a/aiohttp/protocol.py +++ b/aiohttp/protocol.py @@ -550,8 +550,11 @@ def enable_chunked_encoding(self): def keep_alive(self): if self.keepalive is None: - if self.version <= HttpVersion10: - if self.headers.get('Connection') == 'keep-alive': + if self.version < HttpVersion10: + # keep alive not supported at all + return False + if self.version == HttpVersion10: + if self.headers.get(hdrs.CONNECTION) == 'keep-alive': return True else: # no headers means we close for Http 1.0 return False diff --git a/aiohttp/web_reqrep.py b/aiohttp/web_reqrep.py index 556294d837b..0cf0919ab1e 100644 --- a/aiohttp/web_reqrep.py +++ b/aiohttp/web_reqrep.py @@ -18,7 +18,7 @@ CIMultiDict, MultiDictProxy, MultiDict) -from .protocol import Response as ResponseImpl +from .protocol import Response as ResponseImpl, HttpVersion10 from .streams import EOF_MARKER @@ -95,7 +95,10 @@ def __init__(self, app, message, payload, transport, reader, writer, *, self._post = None self._post_files_cache = None self._headers = CIMultiDictProxy(message.headers) - self._keep_alive = not message.should_close + if self._version < HttpVersion10: + self._keep_alive = False + else: + self._keep_alive = not message.should_close # matchdict, route_name, handler # or information about traversal lookup diff --git a/tests/test_web_functional.py b/tests/test_web_functional.py index ca2a42537ea..ab0b46cb88c 100644 --- a/tests/test_web_functional.py +++ b/tests/test_web_functional.py @@ -5,7 +5,7 @@ import unittest from aiohttp import web, request, FormData from aiohttp.multidict import MultiDict -from aiohttp.protocol import HttpVersion10, HttpVersion11 +from aiohttp.protocol import HttpVersion, HttpVersion10, HttpVersion11 from aiohttp.streams import EOF_MARKER @@ -479,6 +479,24 @@ def go(): self.loop.run_until_complete(go()) + def test_http09_keep_alive_default(self): + + @asyncio.coroutine + def handler(request): + yield from request.read() + return web.Response(body=b'OK') + + @asyncio.coroutine + def go(): + headers = {'Connection': 'keep-alive'} # should be ignored + _, _, url = yield from self.create_server('GET', '/', handler) + resp = yield from request('GET', url, loop=self.loop, + headers=headers, + version=HttpVersion(0, 9)) + self.assertEqual('close', resp.headers['CONNECTION']) + + self.loop.run_until_complete(go()) + def test_http10_keep_alive_with_headers_close(self): @asyncio.coroutine diff --git a/tests/test_web_response.py b/tests/test_web_response.py index fea7cbac9d1..522fb458f02 100644 --- a/tests/test_web_response.py +++ b/tests/test_web_response.py @@ -4,7 +4,8 @@ from aiohttp import hdrs from aiohttp.multidict import CIMultiDict from aiohttp.web import Request, StreamResponse, Response -from aiohttp.protocol import RawRequestMessage, HttpVersion11, HttpVersion10 +from aiohttp.protocol import HttpVersion, HttpVersion11, HttpVersion10 +from aiohttp.protocol import RawRequestMessage class TestStreamResponse(unittest.TestCase): @@ -356,7 +357,7 @@ def test_keep_alive_http10(self): req = self.request_from_message(message) resp = StreamResponse() resp.start(req) - self.assertEqual(resp.keep_alive, False) + self.assertFalse(resp.keep_alive) headers = CIMultiDict(Connection='keep-alive') message = RawRequestMessage('GET', '/', HttpVersion10, headers, @@ -366,6 +367,15 @@ def test_keep_alive_http10(self): resp.start(req) self.assertEqual(resp.keep_alive, True) + def test_keep_alive_http09(self): + headers = CIMultiDict(Connection='keep-alive') + message = RawRequestMessage('GET', '/', HttpVersion(0, 9), headers, + False, False) + req = self.request_from_message(message) + resp = StreamResponse() + resp.start(req) + self.assertFalse(resp.keep_alive) + class TestResponse(unittest.TestCase):