Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BrokerConnection receive bytes pipe #1032

Merged
merged 2 commits into from
Aug 15, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 3 additions & 13 deletions kafka/client_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,25 +605,14 @@ def _poll(self, timeout, sleep=True):
continue

self._idle_expiry_manager.update(conn.node_id)

# Accumulate as many responses as the connection has pending
while conn.in_flight_requests:
response = conn.recv() # Note: conn.recv runs callbacks / errbacks

# Incomplete responses are buffered internally
# while conn.in_flight_requests retains the request
if not response:
break
responses.append(response)
responses.extend(conn.recv()) # Note: conn.recv runs callbacks / errbacks

# Check for additional pending SSL bytes
if self.config['security_protocol'] in ('SSL', 'SASL_SSL'):
# TODO: optimize
for conn in self._conns.values():
if conn not in processed and conn.connected() and conn._sock.pending():
response = conn.recv()
if response:
responses.append(response)
responses.extend(conn.recv())

for conn in six.itervalues(self._conns):
if conn.requests_timed_out():
Expand All @@ -635,6 +624,7 @@ def _poll(self, timeout, sleep=True):

if self._sensors:
self._sensors.io_time.record((time.time() - end_select) * 1000000000)

self._maybe_close_oldest_connection()
return responses

Expand Down
161 changes: 84 additions & 77 deletions kafka/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import copy
import errno
import logging
import io
from random import shuffle, uniform
import socket
import time
Expand All @@ -18,6 +17,7 @@
from kafka.protocol.api import RequestHeader
from kafka.protocol.admin import SaslHandShakeRequest
from kafka.protocol.commit import GroupCoordinatorResponse, OffsetFetchRequest
from kafka.protocol.frame import KafkaBytes
from kafka.protocol.metadata import MetadataRequest
from kafka.protocol.fetch import FetchRequest
from kafka.protocol.types import Int32
Expand Down Expand Up @@ -234,9 +234,9 @@ def __init__(self, host, port, afi, **configs):
if self.config['ssl_context'] is not None:
self._ssl_context = self.config['ssl_context']
self._sasl_auth_future = None
self._rbuffer = io.BytesIO()
self._header = KafkaBytes(4)
self._rbuffer = None
self._receiving = False
self._next_payload_bytes = 0
self.last_attempt = 0
self._processing = False
self._correlation_id = 0
Expand Down Expand Up @@ -629,17 +629,19 @@ def close(self, error=None):
self.state = ConnectionStates.DISCONNECTED
self.last_attempt = time.time()
self._sasl_auth_future = None
self._receiving = False
self._next_payload_bytes = 0
self._rbuffer.seek(0)
self._rbuffer.truncate()
self._reset_buffer()
if error is None:
error = Errors.Cancelled(str(self))
while self.in_flight_requests:
ifr = self.in_flight_requests.popleft()
ifr.future.failure(error)
self.config['state_change_callback'](self)

def _reset_buffer(self):
self._receiving = False
self._header.seek(0)
self._rbuffer = None

