Skip to content

Commit

Permalink
Do some refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
BewareMyPower committed Nov 25, 2024
1 parent 5f0e32d commit dcd2b4a
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 155 deletions.
74 changes: 57 additions & 17 deletions pulsar/consumer_partition.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,17 @@ type partitionConsumer struct {
availablePermits *availablePermits

// the size of the queue channel for buffering messages
maxQueueSize int32
queueCh *unboundedChannel[*message]
maxQueueSize int32

// pendingMessages queues all messages received from the broker but not delivered to the user via Chan() or
// Receive() methods.
// There is a background goroutine that sends messages from the connection to `pendingMessages` via `queueInCh` and
// reads messages from `pendingMessages` via `queueOutCh` so that the `dispatcher` goroutine can read messages from
// the `queueOutCh`.
pendingMessages *list.List
queueInCh chan *message
queueOutCh chan *message

startMessageID atomicMessageID
lastDequeuedMsg *trackingMessageID

Expand Down Expand Up @@ -354,7 +363,6 @@ func newPartitionConsumer(parent Consumer, client *client, options *partitionCon
partitionIdx: int32(options.partitionIdx),
eventsCh: make(chan interface{}, 10),
maxQueueSize: int32(options.receiverQueueSize),
queueCh: newUnboundedChannel[*message](),
startMessageID: atomicMessageID{msgID: options.startMessageID},
connectedCh: make(chan struct{}),
messageCh: messageCh,
Expand Down Expand Up @@ -419,6 +427,7 @@ func newPartitionConsumer(parent Consumer, client *client, options *partitionCon
}
pc.log.Info("Created consumer")
pc.setConsumerState(consumerReady)
pc.startQueueMessagesFromBroker()

startingMessageID := pc.startMessageID.get()
if pc.options.startMessageIDInclusive && startingMessageID != nil && startingMessageID.equal(latestMessageID) {
Expand Down Expand Up @@ -949,11 +958,6 @@ func (pc *partitionConsumer) Close() {

// wait for request to finish
<-req.doneCh

// It will close `queueCh.in`. If `MessageReceived` was called after that, it will panic because new messages
// will be sent to a closed channel. However, generally it's impossible because the broker will not be able to
// dispatch messages to this consumer after receiving the close request.
pc.queueCh.stop()
}

func (pc *partitionConsumer) Seek(msgID MessageID) error {
Expand Down Expand Up @@ -1176,7 +1180,7 @@ func (pc *partitionConsumer) MessageReceived(response *pb.CommandMessage, header
pc.markScaleIfNeed()
}

pc.queueCh.inCh <- &message{
pc.queueInCh <- &message{
publishTime: timeFromUnixTimestampMillis(msgMeta.GetPublishTime()),
eventTime: timeFromUnixTimestampMillis(msgMeta.GetEventTime()),
key: msgMeta.GetPartitionKey(),
Expand Down Expand Up @@ -1378,7 +1382,7 @@ func (pc *partitionConsumer) MessageReceived(response *pb.CommandMessage, header
pc.markScaleIfNeed()
}

pc.queueCh.inCh <- msg
pc.queueInCh <- msg
}

if skippedMessages > 0 {
Expand Down Expand Up @@ -1542,12 +1546,14 @@ func (pc *partitionConsumer) dispatcher() {
}()
var queueMsg *message
for {
var queueCh <-chan *message
queueMsgCh := pc.queueOutCh
var messageCh chan ConsumerMessage
var nextMessage ConsumerMessage
var nextMessageSize int

if queueMsg != nil {
// Do not read from the queued message channel since there is already a message polled in the last loop
queueMsgCh = nil
nextMessage = ConsumerMessage{
Consumer: pc.parentConsumer,
Message: queueMsg,
Expand All @@ -1568,8 +1574,6 @@ func (pc *partitionConsumer) dispatcher() {
} else {
pc.log.Debug("skip dispatching messages when seeking")
}
} else {
queueCh = pc.queueCh.outCh
}

select {
Expand Down Expand Up @@ -1607,7 +1611,7 @@ func (pc *partitionConsumer) dispatcher() {
pc.log.Debug("received dispatcherSeekingControlCh, set isSeek to true")
pc.isSeeking.Store(true)

case msg, ok := <-queueCh:
case msg, ok := <-queueMsgCh:
if !ok {
return
}
Expand All @@ -1630,9 +1634,9 @@ func (pc *partitionConsumer) dispatcher() {
// drain the message queue on any new connection by sending a
// special nil message to the channel so we know when to stop dropping messages
var nextMessageInQueue *trackingMessageID
pc.queueCh.inCh <- nil
pc.queueInCh <- nil

for m := range pc.queueCh.outCh {
for m := range pc.queueOutCh {
// the queue has been drained
if m == nil {
break
Expand Down Expand Up @@ -2080,7 +2084,7 @@ func (pc *partitionConsumer) expectMoreIncomingMessages() {
}

func (pc *partitionConsumer) markScaleIfNeed() {
// availablePermits + incomingMessages (messages in queueCh) is the number of prefetched messages
// availablePermits + incomingMessages (messages in pendingMessages) is the number of prefetched messages
// The result of auto-scale we expected is currentQueueSize is slightly bigger than prefetched messages
prev := pc.scaleReceiverQueueHint.Swap(pc.availablePermits.get()+pc.incomingMessages.Load() >=
pc.currentQueueSize.Load())
Expand Down Expand Up @@ -2220,6 +2224,42 @@ func (pc *partitionConsumer) _getConn() internal.Connection {
return *pc.conn.Load()
}

func (pc *partitionConsumer) startQueueMessagesFromBroker() {
pc.queueInCh = make(chan *message)
pc.queueOutCh = make(chan *message)
pc.pendingMessages = list.New()

go func() {
defer func() {
close(pc.queueInCh)
close(pc.queueOutCh)
pc.log.Debug("exiting queueMessagesFromBroker")
}()

for {
front := pc.pendingMessages.Front()
if front == nil {
select {
case msg := <-pc.queueInCh:
pc.pendingMessages.PushBack(msg)
case <-pc.closeCh:
return
}
} else {
msg := front.Value.(*message)
select {
case pc.queueOutCh <- msg:
pc.pendingMessages.Remove(front)
case msg := <-pc.queueInCh:
pc.pendingMessages.PushBack(msg)
case <-pc.closeCh:
return
}
}
}
}()
}

func convertToMessageIDData(msgID *trackingMessageID) *pb.MessageIdData {
if msgID == nil {
return nil
Expand Down
15 changes: 9 additions & 6 deletions pulsar/consumer_partition_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import (
func TestSingleMessageIDNoAckTracker(t *testing.T) {
eventsCh := make(chan interface{}, 1)
pc := partitionConsumer{
queueCh: newUnboundedChannel[*message](),
closeCh: make(chan struct{}),
eventsCh: eventsCh,
compressionProviders: sync.Map{},
options: &partitionConsumerOpts{},
Expand All @@ -41,14 +41,15 @@ func TestSingleMessageIDNoAckTracker(t *testing.T) {
pc.availablePermits = &availablePermits{pc: &pc}
pc.ackGroupingTracker = newAckGroupingTracker(&AckGroupingOptions{MaxSize: 0},
func(id MessageID) { pc.sendIndividualAck(id) }, nil, nil)
pc.startQueueMessagesFromBroker()

headersAndPayload := internal.NewBufferWrapper(rawCompatSingleMessage)
if err := pc.MessageReceived(nil, headersAndPayload); err != nil {
t.Fatal(err)
}

// ensure the tracker was set on the message id
message := <-pc.queueCh.outCh
message := <-pc.queueOutCh
id := message.ID().(*trackingMessageID)
assert.Nil(t, id.tracker)

Expand All @@ -69,7 +70,7 @@ func newTestMetrics() *internal.LeveledMetrics {
func TestBatchMessageIDNoAckTracker(t *testing.T) {
eventsCh := make(chan interface{}, 1)
pc := partitionConsumer{
queueCh: newUnboundedChannel[*message](),
closeCh: make(chan struct{}),
eventsCh: eventsCh,
compressionProviders: sync.Map{},
options: &partitionConsumerOpts{},
Expand All @@ -79,14 +80,15 @@ func TestBatchMessageIDNoAckTracker(t *testing.T) {
pc.availablePermits = &availablePermits{pc: &pc}
pc.ackGroupingTracker = newAckGroupingTracker(&AckGroupingOptions{MaxSize: 0},
func(id MessageID) { pc.sendIndividualAck(id) }, nil, nil)
pc.startQueueMessagesFromBroker()

headersAndPayload := internal.NewBufferWrapper(rawBatchMessage1)
if err := pc.MessageReceived(nil, headersAndPayload); err != nil {
t.Fatal(err)
}

// ensure the tracker was set on the message id
message := <-pc.queueCh.outCh
message := <-pc.queueOutCh
id := message.ID().(*trackingMessageID)
assert.Nil(t, id.tracker)

Expand All @@ -104,7 +106,7 @@ func TestBatchMessageIDNoAckTracker(t *testing.T) {
func TestBatchMessageIDWithAckTracker(t *testing.T) {
eventsCh := make(chan interface{}, 1)
pc := partitionConsumer{
queueCh: newUnboundedChannel[*message](),
closeCh: make(chan struct{}),
eventsCh: eventsCh,
compressionProviders: sync.Map{},
options: &partitionConsumerOpts{},
Expand All @@ -114,6 +116,7 @@ func TestBatchMessageIDWithAckTracker(t *testing.T) {
pc.availablePermits = &availablePermits{pc: &pc}
pc.ackGroupingTracker = newAckGroupingTracker(&AckGroupingOptions{MaxSize: 0},
func(id MessageID) { pc.sendIndividualAck(id) }, nil, nil)
pc.startQueueMessagesFromBroker()

headersAndPayload := internal.NewBufferWrapper(rawBatchMessage10)
if err := pc.MessageReceived(nil, headersAndPayload); err != nil {
Expand All @@ -125,7 +128,7 @@ func TestBatchMessageIDWithAckTracker(t *testing.T) {
running := true
for running {
select {
case m := <-pc.queueCh.outCh:
case m := <-pc.queueOutCh:
id := m.ID().(*trackingMessageID)
assert.NotNil(t, id.tracker)
messageIDs = append(messageIDs, id)
Expand Down
10 changes: 6 additions & 4 deletions pulsar/consumer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4887,14 +4887,15 @@ func TestAckResponseNotBlocked(t *testing.T) {
defer client.Close()

topic := fmt.Sprintf("test-ack-response-not-blocked-%v", time.Now().Nanosecond())
assert.Nil(t, createPartitionedTopic(topic, 10))

producer, err := client.CreateProducer(ProducerOptions{
Topic: topic,
})
assert.Nil(t, err)

ctx := context.Background()
numMessages := 100
numMessages := 1000
for i := 0; i < numMessages; i++ {
producer.SendAsync(ctx, &ProducerMessage{
Payload: []byte(fmt.Sprintf("value-%d", i)),
Expand All @@ -4903,7 +4904,9 @@ func TestAckResponseNotBlocked(t *testing.T) {
t.Fatal(err)
}
})
time.Sleep(1 * time.Millisecond)
if i%100 == 99 {
assert.Nil(t, producer.Flush())
}
}
producer.Flush()
producer.Close()
Expand All @@ -4917,15 +4920,14 @@ func TestAckResponseNotBlocked(t *testing.T) {
Type: KeyShared,
EnableBatchIndexAcknowledgment: true,
AckWithResponse: true,
ReceiverQueueSize: 10,
ReceiverQueueSize: 5,
})
assert.Nil(t, err)
msgIDs := make([]MessageID, 0)
for i := 0; i < numMessages; i++ {
if msg, err := consumer.Receive(context.Background()); err != nil {
t.Fatal(err)
} else {
t.Log("Received message: ", msg.ID())
msgIDs = append(msgIDs, msg.ID())
if len(msgIDs) >= 10 {
if err := consumer.AckIDList(msgIDs); err != nil {
Expand Down
68 changes: 0 additions & 68 deletions pulsar/unbounded_channel.go

This file was deleted.

Loading

0 comments on commit dcd2b4a

Please sign in to comment.