diff --git a/telnetlib3/server.py b/telnetlib3/server.py index a9b808e..2d61287 100755 --- a/telnetlib3/server.py +++ b/telnetlib3/server.py @@ -14,6 +14,7 @@ import asyncio import logging import signal +from weakref import proxy # local from . import server_base @@ -111,7 +112,7 @@ def check_negotiation(self, final=False): encoding = self.encoding(outgoing=True, incoming=True) if not self.waiter_encoding.done() and result: self.log.debug('encoding complete: {0!r}'.format(encoding)) - self.waiter_encoding.set_result(self) + self.waiter_encoding.set_result(proxy(self)) elif (not self.waiter_encoding.done() and self.writer.remote_option.get(TTYPE) is False): @@ -120,13 +121,13 @@ def check_negotiation(self, final=False): # the distant end would not support it, declaring encoding failed. self.log.debug('encoding failed after {0:1.2f}s: {1}' .format(self.duration, encoding)) - self.waiter_encoding.set_result(self) + self.waiter_encoding.set_result(proxy(self)) return parent elif not self.waiter_encoding.done() and final: self.log.debug('encoding failed after {0:1.2f}s: {1}' .format(self.duration, encoding)) - self.waiter_encoding.set_result(self) + self.waiter_encoding.set_result(proxy(self)) return parent return parent and result diff --git a/telnetlib3/server_base.py b/telnetlib3/server_base.py index b7e482b..e81e3b8 100644 --- a/telnetlib3/server_base.py +++ b/telnetlib3/server_base.py @@ -4,6 +4,7 @@ import logging import datetime import sys +from weakref import proxy from .stream_writer import (TelnetWriter, TelnetWriterUnicode) from .stream_reader import (TelnetReader, TelnetReaderUnicode) @@ -85,7 +86,11 @@ def connection_lost(self, exc): self._transport.close() self._waiter_connected.cancel() if self.shell is None: - self._waiter_closed.set_result(self) + self._waiter_closed.set_result(proxy(self)) + + # break circular refrences. + self._transport = None + self.reader.fn_encoding = None def connection_made(self, transport): """ @@ -134,7 +139,7 @@ def begin_shell(self, result): if asyncio.iscoroutine(coro): fut = self._loop.create_task(coro) fut.add_done_callback( - lambda fut_obj: self._waiter_closed.set_result(self)) + lambda fut_obj: self._waiter_closed.set_result(proxy(self))) def data_received(self, data): """Process bytes received by transport.""" @@ -277,11 +282,11 @@ def _check_negotiation_timer(self): if self.check_negotiation(final=final): self.log.debug('negotiation complete after {:1.2f}s.' .format(self.duration)) - self._waiter_connected.set_result(self) + self._waiter_connected.set_result(proxy(self)) elif final: self.log.debug('negotiation failed after {:1.2f}s.' .format(self.duration)) - self._waiter_connected.set_result(self) + self._waiter_connected.set_result(proxy(self)) else: # keep re-queuing until complete self._check_later = self._loop.call_later( diff --git a/telnetlib3/stream_writer.py b/telnetlib3/stream_writer.py index 798a808..30eaefb 100644 --- a/telnetlib3/stream_writer.py +++ b/telnetlib3/stream_writer.py @@ -182,6 +182,17 @@ def __init__(self, transport, protocol, *, client=False, server=False, # Base protocol methods + def close(self): + super().close() + # break circular refs + self._ext_callback.clear() + self._ext_send_callback.clear() + self._slc_callback.clear() + self._iac_callback.clear() + self.fn_encoding = None + self._protocol = None + self._transport = None + def __repr__(self): """Description of stream encoding state.""" info = ['TelnetWriter'] diff --git a/telnetlib3/tests/test_shell.py b/telnetlib3/tests/test_shell.py index fa50951..c4b9d7a 100644 --- a/telnetlib3/tests/test_shell.py +++ b/telnetlib3/tests/test_shell.py @@ -1,6 +1,7 @@ """Test the server's shell(reader, writer) callback.""" # std imports import asyncio +import weakref # local imports import telnetlib3 @@ -22,6 +23,7 @@ async def test_telnet_server_shell_as_coroutine(event_loop, bind_host, from telnetlib3.telopt import IAC, DO, WONT, TTYPE # given, _waiter = asyncio.Future() + _saved = weakref.WeakSet() send_input = 'Alpha' expect_output = 'Beta' expect_hello = IAC + DO + TTYPE @@ -33,6 +35,10 @@ def shell(reader, writer): inp = yield from reader.readexactly(len(send_input)) assert inp == send_input writer.write(expect_output) + _saved.add(writer) + _saved.add(writer._protocol) + yield from writer.drain() + writer.close() # exercise, await telnetlib3.create_server( @@ -61,6 +67,9 @@ def shell(reader, writer): # verify, assert server_output.decode('ascii') == expect_output + + # no leaks + assert len(_saved) == 0 @pytest.mark.asyncio