Skip to content

Commit

Permalink
fix(metrics): fix race when accessing metric registry (#2409)
Browse files Browse the repository at this point in the history
A race condition was introduced in
5b04c98 (feat(metrics): track
consumer-fetch-response-size) when passing the metric registry around to
get additional metrics. Notably, `handleResponsePromise()` could access
the registry after the broker has been closed and is tentatively being
reopened. This triggers a data race because `b.metricRegistry` is being
set during `Open()` (as it is part of the configuration).

We fix this by reverting the addition of `handleResponsePromise()` as a
method to `Broker`. Instead, we provide it with the metric registry as
an argument. An alternative would have been to get the metric registry
before the `select` call. However, removing it as a method make it
clearer than this function is not allowed to access the broker internals
as they are not protected by the lock and the broker may not be alive
any more.

All the following calls to `b.metricRegistry` are done while the lock is
held:

- inside `Open()`, the lock is held, including inside the goroutine
- inside `Close()`, the lock is held
- `AsyncProduce()` has a contract that it must be called while the broker
  is open, we keep a copy of the metric registry to use inside the callback
- `sendInternal()`, has a contract that the lock should be held
- `authenticateViaSASLv1()` is called from `Open()` and
  `sendWithPromise()`, both of them holding the lock
- `sendAndReceiveSASLHandshake()` is called from
- `authenticateViaSASLv0/v1()`, which are called from `Open()` and
  `sendWithPromise()`

I am unsure about `responseReceiver()`, however, it is also calling
`b.readFull()` which accesses `b.conn`, so I suppose it is safe.

This leaves `sendAndReceive()` which is calling `send()`, which is
calling `sendWithPromise()` which puts a lock. We move the lock to
`sendAndReceive()` instead. `send()` is only called from
`sendAndReceiver()` and we put a lock for `sendWithPromise()` other
caller.

The test has been stolen from #2393 from @samuelhewitt. #2393 is an
alternative proposal using a RW lock to protect `b.metricRegistry`.

Fix #2320
  • Loading branch information
vincentbernat authored Jan 10, 2023
1 parent 67d977b commit b0eda59
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 10 deletions.
23 changes: 14 additions & 9 deletions broker.go
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,7 @@ type ProduceCallback func(*ProduceResponse, error)
//
// Make sure not to Close the broker in the callback as it will lead to a deadlock.
func (b *Broker) AsyncProduce(request *ProduceRequest, cb ProduceCallback) error {
metricRegistry := b.metricRegistry
needAcks := request.RequiredAcks != NoResponse
// Use a nil promise when no acks is required
var promise *responsePromise
Expand All @@ -446,7 +447,7 @@ func (b *Broker) AsyncProduce(request *ProduceRequest, cb ProduceCallback) error
return
}

if err := versionedDecode(packets, res, request.version(), b.metricRegistry); err != nil {
if err := versionedDecode(packets, res, request.version(), metricRegistry); err != nil {
// Malformed response
cb(nil, err)
return
Expand All @@ -459,6 +460,8 @@ func (b *Broker) AsyncProduce(request *ProduceRequest, cb ProduceCallback) error
}
}

b.lock.Lock()
defer b.lock.Unlock()
return b.sendWithPromise(request, promise)
}

Expand Down Expand Up @@ -939,6 +942,7 @@ func (b *Broker) write(buf []byte) (n int, err error) {
return b.conn.Write(buf)
}

