diff --git a/bolt_transport_test.go b/bolt_transport_test.go index 61e5d1ca..7cca0b05 100644 --- a/bolt_transport_test.go +++ b/bolt_transport_test.go @@ -227,7 +227,7 @@ func TestBoltTransportClosed(t *testing.T) { assert.Equal(t, transport.Dispatch(&Update{Topics: s.Topics}), ErrClosedTransport) - _, ok := <-s.disconnected + _, ok := <-s.out assert.False(t, ok) } diff --git a/local_transport_bench_test.go b/local_transport_bench_test.go index a591ac89..23c80d9a 100644 --- a/local_transport_bench_test.go +++ b/local_transport_bench_test.go @@ -39,7 +39,6 @@ func subBenchLocalTransport(b *testing.B, topics, concurrency, matchPct int, tes } } out := make(chan *Update, 50000) - once := &sync.Once{} for i := 0; i < concurrency; i++ { s := NewSubscriber("", zap.NewNop()) if i%100 < matchPct { @@ -48,7 +47,6 @@ func subBenchLocalTransport(b *testing.B, topics, concurrency, matchPct int, tes s.SetTopics(tsNoMatch, nil) } s.out = out - s.disconnectedOnce = once tr.AddSubscriber(s) } ctx, done := context.WithCancel(context.Background()) diff --git a/subscribe_test.go b/subscribe_test.go index 4e440ae7..878fcd4f 100644 --- a/subscribe_test.go +++ b/subscribe_test.go @@ -272,7 +272,7 @@ func TestUnsubscribe(t *testing.T) { hub.SubscribeHandler(httptest.NewRecorder(), req) assert.Equal(t, 0, s.subscribers.Len()) s.subscribers.Walk(0, func(s *Subscriber) bool { - _, ok := <-s.disconnected + _, ok := <-s.out assert.False(t, ok) return true diff --git a/subscriber.go b/subscriber.go index 806be5ea..2603ccc5 100644 --- a/subscriber.go +++ b/subscriber.go @@ -26,9 +26,9 @@ type Subscriber struct { PrivateRegexps []*regexp.Regexp Debug bool - disconnectedOnce *sync.Once + disconnected int32 out chan *Update - disconnected chan struct{} + outMutex sync.RWMutex responseLastEventID chan string logger Logger ready int32 @@ -45,9 +45,7 @@ func NewSubscriber(lastEventID string, logger Logger) *Subscriber { RequestLastEventID: lastEventID, responseLastEventID: make(chan string, 1), out: make(chan *Update, 1000), - disconnected: make(chan struct{}), logger: logger, - disconnectedOnce: &sync.Once{}, } return s @@ -55,10 +53,8 @@ func NewSubscriber(lastEventID string, logger Logger) *Subscriber { // Dispatch an update to the subscriber. func (s *Subscriber) Dispatch(u *Update, fromHistory bool) bool { - select { - case <-s.disconnected: + if atomic.LoadInt32(&s.disconnected) > 0 { return false - default: } if !fromHistory && atomic.LoadInt32(&s.ready) < 1 { @@ -72,12 +68,12 @@ func (s *Subscriber) Dispatch(u *Update, fromHistory bool) bool { s.liveMutex.Unlock() } - select { - case <-s.disconnected: + s.outMutex.RLock() + defer s.outMutex.RUnlock() + if atomic.LoadInt32(&s.disconnected) > 0 { return false - - default: } + s.out <- u return true @@ -109,10 +105,15 @@ func (s *Subscriber) HistoryDispatched(responseLastEventID string) { // Disconnect disconnects the subscriber. func (s *Subscriber) Disconnect() { - s.disconnectedOnce.Do(func() { - close(s.disconnected) - close(s.out) - }) + if atomic.LoadInt32(&s.disconnected) > 0 { + return + } + + s.outMutex.Lock() + defer s.outMutex.Unlock() + + atomic.StoreInt32(&s.disconnected, 1) + close(s.out) } // SetTopics compiles topic selector regexps. diff --git a/transport_test.go b/transport_test.go index 9c1bdb4f..e9ccadab 100644 --- a/transport_test.go +++ b/transport_test.go @@ -62,7 +62,7 @@ func TestLocalTransportClosed(t *testing.T) { assert.Equal(t, transport.AddSubscriber(NewSubscriber("", zap.NewNop())), ErrClosedTransport) assert.Equal(t, transport.Dispatch(&Update{}), ErrClosedTransport) - _, ok := <-s.disconnected + _, ok := <-s.out assert.False(t, ok) }