From 55578535cf1d7f3ef0c01fc52bde8c33c344700c Mon Sep 17 00:00:00 2001 From: Iyed Bennour Date: Wed, 27 Feb 2019 20:23:26 +0100 Subject: [PATCH] Add SASL SCRAM-SHA-512 and SCRAM-SHA-256 mechanismes --- broker.go | 123 +++++++++++++++++++++++++++++++++++++-- broker_test.go | 141 +++++++++++++++++++++++++++++++++++++++++++-- config.go | 36 +++++++++--- config_test.go | 22 ++++++- response_header.go | 3 + 5 files changed, 303 insertions(+), 22 deletions(-) diff --git a/broker.go b/broker.go index 9129089ac..42e1f4529 100644 --- a/broker.go +++ b/broker.go @@ -56,6 +56,10 @@ const ( SASLTypeOAuth = "OAUTHBEARER" // SASLTypePlaintext represents the SASL/PLAIN mechanism SASLTypePlaintext = "PLAIN" + // SASLTypeSCRAMSHA256 represents the SCRAM-SHA-256 mechanism. + SASLTypeSCRAMSHA256 = "SCRAM-SHA-256" + // SASLTypeSCRAMSHA512 represents the SCRAM-SHA-512 mechanism. + SASLTypeSCRAMSHA512 = "SCRAM-SHA-512" // SASLHandshakeV0 is v0 of the Kafka SASL handshake protocol. Client and // server negotiate SASL auth using opaque packets. SASLHandshakeV0 = int16(0) @@ -92,6 +96,20 @@ type AccessTokenProvider interface { Token() (*AccessToken, error) } +// SCRAMClient is a an interface to a SCRAM +// client implementation. +type SCRAMClient interface { + // Begin prepares the client for the SCRAM exchange + // with the server with a user name and a password + Begin(userName, password, authzID string) error + // Step steps client through the SCRAM exchange. It is + // called repeatedly until it errors or `Done` returns true. + Step(challenge string) (response string, err error) + // Done should return true when the SCRAM conversation + // is over. + Done() bool +} + type responsePromise struct { requestTime time.Time correlationID int32 @@ -793,14 +811,19 @@ func (b *Broker) responseReceiver() { } func (b *Broker) authenticateViaSASL() error { - if b.conf.Net.SASL.Mechanism == SASLTypeOAuth { + switch b.conf.Net.SASL.Mechanism { + case SASLTypeOAuth: return b.sendAndReceiveSASLOAuth(b.conf.Net.SASL.TokenProvider) + case SASLTypeSCRAMSHA256, SASLTypeSCRAMSHA512: + return b.sendAndReceiveSASLSCRAMv1() + default: + return b.sendAndReceiveSASLPlainAuth() } - return b.sendAndReceiveSASLPlainAuth() + } -func (b *Broker) sendAndReceiveSASLHandshake(saslType string, version int16) error { - rb := &SaslHandshakeRequest{Mechanism: saslType, Version: version} +func (b *Broker) sendAndReceiveSASLHandshake(saslType SASLMechanism, version int16) error { + rb := &SaslHandshakeRequest{Mechanism: string(saslType), Version: version} req := &request{correlationID: b.correlationID, clientID: b.conf.ClientID, body: rb} buf, err := encode(req, b.conf.MetricRegistry) @@ -846,7 +869,7 @@ func (b *Broker) sendAndReceiveSASLHandshake(saslType string, version int16) err Logger.Printf("Invalid SASL Mechanism : %s\n", res.Err.Error()) return res.Err } - Logger.Print("Successful SASL handshake") + Logger.Print("Successful SASL handshake. Available mechanisms: ", res.EnabledMechanisms) return nil } @@ -949,6 +972,96 @@ func (b *Broker) sendAndReceiveSASLOAuth(provider AccessTokenProvider) error { return nil } +func (b *Broker) sendAndReceiveSASLSCRAMv1() error { + if err := b.sendAndReceiveSASLHandshake(b.conf.Net.SASL.Mechanism, SASLHandshakeV1); err != nil { + return err + } + + scramClient := b.conf.Net.SASL.SCRAMClient + if err := scramClient.Begin(b.conf.Net.SASL.User, b.conf.Net.SASL.Password, b.conf.Net.SASL.SCRAMAuthzID); err != nil { + return fmt.Errorf("failed to start SCRAM exchange with the server: %s", err.Error()) + } + + msg, err := scramClient.Step("") + if err != nil { + return fmt.Errorf("failed to advance the SCRAM exchange: %s", err.Error()) + + } + + for !scramClient.Done() { + requestTime := time.Now() + correlationID := b.correlationID + bytesWritten, err := b.sendSaslAuthenticateRequest(correlationID, []byte(msg)) + if err != nil { + Logger.Printf("Failed to write SASL auth header to broker %s: %s\n", b.addr, err.Error()) + return err + } + + b.updateOutgoingCommunicationMetrics(bytesWritten) + b.correlationID++ + challenge, err := b.receiveSaslAuthenticateResponse(correlationID) + if err != nil { + Logger.Printf("Failed to read response while authenticating with SASL to broker %s: %s\n", b.addr, err.Error()) + return err + } + + b.updateIncomingCommunicationMetrics(len(challenge), time.Since(requestTime)) + msg, err = scramClient.Step(string(challenge)) + if err != nil { + Logger.Println("SASL authentication failed", err) + return err + } + } + Logger.Println("SASL authentication succeeded") + return nil +} + +func (b *Broker) sendSaslAuthenticateRequest(correlationID int32, msg []byte) (int, error) { + rb := &SaslAuthenticateRequest{msg} + req := &request{correlationID: correlationID, clientID: b.conf.ClientID, body: rb} + buf, err := encode(req, b.conf.MetricRegistry) + if err != nil { + return 0, err + } + if err := b.conn.SetWriteDeadline(time.Now().Add(b.conf.Net.WriteTimeout)); err != nil { + return 0, err + } + return b.conn.Write(buf) +} + +func (b *Broker) receiveSaslAuthenticateResponse(correlationID int32) ([]byte, error) { + buf := make([]byte, responseLengthSize+correlationIDSize) + bytesRead, err := io.ReadFull(b.conn, buf) + if err != nil { + return nil, err + } + header := responseHeader{} + err = decode(buf, &header) + if err != nil { + return nil, err + } + if header.correlationID != correlationID { + return nil, fmt.Errorf("correlation ID didn't match, wanted %d, got %d", b.correlationID, header.correlationID) + } + buf = make([]byte, header.length-correlationIDSize) + c, err := io.ReadFull(b.conn, buf) + bytesRead += c + if err != nil { + return nil, err + } + res := &SaslAuthenticateResponse{} + if err := versionedDecode(buf, res, 0); err != nil { + return nil, err + } + if err != nil { + return nil, err + } + if res.Err != ErrNoError { + return nil, res.Err + } + return res.SaslAuthBytes, nil +} + // Build SASL/OAUTHBEARER initial client response as described by RFC-7628 // https://tools.ietf.org/html/rfc7628 func buildClientInitialResponse(token *AccessToken) ([]byte, error) { diff --git a/broker_test.go b/broker_test.go index 009a2a66a..a3b17af4f 100644 --- a/broker_test.go +++ b/broker_test.go @@ -179,16 +179,12 @@ func TestSASLOAuthBearer(t *testing.T) { // mockBroker mocks underlying network logic and broker responses mockBroker := NewMockBroker(t, 0) - mockSASLAuthResponse := NewMockSaslAuthenticateResponse(t). - SetAuthBytes([]byte(`response_payload`)) - + mockSASLAuthResponse := NewMockSaslAuthenticateResponse(t).SetAuthBytes([]byte("response_payload")) if test.mockAuthErr != ErrNoError { mockSASLAuthResponse = mockSASLAuthResponse.SetError(test.mockAuthErr) } - mockSASLHandshakeResponse := NewMockSaslHandshakeResponse(t). - SetEnabledMechanisms([]string{SASLTypeOAuth}) - + mockSASLHandshakeResponse := NewMockSaslHandshakeResponse(t).SetEnabledMechanisms([]string{SASLTypeOAuth}) if test.mockHandshakeErr != ErrNoError { mockSASLHandshakeResponse = mockSASLHandshakeResponse.SetError(test.mockHandshakeErr) } @@ -248,6 +244,139 @@ func TestSASLOAuthBearer(t *testing.T) { } } +// A mock scram client. +type MockSCRAMClient struct { + done bool +} + +func (m *MockSCRAMClient) Begin(userName, password, authzID string) (err error) { + return nil +} + +func (m *MockSCRAMClient) Step(challenge string) (response string, err error) { + if challenge == "" { + return "ping", nil + } + if challenge == "pong" { + m.done = true + return "", nil + } + return "", errors.New("failed to authenticate :(") +} + +func (m *MockSCRAMClient) Done() bool { + return m.done +} + +var _ SCRAMClient = &MockSCRAMClient{} + +func TestSASLSCRAMSHAXXX(t *testing.T) { + testTable := []struct { + name string + mockHandshakeErr KError + mockSASLAuthErr KError + expectClientErr bool + scramClient *MockSCRAMClient + scramChallengeResp string + }{ + { + name: "SASL/SCRAMSHAXXX successfull authentication", + mockHandshakeErr: ErrNoError, + scramClient: &MockSCRAMClient{}, + scramChallengeResp: "pong", + }, + { + name: "SASL/SCRAMSHAXXX SCRAM client step error client", + mockHandshakeErr: ErrNoError, + mockSASLAuthErr: ErrNoError, + scramClient: &MockSCRAMClient{}, + scramChallengeResp: "gong", + expectClientErr: true, + }, + { + name: "SASL/SCRAMSHAXXX server authentication error", + mockHandshakeErr: ErrNoError, + mockSASLAuthErr: ErrSASLAuthenticationFailed, + scramClient: &MockSCRAMClient{}, + scramChallengeResp: "pong", + }, + { + name: "SASL/SCRAMSHAXXX unsupported SCRAM mechanism", + mockHandshakeErr: ErrUnsupportedSASLMechanism, + mockSASLAuthErr: ErrNoError, + scramClient: &MockSCRAMClient{}, + scramChallengeResp: "pong", + }, + } + + for i, test := range testTable { + + // mockBroker mocks underlying network logic and broker responses + mockBroker := NewMockBroker(t, 0) + broker := NewBroker(mockBroker.Addr()) + // broker executes SASL requests against mockBroker + broker.requestRate = metrics.NilMeter{} + broker.outgoingByteRate = metrics.NilMeter{} + broker.incomingByteRate = metrics.NilMeter{} + broker.requestSize = metrics.NilHistogram{} + broker.responseSize = metrics.NilHistogram{} + broker.responseRate = metrics.NilMeter{} + broker.requestLatency = metrics.NilHistogram{} + + mockSASLAuthResponse := NewMockSaslAuthenticateResponse(t).SetAuthBytes([]byte(test.scramChallengeResp)) + mockSASLHandshakeResponse := NewMockSaslHandshakeResponse(t).SetEnabledMechanisms([]string{SASLTypeSCRAMSHA256, SASLTypeSCRAMSHA512}) + + if test.mockSASLAuthErr != ErrNoError { + mockSASLAuthResponse = mockSASLAuthResponse.SetError(test.mockSASLAuthErr) + } + if test.mockHandshakeErr != ErrNoError { + mockSASLHandshakeResponse = mockSASLHandshakeResponse.SetError(test.mockHandshakeErr) + } + + mockBroker.SetHandlerByMap(map[string]MockResponse{ + "SaslAuthenticateRequest": mockSASLAuthResponse, + "SaslHandshakeRequest": mockSASLHandshakeResponse, + }) + + conf := NewConfig() + conf.Net.SASL.Mechanism = SASLTypeSCRAMSHA512 + conf.Net.SASL.SCRAMClient = test.scramClient + + broker.conf = conf + dialer := net.Dialer{ + Timeout: conf.Net.DialTimeout, + KeepAlive: conf.Net.KeepAlive, + LocalAddr: conf.Net.LocalAddr, + } + + conn, err := dialer.Dial("tcp", mockBroker.listener.Addr().String()) + + if err != nil { + t.Fatal(err) + } + + broker.conn = conn + + err = broker.authenticateViaSASL() + + if test.mockSASLAuthErr != ErrNoError { + if test.mockSASLAuthErr != err { + t.Errorf("[%d]:[%s] Expected %s SASL authentication error, got %s\n", i, test.name, test.mockHandshakeErr, err) + } + } else if test.mockHandshakeErr != ErrNoError { + if test.mockHandshakeErr != err { + t.Errorf("[%d]:[%s] Expected %s handshake error, got %s\n", i, test.name, test.mockHandshakeErr, err) + } + } else if test.expectClientErr && err == nil { + t.Errorf("[%d]:[%s] Expected a client error and got none\n", i, test.name) + } else if !test.expectClientErr && err != nil { + t.Errorf("[%d]:[%s] Unexpected error, got %s\n", i, test.name, err) + } + + mockBroker.Close() + } +} + func TestBuildClientInitialResponse(t *testing.T) { testTable := []struct { diff --git a/config.go b/config.go index 9495b7f5a..6fa8bb940 100644 --- a/config.go +++ b/config.go @@ -61,9 +61,14 @@ type Config struct { // (defaults to true). You should only set this to false if you're using // a non-Kafka SASL proxy. Handshake bool - //username and password for SASL/PLAIN authentication + //username and password for SASL/PLAIN or SASL/SCRAM authentication User string Password string + // authz id used for SASL/SCRAM authentication + SCRAMAuthzID string + // SCRAMClient is a user provided implementation of a SCRAM + // client used to perform the SCRAM exchange with the server. + SCRAMClient SCRAMClient // TokenProvider is a user-defined callback for generating // access tokens for SASL/OAUTHBEARER auth. See the // AccessTokenProvider interface docs for proper implementation @@ -475,22 +480,35 @@ func (c *Config) Validate() error { case c.Net.KeepAlive < 0: return ConfigurationError("Net.KeepAlive must be >= 0") case c.Net.SASL.Enable: - // For backwards compatibility, empty mechanism value defaults to PLAIN - isSASLPlain := len(c.Net.SASL.Mechanism) == 0 || c.Net.SASL.Mechanism == SASLTypePlaintext - if isSASLPlain { + if c.Net.SASL.Mechanism == "" { + c.Net.SASL.Mechanism = SASLTypePlaintext + } + + switch c.Net.SASL.Mechanism { + case SASLTypePlaintext: if c.Net.SASL.User == "" { return ConfigurationError("Net.SASL.User must not be empty when SASL is enabled") } if c.Net.SASL.Password == "" { return ConfigurationError("Net.SASL.Password must not be empty when SASL is enabled") } - } else if c.Net.SASL.Mechanism == SASLTypeOAuth { + case SASLTypeOAuth: if c.Net.SASL.TokenProvider == nil { - return ConfigurationError("An AccessTokenProvider instance must be provided to Net.SASL.User.TokenProvider") + return ConfigurationError("An AccessTokenProvider instance must be provided to Net.SASL.TokenProvider") + } + case SASLTypeSCRAMSHA256, SASLTypeSCRAMSHA512: + if c.Net.SASL.User == "" { + return ConfigurationError("Net.SASL.User must not be empty when SASL is enabled") + } + if c.Net.SASL.Password == "" { + return ConfigurationError("Net.SASL.Password must not be empty when SASL is enabled") + } + if c.Net.SASL.SCRAMClient == nil { + return ConfigurationError("A SCRAMClient instance must be provided to Net.SASL.SCRAMClient") } - } else { - msg := fmt.Sprintf("The SASL mechanism configuration is invalid. Possible values are `%s` and `%s`", - SASLTypeOAuth, SASLTypePlaintext) + default: + msg := fmt.Sprintf("The SASL mechanism configuration is invalid. Possible values are `%s`, `%s`, `%s` and `%s`", + SASLTypeOAuth, SASLTypePlaintext, SASLTypeSCRAMSHA256, SASLTypeSCRAMSHA512) return ConfigurationError(msg) } } diff --git a/config_test.go b/config_test.go index 3ea718d04..3ba9a2023 100644 --- a/config_test.go +++ b/config_test.go @@ -91,14 +91,32 @@ func TestNetConfigValidates(t *testing.T) { cfg.Net.SASL.Mechanism = "AnIncorrectSASLMechanism" cfg.Net.SASL.TokenProvider = &DummyTokenProvider{} }, - "The SASL mechanism configuration is invalid. Possible values are `OAUTHBEARER` and `PLAIN`"}, + "The SASL mechanism configuration is invalid. Possible values are `OAUTHBEARER`, `PLAIN`, `SCRAM-SHA-256` and `SCRAM-SHA-512`"}, {"SASL.Mechanism.OAUTHBEARER - Missing token provider", func(cfg *Config) { cfg.Net.SASL.Enable = true cfg.Net.SASL.Mechanism = SASLTypeOAuth cfg.Net.SASL.TokenProvider = nil }, - "An AccessTokenProvider instance must be provided to Net.SASL.User.TokenProvider"}, + "An AccessTokenProvider instance must be provided to Net.SASL.TokenProvider"}, + {"SASL.Mechanism SCRAM-SHA-256 - Missing SCRAM client", + func(cfg *Config) { + cfg.Net.SASL.Enable = true + cfg.Net.SASL.Mechanism = SASLTypeSCRAMSHA256 + cfg.Net.SASL.SCRAMClient = nil + cfg.Net.SASL.User = "user" + cfg.Net.SASL.Password = "stong_password" + }, + "A SCRAMClient instance must be provided to Net.SASL.SCRAMClient"}, + {"SASL.Mechanism SCRAM-SHA-512 - Missing SCRAM client", + func(cfg *Config) { + cfg.Net.SASL.Enable = true + cfg.Net.SASL.Mechanism = SASLTypeSCRAMSHA512 + cfg.Net.SASL.SCRAMClient = nil + cfg.Net.SASL.User = "user" + cfg.Net.SASL.Password = "stong_password" + }, + "A SCRAMClient instance must be provided to Net.SASL.SCRAMClient"}, } for i, test := range tests { diff --git a/response_header.go b/response_header.go index f3f4d27d6..7a7591851 100644 --- a/response_header.go +++ b/response_header.go @@ -2,6 +2,9 @@ package sarama import "fmt" +const responseLengthSize = 4 +const correlationIDSize = 4 + type responseHeader struct { length int32 correlationID int32