diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go index d7fa753d9..b58472914 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go @@ -29,6 +29,26 @@ import ( const ackWaitTimeout = 30 * time.Second +type epollState struct { + // connections is a map of fd -> connection to keep track of all active connections + connections map[int]*connection + hasConnections atomic.Bool + // triggers is a map of subscription id -> fd to easily look up the connection for a subscription id + triggers map[uint64]int + + // clientUnsubscribe is a channel to signal to the epoll run loop that a client needs to be unsubscribed + clientUnsubscribe chan uint64 + // addConn is a channel to signal to the epoll run loop that a new connection needs to be added + addConn chan *connection + // waitForEventsTicker is the ticker for the epoll run loop + // it is used to prevent busy waiting and to limit the CPU usage + // instead of polling the epoll instance all the time, we wait until the next tick to throttle the epoll loop + waitForEventsTicker *time.Ticker + + // waitForEventsTick is the channel to receive the tick from the waitForEventsTicker + waitForEventsTick <-chan time.Time +} + // subscriptionClient allows running multiple subscriptions via the same WebSocket either SSE connection // It takes care of de-duplicating connections to the same origin under certain circumstances // If Hash(URL,Body,Headers) result in the same result, an existing connection is re-used @@ -47,23 +67,7 @@ type subscriptionClient struct { epoll epoller.Poller epollConfig EpollConfiguration - - // connections is a map of fd -> connection to keep track of all active connections - connections map[int]*connection - hasConnections atomic.Bool - // triggers is a map of subscription id -> fd to easily look up the connection for a subscription id - triggers map[uint64]int - - // clientUnsubscribe is a channel to signal to the epoll run loop that a client needs to be unsubscribed - clientUnsubscribe chan uint64 - // addConn is a channel to signal to the epoll run loop that a new connection needs to be added - addConn chan *connection - // waitForEventsTicker is the ticker for the epoll run loop - // it is used to prevent busy waiting and to limit the CPU usage - // instead of polling the epoll instance all the time, we wait until the next tick to throttle the epoll loop - waitForEventsTicker *time.Ticker - // waitForEventsTick is the channel to receive the tick from the waitForEventsTicker - waitForEventsTick <-chan time.Time + epollState *epollState } func (c *subscriptionClient) SubscribeAsync(ctx *resolve.Context, id uint64, options GraphQLSubscriptionOptions, updater resolve.SubscriptionUpdater) error { @@ -81,7 +85,12 @@ func (c *subscriptionClient) SubscribeAsync(ctx *resolve.Context, id uint64, opt } func (c *subscriptionClient) Unsubscribe(id uint64) { - c.clientUnsubscribe <- id + // if we don't have epoll, we don't have a channel consumer of the clientUnsubscribe channel + // we have to return to prevent a deadlock + if c.epoll == nil { + return + } + c.epollState.clientUnsubscribe <- id } type InvalidWsSubprotocolError struct { @@ -195,16 +204,19 @@ func NewGraphQLSubscriptionClient(httpClient, streamingClient *http.Client, engi epollConfig: op.epollConfiguration, } if !op.epollConfiguration.Disable { - client.connections = make(map[int]*connection) - client.triggers = make(map[uint64]int) - client.clientUnsubscribe = make(chan uint64, op.epollConfiguration.BufferSize) - client.addConn = make(chan *connection, op.epollConfiguration.BufferSize) - // this is not needed, but we want to make it explicit that we're starting with nil as the tick channel - // reading from nil channels blocks forever, which allows us to prevent the epoll loop from starting - // once we add the first connection, we start the ticker and set the tick channel - // after the last connection is removed, we set the tick channel to nil again - // this way we can start and stop the epoll loop dynamically - client.waitForEventsTick = nil + client.epollState = &epollState{ + connections: make(map[int]*connection), + triggers: make(map[uint64]int), + clientUnsubscribe: make(chan uint64, op.epollConfiguration.BufferSize), + addConn: make(chan *connection, op.epollConfiguration.BufferSize), + // this is not needed, but we want to make it explicit that we're starting with nil as the tick channel + // reading from nil channels blocks forever, which allows us to prevent the epoll loop from starting + // once we add the first connection, we start the ticker and set the tick channel + // after the last connection is removed, we set the tick channel to nil again + // this way we can start and stop the epoll loop dynamically + waitForEventsTick: nil, + } + // ignore error is ok, it means that epoll is not supported, which is handled gracefully by the client epoll, _ := epoller.NewPoller(op.epollConfiguration.BufferSize, op.epollConfiguration.TickInterval) if epoll != nil { @@ -323,7 +335,7 @@ func (c *subscriptionClient) asyncSubscribeWS(requestContext, engineContext cont fd := epoller.SocketFD(conn.conn) conn.id, conn.fd = id, fd // submit the connection to the epoll run loop - c.addConn <- conn + c.epollState.addConn <- conn return nil } @@ -636,16 +648,16 @@ func (c *subscriptionClient) runEpoll(ctx context.Context) { // if the engine context is done, we close the epoll loop case <-done: return - case conn := <-c.addConn: + case conn := <-c.epollState.addConn: c.handleAddConn(conn) - case id := <-c.clientUnsubscribe: + case id := <-c.epollState.clientUnsubscribe: c.handleClientUnsubscribe(id) // while len(c.connections) == 0, this channel is nil, so we will never try to wait for epoll events // this is important to prevent busy waiting // once we add the first connection, we start the ticker and set the tick channel // the ticker ensures that we don't poll the epoll instance all the time, // but at most every TickInterval - case <-c.waitForEventsTick: + case <-c.epollState.waitForEventsTick: events, err := c.epoll.Wait(c.epollConfig.WaitForNumEvents) if err != nil { c.log.Error("epoll.Wait", abstractlogger.Error(err)) @@ -656,7 +668,7 @@ func (c *subscriptionClient) runEpoll(ctx context.Context) { for i := range events { fd := epoller.SocketFD(events[i]) - conn, ok := c.connections[fd] + conn, ok := c.epollState.connections[fd] if !ok { // Should never happen panic(fmt.Sprintf("connection with fd %d not found", fd)) @@ -684,9 +696,9 @@ func (c *subscriptionClient) runEpoll(ctx context.Context) { } // we decrease the number of events we're waiting for to eventually break the loop waitForEvents-- - case conn := <-c.addConn: + case conn := <-c.epollState.addConn: c.handleAddConn(conn) - case id := <-c.clientUnsubscribe: + case id := <-c.epollState.clientUnsubscribe: c.handleClientUnsubscribe(id) case <-done: return @@ -698,10 +710,10 @@ func (c *subscriptionClient) runEpoll(ctx context.Context) { func (c *subscriptionClient) close() { defer c.log.Debug("subscriptionClient.close", abstractlogger.String("reason", "epoll closed by context")) - if c.waitForEventsTicker != nil { - c.waitForEventsTicker.Stop() + if c.epollState.waitForEventsTicker != nil { + c.epollState.waitForEventsTicker.Stop() } - for _, conn := range c.connections { + for _, conn := range c.epollState.connections { _ = c.epoll.Remove(conn.conn) conn.handler.ServerClose() } @@ -719,52 +731,52 @@ func (c *subscriptionClient) handleAddConn(conn *connection) { conn.handler.ServerClose() return } - c.connections[conn.fd] = conn - c.triggers[conn.id] = conn.fd + c.epollState.connections[conn.fd] = conn + c.epollState.triggers[conn.id] = conn.fd // when we previously had 0 connections, we will have 1 connection now // this means we need to start the ticker so that we get epoll events - if len(c.connections) == 1 { - c.waitForEventsTicker = time.NewTicker(c.epollConfig.TickInterval) - c.waitForEventsTick = c.waitForEventsTicker.C - c.hasConnections.Store(true) + if len(c.epollState.connections) == 1 { + c.epollState.waitForEventsTicker = time.NewTicker(c.epollConfig.TickInterval) + c.epollState.waitForEventsTick = c.epollState.waitForEventsTicker.C + c.epollState.hasConnections.Store(true) } } func (c *subscriptionClient) handleClientUnsubscribe(id uint64) { - fd, ok := c.triggers[id] + fd, ok := c.epollState.triggers[id] if !ok { return } - delete(c.triggers, id) - conn, ok := c.connections[fd] + delete(c.epollState.triggers, id) + conn, ok := c.epollState.connections[fd] if !ok { return } - delete(c.connections, fd) + delete(c.epollState.connections, fd) _ = c.epoll.Remove(conn.conn) conn.handler.ClientClose() // if we have no connections left, we stop the ticker - if len(c.connections) == 0 { - c.waitForEventsTicker.Stop() - c.waitForEventsTick = nil - c.hasConnections.Store(false) + if len(c.epollState.connections) == 0 { + c.epollState.waitForEventsTicker.Stop() + c.epollState.waitForEventsTick = nil + c.epollState.hasConnections.Store(false) } } func (c *subscriptionClient) handleServerUnsubscribe(fd int) { - conn, ok := c.connections[fd] + conn, ok := c.epollState.connections[fd] if !ok { return } - delete(c.connections, fd) - delete(c.triggers, conn.id) + delete(c.epollState.connections, fd) + delete(c.epollState.triggers, conn.id) _ = c.epoll.Remove(conn.conn) conn.handler.ServerClose() // if we have no connections left, we stop the ticker - if len(c.connections) == 0 { - c.waitForEventsTicker.Stop() - c.waitForEventsTick = nil - c.hasConnections.Store(false) + if len(c.epollState.connections) == 0 { + c.epollState.waitForEventsTicker.Stop() + c.epollState.waitForEventsTick = nil + c.epollState.hasConnections.Store(false) } } diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go index ae3e65163..c743efe00 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go @@ -907,7 +907,7 @@ func TestAsyncSubscribe(t *testing.T) { return true }, time.Second*5, time.Millisecond*10, "server did not close") time.Sleep(time.Second) - assert.Equal(t, false, client.hasConnections.Load()) + assert.Equal(t, false, client.epollState.hasConnections.Load()) }) t.Run("forever timeout", func(t *testing.T) { t.Parallel() @@ -1103,7 +1103,7 @@ func TestAsyncSubscribe(t *testing.T) { return true }, time.Second, time.Millisecond*10, "server did not close") serverCancel() - assert.Equal(t, false, client.hasConnections.Load()) + assert.Equal(t, false, client.epollState.hasConnections.Load()) }) t.Run("error object", func(t *testing.T) { t.Parallel()