diff --git a/graphql/handler/transport/websocket.go b/graphql/handler/transport/websocket.go index af35ad13e98..1969f9e2855 100644 --- a/graphql/handler/transport/websocket.go +++ b/graphql/handler/transport/websocket.go @@ -28,6 +28,15 @@ type ( KeepAlivePingInterval time.Duration PongOnlyInterval time.Duration PingPongInterval time.Duration + /* If PingPongInterval has a non-0 duration, then when the server sends a ping + * it sets a ReadDeadline of PingPongInterval*2 and if the client doesn't respond + * with pong before that deadline is reached then the connection will die with a + * 1006 error code. + * + * MissingPongOk if true, tells the server to not use a ReadDeadline such that a + * missing/slow pong response from the client doesn't kill the connection. + */ + MissingPongOk bool didInjectSubprotocols bool } @@ -41,6 +50,7 @@ type ( keepAliveTicker *time.Ticker pongOnlyTicker *time.Ticker pingPongTicker *time.Ticker + receivedPong bool exec graphql.GraphExecutor closed bool @@ -258,9 +268,11 @@ func (c *wsConnection) run() { c.pingPongTicker = time.NewTicker(c.PingPongInterval) c.mu.Unlock() - // Note: when the connection is closed by this deadline, the client - // will receive an "invalid close code" - c.conn.SetReadDeadline(time.Now().UTC().Add(2 * c.PingPongInterval)) + if !c.MissingPongOk { + // Note: when the connection is closed by this deadline, the client + // will receive an "invalid close code" + c.conn.SetReadDeadline(time.Now().UTC().Add(2 * c.PingPongInterval)) + } go c.ping(ctx) } @@ -295,7 +307,11 @@ func (c *wsConnection) run() { case pingMessageType: c.write(&message{t: pongMessageType, payload: m.payload}) case pongMessageType: - c.conn.SetReadDeadline(time.Now().UTC().Add(2 * c.PingPongInterval)) + c.mu.Lock() + c.receivedPong = true + c.mu.Unlock() + // Clear ReadTimeout -- 0 time val clears. + c.conn.SetReadDeadline(time.Time{}) default: c.sendConnectionError("unexpected message %s", m.t) c.close(websocket.CloseProtocolError, "unexpected message") @@ -336,6 +352,14 @@ func (c *wsConnection) ping(ctx context.Context) { return case <-c.pingPongTicker.C: c.write(&message{t: pingMessageType, payload: json.RawMessage{}}) + // The initial deadline for this method is set in run() + // if we have not yet received a pong, don't reset the deadline. + c.mu.Lock() + if !c.MissingPongOk && c.receivedPong { + c.conn.SetReadDeadline(time.Now().UTC().Add(2 * c.PingPongInterval)) + } + c.receivedPong = false + c.mu.Unlock() } } } diff --git a/graphql/handler/transport/websocket_test.go b/graphql/handler/transport/websocket_test.go index cc4f42d1f8b..f05a9df1898 100644 --- a/graphql/handler/transport/websocket_test.go +++ b/graphql/handler/transport/websocket_test.go @@ -638,7 +638,7 @@ func TestWebsocketWithPingPongInterval(t *testing.T) { } t.Run("client receives ping and responds with pong", func(t *testing.T) { - _, srv := initialize(transport.Websocket{PingPongInterval: 10 * time.Millisecond}) + _, srv := initialize(transport.Websocket{PingPongInterval: 20 * time.Millisecond}) defer srv.Close() c := wsConnectWithSubprocotol(srv.URL, graphqltransportwsSubprotocol) @@ -652,6 +652,11 @@ func TestWebsocketWithPingPongInterval(t *testing.T) { assert.Equal(t, graphqltransportwsPingMsg, readOp(c).Type) }) + t.Run("client sends ping and expects pong", func(t *testing.T) { + _, srv := initialize(transport.Websocket{PingPongInterval: 10 * time.Millisecond}) + defer srv.Close() + }) + t.Run("client sends ping and expects pong", func(t *testing.T) { _, srv := initialize(transport.Websocket{PingPongInterval: 10 * time.Millisecond}) defer srv.Close() @@ -666,6 +671,67 @@ func TestWebsocketWithPingPongInterval(t *testing.T) { assert.Equal(t, graphqltransportwsPongMsg, readOp(c).Type) }) + t.Run("server closes with error if client does not pong and !MissingPongOk", func(t *testing.T) { + h := testserver.New() + closeFuncCalled := make(chan bool, 1) + h.AddTransport(transport.Websocket{ + MissingPongOk: false, // default value but beign explicit for test clarity. + PingPongInterval: 5 * time.Millisecond, + CloseFunc: func(_ context.Context, _closeCode int) { + closeFuncCalled <- true + }, + }) + + srv := httptest.NewServer(h) + defer srv.Close() + + c := wsConnectWithSubprocotol(srv.URL, graphqltransportwsSubprotocol) + defer c.Close() + + require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsConnectionInitMsg})) + assert.Equal(t, graphqltransportwsConnectionAckMsg, readOp(c).Type) + + assert.Equal(t, graphqltransportwsPingMsg, readOp(c).Type) + + select { + case res := <-closeFuncCalled: + assert.True(t, res) + case <-time.NewTimer(time.Millisecond * 20).C: + // with a 5ms interval 10ms should be the timeout, double that to make the test less likely to flake under load + assert.Fail(t, "The close handler was not called in time") + } + }) + + t.Run("server does not close with error if client does not pong and MissingPongOk", func(t *testing.T) { + h := testserver.New() + closeFuncCalled := make(chan bool, 1) + h.AddTransport(transport.Websocket{ + MissingPongOk: true, + PingPongInterval: 10 * time.Millisecond, + CloseFunc: func(_ context.Context, _closeCode int) { + closeFuncCalled <- true + }, + }) + + srv := httptest.NewServer(h) + defer srv.Close() + + c := wsConnectWithSubprocotol(srv.URL, graphqltransportwsSubprotocol) + defer c.Close() + + require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsConnectionInitMsg})) + assert.Equal(t, graphqltransportwsConnectionAckMsg, readOp(c).Type) + + assert.Equal(t, graphqltransportwsPingMsg, readOp(c).Type) + + select { + case <-closeFuncCalled: + assert.Fail(t, "The close handler was called even with MissingPongOk = true") + case _, ok := <-time.NewTimer(time.Millisecond * 20).C: + assert.True(t, ok) + } + }) + t.Run("ping-pongs are not sent when the graphql-ws sub protocol is used", func(t *testing.T) { // Regression test // ---