diff --git a/kafka/conn.py b/kafka/conn.py index 33950dbbf..16ac4dc19 100644 --- a/kafka/conn.py +++ b/kafka/conn.py @@ -586,11 +586,14 @@ def _try_authenticate_plain(self, future): self.config['sasl_plain_password']]).encode('utf-8')) size = Int32.encode(len(msg)) try: - self._send_bytes_blocking(size + msg) + with self._lock: + if not self._can_send_recv(): + return future.failure(Errors.NodeNotReadyError(str(self))) + self._send_bytes_blocking(size + msg) - # The server will send a zero sized message (that is Int32(0)) on success. - # The connection is closed on failure - data = self._recv_bytes_blocking(4) + # The server will send a zero sized message (that is Int32(0)) on success. + # The connection is closed on failure + data = self._recv_bytes_blocking(4) except ConnectionError as e: log.exception("%s: Error receiving reply from server", self) @@ -614,6 +617,9 @@ def _try_authenticate_gssapi(self, future): ).canonicalize(gssapi.MechType.kerberos) log.debug('%s: GSSAPI name: %s', self, gssapi_name) + self._lock.acquire() + if not self._can_send_recv(): + return future.failure(Errors.NodeNotReadyError(str(self))) # Establish security context and negotiate protection level # For reference RFC 2222, section 7.2.1 try: @@ -656,13 +662,16 @@ def _try_authenticate_gssapi(self, future): self._send_bytes_blocking(size + msg) except ConnectionError as e: + self._lock.release() log.exception("%s: Error receiving reply from server", self) error = Errors.KafkaConnectionError("%s: %s" % (self, e)) self.close(error=error) return future.failure(error) except Exception as e: + self._lock.release() return future.failure(e) + self._lock.release() log.info('%s: Authenticated as %s via GSSAPI', self, gssapi_name) return future.success(True) @@ -671,6 +680,9 @@ def _try_authenticate_oauth(self, future): msg = bytes(self._build_oauth_client_request().encode("utf-8")) size = Int32.encode(len(msg)) + self._lock.acquire() + if not self._can_send_recv(): + return future.failure(Errors.NodeNotReadyError(str(self))) try: # Send SASL OAuthBearer request with OAuth token self._send_bytes_blocking(size + msg) @@ -680,11 +692,14 @@ def _try_authenticate_oauth(self, future): data = self._recv_bytes_blocking(4) except ConnectionError as e: + self._lock.release() log.exception("%s: Error receiving reply from server", self) error = Errors.KafkaConnectionError("%s: %s" % (self, e)) self.close(error=error) return future.failure(error) + self._lock.release() + if data != b'\x00\x00\x00\x00': error = Errors.AuthenticationFailedError('Unrecognized response during authentication') return future.failure(error) @@ -784,26 +799,33 @@ def close(self, error=None): will be failed with this exception. Default: kafka.errors.KafkaConnectionError. """ - if self.state is ConnectionStates.DISCONNECTED: - if error is not None: - log.warning('%s: Duplicate close() with error: %s', self, error) - return - log.info('%s: Closing connection. %s', self, error or '') - self.state = ConnectionStates.DISCONNECTING - self.config['state_change_callback'](self) - self._update_reconnect_backoff() - self._close_socket() - self.state = ConnectionStates.DISCONNECTED - self._sasl_auth_future = None - self._protocol = KafkaProtocol( - client_id=self.config['client_id'], - api_version=self.config['api_version']) - if error is None: - error = Errors.Cancelled(str(self)) - while self.in_flight_requests: - (_correlation_id, (future, _timestamp)) = self.in_flight_requests.popitem() + with self._lock: + if self.state is ConnectionStates.DISCONNECTED: + return + log.info('%s: Closing connection. %s', self, error or '') + self.state = ConnectionStates.DISCONNECTING + self.config['state_change_callback'](self) + self._update_reconnect_backoff() + self._close_socket() + self.state = ConnectionStates.DISCONNECTED + self._sasl_auth_future = None + self._protocol = KafkaProtocol( + client_id=self.config['client_id'], + api_version=self.config['api_version']) + if error is None: + error = Errors.Cancelled(str(self)) + ifrs = list(self.in_flight_requests.items()) + self.in_flight_requests.clear() + self.config['state_change_callback'](self) + + # drop lock before processing futures + for (_correlation_id, (future, _timestamp)) in ifrs: future.failure(error) - self.config['state_change_callback'](self) + + def _can_send_recv(self): + """Return True iff socket is ready for requests / responses""" + return self.state in (ConnectionStates.AUTHENTICATING, + ConnectionStates.CONNECTED) def send(self, request, blocking=True): """Queue request for async network send, return Future()""" @@ -817,18 +839,20 @@ def send(self, request, blocking=True): return self._send(request, blocking=blocking) def _send(self, request, blocking=True): - assert self.state in (ConnectionStates.AUTHENTICATING, ConnectionStates.CONNECTED) future = Future() with self._lock: + if not self._can_send_recv(): + return future.failure(Errors.NodeNotReadyError(str(self))) + correlation_id = self._protocol.send_request(request) - log.debug('%s Request %d: %s', self, correlation_id, request) - if request.expect_response(): - sent_time = time.time() - assert correlation_id not in self.in_flight_requests, 'Correlation ID already in-flight!' - self.in_flight_requests[correlation_id] = (future, sent_time) - else: - future.success(None) + log.debug('%s Request %d: %s', self, correlation_id, request) + if request.expect_response(): + sent_time = time.time() + assert correlation_id not in self.in_flight_requests, 'Correlation ID already in-flight!' + self.in_flight_requests[correlation_id] = (future, sent_time) + else: + future.success(None) # Attempt to replicate behavior from prior to introduction of # send_pending_requests() / async sends @@ -839,16 +863,15 @@ def _send(self, request, blocking=True): def send_pending_requests(self): """Can block on network if request is larger than send_buffer_bytes""" - if self.state not in (ConnectionStates.AUTHENTICATING, - ConnectionStates.CONNECTED): - return Errors.NodeNotReadyError(str(self)) - with self._lock: - data = self._protocol.send_bytes() try: - # In the future we might manage an internal write buffer - # and send bytes asynchronously. For now, just block - # sending each request payload - total_bytes = self._send_bytes_blocking(data) + with self._lock: + if not self._can_send_recv(): + return Errors.NodeNotReadyError(str(self)) + # In the future we might manage an internal write buffer + # and send bytes asynchronously. For now, just block + # sending each request payload + data = self._protocol.send_bytes() + total_bytes = self._send_bytes_blocking(data) if self._sensors: self._sensors.bytes_sent.record(total_bytes) return total_bytes @@ -868,18 +891,6 @@ def recv(self): Return list of (response, future) tuples """ - if not self.connected() and not self.state is ConnectionStates.AUTHENTICATING: - log.warning('%s cannot recv: socket not connected', self) - # If requests are pending, we should close the socket and - # fail all the pending request futures - if self.in_flight_requests: - self.close(Errors.KafkaConnectionError('Socket not connected during recv with in-flight-requests')) - return () - - elif not self.in_flight_requests: - log.warning('%s: No in-flight-requests to recv', self) - return () - responses = self._recv() if not responses and self.requests_timed_out(): log.warning('%s timed out after %s ms. Closing connection.', @@ -892,7 +903,8 @@ def recv(self): # augment respones w/ correlation_id, future, and timestamp for i, (correlation_id, response) in enumerate(responses): try: - (future, timestamp) = self.in_flight_requests.pop(correlation_id) + with self._lock: + (future, timestamp) = self.in_flight_requests.pop(correlation_id) except KeyError: self.close(Errors.KafkaConnectionError('Received unrecognized correlation id')) return () @@ -908,6 +920,12 @@ def recv(self): def _recv(self): """Take all available bytes from socket, return list of any responses from parser""" recvd = [] + self._lock.acquire() + if not self._can_send_recv(): + log.warning('%s cannot recv: socket not connected', self) + self._lock.release() + return () + while len(recvd) < self.config['sock_chunk_buffer_count']: try: data = self._sock.recv(self.config['sock_chunk_bytes']) @@ -917,6 +935,7 @@ def _recv(self): # without an exception raised if not data: log.error('%s: socket disconnected', self) + self._lock.release() self.close(error=Errors.KafkaConnectionError('socket disconnected')) return [] else: @@ -929,11 +948,13 @@ def _recv(self): break log.exception('%s: Error receiving network data' ' closing socket', self) + self._lock.release() self.close(error=Errors.KafkaConnectionError(e)) return [] except BlockingIOError: if six.PY3: break + self._lock.release() raise recvd_data = b''.join(recvd) @@ -943,20 +964,23 @@ def _recv(self): try: responses = self._protocol.receive_bytes(recvd_data) except Errors.KafkaProtocolError as e: + self._lock.release() self.close(e) return [] else: + self._lock.release() return responses def requests_timed_out(self): - if self.in_flight_requests: - get_timestamp = lambda v: v[1] - oldest_at = min(map(get_timestamp, - self.in_flight_requests.values())) - timeout = self.config['request_timeout_ms'] / 1000.0 - if time.time() >= oldest_at + timeout: - return True - return False + with self._lock: + if self.in_flight_requests: + get_timestamp = lambda v: v[1] + oldest_at = min(map(get_timestamp, + self.in_flight_requests.values())) + timeout = self.config['request_timeout_ms'] / 1000.0 + if time.time() >= oldest_at + timeout: + return True + return False def _handle_api_version_response(self, response): error_type = Errors.for_code(response.error_code)