Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(ws): deadlock on unsubscribe when epoll disabled #982

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,26 @@ import (

const ackWaitTimeout = 30 * time.Second

type epollState struct {
// connections is a map of fd -> connection to keep track of all active connections
StarpTech marked this conversation as resolved.
Show resolved Hide resolved
connections map[int]*connection
hasConnections atomic.Bool
// triggers is a map of subscription id -> fd to easily look up the connection for a subscription id
triggers map[uint64]int

// clientUnsubscribe is a channel to signal to the epoll run loop that a client needs to be unsubscribed
clientUnsubscribe chan uint64
// addConn is a channel to signal to the epoll run loop that a new connection needs to be added
addConn chan *connection
// waitForEventsTicker is the ticker for the epoll run loop
// it is used to prevent busy waiting and to limit the CPU usage
// instead of polling the epoll instance all the time, we wait until the next tick to throttle the epoll loop
waitForEventsTicker *time.Ticker

// waitForEventsTick is the channel to receive the tick from the waitForEventsTicker
waitForEventsTick <-chan time.Time
}

// subscriptionClient allows running multiple subscriptions via the same WebSocket either SSE connection
// It takes care of de-duplicating connections to the same origin under certain circumstances
// If Hash(URL,Body,Headers) result in the same result, an existing connection is re-used
Expand All @@ -47,23 +67,7 @@ type subscriptionClient struct {

epoll epoller.Poller
epollConfig EpollConfiguration

// connections is a map of fd -> connection to keep track of all active connections
connections map[int]*connection
hasConnections atomic.Bool
// triggers is a map of subscription id -> fd to easily look up the connection for a subscription id
triggers map[uint64]int

// clientUnsubscribe is a channel to signal to the epoll run loop that a client needs to be unsubscribed
clientUnsubscribe chan uint64
// addConn is a channel to signal to the epoll run loop that a new connection needs to be added
addConn chan *connection
// waitForEventsTicker is the ticker for the epoll run loop
// it is used to prevent busy waiting and to limit the CPU usage
// instead of polling the epoll instance all the time, we wait until the next tick to throttle the epoll loop
waitForEventsTicker *time.Ticker
// waitForEventsTick is the channel to receive the tick from the waitForEventsTicker
waitForEventsTick <-chan time.Time
epollState *epollState
}

func (c *subscriptionClient) SubscribeAsync(ctx *resolve.Context, id uint64, options GraphQLSubscriptionOptions, updater resolve.SubscriptionUpdater) error {
Expand All @@ -81,7 +85,12 @@ func (c *subscriptionClient) SubscribeAsync(ctx *resolve.Context, id uint64, opt
}

func (c *subscriptionClient) Unsubscribe(id uint64) {
c.clientUnsubscribe <- id
// if we don't have epoll, we don't have a channel consumer of the clientUnsubscribe channel
// we have to return to prevent a deadlock
if c.epoll == nil {
return
}
c.epollState.clientUnsubscribe <- id
}

type InvalidWsSubprotocolError struct {
Expand Down Expand Up @@ -195,16 +204,19 @@ func NewGraphQLSubscriptionClient(httpClient, streamingClient *http.Client, engi
epollConfig: op.epollConfiguration,
}
if !op.epollConfiguration.Disable {
client.connections = make(map[int]*connection)
client.triggers = make(map[uint64]int)
client.clientUnsubscribe = make(chan uint64, op.epollConfiguration.BufferSize)
client.addConn = make(chan *connection, op.epollConfiguration.BufferSize)
// this is not needed, but we want to make it explicit that we're starting with nil as the tick channel
// reading from nil channels blocks forever, which allows us to prevent the epoll loop from starting
// once we add the first connection, we start the ticker and set the tick channel
// after the last connection is removed, we set the tick channel to nil again
// this way we can start and stop the epoll loop dynamically
client.waitForEventsTick = nil
client.epollState = &epollState{
connections: make(map[int]*connection),
triggers: make(map[uint64]int),
clientUnsubscribe: make(chan uint64, op.epollConfiguration.BufferSize),
addConn: make(chan *connection, op.epollConfiguration.BufferSize),
// this is not needed, but we want to make it explicit that we're starting with nil as the tick channel
// reading from nil channels blocks forever, which allows us to prevent the epoll loop from starting
// once we add the first connection, we start the ticker and set the tick channel
// after the last connection is removed, we set the tick channel to nil again
// this way we can start and stop the epoll loop dynamically
waitForEventsTick: nil,
}

// ignore error is ok, it means that epoll is not supported, which is handled gracefully by the client
epoll, _ := epoller.NewPoller(op.epollConfiguration.BufferSize, op.epollConfiguration.TickInterval)
if epoll != nil {
Expand Down Expand Up @@ -323,7 +335,7 @@ func (c *subscriptionClient) asyncSubscribeWS(requestContext, engineContext cont
fd := epoller.SocketFD(conn.conn)
conn.id, conn.fd = id, fd
// submit the connection to the epoll run loop
c.addConn <- conn
c.epollState.addConn <- conn
return nil
}

Expand Down Expand Up @@ -636,16 +648,16 @@ func (c *subscriptionClient) runEpoll(ctx context.Context) {
// if the engine context is done, we close the epoll loop
case <-done:
return
case conn := <-c.addConn:
case conn := <-c.epollState.addConn:
c.handleAddConn(conn)
case id := <-c.clientUnsubscribe:
case id := <-c.epollState.clientUnsubscribe:
c.handleClientUnsubscribe(id)
// while len(c.connections) == 0, this channel is nil, so we will never try to wait for epoll events
// this is important to prevent busy waiting
// once we add the first connection, we start the ticker and set the tick channel
// the ticker ensures that we don't poll the epoll instance all the time,
// but at most every TickInterval
case <-c.waitForEventsTick:
case <-c.epollState.waitForEventsTick:
events, err := c.epoll.Wait(c.epollConfig.WaitForNumEvents)
if err != nil {
c.log.Error("epoll.Wait", abstractlogger.Error(err))
Expand All @@ -656,7 +668,7 @@ func (c *subscriptionClient) runEpoll(ctx context.Context) {

for i := range events {
fd := epoller.SocketFD(events[i])
conn, ok := c.connections[fd]
conn, ok := c.epollState.connections[fd]
if !ok {
// Should never happen
panic(fmt.Sprintf("connection with fd %d not found", fd))
Expand Down Expand Up @@ -684,9 +696,9 @@ func (c *subscriptionClient) runEpoll(ctx context.Context) {
}
// we decrease the number of events we're waiting for to eventually break the loop
waitForEvents--
case conn := <-c.addConn:
case conn := <-c.epollState.addConn:
c.handleAddConn(conn)
case id := <-c.clientUnsubscribe:
case id := <-c.epollState.clientUnsubscribe:
c.handleClientUnsubscribe(id)
case <-done:
return
Expand All @@ -698,10 +710,10 @@ func (c *subscriptionClient) runEpoll(ctx context.Context) {

func (c *subscriptionClient) close() {
defer c.log.Debug("subscriptionClient.close", abstractlogger.String("reason", "epoll closed by context"))
if c.waitForEventsTicker != nil {
c.waitForEventsTicker.Stop()
if c.epollState.waitForEventsTicker != nil {
c.epollState.waitForEventsTicker.Stop()
}
for _, conn := range c.connections {
for _, conn := range c.epollState.connections {
_ = c.epoll.Remove(conn.conn)
conn.handler.ServerClose()
}
Expand All @@ -719,52 +731,52 @@ func (c *subscriptionClient) handleAddConn(conn *connection) {
conn.handler.ServerClose()
return
}
c.connections[conn.fd] = conn
c.triggers[conn.id] = conn.fd
c.epollState.connections[conn.fd] = conn
c.epollState.triggers[conn.id] = conn.fd
// when we previously had 0 connections, we will have 1 connection now
// this means we need to start the ticker so that we get epoll events
if len(c.connections) == 1 {
c.waitForEventsTicker = time.NewTicker(c.epollConfig.TickInterval)
c.waitForEventsTick = c.waitForEventsTicker.C
c.hasConnections.Store(true)
if len(c.epollState.connections) == 1 {
c.epollState.waitForEventsTicker = time.NewTicker(c.epollConfig.TickInterval)
c.epollState.waitForEventsTick = c.epollState.waitForEventsTicker.C
c.epollState.hasConnections.Store(true)
}
}

func (c *subscriptionClient) handleClientUnsubscribe(id uint64) {
fd, ok := c.triggers[id]
fd, ok := c.epollState.triggers[id]
if !ok {
return
}
delete(c.triggers, id)
conn, ok := c.connections[fd]
delete(c.epollState.triggers, id)
conn, ok := c.epollState.connections[fd]
if !ok {
return
}
delete(c.connections, fd)
delete(c.epollState.connections, fd)
_ = c.epoll.Remove(conn.conn)
conn.handler.ClientClose()
// if we have no connections left, we stop the ticker
if len(c.connections) == 0 {
c.waitForEventsTicker.Stop()
c.waitForEventsTick = nil
c.hasConnections.Store(false)
if len(c.epollState.connections) == 0 {
c.epollState.waitForEventsTicker.Stop()
c.epollState.waitForEventsTick = nil
c.epollState.hasConnections.Store(false)
}
}

func (c *subscriptionClient) handleServerUnsubscribe(fd int) {
conn, ok := c.connections[fd]
conn, ok := c.epollState.connections[fd]
if !ok {
return
}
delete(c.connections, fd)
delete(c.triggers, conn.id)
delete(c.epollState.connections, fd)
delete(c.epollState.triggers, conn.id)
_ = c.epoll.Remove(conn.conn)
conn.handler.ServerClose()
// if we have no connections left, we stop the ticker
if len(c.connections) == 0 {
c.waitForEventsTicker.Stop()
c.waitForEventsTick = nil
c.hasConnections.Store(false)
if len(c.epollState.connections) == 0 {
c.epollState.waitForEventsTicker.Stop()
c.epollState.waitForEventsTick = nil
c.epollState.hasConnections.Store(false)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -907,7 +907,7 @@ func TestAsyncSubscribe(t *testing.T) {
return true
}, time.Second*5, time.Millisecond*10, "server did not close")
time.Sleep(time.Second)
assert.Equal(t, false, client.hasConnections.Load())
assert.Equal(t, false, client.epollState.hasConnections.Load())
})
t.Run("forever timeout", func(t *testing.T) {
t.Parallel()
Expand Down Expand Up @@ -1103,7 +1103,7 @@ func TestAsyncSubscribe(t *testing.T) {
return true
}, time.Second, time.Millisecond*10, "server did not close")
serverCancel()
assert.Equal(t, false, client.hasConnections.Load())
assert.Equal(t, false, client.epollState.hasConnections.Load())
})
t.Run("error object", func(t *testing.T) {
t.Parallel()
Expand Down
Loading