77
88import redis
99from redis .asyncio .connection import (
10+ BaseParser ,
1011 Connection ,
1112 PythonParser ,
1213 UnixDomainSocketConnection ,
@@ -24,16 +25,19 @@ async def test_invalid_response(create_redis):
2425 r = await create_redis (single_connection_client = True )
2526
2627 raw = b"x"
28+ fake_stream = FakeStream (raw + b"\r \n " )
2729
28- parser : "PythonParser" = r .connection ._parser
29- if not isinstance (parser , PythonParser ):
30- pytest .skip ("PythonParser only" )
31- stream_mock = mock .Mock (parser ._stream )
32- stream_mock .readline .return_value = raw + b"\r \n "
33- with mock .patch .object (parser , "_stream" , stream_mock ):
30+ parser : BaseParser = r .connection ._parser
31+ with mock .patch .object (parser , "_stream" , fake_stream ):
3432 with pytest .raises (InvalidResponse ) as cm :
3533 await parser .read_response ()
36- assert str (cm .value ) == f"Protocol Error: { raw !r} "
34+ if isinstance (parser , PythonParser ):
35+ assert str (cm .value ) == f"Protocol Error: { raw !r} "
36+ else :
37+ assert (
38+ str (cm .value ) == f'Protocol error, got "{ raw .decode ()} " as reply type byte'
39+ )
40+ await r .connection .disconnect ()
3741
3842
3943@skip_if_server_version_lt ("4.0.0" )
@@ -115,26 +119,27 @@ async def test_connect_timeout_error_without_retry():
115119 assert str (e .value ) == "Timeout connecting to server"
116120
117121
118- class TestError (BaseException ):
119- pass
120-
121-
122- class InterruptingReader :
122+ class FakeStream :
123123 """
124124 A class simulating an asyncio input buffer, but raising a
125125 special exception every other read.
126126 """
127127
128- def __init__ (self , data ):
128+ class TestError (BaseException ):
129+ pass
130+
131+ def __init__ (self , data , interrupt_every = 0 ):
129132 self .data = data
130133 self .counter = 0
131134 self .pos = 0
135+ self .interrupt_every = interrupt_every
132136
133137 def tick (self ):
134138 self .counter += 1
135- # return
136- if (self .counter % 2 ) == 0 :
137- raise TestError ()
139+ if not self .interrupt_every :
140+ return
141+ if (self .counter % self .interrupt_every ) == 0 :
142+ raise self .TestError ()
138143
139144 async def read (self , want ):
140145 self .tick ()
@@ -176,12 +181,12 @@ async def test_connection_parse_response_resume(r: redis.Redis):
176181 b"$25\r \n hi\r \n there\r \n +how\r \n are\r \n you\r \n "
177182 )
178183
179- conn ._parser ._stream = InterruptingReader (message )
184+ conn ._parser ._stream = FakeStream (message , interrupt_every = 2 )
180185 for i in range (100 ):
181186 try :
182187 response = await conn .read_response ()
183188 break
184- except TestError :
189+ except FakeStream . TestError :
185190 pass
186191
187192 else :
0 commit comments