diff --git a/common/membership/hashring.go b/common/membership/hashring.go index 1b173b75d28..7c1b6b54759 100644 --- a/common/membership/hashring.go +++ b/common/membership/hashring.go @@ -90,7 +90,7 @@ func newHashring( logger log.Logger, scope metrics.Scope, ) *ring { - ring := &ring{ + r := &ring{ status: common.DaemonStatusInitialized, service: service, peerProvider: provider, @@ -100,11 +100,11 @@ func newHashring( scope: scope, } - ring.members.keys = make(map[string]HostInfo) - ring.subscribers.keys = make(map[string]chan<- *ChangedEvent) + r.members.keys = make(map[string]HostInfo) + r.subscribers.keys = make(map[string]chan<- *ChangedEvent) - ring.value.Store(emptyHashring()) - return ring + r.value.Store(emptyHashring()) + return r } func emptyHashring() *hashring.HashRing { @@ -124,7 +124,7 @@ func (r *ring) Start() { r.logger.Fatal("subscribing to peer provider", tag.Error(err)) } - if err := r.refresh(); err != nil { + if _, err := r.refresh(); err != nil { r.logger.Fatal("failed to start service resolver", tag.Error(err)) } @@ -176,24 +176,32 @@ func (r *ring) Lookup( return host, nil } -// Subscribe registers callback watcher. -// All subscribers are notified about ring change. -func (r *ring) Subscribe( - service string, - notifyChannel chan<- *ChangedEvent, -) error { +// Subscribe registers callback watcher. Services can use this to be informed about membership changes +func (r *ring) Subscribe(watcher string, notifyChannel chan<- *ChangedEvent) error { r.subscribers.Lock() defer r.subscribers.Unlock() - _, ok := r.subscribers.keys[service] + _, ok := r.subscribers.keys[watcher] if ok { - return fmt.Errorf("service %q already subscribed", service) + return fmt.Errorf("watcher %q is already subscribed", watcher) } - r.subscribers.keys[service] = notifyChannel + r.subscribers.keys[watcher] = notifyChannel return nil } +func (r *ring) notifySubscribers(msg *ChangedEvent) { + r.subscribers.Lock() + defer r.subscribers.Unlock() + for name, ch := range r.subscribers.keys { + select { + case ch <- msg: + default: + r.logger.Error("subscriber notification failed", tag.Name(name)) + } + } +} + // Unsubscribe removes subscriber func (r *ring) Unsubscribe( name string, @@ -227,22 +235,22 @@ func (r *ring) Members() []HostInfo { return hosts } -func (r *ring) refresh() error { +func (r *ring) refresh() (refreshed bool, err error) { if r.members.refreshed.After(time.Now().Add(-minRefreshInternal)) { // refreshed too frequently - return nil + return false, nil } members, err := r.peerProvider.GetMembers(r.service) if err != nil { - return fmt.Errorf("getting members from peer provider: %w", err) + return false, fmt.Errorf("getting members from peer provider: %w", err) } r.members.Lock() defer r.members.Unlock() newMembersMap, changed := r.compareMembers(members) if !changed { - return nil + return false, nil } ring := emptyHashring() @@ -253,7 +261,7 @@ func (r *ring) refresh() error { r.value.Store(ring) r.logger.Info("refreshed ring members", tag.Value(members)) - return nil + return true, nil } func (r *ring) refreshRingWorker() { @@ -265,13 +273,17 @@ func (r *ring) refreshRingWorker() { select { case <-r.shutdownCh: return - case <-r.refreshChan: // local signal or signal from provider - if err := r.refresh(); err != nil { + case event := <-r.refreshChan: // local signal or signal from provider + refreshed, err := r.refresh() + if err != nil { r.logger.Error("refreshing ring", tag.Error(err)) } + if refreshed { + r.notifySubscribers(event) + } case <-refreshTicker.C: // periodically refresh membership r.emitHashIdentifier() - if err := r.refresh(); err != nil { + if _, err := r.refresh(); err != nil { r.logger.Error("periodically refreshing ring", tag.Error(err)) } } diff --git a/common/membership/hashring_test.go b/common/membership/hashring_test.go index 01be0b2e833..77be7c0a436 100644 --- a/common/membership/hashring_test.go +++ b/common/membership/hashring_test.go @@ -48,9 +48,9 @@ func randSeq(n int) string { } func randomHostInfo(n int) []HostInfo { - res := make([]HostInfo, n) + res := make([]HostInfo, 0, n) for i := 0; i < n; i++ { - res = append(res, NewHostInfo(randSeq(5))) + res = append(res, NewDetailedHostInfo(randSeq(5), randSeq(12), PortMap{randSeq(3): 666})) } return res } @@ -116,27 +116,73 @@ func TestRefreshUpdatesRingOnlyWhenRingHasChanged(t *testing.T) { pp := NewMockPeerProvider(ctrl) pp.EXPECT().Subscribe(gomock.Any(), gomock.Any()).Times(1) - pp.EXPECT().GetMembers("test-service").Times(3) + pp.EXPECT().GetMembers("test-service").Times(1).Return(randomHostInfo(3), nil) hr := newHashring("test-service", pp, log.NewNoop(), metrics.NoopScope(0)) + // Start will also call .refresh() hr.Start() - - hr.refresh() updatedAt := hr.members.refreshed hr.refresh() + refreshed, err := hr.refresh() + + assert.NoError(t, err) + assert.False(t, refreshed) assert.Equal(t, updatedAt, hr.members.refreshed) } +func TestRefreshWillNotifySubscribers(t *testing.T) { + ctrl := gomock.NewController(t) + pp := NewMockPeerProvider(ctrl) + var hostsToReturn []HostInfo + pp.EXPECT().Subscribe(gomock.Any(), gomock.Any()).Times(1) + pp.EXPECT().GetMembers("test-service").Times(2).DoAndReturn(func(service string) ([]HostInfo, error) { + hostsToReturn = randomHostInfo(5) + time.Sleep(time.Millisecond * 70) + return hostsToReturn, nil + }) + + changed := &ChangedEvent{ + HostsAdded: []string{"a"}, + HostsUpdated: []string{"b"}, + HostsRemoved: []string{"c"}, + } + + hr := newHashring("test-service", pp, log.NewNoop(), metrics.NoopScope(0)) + hr.Start() + + var changeCh = make(chan *ChangedEvent, 2) + // Check if multiple subscribers will get notified + assert.NoError(t, hr.Subscribe("subscriber1", changeCh)) + assert.NoError(t, hr.Subscribe("subscriber2", changeCh)) + + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + changedEvent := <-changeCh + changedEvent2 := <-changeCh + assert.Equal(t, changed, changedEvent) + assert.Equal(t, changed, changedEvent2) + }() + + // to bypass internal check + hr.members.refreshed = time.Now().AddDate(0, 0, -1) + hr.refreshChan <- changed + wg.Wait() // wait until both subscribers will get notification + // Test if internal members are updated + assert.ElementsMatch(t, hr.Members(), hostsToReturn, "members should contain just-added nodes") +} + func TestSubscribeIgnoresDuplicates(t *testing.T) { var changeCh = make(chan *ChangedEvent) ctrl := gomock.NewController(t) pp := NewMockPeerProvider(ctrl) - hr := newHashring("test-service", pp, log.NewNoop(), metrics.NoopScope(0)) + hr := newHashring("test-watcher", pp, log.NewNoop(), metrics.NoopScope(0)) - assert.NoError(t, hr.Subscribe("test-service", changeCh)) - assert.Error(t, hr.Subscribe("test-service", changeCh)) + assert.NoError(t, hr.Subscribe("test-watcher", changeCh)) + assert.Error(t, hr.Subscribe("test-watcher", changeCh)) assert.Equal(t, 1, len(hr.subscribers.keys)) } @@ -190,7 +236,8 @@ func TestErrorIsPropagatedWhenProviderFails(t *testing.T) { pp.EXPECT().GetMembers(gomock.Any()).Return(nil, errors.New("error")) hr := newHashring("test-service", pp, log.NewNoop(), metrics.NoopScope(0)) - assert.Error(t, hr.refresh()) + _, err := hr.refresh() + assert.Error(t, err) } func TestStopWillStopProvider(t *testing.T) { @@ -227,7 +274,8 @@ func TestLookupAndRefreshRaceCondition(t *testing.T) { for i := 0; i < 50; i++ { // to bypass internal check hr.members.refreshed = time.Now().AddDate(0, 0, -1) - assert.NoError(t, hr.refresh()) + _, err := hr.refresh() + assert.NoError(t, err) } wg.Done() }()