Skip to content

Commit

Permalink
Fix unsubscribe during inflight subscribe in goroutine (#372)
Browse files Browse the repository at this point in the history
  • Loading branch information
FZambia authored May 2, 2024
1 parent 90bf800 commit 2188a02
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 4 deletions.
40 changes: 36 additions & 4 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ const (
// ChannelContext contains extra context for channel connection subscribed to.
// Note: this struct is aligned to consume less memory.
type ChannelContext struct {
subscribingCh chan struct{}
info []byte
expireAt int64
positionCheckTime int64
Expand Down Expand Up @@ -1596,11 +1597,14 @@ func (c *Client) handleRefresh(req *protocol.RefreshRequest, cmd *protocol.Comma
// Channel kept in a map during subscribe request to check for duplicate subscription attempts.
func (c *Client) onSubscribeError(channel string) {
c.mu.Lock()
_, ok := c.channels[channel]
chCtx, ok := c.channels[channel]
delete(c.channels, channel)
c.mu.Unlock()
if ok {
_ = c.node.removeSubscription(channel, c)
if chCtx.subscribingCh != nil {
close(chCtx.subscribingCh)
}
}
}

Expand Down Expand Up @@ -2633,8 +2637,12 @@ func (c *Client) validateSubscribeRequest(cmd *protocol.SubscribeRequest) (*Erro
return ErrorLimitExceeded, nil
}
// Put channel to a map to track duplicate subscriptions. This channel should
// be removed from a map upon an error during subscribe.
c.channels[channel] = ChannelContext{}
// be removed from a map upon an error during subscribe. Also initialize subscribingCh
// which is used to sync unsubscribe requests with inflight subscriptions (useful when
// subscribe is performed in a separate goroutine).
c.channels[channel] = ChannelContext{
subscribingCh: make(chan struct{}),
}
c.mu.Unlock()

return nil, nil
Expand Down Expand Up @@ -2898,6 +2906,11 @@ func (c *Client) subscribeCmd(req *protocol.SubscribeRequest, reply SubscribeRep
if !serverSide {
// In case of server-side sub this will be done later by the caller.
c.mu.Lock()
if chCtx, ok := c.channels[channel]; ok {
subscribedCh := chCtx.subscribingCh
defer func() { close(subscribedCh) }()
channelContext.subscribingCh = subscribedCh
}
c.channels[channel] = channelContext
c.mu.Unlock()
// Stop syncing recovery and PUB/SUB.
Expand Down Expand Up @@ -3063,12 +3076,31 @@ func (c *Client) unsubscribe(channel string, unsubscribe Unsubscribe, disconnect
c.mu.RLock()
info := c.clientInfo(channel)
chCtx, ok := c.channels[channel]
subscribingCh := chCtx.subscribingCh
isSubscribed := channelHasFlag(chCtx.flags, flagSubscribed)
serverSide := channelHasFlag(chCtx.flags, flagServerSide)
c.mu.RUnlock()
if !ok {
return nil
}

serverSide := channelHasFlag(chCtx.flags, flagServerSide)
if !serverSide && !isSubscribed && subscribingCh != nil {
// If client is not yet subscribed on a client-side channel, and subscribe
// command is in progress - we need to wait for it to finish before proceeding.
// If client hits subscribe or unsubscribe timeouts – it reconnects, so we never
// hang long here.
select {
case <-subscribingCh:
c.mu.RLock()
chCtx, ok = c.channels[channel]
c.mu.RUnlock()
if !ok {
return nil
}
case <-c.Context().Done():
return nil
}
}

c.mu.Lock()
delete(c.channels, channel)
Expand Down
77 changes: 77 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3821,3 +3821,80 @@ func TestClient_HandleCommandV2(t *testing.T) {
}, 0)
require.False(t, ok)
}

// Not looking at subscribe result - just execute subscribe command.
func asyncSubscribeClient(t testing.TB, client *Client, ch string) {
rwWrapper := testReplyWriterWrapper()
err := client.handleSubscribe(&protocol.SubscribeRequest{
Channel: ch,
}, &protocol.Command{Id: 1}, time.Now(), rwWrapper.rw)
require.NoError(t, err)
}

func TestClientUnsubscribeDuringSubscribe(t *testing.T) {
t.Parallel()
node := defaultNodeNoHandlers()
subscribedCh := make(chan struct{}, 2)
unsubscribedCh := make(chan struct{}, 2)
node.OnConnect(func(client *Client) {
client.OnSubscribe(func(e SubscribeEvent, cb SubscribeCallback) {
go func() {
defer func() {
subscribedCh <- struct{}{}
}()
time.Sleep(200 * time.Millisecond)
cb(SubscribeReply{}, nil)
}()
})
client.OnUnsubscribe(func(e UnsubscribeEvent) {
<-subscribedCh
unsubscribedCh <- struct{}{}
})
})
defer func() { _ = node.Shutdown(context.Background()) }()
client := newTestClient(t, node, "42")
connectClientV2(t, client)
asyncSubscribeClient(t, client, "test")
client.Unsubscribe("test")
client.mu.Lock()
_, ok := client.channels["test"]
client.mu.Unlock()
require.False(t, ok)
waitWithTimeout(t, unsubscribedCh)
asyncSubscribeClient(t, client, "test")
err := client.close(DisconnectForceNoReconnect)
waitWithTimeout(t, unsubscribedCh)
require.NoError(t, err)
}

func TestClientUnsubscribeDuringSubscribeWithError(t *testing.T) {
t.Parallel()
node := defaultNodeNoHandlers()
subscribedCh := make(chan struct{}, 1)
node.OnConnect(func(client *Client) {
client.OnSubscribe(func(e SubscribeEvent, cb SubscribeCallback) {
go func() {
defer func() {
subscribedCh <- struct{}{}
}()
time.Sleep(200 * time.Millisecond)
cb(SubscribeReply{}, ErrorInternal)
}()
})
client.OnUnsubscribe(func(e UnsubscribeEvent) {
t.Fatal("unexpected unsubscribe")
})
})
defer func() { _ = node.Shutdown(context.Background()) }()
client := newTestClient(t, node, "42")
connectClientV2(t, client)
asyncSubscribeClient(t, client, "test")
client.Unsubscribe("test")
client.mu.Lock()
_, ok := client.channels["test"]
client.mu.Unlock()
require.False(t, ok)
waitWithTimeout(t, subscribedCh)
err := client.close(DisconnectForceNoReconnect)
require.NoError(t, err)
}

0 comments on commit 2188a02

Please sign in to comment.