Skip to content

Commit e3dc7a3

Browse files
committed
Add WebSocket masking and correctly use Sec-WebSocket-Key in client
Closes #88
1 parent 4130a30 commit e3dc7a3

File tree

6 files changed

+127
-39
lines changed

6 files changed

+127
-39
lines changed

accept.go

+6-3
Original file line numberDiff line numberDiff line change
@@ -165,12 +165,15 @@ var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
165165

166166
func handleSecWebSocketKey(w http.ResponseWriter, r *http.Request) {
167167
key := r.Header.Get("Sec-WebSocket-Key")
168+
w.Header().Set("Sec-WebSocket-Accept", secWebSocketAccept(key))
169+
}
170+
171+
func secWebSocketAccept(secWebSocketKey string) string {
168172
h := sha1.New()
169-
h.Write([]byte(key))
173+
h.Write([]byte(secWebSocketKey))
170174
h.Write(keyGUID)
171175

172-
responseKey := base64.StdEncoding.EncodeToString(h.Sum(nil))
173-
w.Header().Set("Sec-WebSocket-Accept", responseKey)
176+
return base64.StdEncoding.EncodeToString(h.Sum(nil))
174177
}
175178

176179
func authenticateOrigin(r *http.Request) error {

dial.go

+13-9
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"encoding/base64"
88
"io"
99
"io/ioutil"
10+
"math/rand"
1011
"net/http"
1112
"net/url"
1213
"strings"
@@ -30,11 +31,6 @@ type DialOptions struct {
3031
Subprotocols []string
3132
}
3233

33-
// We use this key for all client requests as the Sec-WebSocket-Key header doesn't do anything.
34-
// See https://stackoverflow.com/a/37074398/4283659.
35-
// We also use the same mask key for every message as it too does not make a difference.
36-
var secWebSocketKey = base64.StdEncoding.EncodeToString(make([]byte, 16))
37-
3834
// Dial performs a WebSocket handshake on the given url with the given options.
3935
// The response is the WebSocket handshake response from the server.
4036
// If an error occurs, the returned response may be non nil. However, you can only
@@ -82,7 +78,7 @@ func dial(ctx context.Context, u string, opts DialOptions) (_ *Conn, _ *http.Res
8278
req.Header.Set("Connection", "Upgrade")
8379
req.Header.Set("Upgrade", "websocket")
8480
req.Header.Set("Sec-WebSocket-Version", "13")
85-
req.Header.Set("Sec-WebSocket-Key", secWebSocketKey)
81+
req.Header.Set("Sec-WebSocket-Key", makeSecWebSocketKey())
8682
if len(opts.Subprotocols) > 0 {
8783
req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ","))
8884
}
@@ -118,12 +114,13 @@ func dial(ctx context.Context, u string, opts DialOptions) (_ *Conn, _ *http.Res
118114
closer: rwc,
119115
client: true,
120116
}
117+
c.extractBufioWriterBuf(rwc)
121118
c.init()
122119

123120
return c, resp, nil
124121
}
125122

