From 23490000ce95d5cf04f87690e466d058110442fb Mon Sep 17 00:00:00 2001 From: William Barnhart Date: Mon, 18 Mar 2024 11:24:00 -0400 Subject: [PATCH] Support custom SASL mechanisms including AWS MSK (#170) * Support custom SASL mechanisms There is some interest in supporting various SASL mechanisms not currently included in the library: * #2110 (DMS) * #2204 (SSPI) * #2232 (AWS_MSK_IAM) Adding these mechanisms in the core library may be undesirable due to: * Increased maintenance burden. * Unavailable testing environments. * Vendor specificity. This commit provides a quick prototype for a pluggable SASL system. --- **Example** To define a custom SASL mechanism a module must implement two methods: ```py def validate_config(conn): # Check configuration values, available libraries, etc. assert conn.config['vendor_specific_setting'] is not None, ( 'vendor_specific_setting required when sasl_mechanism=MY_SASL' ) def try_authenticate(conn, future): # Do authentication routine and return resolved Future with failed # or succeeded state. ``` And then the custom mechanism should be registered before initializing a KafkaAdminClient, KafkaConsumer, or KafkaProducer: ```py import kafka.sasl from kafka import KafkaProducer import my_sasl kafka.sasl.register_mechanism('MY_SASL', my_sasl) producer = KafkaProducer(sasl_mechanism='MY_SASL') ``` --- **Notes** **ABCs** This prototype does not implement an ABC for custom SASL mechanisms. Using an ABC would reduce a few of the explicit assertions involved with registering a mechanism and is a viable option. Due to differing feature sets between py2/py3 this option was not explored, but shouldn't be difficult. **Private Methods** This prototype relies on some methods that are currently marked as **private** in `BrokerConnection`. * `._can_send_recv` * `._lock` * `._recv_bytes_blocking` * `._send_bytes_blocking` A pluggable system would require stable interfaces for these actions. **Alternative Approach** If the module-scoped dict modification in `register_mechanism` feels too clunky maybe the addtional mechanisms can be specified via an argument when initializing one of the `Kafka*` classes? * Add test_msk.py by @mattoberle * add msk to __init__ and check for extension in conn.py * rename try_authenticate in msk.py * fix imports * fix imports * add botocore to requirements-dev.txt * add boto3 to requirements-dev.txt * add awscli to requirements-dev.txt * add awscli to workflow since it takes too long to install normally * just install botocore i guess * just install boto3 i guess * force reinstall awscli * try something weird * ok now the dang tests should work and if they don't i'll cry * skip the msk test for now... * Revert "skip the msk test for now..." This reverts commit 1c29667ccfd2cbd2a7b00a5328ee0556362d7ef4. * skip the msk test for now... * nvm just needed to update tox lol * Update kafka/sasl/gssapi.py Co-authored-by: code-review-doctor[bot] <72320148+code-review-doctor[bot]@users.noreply.github.com> * Update kafka/sasl/oauthbearer.py Co-authored-by: code-review-doctor[bot] <72320148+code-review-doctor[bot]@users.noreply.github.com> * Update kafka/sasl/plain.py Co-authored-by: code-review-doctor[bot] <72320148+code-review-doctor[bot]@users.noreply.github.com> * Update kafka/sasl/scram.py Co-authored-by: code-review-doctor[bot] <72320148+code-review-doctor[bot]@users.noreply.github.com> * Update kafka/sasl/msk.py Co-authored-by: code-review-doctor[bot] <72320148+code-review-doctor[bot]@users.noreply.github.com> --------- Co-authored-by: Matt Oberle Co-authored-by: code-review-doctor[bot] <72320148+code-review-doctor[bot]@users.noreply.github.com> --- kafka/conn.py | 279 ++++---------------------------------- kafka/sasl/__init__.py | 54 ++++++++ kafka/sasl/gssapi.py | 100 ++++++++++++++ kafka/sasl/msk.py | 231 +++++++++++++++++++++++++++++++ kafka/sasl/oauthbearer.py | 80 +++++++++++ kafka/sasl/plain.py | 58 ++++++++ kafka/sasl/scram.py | 68 ++++++++++ requirements-dev.txt | 1 + test/test_msk.py | 70 ++++++++++ tox.ini | 1 + 10 files changed, 687 insertions(+), 255 deletions(-) create mode 100644 kafka/sasl/__init__.py create mode 100644 kafka/sasl/gssapi.py create mode 100644 kafka/sasl/msk.py create mode 100644 kafka/sasl/oauthbearer.py create mode 100644 kafka/sasl/plain.py create mode 100644 kafka/sasl/scram.py create mode 100644 test/test_msk.py diff --git a/kafka/conn.py b/kafka/conn.py index d04acce3e..f253cbda1 100644 --- a/kafka/conn.py +++ b/kafka/conn.py @@ -2,7 +2,6 @@ import copy import errno -import io import logging from random import shuffle, uniform @@ -14,25 +13,26 @@ from kafka.vendor import selectors34 as selectors import socket -import struct import threading import time from kafka.vendor import six +from kafka import sasl import kafka.errors as Errors from kafka.future import Future from kafka.metrics.stats import Avg, Count, Max, Rate -from kafka.oauth.abstract import AbstractTokenProvider -from kafka.protocol.admin import SaslHandShakeRequest, DescribeAclsRequest_v2, DescribeClientQuotasRequest +from kafka.protocol.admin import ( + DescribeAclsRequest_v2, + DescribeClientQuotasRequest, + SaslHandShakeRequest, +) from kafka.protocol.commit import OffsetFetchRequest from kafka.protocol.offset import OffsetRequest from kafka.protocol.produce import ProduceRequest from kafka.protocol.metadata import MetadataRequest from kafka.protocol.fetch import FetchRequest from kafka.protocol.parser import KafkaProtocol -from kafka.protocol.types import Int32, Int8 -from kafka.scram import ScramClient from kafka.version import __version__ @@ -83,6 +83,12 @@ class SSLWantWriteError(Exception): gssapi = None GSSError = None +# needed for AWS_MSK_IAM authentication: +try: + from botocore.session import Session as BotoSession +except ImportError: + # no botocore available, will disable AWS_MSK_IAM mechanism + BotoSession = None AFI_NAMES = { socket.AF_UNSPEC: "unspecified", @@ -227,7 +233,6 @@ class BrokerConnection(object): 'sasl_oauth_token_provider': None } SECURITY_PROTOCOLS = ('PLAINTEXT', 'SSL', 'SASL_PLAINTEXT', 'SASL_SSL') - SASL_MECHANISMS = ('PLAIN', 'GSSAPI', 'OAUTHBEARER', "SCRAM-SHA-256", "SCRAM-SHA-512") def __init__(self, host, port, afi, **configs): self.host = host @@ -256,26 +261,19 @@ def __init__(self, host, port, afi, **configs): assert self.config['security_protocol'] in self.SECURITY_PROTOCOLS, ( 'security_protocol must be in ' + ', '.join(self.SECURITY_PROTOCOLS)) + if self.config['security_protocol'] in ('SSL', 'SASL_SSL'): assert ssl_available, "Python wasn't built with SSL support" + if self.config['sasl_mechanism'] == 'AWS_MSK_IAM': + assert BotoSession is not None, 'AWS_MSK_IAM requires the "botocore" package' + assert self.config['security_protocol'] == 'SASL_SSL', 'AWS_MSK_IAM requires SASL_SSL' + if self.config['security_protocol'] in ('SASL_PLAINTEXT', 'SASL_SSL'): - assert self.config['sasl_mechanism'] in self.SASL_MECHANISMS, ( - 'sasl_mechanism must be in ' + ', '.join(self.SASL_MECHANISMS)) - if self.config['sasl_mechanism'] in ('PLAIN', 'SCRAM-SHA-256', 'SCRAM-SHA-512'): - assert self.config['sasl_plain_username'] is not None, ( - 'sasl_plain_username required for PLAIN or SCRAM sasl' - ) - assert self.config['sasl_plain_password'] is not None, ( - 'sasl_plain_password required for PLAIN or SCRAM sasl' - ) - if self.config['sasl_mechanism'] == 'GSSAPI': - assert gssapi is not None, 'GSSAPI lib not available' - assert self.config['sasl_kerberos_service_name'] is not None, 'sasl_kerberos_service_name required for GSSAPI sasl' - if self.config['sasl_mechanism'] == 'OAUTHBEARER': - token_provider = self.config['sasl_oauth_token_provider'] - assert token_provider is not None, 'sasl_oauth_token_provider required for OAUTHBEARER sasl' - assert callable(getattr(token_provider, "token", None)), 'sasl_oauth_token_provider must implement method #token()' + assert self.config['sasl_mechanism'] in sasl.MECHANISMS, ( + 'sasl_mechanism must be one of {}'.format(', '.join(sasl.MECHANISMS.keys())) + ) + sasl.MECHANISMS[self.config['sasl_mechanism']].validate_config(self) # This is not a general lock / this class is not generally thread-safe yet # However, to avoid pushing responsibility for maintaining # per-connection locks to the upstream client, we will use this lock to @@ -553,19 +551,9 @@ def _handle_sasl_handshake_response(self, future, response): Errors.UnsupportedSaslMechanismError( 'Kafka broker does not support %s sasl mechanism. Enabled mechanisms are: %s' % (self.config['sasl_mechanism'], response.enabled_mechanisms))) - elif self.config['sasl_mechanism'] == 'PLAIN': - return self._try_authenticate_plain(future) - elif self.config['sasl_mechanism'] == 'GSSAPI': - return self._try_authenticate_gssapi(future) - elif self.config['sasl_mechanism'] == 'OAUTHBEARER': - return self._try_authenticate_oauth(future) - elif self.config['sasl_mechanism'].startswith("SCRAM-SHA-"): - return self._try_authenticate_scram(future) - else: - return future.failure( - Errors.UnsupportedSaslMechanismError( - 'kafka-python does not support SASL mechanism %s' % - self.config['sasl_mechanism'])) + + try_authenticate = sasl.MECHANISMS[self.config['sasl_mechanism']].try_authenticate + return try_authenticate(self, future) def _send_bytes(self, data): """Send some data via non-blocking IO @@ -619,225 +607,6 @@ def _recv_bytes_blocking(self, n): finally: self._sock.settimeout(0.0) - def _try_authenticate_plain(self, future): - if self.config['security_protocol'] == 'SASL_PLAINTEXT': - log.warning('%s: Sending username and password in the clear', self) - - data = b'' - # Send PLAIN credentials per RFC-4616 - msg = bytes('\0'.join([self.config['sasl_plain_username'], - self.config['sasl_plain_username'], - self.config['sasl_plain_password']]).encode('utf-8')) - size = Int32.encode(len(msg)) - - err = None - close = False - with self._lock: - if not self._can_send_recv(): - err = Errors.NodeNotReadyError(str(self)) - close = False - else: - try: - self._send_bytes_blocking(size + msg) - - # The server will send a zero sized message (that is Int32(0)) on success. - # The connection is closed on failure - data = self._recv_bytes_blocking(4) - - except (ConnectionError, TimeoutError) as e: - log.exception("%s: Error receiving reply from server", self) - err = Errors.KafkaConnectionError("%s: %s" % (self, e)) - close = True - - if err is not None: - if close: - self.close(error=err) - return future.failure(err) - - if data != b'\x00\x00\x00\x00': - error = Errors.AuthenticationFailedError('Unrecognized response during authentication') - return future.failure(error) - - log.info('%s: Authenticated as %s via PLAIN', self, self.config['sasl_plain_username']) - return future.success(True) - - def _try_authenticate_scram(self, future): - if self.config['security_protocol'] == 'SASL_PLAINTEXT': - log.warning('%s: Exchanging credentials in the clear', self) - - scram_client = ScramClient( - self.config['sasl_plain_username'], self.config['sasl_plain_password'], self.config['sasl_mechanism'] - ) - - err = None - close = False - with self._lock: - if not self._can_send_recv(): - err = Errors.NodeNotReadyError(str(self)) - close = False - else: - try: - client_first = scram_client.first_message().encode('utf-8') - size = Int32.encode(len(client_first)) - self._send_bytes_blocking(size + client_first) - - (data_len,) = struct.unpack('>i', self._recv_bytes_blocking(4)) - server_first = self._recv_bytes_blocking(data_len).decode('utf-8') - scram_client.process_server_first_message(server_first) - - client_final = scram_client.final_message().encode('utf-8') - size = Int32.encode(len(client_final)) - self._send_bytes_blocking(size + client_final) - - (data_len,) = struct.unpack('>i', self._recv_bytes_blocking(4)) - server_final = self._recv_bytes_blocking(data_len).decode('utf-8') - scram_client.process_server_final_message(server_final) - - except (ConnectionError, TimeoutError) as e: - log.exception("%s: Error receiving reply from server", self) - err = Errors.KafkaConnectionError("%s: %s" % (self, e)) - close = True - - if err is not None: - if close: - self.close(error=err) - return future.failure(err) - - log.info( - '%s: Authenticated as %s via %s', self, self.config['sasl_plain_username'], self.config['sasl_mechanism'] - ) - return future.success(True) - - def _try_authenticate_gssapi(self, future): - kerberos_damin_name = self.config['sasl_kerberos_domain_name'] or self.host - auth_id = self.config['sasl_kerberos_service_name'] + '@' + kerberos_damin_name - gssapi_name = gssapi.Name( - auth_id, - name_type=gssapi.NameType.hostbased_service - ).canonicalize(gssapi.MechType.kerberos) - log.debug('%s: GSSAPI name: %s', self, gssapi_name) - - err = None - close = False - with self._lock: - if not self._can_send_recv(): - err = Errors.NodeNotReadyError(str(self)) - close = False - else: - # Establish security context and negotiate protection level - # For reference RFC 2222, section 7.2.1 - try: - # Exchange tokens until authentication either succeeds or fails - client_ctx = gssapi.SecurityContext(name=gssapi_name, usage='initiate') - received_token = None - while not client_ctx.complete: - # calculate an output token from kafka token (or None if first iteration) - output_token = client_ctx.step(received_token) - - # pass output token to kafka, or send empty response if the security - # context is complete (output token is None in that case) - if output_token is None: - self._send_bytes_blocking(Int32.encode(0)) - else: - msg = output_token - size = Int32.encode(len(msg)) - self._send_bytes_blocking(size + msg) - - # The server will send a token back. Processing of this token either - # establishes a security context, or it needs further token exchange. - # The gssapi will be able to identify the needed next step. - # The connection is closed on failure. - header = self._recv_bytes_blocking(4) - (token_size,) = struct.unpack('>i', header) - received_token = self._recv_bytes_blocking(token_size) - - # Process the security layer negotiation token, sent by the server - # once the security context is established. - - # unwraps message containing supported protection levels and msg size - msg = client_ctx.unwrap(received_token).message - # Kafka currently doesn't support integrity or confidentiality security layers, so we - # simply set QoP to 'auth' only (first octet). We reuse the max message size proposed - # by the server - msg = Int8.encode(SASL_QOP_AUTH & Int8.decode(io.BytesIO(msg[0:1]))) + msg[1:] - # add authorization identity to the response, GSS-wrap and send it - msg = client_ctx.wrap(msg + auth_id.encode(), False).message - size = Int32.encode(len(msg)) - self._send_bytes_blocking(size + msg) - - except (ConnectionError, TimeoutError) as e: - log.exception("%s: Error receiving reply from server", self) - err = Errors.KafkaConnectionError("%s: %s" % (self, e)) - close = True - except Exception as e: - err = e - close = True - - if err is not None: - if close: - self.close(error=err) - return future.failure(err) - - log.info('%s: Authenticated as %s via GSSAPI', self, gssapi_name) - return future.success(True) - - def _try_authenticate_oauth(self, future): - data = b'' - - msg = bytes(self._build_oauth_client_request().encode("utf-8")) - size = Int32.encode(len(msg)) - - err = None - close = False - with self._lock: - if not self._can_send_recv(): - err = Errors.NodeNotReadyError(str(self)) - close = False - else: - try: - # Send SASL OAuthBearer request with OAuth token - self._send_bytes_blocking(size + msg) - - # The server will send a zero sized message (that is Int32(0)) on success. - # The connection is closed on failure - data = self._recv_bytes_blocking(4) - - except (ConnectionError, TimeoutError) as e: - log.exception("%s: Error receiving reply from server", self) - err = Errors.KafkaConnectionError("%s: %s" % (self, e)) - close = True - - if err is not None: - if close: - self.close(error=err) - return future.failure(err) - - if data != b'\x00\x00\x00\x00': - error = Errors.AuthenticationFailedError('Unrecognized response during authentication') - return future.failure(error) - - log.info('%s: Authenticated via OAuth', self) - return future.success(True) - - def _build_oauth_client_request(self): - token_provider = self.config['sasl_oauth_token_provider'] - return "n,,\x01auth=Bearer {}{}\x01\x01".format(token_provider.token(), self._token_extensions()) - - def _token_extensions(self): - """ - Return a string representation of the OPTIONAL key-value pairs that can be sent with an OAUTHBEARER - initial request. - """ - token_provider = self.config['sasl_oauth_token_provider'] - - # Only run if the #extensions() method is implemented by the clients Token Provider class - # Builds up a string separated by \x01 via a dict of key value pairs - if callable(getattr(token_provider, "extensions", None)) and len(token_provider.extensions()) > 0: - msg = "\x01".join(["{}={}".format(k, v) for k, v in token_provider.extensions().items()]) - return "\x01" + msg - else: - return "" - def blacked_out(self): """ Return true if we are disconnected from the given node and can't diff --git a/kafka/sasl/__init__.py b/kafka/sasl/__init__.py new file mode 100644 index 000000000..4a7f21a5f --- /dev/null +++ b/kafka/sasl/__init__.py @@ -0,0 +1,54 @@ +import logging + +from kafka.sasl import gssapi, oauthbearer, plain, scram, msk + +log = logging.getLogger(__name__) + +MECHANISMS = { + 'GSSAPI': gssapi, + 'OAUTHBEARER': oauthbearer, + 'PLAIN': plain, + 'SCRAM-SHA-256': scram, + 'SCRAM-SHA-512': scram, + 'AWS_MSK_IAM': msk, +} + + +def register_mechanism(key, module): + """ + Registers a custom SASL mechanism that can be used via sasl_mechanism={key}. + + Example: + import kakfa.sasl + from kafka import KafkaProducer + from mymodule import custom_sasl + kafka.sasl.register_mechanism('CUSTOM_SASL', custom_sasl) + + producer = KafkaProducer(sasl_mechanism='CUSTOM_SASL') + + Arguments: + key (str): The name of the mechanism returned by the broker and used + in the sasl_mechanism config value. + module (module): A module that implements the following methods... + + def validate_config(conn: BrokerConnection): -> None: + # Raises an AssertionError for missing or invalid conifg values. + + def try_authenticate(conn: BrokerConncetion, future: -> Future): + # Executes authentication routine and returns a resolved Future. + + Raises: + AssertionError: The registered module does not define a required method. + """ + assert callable(getattr(module, 'validate_config', None)), ( + 'Custom SASL mechanism {} must implement method #validate_config()' + .format(key) + ) + assert callable(getattr(module, 'try_authenticate', None)), ( + 'Custom SASL mechanism {} must implement method #try_authenticate()' + .format(key) + ) + if key in MECHANISMS: + log.warning('Overriding existing SASL mechanism {}'.format(key)) + + MECHANISMS[key] = module diff --git a/kafka/sasl/gssapi.py b/kafka/sasl/gssapi.py new file mode 100644 index 000000000..3daf7e148 --- /dev/null +++ b/kafka/sasl/gssapi.py @@ -0,0 +1,100 @@ +import io +import logging +import struct + +import kafka.errors as Errors +from kafka.protocol.types import Int8, Int32 + +try: + import gssapi + from gssapi.raw.misc import GSSError +except ImportError: + gssapi = None + GSSError = None + +log = logging.getLogger(__name__) + +SASL_QOP_AUTH = 1 + + +def validate_config(conn): + assert gssapi is not None, ( + 'gssapi library required when sasl_mechanism=GSSAPI' + ) + assert conn.config['sasl_kerberos_service_name'] is not None, ( + 'sasl_kerberos_service_name required when sasl_mechanism=GSSAPI' + ) + + +def try_authenticate(conn, future): + kerberos_damin_name = conn.config['sasl_kerberos_domain_name'] or conn.host + auth_id = conn.config['sasl_kerberos_service_name'] + '@' + kerberos_damin_name + gssapi_name = gssapi.Name( + auth_id, + name_type=gssapi.NameType.hostbased_service + ).canonicalize(gssapi.MechType.kerberos) + log.debug('%s: GSSAPI name: %s', conn, gssapi_name) + + err = None + close = False + with conn._lock: + if not conn._can_send_recv(): + err = Errors.NodeNotReadyError(str(conn)) + close = False + else: + # Establish security context and negotiate protection level + # For reference RFC 2222, section 7.2.1 + try: + # Exchange tokens until authentication either succeeds or fails + client_ctx = gssapi.SecurityContext(name=gssapi_name, usage='initiate') + received_token = None + while not client_ctx.complete: + # calculate an output token from kafka token (or None if first iteration) + output_token = client_ctx.step(received_token) + + # pass output token to kafka, or send empty response if the security + # context is complete (output token is None in that case) + if output_token is None: + conn._send_bytes_blocking(Int32.encode(0)) + else: + msg = output_token + size = Int32.encode(len(msg)) + conn._send_bytes_blocking(size + msg) + + # The server will send a token back. Processing of this token either + # establishes a security context, or it needs further token exchange. + # The gssapi will be able to identify the needed next step. + # The connection is closed on failure. + header = conn._recv_bytes_blocking(4) + (token_size,) = struct.unpack('>i', header) + received_token = conn._recv_bytes_blocking(token_size) + + # Process the security layer negotiation token, sent by the server + # once the security context is established. + + # unwraps message containing supported protection levels and msg size + msg = client_ctx.unwrap(received_token).message + # Kafka currently doesn't support integrity or confidentiality + # security layers, so we simply set QoP to 'auth' only (first octet). + # We reuse the max message size proposed by the server + msg = Int8.encode(SASL_QOP_AUTH & Int8.decode(io.BytesIO(msg[0:1]))) + msg[1:] + # add authorization identity to the response, GSS-wrap and send it + msg = client_ctx.wrap(msg + auth_id.encode(), False).message + size = Int32.encode(len(msg)) + conn._send_bytes_blocking(size + msg) + + except (ConnectionError, TimeoutError) as e: + log.exception("%s: Error receiving reply from server", conn) + err = Errors.KafkaConnectionError(f"{conn}: {e}") + close = True + except Exception as e: + err = e + close = True + + if err is not None: + if close: + conn.close(error=err) + return future.failure(err) + + log.info('%s: Authenticated as %s via GSSAPI', conn, gssapi_name) + return future.success(True) diff --git a/kafka/sasl/msk.py b/kafka/sasl/msk.py new file mode 100644 index 000000000..3f2d054e7 --- /dev/null +++ b/kafka/sasl/msk.py @@ -0,0 +1,231 @@ +import datetime +import hashlib +import hmac +import json +import string +import struct +import logging + + +from kafka.vendor.six.moves import urllib +from kafka.protocol.types import Int32 +import kafka.errors as Errors + +from botocore.session import Session as BotoSession # importing it in advance is not an option apparently... + + +def try_authenticate(self, future): + + session = BotoSession() + credentials = session.get_credentials().get_frozen_credentials() + client = AwsMskIamClient( + host=self.host, + access_key=credentials.access_key, + secret_key=credentials.secret_key, + region=session.get_config_variable('region'), + token=credentials.token, + ) + + msg = client.first_message() + size = Int32.encode(len(msg)) + + err = None + close = False + with self._lock: + if not self._can_send_recv(): + err = Errors.NodeNotReadyError(str(self)) + close = False + else: + try: + self._send_bytes_blocking(size + msg) + data = self._recv_bytes_blocking(4) + data = self._recv_bytes_blocking(struct.unpack('4B', data)[-1]) + except (ConnectionError, TimeoutError) as e: + logging.exception("%s: Error receiving reply from server", self) + err = Errors.KafkaConnectionError(f"{self}: {e}") + close = True + + if err is not None: + if close: + self.close(error=err) + return future.failure(err) + + logging.info('%s: Authenticated via AWS_MSK_IAM %s', self, data.decode('utf-8')) + return future.success(True) + + +class AwsMskIamClient: + UNRESERVED_CHARS = string.ascii_letters + string.digits + '-._~' + + def __init__(self, host, access_key, secret_key, region, token=None): + """ + Arguments: + host (str): The hostname of the broker. + access_key (str): An AWS_ACCESS_KEY_ID. + secret_key (str): An AWS_SECRET_ACCESS_KEY. + region (str): An AWS_REGION. + token (Optional[str]): An AWS_SESSION_TOKEN if using temporary + credentials. + """ + self.algorithm = 'AWS4-HMAC-SHA256' + self.expires = '900' + self.hashfunc = hashlib.sha256 + self.headers = [ + ('host', host) + ] + self.version = '2020_10_22' + + self.service = 'kafka-cluster' + self.action = '{}:Connect'.format(self.service) + + now = datetime.datetime.utcnow() + self.datestamp = now.strftime('%Y%m%d') + self.timestamp = now.strftime('%Y%m%dT%H%M%SZ') + + self.host = host + self.access_key = access_key + self.secret_key = secret_key + self.region = region + self.token = token + + @property + def _credential(self): + return '{0.access_key}/{0._scope}'.format(self) + + @property + def _scope(self): + return '{0.datestamp}/{0.region}/{0.service}/aws4_request'.format(self) + + @property + def _signed_headers(self): + """ + Returns (str): + An alphabetically sorted, semicolon-delimited list of lowercase + request header names. + """ + return ';'.join(sorted(k.lower() for k, _ in self.headers)) + + @property + def _canonical_headers(self): + """ + Returns (str): + A newline-delited list of header names and values. + Header names are lowercased. + """ + return '\n'.join(map(':'.join, self.headers)) + '\n' + + @property + def _canonical_request(self): + """ + Returns (str): + An AWS Signature Version 4 canonical request in the format: + \n + \n + \n + \n + \n + + """ + # The hashed_payload is always an empty string for MSK. + hashed_payload = self.hashfunc(b'').hexdigest() + return '\n'.join(( + 'GET', + '/', + self._canonical_querystring, + self._canonical_headers, + self._signed_headers, + hashed_payload, + )) + + @property + def _canonical_querystring(self): + """ + Returns (str): + A '&'-separated list of URI-encoded key/value pairs. + """ + params = [] + params.append(('Action', self.action)) + params.append(('X-Amz-Algorithm', self.algorithm)) + params.append(('X-Amz-Credential', self._credential)) + params.append(('X-Amz-Date', self.timestamp)) + params.append(('X-Amz-Expires', self.expires)) + if self.token: + params.append(('X-Amz-Security-Token', self.token)) + params.append(('X-Amz-SignedHeaders', self._signed_headers)) + + return '&'.join(self._uriencode(k) + '=' + self._uriencode(v) for k, v in params) + + @property + def _signing_key(self): + """ + Returns (bytes): + An AWS Signature V4 signing key generated from the secret_key, date, + region, service, and request type. + """ + key = self._hmac(('AWS4' + self.secret_key).encode('utf-8'), self.datestamp) + key = self._hmac(key, self.region) + key = self._hmac(key, self.service) + key = self._hmac(key, 'aws4_request') + return key + + @property + def _signing_str(self): + """ + Returns (str): + A string used to sign the AWS Signature V4 payload in the format: + \n + \n + \n + + """ + canonical_request_hash = self.hashfunc(self._canonical_request.encode('utf-8')).hexdigest() + return '\n'.join((self.algorithm, self.timestamp, self._scope, canonical_request_hash)) + + def _uriencode(self, msg): + """ + Arguments: + msg (str): A string to URI-encode. + + Returns (str): + The URI-encoded version of the provided msg, following the encoding + rules specified: https://github.com/aws/aws-msk-iam-auth#uriencode + """ + return urllib.parse.quote(msg, safe=self.UNRESERVED_CHARS) + + def _hmac(self, key, msg): + """ + Arguments: + key (bytes): A key to use for the HMAC digest. + msg (str): A value to include in the HMAC digest. + Returns (bytes): + An HMAC digest of the given key and msg. + """ + return hmac.new(key, msg.encode('utf-8'), digestmod=self.hashfunc).digest() + + def first_message(self): + """ + Returns (bytes): + An encoded JSON authentication payload that can be sent to the + broker. + """ + signature = hmac.new( + self._signing_key, + self._signing_str.encode('utf-8'), + digestmod=self.hashfunc, + ).hexdigest() + msg = { + 'version': self.version, + 'host': self.host, + 'user-agent': 'kafka-python', + 'action': self.action, + 'x-amz-algorithm': self.algorithm, + 'x-amz-credential': self._credential, + 'x-amz-date': self.timestamp, + 'x-amz-signedheaders': self._signed_headers, + 'x-amz-expires': self.expires, + 'x-amz-signature': signature, + } + if self.token: + msg['x-amz-security-token'] = self.token + + return json.dumps(msg, separators=(',', ':')).encode('utf-8') diff --git a/kafka/sasl/oauthbearer.py b/kafka/sasl/oauthbearer.py new file mode 100644 index 000000000..2fab7c37b --- /dev/null +++ b/kafka/sasl/oauthbearer.py @@ -0,0 +1,80 @@ +import logging + +import kafka.errors as Errors +from kafka.protocol.types import Int32 + +log = logging.getLogger(__name__) + + +def validate_config(conn): + token_provider = conn.config.get('sasl_oauth_token_provider') + assert token_provider is not None, ( + 'sasl_oauth_token_provider required when sasl_mechanism=OAUTHBEARER' + ) + assert callable(getattr(token_provider, 'token', None)), ( + 'sasl_oauth_token_provider must implement method #token()' + ) + + +def try_authenticate(conn, future): + data = b'' + + msg = bytes(_build_oauth_client_request(conn).encode("utf-8")) + size = Int32.encode(len(msg)) + + err = None + close = False + with conn._lock: + if not conn._can_send_recv(): + err = Errors.NodeNotReadyError(str(conn)) + close = False + else: + try: + # Send SASL OAuthBearer request with OAuth token + conn._send_bytes_blocking(size + msg) + + # The server will send a zero sized message (that is Int32(0)) on success. + # The connection is closed on failure + data = conn._recv_bytes_blocking(4) + + except (ConnectionError, TimeoutError) as e: + log.exception("%s: Error receiving reply from server", conn) + err = Errors.KafkaConnectionError(f"{conn}: {e}") + close = True + + if err is not None: + if close: + conn.close(error=err) + return future.failure(err) + + if data != b'\x00\x00\x00\x00': + error = Errors.AuthenticationFailedError('Unrecognized response during authentication') + return future.failure(error) + + log.info('%s: Authenticated via OAuth', conn) + return future.success(True) + + +def _build_oauth_client_request(conn): + token_provider = conn.config['sasl_oauth_token_provider'] + return "n,,\x01auth=Bearer {}{}\x01\x01".format( + token_provider.token(), + _token_extensions(conn), + ) + + +def _token_extensions(conn): + """ + Return a string representation of the OPTIONAL key-value pairs that can be + sent with an OAUTHBEARER initial request. + """ + token_provider = conn.config['sasl_oauth_token_provider'] + + # Only run if the #extensions() method is implemented by the clients Token Provider class + # Builds up a string separated by \x01 via a dict of key value pairs + if (callable(getattr(token_provider, "extensions", None)) + and len(token_provider.extensions()) > 0): + msg = "\x01".join(["{}={}".format(k, v) for k, v in token_provider.extensions().items()]) + return "\x01" + msg + else: + return "" diff --git a/kafka/sasl/plain.py b/kafka/sasl/plain.py new file mode 100644 index 000000000..625a43f08 --- /dev/null +++ b/kafka/sasl/plain.py @@ -0,0 +1,58 @@ +import logging + +import kafka.errors as Errors +from kafka.protocol.types import Int32 + +log = logging.getLogger(__name__) + + +def validate_config(conn): + assert conn.config['sasl_plain_username'] is not None, ( + 'sasl_plain_username required when sasl_mechanism=PLAIN' + ) + assert conn.config['sasl_plain_password'] is not None, ( + 'sasl_plain_password required when sasl_mechanism=PLAIN' + ) + + +def try_authenticate(conn, future): + if conn.config['security_protocol'] == 'SASL_PLAINTEXT': + log.warning('%s: Sending username and password in the clear', conn) + + data = b'' + # Send PLAIN credentials per RFC-4616 + msg = bytes('\0'.join([conn.config['sasl_plain_username'], + conn.config['sasl_plain_username'], + conn.config['sasl_plain_password']]).encode('utf-8')) + size = Int32.encode(len(msg)) + + err = None + close = False + with conn._lock: + if not conn._can_send_recv(): + err = Errors.NodeNotReadyError(str(conn)) + close = False + else: + try: + conn._send_bytes_blocking(size + msg) + + # The server will send a zero sized message (that is Int32(0)) on success. + # The connection is closed on failure + data = conn._recv_bytes_blocking(4) + + except (ConnectionError, TimeoutError) as e: + log.exception("%s: Error receiving reply from server", conn) + err = Errors.KafkaConnectionError(f"{conn}: {e}") + close = True + + if err is not None: + if close: + conn.close(error=err) + return future.failure(err) + + if data != b'\x00\x00\x00\x00': + error = Errors.AuthenticationFailedError('Unrecognized response during authentication') + return future.failure(error) + + log.info('%s: Authenticated as %s via PLAIN', conn, conn.config['sasl_plain_username']) + return future.success(True) diff --git a/kafka/sasl/scram.py b/kafka/sasl/scram.py new file mode 100644 index 000000000..4f3e60126 --- /dev/null +++ b/kafka/sasl/scram.py @@ -0,0 +1,68 @@ +import logging +import struct + +import kafka.errors as Errors +from kafka.protocol.types import Int32 +from kafka.scram import ScramClient + +log = logging.getLogger() + + +def validate_config(conn): + assert conn.config['sasl_plain_username'] is not None, ( + 'sasl_plain_username required when sasl_mechanism=SCRAM-*' + ) + assert conn.config['sasl_plain_password'] is not None, ( + 'sasl_plain_password required when sasl_mechanism=SCRAM-*' + ) + + +def try_authenticate(conn, future): + if conn.config['security_protocol'] == 'SASL_PLAINTEXT': + log.warning('%s: Exchanging credentials in the clear', conn) + + scram_client = ScramClient( + conn.config['sasl_plain_username'], + conn.config['sasl_plain_password'], + conn.config['sasl_mechanism'], + ) + + err = None + close = False + with conn._lock: + if not conn._can_send_recv(): + err = Errors.NodeNotReadyError(str(conn)) + close = False + else: + try: + client_first = scram_client.first_message().encode('utf-8') + size = Int32.encode(len(client_first)) + conn._send_bytes_blocking(size + client_first) + + (data_len,) = struct.unpack('>i', conn._recv_bytes_blocking(4)) + server_first = conn._recv_bytes_blocking(data_len).decode('utf-8') + scram_client.process_server_first_message(server_first) + + client_final = scram_client.final_message().encode('utf-8') + size = Int32.encode(len(client_final)) + conn._send_bytes_blocking(size + client_final) + + (data_len,) = struct.unpack('>i', conn._recv_bytes_blocking(4)) + server_final = conn._recv_bytes_blocking(data_len).decode('utf-8') + scram_client.process_server_final_message(server_final) + + except (ConnectionError, TimeoutError) as e: + log.exception("%s: Error receiving reply from server", conn) + err = Errors.KafkaConnectionError(f"{conn}: {e}") + close = True + + if err is not None: + if close: + conn.close(error=err) + return future.failure(err) + + log.info( + '%s: Authenticated as %s via %s', + conn, conn.config['sasl_plain_username'], conn.config['sasl_mechanism'] + ) + return future.success(True) diff --git a/requirements-dev.txt b/requirements-dev.txt index 1fa933da2..3f6e5542c 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -15,3 +15,4 @@ Sphinx sphinx-rtd-theme tox xxhash +botocore \ No newline at end of file diff --git a/test/test_msk.py b/test/test_msk.py new file mode 100644 index 000000000..7fca53b3d --- /dev/null +++ b/test/test_msk.py @@ -0,0 +1,70 @@ +import datetime +import json + + +try: + from unittest import mock +except ImportError: + import mock + +from kafka.sasl.msk import AwsMskIamClient + + +def client_factory(token=None): + + now = datetime.datetime.utcfromtimestamp(1629321911) + with mock.patch('kafka.sasl.msk.datetime') as mock_dt: + mock_dt.datetime.utcnow = mock.Mock(return_value=now) + return AwsMskIamClient( + host='localhost', + access_key='XXXXXXXXXXXXXXXXXXXX', + secret_key='XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX', + region='us-east-1', + token=token, + ) + + +def test_aws_msk_iam_client_permanent_credentials(): + client = client_factory(token=None) + msg = client.first_message() + assert msg + assert isinstance(msg, bytes) + actual = json.loads(msg) + + expected = { + 'version': '2020_10_22', + 'host': 'localhost', + 'user-agent': 'kafka-python', + 'action': 'kafka-cluster:Connect', + 'x-amz-algorithm': 'AWS4-HMAC-SHA256', + 'x-amz-credential': 'XXXXXXXXXXXXXXXXXXXX/20210818/us-east-1/kafka-cluster/aws4_request', + 'x-amz-date': '20210818T212511Z', + 'x-amz-signedheaders': 'host', + 'x-amz-expires': '900', + 'x-amz-signature': '0fa42ae3d5693777942a7a4028b564f0b372bafa2f71c1a19ad60680e6cb994b', + } + assert actual == expected + + +def test_aws_msk_iam_client_temporary_credentials(): + client = client_factory(token='XXXXX') + msg = client.first_message() + assert msg + assert isinstance(msg, bytes) + actual = json.loads(msg) + + expected = { + 'version': '2020_10_22', + 'host': 'localhost', + 'user-agent': 'kafka-python', + 'action': 'kafka-cluster:Connect', + 'x-amz-algorithm': 'AWS4-HMAC-SHA256', + 'x-amz-credential': 'XXXXXXXXXXXXXXXXXXXX/20210818/us-east-1/kafka-cluster/aws4_request', + 'x-amz-date': '20210818T212511Z', + 'x-amz-signedheaders': 'host', + 'x-amz-expires': '900', + 'x-amz-signature': 'b0619c50b7ecb4a7f6f92bd5f733770df5710e97b25146f97015c0b1db783b05', + 'x-amz-security-token': 'XXXXX', + } + assert actual == expected + diff --git a/tox.ini b/tox.ini index d9b1e36d4..3d8bfbbc4 100644 --- a/tox.ini +++ b/tox.ini @@ -28,6 +28,7 @@ deps = lz4 xxhash crc32c + botocore commands = pytest {posargs:--pylint --pylint-rcfile=pylint.rc --pylint-error-types=EF --cov=kafka --cov-config=.covrc} setenv =