diff --git a/handler/websocket.go b/handler/websocket.go index 2af04332448..947b3ed9a4a 100644 --- a/handler/websocket.go +++ b/handler/websocket.go @@ -43,6 +43,8 @@ type wsConnection struct { active map[string]context.CancelFunc mu sync.Mutex cfg *Config + + initPayload graphql.InitPayload } func connectWs(exec graphql.ExecutableSchema, w http.ResponseWriter, r *http.Request, cfg *Config) { @@ -63,42 +65,41 @@ func connectWs(exec graphql.ExecutableSchema, w http.ResponseWriter, r *http.Req cfg: cfg, } - initPayload, ok := conn.init() - if !ok { + if !conn.init() { return } - conn.run(initPayload) + conn.run() } -func (c *wsConnection) init() (initPayload graphql.InitPayload, ok bool) { +func (c *wsConnection) init() bool { message := c.readOp() if message == nil { c.close(websocket.CloseProtocolError, "decoding error") - return nil, false + return false } - initPayload = make(graphql.InitPayload) - switch message.Type { case connectionInitMsg: - err := json.Unmarshal(message.Payload, &initPayload) - if err != nil { - // Treat an invalid payload as no payload - initPayload = nil + if len(message.Payload) > 0 { + c.initPayload = make(graphql.InitPayload) + err := json.Unmarshal(message.Payload, &c.initPayload) + if err != nil { + return false + } } c.write(&operationMessage{Type: connectionAckMsg}) case connectionTerminateMsg: c.close(websocket.CloseNormalClosure, "terminated") - return nil, false + return false default: c.sendConnectionError("unexpected message %s", message.Type) c.close(websocket.CloseProtocolError, "unexpected message") - return nil, false + return false } - return initPayload, true + return true } func (c *wsConnection) write(msg *operationMessage) { @@ -107,7 +108,7 @@ func (c *wsConnection) write(msg *operationMessage) { c.mu.Unlock() } -func (c *wsConnection) run(initPayload graphql.InitPayload) { +func (c *wsConnection) run() { for { message := c.readOp() if message == nil { @@ -116,7 +117,7 @@ func (c *wsConnection) run(initPayload graphql.InitPayload) { switch message.Type { case startMsg: - if !c.subscribe(message, initPayload) { + if !c.subscribe(message) { return } case stopMsg: @@ -140,7 +141,7 @@ func (c *wsConnection) run(initPayload graphql.InitPayload) { } } -func (c *wsConnection) subscribe(message *operationMessage, initPayload graphql.InitPayload) bool { +func (c *wsConnection) subscribe(message *operationMessage) bool { var reqParams params if err := jsonDecode(bytes.NewReader(message.Payload), &reqParams); err != nil { c.sendConnectionError("invalid json") @@ -167,8 +168,8 @@ func (c *wsConnection) subscribe(message *operationMessage, initPayload graphql. reqCtx := c.cfg.newRequestContext(doc, reqParams.Query, vars) ctx := graphql.WithRequestContext(c.ctx, reqCtx) - if initPayload != nil { - ctx = graphql.WithInitPayload(ctx, initPayload) + if c.initPayload != nil { + ctx = graphql.WithInitPayload(ctx, c.initPayload) } if op.Operation != ast.Subscription {