// b.lock must be haled by caller
func (b *Broker) send(rb protocolBody, promiseResponse bool, responseHeaderVersion int16) (*responsePromise, error) {
var promise *responsePromise
if promiseResponse {
Expand All @@ -963,10 +967,8 @@ func makeResponsePromise(responseHeaderVersion int16) *responsePromise {
return promise
}

// b.lock must be held by caller
func (b *Broker) sendWithPromise(rb protocolBody, promise *responsePromise) error {
b.lock.Lock()
defer b.lock.Unlock()

if b.conn == nil {
if b.connErr != nil {
return b.connErr
Expand Down Expand Up @@ -1022,6 +1024,8 @@ func (b *Broker) sendInternal(rb protocolBody, promise *responsePromise) error {
}

func (b *Broker) sendAndReceive(req protocolBody, res protocolBody) error {
b.lock.Lock()
defer b.lock.Unlock()
responseHeaderVersion := int16(-1)
if res != nil {
responseHeaderVersion = res.headerVersion()
Expand All @@ -1036,13 +1040,13 @@ func (b *Broker) sendAndReceive(req protocolBody, res protocolBody) error {
return nil
}

return b.handleResponsePromise(req, res, promise)
return handleResponsePromise(req, res, promise, b.metricRegistry)
}

func (b *Broker) handleResponsePromise(req protocolBody, res protocolBody, promise *responsePromise) error {
func handleResponsePromise(req protocolBody, res protocolBody, promise *responsePromise, metricRegistry metrics.Registry) error {
select {
case buf := <-promise.packets:
return versionedDecode(buf, res, req.version(), b.metricRegistry)
return versionedDecode(buf, res, req.version(), metricRegistry)
case err := <-promise.errors:
return err
}
Expand Down Expand Up @@ -1185,6 +1189,7 @@ func (b *Broker) authenticateViaSASLv0() error {
}

func (b *Broker) authenticateViaSASLv1() error {
metricRegistry := b.metricRegistry
if b.conf.Net.SASL.Handshake {
handshakeRequest := &SaslHandshakeRequest{Mechanism: string(b.conf.Net.SASL.Mechanism), Version: b.conf.Net.SASL.Version}
handshakeResponse := new(SaslHandshakeResponse)
Expand All @@ -1195,7 +1200,7 @@ func (b *Broker) authenticateViaSASLv1() error {
Logger.Printf("Error while performing SASL handshake %s\n", b.addr)
return handshakeErr
}
handshakeErr = b.handleResponsePromise(handshakeRequest, handshakeResponse, prom)
handshakeErr = handleResponsePromise(handshakeRequest, handshakeResponse, prom, metricRegistry)
if handshakeErr != nil {
Logger.Printf("Error while performing SASL handshake %s\n", b.addr)
return handshakeErr
Expand All @@ -1215,7 +1220,7 @@ func (b *Broker) authenticateViaSASLv1() error {
Logger.Printf("Error while performing SASL Auth %s\n", b.addr)
return nil, authErr
}
authErr = b.handleResponsePromise(authenticateRequest, authenticateResponse, prom)
authErr = handleResponsePromise(authenticateRequest, authenticateResponse, prom, metricRegistry)
if authErr != nil {
Logger.Printf("Error while performing SASL Auth %s\n", b.addr)
return nil, authErr
Expand Down
2 changes: 1 addition & 1 deletion broker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ func TestSimpleBrokerCommunication(t *testing.T) {
pendingNotify <- brokerMetrics{bytesRead, bytesWritten}
})
broker := NewBroker(mb.Addr())
// Set the broker id in order to validate local broujhjker metrics
// Set the broker id in order to validate local broker metrics
broker.id = 0
conf := NewTestConfig()
conf.ApiVersionsRequest = false
Expand Down
107 changes: 107 additions & 0 deletions consumer_group_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ package sarama

import (
"context"
"errors"
"sync"
"testing"
"time"
)

type handler struct {
Expand Down Expand Up @@ -93,3 +95,108 @@ func TestConsumerGroupNewSessionDuringOffsetLoad(t *testing.T) {
}()
wg.Wait()
}

func TestConsume_RaceTest(t *testing.T) {
const groupID = "test-group"
const topic = "test-topic"
const offsetStart = int64(1234)

cfg := NewConfig()
cfg.Version = V2_8_1_0
cfg.Consumer.Return.Errors = true

seedBroker := NewMockBroker(t, 1)

joinGroupResponse := &JoinGroupResponse{}

syncGroupResponse := &SyncGroupResponse{
Version: 3, // sarama > 2.3.0.0 uses version 3
}
// Leverage mock response to get the MemberAssignment bytes
mockSyncGroupResponse := NewMockSyncGroupResponse(t).SetMemberAssignment(&ConsumerGroupMemberAssignment{
Version: 1,
Topics: map[string][]int32{topic: {0}}, // map "test-topic" to partition 0
UserData: []byte{0x01},
})
syncGroupResponse.MemberAssignment = mockSyncGroupResponse.MemberAssignment

heartbeatResponse := &HeartbeatResponse{
Err: ErrNoError,
}
offsetFetchResponse := &OffsetFetchResponse{
Version: 1,
ThrottleTimeMs: 0,
Err: ErrNoError,
}
offsetFetchResponse.AddBlock(topic, 0, &OffsetFetchResponseBlock{
Offset: offsetStart,
LeaderEpoch: 0,
Metadata: "",
Err: ErrNoError})

offsetResponse := &OffsetResponse{
Version: 1,
}
offsetResponse.AddTopicPartition(topic, 0, offsetStart)

metadataResponse := new(MetadataResponse)
metadataResponse.AddBroker(seedBroker.Addr(), seedBroker.BrokerID())
metadataResponse.AddTopic("mismatched-topic", ErrUnknownTopicOrPartition)

handlerMap := map[string]MockResponse{
"ApiVersionsRequest": NewMockApiVersionsResponse(t),
"MetadataRequest": NewMockSequence(metadataResponse),
"OffsetRequest": NewMockSequence(offsetResponse),
"OffsetFetchRequest": NewMockSequence(offsetFetchResponse),
"FindCoordinatorRequest": NewMockSequence(NewMockFindCoordinatorResponse(t).
SetCoordinator(CoordinatorGroup, groupID, seedBroker)),
"JoinGroupRequest": NewMockSequence(joinGroupResponse),
"SyncGroupRequest": NewMockSequence(syncGroupResponse),
"HeartbeatRequest": NewMockSequence(heartbeatResponse),
}
seedBroker.SetHandlerByMap(handlerMap)

cancelCtx, cancel := context.WithDeadline(context.Background(), time.Now().Add(4*time.Second))

defer seedBroker.Close()

retryWait := 20 * time.Millisecond
var err error
clientRetries := 0
outerFor:
for {
_, err = NewConsumerGroup([]string{seedBroker.Addr()}, groupID, cfg)
if err == nil {
break
}

if retryWait < time.Minute {
retryWait *= 2
}

clientRetries++

timer := time.NewTimer(retryWait)
select {
case <-cancelCtx.Done():
err = cancelCtx.Err()
timer.Stop()
break outerFor
case <-timer.C:
}
timer.Stop()
}
if err == nil {
t.Fatalf("should not proceed to Consume")
}

if clientRetries <= 0 {
t.Errorf("clientRetries = %v; want > 0", clientRetries)
}

if err != nil && !errors.Is(err, context.DeadlineExceeded) {
t.Fatal(err)
}

cancel()
}

0 comments on commit b0eda59

Please sign in to comment.