diff --git a/CHANGES.txt b/CHANGES.txt index 222c6739611..ae84cde4ab9 100644 --- a/CHANGES.txt +++ b/CHANGES.txt @@ -2,7 +2,7 @@ CHANGES ======= 0.18.0a0 (XX-XX-XXXX) -------------------- +--------------------- - Use errors.HttpProcessingError.message as HTTP error reason and message #459 diff --git a/aiohttp/web.py b/aiohttp/web.py index 95ed41de5c1..706fbb04ed6 100644 --- a/aiohttp/web.py +++ b/aiohttp/web.py @@ -92,7 +92,7 @@ def handle_request(self, message, payload): except HTTPException as exc: resp = exc - resp_msg = resp.start(request) + resp_msg = yield from resp.prepare(request) yield from resp.write_eof() # notify server about keep-alive diff --git a/aiohttp/web_reqrep.py b/aiohttp/web_reqrep.py index 120a312e02a..3c07ef89621 100644 --- a/aiohttp/web_reqrep.py +++ b/aiohttp/web_reqrep.py @@ -425,9 +425,14 @@ def _copy_cookies(self): self.headers.add(hdrs.SET_COOKIE, value) @property - def started(self): + def prepared(self): return self._resp_impl is not None + @property + def started(self): + warnings.warn('use Response.prepared instead', DeprecationWarning) + return self.prepared + @property def status(self): return self._status @@ -612,27 +617,39 @@ def _start_pre_check(self, request): return None def _start_compression(self, request): - def start(coding): + def _start(coding): if coding != ContentCoding.identity: self.headers[hdrs.CONTENT_ENCODING] = coding.value self._resp_impl.add_compression_filter(coding.value) self.content_length = None if self._compression_force: - start(self._compression_force) + _start(self._compression_force) else: accept_encoding = request.headers.get( hdrs.ACCEPT_ENCODING, '').lower() for coding in ContentCoding: if coding.value in accept_encoding: - start(coding) + _start(coding) return def start(self, request): + warnings.warn('use .prepare(request) instead', DeprecationWarning) resp_impl = self._start_pre_check(request) if resp_impl is not None: return resp_impl + return self._start(request) + + @asyncio.coroutine + def prepare(self, request): + resp_impl = self._start_pre_check(request) + if resp_impl is not None: + return resp_impl + + return self._start(request) + + def _start(self, request): self._req = request keep_alive = self._keep_alive if keep_alive is None: diff --git a/aiohttp/web_urldispatcher.py b/aiohttp/web_urldispatcher.py index 62851b6edca..f0f0db94833 100644 --- a/aiohttp/web_urldispatcher.py +++ b/aiohttp/web_urldispatcher.py @@ -281,7 +281,7 @@ def handle(self, request): file_size = st.st_size resp.content_length = file_size - resp.start(request) + yield from resp.prepare(request) with open(filepath, 'rb') as f: yield from self._sendfile(request, resp, f, file_size) diff --git a/aiohttp/web_ws.py b/aiohttp/web_ws.py index ce7e4eeec9b..f3736eb6cfe 100644 --- a/aiohttp/web_ws.py +++ b/aiohttp/web_ws.py @@ -35,12 +35,19 @@ def __init__(self, *, self._autoclose = autoclose self._autoping = autoping - def start(self, request): + @asyncio.coroutine + def prepare(self, request): # make pre-check to don't hide it by do_handshake() exceptions resp_impl = self._start_pre_check(request) if resp_impl is not None: return resp_impl + parser, protocol, writer = self._pre_start(request) + resp_impl = yield from super().prepare(request) + self._post_start(request, parser, protocol, writer) + return resp_impl + + def _pre_start(self, request): try: status, headers, parser, writer, protocol = do_handshake( request.method, request.headers, request.transport, @@ -59,17 +66,27 @@ def start(self, request): for k, v in headers: self.headers[k] = v self.force_close() + return parser, protocol, writer - resp_impl = super().start(request) - + def _post_start(self, request, parser, protocol, writer): self._reader = request._reader.set_parser(parser) self._writer = writer self._protocol = protocol self._loop = request.app.loop + def start(self, request): + warnings.warn('use .prepare(request) instead', DeprecationWarning) + # make pre-check to don't hide it by do_handshake() exceptions + resp_impl = self._start_pre_check(request) + if resp_impl is not None: + return resp_impl + + parser, protocol, writer = self._pre_start(request) + resp_impl = super().start(request) + self._post_start(request, parser, protocol, writer) return resp_impl - def can_start(self, request): + def can_prepare(self, request): if self._writer is not None: raise RuntimeError('Already started') try: @@ -81,6 +98,10 @@ def can_start(self, request): else: return True, protocol + def can_start(self, request): + warnings.warn('use .can_prepare(request) instead', DeprecationWarning) + return self.can_prepare(request) + @property def closed(self): return self._closed @@ -98,7 +119,7 @@ def exception(self): def ping(self, message='b'): if self._writer is None: - raise RuntimeError('Call .start() first') + raise RuntimeError('Call .prepare() first') if self._closed: raise RuntimeError('websocket connection is closing') self._writer.ping(message) @@ -106,14 +127,14 @@ def ping(self, message='b'): def pong(self, message='b'): # unsolicited pong if self._writer is None: - raise RuntimeError('Call .start() first') + raise RuntimeError('Call .prepare() first') if self._closed: raise RuntimeError('websocket connection is closing') self._writer.pong(message) def send_str(self, data): if self._writer is None: - raise RuntimeError('Call .start() first') + raise RuntimeError('Call .prepare() first') if self._closed: raise RuntimeError('websocket connection is closing') if not isinstance(data, str): @@ -122,7 +143,7 @@ def send_str(self, data): def send_bytes(self, data): if self._writer is None: - raise RuntimeError('Call .start() first') + raise RuntimeError('Call .prepare() first') if self._closed: raise RuntimeError('websocket connection is closing') if not isinstance(data, (bytes, bytearray, memoryview)): @@ -151,7 +172,7 @@ def write_eof(self): @asyncio.coroutine def close(self, *, code=1000, message=b''): if self._writer is None: - raise RuntimeError('Call .start() first') + raise RuntimeError('Call .prepare() first') if not self._closed: self._closed = True @@ -190,7 +211,7 @@ def close(self, *, code=1000, message=b''): @asyncio.coroutine def receive(self): if self._reader is None: - raise RuntimeError('Call .start() first') + raise RuntimeError('Call .prepare() first') if self._waiting: raise RuntimeError('Concurrent call to receive() is not allowed') @@ -239,7 +260,7 @@ def receive(self): self._waiting = False @asyncio.coroutine - def receive_msg(self): # pragma: no cover + def receive_msg(self): warnings.warn( 'receive_msg() coroutine is deprecated. use receive() instead', DeprecationWarning) diff --git a/docs/web.rst b/docs/web.rst index f0ab25b4ca9..e95c88da997 100644 --- a/docs/web.rst +++ b/docs/web.rst @@ -407,7 +407,7 @@ using response's methods: def websocket_handler(request): ws = web.WebSocketResponse() - ws.start(request) + yield from ws.prepare(request) while True: msg = yield from ws.receive() diff --git a/docs/web_reference.rst b/docs/web_reference.rst index f6fd288dfab..f2f08398447 100644 --- a/docs/web_reference.rst +++ b/docs/web_reference.rst @@ -339,11 +339,10 @@ StreamResponse The most important thing you should know about *response* --- it is *Finite State Machine*. - That means you can do any manipulations with *headers*, - *cookies* and *status code* only before :meth:`start` - called. + That means you can do any manipulations with *headers*, *cookies* + and *status code* only before :meth:`prepare` coroutine is called. - Once you call :meth:`start` any change of + Once you call :meth:`prepare` any change of the *HTTP header* part will raise :exc:`RuntimeError` exception. Any :meth:`write` call after :meth:`write_eof` is also forbidden. @@ -355,11 +354,19 @@ StreamResponse parameter. Otherwise pass :class:`str` with arbitrary *status* explanation.. - .. attribute:: started + .. attribute:: prepared - Read-only :class:`bool` property, ``True`` if :meth:`start` has + Read-only :class:`bool` property, ``True`` if :meth:`prepare` has been called, ``False`` otherwise. + .. versionadded:: 0.18 + + .. attribute:: started + + Deprecated alias for :attr:`prepared`. + + .. deprecated:: 0.18 + .. attribute:: status Read-only property for *HTTP response status code*, :class:`int`. @@ -550,16 +557,30 @@ StreamResponse Send *HTTP header*. You should not change any header data after calling this method. + .. deprecated:: 0.18 + + Use :meth:`prepare` instead. + + .. coroutinemethod:: prepare(request) + + :param aiohttp.web.Request request: HTTP request object, that the + response answers. + + Send *HTTP header*. You should not change any header data after + calling this method. + + .. versionadded:: 0.18 + .. method:: write(data) Send byte-ish data as the part of *response BODY*. - :meth:`start` must be called before. + :meth:`prepare` must be called before. Raises :exc:`TypeError` if data is not :class:`bytes`, :class:`bytearray` or :class:`memoryview` instance. - Raises :exc:`RuntimeError` if :meth:`start` has not been called. + Raises :exc:`RuntimeError` if :meth:`prepare` has not been called. Raises :exc:`RuntimeError` if :meth:`write_eof` has been called. @@ -651,11 +672,23 @@ WebSocketResponse Class for handling server-side websockets. - After starting (by :meth:`start` call) the response you + After starting (by :meth:`prepare` call) the response you cannot use :meth:`~StreamResponse.write` method but should to communicate with websocket client by :meth:`send_str`, :meth:`receive` and others. + .. coroutinemethod:: prepare(request) + + Starts websocket. After the call you can use websocket methods. + + :param aiohttp.web.Request request: HTTP request object, that the + response answers. + + + :raises HTTPException: if websocket handshake has failed. + + .. versionadded:: 0.18 + .. method:: start(request) Starts websocket. After the call you can use websocket methods. @@ -666,12 +699,17 @@ WebSocketResponse :raises HTTPException: if websocket handshake has failed. - .. method:: can_start(request) + .. deprecated:: 0.18 + + Use :meth:`prepare` instead. + + .. method:: can_prepare(request) Performs checks for *request* data to figure out if websocket can be started on the request. - If :meth:`can_start` call is success then :meth:`start` will success too. + If :meth:`can_prepare` call is success then :meth:`prepare` will + success too. :param aiohttp.web.Request request: HTTP request object, that the response answers. @@ -684,6 +722,12 @@ WebSocketResponse .. note:: The method never raises exception. + .. method:: can_start(request) + + Deprecated alias for :meth:`can_prepare` + + .. deprecated:: 0.18 + .. attribute:: closed Read-only property, ``True`` if connection has been closed or in process diff --git a/examples/web_srv.py b/examples/web_srv.py index 29fdb82080b..780b4c2af44 100755 --- a/examples/web_srv.py +++ b/examples/web_srv.py @@ -15,7 +15,7 @@ def intro(request): binary = txt.encode('utf8') resp = StreamResponse() resp.content_length = len(binary) - resp.start(request) + yield from resp.prepare(request) resp.write(binary) return resp @@ -36,7 +36,7 @@ def hello(request): name = request.match_info.get('name', 'Anonymous') answer = ('Hello, ' + name).encode('utf8') resp.content_length = len(answer) - resp.start(request) + yield from resp.prepare(request) resp.write(answer) yield from resp.write_eof() return resp diff --git a/examples/web_ws.py b/examples/web_ws.py index 9499a97ded2..7dd5bdae023 100755 --- a/examples/web_ws.py +++ b/examples/web_ws.py @@ -17,7 +17,7 @@ def wshandler(request): with open(WS_FILE, 'rb') as fp: return Response(body=fp.read(), content_type='text/html') - resp.start(request) + yield from resp.prepare(request) print('Someone joined.') for ws in request.app['sockets']: ws.send_str('Someone joined') diff --git a/tests/autobahn/server.py b/tests/autobahn/server.py index bc33b6e05fe..e33691212a2 100644 --- a/tests/autobahn/server.py +++ b/tests/autobahn/server.py @@ -12,7 +12,7 @@ def wshandler(request): if not ok: return web.HTTPBadRequest() - ws.start(request) + yield from ws.prepare(request) while True: msg = yield from ws.receive() diff --git a/tests/test_web_exceptions.py b/tests/test_web_exceptions.py index 8be61636029..611562619fd 100644 --- a/tests/test_web_exceptions.py +++ b/tests/test_web_exceptions.py @@ -50,7 +50,7 @@ def test_all_http_exceptions_exported(self): def test_HTTPOk(self): req = self.make_request() resp = web.HTTPOk() - resp.start(req) + self.loop.run_until_complete(resp.prepare(req)) self.loop.run_until_complete(resp.write_eof()) txt = self.buf.decode('utf8') self.assertRegex(txt, ('HTTP/1.1 200 OK\r\n' @@ -85,7 +85,7 @@ def test_HTTPFound(self): resp = web.HTTPFound(location='/redirect') self.assertEqual('/redirect', resp.location) self.assertEqual('/redirect', resp.headers['location']) - resp.start(req) + self.loop.run_until_complete(resp.prepare(req)) self.loop.run_until_complete(resp.write_eof()) txt = self.buf.decode('utf8') self.assertRegex(txt, ('HTTP/1.1 302 Found\r\n' @@ -110,7 +110,7 @@ def test_HTTPMethodNotAllowed(self): self.assertEqual('GET', resp.method) self.assertEqual(['POST', 'PUT'], resp.allowed_methods) self.assertEqual('POST,PUT', resp.headers['allow']) - resp.start(req) + self.loop.run_until_complete(resp.prepare(req)) self.loop.run_until_complete(resp.write_eof()) txt = self.buf.decode('utf8') self.assertRegex(txt, ('HTTP/1.1 405 Method Not Allowed\r\n' diff --git a/tests/test_web_functional.py b/tests/test_web_functional.py index 596587d587f..91b40517993 100644 --- a/tests/test_web_functional.py +++ b/tests/test_web_functional.py @@ -715,7 +715,7 @@ def test_stream_response_multiple_chunks(self): def handler(request): resp = web.StreamResponse() resp.enable_chunked_encoding() - resp.start(request) + yield from resp.prepare(request) resp.write(b'x') resp.write(b'y') resp.write(b'z') diff --git a/tests/test_web_response.py b/tests/test_web_response.py index db9e1928edb..11fb5454871 100644 --- a/tests/test_web_response.py +++ b/tests/test_web_response.py @@ -147,16 +147,16 @@ def test_start(self, ResponseImpl): resp = StreamResponse() self.assertIsNone(resp.keep_alive) - msg = resp.start(req) + msg = self.loop.run_until_complete(resp.prepare(req)) self.assertTrue(msg.send_headers.called) - self.assertIs(msg, resp.start(req)) + self.assertIs(msg, self.loop.run_until_complete(resp.prepare(req))) self.assertTrue(resp.keep_alive) req2 = self.make_request('GET', '/') with self.assertRaises(RuntimeError): - resp.start(req2) + self.loop.run_until_complete(resp.prepare(req2)) @mock.patch('aiohttp.web_reqrep.ResponseImpl') def test_chunked_encoding(self, ResponseImpl): @@ -167,7 +167,7 @@ def test_chunked_encoding(self, ResponseImpl): resp.enable_chunked_encoding() self.assertTrue(resp.chunked) - msg = resp.start(req) + msg = self.loop.run_until_complete(resp.prepare(req)) self.assertTrue(msg.chunked) @mock.patch('aiohttp.web_reqrep.ResponseImpl') @@ -179,7 +179,7 @@ def test_chunk_size(self, ResponseImpl): resp.enable_chunked_encoding(chunk_size=8192) self.assertTrue(resp.chunked) - msg = resp.start(req) + msg = self.loop.run_until_complete(resp.prepare(req)) self.assertTrue(msg.chunked) msg.add_chunking_filter.assert_called_with(8192) self.assertIsNotNone(msg.filter) @@ -192,7 +192,7 @@ def test_chunked_encoding_forbidden_for_http_10(self): with self.assertRaisesRegex( RuntimeError, "Using chunked encoding is forbidden for HTTP/1.0"): - resp.start(req) + self.loop.run_until_complete(resp.prepare(req)) @mock.patch('aiohttp.web_reqrep.ResponseImpl') def test_compression_no_accept(self, ResponseImpl): @@ -204,7 +204,7 @@ def test_compression_no_accept(self, ResponseImpl): resp.enable_compression() self.assertTrue(resp.compression) - msg = resp.start(req) + msg = self.loop.run_until_complete(resp.prepare(req)) self.assertFalse(msg.add_compression_filter.called) @mock.patch('aiohttp.web_reqrep.ResponseImpl') @@ -217,7 +217,7 @@ def test_force_compression_no_accept_backwards_compat(self, ResponseImpl): resp.enable_compression(force=True) self.assertTrue(resp.compression) - msg = resp.start(req) + msg = self.loop.run_until_complete(resp.prepare(req)) self.assertTrue(msg.add_compression_filter.called) self.assertIsNotNone(msg.filter) @@ -230,7 +230,7 @@ def test_force_compression_false_backwards_compat(self, ResponseImpl): resp.enable_compression(force=False) self.assertTrue(resp.compression) - msg = resp.start(req) + msg = self.loop.run_until_complete(resp.prepare(req)) self.assertFalse(msg.add_compression_filter.called) @mock.patch('aiohttp.web_reqrep.ResponseImpl') @@ -245,7 +245,7 @@ def test_compression_default_coding(self, ResponseImpl): resp.enable_compression() self.assertTrue(resp.compression) - msg = resp.start(req) + msg = self.loop.run_until_complete(resp.prepare(req)) msg.add_compression_filter.assert_called_with('deflate') self.assertEqual('deflate', resp.headers.get(hdrs.CONTENT_ENCODING)) self.assertIsNotNone(msg.filter) @@ -260,7 +260,7 @@ def test_force_compression_deflate(self, ResponseImpl): resp.enable_compression(ContentCoding.deflate) self.assertTrue(resp.compression) - msg = resp.start(req) + msg = self.loop.run_until_complete(resp.prepare(req)) msg.add_compression_filter.assert_called_with('deflate') self.assertEqual('deflate', resp.headers.get(hdrs.CONTENT_ENCODING)) @@ -272,7 +272,7 @@ def test_force_compression_no_accept_deflate(self, ResponseImpl): resp.enable_compression(ContentCoding.deflate) self.assertTrue(resp.compression) - msg = resp.start(req) + msg = self.loop.run_until_complete(resp.prepare(req)) msg.add_compression_filter.assert_called_with('deflate') self.assertEqual('deflate', resp.headers.get(hdrs.CONTENT_ENCODING)) @@ -286,7 +286,7 @@ def test_force_compression_gzip(self, ResponseImpl): resp.enable_compression(ContentCoding.gzip) self.assertTrue(resp.compression) - msg = resp.start(req) + msg = self.loop.run_until_complete(resp.prepare(req)) msg.add_compression_filter.assert_called_with('gzip') self.assertEqual('gzip', resp.headers.get(hdrs.CONTENT_ENCODING)) @@ -298,7 +298,7 @@ def test_force_compression_no_accept_gzip(self, ResponseImpl): resp.enable_compression(ContentCoding.gzip) self.assertTrue(resp.compression) - msg = resp.start(req) + msg = self.loop.run_until_complete(resp.prepare(req)) msg.add_compression_filter.assert_called_with('gzip') self.assertEqual('gzip', resp.headers.get(hdrs.CONTENT_ENCODING)) @@ -310,12 +310,13 @@ def test_delete_content_length_if_compression_enabled(self, ResponseImpl): resp.enable_compression(ContentCoding.gzip) - resp.start(req) + self.loop.run_until_complete(resp.prepare(req)) self.assertIsNone(resp.content_length) def test_write_non_byteish(self): resp = StreamResponse() - resp.start(self.make_request('GET', '/')) + self.loop.run_until_complete( + resp.prepare(self.make_request('GET', '/'))) with self.assertRaises(AssertionError): resp.write(123) @@ -328,7 +329,8 @@ def test_write_before_start(self): def test_cannot_write_after_eof(self): resp = StreamResponse() - resp.start(self.make_request('GET', '/')) + self.loop.run_until_complete( + resp.prepare(self.make_request('GET', '/'))) resp.write(b'data') self.writer.drain.return_value = () @@ -347,7 +349,8 @@ def test_cannot_write_eof_before_headers(self): def test_cannot_write_eof_twice(self): resp = StreamResponse() - resp.start(self.make_request('GET', '/')) + self.loop.run_until_complete( + resp.prepare(self.make_request('GET', '/'))) resp.write(b'data') self.writer.drain.return_value = () @@ -360,13 +363,15 @@ def test_cannot_write_eof_twice(self): def test_write_returns_drain(self): resp = StreamResponse() - resp.start(self.make_request('GET', '/')) + self.loop.run_until_complete( + resp.prepare(self.make_request('GET', '/'))) self.assertEqual((), resp.write(b'data')) def test_write_returns_empty_tuple_on_empty_data(self): resp = StreamResponse() - resp.start(self.make_request('GET', '/')) + self.loop.run_until_complete( + resp.prepare(self.make_request('GET', '/'))) self.assertEqual((), resp.write(b'')) @@ -460,14 +465,14 @@ def test_start_force_close(self): resp.force_close() self.assertFalse(resp.keep_alive) - msg = resp.start(req) + msg = self.loop.run_until_complete(resp.prepare(req)) self.assertFalse(resp.keep_alive) self.assertTrue(msg.closing) def test___repr__(self): req = self.make_request('GET', '/path/to') resp = StreamResponse(reason=301) - resp.start(req) + self.loop.run_until_complete(resp.prepare(req)) self.assertEqual("", repr(resp)) def test___repr__not_started(self): @@ -479,7 +484,7 @@ def test_keep_alive_http10(self): True, False) req = self.request_from_message(message) resp = StreamResponse() - resp.start(req) + self.loop.run_until_complete(resp.prepare(req)) self.assertFalse(resp.keep_alive) headers = CIMultiDict(Connection='keep-alive') @@ -487,7 +492,7 @@ def test_keep_alive_http10(self): False, False) req = self.request_from_message(message) resp = StreamResponse() - resp.start(req) + self.loop.run_until_complete(resp.prepare(req)) self.assertEqual(resp.keep_alive, True) def test_keep_alive_http09(self): @@ -496,9 +501,19 @@ def test_keep_alive_http09(self): False, False) req = self.request_from_message(message) resp = StreamResponse() - resp.start(req) + self.loop.run_until_complete(resp.prepare(req)) self.assertFalse(resp.keep_alive) + @mock.patch('aiohttp.web_reqrep.ResponseImpl') + def test_start_twice(self, ResponseImpl): + req = self.make_request('GET', '/') + resp = StreamResponse() + + with self.assertWarns(DeprecationWarning): + impl1 = resp.start(req) + impl2 = resp.start(req) + self.assertIs(impl1, impl2) + class TestResponse(unittest.TestCase): @@ -599,7 +614,7 @@ def append(data): self.writer.write.side_effect = append - resp.start(req) + self.loop.run_until_complete(resp.prepare(req)) self.loop.run_until_complete(resp.write_eof()) txt = buf.decode('utf8') self.assertRegex(txt, 'HTTP/1.1 200 OK\r\nCONTENT-LENGTH: 0\r\n' @@ -619,7 +634,7 @@ def append(data): self.writer.write.side_effect = append - resp.start(req) + self.loop.run_until_complete(resp.prepare(req)) self.loop.run_until_complete(resp.write_eof()) txt = buf.decode('utf8') self.assertRegex(txt, 'HTTP/1.1 200 OK\r\nCONTENT-LENGTH: 4\r\n' @@ -640,7 +655,7 @@ def append(data): self.writer.write.side_effect = append - resp.start(req) + self.loop.run_until_complete(resp.prepare(req)) self.loop.run_until_complete(resp.write_eof()) txt = buf.decode('utf8') self.assertRegex(txt, 'HTTP/1.1 200 OK\r\nCONTENT-LENGTH: 0\r\n' @@ -669,12 +684,13 @@ def test_set_text_with_charset(self): def test_started_when_not_started(self): resp = StreamResponse() - self.assertFalse(resp.started) + self.assertFalse(resp.prepared) def test_started_when_started(self): resp = StreamResponse() - resp.start(self.make_request('GET', '/')) - self.assertTrue(resp.started) + self.loop.run_until_complete( + resp.prepare(self.make_request('GET', '/'))) + self.assertTrue(resp.prepared) def test_drain_before_start(self): diff --git a/tests/test_web_websocket.py b/tests/test_web_websocket.py index e2a6b48b333..bc8a7fb11a0 100644 --- a/tests/test_web_websocket.py +++ b/tests/test_web_websocket.py @@ -90,7 +90,7 @@ def test_receive_str_nonstring(self): def go(): req = self.make_request('GET', '/') ws = WebSocketResponse() - ws.start(req) + yield from ws.prepare(req) @asyncio.coroutine def receive(): @@ -109,7 +109,7 @@ def test_receive_bytes_nonsbytes(self): def go(): req = self.make_request('GET', '/') ws = WebSocketResponse() - ws.start(req) + yield from ws.prepare(req) @asyncio.coroutine def receive(): @@ -125,14 +125,14 @@ def receive(): def test_send_str_nonstring(self): req = self.make_request('GET', '/') ws = WebSocketResponse() - ws.start(req) + self.loop.run_until_complete(ws.prepare(req)) with self.assertRaises(TypeError): ws.send_str(b'bytes') def test_send_bytes_nonbytes(self): req = self.make_request('GET', '/') ws = WebSocketResponse() - ws.start(req) + self.loop.run_until_complete(ws.prepare(req)) with self.assertRaises(TypeError): ws.send_bytes('string') @@ -141,33 +141,33 @@ def test_write(self): with self.assertRaises(RuntimeError): ws.write(b'data') - def test_can_start_ok(self): + def test_can_prepare_ok(self): req = self.make_request('GET', '/') ws = WebSocketResponse(protocols=('chat',)) - self.assertEqual((True, 'chat'), ws.can_start(req)) + self.assertEqual((True, 'chat'), ws.can_prepare(req)) - def test_can_start_unknown_protocol(self): + def test_can_prepare_unknown_protocol(self): req = self.make_request('GET', '/') ws = WebSocketResponse() - self.assertEqual((True, None), ws.can_start(req)) + self.assertEqual((True, None), ws.can_prepare(req)) - def test_can_start_invalid_method(self): + def test_can_prepare_invalid_method(self): req = self.make_request('POST', '/') ws = WebSocketResponse() - self.assertEqual((False, None), ws.can_start(req)) + self.assertEqual((False, None), ws.can_prepare(req)) - def test_can_start_without_upgrade(self): + def test_can_prepare_without_upgrade(self): req = self.make_request('GET', '/', headers=CIMultiDict({})) ws = WebSocketResponse() - self.assertEqual((False, None), ws.can_start(req)) + self.assertEqual((False, None), ws.can_prepare(req)) - def test_can_start_started(self): + def test_can_prepare_started(self): req = self.make_request('GET', '/') ws = WebSocketResponse() - ws.start(req) + self.loop.run_until_complete(ws.prepare(req)) with self.assertRaisesRegex(RuntimeError, 'Already started'): - ws.can_start(req) + ws.can_prepare(req) def test_closed_after_ctor(self): ws = WebSocketResponse() @@ -177,7 +177,7 @@ def test_closed_after_ctor(self): def test_send_str_closed(self): req = self.make_request('GET', '/') ws = WebSocketResponse() - ws.start(req) + self.loop.run_until_complete(ws.prepare(req)) self.loop.run_until_complete(ws.close()) with self.assertRaises(RuntimeError): ws.send_str('string') @@ -185,7 +185,7 @@ def test_send_str_closed(self): def test_send_bytes_closed(self): req = self.make_request('GET', '/') ws = WebSocketResponse() - ws.start(req) + self.loop.run_until_complete(ws.prepare(req)) self.loop.run_until_complete(ws.close()) with self.assertRaises(RuntimeError): ws.send_bytes(b'bytes') @@ -193,7 +193,7 @@ def test_send_bytes_closed(self): def test_ping_closed(self): req = self.make_request('GET', '/') ws = WebSocketResponse() - ws.start(req) + self.loop.run_until_complete(ws.prepare(req)) self.loop.run_until_complete(ws.close()) with self.assertRaises(RuntimeError): ws.ping() @@ -201,7 +201,7 @@ def test_ping_closed(self): def test_pong_closed(self): req = self.make_request('GET', '/') ws = WebSocketResponse() - ws.start(req) + self.loop.run_until_complete(ws.prepare(req)) self.loop.run_until_complete(ws.close()) with self.assertRaises(RuntimeError): ws.pong() @@ -209,7 +209,7 @@ def test_pong_closed(self): def test_close_idempotent(self): req = self.make_request('GET', '/') ws = WebSocketResponse() - ws.start(req) + self.loop.run_until_complete(ws.prepare(req)) writer = mock.Mock() ws._writer = writer self.assertTrue( @@ -222,14 +222,14 @@ def test_start_invalid_method(self): req = self.make_request('POST', '/') ws = WebSocketResponse() with self.assertRaises(HTTPMethodNotAllowed): - ws.start(req) + self.loop.run_until_complete(ws.prepare(req)) def test_start_without_upgrade(self): req = self.make_request('GET', '/', headers=CIMultiDict({})) ws = WebSocketResponse() with self.assertRaises(HTTPBadRequest): - ws.start(req) + self.loop.run_until_complete(ws.prepare(req)) def test_wait_closed_before_start(self): @@ -254,7 +254,7 @@ def go(): def test_write_eof_idempotent(self): req = self.make_request('GET', '/') ws = WebSocketResponse() - ws.start(req) + self.loop.run_until_complete(ws.prepare(req)) self.loop.run_until_complete(ws.close()) @asyncio.coroutine @@ -268,7 +268,7 @@ def go(): def test_receive_exc_in_reader(self): req = self.make_request('GET', '/') ws = WebSocketResponse() - ws.start(req) + self.loop.run_until_complete(ws.prepare(req)) exc = ValueError() res = asyncio.Future(loop=self.loop) @@ -287,7 +287,7 @@ def go(): def test_receive_cancelled(self): req = self.make_request('GET', '/') ws = WebSocketResponse() - ws.start(req) + self.loop.run_until_complete(ws.prepare(req)) res = asyncio.Future(loop=self.loop) res.set_exception(asyncio.CancelledError()) @@ -300,7 +300,7 @@ def test_receive_cancelled(self): def test_receive_timeouterror(self): req = self.make_request('GET', '/') ws = WebSocketResponse() - ws.start(req) + self.loop.run_until_complete(ws.prepare(req)) res = asyncio.Future(loop=self.loop) res.set_exception(asyncio.TimeoutError()) @@ -313,7 +313,7 @@ def test_receive_timeouterror(self): def test_receive_client_disconnected(self): req = self.make_request('GET', '/') ws = WebSocketResponse() - ws.start(req) + self.loop.run_until_complete(ws.prepare(req)) exc = errors.ClientDisconnectedError() res = asyncio.Future(loop=self.loop) @@ -333,7 +333,7 @@ def go(): def test_multiple_receive_on_close_connection(self): req = self.make_request('GET', '/') ws = WebSocketResponse() - ws.start(req) + self.loop.run_until_complete(ws.prepare(req)) self.loop.run_until_complete(ws.close()) self.loop.run_until_complete(ws.receive()) self.loop.run_until_complete(ws.receive()) @@ -345,7 +345,7 @@ def test_multiple_receive_on_close_connection(self): def test_concurrent_receive(self): req = self.make_request('GET', '/') ws = WebSocketResponse() - ws.start(req) + self.loop.run_until_complete(ws.prepare(req)) ws._waiting = True self.assertRaises( @@ -356,7 +356,7 @@ def test_close_exc(self): reader = self.reader.set_parser.return_value = mock.Mock() ws = WebSocketResponse() - ws.start(req) + self.loop.run_until_complete(ws.prepare(req)) exc = ValueError() reader.read.return_value = asyncio.Future(loop=self.loop) @@ -376,7 +376,7 @@ def test_close_exc(self): def test_close_exc2(self): req = self.make_request('GET', '/') ws = WebSocketResponse() - ws.start(req) + self.loop.run_until_complete(ws.prepare(req)) exc = ValueError() self.writer.close.side_effect = exc @@ -390,3 +390,17 @@ def test_close_exc2(self): self.writer.close.side_effect = asyncio.CancelledError() self.assertRaises(asyncio.CancelledError, self.loop.run_until_complete, ws.close()) + + def test_start_twice_idempotent(self): + req = self.make_request('GET', '/') + ws = WebSocketResponse() + with self.assertWarns(DeprecationWarning): + impl1 = ws.start(req) + impl2 = ws.start(req) + self.assertIs(impl1, impl2) + + def test_can_start_ok(self): + req = self.make_request('GET', '/') + ws = WebSocketResponse(protocols=('chat',)) + with self.assertWarns(DeprecationWarning): + self.assertEqual((True, 'chat'), ws.can_start(req)) diff --git a/tests/test_web_websocket_functional.py b/tests/test_web_websocket_functional.py index f4e410eca7b..0b62b672659 100644 --- a/tests/test_web_websocket_functional.py +++ b/tests/test_web_websocket_functional.py @@ -84,7 +84,7 @@ def test_send_recv_text(self): @asyncio.coroutine def handler(request): ws = web.WebSocketResponse() - ws.start(request) + yield from ws.prepare(request) msg = yield from ws.receive_str() ws.send_str(msg+'/answer') yield from ws.close() @@ -119,7 +119,7 @@ def test_send_recv_bytes(self): @asyncio.coroutine def handler(request): ws = web.WebSocketResponse() - ws.start(request) + yield from ws.prepare(request) msg = yield from ws.receive_bytes() ws.send_bytes(msg+b'/answer') @@ -154,7 +154,7 @@ def test_auto_pong_with_closing_by_peer(self): @asyncio.coroutine def handler(request): ws = web.WebSocketResponse() - ws.start(request) + yield from ws.prepare(request) yield from ws.receive() msg = yield from ws.receive() @@ -186,7 +186,7 @@ def test_ping(self): @asyncio.coroutine def handler(request): ws = web.WebSocketResponse() - ws.start(request) + yield from ws.prepare(request) ws.ping('data') yield from ws.receive() @@ -214,7 +214,7 @@ def test_client_ping(self): @asyncio.coroutine def handler(request): ws = web.WebSocketResponse() - ws.start(request) + yield from ws.prepare(request) yield from ws.receive() closed.set_result(None) @@ -242,7 +242,7 @@ def test_pong(self): @asyncio.coroutine def handler(request): ws = web.WebSocketResponse(autoping=False) - ws.start(request) + yield from ws.prepare(request) msg = yield from ws.receive() self.assertEqual(msg.tp, web.MsgType.ping) @@ -279,7 +279,7 @@ def handler(request): ws = web.WebSocketResponse() ws.set_status(200) self.assertEqual(200, ws.status) - ws.start(request) + yield from ws.prepare(request) self.assertEqual(101, ws.status) yield from ws.close() closed.set_result(None) @@ -302,7 +302,7 @@ def test_handle_protocol(self): @asyncio.coroutine def handler(request): ws = web.WebSocketResponse(protocols=('foo', 'bar')) - ws.start(request) + yield from ws.prepare(request) yield from ws.close() self.assertEqual('bar', ws.protocol) closed.set_result(None) @@ -326,7 +326,7 @@ def test_server_close_handshake(self): @asyncio.coroutine def handler(request): ws = web.WebSocketResponse(protocols=('foo', 'bar')) - ws.start(request) + yield from ws.prepare(request) yield from ws.close() closed.set_result(None) return ws @@ -352,7 +352,7 @@ def test_client_close_handshake(self): def handler(request): ws = web.WebSocketResponse( autoclose=False, protocols=('foo', 'bar')) - ws.start(request) + yield from ws.prepare(request) msg = yield from ws.receive() self.assertEqual(msg.tp, web.MsgType.close) @@ -387,7 +387,7 @@ def test_server_close_handshake_server_eats_client_messages(self): @asyncio.coroutine def handler(request): ws = web.WebSocketResponse(protocols=('foo', 'bar')) - ws.start(request) + yield from ws.prepare(request) yield from ws.close() closed.set_result(None) return ws @@ -411,3 +411,24 @@ def go(): response.close() self.loop.run_until_complete(go()) + + def test_receive_msg(self): + @asyncio.coroutine + def handler(request): + ws = web.WebSocketResponse() + yield from ws.prepare(request) + + with self.assertWarns(DeprecationWarning): + msg = yield from ws.receive_msg() + self.assertEqual(msg.data, b'data') + yield from ws.close() + return ws + + @asyncio.coroutine + def go(): + _, _, url = yield from self.create_server('GET', '/', handler) + resp = yield from aiohttp.ws_connect(url, loop=self.loop) + resp.send_bytes(b'data') + yield from resp.close() + + self.loop.run_until_complete(go())