From 0530f9ae95bf4a28d41b14450e5608cdd74a5688 Mon Sep 17 00:00:00 2001 From: Dimitris Halatsis Date: Sun, 4 Aug 2024 23:41:51 +0300 Subject: [PATCH] feat(pubsub): make batch requests provide results independently --- pubsub/pubsub.go | 79 +++++++++++++++++++++++++++++++----------------- 1 file changed, 51 insertions(+), 28 deletions(-) diff --git a/pubsub/pubsub.go b/pubsub/pubsub.go index f68b17e03e..3d68e54646 100644 --- a/pubsub/pubsub.go +++ b/pubsub/pubsub.go @@ -562,20 +562,37 @@ func (s *Subscription) Receive(ctx context.Context) (_ *Message, err error) { // log.Printf("BATCH SIZE %d", batchSize) go func() { + defer func() { + close(s.waitc) + s.waitc = nil + }() if s.preReceiveBatchHook != nil { s.preReceiveBatchHook(batchSize) } - msgs, err := s.getNextBatch(batchSize) - s.mu.Lock() - defer s.mu.Unlock() - if err != nil { - // Non-retryable error from ReceiveBatch -> permanent error. - s.err = err - } else if len(msgs) > 0 { - s.q = append(s.q, msgs...) + resultChannel := s.getNextBatch(batchSize) + for { + select { + case msgs, ok := <-resultChannel.msgs: + if !ok { + // batch reception finished + return + } else if len(msgs) > 0 { + // messages received from channel + s.mu.Lock() + s.q = append(s.q, msgs...) + s.mu.Unlock() + } + case err := <-resultChannel.err: + // err can receive message only after batch group completes + if err != nil { + // Non-retryable error from ReceiveBatch -> permanent error. + s.mu.Lock() + s.err = err + s.mu.Unlock() + } + return + } } - close(s.waitc) - s.waitc = nil }() } if len(s.q) > 0 { @@ -623,30 +640,33 @@ func (s *Subscription) Receive(ctx context.Context) (_ *Message, err error) { }) return m2, nil } - // A call to ReceiveBatch must be in flight. Wait for it. - waitc := s.waitc - s.mu.Unlock() + s.mu.Unlock() // unlock to allow message or error processing from background goroutine select { - case <-waitc: - s.mu.Lock() - // Continue to top of loop. case <-ctx.Done(): s.mu.Lock() return nil, ctx.Err() + default: + // Continue to top of loop. + s.mu.Lock() } } } -// getNextBatch gets the next batch of messages from the server and returns it. -func (s *Subscription) getNextBatch(nMessages int) ([]*driver.Message, error) { - var mu sync.Mutex - var q []*driver.Message +type batchChannelResult struct { + msgs chan []*driver.Message + err chan error +} +// getNextBatch gets the next batch of messages from the server and returns it. +func (s *Subscription) getNextBatch(nMessages int) *batchChannelResult { // Split nMessages into batches based on recvBatchOpts; we'll make a // separate ReceiveBatch call for each batch, and aggregate the results in // msgs. batches := batcher.Split(nMessages, s.recvBatchOpts) - + result := batchChannelResult{ + msgs: make(chan []*driver.Message, len(batches)), + err: make(chan error), + } g, ctx := errgroup.WithContext(s.backgroundCtx) for _, maxMessagesInBatch := range batches { // Make a copy of the loop variable since it will be used by a goroutine. @@ -663,16 +683,19 @@ func (s *Subscription) getNextBatch(nMessages int) ([]*driver.Message, error) { if err != nil { return wrapError(s.driver, err) } - mu.Lock() - defer mu.Unlock() - q = append(q, msgs...) + result.msgs <- msgs return nil }) } - if err := g.Wait(); err != nil { - return nil, err - } - return q, nil + go func() { + // wait on group completion on the background and proper channel closing + if err := g.Wait(); err != nil { + result.err <- err + } + close(result.err) + close(result.msgs) + }() + return &result } var errSubscriptionShutdown = gcerr.Newf(gcerr.FailedPrecondition, nil, "pubsub: Subscription has been Shutdown")