Skip to content

Commit

Permalink
Improve protocol error messages
Browse files Browse the repository at this point in the history
To aid protocol error debugging, report all errors found in the first two bytes of a message header.
  • Loading branch information
garyburd authored Jan 2, 2022
1 parent 2d6ee4c commit f0643a3
Showing 1 changed file with 40 additions and 17 deletions.
57 changes: 40 additions & 17 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"math/rand"
"net"
"strconv"
"strings"
"sync"
"time"
"unicode/utf8"
Expand Down Expand Up @@ -794,47 +795,69 @@ func (c *Conn) advanceFrame() (int, error) {
}

// 2. Read and parse first two bytes of frame header.
// To aid debugging, collect and report all errors in the first two bytes
// of the header.

var errors []string

p, err := c.read(2)
if err != nil {
return noFrame, err
}

final := p[0]&finalBit != 0
frameType := int(p[0] & 0xf)
final := p[0]&finalBit != 0
rsv1 := p[0]&rsv1Bit != 0
rsv2 := p[0]&rsv2Bit != 0
rsv3 := p[0]&rsv3Bit != 0
mask := p[1]&maskBit != 0
c.setReadRemaining(int64(p[1] & 0x7f))

c.readDecompress = false
if c.newDecompressionReader != nil && (p[0]&rsv1Bit) != 0 {
c.readDecompress = true
p[0] &^= rsv1Bit
if rsv1 {
if c.newDecompressionReader != nil {
c.readDecompress = true
} else {
errors = append(errors, "RSV1 set")
}
}

if rsv := p[0] & (rsv1Bit | rsv2Bit | rsv3Bit); rsv != 0 {
return noFrame, c.handleProtocolError("unexpected reserved bits 0x" + strconv.FormatInt(int64(rsv), 16))
if rsv2 {
errors = append(errors, "RSV2 set")
}

if rsv3 {
errors = append(errors, "RSV3 set")
}

switch frameType {
case CloseMessage, PingMessage, PongMessage:
if c.readRemaining > maxControlFramePayloadSize {
return noFrame, c.handleProtocolError("control frame length > 125")
errors = append(errors, "len > 125 for control")
}
if !final {
return noFrame, c.handleProtocolError("control frame not final")
errors = append(errors, "FIN not set on control")
}
case TextMessage, BinaryMessage:
if !c.readFinal {
return noFrame, c.handleProtocolError("message start before final message frame")
errors = append(errors, "data before FIN")
}
c.readFinal = final
case continuationFrame:
if c.readFinal {
return noFrame, c.handleProtocolError("continuation after final message frame")
errors = append(errors, "continuation after FIN")
}
c.readFinal = final
default:
return noFrame, c.handleProtocolError("unknown opcode " + strconv.Itoa(frameType))
errors = append(errors, "bad opcode "+strconv.Itoa(frameType))
}

if mask != c.isServer {
errors = append(errors, "bad MASK")
}

if len(errors) > 0 {
return noFrame, c.handleProtocolError(strings.Join(errors, ", "))
}

// 3. Read and parse frame length as per
Expand Down Expand Up @@ -872,10 +895,6 @@ func (c *Conn) advanceFrame() (int, error) {

// 4. Handle frame masking.

if mask != c.isServer {
return noFrame, c.handleProtocolError("incorrect mask flag")
}

if mask {
c.readMaskPos = 0
p, err := c.read(len(c.readMaskKey))
Expand Down Expand Up @@ -935,7 +954,7 @@ func (c *Conn) advanceFrame() (int, error) {
if len(payload) >= 2 {
closeCode = int(binary.BigEndian.Uint16(payload))
if !isValidReceivedCloseCode(closeCode) {
return noFrame, c.handleProtocolError("invalid close code")
return noFrame, c.handleProtocolError("bad close code " + strconv.Itoa(closeCode))
}
closeText = string(payload[2:])
if !utf8.ValidString(closeText) {
Expand All @@ -952,7 +971,11 @@ func (c *Conn) advanceFrame() (int, error) {
}

func (c *Conn) handleProtocolError(message string) error {
c.WriteControl(CloseMessage, FormatCloseMessage(CloseProtocolError, message), time.Now().Add(writeWait))
data := FormatCloseMessage(CloseProtocolError, message)
if len(data) > maxControlFramePayloadSize {
data = data[:maxControlFramePayloadSize]
}
c.WriteControl(CloseMessage, data, time.Now().Add(writeWait))
return errors.New("websocket: " + message)
}

Expand Down

0 comments on commit f0643a3

Please sign in to comment.