From 92a21de56d88c9c693752b06246d5b0b1f7f5227 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?K=C3=A9vin=20Dunglas?= Date: Fri, 12 Aug 2022 16:39:57 +0200 Subject: [PATCH] fix(hub): ensure that an update is dispatched if any of its topics is subscribed and allowed --- bolt_transport_test.go | 6 +- local_transport_test.go | 4 +- publish_test.go | 6 +- subscriber.go | 118 +++++++++++++++++++++------------------- subscriber_list.go | 27 ++++----- subscriber_test.go | 25 +++++++-- subscription_test.go | 2 +- 7 files changed, 101 insertions(+), 87 deletions(-) diff --git a/bolt_transport_test.go b/bolt_transport_test.go index 8f28c88b..8ca2a3f6 100644 --- a/bolt_transport_test.go +++ b/bolt_transport_test.go @@ -231,12 +231,12 @@ func TestBoltTransportDispatch(t *testing.T) { subscribedNotAuthorized := &Update{Topics: []string{"https://example.com/foo"}, Private: true} require.Nil(t, transport.Dispatch(subscribedNotAuthorized)) - public := &Update{Topics: s.Topics} + public := &Update{Topics: s.SubscribedTopics} require.Nil(t, transport.Dispatch(public)) assert.Equal(t, public, <-s.Receive()) - private := &Update{Topics: s.PrivateTopics, Private: true} + private := &Update{Topics: s.AllowedPrivateTopics, Private: true} require.Nil(t, transport.Dispatch(private)) assert.Equal(t, private, <-s.Receive()) @@ -256,7 +256,7 @@ func TestBoltTransportClosed(t *testing.T) { require.Nil(t, transport.Close()) require.NotNil(t, transport.AddSubscriber(s)) - assert.Equal(t, transport.Dispatch(&Update{Topics: s.Topics}), ErrClosedTransport) + assert.Equal(t, transport.Dispatch(&Update{Topics: s.SubscribedTopics}), ErrClosedTransport) _, ok := <-s.out assert.False(t, ok) diff --git a/local_transport_test.go b/local_transport_test.go index e9ccadab..11bc8dca 100644 --- a/local_transport_test.go +++ b/local_transport_test.go @@ -45,7 +45,7 @@ func TestLocalTransportDispatch(t *testing.T) { s.SetTopics([]string{"http://example.com/foo"}, nil) assert.Nil(t, transport.AddSubscriber(s)) - u := &Update{Topics: s.Topics} + u := &Update{Topics: s.SubscribedTopics} require.Nil(t, transport.Dispatch(u)) assert.Equal(t, u, <-s.Receive()) } @@ -97,7 +97,7 @@ func TestLiveReading(t *testing.T) { s.SetTopics([]string{"https://example.com"}, nil) require.Nil(t, transport.AddSubscriber(s)) - u := &Update{Topics: s.Topics} + u := &Update{Topics: s.SubscribedTopics} assert.Nil(t, transport.Dispatch(u)) receivedUpdate := <-s.Receive() diff --git a/publish_test.go b/publish_test.go index 528fc1a9..45a9082d 100644 --- a/publish_test.go +++ b/publish_test.go @@ -188,7 +188,7 @@ func TestPublishOK(t *testing.T) { assert.True(t, ok) require.NotNil(t, u) assert.Equal(t, "id", u.ID) - assert.Equal(t, s.Topics, u.Topics) + assert.Equal(t, s.SubscribedTopics, u.Topics) assert.Equal(t, "Hello!", u.Data) assert.True(t, u.Private) }(&wg) @@ -201,7 +201,7 @@ func TestPublishOK(t *testing.T) { req := httptest.NewRequest(http.MethodPost, defaultHubURL, strings.NewReader(form.Encode())) req.Header.Add("Content-Type", "application/x-www-form-urlencoded") - req.Header.Add("Authorization", "Bearer "+createDummyAuthorizedJWT(hub, rolePublisher, s.Topics)) + req.Header.Add("Authorization", "Bearer "+createDummyAuthorizedJWT(hub, rolePublisher, s.SubscribedTopics)) w := httptest.NewRecorder() hub.PublishHandler(w, req) @@ -239,7 +239,7 @@ func TestPublishGenerateUUID(t *testing.T) { h := createDummy() s := NewSubscriber("", zap.NewNop()) - s.SetTopics([]string{"http://example.com/books/1"}, s.Topics) + s.SetTopics([]string{"http://example.com/books/1"}, s.SubscribedTopics) require.Nil(t, h.transport.AddSubscriber(s)) diff --git a/subscriber.go b/subscriber.go index 0b0a34f1..49d76528 100644 --- a/subscriber.go +++ b/subscriber.go @@ -14,17 +14,17 @@ import ( // Subscriber represents a client subscribed to a list of topics. type Subscriber struct { - ID string - EscapedID string - Claims *claims - EscapedTopics []string - RequestLastEventID string - RemoteAddr string - Topics []string - TopicRegexps []*regexp.Regexp - PrivateTopics []string - PrivateRegexps []*regexp.Regexp - Debug bool + ID string + EscapedID string + Claims *claims + EscapedTopics []string + RequestLastEventID string + RemoteAddr string + SubscribedTopics []string + SubscribedTopicRegexps []*regexp.Regexp + AllowedPrivateTopics []string + AllowedPrivateRegexps []*regexp.Regexp + Debug bool disconnected int32 out chan *Update @@ -121,26 +121,26 @@ func (s *Subscriber) Disconnect() { } // SetTopics compiles topic selector regexps. -func (s *Subscriber) SetTopics(topics, privateTopics []string) { - s.Topics = topics - s.TopicRegexps = make([]*regexp.Regexp, len(topics)) - for i, ts := range topics { +func (s *Subscriber) SetTopics(subscribedTopics, allowedPrivateTopics []string) { + s.SubscribedTopics = subscribedTopics + s.SubscribedTopicRegexps = make([]*regexp.Regexp, len(subscribedTopics)) + for i, ts := range subscribedTopics { var r *regexp.Regexp if tpl, err := uritemplate.New(ts); err == nil { r = tpl.Regexp() } - s.TopicRegexps[i] = r + s.SubscribedTopicRegexps[i] = r } - s.PrivateTopics = privateTopics - s.PrivateRegexps = make([]*regexp.Regexp, len(privateTopics)) - for i, ts := range privateTopics { + s.AllowedPrivateTopics = allowedPrivateTopics + s.AllowedPrivateRegexps = make([]*regexp.Regexp, len(allowedPrivateTopics)) + for i, ts := range allowedPrivateTopics { var r *regexp.Regexp if tpl, err := uritemplate.New(ts); err == nil { r = tpl.Regexp() } - s.PrivateRegexps[i] = r + s.AllowedPrivateRegexps[i] = r } - s.EscapedTopics = escapeTopics(topics) + s.EscapedTopics = escapeTopics(subscribedTopics) } func escapeTopics(topics []string) []string { @@ -153,36 +153,48 @@ func escapeTopics(topics []string) []string { } // MatchTopic checks if the current subscriber can access to the given topic. -func (s *Subscriber) MatchTopic(topic string, private bool) (match bool) { - for i, ts := range s.Topics { - if ts == "*" || ts == topic { - match = true +// +//nolint:gocognit +func (s *Subscriber) MatchTopics(topics []string, private bool) bool { + var subscribed bool + canAccess := !private - break - } + for _, topic := range topics { + if !subscribed { + for i, ts := range s.SubscribedTopics { + if ts == "*" || ts == topic { + subscribed = true - r := s.TopicRegexps[i] - if r != nil && r.MatchString(topic) { - match = true + break + } - break + r := s.SubscribedTopicRegexps[i] + if r != nil && r.MatchString(topic) { + subscribed = true + + break + } + } } - } - if !match { - return false - } - if !private { - return true - } + if !canAccess { + for i, ts := range s.AllowedPrivateTopics { + if ts == "*" || ts == topic { + canAccess = true - for i, ts := range s.PrivateTopics { - if ts == "*" || ts == topic { - return true + break + } + + r := s.AllowedPrivateRegexps[i] + if r != nil && r.MatchString(topic) { + canAccess = true + + break + } + } } - r := s.PrivateRegexps[i] - if r != nil && r.MatchString(topic) { + if subscribed && canAccess { return true } } @@ -192,20 +204,14 @@ func (s *Subscriber) MatchTopic(topic string, private bool) (match bool) { // Match checks if the current subscriber can receive the given update. func (s *Subscriber) Match(u *Update) bool { - for _, t := range u.Topics { - if s.MatchTopic(t, u.Private) { - return true - } - } - - return false + return s.MatchTopics(u.Topics, u.Private) } // getSubscriptions return the list of subscriptions associated to this subscriber. func (s *Subscriber) getSubscriptions(topic, context string, active bool) []subscription { var subscriptions []subscription //nolint:prealloc - for k, t := range s.Topics { - if topic != "" && !s.MatchTopic(topic, false) { + for k, t := range s.SubscribedTopics { + if topic != "" && !s.MatchTopics([]string{topic}, false) { continue } @@ -233,13 +239,13 @@ func (s *Subscriber) MarshalLogObject(enc zapcore.ObjectEncoder) error { if s.RemoteAddr != "" { enc.AddString("remote_addr", s.RemoteAddr) } - if s.PrivateTopics != nil { - if err := enc.AddArray("topic_selectors", stringArray(s.PrivateTopics)); err != nil { + if s.AllowedPrivateTopics != nil { + if err := enc.AddArray("topic_selectors", stringArray(s.AllowedPrivateTopics)); err != nil { return fmt.Errorf("log error: %w", err) } } - if s.Topics != nil { - if err := enc.AddArray("topics", stringArray(s.Topics)); err != nil { + if s.SubscribedTopics != nil { + if err := enc.AddArray("topics", stringArray(s.SubscribedTopics)); err != nil { return fmt.Errorf("log error: %w", err) } } diff --git a/subscriber_list.go b/subscriber_list.go index a24884d8..e9b7238c 100644 --- a/subscriber_list.go +++ b/subscriber_list.go @@ -1,39 +1,32 @@ package mercure import ( - "strings" - "github.com/kevburnsjr/skipfilter" ) +type filter struct { + topics []string + private bool +} + type SubscriberList struct { skipfilter *skipfilter.SkipFilter } func NewSubscriberList(size int) *SubscriberList { return &SubscriberList{ - skipfilter: skipfilter.New(func(s interface{}, topic interface{}) bool { - p := strings.SplitN(topic.(string), "_", 2) - if len(p) < 2 { - return false - } + skipfilter: skipfilter.New(func(s interface{}, fil interface{}) bool { + f := fil.(*filter) - return s.(*Subscriber).MatchTopic(p[1], p[0] == "p") + return s.(*Subscriber).MatchTopics(f.topics, f.private) }, size), } } func (sc *SubscriberList) MatchAny(u *Update) (res []*Subscriber) { - scopedTopics := make([]interface{}, len(u.Topics)) - for i, t := range u.Topics { - if u.Private { - scopedTopics[i] = "p_" + t - } else { - scopedTopics[i] = "_" + t - } - } + f := &filter{u.Topics, u.Private} - for _, m := range sc.skipfilter.MatchAny(scopedTopics...) { + for _, m := range sc.skipfilter.MatchAny(f) { res = append(res, m.(*Subscriber)) } diff --git a/subscriber_test.go b/subscriber_test.go index 823bbc69..fbe5e4c4 100644 --- a/subscriber_test.go +++ b/subscriber_test.go @@ -10,15 +10,16 @@ import ( func TestDispatch(t *testing.T) { s := NewSubscriber("1", zap.NewNop()) - s.Topics = []string{"http://example.com"} + s.SubscribedTopics = []string{"http://example.com"} + s.SubscribedTopics = []string{"http://example.com"} defer s.Disconnect() // Dispatch must be non-blocking // Messages coming from the history can be sent after live messages, but must be received first - s.Dispatch(&Update{Topics: s.Topics, Event: Event{ID: "3"}}, false) - s.Dispatch(&Update{Topics: s.Topics, Event: Event{ID: "1"}}, true) - s.Dispatch(&Update{Topics: s.Topics, Event: Event{ID: "4"}}, false) - s.Dispatch(&Update{Topics: s.Topics, Event: Event{ID: "2"}}, true) + s.Dispatch(&Update{Topics: s.SubscribedTopics, Event: Event{ID: "3"}}, false) + s.Dispatch(&Update{Topics: s.SubscribedTopics, Event: Event{ID: "1"}}, true) + s.Dispatch(&Update{Topics: s.SubscribedTopics, Event: Event{ID: "4"}}, false) + s.Dispatch(&Update{Topics: s.SubscribedTopics, Event: Event{ID: "2"}}, true) s.HistoryDispatched("") s.Ready() @@ -56,3 +57,17 @@ func TestLogSubscriber(t *testing.T) { assert.Contains(t, log, `"topic_selectors":["https://example.com/foo"]`) assert.Contains(t, log, `"topics":["https://example.com/bar"]`) } + +func TestMatchTopic(t *testing.T) { + s := NewSubscriber("", zap.NewNop()) + s.SetTopics([]string{"https://example.com/no-match", "https://example.com/books/{id}"}, []string{"https://example.com/users/foo/{?topic}"}) + + assert.False(t, s.Match(&Update{Topics: []string{"https://example.com/not-subscribed"}})) + assert.False(t, s.Match(&Update{Topics: []string{"https://example.com/not-subscribed"}, Private: true})) + assert.False(t, s.Match(&Update{Topics: []string{"https://example.com/no-match"}, Private: true})) + assert.False(t, s.Match(&Update{Topics: []string{"https://example.com/books/1"}, Private: true})) + assert.False(t, s.Match(&Update{Topics: []string{"https://example.com/books/1", "https://example.com/users/bar/?topic=https%3A%2F%2Fexample.com%2Fbooks%2F1"}, Private: true})) + + assert.True(t, s.Match(&Update{Topics: []string{"https://example.com/books/1"}})) + assert.True(t, s.Match(&Update{Topics: []string{"https://example.com/books/1", "https://example.com/users/foo/?topic=https%3A%2F%2Fexample.com%2Fbooks%2F1"}, Private: true})) +} diff --git a/subscription_test.go b/subscription_test.go index 2747fc49..65d5a311 100644 --- a/subscription_test.go +++ b/subscription_test.go @@ -191,7 +191,7 @@ func TestSubscriptionHandler(t *testing.T) { var subscription subscription json.Unmarshal(w.Body.Bytes(), &subscription) - expectedSub := s.getSubscriptions(s.Topics[1], "https://mercure.rocks/", true)[1] + expectedSub := s.getSubscriptions(s.SubscribedTopics[1], "https://mercure.rocks/", true)[1] expectedSub.LastEventID, _, _ = hub.transport.(TransportSubscribers).GetSubscribers() assert.Equal(t, expectedSub, subscription)