From 0186bc343214c67dd13afc0ccbeda44f557724a7 Mon Sep 17 00:00:00 2001 From: Piotr Piotrowski Date: Fri, 2 Aug 2024 10:49:36 +0200 Subject: [PATCH 1/2] [ADDED] Closed method to wait for Consume to complete Signed-off-by: Piotr Piotrowski --- jetstream/ordered.go | 57 +++++++++---- jetstream/pull.go | 30 +++++++ jetstream/test/ordered_test.go | 129 ++++++++++++++++++++++++++++ jetstream/test/pull_test.go | 148 +++++++++++++++++++++++++++++++++ 4 files changed, 349 insertions(+), 15 deletions(-) diff --git a/jetstream/ordered.go b/jetstream/ordered.go index 85b7ea9e9..c15434284 100644 --- a/jetstream/ordered.go +++ b/jetstream/ordered.go @@ -32,13 +32,13 @@ type ( cfg *OrderedConsumerConfig stream string currentConsumer *pullConsumer - currentSub ConsumeContext + currentSub *pullSubscription cursor cursor namePrefix string serial int consumerType consumerType doReset chan struct{} - resetInProgress uint32 + resetInProgress atomic.Uint32 userErrHandler ConsumeErrHandlerFunc stopAfter int stopAfterMsgsLeft chan int @@ -52,7 +52,7 @@ type ( consumer *orderedConsumer opts []PullMessagesOpt done chan struct{} - closed uint32 + closed atomic.Uint32 } cursor struct { @@ -138,7 +138,7 @@ func (c *orderedConsumer) Consume(handler MessageHandler, opts ...PullConsumeOpt if err != nil { return nil, err } - c.currentSub = cc + c.currentSub = cc.(*pullSubscription) go func() { for { @@ -175,7 +175,7 @@ func (c *orderedConsumer) Consume(handler MessageHandler, opts ...PullConsumeOpt c.errHandler(c.serial)(cc, err) } else { c.Lock() - c.currentSub = cc + c.currentSub = cc.(*pullSubscription) c.Unlock() } case <-sub.done: @@ -210,8 +210,8 @@ func (c *orderedConsumer) errHandler(serial int) func(cc ConsumeContext, err err errors.Is(err, ErrConsumerDeleted) || errors.Is(err, errConnected) { // only reset if serial matches the current consumer serial and there is no reset in progress - if serial == c.serial && atomic.LoadUint32(&c.resetInProgress) == 0 { - atomic.StoreUint32(&c.resetInProgress, 1) + if serial == c.serial && c.resetInProgress.Load() == 0 { + c.resetInProgress.Store(1) c.doReset <- struct{}{} } } @@ -256,7 +256,7 @@ func (c *orderedConsumer) Messages(opts ...PullMessagesOpt) (MessagesContext, er if err != nil { return nil, err } - c.currentSub = cc + c.currentSub = cc.(*pullSubscription) sub := &orderedSubscription{ consumer: c, @@ -270,7 +270,7 @@ func (c *orderedConsumer) Messages(opts ...PullMessagesOpt) (MessagesContext, er func (s *orderedSubscription) Next() (Msg, error) { for { - msg, err := s.consumer.currentSub.(*pullSubscription).Next() + msg, err := s.consumer.currentSub.Next() if err != nil { if errors.Is(err, ErrMsgIteratorClosed) { s.Stop() @@ -297,7 +297,7 @@ func (s *orderedSubscription) Next() (Msg, error) { if err != nil { return nil, err } - s.consumer.currentSub = cc + s.consumer.currentSub = cc.(*pullSubscription) continue } @@ -321,7 +321,7 @@ func (s *orderedSubscription) Next() (Msg, error) { if err != nil { return nil, err } - s.consumer.currentSub = cc + s.consumer.currentSub = cc.(*pullSubscription) continue } s.consumer.cursor.deliverSeq = dseq @@ -331,7 +331,7 @@ func (s *orderedSubscription) Next() (Msg, error) { } func (s *orderedSubscription) Stop() { - if !atomic.CompareAndSwapUint32(&s.closed, 0, 1) { + if !s.closed.CompareAndSwap(0, 1) { return } s.consumer.Lock() @@ -343,7 +343,7 @@ func (s *orderedSubscription) Stop() { } func (s *orderedSubscription) Drain() { - if !atomic.CompareAndSwapUint32(&s.closed, 0, 1) { + if !s.closed.CompareAndSwap(0, 1) { return } if s.consumer.currentSub != nil { @@ -354,6 +354,33 @@ func (s *orderedSubscription) Drain() { close(s.done) } +// Closed returns a channel that is closed when the consuming is +// fully stopped/drained. When the channel is closed, no more messages +// will be received and processing is complete. +func (s *orderedSubscription) Closed() <-chan struct{} { + s.consumer.Lock() + defer s.consumer.Unlock() + closedCh := make(chan struct{}) + + go func() { + for { + s.consumer.Lock() + if s.consumer.currentSub == nil { + return + } + closed := s.consumer.currentSub.Closed() + s.consumer.Unlock() + + <-closed + if s.closed.Load() == 1 { + close(closedCh) + return + } + } + }() + return closedCh +} + // Fetch is used to retrieve up to a provided number of messages from a // stream. This method will always send a single request and wait until // either all messages are retrieved or request times out. @@ -495,7 +522,7 @@ func serialNumberFromConsumer(name string) int { func (c *orderedConsumer) reset() error { c.Lock() defer c.Unlock() - defer atomic.StoreUint32(&c.resetInProgress, 0) + defer c.resetInProgress.Store(0) if c.currentConsumer != nil { c.currentConsumer.Lock() if c.currentSub != nil { @@ -524,7 +551,7 @@ func (c *orderedConsumer) reset() error { cancel: c.subscription.done, } err = retryWithBackoff(func(attempt int) (bool, error) { - isClosed := atomic.LoadUint32(&c.subscription.closed) == 1 + isClosed := c.subscription.closed.Load() == 1 if isClosed { return false, errOrderedConsumerClosed } diff --git a/jetstream/pull.go b/jetstream/pull.go index a510c7c8c..540282877 100644 --- a/jetstream/pull.go +++ b/jetstream/pull.go @@ -59,6 +59,11 @@ type ( // Drain unsubscribes from the stream and cancels subscription. // All messages that are already in the buffer will be processed in callback function. Drain() + + // Closed returns a channel that is closed when the consuming is + // fully stopped/drained. When the channel is closed, no more messages + // will be received and processing is complete. + Closed() <-chan struct{} } // MessageHandler is a handler function used as callback in [Consume]. @@ -125,6 +130,7 @@ type ( fetchNext chan *pullRequest consumeOpts *consumeOpts delivered int + closedCh chan struct{} } pendingMsgs struct { @@ -257,6 +263,12 @@ func (p *pullConsumer) Consume(handler MessageHandler, opts ...PullConsumeOpt) ( return func(subject string) { p.subs.Delete(sid) sub.draining.CompareAndSwap(1, 0) + sub.Lock() + if sub.closedCh != nil { + close(sub.closedCh) + sub.closedCh = nil + } + sub.Unlock() } }(sub.id)) @@ -649,6 +661,24 @@ func (s *pullSubscription) Drain() { } } +// Closed returns a channel that is closed when consuming is +// fully stopped/drained. When the channel is closed, no more messages +// will be received and processing is complete. +func (s *pullSubscription) Closed() <-chan struct{} { + s.Lock() + defer s.Unlock() + closedCh := s.closedCh + if closedCh == nil { + closedCh = make(chan struct{}) + s.closedCh = closedCh + } + if !s.subscription.IsValid() { + close(s.closedCh) + s.closedCh = nil + } + return closedCh +} + // Fetch sends a single request to retrieve given number of messages. // It will wait up to provided expiry time if not all messages are available. func (p *pullConsumer) Fetch(batch int, opts ...FetchOpt) (MessageBatch, error) { diff --git a/jetstream/test/ordered_test.go b/jetstream/test/ordered_test.go index 522c92196..1558b27c0 100644 --- a/jetstream/test/ordered_test.go +++ b/jetstream/test/ordered_test.go @@ -578,6 +578,135 @@ func TestOrderedConsumerConsume(t *testing.T) { time.Sleep(50 * time.Millisecond) } }) + + t.Run("wait for closed after drain", func(t *testing.T) { + for i := 0; i < 10; i++ { + t.Run(fmt.Sprintf("run %d", i), func(t *testing.T) { + srv := RunBasicJetStreamServer() + defer shutdownJSServerAndRemoveStorage(t, srv) + nc, err := nats.Connect(srv.ClientURL()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + js, err := jetstream.New(nc) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer nc.Close() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + c, err := s.OrderedConsumer(ctx, jetstream.OrderedConsumerConfig{}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + msgs := make([]jetstream.Msg, 0) + lock := sync.Mutex{} + publishTestMsgs(t, js) + cc, err := c.Consume(func(msg jetstream.Msg) { + time.Sleep(50 * time.Millisecond) + msg.Ack() + lock.Lock() + msgs = append(msgs, msg) + lock.Unlock() + }) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + closed := cc.Closed() + time.Sleep(100 * time.Millisecond) + if err := s.DeleteConsumer(context.Background(), c.CachedInfo().Name); err != nil { + t.Fatalf("Unexpected error: %v", err) + } + publishTestMsgs(t, js) + + // wait for the consumer to be recreated before calling drain + for i := 0; i < 5; i++ { + _, err = c.Info(ctx) + if err != nil { + if errors.Is(err, jetstream.ErrConsumerNotFound) { + time.Sleep(100 * time.Millisecond) + continue + } + t.Fatalf("Unexpected error: %v", err) + } + break + } + + cc.Drain() + + select { + case <-closed: + case <-time.After(5 * time.Second): + t.Fatalf("Timeout waiting for consume to be closed") + } + + if len(msgs) != 2*len(testMsgs) { + t.Fatalf("Unexpected received message count after consume closed; want %d; got %d", 2*len(testMsgs), len(msgs)) + } + }) + } + }) + + t.Run("wait for closed on already closed consume", func(t *testing.T) { + for i := 0; i < 10; i++ { + t.Run(fmt.Sprintf("run %d", i), func(t *testing.T) { + srv := RunBasicJetStreamServer() + defer shutdownJSServerAndRemoveStorage(t, srv) + nc, err := nats.Connect(srv.ClientURL()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + js, err := jetstream.New(nc) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer nc.Close() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + c, err := s.OrderedConsumer(ctx, jetstream.OrderedConsumerConfig{}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + msgs := make([]jetstream.Msg, 0) + lock := sync.Mutex{} + publishTestMsgs(t, js) + cc, err := c.Consume(func(msg jetstream.Msg) { + time.Sleep(50 * time.Millisecond) + msg.Ack() + lock.Lock() + msgs = append(msgs, msg) + lock.Unlock() + }) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + time.Sleep(100 * time.Millisecond) + if err := s.DeleteConsumer(context.Background(), c.CachedInfo().Name); err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + cc.Stop() + + time.Sleep(100 * time.Millisecond) + + select { + case <-cc.Closed(): + case <-time.After(5 * time.Second): + t.Fatalf("Timeout waiting for consume to be closed") + } + }) + } + }) } func TestOrderedConsumerMessages(t *testing.T) { diff --git a/jetstream/test/pull_test.go b/jetstream/test/pull_test.go index f35aae315..4042e52f5 100644 --- a/jetstream/test/pull_test.go +++ b/jetstream/test/pull_test.go @@ -2584,6 +2584,154 @@ func TestPullConsumerConsume(t *testing.T) { cc.Drain() wg.Wait() }) + + t.Run("wait for closed after drain", func(t *testing.T) { + srv := RunBasicJetStreamServer() + defer shutdownJSServerAndRemoveStorage(t, srv) + nc, err := nats.Connect(srv.ClientURL()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + js, err := jetstream.New(nc) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer nc.Close() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + msgs := make([]jetstream.Msg, 0) + lock := sync.Mutex{} + publishTestMsgs(t, js) + cc, err := c.Consume(func(msg jetstream.Msg) { + time.Sleep(50 * time.Millisecond) + msg.Ack() + lock.Lock() + msgs = append(msgs, msg) + lock.Unlock() + }) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + closed := cc.Closed() + time.Sleep(100 * time.Millisecond) + + cc.Drain() + + select { + case <-closed: + case <-time.After(5 * time.Second): + t.Fatalf("Timeout waiting for consume to be closed") + } + + if len(msgs) != len(testMsgs) { + t.Fatalf("Unexpected received message count after consume closed; want %d; got %d", len(testMsgs), len(msgs)) + } + }) + + t.Run("wait for closed after stop", func(t *testing.T) { + srv := RunBasicJetStreamServer() + defer shutdownJSServerAndRemoveStorage(t, srv) + nc, err := nats.Connect(srv.ClientURL()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + js, err := jetstream.New(nc) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer nc.Close() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + msgs := make([]jetstream.Msg, 0) + lock := sync.Mutex{} + publishTestMsgs(t, js) + cc, err := c.Consume(func(msg jetstream.Msg) { + time.Sleep(50 * time.Millisecond) + msg.Ack() + lock.Lock() + msgs = append(msgs, msg) + lock.Unlock() + }) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + time.Sleep(100 * time.Millisecond) + closed := cc.Closed() + + cc.Stop() + + select { + case <-closed: + case <-time.After(5 * time.Second): + t.Fatalf("Timeout waiting for consume to be closed") + } + + if len(msgs) < 1 || len(msgs) > 3 { + t.Fatalf("Unexpected received message count after consume closed; want 1-3; got %d", len(msgs)) + } + }) + + t.Run("wait for closed on already closed consume", func(t *testing.T) { + srv := RunBasicJetStreamServer() + defer shutdownJSServerAndRemoveStorage(t, srv) + nc, err := nats.Connect(srv.ClientURL()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + js, err := jetstream.New(nc) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer nc.Close() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + publishTestMsgs(t, js) + cc, err := c.Consume(func(msg jetstream.Msg) { + time.Sleep(50 * time.Millisecond) + msg.Ack() + }) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + time.Sleep(100 * time.Millisecond) + + cc.Stop() + + time.Sleep(100 * time.Millisecond) + + select { + case <-cc.Closed(): + case <-time.After(5 * time.Second): + t.Fatalf("Timeout waiting for consume to be closed") + } + }) } func TestPullConsumerConsume_WithCluster(t *testing.T) { From 19996e3ae0347070f30409c93895604a25770e5c Mon Sep 17 00:00:00 2001 From: Piotr Piotrowski Date: Mon, 12 Aug 2024 15:03:51 +0200 Subject: [PATCH 2/2] Add a comment Signed-off-by: Piotr Piotrowski --- jetstream/ordered.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/jetstream/ordered.go b/jetstream/ordered.go index c15434284..1f78568fc 100644 --- a/jetstream/ordered.go +++ b/jetstream/ordered.go @@ -368,10 +368,14 @@ func (s *orderedSubscription) Closed() <-chan struct{} { if s.consumer.currentSub == nil { return } + closed := s.consumer.currentSub.Closed() s.consumer.Unlock() + // wait until the underlying pull consumer is closed <-closed + // if the subscription is closed and ordered consumer is closed as well, + // send a signal that the Consume() is fully stopped if s.closed.Load() == 1 { close(closedCh) return