@@ -629,6 +629,154 @@ async def runner():
629629
630630class Test_UV_TCP (_TestTCP , tb .UVTestCase ):
631631
632+ def test_create_server_buffered_1 (self ):
633+ SIZE = 123123
634+
635+ class Proto (asyncio .BaseProtocol ):
636+ def connection_made (self , tr ):
637+ self .tr = tr
638+ self .recvd = b''
639+ self .data = bytearray (50 )
640+ self .buf = memoryview (self .data )
641+
642+ def get_buffer (self ):
643+ return self .buf
644+
645+ def buffer_updated (self , nbytes ):
646+ self .recvd += self .buf [:nbytes ]
647+ if self .recvd == b'a' * SIZE :
648+ self .tr .write (b'hello' )
649+
650+ def eof_received (self ):
651+ pass
652+
653+ async def test ():
654+ port = tb .find_free_port ()
655+ srv = await self .loop .create_server (Proto , '127.0.0.1' , port )
656+
657+ s = socket .socket (socket .AF_INET )
658+ with s :
659+ s .setblocking (False )
660+ await self .loop .sock_connect (s , ('127.0.0.1' , port ))
661+ await self .loop .sock_sendall (s , b'a' * SIZE )
662+ d = await self .loop .sock_recv (s , 100 )
663+ self .assertEqual (d , b'hello' )
664+
665+ srv .close ()
666+ await srv .wait_closed ()
667+
668+ self .loop .run_until_complete (test ())
669+
670+ def test_create_server_buffered_2 (self ):
671+ class ProtoExc (asyncio .BaseProtocol ):
672+ def __init__ (self ):
673+ self ._lost_exc = None
674+
675+ def get_buffer (self ):
676+ 1 / 0
677+
678+ def buffer_updated (self , nbytes ):
679+ pass
680+
681+ def connection_lost (self , exc ):
682+ self ._lost_exc = exc
683+
684+ def eof_received (self ):
685+ pass
686+
687+ class ProtoZeroBuf1 (asyncio .BaseProtocol ):
688+ def __init__ (self ):
689+ self ._lost_exc = None
690+
691+ def get_buffer (self ):
692+ return bytearray (0 )
693+
694+ def buffer_updated (self , nbytes ):
695+ pass
696+
697+ def connection_lost (self , exc ):
698+ self ._lost_exc = exc
699+
700+ def eof_received (self ):
701+ pass
702+
703+ class ProtoZeroBuf2 (asyncio .BaseProtocol ):
704+ def __init__ (self ):
705+ self ._lost_exc = None
706+
707+ def get_buffer (self ):
708+ return memoryview (bytearray (0 ))
709+
710+ def buffer_updated (self , nbytes ):
711+ pass
712+
713+ def connection_lost (self , exc ):
714+ self ._lost_exc = exc
715+
716+ def eof_received (self ):
717+ pass
718+
719+ class ProtoUpdatedError (asyncio .BaseProtocol ):
720+ def __init__ (self ):
721+ self ._lost_exc = None
722+
723+ def get_buffer (self ):
724+ return memoryview (bytearray (100 ))
725+
726+ def buffer_updated (self , nbytes ):
727+ raise RuntimeError ('oups' )
728+
729+ def connection_lost (self , exc ):
730+ self ._lost_exc = exc
731+
732+ def eof_received (self ):
733+ pass
734+
735+ async def test (proto_factory , exc_type , exc_re ):
736+ port = tb .find_free_port ()
737+ proto = proto_factory ()
738+ srv = await self .loop .create_server (
739+ lambda : proto , '127.0.0.1' , port )
740+
741+ try :
742+ s = socket .socket (socket .AF_INET )
743+ with s :
744+ s .setblocking (False )
745+ await self .loop .sock_connect (s , ('127.0.0.1' , port ))
746+ await self .loop .sock_sendall (s , b'a' )
747+ d = await self .loop .sock_recv (s , 100 )
748+ if not d :
749+ raise ConnectionResetError
750+ except ConnectionResetError :
751+ pass
752+ else :
753+ self .fail ("server didn't abort the connection" )
754+ return
755+ finally :
756+ srv .close ()
757+ await srv .wait_closed ()
758+
759+ if proto ._lost_exc is None :
760+ self .fail ("connection_lost() was not called" )
761+ return
762+
763+ with self .assertRaisesRegex (exc_type , exc_re ):
764+ raise proto ._lost_exc
765+
766+ self .loop .set_exception_handler (lambda loop , ctx : None )
767+
768+ self .loop .run_until_complete (
769+ test (ProtoExc , RuntimeError , 'unhandled error .* get_buffer' ))
770+
771+ self .loop .run_until_complete (
772+ test (ProtoZeroBuf1 , RuntimeError , 'unhandled error .* get_buffer' ))
773+
774+ self .loop .run_until_complete (
775+ test (ProtoZeroBuf2 , RuntimeError , 'unhandled error .* get_buffer' ))
776+
777+ self .loop .run_until_complete (
778+ test (ProtoUpdatedError , RuntimeError , r'^oups$' ))
779+
632780 def test_transport_get_extra_info (self ):
633781 # This tests is only for uvloop. asyncio should pass it
634782 # too in Python 3.6.
0 commit comments