Skip to content

Commit

Permalink
Add ability to not fail when pong is not received. (#2815)
Browse files Browse the repository at this point in the history
I also changed how the read deadline set works a little, the reason for
this is that the protocol allows for pong to be sent without a ping.

So setting a read deadline on receiving pong isn't great. Instead we
should always set the read deadline on sending ping. Though to do this
we need to know whether we have received a pong or not. Because if we
set the read deadline when the previous ping still hasn't received the
pong. Then it will never hit the deadline.
  • Loading branch information
Chris Pride authored Oct 1, 2023
1 parent 89ac736 commit 37f8e4e
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 5 deletions.
32 changes: 28 additions & 4 deletions graphql/handler/transport/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,15 @@ type (
KeepAlivePingInterval time.Duration
PongOnlyInterval time.Duration
PingPongInterval time.Duration
/* If PingPongInterval has a non-0 duration, then when the server sends a ping
* it sets a ReadDeadline of PingPongInterval*2 and if the client doesn't respond
* with pong before that deadline is reached then the connection will die with a
* 1006 error code.
*
* MissingPongOk if true, tells the server to not use a ReadDeadline such that a
* missing/slow pong response from the client doesn't kill the connection.
*/
MissingPongOk bool

didInjectSubprotocols bool
}
Expand All @@ -41,6 +50,7 @@ type (
keepAliveTicker *time.Ticker
pongOnlyTicker *time.Ticker
pingPongTicker *time.Ticker
receivedPong bool
exec graphql.GraphExecutor
closed bool

Expand Down Expand Up @@ -258,9 +268,11 @@ func (c *wsConnection) run() {
c.pingPongTicker = time.NewTicker(c.PingPongInterval)
c.mu.Unlock()

// Note: when the connection is closed by this deadline, the client
// will receive an "invalid close code"
c.conn.SetReadDeadline(time.Now().UTC().Add(2 * c.PingPongInterval))
if !c.MissingPongOk {
// Note: when the connection is closed by this deadline, the client
// will receive an "invalid close code"
c.conn.SetReadDeadline(time.Now().UTC().Add(2 * c.PingPongInterval))
}
go c.ping(ctx)
}

Expand Down Expand Up @@ -295,7 +307,11 @@ func (c *wsConnection) run() {
case pingMessageType:
c.write(&message{t: pongMessageType, payload: m.payload})
case pongMessageType:
c.conn.SetReadDeadline(time.Now().UTC().Add(2 * c.PingPongInterval))
c.mu.Lock()
c.receivedPong = true
c.mu.Unlock()
// Clear ReadTimeout -- 0 time val clears.
c.conn.SetReadDeadline(time.Time{})
default:
c.sendConnectionError("unexpected message %s", m.t)
c.close(websocket.CloseProtocolError, "unexpected message")
Expand Down Expand Up @@ -336,6 +352,14 @@ func (c *wsConnection) ping(ctx context.Context) {
return
case <-c.pingPongTicker.C:
c.write(&message{t: pingMessageType, payload: json.RawMessage{}})
// The initial deadline for this method is set in run()
// if we have not yet received a pong, don't reset the deadline.
c.mu.Lock()
if !c.MissingPongOk && c.receivedPong {
c.conn.SetReadDeadline(time.Now().UTC().Add(2 * c.PingPongInterval))
}
c.receivedPong = false
c.mu.Unlock()
}
}
}
Expand Down
68 changes: 67 additions & 1 deletion graphql/handler/transport/websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -638,7 +638,7 @@ func TestWebsocketWithPingPongInterval(t *testing.T) {
}

t.Run("client receives ping and responds with pong", func(t *testing.T) {
_, srv := initialize(transport.Websocket{PingPongInterval: 10 * time.Millisecond})
_, srv := initialize(transport.Websocket{PingPongInterval: 20 * time.Millisecond})
defer srv.Close()

c := wsConnectWithSubprocotol(srv.URL, graphqltransportwsSubprotocol)
Expand All @@ -652,6 +652,11 @@ func TestWebsocketWithPingPongInterval(t *testing.T) {
assert.Equal(t, graphqltransportwsPingMsg, readOp(c).Type)
})

t.Run("client sends ping and expects pong", func(t *testing.T) {
_, srv := initialize(transport.Websocket{PingPongInterval: 10 * time.Millisecond})
defer srv.Close()
})

t.Run("client sends ping and expects pong", func(t *testing.T) {
_, srv := initialize(transport.Websocket{PingPongInterval: 10 * time.Millisecond})
defer srv.Close()
Expand All @@ -666,6 +671,67 @@ func TestWebsocketWithPingPongInterval(t *testing.T) {
assert.Equal(t, graphqltransportwsPongMsg, readOp(c).Type)
})

t.Run("server closes with error if client does not pong and !MissingPongOk", func(t *testing.T) {
h := testserver.New()
closeFuncCalled := make(chan bool, 1)
h.AddTransport(transport.Websocket{
MissingPongOk: false, // default value but beign explicit for test clarity.
PingPongInterval: 5 * time.Millisecond,
CloseFunc: func(_ context.Context, _closeCode int) {
closeFuncCalled <- true
},
})

srv := httptest.NewServer(h)
defer srv.Close()

c := wsConnectWithSubprocotol(srv.URL, graphqltransportwsSubprotocol)
defer c.Close()

require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsConnectionInitMsg}))
assert.Equal(t, graphqltransportwsConnectionAckMsg, readOp(c).Type)

assert.Equal(t, graphqltransportwsPingMsg, readOp(c).Type)

select {
case res := <-closeFuncCalled:
assert.True(t, res)
case <-time.NewTimer(time.Millisecond * 20).C:
// with a 5ms interval 10ms should be the timeout, double that to make the test less likely to flake under load
assert.Fail(t, "The close handler was not called in time")
}
})

t.Run("server does not close with error if client does not pong and MissingPongOk", func(t *testing.T) {
h := testserver.New()
closeFuncCalled := make(chan bool, 1)
h.AddTransport(transport.Websocket{
MissingPongOk: true,
PingPongInterval: 10 * time.Millisecond,
CloseFunc: func(_ context.Context, _closeCode int) {
closeFuncCalled <- true
},
})

srv := httptest.NewServer(h)
defer srv.Close()

c := wsConnectWithSubprocotol(srv.URL, graphqltransportwsSubprotocol)
defer c.Close()

require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsConnectionInitMsg}))
assert.Equal(t, graphqltransportwsConnectionAckMsg, readOp(c).Type)

assert.Equal(t, graphqltransportwsPingMsg, readOp(c).Type)

select {
case <-closeFuncCalled:
assert.Fail(t, "The close handler was called even with MissingPongOk = true")
case _, ok := <-time.NewTimer(time.Millisecond * 20).C:
assert.True(t, ok)
}
})

t.Run("ping-pongs are not sent when the graphql-ws sub protocol is used", func(t *testing.T) {
// Regression test
// ---
Expand Down

0 comments on commit 37f8e4e

Please sign in to comment.