Skip to content

Commit

Permalink
Moved initPayload to wsConnection member, changed wsConnection.init t…
Browse files Browse the repository at this point in the history
…o return false on invalid payload
  • Loading branch information
gissleh committed Sep 20, 2018
1 parent 25268ef commit 01923de
Showing 1 changed file with 20 additions and 19 deletions.
39 changes: 20 additions & 19 deletions handler/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ type wsConnection struct {
active map[string]context.CancelFunc
mu sync.Mutex
cfg *Config

initPayload graphql.InitPayload
}

func connectWs(exec graphql.ExecutableSchema, w http.ResponseWriter, r *http.Request, cfg *Config) {
Expand All @@ -63,42 +65,41 @@ func connectWs(exec graphql.ExecutableSchema, w http.ResponseWriter, r *http.Req
cfg: cfg,
}

initPayload, ok := conn.init()
if !ok {
if !conn.init() {
return
}

conn.run(initPayload)
conn.run()
}

func (c *wsConnection) init() (initPayload graphql.InitPayload, ok bool) {
func (c *wsConnection) init() bool {
message := c.readOp()
if message == nil {
c.close(websocket.CloseProtocolError, "decoding error")
return nil, false
return false
}

initPayload = make(graphql.InitPayload)

switch message.Type {
case connectionInitMsg:
err := json.Unmarshal(message.Payload, &initPayload)
if err != nil {
// Treat an invalid payload as no payload
initPayload = nil
if len(message.Payload) > 0 {
c.initPayload = make(graphql.InitPayload)
err := json.Unmarshal(message.Payload, &c.initPayload)
if err != nil {
return false
}
}

c.write(&operationMessage{Type: connectionAckMsg})
case connectionTerminateMsg:
c.close(websocket.CloseNormalClosure, "terminated")
return nil, false
return false
default:
c.sendConnectionError("unexpected message %s", message.Type)
c.close(websocket.CloseProtocolError, "unexpected message")
return nil, false
return false
}

return initPayload, true
return true
}

func (c *wsConnection) write(msg *operationMessage) {
Expand All @@ -107,7 +108,7 @@ func (c *wsConnection) write(msg *operationMessage) {
c.mu.Unlock()
}

func (c *wsConnection) run(initPayload graphql.InitPayload) {
func (c *wsConnection) run() {
for {
message := c.readOp()
if message == nil {
Expand All @@ -116,7 +117,7 @@ func (c *wsConnection) run(initPayload graphql.InitPayload) {

switch message.Type {
case startMsg:
if !c.subscribe(message, initPayload) {
if !c.subscribe(message) {
return
}
case stopMsg:
Expand All @@ -140,7 +141,7 @@ func (c *wsConnection) run(initPayload graphql.InitPayload) {
}
}

func (c *wsConnection) subscribe(message *operationMessage, initPayload graphql.InitPayload) bool {
func (c *wsConnection) subscribe(message *operationMessage) bool {
var reqParams params
if err := jsonDecode(bytes.NewReader(message.Payload), &reqParams); err != nil {
c.sendConnectionError("invalid json")
Expand All @@ -167,8 +168,8 @@ func (c *wsConnection) subscribe(message *operationMessage, initPayload graphql.
reqCtx := c.cfg.newRequestContext(doc, reqParams.Query, vars)
ctx := graphql.WithRequestContext(c.ctx, reqCtx)

if initPayload != nil {
ctx = graphql.WithInitPayload(ctx, initPayload)
if c.initPayload != nil {
ctx = graphql.WithInitPayload(ctx, c.initPayload)
}

if op.Operation != ast.Subscription {
Expand Down

0 comments on commit 01923de

Please sign in to comment.