Skip to content

Add WebSocket masking and correctly use Sec-WebSocket-Key in client #89

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 4, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions accept.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,12 +165,15 @@ var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")

func handleSecWebSocketKey(w http.ResponseWriter, r *http.Request) {
key := r.Header.Get("Sec-WebSocket-Key")
w.Header().Set("Sec-WebSocket-Accept", secWebSocketAccept(key))
}

func secWebSocketAccept(secWebSocketKey string) string {
h := sha1.New()
h.Write([]byte(key))
h.Write([]byte(secWebSocketKey))
h.Write(keyGUID)

responseKey := base64.StdEncoding.EncodeToString(h.Sum(nil))
w.Header().Set("Sec-WebSocket-Accept", responseKey)
return base64.StdEncoding.EncodeToString(h.Sum(nil))
}

func authenticateOrigin(r *http.Request) error {
Expand Down
27 changes: 17 additions & 10 deletions dial.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"encoding/base64"
"io"
"io/ioutil"
"math/rand"
"net/http"
"net/url"
"strings"
Expand All @@ -30,11 +31,6 @@ type DialOptions struct {
Subprotocols []string
}

// We use this key for all client requests as the Sec-WebSocket-Key header doesn't do anything.
// See https://stackoverflow.com/a/37074398/4283659.
// We also use the same mask key for every message as it too does not make a difference.
var secWebSocketKey = base64.StdEncoding.EncodeToString(make([]byte, 16))

// Dial performs a WebSocket handshake on the given url with the given options.
// The response is the WebSocket handshake response from the server.
// If an error occurs, the returned response may be non nil. However, you can only
Expand Down Expand Up @@ -82,7 +78,7 @@ func dial(ctx context.Context, u string, opts DialOptions) (_ *Conn, _ *http.Res
req.Header.Set("Connection", "Upgrade")
req.Header.Set("Upgrade", "websocket")
req.Header.Set("Sec-WebSocket-Version", "13")
req.Header.Set("Sec-WebSocket-Key", secWebSocketKey)
req.Header.Set("Sec-WebSocket-Key", makeSecWebSocketKey())
if len(opts.Subprotocols) > 0 {
req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ","))
}
Expand All @@ -101,7 +97,7 @@ func dial(ctx context.Context, u string, opts DialOptions) (_ *Conn, _ *http.Res
}
}()

err = verifyServerResponse(resp)
err = verifyServerResponse(req, resp)
if err != nil {
return nil, resp, err
}
Expand All @@ -118,12 +114,13 @@ func dial(ctx context.Context, u string, opts DialOptions) (_ *Conn, _ *http.Res
closer: rwc,
client: true,
}
c.extractBufioWriterBuf(rwc)
c.init()

return c, resp, nil
}

