Skip to content

Commit

Permalink
Add SASL SCRAM-SHA-512 and SCRAM-SHA-256 mechanismes
Browse files Browse the repository at this point in the history
  • Loading branch information
iyedbennour committed Feb 27, 2019
1 parent 6bc31ae commit 5557853
Show file tree
Hide file tree
Showing 5 changed files with 303 additions and 22 deletions.
123 changes: 118 additions & 5 deletions broker.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ const (
SASLTypeOAuth = "OAUTHBEARER"
// SASLTypePlaintext represents the SASL/PLAIN mechanism
SASLTypePlaintext = "PLAIN"
// SASLTypeSCRAMSHA256 represents the SCRAM-SHA-256 mechanism.
SASLTypeSCRAMSHA256 = "SCRAM-SHA-256"
// SASLTypeSCRAMSHA512 represents the SCRAM-SHA-512 mechanism.
SASLTypeSCRAMSHA512 = "SCRAM-SHA-512"
// SASLHandshakeV0 is v0 of the Kafka SASL handshake protocol. Client and
// server negotiate SASL auth using opaque packets.
SASLHandshakeV0 = int16(0)
Expand Down Expand Up @@ -92,6 +96,20 @@ type AccessTokenProvider interface {
Token() (*AccessToken, error)
}

// SCRAMClient is a an interface to a SCRAM
// client implementation.
type SCRAMClient interface {
// Begin prepares the client for the SCRAM exchange
// with the server with a user name and a password
Begin(userName, password, authzID string) error
// Step steps client through the SCRAM exchange. It is
// called repeatedly until it errors or `Done` returns true.
Step(challenge string) (response string, err error)
// Done should return true when the SCRAM conversation
// is over.
Done() bool
}

type responsePromise struct {
requestTime time.Time
correlationID int32
Expand Down Expand Up @@ -793,14 +811,19 @@ func (b *Broker) responseReceiver() {
}

func (b *Broker) authenticateViaSASL() error {
if b.conf.Net.SASL.Mechanism == SASLTypeOAuth {
switch b.conf.Net.SASL.Mechanism {
case SASLTypeOAuth:
return b.sendAndReceiveSASLOAuth(b.conf.Net.SASL.TokenProvider)
case SASLTypeSCRAMSHA256, SASLTypeSCRAMSHA512:
return b.sendAndReceiveSASLSCRAMv1()
default:
return b.sendAndReceiveSASLPlainAuth()
}
return b.sendAndReceiveSASLPlainAuth()

}

func (b *Broker) sendAndReceiveSASLHandshake(saslType string, version int16) error {
rb := &SaslHandshakeRequest{Mechanism: saslType, Version: version}
func (b *Broker) sendAndReceiveSASLHandshake(saslType SASLMechanism, version int16) error {
rb := &SaslHandshakeRequest{Mechanism: string(saslType), Version: version}

req := &request{correlationID: b.correlationID, clientID: b.conf.ClientID, body: rb}
buf, err := encode(req, b.conf.MetricRegistry)
Expand Down Expand Up @@ -846,7 +869,7 @@ func (b *Broker) sendAndReceiveSASLHandshake(saslType string, version int16) err
Logger.Printf("Invalid SASL Mechanism : %s\n", res.Err.Error())
return res.Err
}
Logger.Print("Successful SASL handshake")
Logger.Print("Successful SASL handshake. Available mechanisms: ", res.EnabledMechanisms)
return nil
}

Expand Down Expand Up @@ -949,6 +972,96 @@ func (b *Broker) sendAndReceiveSASLOAuth(provider AccessTokenProvider) error {
return nil
}

func (b *Broker) sendAndReceiveSASLSCRAMv1() error {
if err := b.sendAndReceiveSASLHandshake(b.conf.Net.SASL.Mechanism, SASLHandshakeV1); err != nil {
return err
}

scramClient := b.conf.Net.SASL.SCRAMClient
if err := scramClient.Begin(b.conf.Net.SASL.User, b.conf.Net.SASL.Password, b.conf.Net.SASL.SCRAMAuthzID); err != nil {
return fmt.Errorf("failed to start SCRAM exchange with the server: %s", err.Error())
}

msg, err := scramClient.Step("")
if err != nil {
return fmt.Errorf("failed to advance the SCRAM exchange: %s", err.Error())

}

for !scramClient.Done() {
requestTime := time.Now()
correlationID := b.correlationID
bytesWritten, err := b.sendSaslAuthenticateRequest(correlationID, []byte(msg))
if err != nil {
Logger.Printf("Failed to write SASL auth header to broker %s: %s\n", b.addr, err.Error())
return err
}

b.updateOutgoingCommunicationMetrics(bytesWritten)
b.correlationID++
challenge, err := b.receiveSaslAuthenticateResponse(correlationID)
if err != nil {
Logger.Printf("Failed to read response while authenticating with SASL to broker %s: %s\n", b.addr, err.Error())
return err
}

b.updateIncomingCommunicationMetrics(len(challenge), time.Since(requestTime))
msg, err = scramClient.Step(string(challenge))
if err != nil {
Logger.Println("SASL authentication failed", err)
return err
}
}
Logger.Println("SASL authentication succeeded")
return nil
}

func (b *Broker) sendSaslAuthenticateRequest(correlationID int32, msg []byte) (int, error) {
rb := &SaslAuthenticateRequest{msg}
req := &request{correlationID: correlationID, clientID: b.conf.ClientID, body: rb}
buf, err := encode(req, b.conf.MetricRegistry)
if err != nil {
return 0, err
}
if err := b.conn.SetWriteDeadline(time.Now().Add(b.conf.Net.WriteTimeout)); err != nil {
return 0, err
}
return b.conn.Write(buf)
}

func (b *Broker) receiveSaslAuthenticateResponse(correlationID int32) ([]byte, error) {
buf := make([]byte, responseLengthSize+correlationIDSize)
bytesRead, err := io.ReadFull(b.conn, buf)
if err != nil {
return nil, err
}
header := responseHeader{}
err = decode(buf, &header)
if err != nil {
return nil, err
}
if header.correlationID != correlationID {
return nil, fmt.Errorf("correlation ID didn't match, wanted %d, got %d", b.correlationID, header.correlationID)
}
buf = make([]byte, header.length-correlationIDSize)
c, err := io.ReadFull(b.conn, buf)
bytesRead += c
if err != nil {
return nil, err
}
res := &SaslAuthenticateResponse{}
if err := versionedDecode(buf, res, 0); err != nil {
return nil, err
}
if err != nil {
return nil, err
}
if res.Err != ErrNoError {
return nil, res.Err
}
return res.SaslAuthBytes, nil
}

// Build SASL/OAUTHBEARER initial client response as described by RFC-7628
// https://tools.ietf.org/html/rfc7628
func buildClientInitialResponse(token *AccessToken) ([]byte, error) {
Expand Down
141 changes: 135 additions & 6 deletions broker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,16 +179,12 @@ func TestSASLOAuthBearer(t *testing.T) {
// mockBroker mocks underlying network logic and broker responses
mockBroker := NewMockBroker(t, 0)

mockSASLAuthResponse := NewMockSaslAuthenticateResponse(t).
SetAuthBytes([]byte(`response_payload`))

mockSASLAuthResponse := NewMockSaslAuthenticateResponse(t).SetAuthBytes([]byte("response_payload"))
if test.mockAuthErr != ErrNoError {
mockSASLAuthResponse = mockSASLAuthResponse.SetError(test.mockAuthErr)
}

mockSASLHandshakeResponse := NewMockSaslHandshakeResponse(t).
SetEnabledMechanisms([]string{SASLTypeOAuth})

mockSASLHandshakeResponse := NewMockSaslHandshakeResponse(t).SetEnabledMechanisms([]string{SASLTypeOAuth})
if test.mockHandshakeErr != ErrNoError {
mockSASLHandshakeResponse = mockSASLHandshakeResponse.SetError(test.mockHandshakeErr)
}
Expand Down Expand Up @@ -248,6 +244,139 @@ func TestSASLOAuthBearer(t *testing.T) {
}
}

// A mock scram client.
type MockSCRAMClient struct {
done bool
}

func (m *MockSCRAMClient) Begin(userName, password, authzID string) (err error) {
return nil
}

func (m *MockSCRAMClient) Step(challenge string) (response string, err error) {
if challenge == "" {
return "ping", nil
}
if challenge == "pong" {
m.done = true
return "", nil
}
return "", errors.New("failed to authenticate :(")
}

func (m *MockSCRAMClient) Done() bool {
return m.done
}

var _ SCRAMClient = &MockSCRAMClient{}

func TestSASLSCRAMSHAXXX(t *testing.T) {
testTable := []struct {
name string
mockHandshakeErr KError
mockSASLAuthErr KError
expectClientErr bool
scramClient *MockSCRAMClient
scramChallengeResp string
}{
{
name: "SASL/SCRAMSHAXXX successfull authentication",
mockHandshakeErr: ErrNoError,
scramClient: &MockSCRAMClient{},
scramChallengeResp: "pong",
},
{
name: "SASL/SCRAMSHAXXX SCRAM client step error client",
mockHandshakeErr: ErrNoError,
mockSASLAuthErr: ErrNoError,
scramClient: &MockSCRAMClient{},
scramChallengeResp: "gong",
expectClientErr: true,
},
{
name: "SASL/SCRAMSHAXXX server authentication error",
mockHandshakeErr: ErrNoError,
mockSASLAuthErr: ErrSASLAuthenticationFailed,
scramClient: &MockSCRAMClient{},
scramChallengeResp: "pong",
},
{
name: "SASL/SCRAMSHAXXX unsupported SCRAM mechanism",
mockHandshakeErr: ErrUnsupportedSASLMechanism,
mockSASLAuthErr: ErrNoError,
scramClient: &MockSCRAMClient{},
scramChallengeResp: "pong",
},
}

for i, test := range testTable {

// mockBroker mocks underlying network logic and broker responses
mockBroker := NewMockBroker(t, 0)
broker := NewBroker(mockBroker.Addr())
// broker executes SASL requests against mockBroker
broker.requestRate = metrics.NilMeter{}
broker.outgoingByteRate = metrics.NilMeter{}
broker.incomingByteRate = metrics.NilMeter{}
broker.requestSize = metrics.NilHistogram{}
broker.responseSize = metrics.NilHistogram{}
broker.responseRate = metrics.NilMeter{}
broker.requestLatency = metrics.NilHistogram{}

mockSASLAuthResponse := NewMockSaslAuthenticateResponse(t).SetAuthBytes([]byte(test.scramChallengeResp))
mockSASLHandshakeResponse := NewMockSaslHandshakeResponse(t).SetEnabledMechanisms([]string{SASLTypeSCRAMSHA256, SASLTypeSCRAMSHA512})

if test.mockSASLAuthErr != ErrNoError {
mockSASLAuthResponse = mockSASLAuthResponse.SetError(test.mockSASLAuthErr)
}
if test.mockHandshakeErr != ErrNoError {
mockSASLHandshakeResponse = mockSASLHandshakeResponse.SetError(test.mockHandshakeErr)
}

mockBroker.SetHandlerByMap(map[string]MockResponse{
"SaslAuthenticateRequest": mockSASLAuthResponse,
"SaslHandshakeRequest": mockSASLHandshakeResponse,
})

conf := NewConfig()
conf.Net.SASL.Mechanism = SASLTypeSCRAMSHA512
conf.Net.SASL.SCRAMClient = test.scramClient

broker.conf = conf
dialer := net.Dialer{
Timeout: conf.Net.DialTimeout,
KeepAlive: conf.Net.KeepAlive,
LocalAddr: conf.Net.LocalAddr,
}

conn, err := dialer.Dial("tcp", mockBroker.listener.Addr().String())

if err != nil {
t.Fatal(err)
}

broker.conn = conn

err = broker.authenticateViaSASL()

if test.mockSASLAuthErr != ErrNoError {
if test.mockSASLAuthErr != err {
t.Errorf("[%d]:[%s] Expected %s SASL authentication error, got %s\n", i, test.name, test.mockHandshakeErr, err)
}
} else if test.mockHandshakeErr != ErrNoError {
if test.mockHandshakeErr != err {
t.Errorf("[%d]:[%s] Expected %s handshake error, got %s\n", i, test.name, test.mockHandshakeErr, err)
}
} else if test.expectClientErr && err == nil {
t.Errorf("[%d]:[%s] Expected a client error and got none\n", i, test.name)
} else if !test.expectClientErr && err != nil {
t.Errorf("[%d]:[%s] Unexpected error, got %s\n", i, test.name, err)
}

mockBroker.Close()
}
}

func TestBuildClientInitialResponse(t *testing.T) {

testTable := []struct {
Expand Down
36 changes: 27 additions & 9 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,14 @@ type Config struct {
// (defaults to true). You should only set this to false if you're using
// a non-Kafka SASL proxy.
Handshake bool
//username and password for SASL/PLAIN authentication
//username and password for SASL/PLAIN or SASL/SCRAM authentication
User string
Password string
// authz id used for SASL/SCRAM authentication
SCRAMAuthzID string
// SCRAMClient is a user provided implementation of a SCRAM
// client used to perform the SCRAM exchange with the server.
SCRAMClient SCRAMClient
// TokenProvider is a user-defined callback for generating
// access tokens for SASL/OAUTHBEARER auth. See the
// AccessTokenProvider interface docs for proper implementation
Expand Down Expand Up @@ -475,22 +480,35 @@ func (c *Config) Validate() error {
case c.Net.KeepAlive < 0:
return ConfigurationError("Net.KeepAlive must be >= 0")
case c.Net.SASL.Enable:
// For backwards compatibility, empty mechanism value defaults to PLAIN
isSASLPlain := len(c.Net.SASL.Mechanism) == 0 || c.Net.SASL.Mechanism == SASLTypePlaintext
if isSASLPlain {
if c.Net.SASL.Mechanism == "" {
c.Net.SASL.Mechanism = SASLTypePlaintext
}

switch c.Net.SASL.Mechanism {
case SASLTypePlaintext:
if c.Net.SASL.User == "" {
return ConfigurationError("Net.SASL.User must not be empty when SASL is enabled")
}
if c.Net.SASL.Password == "" {
return ConfigurationError("Net.SASL.Password must not be empty when SASL is enabled")
}
} else if c.Net.SASL.Mechanism == SASLTypeOAuth {
case SASLTypeOAuth:
if c.Net.SASL.TokenProvider == nil {
return ConfigurationError("An AccessTokenProvider instance must be provided to Net.SASL.User.TokenProvider")
return ConfigurationError("An AccessTokenProvider instance must be provided to Net.SASL.TokenProvider")
}
case SASLTypeSCRAMSHA256, SASLTypeSCRAMSHA512:
if c.Net.SASL.User == "" {
return ConfigurationError("Net.SASL.User must not be empty when SASL is enabled")
}
if c.Net.SASL.Password == "" {
return ConfigurationError("Net.SASL.Password must not be empty when SASL is enabled")
}
if c.Net.SASL.SCRAMClient == nil {
return ConfigurationError("A SCRAMClient instance must be provided to Net.SASL.SCRAMClient")
}
} else {
msg := fmt.Sprintf("The SASL mechanism configuration is invalid. Possible values are `%s` and `%s`",
SASLTypeOAuth, SASLTypePlaintext)
default:
msg := fmt.Sprintf("The SASL mechanism configuration is invalid. Possible values are `%s`, `%s`, `%s` and `%s`",
SASLTypeOAuth, SASLTypePlaintext, SASLTypeSCRAMSHA256, SASLTypeSCRAMSHA512)
return ConfigurationError(msg)
}
}
Expand Down
Loading

0 comments on commit 5557853

Please sign in to comment.