diff --git a/pulsar/ack_grouping_tracker.go b/pulsar/ack_grouping_tracker.go index c4ecc00389..bb9059ac06 100644 --- a/pulsar/ack_grouping_tracker.go +++ b/pulsar/ack_grouping_tracker.go @@ -62,7 +62,7 @@ func newAckGroupingTracker(options *AckGroupingOptions, maxNumAcks: int(options.MaxSize), ackCumulative: ackCumulative, ackList: ackList, - pendingAcks: make(map[[2]uint64]*bitset.BitSet), + pendingAcks: make(map[position]*bitset.BitSet), lastCumulativeAck: EarliestMessageID(), } @@ -110,6 +110,15 @@ func (i *immediateAckGroupingTracker) flushAndClean() { func (i *immediateAckGroupingTracker) close() { } +type position struct { + ledgerID uint64 + entryID uint64 +} + +func newPosition(msgID MessageID) position { + return position{ledgerID: uint64(msgID.LedgerID()), entryID: uint64(msgID.EntryID())} +} + type timedAckGroupingTracker struct { sync.RWMutex @@ -124,7 +133,7 @@ type timedAckGroupingTracker struct { // in the batch whose batch size is 3 are not acknowledged. // After the 1st message (i.e. batch index is 0) is acknowledged, the bits will become "011". // Value is nil if the entry represents a single message. - pendingAcks map[[2]uint64]*bitset.BitSet + pendingAcks map[position]*bitset.BitSet lastCumulativeAck MessageID cumulativeAckRequired int32 @@ -138,35 +147,36 @@ func (t *timedAckGroupingTracker) add(id MessageID) { } } -func (t *timedAckGroupingTracker) tryAddIndividual(id MessageID) map[[2]uint64]*bitset.BitSet { - t.Lock() - defer t.Unlock() - key := [2]uint64{uint64(id.LedgerID()), uint64(id.EntryID())} - +func addMsgIDToPendingAcks(pendingAcks map[position]*bitset.BitSet, id MessageID) { + key := newPosition(id) batchIdx := id.BatchIdx() batchSize := id.BatchSize() if batchIdx >= 0 && batchSize > 0 { - bs, found := t.pendingAcks[key] + bs, found := pendingAcks[key] if !found { - if batchSize > 1 { - bs = bitset.New(uint(batchSize)) - for i := uint(0); i < uint(batchSize); i++ { - bs.Set(i) - } + bs = bitset.New(uint(batchSize)) + for i := uint(0); i < uint(batchSize); i++ { + bs.Set(i) } - t.pendingAcks[key] = bs + pendingAcks[key] = bs } if bs != nil { bs.Clear(uint(batchIdx)) } } else { - t.pendingAcks[key] = nil + pendingAcks[key] = nil } +} +func (t *timedAckGroupingTracker) tryAddIndividual(id MessageID) map[position]*bitset.BitSet { + t.Lock() + defer t.Unlock() + + addMsgIDToPendingAcks(t.pendingAcks, id) if len(t.pendingAcks) >= t.maxNumAcks { pendingAcks := t.pendingAcks - t.pendingAcks = make(map[[2]uint64]*bitset.BitSet) + t.pendingAcks = make(map[position]*bitset.BitSet) return pendingAcks } return nil @@ -195,7 +205,7 @@ func (t *timedAckGroupingTracker) isDuplicate(id MessageID) bool { if messageIDCompare(t.lastCumulativeAck, id) >= 0 { return true } - key := [2]uint64{uint64(id.LedgerID()), uint64(id.EntryID())} + key := newPosition(id) if bs, found := t.pendingAcks[key]; found { if bs == nil { return true @@ -232,11 +242,11 @@ func (t *timedAckGroupingTracker) flushAndClean() { } } -func (t *timedAckGroupingTracker) clearPendingAcks() map[[2]uint64]*bitset.BitSet { +func (t *timedAckGroupingTracker) clearPendingAcks() map[position]*bitset.BitSet { t.Lock() defer t.Unlock() pendingAcks := t.pendingAcks - t.pendingAcks = make(map[[2]uint64]*bitset.BitSet) + t.pendingAcks = make(map[position]*bitset.BitSet) return pendingAcks } @@ -250,12 +260,10 @@ func (t *timedAckGroupingTracker) close() { } } -func (t *timedAckGroupingTracker) flushIndividual(pendingAcks map[[2]uint64]*bitset.BitSet) { +func toMsgIDDataList(pendingAcks map[position]*bitset.BitSet) []*pb.MessageIdData { msgIDs := make([]*pb.MessageIdData, 0, len(pendingAcks)) for k, v := range pendingAcks { - ledgerID := k[0] - entryID := k[1] - msgID := &pb.MessageIdData{LedgerId: &ledgerID, EntryId: &entryID} + msgID := &pb.MessageIdData{LedgerId: &k.ledgerID, EntryId: &k.entryID} if v != nil && !v.None() { bytes := v.Bytes() msgID.AckSet = make([]int64, len(bytes)) @@ -265,5 +273,9 @@ func (t *timedAckGroupingTracker) flushIndividual(pendingAcks map[[2]uint64]*bit } msgIDs = append(msgIDs, msgID) } - t.ackList(msgIDs) + return msgIDs +} + +func (t *timedAckGroupingTracker) flushIndividual(pendingAcks map[position]*bitset.BitSet) { + t.ackList(toMsgIDDataList(pendingAcks)) } diff --git a/pulsar/consumer.go b/pulsar/consumer.go index 880cad563e..7aee9645c6 100644 --- a/pulsar/consumer.go +++ b/pulsar/consumer.go @@ -19,6 +19,8 @@ package pulsar import ( "context" + "fmt" + "strings" "time" "github.com/apache/pulsar-client-go/pulsar/backoff" @@ -266,6 +268,23 @@ type ConsumerOptions struct { startMessageID *trackingMessageID } +// This error is returned when `AckIDList` failed and `AckWithResponse` is true. +// It only contains the valid message IDs that failed to be acknowledged in the `AckIDList` call. +// For those invalid message IDs, users should ignore them and not acknowledge them again. +type AckError map[MessageID]error + +func (e AckError) Error() string { + builder := strings.Builder{} + errorMap := make(map[string][]MessageID) + for id, err := range e { + errorMap[err.Error()] = append(errorMap[err.Error()], id) + } + for err, msgIDs := range errorMap { + builder.WriteString(fmt.Sprintf("error: %s, failed message IDs: %v\n", err, msgIDs)) + } + return builder.String() +} + // Consumer is an interface that abstracts behavior of Pulsar's consumer type Consumer interface { // Subscription get a subscription for the consumer @@ -305,8 +324,20 @@ type Consumer interface { Ack(Message) error // AckID the consumption of a single message, identified by its MessageID + // When `EnableBatchIndexAcknowledgment` is false, if a message ID represents a message in the batch, + // it will not be actually acknowledged by broker until all messages in that batch are acknowledged via + // `AckID` or `AckIDList`. AckID(MessageID) error + // AckIDList the consumption of a list of messages, identified by their MessageIDs + // + // This method should be used when `AckWithResponse` is true. Otherwise, it will be equivalent with calling + // `AckID` on each message ID in the list. + // + // When `AckWithResponse` is true, the returned error could be an `AckError` which contains the failed message ID + // and the corresponding error. + AckIDList([]MessageID) error + // AckWithTxn the consumption of a single message with a transaction AckWithTxn(Message, Transaction) error diff --git a/pulsar/consumer_impl.go b/pulsar/consumer_impl.go index 740a7df97d..eafa4b47d8 100644 --- a/pulsar/consumer_impl.go +++ b/pulsar/consumer_impl.go @@ -41,6 +41,7 @@ const defaultNackRedeliveryDelay = 1 * time.Minute type acker interface { // AckID does not handle errors returned by the Broker side, so no need to wait for doneCh to finish. AckID(id MessageID) error + AckIDList(msgIDs []MessageID) error AckIDWithResponse(id MessageID) error AckIDWithTxn(msgID MessageID, txn Transaction) error AckIDCumulative(msgID MessageID) error @@ -559,6 +560,15 @@ func (c *consumer) AckID(msgID MessageID) error { return c.consumers[msgID.PartitionIdx()].AckID(msgID) } +func (c *consumer) AckIDList(msgIDs []MessageID) error { + return ackIDListFromMultiTopics(c.log, msgIDs, func(msgID MessageID) (acker, error) { + if err := c.checkMsgIDPartition(msgID); err != nil { + return nil, err + } + return c.consumers[msgID.PartitionIdx()], nil + }) +} + // AckCumulative the reception of all the messages in the stream up to (and including) // the provided message, identified by its MessageID func (c *consumer) AckCumulative(msg Message) error { diff --git a/pulsar/consumer_multitopic.go b/pulsar/consumer_multitopic.go index 3030bda327..26430add00 100644 --- a/pulsar/consumer_multitopic.go +++ b/pulsar/consumer_multitopic.go @@ -167,6 +167,49 @@ func (c *multiTopicConsumer) AckID(msgID MessageID) error { return mid.consumer.AckID(msgID) } +func (c *multiTopicConsumer) AckIDList(msgIDs []MessageID) error { + return ackIDListFromMultiTopics(c.log, msgIDs, func(msgID MessageID) (acker, error) { + if !checkMessageIDType(msgID) { + return nil, fmt.Errorf("invalid message id type %T", msgID) + } + if mid := toTrackingMessageID(msgID); mid != nil && mid.consumer != nil { + return mid.consumer, nil + } + return nil, errors.New("consumer is nil") + }) +} + +func ackIDListFromMultiTopics(log log.Logger, msgIDs []MessageID, findConsumer func(MessageID) (acker, error)) error { + consumerToMsgIDs := make(map[acker][]MessageID) + for _, msgID := range msgIDs { + if consumer, err := findConsumer(msgID); err == nil { + consumerToMsgIDs[consumer] = append(consumerToMsgIDs[consumer], msgID) + } else { + log.Warnf("Can not find consumer for %v", msgID) + } + } + + ackError := AckError{} + for consumer, ids := range consumerToMsgIDs { + if err := consumer.AckIDList(ids); err != nil { + if topicAckError := err.(AckError); topicAckError != nil { + for id, err := range topicAckError { + ackError[id] = err + } + } else { + // It should not reach here + for _, id := range ids { + ackError[id] = err + } + } + } + } + if len(ackError) == 0 { + return nil + } + return ackError +} + // AckWithTxn the consumption of a single message with a transaction func (c *multiTopicConsumer) AckWithTxn(msg Message, txn Transaction) error { msgID := msg.ID() diff --git a/pulsar/consumer_multitopic_test.go b/pulsar/consumer_multitopic_test.go index 7c6b898a99..4b3ec90855 100644 --- a/pulsar/consumer_multitopic_test.go +++ b/pulsar/consumer_multitopic_test.go @@ -21,6 +21,7 @@ import ( "fmt" "strings" "testing" + "time" "github.com/apache/pulsar-client-go/pulsaradmin" "github.com/apache/pulsar-client-go/pulsaradmin/pkg/admin/config" @@ -218,3 +219,101 @@ func TestMultiTopicGetLastMessageIDs(t *testing.T) { } } + +func TestMultiTopicAckIDList(t *testing.T) { + for _, params := range []bool{true, false} { + t.Run(fmt.Sprintf("TestMultiTopicConsumerAckIDList%v", params), func(t *testing.T) { + runMultiTopicAckIDList(t, params) + }) + } +} + +func runMultiTopicAckIDList(t *testing.T, regex bool) { + topicPrefix := fmt.Sprintf("multiTopicAckIDList%v", time.Now().UnixNano()) + topic1 := "persistent://public/default/" + topicPrefix + "1" + topic2 := "persistent://public/default/" + topicPrefix + "2" + + client, err := NewClient(ClientOptions{URL: "pulsar://localhost:6650"}) + assert.Nil(t, err) + defer client.Close() + + if regex { + admin, err := pulsaradmin.NewClient(&config.Config{}) + assert.Nil(t, err) + for _, topic := range []string{topic1, topic2} { + topicName, err := utils.GetTopicName(topic) + assert.Nil(t, err) + admin.Topics().Create(*topicName, 0) + } + } + + createConsumer := func() Consumer { + options := ConsumerOptions{ + SubscriptionName: "sub", + Type: Shared, + AckWithResponse: true, + } + if regex { + options.TopicsPattern = topicPrefix + ".*" + } else { + options.Topics = []string{topic1, topic2} + } + consumer, err := client.Subscribe(options) + assert.Nil(t, err) + return consumer + } + consumer := createConsumer() + + sendMessages(t, client, topic1, 0, 3, false) + sendMessages(t, client, topic2, 0, 2, false) + + receiveMessageMap := func(consumer Consumer, numMessages int) map[string][]Message { + msgs := receiveMessages(t, consumer, numMessages) + topicToMsgs := make(map[string][]Message) + for _, msg := range msgs { + topicToMsgs[msg.Topic()] = append(topicToMsgs[msg.Topic()], msg) + } + return topicToMsgs + } + + topicToMsgs := receiveMessageMap(consumer, 5) + assert.Equal(t, 3, len(topicToMsgs[topic1])) + for i := 0; i < 3; i++ { + assert.Equal(t, fmt.Sprintf("msg-%d", i), string(topicToMsgs[topic1][i].Payload())) + } + assert.Equal(t, 2, len(topicToMsgs[topic2])) + for i := 0; i < 2; i++ { + assert.Equal(t, fmt.Sprintf("msg-%d", i), string(topicToMsgs[topic2][i].Payload())) + } + + assert.Nil(t, consumer.AckIDList([]MessageID{ + topicToMsgs[topic1][0].ID(), + topicToMsgs[topic1][2].ID(), + topicToMsgs[topic2][1].ID(), + })) + + consumer.Close() + consumer = createConsumer() + topicToMsgs = receiveMessageMap(consumer, 2) + assert.Equal(t, 1, len(topicToMsgs[topic1])) + assert.Equal(t, "msg-1", string(topicToMsgs[topic1][0].Payload())) + assert.Equal(t, 1, len(topicToMsgs[topic2])) + assert.Equal(t, "msg-0", string(topicToMsgs[topic2][0].Payload())) + consumer.Close() + + msgID0 := topicToMsgs[topic1][0].ID() + err = consumer.AckIDList([]MessageID{msgID0}) + assert.NotNil(t, err) + t.Logf("AckIDList error: %v", err) + + msgID1 := topicToMsgs[topic2][0].ID() + if ackError, ok := consumer.AckIDList([]MessageID{msgID0, msgID1}).(AckError); ok { + assert.Equal(t, 2, len(ackError)) + assert.Contains(t, ackError, msgID0) + assert.Equal(t, "consumer state is closed", ackError[msgID0].Error()) + assert.Contains(t, ackError, msgID1) + assert.Equal(t, "consumer state is closed", ackError[msgID1].Error()) + } else { + assert.Fail(t, "AckIDList should return AckError") + } +} diff --git a/pulsar/consumer_partition.go b/pulsar/consumer_partition.go index 4e8fba5a52..471d45a3a4 100644 --- a/pulsar/consumer_partition.go +++ b/pulsar/consumer_partition.go @@ -198,6 +198,10 @@ func (pc *partitionConsumer) pauseDispatchMessage() { pc.dispatcherSeekingControlCh <- struct{}{} } +func (pc *partitionConsumer) Topic() string { + return pc.topic +} + func (pc *partitionConsumer) ActiveConsumerChanged(isActive bool) { listener := pc.options.consumerEventListener if listener == nil { @@ -375,7 +379,12 @@ func newPartitionConsumer(parent Consumer, client *client, options *partitionCon pc.ackGroupingTracker = newAckGroupingTracker(options.ackGroupingOptions, func(id MessageID) { pc.sendIndividualAck(id) }, func(id MessageID) { pc.sendCumulativeAck(id) }, - func(ids []*pb.MessageIdData) { pc.eventsCh <- ids }) + func(ids []*pb.MessageIdData) { + pc.eventsCh <- &ackListRequest{ + errCh: nil, // ignore the error + msgIDs: ids, + } + }) pc.setConsumerState(consumerInit) pc.log = client.log.SubLogger(log.Fields{ "name": pc.name, @@ -695,6 +704,86 @@ func (pc *partitionConsumer) AckID(msgID MessageID) error { return pc.ackID(msgID, false) } +func (pc *partitionConsumer) AckIDList(msgIDs []MessageID) error { + if !pc.options.ackWithResponse { + for _, msgID := range msgIDs { + if err := pc.AckID(msgID); err != nil { + return err + } + } + return nil + } + + chunkedMsgIDs := make([]*chunkMessageID, 0) // we need to remove them after acknowledging + pendingAcks := make(map[position]*bitset.BitSet) + validMsgIDs := make([]MessageID, 0, len(msgIDs)) + + // They might be complete after the whole for loop + for _, msgID := range msgIDs { + if msgID.PartitionIdx() != pc.partitionIdx { + pc.log.Errorf("%v inconsistent partition index %v (current: %v)", msgID, msgID.PartitionIdx(), pc.partitionIdx) + } else if msgID.BatchIdx() >= 0 && msgID.BatchSize() > 0 && + msgID.BatchIdx() >= msgID.BatchSize() { + pc.log.Errorf("%v invalid batch index %v (size: %v)", msgID, msgID.BatchIdx(), msgID.BatchSize()) + } else { + valid := true + switch convertedMsgID := msgID.(type) { + case *trackingMessageID: + position := newPosition(msgID) + if convertedMsgID.ack() { + pendingAcks[position] = nil + } else if pc.options.enableBatchIndexAck { + pendingAcks[position] = convertedMsgID.tracker.getAckBitSet() + } + case *chunkMessageID: + for _, id := range pc.unAckChunksTracker.get(convertedMsgID) { + pendingAcks[newPosition(id)] = nil + } + chunkedMsgIDs = append(chunkedMsgIDs, convertedMsgID) + case *messageID: + pendingAcks[newPosition(msgID)] = nil + default: + pc.log.Errorf("invalid message id type %T: %v", msgID, msgID) + valid = false + } + if valid { + validMsgIDs = append(validMsgIDs, msgID) + } + } + } + + if state := pc.getConsumerState(); state == consumerClosed || state == consumerClosing { + pc.log.WithField("state", state).Error("Failed to ack by closing or closed consumer") + return toAckError(map[error][]MessageID{errors.New("consumer state is closed"): validMsgIDs}) + } + + req := &ackListRequest{ + errCh: make(chan error), + msgIDs: toMsgIDDataList(pendingAcks), + } + pc.eventsCh <- req + if err := <-req.errCh; err != nil { + return toAckError(map[error][]MessageID{err: validMsgIDs}) + } + for _, id := range chunkedMsgIDs { + pc.unAckChunksTracker.remove(id) + } + for _, id := range msgIDs { + pc.options.interceptors.OnAcknowledge(pc.parentConsumer, id) + } + return nil +} + +func toAckError(errorMap map[error][]MessageID) AckError { + e := AckError{} + for err, ids := range errorMap { + for _, id := range ids { + e[id] = err + } + } + return e +} + func (pc *partitionConsumer) AckIDCumulative(msgID MessageID) error { if !checkMessageIDType(msgID) { pc.log.Errorf("invalid message id type %T", msgID) @@ -1027,11 +1116,22 @@ func (pc *partitionConsumer) internalAck(req *ackRequest) { } } -func (pc *partitionConsumer) internalAckList(msgIDs []*pb.MessageIdData) { +func (pc *partitionConsumer) internalAckList(request *ackListRequest) { + if request.errCh != nil { + reqID := pc.client.rpcClient.NewRequestID() + _, err := pc.client.rpcClient.RequestOnCnx(pc._getConn(), reqID, pb.BaseCommand_ACK, &pb.CommandAck{ + AckType: pb.CommandAck_Individual.Enum(), + ConsumerId: proto.Uint64(pc.consumerID), + MessageId: request.msgIDs, + RequestId: &reqID, + }) + request.errCh <- err + return + } pc.client.rpcClient.RequestOnCnxNoWait(pc._getConn(), pb.BaseCommand_ACK, &pb.CommandAck{ AckType: pb.CommandAck_Individual.Enum(), ConsumerId: proto.Uint64(pc.consumerID), - MessageId: msgIDs, + MessageId: request.msgIDs, }) } @@ -1563,6 +1663,11 @@ type ackRequest struct { err error } +type ackListRequest struct { + errCh chan error + msgIDs []*pb.MessageIdData +} + type ackWithTxnRequest struct { doneCh chan struct{} msgID trackingMessageID @@ -1623,7 +1728,7 @@ func (pc *partitionConsumer) runEventsLoop() { pc.internalAck(v) case *ackWithTxnRequest: pc.internalAckWithTxn(v) - case []*pb.MessageIdData: + case *ackListRequest: pc.internalAckList(v) case *redeliveryRequest: pc.internalRedeliver(v) diff --git a/pulsar/consumer_regex.go b/pulsar/consumer_regex.go index 58cfa80fa3..ced770e996 100644 --- a/pulsar/consumer_regex.go +++ b/pulsar/consumer_regex.go @@ -215,6 +215,18 @@ func (c *regexConsumer) AckID(msgID MessageID) error { return mid.consumer.AckID(msgID) } +func (c *regexConsumer) AckIDList(msgIDs []MessageID) error { + return ackIDListFromMultiTopics(c.log, msgIDs, func(msgID MessageID) (acker, error) { + if !checkMessageIDType(msgID) { + return nil, fmt.Errorf("invalid message id type %T", msgID) + } + if mid := toTrackingMessageID(msgID); mid.consumer != nil { + return mid.consumer, nil + } + return nil, errors.New("consumer is nil in consumer_regex") + }) +} + // AckID the consumption of a single message, identified by its MessageID func (c *regexConsumer) AckWithTxn(msg Message, txn Transaction) error { msgID := msg.ID() diff --git a/pulsar/consumer_test.go b/pulsar/consumer_test.go index d8f31458f1..2524f6816a 100644 --- a/pulsar/consumer_test.go +++ b/pulsar/consumer_test.go @@ -4745,3 +4745,135 @@ func TestLookupConsumer(t *testing.T) { consumer.Ack(msg) } } + +func TestAckIDList(t *testing.T) { + for _, params := range []bool{true, false} { + t.Run(fmt.Sprintf("TestAckIDList_%v", params), func(t *testing.T) { + runAckIDListTest(t, params) + }) + } +} + +func runAckIDListTest(t *testing.T, enableBatchIndexAck bool) { + client, err := NewClient(ClientOptions{URL: lookupURL}) + assert.Nil(t, err) + defer client.Close() + + topic := fmt.Sprintf("test-ack-id-list-%v", time.Now().Nanosecond()) + + consumer := createSharedConsumer(t, client, topic, enableBatchIndexAck) + sendMessages(t, client, topic, 0, 5, true) // entry 0: [0, 1, 2, 3, 4] + sendMessages(t, client, topic, 5, 3, false) // entry 2: [5], 3: [6], 4: [7] + sendMessages(t, client, topic, 8, 2, true) // entry 5: [8, 9] + + msgs := receiveMessages(t, consumer, 10) + originalMsgIDs := make([]MessageID, 0) + for i := 0; i < 10; i++ { + originalMsgIDs = append(originalMsgIDs, msgs[i].ID()) + assert.Equal(t, fmt.Sprintf("msg-%d", i), string(msgs[i].Payload())) + } + + ackedIndexes := []int{0, 2, 3, 6, 8, 9} + unackedIndexes := []int{1, 4, 5, 7} + if !enableBatchIndexAck { + // [0, 4] is the first batch range but only partial of it is acked + unackedIndexes = []int{0, 1, 2, 3, 4, 5, 7} + } + msgIDs := make([]MessageID, len(ackedIndexes)) + for i := 0; i < len(ackedIndexes); i++ { + msgIDs[i] = msgs[ackedIndexes[i]].ID() + } + assert.Nil(t, consumer.AckIDList(msgIDs)) + consumer.Close() + + consumer = createSharedConsumer(t, client, topic, enableBatchIndexAck) + msgs = receiveMessages(t, consumer, len(unackedIndexes)) + for i := 0; i < len(unackedIndexes); i++ { + assert.Equal(t, fmt.Sprintf("msg-%d", unackedIndexes[i]), string(msgs[i].Payload())) + } + + if !enableBatchIndexAck { + msgIDs = make([]MessageID, 0) + for i := 0; i < 5; i++ { + msgIDs = append(msgIDs, originalMsgIDs[i]) + } + assert.Nil(t, consumer.AckIDList(msgIDs)) + consumer.Close() + + consumer = createSharedConsumer(t, client, topic, enableBatchIndexAck) + msgs = receiveMessages(t, consumer, 2) + assert.Equal(t, "msg-5", string(msgs[0].Payload())) + assert.Equal(t, "msg-7", string(msgs[1].Payload())) + consumer.Close() + } + consumer.Close() + err = consumer.AckIDList(msgIDs) + assert.NotNil(t, err) + if ackError := err.(AckError); ackError != nil { + assert.Equal(t, len(msgIDs), len(ackError)) + for _, id := range msgIDs { + assert.Contains(t, ackError, id) + assert.Equal(t, "consumer state is closed", ackError[id].Error()) + } + } else { + assert.Fail(t, "AckIDList should return AckError") + } +} + +func createSharedConsumer(t *testing.T, client Client, topic string, enableBatchIndexAck bool) Consumer { + consumer, err := client.Subscribe(ConsumerOptions{ + Topic: topic, + SubscriptionName: "my-sub", + SubscriptionInitialPosition: SubscriptionPositionEarliest, + Type: Shared, + EnableBatchIndexAcknowledgment: enableBatchIndexAck, + AckWithResponse: true, + }) + assert.Nil(t, err) + return consumer +} + +func sendMessages(t *testing.T, client Client, topic string, startIndex int, numMessages int, batching bool) { + producer, err := client.CreateProducer(ProducerOptions{ + Topic: topic, + DisableBatching: !batching, + BatchingMaxMessages: uint(numMessages), + BatchingMaxSize: 1024 * 1024 * 10, + BatchingMaxPublishDelay: 1 * time.Hour, + }) + assert.Nil(t, err) + defer producer.Close() + + ctx := context.Background() + for i := 0; i < numMessages; i++ { + msg := &ProducerMessage{Payload: []byte(fmt.Sprintf("msg-%d", startIndex+i))} + if batching { + producer.SendAsync(ctx, msg, func(_ MessageID, _ *ProducerMessage, err error) { + if err != nil { + t.Logf("Failed to send message: %v", err) + } + }) + } else { + if _, err := producer.Send(ctx, msg); err != nil { + assert.Fail(t, "Failed to send message: %v", err) + } + } + } + assert.Nil(t, producer.Flush()) +} + +func receiveMessages(t *testing.T, consumer Consumer, numMessages int) []Message { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + msgs := make([]Message, 0) + for i := 0; i < numMessages; i++ { + if msg, err := consumer.Receive(ctx); err == nil { + msgs = append(msgs, msg) + } else { + t.Logf("Failed to receive message: %v", err) + break + } + } + assert.Equal(t, numMessages, len(msgs)) + return msgs +} diff --git a/pulsar/consumer_zero_queue.go b/pulsar/consumer_zero_queue.go index 8117127286..3f2862da2d 100644 --- a/pulsar/consumer_zero_queue.go +++ b/pulsar/consumer_zero_queue.go @@ -171,6 +171,10 @@ func (z *zeroQueueConsumer) AckID(msgID MessageID) error { return z.pc.AckID(msgID) } +func (z *zeroQueueConsumer) AckIDList(msgIDs []MessageID) error { + return z.pc.AckIDList(msgIDs) +} + func (z *zeroQueueConsumer) AckWithTxn(msg Message, txn Transaction) error { msgID := msg.ID() if err := z.checkMsgIDPartition(msgID); err != nil { diff --git a/pulsar/impl_message.go b/pulsar/impl_message.go index 478b1af27c..0acd782b80 100644 --- a/pulsar/impl_message.go +++ b/pulsar/impl_message.go @@ -404,6 +404,12 @@ type ackTracker struct { prevBatchAcked uint32 } +func (t *ackTracker) getAckBitSet() *bitset.BitSet { + t.Lock() + defer t.Unlock() + return t.batchIDs.Clone() +} + func (t *ackTracker) ack(batchID int) bool { if batchID < 0 { return true diff --git a/pulsar/internal/pulsartracing/consumer_interceptor_test.go b/pulsar/internal/pulsartracing/consumer_interceptor_test.go index 1fa1bf0d17..e7712356f5 100644 --- a/pulsar/internal/pulsartracing/consumer_interceptor_test.go +++ b/pulsar/internal/pulsartracing/consumer_interceptor_test.go @@ -79,6 +79,10 @@ func (c *mockConsumer) AckID(_ pulsar.MessageID) error { return nil } +func (c *mockConsumer) AckIDList(_ []pulsar.MessageID) error { + return nil +} + func (c *mockConsumer) AckCumulative(_ pulsar.Message) error { return nil }