def send(self, request):
"""send request, return Future()

Expand Down Expand Up @@ -713,11 +715,11 @@ def recv(self):
# fail all the pending request futures
if self.in_flight_requests:
self.close(Errors.ConnectionError('Socket not connected during recv with in-flight-requests'))
return None
return ()

elif not self.in_flight_requests:
log.warning('%s: No in-flight-requests to recv', self)
return None
return ()

response = self._recv()
if not response and self.requests_timed_out():
Expand All @@ -726,103 +728,108 @@ def recv(self):
self.close(error=Errors.RequestTimedOutError(
'Request timed out after %s ms' %
self.config['request_timeout_ms']))
return None
return ()
return response

def _recv(self):
# Not receiving is the state of reading the payload header
if not self._receiving:
responses = []
SOCK_CHUNK_BYTES = 4096
while True:
try:
bytes_to_read = 4 - self._rbuffer.tell()
data = self._sock.recv(bytes_to_read)
data = self._sock.recv(SOCK_CHUNK_BYTES)
# We expect socket.recv to raise an exception if there is not
# enough data to read the full bytes_to_read
# but if the socket is disconnected, we will get empty data
# without an exception raised
if not data:
log.error('%s: socket disconnected', self)
self.close(error=Errors.ConnectionError('socket disconnected'))
return None
self._rbuffer.write(data)
break
else:
responses.extend(self.receive_bytes(data))
if len(data) < SOCK_CHUNK_BYTES:
break
except SSLWantReadError:
return None
break
except ConnectionError as e:
if six.PY2 and e.errno == errno.EWOULDBLOCK:
return None
log.exception('%s: Error receiving 4-byte payload header -'
break
log.exception('%s: Error receiving network data'
' closing socket', self)
self.close(error=Errors.ConnectionError(e))
return None
except BlockingIOError:
if six.PY3:
return None
raise

if self._rbuffer.tell() == 4:
self._rbuffer.seek(0)
self._next_payload_bytes = Int32.decode(self._rbuffer)
# reset buffer and switch state to receiving payload bytes
self._rbuffer.seek(0)
self._rbuffer.truncate()
self._receiving = True
elif self._rbuffer.tell() > 4:
raise Errors.KafkaError('this should not happen - are you threading?')

if self._receiving:
staged_bytes = self._rbuffer.tell()
try:
bytes_to_read = self._next_payload_bytes - staged_bytes
data = self._sock.recv(bytes_to_read)
# We expect socket.recv to raise an exception if there is not
# enough data to read the full bytes_to_read
# but if the socket is disconnected, we will get empty data
# without an exception raised
if bytes_to_read and not data:
log.error('%s: socket disconnected', self)
self.close(error=Errors.ConnectionError('socket disconnected'))
return None
self._rbuffer.write(data)
except SSLWantReadError:
return None
except ConnectionError as e:
# Extremely small chance that we have exactly 4 bytes for a
# header, but nothing to read in the body yet
if six.PY2 and e.errno == errno.EWOULDBLOCK:
return None
log.exception('%s: Error in recv', self)
self.close(error=Errors.ConnectionError(e))
return None
break
except BlockingIOError:
if six.PY3:
return None
break
raise
return responses

staged_bytes = self._rbuffer.tell()
if staged_bytes > self._next_payload_bytes:
self.close(error=Errors.KafkaError('Receive buffer has more bytes than expected?'))

if staged_bytes != self._next_payload_bytes:
return None
def receive_bytes(self, data):
i = 0
n = len(data)
responses = []
if self._sensors:
self._sensors.bytes_received.record(n)
while i < n:

# Not receiving is the state of reading the payload header
if not self._receiving:
bytes_to_read = min(4 - self._header.tell(), n - i)
self._header.write(data[i:i+bytes_to_read])
i += bytes_to_read

if self._header.tell() == 4:
self._header.seek(0)
nbytes = Int32.decode(self._header)
# reset buffer and switch state to receiving payload bytes
self._rbuffer = KafkaBytes(nbytes)
self._receiving = True
elif self._header.tell() > 4:
raise Errors.KafkaError('this should not happen - are you threading?')


if self._receiving:
total_bytes = len(self._rbuffer)
staged_bytes = self._rbuffer.tell()
bytes_to_read = min(total_bytes - staged_bytes, n - i)
self._rbuffer.write(data[i:i+bytes_to_read])
i += bytes_to_read

staged_bytes = self._rbuffer.tell()
if staged_bytes > total_bytes:
self.close(error=Errors.KafkaError('Receive buffer has more bytes than expected?'))

if staged_bytes != total_bytes:
break

self._receiving = False
self._next_payload_bytes = 0
if self._sensors:
self._sensors.bytes_received.record(4 + self._rbuffer.tell())
self._rbuffer.seek(0)
response = self._process_response(self._rbuffer)
self._rbuffer.seek(0)
self._rbuffer.truncate()
return response
self._receiving = False
self._rbuffer.seek(0)
resp = self._process_response(self._rbuffer)
if resp is not None:
responses.append(resp)
self._reset_buffer()
return responses

def _process_response(self, read_buffer):
assert not self._processing, 'Recursion not supported'
self._processing = True
ifr = self.in_flight_requests.popleft()
recv_correlation_id = Int32.decode(read_buffer)

if not self.in_flight_requests:
error = Errors.CorrelationIdError(
'%s: No in-flight-request found for server response'
' with correlation ID %d'
% (self, recv_correlation_id))
self.close(error)
self._processing = False
return None
else:
ifr = self.in_flight_requests.popleft()

if self._sensors:
self._sensors.request_time.record((time.time() - ifr.timestamp) * 1000)

# verify send/recv correlation ids match
recv_correlation_id = Int32.decode(read_buffer)

# 0.8.2 quirk
if (self.config['api_version'] == (0, 8, 2) and
Expand Down
30 changes: 30 additions & 0 deletions kafka/protocol/frame.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
class KafkaBytes(bytearray):
def __init__(self, size):
super(KafkaBytes, self).__init__(size)
self._idx = 0

def read(self, nbytes=None):
if nbytes is None:
nbytes = len(self) - self._idx
start = self._idx
self._idx += nbytes
if self._idx > len(self):
self._idx = len(self)
return bytes(self[start:self._idx])

def write(self, data):
start = self._idx
self._idx += len(data)
self[start:self._idx] = data

def seek(self, idx):
self._idx = idx

def tell(self):
return self._idx

def __str__(self):
return 'KafkaBytes(%d)' % len(self)

def __repr__(self):
return str(self)
7 changes: 4 additions & 3 deletions kafka/protocol/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ..codec import (has_gzip, has_snappy, has_lz4,
gzip_decode, snappy_decode,
lz4_decode, lz4_decode_old_kafka)
from .frame import KafkaBytes
from .struct import Struct
from .types import (
Int8, Int32, Int64, Bytes, Schema, AbstractType
Expand Down Expand Up @@ -155,10 +156,10 @@ class MessageSet(AbstractType):
@classmethod
def encode(cls, items):
# RecordAccumulator encodes messagesets internally
if isinstance(items, io.BytesIO):
if isinstance(items, (io.BytesIO, KafkaBytes)):
size = Int32.decode(items)
# rewind and return all the bytes
items.seek(-4, 1)
items.seek(items.tell() - 4)
return items.read(size + 4)

encoded_values = []
Expand Down Expand Up @@ -198,7 +199,7 @@ def decode(cls, data, bytes_to_read=None):

@classmethod
def repr(cls, messages):
if isinstance(messages, io.BytesIO):
if isinstance(messages, (KafkaBytes, io.BytesIO)):
offset = messages.tell()
decoded = cls.decode(messages)
messages.seek(offset)
Expand Down