126-
func verifyServerResponse(resp *http.Response) error {
123+
func verifyServerResponse(r *http.Request, resp *http.Response) error {
127124
if resp.StatusCode != http.StatusSwitchingProtocols {
128125
return xerrors.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode)
129126
}
@@ -136,8 +133,9 @@ func verifyServerResponse(resp *http.Response) error {
136133
return xerrors.Errorf("websocket protocol violation: Upgrade header %q does not contain websocket", resp.Header.Get("Upgrade"))
137134
}
138135

139-
// We do not care about Sec-WebSocket-Accept because it does not matter.
140-
// See the secWebSocketKey global variable.
136+
if resp.Header.Get("Sec-WebSocket-Accept") != secWebSocketAccept(r.Header.Get("Sec-WebSocket-Key")) {
137+
return xerrors.Errorf("websocket protocol violation: invalid Sec-WebSocket-Accept", resp.Header.Get("Sec-WebSocket-Accept"))
138+
}
141139

142140
return nil
143141
}
@@ -176,3 +174,9 @@ func getBufioWriter(w io.Writer) *bufio.Writer {
176174
func returnBufioWriter(bw *bufio.Writer) {
177175
bufioWriterPool.Put(bw)
178176
}
177+
178+
func makeSecWebSocketKey() string {
179+
b := make([]byte, 16)
180+
rand.Read(b)
181+
return base64.StdEncoding.EncodeToString(b)
182+
}

dial_test.go

+18-1
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,15 @@ func Test_verifyServerHandshake(t *testing.T) {
3838
},
3939
success: false,
4040
},
41+
{
42+
name: "badSecWebSocketAccept",
43+
response: func(w http.ResponseWriter) {
44+
w.Header().Set("Connection", "Upgrade")
45+
w.Header().Set("Sec-WebSocket-Accept", "xd")
46+
w.WriteHeader(http.StatusSwitchingProtocols)
47+
},
48+
success: false,
49+
},
4150
{
4251
name: "success",
4352
response: func(w http.ResponseWriter) {
@@ -58,7 +67,15 @@ func Test_verifyServerHandshake(t *testing.T) {
5867
tc.response(w)
5968
resp := w.Result()
6069

61-
err := verifyServerResponse(resp)
70+
r := httptest.NewRequest("GET", "", nil)
71+
key := makeSecWebSocketKey()
72+
r.Header.Set("Sec-WebSocket-Key", key)
73+
74+
if resp.Header.Get("Sec-WebSocket-Accept") == "" {
75+
r.Header.Set("Sec-WebSocket-Accept", secWebSocketAccept(key))
76+
}
77+
78+
err := verifyServerResponse(r, resp)
6279
if (err == nil) != tc.success {
6380
t.Fatalf("unexpected error: %+v", err)
6481
}

export_test.go

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
package websocket
2+
3+
var Compute = handleSecWebSocketKey

websocket.go

+87-25
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package websocket
33
import (
44
"bufio"
55
"context"
6+
cryptorand "crypto/rand"
67
"fmt"
78
"io"
89
"io/ioutil"
@@ -26,8 +27,11 @@ type Conn struct {
2627
subprotocol string
2728
br *bufio.Reader
2829
bw *bufio.Writer
29-
closer io.Closer
30-
client bool
30+
// writeBuf is used for masking, its the buffer in bufio.Writer.
31+
// Only used by the client.
32+
writeBuf []byte
33+
closer io.Closer
34+
client bool
3135

3236
// read limit for a message in bytes.
3337
msgReadLimit int64
@@ -581,22 +585,22 @@ func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, err
581585
// See the Writer method if you want to stream a message. The docs on Writer
582586
// regarding concurrency also apply to this method.
583587
func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error {
584-
err := c.write(ctx, typ, p)
588+
_, err := c.write(ctx, typ, p)
585589
if err != nil {
586590
return xerrors.Errorf("failed to write msg: %w", err)
587591
}
588592
return nil
589593
}
590594

591-
func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) error {
595+
func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error) {
592596
err := c.acquireLock(ctx, c.writeMsgLock)
593597
if err != nil {
594-
return err
598+
return 0, err
595599
}
596600
defer c.releaseLock(c.writeMsgLock)
597601

598-
err = c.writeFrame(ctx, true, opcode(typ), p)
599-
return err
602+
n, err := c.writeFrame(ctx, true, opcode(typ), p)
603+
return n, err
600604
}
601605

602606
// messageWriter enables writing to a WebSocket connection.
@@ -620,12 +624,12 @@ func (w *messageWriter) write(p []byte) (int, error) {
620624
if w.closed {
621625
return 0, xerrors.Errorf("cannot use closed writer")
622626
}
623-
err := w.c.writeFrame(w.ctx, false, w.opcode, p)
627+
n, err := w.c.writeFrame(w.ctx, false, w.opcode, p)
624628
if err != nil {
625-
return 0, xerrors.Errorf("failed to write data frame: %w", err)
629+
return n, xerrors.Errorf("failed to write data frame: %w", err)
626630
}
627631
w.opcode = opContinuation
628-
return len(p), nil
632+
return n, nil
629633
}
630634

631635
// Close flushes the frame to the connection.
@@ -644,7 +648,7 @@ func (w *messageWriter) close() error {
644648
}
645649
w.closed = true
646650

647-
err := w.c.writeFrame(w.ctx, true, w.opcode, nil)
651+
_, err := w.c.writeFrame(w.ctx, true, w.opcode, nil)
648652
if err != nil {
649653
return xerrors.Errorf("failed to write fin frame: %w", err)
650654
}
@@ -654,34 +658,40 @@ func (w *messageWriter) close() error {
654658
}
655659

656660
func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error {
657-
err := c.writeFrame(ctx, true, opcode, p)
661+
_, err := c.writeFrame(ctx, true, opcode, p)
658662
if err != nil {
659663
return xerrors.Errorf("failed to write control frame: %w", err)
660664
}
661665
return nil
662666
}
663667

664668
// writeFrame handles all writes to the connection.
665-
// We never mask inside here because our mask key is always 0,0,0,0.
666-
// See comment on secWebSocketKey for why.
667-
func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte) error {
669+
func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte) (int, error) {
668670
h := header{
669671
fin: fin,
670672
opcode: opcode,
671673
masked: c.client,
672674
payloadLength: int64(len(p)),
673675
}
676+
677+
if c.client {
678+
_, err := io.ReadFull(cryptorand.Reader, h.maskKey[:])
679+
if err != nil {
680+
return 0, xerrors.Errorf("failed to generate masking key: %w", err)
681+
}
682+
}
683+
674684
b2 := marshalHeader(h)
675685

676686
err := c.acquireLock(ctx, c.writeFrameLock)
677687
if err != nil {
678-
return err
688+
return 0, err
679689
}
680690
defer c.releaseLock(c.writeFrameLock)
681691

682692
select {
683693
case <-c.closed:
684-
return c.closeErr
694+
return 0, c.closeErr
685695
case c.setWriteTimeout <- ctx:
686696
}
687697

@@ -705,29 +715,61 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte
705715

706716
_, err = c.bw.Write(b2)
707717
if err != nil {
708-
return writeErr(err)
709-
}
710-
_, err = c.bw.Write(p)
711-
if err != nil {
712-
return writeErr(err)
718+
return 0, writeErr(err)
719+
}
720+
721+
var n int
722+
if c.client {
723+
var keypos int
724+
for len(p) > 0 {
725+
if c.bw.Available() == 0 {
726+
err = c.bw.Flush()
727+
if err != nil {
728+
return n, writeErr(err)
729+
}
730+
}
731+
732+
// Start of next write in the buffer.
733+
i := c.bw.Buffered()
734+
735+
p2 := p
736+
if len(p) > c.bw.Available() {
737+
p2 = p[:c.bw.Available()]
738+
}
739+
740+
n2, err := c.bw.Write(p2)
741+
if err != nil {
742+
return n, writeErr(err)
743+
}
744+
745+
keypos = fastXOR(h.maskKey, keypos, c.writeBuf[i:i+n2])
746+
747+
p = p[n2:]
748+
n += n2
749+
}
750+
} else {
751+
n, err = c.bw.Write(p)
752+
if err != nil {
753+
return n, writeErr(err)
754+
}
713755
}
714756

715757
if fin {
716758
err = c.bw.Flush()
717759
if err != nil {
718-
return writeErr(err)
760+
return n, writeErr(err)
719761
}
720762
}
721763

722764
// We already finished writing, no need to potentially brick the connection if
723765
// the context expires.
724766
select {
725767
case <-c.closed:
726-
return c.closeErr
768+
return n, c.closeErr
727769
case c.setWriteTimeout <- context.Background():
728770
}
729771

730-
return nil
772+
return n, nil
731773
}
732774

733775
func (c *Conn) writePong(p []byte) error {
@@ -842,3 +884,23 @@ func (c *Conn) ping(ctx context.Context) error {
842884
return nil
843885
}
844886
}
887+
888+
type writerFunc func(p []byte) (int, error)
889+
890+
func (f writerFunc) Write(p []byte) (int, error) {
891+
return f(p)
892+
}
893+
894+
// extractBufioWriterBuf grabs the []byte backing a *bufio.Writer
895+
// and stores it in c.writeBuf.
896+
func (c *Conn) extractBufioWriterBuf(w io.Writer) {
897+
c.bw.Reset(writerFunc(func(p2 []byte) (int, error) {
898+
c.writeBuf = p2[:cap(p2)]
899+
return len(p2), nil
900+
}))
901+
902+
c.bw.WriteByte(0)
903+
c.bw.Flush()
904+
905+
c.bw.Reset(w)
906+
}

websocket_test.go

-1
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,6 @@ func TestHandshake(t *testing.T) {
6868

6969
checkHeader("Connection", "Upgrade")
7070
checkHeader("Upgrade", "websocket")
71-
checkHeader("Sec-WebSocket-Accept", "ICX+Yqv66kxgM0FcWaLWlFLwTAI=")
7271
checkHeader("Sec-WebSocket-Protocol", "myproto")
7372

7473
c.Close(websocket.StatusNormalClosure, "")

0 commit comments

Comments
 (0)