Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Additional BrokerConnection locks to synchronize protocol/IFR state #1768

Merged
merged 2 commits into from
Apr 2, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 85 additions & 61 deletions kafka/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,11 +586,14 @@ def _try_authenticate_plain(self, future):
self.config['sasl_plain_password']]).encode('utf-8'))
size = Int32.encode(len(msg))
try:
self._send_bytes_blocking(size + msg)
with self._lock:
if not self._can_send_recv():
return future.failure(Errors.NodeNotReadyError(str(self)))
self._send_bytes_blocking(size + msg)

# The server will send a zero sized message (that is Int32(0)) on success.
# The connection is closed on failure
data = self._recv_bytes_blocking(4)
# The server will send a zero sized message (that is Int32(0)) on success.
# The connection is closed on failure
data = self._recv_bytes_blocking(4)

except ConnectionError as e:
log.exception("%s: Error receiving reply from server", self)
Expand All @@ -614,6 +617,9 @@ def _try_authenticate_gssapi(self, future):
).canonicalize(gssapi.MechType.kerberos)
log.debug('%s: GSSAPI name: %s', self, gssapi_name)

self._lock.acquire()
if not self._can_send_recv():
return future.failure(Errors.NodeNotReadyError(str(self)))
# Establish security context and negotiate protection level
# For reference RFC 2222, section 7.2.1
try:
Expand Down Expand Up @@ -656,13 +662,16 @@ def _try_authenticate_gssapi(self, future):
self._send_bytes_blocking(size + msg)

except ConnectionError as e:
self._lock.release()
log.exception("%s: Error receiving reply from server", self)
error = Errors.KafkaConnectionError("%s: %s" % (self, e))
self.close(error=error)
return future.failure(error)
except Exception as e:
self._lock.release()
return future.failure(e)

self._lock.release()
log.info('%s: Authenticated as %s via GSSAPI', self, gssapi_name)
return future.success(True)

Expand All @@ -671,6 +680,9 @@ def _try_authenticate_oauth(self, future):

msg = bytes(self._build_oauth_client_request().encode("utf-8"))
size = Int32.encode(len(msg))
self._lock.acquire()
if not self._can_send_recv():
return future.failure(Errors.NodeNotReadyError(str(self)))
try:
# Send SASL OAuthBearer request with OAuth token
self._send_bytes_blocking(size + msg)
Expand All @@ -680,11 +692,14 @@ def _try_authenticate_oauth(self, future):
data = self._recv_bytes_blocking(4)

except ConnectionError as e:
self._lock.release()
log.exception("%s: Error receiving reply from server", self)
error = Errors.KafkaConnectionError("%s: %s" % (self, e))
self.close(error=error)
return future.failure(error)

self._lock.release()

