Skip to content

Commit

Permalink
Fix sarama consumer deadlock
Browse files Browse the repository at this point in the history
Signed-off-by: albertteoh <albert.teoh@logz.io>
  • Loading branch information
albertteoh committed Oct 25, 2020
1 parent bd59f13 commit 5652c13
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 37 deletions.
67 changes: 44 additions & 23 deletions cmd/ingester/app/consumer/consumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,13 @@ type Consumer struct {
partitionMapLock sync.Mutex
partitionsHeld int64
partitionsHeldGauge metrics.Gauge

messagesDoneChan chan string
errorsDoneChan chan string
doneWg sync.WaitGroup
}

type consumerState struct {
wg sync.WaitGroup
partitionConsumer sc.PartitionConsumer
}

Expand All @@ -68,6 +71,8 @@ func New(params Params) (*Consumer, error) {
deadlockDetector: deadlockDetector,
partitionIDToState: make(map[int32]*consumerState),
partitionsHeldGauge: partitionsHeldGauge(params.MetricsFactory),
messagesDoneChan: make(chan string),
errorsDoneChan: make(chan string),
}, nil
}

Expand All @@ -78,50 +83,65 @@ func (c *Consumer) Start() {
c.logger.Info("Starting main loop")
for pc := range c.internalConsumer.Partitions() {
c.partitionMapLock.Lock()
if p, ok := c.partitionIDToState[pc.Partition()]; ok {
// This is a guard against simultaneously draining messages
// from the last time the partition was assigned and
// processing new messages for the same partition, which may lead
// to the cleanup process not completing
p.wg.Wait()
}
c.partitionIDToState[pc.Partition()] = &consumerState{partitionConsumer: pc}
c.partitionIDToState[pc.Partition()].wg.Add(2)
c.partitionMapLock.Unlock()
c.partitionMetrics(pc.Partition()).startCounter.Inc(1)

c.doneWg.Add(2)
go c.handleMessages(pc)
go c.handleErrors(pc.Partition(), pc.Errors())
}
}()

// Expect to receive message and error handler "done" signals from each partition.
go waitForDoneSignals(c.messagesDoneChan, &c.doneWg, c.logger)
go waitForDoneSignals(c.errorsDoneChan, &c.doneWg, c.logger)
}

// waitForDoneSignals watches the doneChan for incoming "done" messages. If a message is received,
// the doneWg WaitGroup is decremented via a call to Done().
func waitForDoneSignals(doneChan <-chan string, doneWg *sync.WaitGroup, logger *zap.Logger) {
logger.Debug("Waiting for done signals")
for v := range doneChan {
logger.Debug("Received done signal", zap.String("msg", v))
doneWg.Done()
}
}

// Close closes the Consumer and underlying sarama consumer
func (c *Consumer) Close() error {
c.partitionMapLock.Lock()
for _, p := range c.partitionIDToState {
c.closePartition(p.partitionConsumer)
p.wg.Wait()
}
c.partitionMapLock.Unlock()
c.deadlockDetector.close()
// Close the internal consumer, which will close each partition consumers' message and error channels.
c.logger.Info("Closing parent consumer")
return c.internalConsumer.Close()
err := c.internalConsumer.Close()

c.logger.Debug("Closing deadlock detector")
c.deadlockDetector.close()

c.logger.Debug("Waiting for messages and errors to be handled")
c.doneWg.Wait()

c.logger.Debug("Closing message and error done channels")
close(c.messagesDoneChan)
close(c.errorsDoneChan)

return err
}

