Skip to content

Commit 7b202cc

Browse files
versusvoidfantix
authored andcommitted
Restore context on listen in UVStreamServer. Fix #305
1 parent ae44ec2 commit 7b202cc

File tree

3 files changed

+51
-1
lines changed

3 files changed

+51
-1
lines changed

tests/test_context.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,23 @@
22
import contextvars
33
import decimal
44
import random
5+
import socket
56
import weakref
67

78
from uvloop import _testbase as tb
89

910

11+
class _Protocol(asyncio.Protocol):
12+
def __init__(self, *, loop=None):
13+
self.done = asyncio.Future(loop=loop)
14+
15+
def connection_lost(self, exc):
16+
if exc is None:
17+
self.done.set_result(None)
18+
else:
19+
self.done.set_exception(exc)
20+
21+
1022
class _ContextBaseTests:
1123

1224
def test_task_decimal_context(self):
@@ -126,6 +138,40 @@ async def main():
126138
del tracked
127139
self.assertIsNone(ref())
128140

141+
def test_create_server_protocol_factory_context(self):
142+
cvar = contextvars.ContextVar('cvar', default='outer')
143+
factory_called_future = self.loop.create_future()
144+
proto = _Protocol(loop=self.loop)
145+
146+
def factory():
147+
try:
148+
self.assertEqual(cvar.get(), 'inner')
149+
except Exception as e:
150+
factory_called_future.set_exception(e)
151+
else:
152+
factory_called_future.set_result(None)
153+
154+
return proto
155+
156+
async def test():
157+
cvar.set('inner')
158+
port = tb.find_free_port()
159+
srv = await self.loop.create_server(factory, '127.0.0.1', port)
160+
161+
s = socket.socket(socket.AF_INET)
162+
with s:
163+
s.setblocking(False)
164+
await self.loop.sock_connect(s, ('127.0.0.1', port))
165+
166+
try:
167+
await factory_called_future
168+
finally:
169+
srv.close()
170+
await proto.done
171+
await srv.wait_closed()
172+
173+
self.loop.run_until_complete(test())
174+
129175

130176
class Test_UV_Context(_ContextBaseTests, tb.UVTestCase):
131177
pass

uvloop/handles/streamserver.pxd

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ cdef class UVStreamServer(UVSocketHandle):
77
object protocol_factory
88
bint opened
99
Server _server
10+
object listen_context
1011

1112
# All "inline" methods are final
1213

uvloop/handles/streamserver.pyx

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ cdef class UVStreamServer(UVSocketHandle):
88
self.ssl_handshake_timeout = None
99
self.ssl_shutdown_timeout = None
1010
self.protocol_factory = None
11+
self.listen_context = None
1112

1213
cdef inline _init(self, Loop loop, object protocol_factory,
1314
Server server,
@@ -53,6 +54,8 @@ cdef class UVStreamServer(UVSocketHandle):
5354
if self.opened != 1:
5455
raise RuntimeError('unopened TCPServer')
5556

57+
self.listen_context = Context_CopyCurrent()
58+
5659
err = uv.uv_listen(<uv.uv_stream_t*> self._handle,
5760
self.backlog,
5861
__uv_streamserver_on_listen)
@@ -64,7 +67,7 @@ cdef class UVStreamServer(UVSocketHandle):
6467
cdef inline _on_listen(self):
6568
cdef UVStream client
6669

67-
protocol = self.protocol_factory()
70+
protocol = self.listen_context.run(self.protocol_factory)
6871

6972
if self.ssl is None:
7073
client = self._make_new_transport(protocol, None)

0 commit comments

Comments
 (0)