diff --git a/jetstream/errors.go b/jetstream/errors.go index 77ddf3d38..7125c120f 100644 --- a/jetstream/errors.go +++ b/jetstream/errors.go @@ -109,6 +109,9 @@ var ( // ErrInvalidStreamName is returned when the provided stream name is invalid (contains '.'). ErrInvalidStreamName JetStreamError = &jsError{message: "invalid stream name"} + // ErrInvalidSubject is returned when the provided subject name is invalid. + ErrInvalidSubject JetStreamError = &jsError{message: "invalid subject name"} + // ErrInvalidConsumerName is returned when the provided consumer name is invalid (contains '.'). ErrInvalidConsumerName JetStreamError = &jsError{message: "invalid consumer name"} diff --git a/jetstream/jetstream.go b/jetstream/jetstream.go index ed68d25c6..563699ac2 100644 --- a/jetstream/jetstream.go +++ b/jetstream/jetstream.go @@ -18,6 +18,7 @@ import ( "encoding/json" "errors" "fmt" + "regexp" "strings" "github.com/nats-io/nats.go" @@ -68,6 +69,8 @@ type ( UpdateStream(context.Context, StreamConfig) (Stream, error) // Stream returns a [Stream] hook for a given stream name Stream(context.Context, string) (Stream, error) + // StreamNameBySubject returns a stream name stream listening on given subject + StreamNameBySubject(context.Context, string) (string, error) // DeleteStream removes a stream with given name DeleteStream(context.Context, string) error // ListStreams returns StreamInfoLister enabling iterating over a channel of stream infos @@ -186,8 +189,14 @@ type ( apiPaged Streams []string `json:"streams"` } + + streamsRequest struct { + Subject string `json:"subject,omitempty"` + } ) +var subjectRegexp = regexp.MustCompile(`^[^ >]*[>]?$`) + // New returns a enw JetStream instance // // Available options: @@ -494,6 +503,16 @@ func validateStreamName(stream string) error { return nil } +func validateSubject(subject string) error { + if subject == "" { + return fmt.Errorf("%w: %s", ErrInvalidSubject, "subject cannot be empty") + } + if !subjectRegexp.MatchString(subject) { + return fmt.Errorf("%w: %s", ErrInvalidSubject, subject) + } + return nil +} + func (js *jetStream) AccountInfo(ctx context.Context) (*AccountInfo, error) { var resp accountInfoResponse @@ -591,6 +610,32 @@ func (js *jetStream) StreamNames(ctx context.Context) StreamNameLister { return l } +func (js *jetStream) StreamNameBySubject(ctx context.Context, subject string) (string, error) { + if err := validateSubject(subject); err != nil { + return "", err + } + streamsSubject := apiSubj(js.apiPrefix, apiStreams) + + r := &streamsRequest{Subject: subject} + req, err := json.Marshal(r) + if err != nil { + return "", err + } + var resp streamNamesResponse + _, err = js.apiRequestJSON(ctx, streamsSubject, &resp, req) + if err != nil { + return "", err + } + if resp.Error != nil { + return "", resp.Error + } + if len(resp.Streams) == 0 { + return "", ErrStreamNotFound + } + + return resp.Streams[0], nil +} + // Name returns a channel allowing retrieval of stream names returned by [StreamNames] func (s *streamLister) Name() <-chan string { return s.names diff --git a/jetstream/message_test.go b/jetstream/jetstream_test.go similarity index 75% rename from jetstream/message_test.go rename to jetstream/jetstream_test.go index 653ee07b1..e7f2317db 100644 --- a/jetstream/message_test.go +++ b/jetstream/jetstream_test.go @@ -15,6 +15,7 @@ package jetstream import ( "errors" + "fmt" "testing" "time" @@ -80,3 +81,38 @@ func TestMessageMetadata(t *testing.T) { }) } } + +func TestValidateSubject(t *testing.T) { + tests := []struct { + subject string + withError bool + }{ + {"test.A", false}, + {"test.*", false}, + {"*", false}, + {"*.*", false}, + {"test.*.A", false}, + {"test.>", false}, + {">", false}, + {">.", true}, + {"test.>.A", true}, + {"", true}, + {"test A", true}, + } + + for _, test := range tests { + tName := fmt.Sprintf("subj=%s,err=%t", test.subject, test.withError) + t.Run(tName, func(t *testing.T) { + err := validateSubject(test.subject) + if test.withError { + if err == nil { + t.Fatal("Expected error; got nil") + } + return + } + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + }) + } +} diff --git a/jetstream/test/jetstream_test.go b/jetstream/test/jetstream_test.go index a9323eca8..3416c1410 100644 --- a/jetstream/test/jetstream_test.go +++ b/jetstream/test/jetstream_test.go @@ -928,3 +928,81 @@ func TestJetStream_DeleteConsumer(t *testing.T) { }) } } + +func TestStreamNameBySubject(t *testing.T) { + tests := []struct { + name string + subject string + withError error + expected string + }{ + { + name: "get stream name by subject explicit", + subject: "FOO.123", + expected: "foo", + }, + { + name: "get stream name by subject with wildcard", + subject: "BAR.*", + expected: "bar", + }, + { + name: "match more than one stream, return the first one", + subject: ">", + expected: "", + }, + { + name: "stream not found", + subject: "BAR.XYZ", + withError: jetstream.ErrStreamNotFound, + }, + { + name: "invalid subject", + subject: "FOO.>.123", + withError: jetstream.ErrInvalidSubject, + }, + } + + srv := RunBasicJetStreamServer() + defer shutdownJSServerAndRemoveStorage(t, srv) + nc, err := nats.Connect(srv.ClientURL()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + js, err := jetstream.New(nc) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer nc.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _, err = js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + _, err = js.CreateStream(ctx, jetstream.StreamConfig{Name: "bar", Subjects: []string{"BAR.ABC"}}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + name, err := js.StreamNameBySubject(ctx, test.subject) + if test.withError != nil { + if err == nil || !errors.Is(err, test.withError) { + t.Fatalf("Expected error: %v; got: %v", test.withError, err) + } + return + } + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + if test.expected != "" && name != test.expected { + t.Fatalf("Unexpected stream name; want: %s; got: %s", test.expected, name) + } + + }) + } +} diff --git a/micro/service.go b/micro/service.go index 445c81fc4..6c01570d8 100644 --- a/micro/service.go +++ b/micro/service.go @@ -264,7 +264,7 @@ var ( // this regular expression is suggested regexp for semver validation: https://semver.org/ semVerRegexp = regexp.MustCompile(`^(0|[1-9]\d*)\.(0|[1-9]\d*)\.(0|[1-9]\d*)(?:-((?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*)(?:\.(?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*))*))?(?:\+([0-9a-zA-Z-]+(?:\.[0-9a-zA-Z-]+)*))?$`) nameRegexp = regexp.MustCompile(`^[A-Za-z0-9\-_]+$`) - subjectRegexp = regexp.MustCompile(`^[^ >]+[>]?$`) + subjectRegexp = regexp.MustCompile(`^[^ >]*[>]?$`) ) // Common errors returned by the Service framework.