From 1b4e7588b297b230db77917361b2fb6c23372adf Mon Sep 17 00:00:00 2001 From: Tomasz Pietrek Date: Fri, 1 Mar 2024 14:26:07 +0100 Subject: [PATCH] Add consumer Owner ID Signed-off-by: Tomasz Pietrek --- server/consumer.go | 25 ++++--- server/jetstream_api.go | 1 + server/jetstream_consumer_test.go | 120 ++++++++++++++++++++++++++++++ server/jetstream_test.go | 2 +- 4 files changed, 138 insertions(+), 10 deletions(-) diff --git a/server/consumer.go b/server/consumer.go index d33a7bffde8..f1b6baf9504 100644 --- a/server/consumer.go +++ b/server/consumer.go @@ -109,6 +109,7 @@ type ConsumerConfig struct { // PauseUntil is for suspending the consumer until the deadline. PauseUntil *time.Time `json:"pause_until,omitempty"` + OwnerID string `json:"owner_id,omitempty"` } // SequenceInfo has both the consumer and the stream sequence and last activity. @@ -2976,36 +2977,36 @@ func (o *consumer) needAck(sseq uint64, subj string) bool { } // Helper for the next message requests. -func nextReqFromMsg(msg []byte) (time.Time, int, int, bool, time.Duration, time.Time, error) { +func nextReqFromMsg(msg []byte) (time.Time, int, int, bool, time.Duration, time.Time, string, error) { req := bytes.TrimSpace(msg) switch { case len(req) == 0: - return time.Time{}, 1, 0, false, 0, time.Time{}, nil + return time.Time{}, 1, 0, false, 0, time.Time{}, "", nil case req[0] == '{': var cr JSApiConsumerGetNextRequest if err := json.Unmarshal(req, &cr); err != nil { - return time.Time{}, -1, 0, false, 0, time.Time{}, err + return time.Time{}, -1, 0, false, 0, time.Time{}, "", err } var hbt time.Time if cr.Heartbeat > 0 { if cr.Heartbeat*2 > cr.Expires { - return time.Time{}, 1, 0, false, 0, time.Time{}, errors.New("heartbeat value too large") + return time.Time{}, 1, 0, false, 0, time.Time{}, "", errors.New("heartbeat value too large") } hbt = time.Now().Add(cr.Heartbeat) } if cr.Expires == time.Duration(0) { - return time.Time{}, cr.Batch, cr.MaxBytes, cr.NoWait, cr.Heartbeat, hbt, nil + return time.Time{}, cr.Batch, cr.MaxBytes, cr.NoWait, cr.Heartbeat, hbt, cr.OwnerID, nil } - return time.Now().Add(cr.Expires), cr.Batch, cr.MaxBytes, cr.NoWait, cr.Heartbeat, hbt, nil + return time.Now().Add(cr.Expires), cr.Batch, cr.MaxBytes, cr.NoWait, cr.Heartbeat, hbt, cr.OwnerID, nil default: if n, err := strconv.Atoi(string(req)); err == nil { - return time.Time{}, n, 0, false, 0, time.Time{}, nil + return time.Time{}, n, 0, false, 0, time.Time{}, "", nil } } - return time.Time{}, 1, 0, false, 0, time.Time{}, nil + return time.Time{}, 1, 0, false, 0, time.Time{}, "", nil } // Represents a request that is on the internal waiting queue @@ -3321,12 +3322,18 @@ func (o *consumer) processNextMsgRequest(reply string, msg []byte) { } // Check payload here to see if they sent in batch size or a formal request. - expires, batchSize, maxBytes, noWait, hb, hbt, err := nextReqFromMsg(msg) + expires, batchSize, maxBytes, noWait, hb, hbt, ownerID, err := nextReqFromMsg(msg) if err != nil { sendErr(400, fmt.Sprintf("Bad Request - %v", err)) return } + // Check the owner for exclusive consumer. + if o.cfg.OwnerID != _EMPTY_ && ownerID != o.cfg.OwnerID { + sendErr(412, "Consumer is owned by another client") + return + } + // Check for request limits if o.cfg.MaxRequestBatch > 0 && batchSize > o.cfg.MaxRequestBatch { sendErr(409, fmt.Sprintf("Exceeded MaxRequestBatch of %d", o.cfg.MaxRequestBatch)) diff --git a/server/jetstream_api.go b/server/jetstream_api.go index af27eb3ebd4..2d0bd4168e0 100644 --- a/server/jetstream_api.go +++ b/server/jetstream_api.go @@ -730,6 +730,7 @@ type JSApiConsumerGetNextRequest struct { MaxBytes int `json:"max_bytes,omitempty"` NoWait bool `json:"no_wait,omitempty"` Heartbeat time.Duration `json:"idle_heartbeat,omitempty"` + OwnerID string `json:"owner_id,omitempty"` } // JSApiStreamTemplateCreateResponse for creating templates. diff --git a/server/jetstream_consumer_test.go b/server/jetstream_consumer_test.go index ec8fdf12cff..17319e0a0a7 100644 --- a/server/jetstream_consumer_test.go +++ b/server/jetstream_consumer_test.go @@ -31,6 +31,126 @@ import ( "github.com/nats-io/nuid" ) +func TestJetStreamConsumerExclusive(t *testing.T) { + s := RunBasicJetStreamServer(t) + defer s.Shutdown() + + nc, js := jsClientConnect(t, s) + defer nc.Close() + acc := s.GlobalAccount() + + mset, err := acc.addStream(&StreamConfig{ + Name: "TEST", + Retention: LimitsPolicy, + Subjects: []string{"events.>"}, + MaxAge: time.Second * 90, + }) + require_NoError(t, err) + + _, err = mset.addConsumer(&ConsumerConfig{ + Durable: "consumer", + AckPolicy: AckExplicit, + DeliverPolicy: DeliverAll, + FilterSubject: "events.>", + OwnerID: "me", + }) + require_NoError(t, err) + + for i := 0; i < 10; i++ { + _, err = js.Publish("events.1", []byte("hello")) + require_NoError(t, err) + } + + // set ID that is not owned by us. + cr := JSApiConsumerGetNextRequest{ + Batch: 1, + OwnerID: "notMe", + } + crBytes, err := json.Marshal(cr) + require_NoError(t, err) + + inbox := nats.NewInbox() + err = nc.PublishRequest(fmt.Sprintf(JSApiRequestNextT, "TEST", "consumer"), inbox, crBytes) + require_NoError(t, err) + + consumerSub, err := nc.SubscribeSync(inbox) + require_NoError(t, err) + + msg, err := consumerSub.NextMsg(time.Second) + require_NoError(t, err) + + // check if message header contains error "Consumer is owned by another client" + if !strings.Contains(string(msg.Header.Get("Status")), "412") { + t.Fatalf("Expected exclusive consumer error, got %q", msg.Header.Get("Description")) + } + + // now set our ID + cr = JSApiConsumerGetNextRequest{ + Batch: 2, + OwnerID: "me", + } + crBytes, err = json.Marshal(cr) + require_NoError(t, err) + + err = nc.PublishRequest(fmt.Sprintf(JSApiRequestNextT, "TEST", "consumer"), inbox, crBytes) + require_NoError(t, err) + + msg, err = consumerSub.NextMsg(time.Second) + require_NoError(t, err) + require_Equal(t, string(msg.Data), "hello") + + // update the consumer to different ID + _, err = mset.addConsumer(&ConsumerConfig{ + Durable: "consumer", + AckPolicy: AckExplicit, + DeliverPolicy: DeliverAll, + FilterSubject: "events.>", + OwnerID: "differentMe", + }) + require_NoError(t, err) + + // we should still get messages from the pending pull requests + msg, err = consumerSub.NextMsg(time.Second) + require_NoError(t, err) + require_Equal(t, string(msg.Data), "hello") + + // check if the previous ID works. It should not + cr = JSApiConsumerGetNextRequest{ + Batch: 1, + OwnerID: "me", + } + crBytes, err = json.Marshal(cr) + require_NoError(t, err) + + err = nc.PublishRequest(fmt.Sprintf(JSApiRequestNextT, "TEST", "consumer"), inbox, crBytes) + require_NoError(t, err) + + msg, err = consumerSub.NextMsg(time.Second) + require_NoError(t, err) + + // we should now get an error + if !strings.Contains(string(msg.Header.Get("Status")), "412") { + t.Fatalf("Expected exclusive consumer error, got %q", msg.Header.Get("Description")) + } + + // and this should work now + + cr = JSApiConsumerGetNextRequest{ + Batch: 1, + OwnerID: "differentMe", + } + crBytes, err = json.Marshal(cr) + require_NoError(t, err) + + err = nc.PublishRequest(fmt.Sprintf(JSApiRequestNextT, "TEST", "consumer"), inbox, crBytes) + require_NoError(t, err) + + msg, err = consumerSub.NextMsg(time.Second) + require_NoError(t, err) + require_Equal(t, string(msg.Data), "hello") + +} + func TestJetStreamConsumerMultipleFiltersRemoveFilters(t *testing.T) { s := RunBasicJetStreamServer(t) diff --git a/server/jetstream_test.go b/server/jetstream_test.go index 8e445555d91..2084d013307 100644 --- a/server/jetstream_test.go +++ b/server/jetstream_test.go @@ -639,7 +639,7 @@ func TestJetStreamConsumerMaxDeliveries(t *testing.T) { func TestJetStreamNextReqFromMsg(t *testing.T) { bef := time.Now() - expires, _, _, _, _, _, err := nextReqFromMsg([]byte(`{"expires":5000000000}`)) // nanoseconds + expires, _, _, _, _, _, _, err := nextReqFromMsg([]byte(`{"expires":5000000000}`)) // nanoseconds require_NoError(t, err) now := time.Now() if expires.Before(bef.Add(5*time.Second)) || expires.After(now.Add(5*time.Second)) {