Skip to content

Commit

Permalink
Add support for graphql-transport-ws with duplex ping-pong (#1578)
Browse files Browse the repository at this point in the history
* Add support for graphql-transport-ws with duplex ping-pong

* Add tests for the duplex ping-pong
  • Loading branch information
zdraganov authored Nov 22, 2021
1 parent ae92c83 commit 213ecd9
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 9 deletions.
8 changes: 7 additions & 1 deletion example/chat/readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,32 @@
Example app using subscriptions to build a chat room.

### Server

```bash
go run ./server/server.go
```

### Client

The react app uses two different implementation for the websocket link

- [apollo-link-ws](https://www.apollographql.com/docs/react/api/link/apollo-link-ws) which uses the deprecated [subscriptions-transport-ws](https://github.com/apollographql/subscriptions-transport-ws) library
- [graphql-ws](https://github.com/enisdenjo/graphql-ws)

First you need to install the dependencies

```bash
npm install
npm install
```

Then to run the app with the `apollo-link-ws` implementation do

```bash
npm run start
```

or to run the app with the `graphql-ws` implementation (and the newer `graphql-transport-ws` protocol) do

```bash
npm run start:graphql-transport-ws
```
32 changes: 32 additions & 0 deletions graphql/handler/transport/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ type (
Upgrader websocket.Upgrader
InitFunc WebsocketInitFunc
KeepAlivePingInterval time.Duration
PingPongInterval time.Duration

didInjectSubprotocols bool
}
Expand All @@ -33,6 +34,7 @@ type (
active map[string]context.CancelFunc
mu sync.Mutex
keepAliveTicker *time.Ticker
pingPongTicker *time.Ticker
exec graphql.GraphExecutor

initPayload InitPayload
Expand Down Expand Up @@ -157,6 +159,17 @@ func (c *wsConnection) run() {
go c.keepAlive(ctx)
}

// Create a timer that will fire every interval a ping message that should
// receive a pong (SetPongHandler in init() function)
if c.PingPongInterval != 0 {
c.mu.Lock()
c.pingPongTicker = time.NewTicker(c.PingPongInterval)
c.mu.Unlock()

c.conn.SetReadDeadline(time.Now().UTC().Add(2 * c.PingPongInterval))
go c.ping(ctx)
}

for {
start := graphql.Now()
m, err := c.me.NextMessage()
Expand All @@ -178,6 +191,10 @@ func (c *wsConnection) run() {
case connectionCloseMessageType:
c.close(websocket.CloseNormalClosure, "terminated")
return
case pingMesageType:
c.write(&message{t: pongMessageType, payload: m.payload})
case pongMessageType:
c.conn.SetReadDeadline(time.Now().UTC().Add(2 * c.PingPongInterval))
default:
c.sendConnectionError("unexpected message %s", m.t)
c.close(websocket.CloseProtocolError, "unexpected message")
Expand All @@ -198,6 +215,18 @@ func (c *wsConnection) keepAlive(ctx context.Context) {
}
}

func (c *wsConnection) ping(ctx context.Context) {
for {
select {
case <-ctx.Done():
c.pingPongTicker.Stop()
return
case <-c.pingPongTicker.C:
c.write(&message{t: pingMesageType, payload: json.RawMessage{}})
}
}
}

func (c *wsConnection) subscribe(start time.Time, msg *message) {
ctx := graphql.StartOperationTrace(c.ctx)
var params *graphql.RawParams
Expand Down Expand Up @@ -315,6 +344,9 @@ func (c *wsConnection) sendConnectionError(format string, args ...interface{}) {
func (c *wsConnection) close(closeCode int, message string) {
c.mu.Lock()
_ = c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(closeCode, message))
for _, closer := range c.active {
closer()
}
c.mu.Unlock()
_ = c.conn.Close()
}
30 changes: 22 additions & 8 deletions graphql/handler/transport/websocket_graphql_transport_ws.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,22 @@ const (
graphqltransportwsNextMsg = graphqltransportwsMessageType("next")
graphqltransportwsErrorMsg = graphqltransportwsMessageType("error")
graphqltransportwsCompleteMsg = graphqltransportwsMessageType("complete")
graphqltransportwsPingMsg = graphqltransportwsMessageType("ping")
graphqltransportwsPongMsg = graphqltransportwsMessageType("pong")
)

var allGraphqltransportwsMessageTypes = []graphqltransportwsMessageType{
graphqltransportwsConnectionInitMsg,
graphqltransportwsConnectionAckMsg,
graphqltransportwsSubscribeMsg,
graphqltransportwsNextMsg,
graphqltransportwsErrorMsg,
graphqltransportwsCompleteMsg,
}
var (
allGraphqltransportwsMessageTypes = []graphqltransportwsMessageType{
graphqltransportwsConnectionInitMsg,
graphqltransportwsConnectionAckMsg,
graphqltransportwsSubscribeMsg,
graphqltransportwsNextMsg,
graphqltransportwsErrorMsg,
graphqltransportwsCompleteMsg,
graphqltransportwsPingMsg,
graphqltransportwsPongMsg,
}
)

type (
graphqltransportwsMessageExchanger struct {
Expand Down Expand Up @@ -103,6 +109,10 @@ func (m graphqltransportwsMessage) toMessage() (message, error) {
t = startMessageType
case graphqltransportwsCompleteMsg:
t = stopMessageType
case graphqltransportwsPingMsg:
t = pingMesageType
case graphqltransportwsPongMsg:
t = pongMessageType
}

return message{
Expand Down Expand Up @@ -131,6 +141,10 @@ func (m *graphqltransportwsMessage) fromMessage(msg *message) (err error) {
m.Type = graphqltransportwsCompleteMsg
case errorMessageType:
m.Type = graphqltransportwsErrorMsg
case pingMesageType:
m.Type = graphqltransportwsPingMsg
case pongMessageType:
m.Type = graphqltransportwsPongMsg
}

return err
Expand Down
6 changes: 6 additions & 0 deletions graphql/handler/transport/websocket_subprotocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ const (
dataMessageType
completeMessageType
errorMessageType
pingMesageType
pongMessageType
)

var (
Expand Down Expand Up @@ -68,6 +70,10 @@ func (t messageType) String() string {
text = "complete"
case errorMessageType:
text = "error"
case pingMesageType:
text = "ping"
case pongMessageType:
text = "pong"
}
return text
}
Expand Down
35 changes: 35 additions & 0 deletions graphql/handler/transport/websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,39 @@ func TestWebsocketGraphqltransportwsSubprotocol(t *testing.T) {
})
}

func TestWebsocketWithPingPongInterval(t *testing.T) {
handler := testserver.New()
handler.AddTransport(transport.Websocket{
PingPongInterval: time.Second * 1,
})

srv := httptest.NewServer(handler)
defer srv.Close()

t.Run("client receives ping and responds with pong", func(t *testing.T) {
c := wsConnectWithSubprocotol(srv.URL, graphqltransportwsSubprotocol)
defer c.Close()

require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsConnectionInitMsg}))
assert.Equal(t, graphqltransportwsConnectionAckMsg, readOp(c).Type)

assert.Equal(t, graphqltransportwsPingMsg, readOp(c).Type)
require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsPongMsg}))
assert.Equal(t, graphqltransportwsPingMsg, readOp(c).Type)
})

t.Run("client sends ping and expects pong", func(t *testing.T) {
c := wsConnectWithSubprocotol(srv.URL, graphqltransportwsSubprotocol)
defer c.Close()

require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsConnectionInitMsg}))
assert.Equal(t, graphqltransportwsConnectionAckMsg, readOp(c).Type)

require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsPingMsg}))
assert.Equal(t, graphqltransportwsPongMsg, readOp(c).Type)
})
}