func verifyServerResponse(resp *http.Response) error {
func verifyServerResponse(r *http.Request, resp *http.Response) error {
if resp.StatusCode != http.StatusSwitchingProtocols {
return xerrors.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode)
}
Expand All @@ -136,8 +133,12 @@ func verifyServerResponse(resp *http.Response) error {
return xerrors.Errorf("websocket protocol violation: Upgrade header %q does not contain websocket", resp.Header.Get("Upgrade"))
}

// We do not care about Sec-WebSocket-Accept because it does not matter.
// See the secWebSocketKey global variable.
if resp.Header.Get("Sec-WebSocket-Accept") != secWebSocketAccept(r.Header.Get("Sec-WebSocket-Key")) {
return xerrors.Errorf("websocket protocol violation: invalid Sec-WebSocket-Accept %q, key %q",
resp.Header.Get("Sec-WebSocket-Accept"),
r.Header.Get("Sec-WebSocket-Key"),
)
}

return nil
}
Expand Down Expand Up @@ -176,3 +177,9 @@ func getBufioWriter(w io.Writer) *bufio.Writer {
func returnBufioWriter(bw *bufio.Writer) {
bufioWriterPool.Put(bw)
}

func makeSecWebSocketKey() string {
b := make([]byte, 16)
rand.Read(b)
return base64.StdEncoding.EncodeToString(b)
}
20 changes: 19 additions & 1 deletion dial_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,16 @@ func Test_verifyServerHandshake(t *testing.T) {
},
success: false,
},
{
name: "badSecWebSocketAccept",
response: func(w http.ResponseWriter) {
w.Header().Set("Connection", "Upgrade")
w.Header().Set("Upgrade", "websocket")
w.Header().Set("Sec-WebSocket-Accept", "xd")
w.WriteHeader(http.StatusSwitchingProtocols)
},
success: false,
},
{
name: "success",
response: func(w http.ResponseWriter) {
Expand All @@ -58,7 +68,15 @@ func Test_verifyServerHandshake(t *testing.T) {
tc.response(w)
resp := w.Result()

err := verifyServerResponse(resp)
r := httptest.NewRequest("GET", "/", nil)
key := makeSecWebSocketKey()
r.Header.Set("Sec-WebSocket-Key", key)

if resp.Header.Get("Sec-WebSocket-Accept") == "" {
resp.Header.Set("Sec-WebSocket-Accept", secWebSocketAccept(key))
}

err := verifyServerResponse(r, resp)
if (err == nil) != tc.success {
t.Fatalf("unexpected error: %+v", err)
}
Expand Down
3 changes: 3 additions & 0 deletions export_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
package websocket

var Compute = handleSecWebSocketKey
112 changes: 87 additions & 25 deletions websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package websocket
import (
"bufio"
"context"
cryptorand "crypto/rand"
"fmt"
"io"
"io/ioutil"
Expand All @@ -26,8 +27,11 @@ type Conn struct {
subprotocol string
br *bufio.Reader
bw *bufio.Writer
closer io.Closer
client bool
// writeBuf is used for masking, its the buffer in bufio.Writer.
// Only used by the client.
writeBuf []byte
closer io.Closer
client bool

// read limit for a message in bytes.
msgReadLimit int64
Expand Down Expand Up @@ -581,22 +585,22 @@ func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, err
// See the Writer method if you want to stream a message. The docs on Writer
// regarding concurrency also apply to this method.
func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error {
err := c.write(ctx, typ, p)
_, err := c.write(ctx, typ, p)
if err != nil {
return xerrors.Errorf("failed to write msg: %w", err)
}
return nil
}

func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) error {
func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error) {
err := c.acquireLock(ctx, c.writeMsgLock)
if err != nil {
return err
return 0, err
}
defer c.releaseLock(c.writeMsgLock)

err = c.writeFrame(ctx, true, opcode(typ), p)
return err
n, err := c.writeFrame(ctx, true, opcode(typ), p)
return n, err
}

// messageWriter enables writing to a WebSocket connection.
Expand All @@ -620,12 +624,12 @@ func (w *messageWriter) write(p []byte) (int, error) {
if w.closed {
return 0, xerrors.Errorf("cannot use closed writer")
}
err := w.c.writeFrame(w.ctx, false, w.opcode, p)
n, err := w.c.writeFrame(w.ctx, false, w.opcode, p)
if err != nil {
return 0, xerrors.Errorf("failed to write data frame: %w", err)
return n, xerrors.Errorf("failed to write data frame: %w", err)
}
w.opcode = opContinuation
return len(p), nil
return n, nil
}

// Close flushes the frame to the connection.
Expand All @@ -644,7 +648,7 @@ func (w *messageWriter) close() error {
}
w.closed = true

err := w.c.writeFrame(w.ctx, true, w.opcode, nil)
_, err := w.c.writeFrame(w.ctx, true, w.opcode, nil)
if err != nil {
return xerrors.Errorf("failed to write fin frame: %w", err)
}
Expand All @@ -654,34 +658,40 @@ func (w *messageWriter) close() error {
}

func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error {
err := c.writeFrame(ctx, true, opcode, p)
_, err := c.writeFrame(ctx, true, opcode, p)
if err != nil {
return xerrors.Errorf("failed to write control frame: %w", err)
}
return nil
}

// writeFrame handles all writes to the connection.
// We never mask inside here because our mask key is always 0,0,0,0.
// See comment on secWebSocketKey for why.
func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte) error {
func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte) (int, error) {
h := header{
fin: fin,
opcode: opcode,
masked: c.client,
payloadLength: int64(len(p)),
}

if c.client {
_, err := io.ReadFull(cryptorand.Reader, h.maskKey[:])
if err != nil {
return 0, xerrors.Errorf("failed to generate masking key: %w", err)
}
}

b2 := marshalHeader(h)

err := c.acquireLock(ctx, c.writeFrameLock)
if err != nil {
return err
return 0, err
}
defer c.releaseLock(c.writeFrameLock)

select {
case <-c.closed:
return c.closeErr
return 0, c.closeErr
case c.setWriteTimeout <- ctx:
}

Expand All @@ -705,29 +715,61 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte

_, err = c.bw.Write(b2)
if err != nil {
return writeErr(err)
}
_, err = c.bw.Write(p)
if err != nil {
return writeErr(err)
return 0, writeErr(err)
}

var n int
if c.client {
var keypos int
for len(p) > 0 {
if c.bw.Available() == 0 {
err = c.bw.Flush()
if err != nil {
return n, writeErr(err)
}
}

// Start of next write in the buffer.
i := c.bw.Buffered()

p2 := p
if len(p) > c.bw.Available() {
p2 = p[:c.bw.Available()]
}

n2, err := c.bw.Write(p2)
if err != nil {
return n, writeErr(err)
}

keypos = fastXOR(h.maskKey, keypos, c.writeBuf[i:i+n2])

p = p[n2:]
n += n2
}
} else {
n, err = c.bw.Write(p)
if err != nil {
return n, writeErr(err)
}
}

if fin {
err = c.bw.Flush()
if err != nil {
return writeErr(err)
return n, writeErr(err)
}
}

// We already finished writing, no need to potentially brick the connection if
// the context expires.
select {
case <-c.closed:
return c.closeErr
return n, c.closeErr
case c.setWriteTimeout <- context.Background():
}

return nil
return n, nil
}

func (c *Conn) writePong(p []byte) error {
Expand Down Expand Up @@ -842,3 +884,23 @@ func (c *Conn) ping(ctx context.Context) error {
return nil
}
}

type writerFunc func(p []byte) (int, error)

func (f writerFunc) Write(p []byte) (int, error) {
return f(p)
}

// extractBufioWriterBuf grabs the []byte backing a *bufio.Writer
// and stores it in c.writeBuf.
func (c *Conn) extractBufioWriterBuf(w io.Writer) {
c.bw.Reset(writerFunc(func(p2 []byte) (int, error) {
c.writeBuf = p2[:cap(p2)]
return len(p2), nil
}))

c.bw.WriteByte(0)
c.bw.Flush()

c.bw.Reset(w)
}
1 change: 0 additions & 1 deletion websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ func TestHandshake(t *testing.T) {

checkHeader("Connection", "Upgrade")
checkHeader("Upgrade", "websocket")
checkHeader("Sec-WebSocket-Accept", "ICX+Yqv66kxgM0FcWaLWlFLwTAI=")
checkHeader("Sec-WebSocket-Protocol", "myproto")

c.Close(websocket.StatusNormalClosure, "")
Expand Down