if data != b'\x00\x00\x00\x00':
error = Errors.AuthenticationFailedError('Unrecognized response during authentication')
return future.failure(error)
Expand Down Expand Up @@ -784,26 +799,33 @@ def close(self, error=None):
will be failed with this exception.
Default: kafka.errors.KafkaConnectionError.
"""
if self.state is ConnectionStates.DISCONNECTED:
if error is not None:
log.warning('%s: Duplicate close() with error: %s', self, error)
return
log.info('%s: Closing connection. %s', self, error or '')
self.state = ConnectionStates.DISCONNECTING
self.config['state_change_callback'](self)
self._update_reconnect_backoff()
self._close_socket()
self.state = ConnectionStates.DISCONNECTED
self._sasl_auth_future = None
self._protocol = KafkaProtocol(
client_id=self.config['client_id'],
api_version=self.config['api_version'])
if error is None:
error = Errors.Cancelled(str(self))
while self.in_flight_requests:
(_correlation_id, (future, _timestamp)) = self.in_flight_requests.popitem()
with self._lock:
if self.state is ConnectionStates.DISCONNECTED:
return
log.info('%s: Closing connection. %s', self, error or '')
self.state = ConnectionStates.DISCONNECTING
self.config['state_change_callback'](self)
self._update_reconnect_backoff()
self._close_socket()
self.state = ConnectionStates.DISCONNECTED
self._sasl_auth_future = None
self._protocol = KafkaProtocol(
client_id=self.config['client_id'],
api_version=self.config['api_version'])
if error is None:
error = Errors.Cancelled(str(self))
ifrs = list(self.in_flight_requests.items())
self.in_flight_requests.clear()
self.config['state_change_callback'](self)

# drop lock before processing futures
for (_correlation_id, (future, _timestamp)) in ifrs:
future.failure(error)
self.config['state_change_callback'](self)

def _can_send_recv(self):
"""Return True iff socket is ready for requests / responses"""
return self.state in (ConnectionStates.AUTHENTICATING,
ConnectionStates.CONNECTED)

def send(self, request, blocking=True):
"""Queue request for async network send, return Future()"""
Expand All @@ -817,18 +839,20 @@ def send(self, request, blocking=True):
return self._send(request, blocking=blocking)

def _send(self, request, blocking=True):
assert self.state in (ConnectionStates.AUTHENTICATING, ConnectionStates.CONNECTED)
future = Future()
with self._lock:
if not self._can_send_recv():
return future.failure(Errors.NodeNotReadyError(str(self)))

correlation_id = self._protocol.send_request(request)

log.debug('%s Request %d: %s', self, correlation_id, request)
if request.expect_response():
sent_time = time.time()
assert correlation_id not in self.in_flight_requests, 'Correlation ID already in-flight!'
self.in_flight_requests[correlation_id] = (future, sent_time)
else:
future.success(None)
log.debug('%s Request %d: %s', self, correlation_id, request)
if request.expect_response():
sent_time = time.time()
assert correlation_id not in self.in_flight_requests, 'Correlation ID already in-flight!'
self.in_flight_requests[correlation_id] = (future, sent_time)
else:
future.success(None)

# Attempt to replicate behavior from prior to introduction of
# send_pending_requests() / async sends
Expand All @@ -839,16 +863,15 @@ def _send(self, request, blocking=True):

def send_pending_requests(self):
"""Can block on network if request is larger than send_buffer_bytes"""
if self.state not in (ConnectionStates.AUTHENTICATING,
ConnectionStates.CONNECTED):
return Errors.NodeNotReadyError(str(self))
with self._lock:
data = self._protocol.send_bytes()
try:
# In the future we might manage an internal write buffer
# and send bytes asynchronously. For now, just block
# sending each request payload
total_bytes = self._send_bytes_blocking(data)
with self._lock:
if not self._can_send_recv():
return Errors.NodeNotReadyError(str(self))
# In the future we might manage an internal write buffer
# and send bytes asynchronously. For now, just block
# sending each request payload
data = self._protocol.send_bytes()
total_bytes = self._send_bytes_blocking(data)
if self._sensors:
self._sensors.bytes_sent.record(total_bytes)
return total_bytes
Expand All @@ -868,18 +891,6 @@ def recv(self):

Return list of (response, future) tuples
"""
if not self.connected() and not self.state is ConnectionStates.AUTHENTICATING:
log.warning('%s cannot recv: socket not connected', self)
# If requests are pending, we should close the socket and
# fail all the pending request futures
if self.in_flight_requests:
self.close(Errors.KafkaConnectionError('Socket not connected during recv with in-flight-requests'))
return ()

elif not self.in_flight_requests:
log.warning('%s: No in-flight-requests to recv', self)
return ()

responses = self._recv()
if not responses and self.requests_timed_out():
log.warning('%s timed out after %s ms. Closing connection.',
Expand All @@ -892,7 +903,8 @@ def recv(self):
# augment respones w/ correlation_id, future, and timestamp
for i, (correlation_id, response) in enumerate(responses):
try:
(future, timestamp) = self.in_flight_requests.pop(correlation_id)
with self._lock:
(future, timestamp) = self.in_flight_requests.pop(correlation_id)
except KeyError:
self.close(Errors.KafkaConnectionError('Received unrecognized correlation id'))
return ()
Expand All @@ -908,6 +920,12 @@ def recv(self):
def _recv(self):
"""Take all available bytes from socket, return list of any responses from parser"""
recvd = []
self._lock.acquire()
if not self._can_send_recv():
log.warning('%s cannot recv: socket not connected', self)
self._lock.release()
return ()

while len(recvd) < self.config['sock_chunk_buffer_count']:
try:
data = self._sock.recv(self.config['sock_chunk_bytes'])
Expand All @@ -917,6 +935,7 @@ def _recv(self):
# without an exception raised
if not data:
log.error('%s: socket disconnected', self)
self._lock.release()
self.close(error=Errors.KafkaConnectionError('socket disconnected'))
return []
else:
Expand All @@ -929,11 +948,13 @@ def _recv(self):
break
log.exception('%s: Error receiving network data'
' closing socket', self)
self._lock.release()
self.close(error=Errors.KafkaConnectionError(e))
return []
except BlockingIOError:
if six.PY3:
break
self._lock.release()
raise

recvd_data = b''.join(recvd)
Expand All @@ -943,20 +964,23 @@ def _recv(self):
try:
responses = self._protocol.receive_bytes(recvd_data)
except Errors.KafkaProtocolError as e:
self._lock.release()
self.close(e)
return []
else:
self._lock.release()
return responses

def requests_timed_out(self):
if self.in_flight_requests:
get_timestamp = lambda v: v[1]
oldest_at = min(map(get_timestamp,
self.in_flight_requests.values()))
timeout = self.config['request_timeout_ms'] / 1000.0
if time.time() >= oldest_at + timeout:
return True
return False
with self._lock:
if self.in_flight_requests:
get_timestamp = lambda v: v[1]
oldest_at = min(map(get_timestamp,
self.in_flight_requests.values()))
timeout = self.config['request_timeout_ms'] / 1000.0
if time.time() >= oldest_at + timeout:
return True
return False

def _handle_api_version_response(self, response):
error_type = Errors.for_code(response.error_code)
Expand Down