Skip to content

Commit

Permalink
clean up more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mrocklin committed Oct 8, 2018
1 parent 44a0788 commit d14da81
Showing 1 changed file with 19 additions and 10 deletions.
29 changes: 19 additions & 10 deletions distributed/comm/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ async def read(self, deserializers=None):
else:
frame = b''
frames.append(frame)
except asyncio.streams.IncompleteReadError as e:
except (asyncio.streams.IncompleteReadError, EnvironmentError) as e:
self.reader = None
self.writer = None
if not shutting_down():
Expand Down Expand Up @@ -169,25 +169,33 @@ async def write(self, msg, serializers=None, on_error='message'):
except asyncio.streams.IncompleteReadError as e:
self.reader = None
self.writer = None
raise CommClosedError()
# convert_stream_closed_error(self, e)
# raise CommClosedError()
convert_stream_closed_error(self, e)

return sum(map(nbytes, frames))

@gen.coroutine
def close(self):
writer, self.writer = self.writer, None
if not is_closed(writer):
self._finalizer.detach()
writer.close()
yield writer.wait_closed()
try:
yield writer.drain()
sock = writer._transport.get_extra_info('socket')
if sock:
sock.shutdown(socket.SHUT_RDWR)
except EnvironmentError:
pass
finally:
self._finalizer.detach()
writer.close()
# yield writer.wait_closed()

def abort(self):
writer, self.writer = self.writer, None
if not is_closed(writer):
self._finalizer.detach()
writer.close()
yield writer.wait_closed()
# yield writer.wait_closed()

def closed(self):
return is_closed(self.writer)
Expand Down Expand Up @@ -309,18 +317,19 @@ def _check_started(self):
if self.server is None:
raise ValueError("invalid operation on non-started TCPListener")

async def _handle_stream(self, reader, writer, *args, **kwargs):
@gen.coroutine
def _handle_stream(self, reader, writer, *args, **kwargs):
host, ip = writer.get_extra_info('peername')
address = self.prefix + unparse_host_port(host, ip)
reader, writer = await self._prepare_stream(reader, writer, address)
reader, writer = yield self._prepare_stream(reader, writer, address)
# if stream is None:
# # Preparation failed
# return
logger.debug("Incoming connection from %r to %r",
address, self.contact_address)
local_address = self.prefix + get_stream_address(reader)
comm = self.comm_class(reader, writer, local_address, address, self.deserialize)
await self.comm_handler(comm)
yield self.comm_handler(comm)

def get_host_port(self):
"""
Expand Down

0 comments on commit d14da81

Please sign in to comment.