Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix possible deadlock when AckWithResponse is true due to queueCh is full #1310

Closed
23 changes: 13 additions & 10 deletions pulsar/consumer_partition.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ type partitionConsumer struct {

// the size of the queue channel for buffering messages
maxQueueSize int32
queueCh chan *message
queueCh *unboundedChannel[*message]
startMessageID atomicMessageID
lastDequeuedMsg *trackingMessageID

Expand Down Expand Up @@ -354,7 +354,7 @@ func newPartitionConsumer(parent Consumer, client *client, options *partitionCon
partitionIdx: int32(options.partitionIdx),
eventsCh: make(chan interface{}, 10),
maxQueueSize: int32(options.receiverQueueSize),
queueCh: make(chan *message, options.receiverQueueSize),
queueCh: newUnboundedChannel[*message](),
startMessageID: atomicMessageID{msgID: options.startMessageID},
connectedCh: make(chan struct{}),
messageCh: messageCh,
Expand Down Expand Up @@ -949,6 +949,11 @@ func (pc *partitionConsumer) Close() {

// wait for request to finish
<-req.doneCh

// It will close `queueCh.in`. If `MessageReceived` was called after that, it will panic because new messages
// will be sent to a closed channel. However, generally it's impossible because the broker will not be able to
// dispatch messages to this consumer after receiving the close request.
pc.queueCh.stop()
BewareMyPower marked this conversation as resolved.
Show resolved Hide resolved
}

func (pc *partitionConsumer) Seek(msgID MessageID) error {
Expand Down Expand Up @@ -1171,7 +1176,7 @@ func (pc *partitionConsumer) MessageReceived(response *pb.CommandMessage, header
pc.markScaleIfNeed()
}

pc.queueCh <- &message{
pc.queueCh.inCh <- &message{
publishTime: timeFromUnixTimestampMillis(msgMeta.GetPublishTime()),
eventTime: timeFromUnixTimestampMillis(msgMeta.GetEventTime()),
key: msgMeta.GetPartitionKey(),
Expand Down Expand Up @@ -1373,7 +1378,7 @@ func (pc *partitionConsumer) MessageReceived(response *pb.CommandMessage, header
pc.markScaleIfNeed()
}

pc.queueCh <- msg
pc.queueCh.inCh <- msg
}

if skippedMessages > 0 {
Expand Down Expand Up @@ -1537,7 +1542,7 @@ func (pc *partitionConsumer) dispatcher() {
}()
var queueMsg *message
for {
var queueCh chan *message
var queueCh <-chan *message
var messageCh chan ConsumerMessage
var nextMessage ConsumerMessage
var nextMessageSize int
Expand All @@ -1564,7 +1569,7 @@ func (pc *partitionConsumer) dispatcher() {
pc.log.Debug("skip dispatching messages when seeking")
}
} else {
queueCh = pc.queueCh
queueCh = pc.queueCh.outCh
BewareMyPower marked this conversation as resolved.
Show resolved Hide resolved
}

select {
Expand Down Expand Up @@ -1625,11 +1630,9 @@ func (pc *partitionConsumer) dispatcher() {
// drain the message queue on any new connection by sending a
// special nil message to the channel so we know when to stop dropping messages
var nextMessageInQueue *trackingMessageID
go func() {
pc.queueCh <- nil
}()
pc.queueCh.inCh <- nil

for m := range pc.queueCh {
for m := range pc.queueCh.outCh {
// the queue has been drained
if m == nil {
break
Expand Down
23 changes: 14 additions & 9 deletions pulsar/consumer_partition_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package pulsar
import (
"sync"
"testing"
"time"

"github.com/apache/pulsar-client-go/pulsar/internal"
"github.com/apache/pulsar-client-go/pulsar/internal/crypto"
Expand All @@ -30,7 +31,7 @@ import (
func TestSingleMessageIDNoAckTracker(t *testing.T) {
eventsCh := make(chan interface{}, 1)
pc := partitionConsumer{
queueCh: make(chan *message, 1),
queueCh: newUnboundedChannel[*message](),
eventsCh: eventsCh,
compressionProviders: sync.Map{},
options: &partitionConsumerOpts{},
Expand All @@ -47,7 +48,7 @@ func TestSingleMessageIDNoAckTracker(t *testing.T) {
}

// ensure the tracker was set on the message id
message := <-pc.queueCh
message := <-pc.queueCh.outCh
id := message.ID().(*trackingMessageID)
assert.Nil(t, id.tracker)

Expand All @@ -68,7 +69,7 @@ func newTestMetrics() *internal.LeveledMetrics {
func TestBatchMessageIDNoAckTracker(t *testing.T) {
eventsCh := make(chan interface{}, 1)
pc := partitionConsumer{
queueCh: make(chan *message, 1),
queueCh: newUnboundedChannel[*message](),
eventsCh: eventsCh,
compressionProviders: sync.Map{},
options: &partitionConsumerOpts{},
Expand All @@ -85,7 +86,7 @@ func TestBatchMessageIDNoAckTracker(t *testing.T) {
}

// ensure the tracker was set on the message id
message := <-pc.queueCh
message := <-pc.queueCh.outCh
id := message.ID().(*trackingMessageID)
assert.Nil(t, id.tracker)

Expand All @@ -103,7 +104,7 @@ func TestBatchMessageIDNoAckTracker(t *testing.T) {
func TestBatchMessageIDWithAckTracker(t *testing.T) {
eventsCh := make(chan interface{}, 1)
pc := partitionConsumer{
queueCh: make(chan *message, 10),
queueCh: newUnboundedChannel[*message](),
eventsCh: eventsCh,
compressionProviders: sync.Map{},
options: &partitionConsumerOpts{},
Expand All @@ -121,14 +122,18 @@ func TestBatchMessageIDWithAckTracker(t *testing.T) {

// ensure the tracker was set on the message id
var messageIDs []*trackingMessageID
for i := 0; i < 10; i++ {
running := true
for running {
select {
case m := <-pc.queueCh:
case m := <-pc.queueCh.outCh:
id := m.ID().(*trackingMessageID)
assert.NotNil(t, id.tracker)
messageIDs = append(messageIDs, id)
default:
break
if len(messageIDs) == 10 {
running = false
}
case <-time.After(5 * time.Second):
running = false
}
}

Expand Down
61 changes: 61 additions & 0 deletions pulsar/consumer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4877,3 +4877,64 @@ func receiveMessages(t *testing.T, consumer Consumer, numMessages int) []Message
assert.Equal(t, numMessages, len(msgs))
return msgs
}

func TestAckResponseNotBlocked(t *testing.T) {
client, err := NewClient(ClientOptions{
URL: lookupURL,
OperationTimeout: 5 * time.Second,
})
assert.Nil(t, err)
defer client.Close()

topic := fmt.Sprintf("test-ack-response-not-blocked-%v", time.Now().Nanosecond())

producer, err := client.CreateProducer(ProducerOptions{
Topic: topic,
})
assert.Nil(t, err)

ctx := context.Background()
numMessages := 100
for i := 0; i < numMessages; i++ {
producer.SendAsync(ctx, &ProducerMessage{
Payload: []byte(fmt.Sprintf("value-%d", i)),
}, func(_ MessageID, _ *ProducerMessage, err error) {
if err != nil {
t.Fatal(err)
}
})
time.Sleep(1 * time.Millisecond)
}
producer.Flush()
producer.Close()

// Set a small receiver queue size to trigger ack response blocking if the internal `queueCh`
// is a channel with the same size
consumer, err := client.Subscribe(ConsumerOptions{
Topic: topic,
SubscriptionName: "my-sub",
SubscriptionInitialPosition: SubscriptionPositionEarliest,
Type: KeyShared,
EnableBatchIndexAcknowledgment: true,
AckWithResponse: true,
ReceiverQueueSize: 10,
})
assert.Nil(t, err)
msgIDs := make([]MessageID, 0)
for i := 0; i < numMessages; i++ {
if msg, err := consumer.Receive(context.Background()); err != nil {
t.Fatal(err)
} else {
t.Log("Received message: ", msg.ID())
msgIDs = append(msgIDs, msg.ID())
if len(msgIDs) >= 10 {
if err := consumer.AckIDList(msgIDs); err != nil {
t.Fatal("Failed to acked messages: ", msgIDs, " ", err)
} else {
t.Log("Acked messages: ", msgIDs)
}
msgIDs = msgIDs[:0]
}
}
}
}
68 changes: 68 additions & 0 deletions pulsar/unbounded_channel.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package pulsar

import (
"container/list"
)

type unboundedChannel[T interface{}] struct {
values *list.List
inCh chan<- T
outCh <-chan T
closeCh chan struct{}
}

func newUnboundedChannel[T interface{}]() *unboundedChannel[T] {
inCh := make(chan T)
outCh := make(chan T)
c := &unboundedChannel[T]{
values: list.New(),
inCh: inCh,
outCh: outCh,
closeCh: make(chan struct{}),
}
go func() {
for {
front := c.values.Front()
var ch chan T
var value T
if front != nil {
value = front.Value.(T)
ch = outCh
}
// A send to a nil channel blocks forever so when no values are available,
// it would never send a value to ch
select {
case v := <-inCh:
c.values.PushBack(v)
case ch <- value:
c.values.Remove(front)
case <-c.closeCh:
close(inCh)
close(outCh)
return
}
}
}()
return c
}

func (c *unboundedChannel[T]) stop() {
c.closeCh <- struct{}{}
}
60 changes: 60 additions & 0 deletions pulsar/unbounded_channel_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package pulsar

import (
"testing"
"time"

"github.com/stretchr/testify/assert"
)

func TestUnboundedChannel(t *testing.T) {
c := newUnboundedChannel[int]()
defer c.stop()
go func() {
for i := 0; i < 10; i++ {
c.inCh <- i
}
}()

for i := 0; i < 10; i++ {
v := <-c.outCh
assert.Equal(t, i, v)
}

go func() {
time.Sleep(1 * time.Second)
c.inCh <- -1
}()
start := time.Now()
v := <-c.outCh
elapsed := time.Since(start)
assert.Equal(t, v, -1)
// Verify the read blocks for at least 500ms
assert.True(t, elapsed >= 500*time.Millisecond)

// Verify the send values will not be blocked
for i := 0; i < 10000; i++ {
c.inCh <- i
}
for i := 0; i < 10000; i++ {
v := <-c.outCh
assert.Equal(t, i, v)
}
}
Loading