// handleMessages handles incoming Kafka messages on a channel. Upon the closure of the message channel,
// handleMessages will signal the messagesDoneChan to indicate the graceful shutdown of message handling is done.
func (c *Consumer) handleMessages(pc sc.PartitionConsumer) {
c.logger.Info("Starting message handler", zap.Int32("partition", pc.Partition()))
c.partitionMapLock.Lock()
c.partitionsHeld++
c.partitionsHeldGauge.Update(c.partitionsHeld)
wg := &c.partitionIDToState[pc.Partition()].wg
c.partitionMapLock.Unlock()
defer func() {
c.closePartition(pc)
wg.Done()
c.partitionMapLock.Lock()
c.partitionsHeld--
c.partitionsHeldGauge.Update(c.partitionsHeld)
c.partitionMapLock.Unlock()
c.messagesDoneChan <- "HandleMessages done"
}()

msgMetrics := c.newMsgMetrics(pc.Partition())
Expand Down Expand Up @@ -165,12 +185,13 @@ func (c *Consumer) closePartition(partitionConsumer sc.PartitionConsumer) {
c.logger.Info("Closed partition consumer", zap.Int32("partition", partitionConsumer.Partition()))
}

// handleErrors handles incoming Kafka consumer errors on a channel. Upon the closure of the error channel,
// handleErrors will signal the errorsDoneChan to indicate the graceful shutdown of error handling is done.
func (c *Consumer) handleErrors(partition int32, errChan <-chan *sarama.ConsumerError) {
c.logger.Info("Starting error handler", zap.Int32("partition", partition))
c.partitionMapLock.Lock()
wg := &c.partitionIDToState[partition].wg
c.partitionMapLock.Unlock()
defer wg.Done()
defer func() {
c.errorsDoneChan <- "HandleErrors done"
}()

errMetrics := c.newErrMetrics(partition)
for err := range errChan {
Expand Down
31 changes: 17 additions & 14 deletions cmd/ingester/app/consumer/consumer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func (s partitionConsumerWrapper) Topic() string {
return s.topic
}

func newSaramaClusterConsumer(saramaPartitionConsumer sarama.PartitionConsumer) *kmocks.Consumer {
func newSaramaClusterConsumer(saramaPartitionConsumer sarama.PartitionConsumer, mc *smocks.PartitionConsumer) *kmocks.Consumer {
pcha := make(chan cluster.PartitionConsumer, 1)
pcha <- &partitionConsumerWrapper{
topic: topic,
Expand All @@ -77,27 +77,26 @@ func newSaramaClusterConsumer(saramaPartitionConsumer sarama.PartitionConsumer)
}
saramaClusterConsumer := &kmocks.Consumer{}
saramaClusterConsumer.On("Partitions").Return((<-chan cluster.PartitionConsumer)(pcha))
saramaClusterConsumer.On("Close").Return(nil)
saramaClusterConsumer.On("Close").Return(nil).Run(func(args mock.Arguments) {
mc.Close()
})
saramaClusterConsumer.On("MarkPartitionOffset", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil)
return saramaClusterConsumer
}

func newConsumer(
t *testing.T,
metricsFactory metrics.Factory,
topic string,
processor processor.SpanProcessor,
consumer consumer.Consumer) *Consumer {

logger, _ := zap.NewDevelopment()
return &Consumer{
metricsFactory: metricsFactory,
logger: logger,
internalConsumer: consumer,
partitionIDToState: make(map[int32]*consumerState),
partitionsHeldGauge: partitionsHeldGauge(metricsFactory),
deadlockDetector: newDeadlockDetector(metricsFactory, logger, time.Second),

processorFactory: ProcessorFactory{
consumerParams := Params{
MetricsFactory: metricsFactory,
Logger: logger,
InternalConsumer: consumer,
ProcessorFactory: ProcessorFactory{
topic: topic,
consumer: consumer,
metricsFactory: metricsFactory,
Expand All @@ -106,6 +105,10 @@ func newConsumer(
parallelism: 1,
},
}

c, err := New(consumerParams)
require.NoError(t, err)
return c
}

func TestSaramaConsumerWrapper_MarkPartitionOffset(t *testing.T) {
Expand Down Expand Up @@ -136,7 +139,7 @@ func TestSaramaConsumerWrapper_start_Messages(t *testing.T) {
saramaPartitionConsumer, e := saramaConsumer.ConsumePartition(topic, partition, msgOffset)
require.NoError(t, e)

undertest := newConsumer(localFactory, topic, mp, newSaramaClusterConsumer(saramaPartitionConsumer))
undertest := newConsumer(t, localFactory, topic, mp, newSaramaClusterConsumer(saramaPartitionConsumer, mc))

undertest.partitionIDToState = map[int32]*consumerState{
partition: {
Expand Down Expand Up @@ -202,7 +205,7 @@ func TestSaramaConsumerWrapper_start_Errors(t *testing.T) {
saramaPartitionConsumer, e := saramaConsumer.ConsumePartition(topic, partition, msgOffset)
require.NoError(t, e)

undertest := newConsumer(localFactory, topic, &pmocks.SpanProcessor{}, newSaramaClusterConsumer(saramaPartitionConsumer))
undertest := newConsumer(t, localFactory, topic, &pmocks.SpanProcessor{}, newSaramaClusterConsumer(saramaPartitionConsumer, mc))

undertest.Start()
mc.YieldError(errors.New("Daisy, Daisy"))
Expand Down Expand Up @@ -238,7 +241,7 @@ func TestHandleClosePartition(t *testing.T) {
saramaPartitionConsumer, e := saramaConsumer.ConsumePartition(topic, partition, msgOffset)
require.NoError(t, e)

undertest := newConsumer(metricsFactory, topic, mp, newSaramaClusterConsumer(saramaPartitionConsumer))
undertest := newConsumer(t, metricsFactory, topic, mp, newSaramaClusterConsumer(saramaPartitionConsumer, mc))
undertest.deadlockDetector = newDeadlockDetector(metricsFactory, undertest.logger, 200*time.Millisecond)
undertest.Start()
defer undertest.Close()
Expand Down

0 comments on commit 5652c13

Please sign in to comment.