diff --git a/CHANGES.txt b/CHANGES.txt index 3ba9fde2576..d48fdcc73de 100644 --- a/CHANGES.txt +++ b/CHANGES.txt @@ -14,3 +14,7 @@ CHANGES * default headers in ClientSession are now case-insensitive * Make '=' char and 'wss://' schema safe in urls #477 + +* `ClientResponse.close()` forces connection closing by default from now #479 + N.B. Backward incompatible change: was `.close(force=False) + Using `force` parameter for the method is deprecated: use `.release()` instead. diff --git a/aiohttp/client_reqrep.py b/aiohttp/client_reqrep.py index 53cdede61c4..0a945f1212f 100644 --- a/aiohttp/client_reqrep.py +++ b/aiohttp/client_reqrep.py @@ -598,7 +598,10 @@ def start(self, connection, read_until_eof=False): 'Can not load response cookies: %s', exc) return self - def close(self, force=False): + def close(self, force=True): + if not force: + warnings.warn("force parameter should be True", DeprecationWarning, + stacklevel=2) if self._closed: return @@ -609,29 +612,30 @@ def close(self, force=False): return if self._connection is not None: - if self.content and not self.content.at_eof(): - force = True + self._connection.close() + self._connection = None + self._cleanup_writer() - if force: - self._connection.close() - else: + @asyncio.coroutine + def release(self): + try: + content = self.content + if content is not None and not content.at_eof(): + chunk = yield from content.readany() + while chunk is not EOF_MARKER or chunk: + chunk = yield from content.readany() + finally: + if self._connection is not None: self._connection.release() if self._reader is not None: self._reader.unset_parser() + self._connection = None + self._cleanup_writer() - self._connection = None + def _cleanup_writer(self): if self._writer is not None and not self._writer.done(): self._writer.cancel() - self._writer = None - - @asyncio.coroutine - def release(self): - try: - chunk = yield from self.content.readany() - while chunk is not EOF_MARKER or chunk: - chunk = yield from self.content.readany() - finally: - self.close() + self._writer = None @asyncio.coroutine def wait_for_close(self): @@ -640,7 +644,7 @@ def wait_for_close(self): yield from self._writer finally: self._writer = None - self.close() + yield from self.release() @asyncio.coroutine def read(self, decode=False): @@ -649,10 +653,10 @@ def read(self, decode=False): try: self._content = yield from self.content.read() except: - self.close(True) + self.close() raise else: - self.close() + yield from self.release() data = self._content diff --git a/tests/test_client_response.py b/tests/test_client_response.py index a60416bf051..4798dcbe222 100644 --- a/tests/test_client_response.py +++ b/tests/test_client_response.py @@ -44,7 +44,6 @@ def test_close(self): self.response._connection = self.connection self.response.close() self.assertIsNone(self.response.connection) - self.assertTrue(self.connection.release.called) self.response.close() self.response.close() @@ -76,7 +75,7 @@ def side_effect(*args, **kwargs): res = self.loop.run_until_complete(self.response.read()) self.assertEqual(res, b'payload') - self.assertTrue(self.response.close.called) + self.assertIsNone(self.response._connection) def test_read_and_release_connection_with_error(self): content = self.response.content = unittest.mock.Mock() @@ -87,7 +86,7 @@ def test_read_and_release_connection_with_error(self): self.assertRaises( ValueError, self.loop.run_until_complete, self.response.read()) - self.response.close.assert_called_with(True) + self.response.close.assert_called_with() def test_release(self): fut = asyncio.Future(loop=self.loop) @@ -97,7 +96,7 @@ def test_release(self): self.response.close = unittest.mock.Mock() self.loop.run_until_complete(self.response.release()) - self.assertTrue(self.response.close.called) + self.assertIsNone(self.response._connection) def test_read_and_close(self): self.response.read = unittest.mock.Mock() @@ -133,7 +132,7 @@ def side_effect(*args, **kwargs): res = self.loop.run_until_complete(self.response.text()) self.assertEqual(res, '{"тест": "пройден"}') - self.assertTrue(self.response.close.called) + self.assertIsNone(self.response._connection) def test_text_custom_encoding(self): def side_effect(*args, **kwargs): @@ -150,7 +149,7 @@ def side_effect(*args, **kwargs): res = self.loop.run_until_complete( self.response.text(encoding='cp1251')) self.assertEqual(res, '{"тест": "пройден"}') - self.assertTrue(self.response.close.called) + self.assertIsNone(self.response._connection) self.assertFalse(self.response._get_encoding.called) def test_text_detect_encoding(self): @@ -166,7 +165,7 @@ def side_effect(*args, **kwargs): self.loop.run_until_complete(self.response.read()) res = self.loop.run_until_complete(self.response.text()) self.assertEqual(res, '{"тест": "пройден"}') - self.assertTrue(self.response.close.called) + self.assertIsNone(self.response._connection) def test_text_after_read(self): def side_effect(*args, **kwargs): @@ -181,7 +180,7 @@ def side_effect(*args, **kwargs): res = self.loop.run_until_complete(self.response.text()) self.assertEqual(res, '{"тест": "пройден"}') - self.assertTrue(self.response.close.called) + self.assertIsNone(self.response._connection) def test_json(self): def side_effect(*args, **kwargs): @@ -196,7 +195,7 @@ def side_effect(*args, **kwargs): res = self.loop.run_until_complete(self.response.json()) self.assertEqual(res, {'тест': 'пройден'}) - self.assertTrue(self.response.close.called) + self.assertIsNone(self.response._connection) def test_json_custom_loader(self): self.response.headers = { @@ -237,7 +236,7 @@ def side_effect(*args, **kwargs): res = self.loop.run_until_complete( self.response.json(encoding='cp1251')) self.assertEqual(res, {'тест': 'пройден'}) - self.assertTrue(self.response.close.called) + self.assertIsNone(self.response._connection) self.assertFalse(self.response._get_encoding.called) def test_json_detect_encoding(self): @@ -252,7 +251,7 @@ def side_effect(*args, **kwargs): res = self.loop.run_until_complete(self.response.json()) self.assertEqual(res, {'тест': 'пройден'}) - self.assertTrue(self.response.close.called) + self.assertIsNone(self.response._connection) def test_override_flow_control(self): class MyResponse(ClientResponse): @@ -269,3 +268,9 @@ def test_get_encoding_unknown(self, m_chardet): self.response.headers = {'CONTENT-TYPE': 'application/json'} self.assertEqual(self.response._get_encoding(), 'utf-8') + + def test_close_deprecated(self): + self.response._connection = self.connection + with self.assertWarns(DeprecationWarning): + self.response.close(force=False) + self.assertIsNone(self.response._connection)