func wsConnect(url string) *websocket.Conn {
return wsConnectWithSubprocotol(url, "")
}
Expand Down Expand Up @@ -374,6 +407,8 @@ const (
graphqltransportwsSubscribeMsg = "subscribe"
graphqltransportwsNextMsg = "next"
graphqltransportwsCompleteMsg = "complete"
graphqltransportwsPingMsg = "ping"
graphqltransportwsPongMsg = "pong"
)

type operationMessage struct {
Expand Down
8 changes: 8 additions & 0 deletions handler/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ func GraphQL(exec graphql.ExecutableSchema, options ...Option) http.HandlerFunc
Upgrader: cfg.upgrader,
InitFunc: cfg.websocketInitFunc,
KeepAlivePingInterval: cfg.connectionKeepAlivePingInterval,
PingPongInterval: cfg.connectionPingPongInterval,
})
srv.AddTransport(transport.Options{})
srv.AddTransport(transport.GET{})
Expand Down Expand Up @@ -77,6 +78,7 @@ type Config struct {
upgrader websocket.Upgrader
websocketInitFunc transport.WebsocketInitFunc
connectionKeepAlivePingInterval time.Duration
connectionPingPongInterval time.Duration
recover graphql.RecoverFunc
errorPresenter graphql.ErrorPresenterFunc
fieldHooks []graphql.FieldMiddleware
Expand Down Expand Up @@ -210,6 +212,12 @@ func WebsocketKeepAliveDuration(duration time.Duration) Option {
}
}

func WebsocketPingPongDuration(duration time.Duration) Option {
return func(cfg *Config) {
cfg.connectionPingPongInterval = duration
}
}

// Add cache that will hold queries for automatic persisted queries (APQ)
// Deprecated: switch to graphql/handler.New
func EnablePersistedQueryCache(cache PersistedQueryCache) Option {
Expand Down

0 comments on commit 213ecd9

Please sign in to comment.