55import socket
66import threading
77import weakref
8+ from io import SEEK_END
89from itertools import chain
910from queue import Empty , Full , LifoQueue
1011from time import time
11- from typing import Optional
12+ from typing import Optional , Union
1213from urllib .parse import parse_qs , unquote , urlparse
1314
1415from redis .backoff import NoBackoff
@@ -163,39 +164,47 @@ def parse_error(self, response):
163164
164165
165166class SocketBuffer :
166- def __init__ (self , socket , socket_read_size , socket_timeout ):
167+ def __init__ (
168+ self , socket : socket .socket , socket_read_size : int , socket_timeout : float
169+ ):
167170 self ._sock = socket
168171 self .socket_read_size = socket_read_size
169172 self .socket_timeout = socket_timeout
170173 self ._buffer = io .BytesIO ()
171- # number of bytes written to the buffer from the socket
172- self .bytes_written = 0
173- # number of bytes read from the buffer
174- self .bytes_read = 0
175174
176- @property
177- def length (self ):
178- return self .bytes_written - self .bytes_read
175+ def unread_bytes (self ) -> int :
176+ """
177+ Remaining unread length of buffer
178+ """
179+ pos = self ._buffer .tell ()
180+ end = self ._buffer .seek (0 , SEEK_END )
181+ self ._buffer .seek (pos )
182+ return end - pos
179183
180- def _read_from_socket (self , length = None , timeout = SENTINEL , raise_on_timeout = True ):
184+ def _read_from_socket (
185+ self ,
186+ length : Optional [int ] = None ,
187+ timeout : Union [float , object ] = SENTINEL ,
188+ raise_on_timeout : Optional [bool ] = True ,
189+ ) -> bool :
181190 sock = self ._sock
182191 socket_read_size = self .socket_read_size
183- buf = self ._buffer
184- buf .seek (self .bytes_written )
185192 marker = 0
186193 custom_timeout = timeout is not SENTINEL
187194
195+ buf = self ._buffer
196+ current_pos = buf .tell ()
197+ buf .seek (0 , SEEK_END )
198+ if custom_timeout :
199+ sock .settimeout (timeout )
188200 try :
189- if custom_timeout :
190- sock .settimeout (timeout )
191201 while True :
192202 data = self ._sock .recv (socket_read_size )
193203 # an empty string indicates the server shutdown the socket
194204 if isinstance (data , bytes ) and len (data ) == 0 :
195205 raise ConnectionError (SERVER_CLOSED_CONNECTION_ERROR )
196206 buf .write (data )
197207 data_length = len (data )
198- self .bytes_written += data_length
199208 marker += data_length
200209
201210 if length is not None and length > marker :
@@ -215,55 +224,51 @@ def _read_from_socket(self, length=None, timeout=SENTINEL, raise_on_timeout=True
215224 return False
216225 raise ConnectionError (f"Error while reading from socket: { ex .args } " )
217226 finally :
227+ buf .seek (current_pos )
218228 if custom_timeout :
219229 sock .settimeout (self .socket_timeout )
220230
221- def can_read (self , timeout ) :
222- return bool (self .length ) or self ._read_from_socket (
231+ def can_read (self , timeout : float ) -> bool :
232+ return bool (self .unread_bytes () ) or self ._read_from_socket (
223233 timeout = timeout , raise_on_timeout = False
224234 )
225235
226- def read (self , length ) :
236+ def read (self , length : int ) -> bytes :
227237 length = length + 2 # make sure to read the \r\n terminator
228238 # make sure we've read enough data from the socket
229- if length > self .length :
230- self ._read_from_socket (length - self .length )
239+ if length > self .unread_bytes :
240+ self ._read_from_socket (length - self .unread_bytes )
231241
232- self ._buffer .seek (self .bytes_read )
233242 data = self ._buffer .read (length )
234- self .bytes_read += len (data )
235243 return data [:- 2 ]
236244
237- def readline (self ):
245+ def readline (self ) -> bytes :
238246 buf = self ._buffer
239- buf .seek (self .bytes_read )
240247 data = buf .readline ()
241248 while not data .endswith (SYM_CRLF ):
242249 # there's more data in the socket that we need
243250 self ._read_from_socket ()
244- buf .seek (self .bytes_read )
245251 data = buf .readline ()
246252
247- self .bytes_read += len (data )
248253 return data [:- 2 ]
249254
250- def get_pos (self ):
255+ def get_pos (self ) -> int :
251256 """
252257 Get current read position
253258 """
254- return self .bytes_read
259+ return self ._buffer . tell ()
255260
256- def rewind (self , pos ) :
261+ def rewind (self , pos : int ) -> None :
257262 """
258263 Rewind the buffer to a specific position, to re-start reading
259264 """
260- self .bytes_read = pos
265+ self ._buffer . seek ( pos )
261266
262- def purge (self ):
267+ def purge (self ) -> None :
263268 """
264269 After a successful read, purge the read part of buffer
265270 """
266- unread = self .bytes_written - self . bytes_read
271+ unread = self .unread_bytes ()
267272
268273 # Only if we have read all of the buffer do we truncate, to
269274 # reduce the amount of memory thrashing. This heuristic
@@ -276,13 +281,10 @@ def purge(self):
276281 view = self ._buffer .getbuffer ()
277282 view [:unread ] = view [- unread :]
278283 self ._buffer .truncate (unread )
279- self .bytes_written = unread
280- self .bytes_read = 0
281284 self ._buffer .seek (0 )
282285
283- def close (self ):
286+ def close (self ) -> None :
284287 try :
285- self .bytes_written = self .bytes_read = 0
286288 self ._buffer .close ()
287289 except Exception :
288290 # issue #633 suggests the purge/close somehow raised a
@@ -498,6 +500,7 @@ def read_response(self, disable_decoding=False):
498500 return response
499501
500502
503+ DefaultParser : BaseParser
501504if HIREDIS_AVAILABLE :
502505 DefaultParser = HiredisParser
503506else :
0 commit comments