Skip to content

Commit b39ca87

Browse files
committed
Fix bugs and improve docs
1 parent 4b724ae commit b39ca87

10 files changed

+128
-70
lines changed

README.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ go get nhooyr.io/websocket@v0.2.0
2222
- Zero dependencies outside of the stdlib for the core library
2323
- JSON and ProtoBuf helpers in the wsjson and wspb subpackages
2424
- High performance
25-
- Concurrent writes
25+
- Concurrent reads and writes out of the box
2626

2727
## Roadmap
2828

@@ -122,8 +122,8 @@ also uses net/http's Client and ResponseWriter directly for WebSocket handshakes
122122
gorilla/websocket writes its handshakes to the underlying net.Conn which means
123123
it has to reinvent hooks for TLS and proxies and prevents support of HTTP/2.
124124

125-
Some more advantages of nhooyr/websocket are that it supports concurrent writes and makes it
126-
very easy to close the connection with a status code and reason.
125+
Some more advantages of nhooyr/websocket are that it supports concurrent reads,
126+
writes and makes it very easy to close the connection with a status code and reason.
127127

128128
In terms of performance, there is no significant difference between the two. Will update
129129
with benchmarks soon ([#75](https://github.com/nhooyr/websocket/issues/75)).

accept.go

+8-1
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
package websocket
22

33
import (
4+
"bytes"
45
"crypto/sha1"
56
"encoding/base64"
7+
"io"
68
"net/http"
79
"net/textproto"
810
"net/url"
@@ -78,6 +80,9 @@ func verifyClientRequest(w http.ResponseWriter, r *http.Request) error {
7880
//
7981
// Accept will reject the handshake if the Origin domain is not the same as the Host unless
8082
// the InsecureSkipVerify option is set.
83+
//
84+
// The returned connection will be bound by r.Context(). Use c.Context() to change
85+
// the bounding context.
8186
func Accept(w http.ResponseWriter, r *http.Request, opts AcceptOptions) (*Conn, error) {
8287
c, err := accept(w, r, opts)
8388
if err != nil {
@@ -126,14 +131,16 @@ func accept(w http.ResponseWriter, r *http.Request, opts AcceptOptions) (*Conn,
126131
return nil, err
127132
}
128133

134+
b, _ := brw.Reader.Peek(brw.Reader.Buffered())
135+
brw.Reader.Reset(io.MultiReader(bytes.NewReader(b), netConn))
136+
129137
c := &Conn{
130138
subprotocol: w.Header().Get("Sec-WebSocket-Protocol"),
131139
br: brw.Reader,
132140
bw: brw.Writer,
133141
closer: netConn,
134142
}
135143
c.init()
136-
// TODO document.
137144
c.Context(r.Context())
138145

139146
return c, nil

dial.go

+39-4
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,13 @@ import (
55
"bytes"
66
"context"
77
"encoding/base64"
8+
"golang.org/x/xerrors"
89
"io"
910
"io/ioutil"
1011
"net/http"
1112
"net/url"
1213
"strings"
13-
14-
"golang.org/x/xerrors"
14+
"sync"
1515
)
1616

1717
// DialOptions represents the options available to pass to Dial.
@@ -112,8 +112,8 @@ func dial(ctx context.Context, u string, opts DialOptions) (_ *Conn, _ *http.Res
112112

113113
c := &Conn{
114114
subprotocol: resp.Header.Get("Sec-WebSocket-Protocol"),
115-
br: bufio.NewReader(rwc),
116-
bw: bufio.NewWriter(rwc),
115+
br: getBufioReader(rwc),
116+
bw: getBufioWriter(rwc),
117117
closer: rwc,
118118
client: true,
119119
}
@@ -140,3 +140,38 @@ func verifyServerResponse(resp *http.Response) error {
140140

141141
return nil
142142
}
143+
144+
// The below pools can only be used by the client because http.Hijacker will always
145+
// have a bufio.Reader/Writer for us so it doesn't make sense to use a pool on top.
146+
147+
var bufioReaderPool = sync.Pool{
148+
New: func() interface{} {
149+
return bufio.NewReader(nil)
150+
},
151+
}
152+
153+
func getBufioReader(r io.Reader) *bufio.Reader {
154+
br := bufioReaderPool.Get().(*bufio.Reader)
155+
br.Reset(r)
156+
return br
157+
}
158+
159+
func returnBufioReader(br *bufio.Reader) {
160+
bufioReaderPool.Put(br)
161+
}
162+
163+
var bufioWriterPool = sync.Pool{
164+
New: func() interface{} {
165+
return bufio.NewWriter(nil)
166+
},
167+
}
168+
169+
func getBufioWriter(w io.Writer) *bufio.Writer {
170+
bw := bufioWriterPool.Get().(*bufio.Writer)
171+
bw.Reset(w)
172+
return bw
173+
}
174+
175+
func returnBufioWriter(bw *bufio.Writer) {
176+
bufioWriterPool.Put(bw)
177+
}

example_echo_test.go

+4-1
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ func Example_echo() {
5151

5252
// Now we dial the server, send the messages and echo the responses.
5353
err = client("ws://" + l.Addr().String())
54+
time.Sleep(time.Second)
5455
if err != nil {
5556
log.Fatalf("client failed: %v", err)
5657
}
@@ -66,6 +67,8 @@ func Example_echo() {
6667
// It ensures the client speaks the echo subprotocol and
6768
// only allows one message every 100ms with a 10 message burst.
6869
func echoServer(w http.ResponseWriter, r *http.Request) error {
70+
log.Printf("serving %v", r.RemoteAddr)
71+
6972
c, err := websocket.Accept(w, r, websocket.AcceptOptions{
7073
Subprotocols: []string{"echo"},
7174
})
@@ -83,7 +86,7 @@ func echoServer(w http.ResponseWriter, r *http.Request) error {
8386
for {
8487
err = echo(r.Context(), c, l)
8588
if err != nil {
86-
return xerrors.Errorf("failed to echo: %w", err)
89+
return xerrors.Errorf("failed to echo with %v: %w", r.RemoteAddr, err)
8790
}
8891
}
8992
}

go.mod

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,6 @@ require (
1212
golang.org/x/text v0.3.2 // indirect
1313
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4
1414
golang.org/x/tools v0.0.0-20190429184909-35c670923e21
15-
golang.org/x/xerrors v0.0.0-20190315151331-d61658bd2e18
15+
golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522
1616
mvdan.cc/sh v2.6.4+incompatible
1717
)

go.sum

+2
Original file line numberDiff line numberDiff line change
@@ -30,5 +30,7 @@ golang.org/x/tools v0.0.0-20190429184909-35c670923e21 h1:Kjcw+D2LTzLmxOHrMK9uvYP
3030
golang.org/x/tools v0.0.0-20190429184909-35c670923e21/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
3131
golang.org/x/xerrors v0.0.0-20190315151331-d61658bd2e18 h1:1AGvnywFL1aB5KLRxyLseWJI6aSYPo3oF7HSpXdWQdU=
3232
golang.org/x/xerrors v0.0.0-20190315151331-d61658bd2e18/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
33+
golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522 h1:bhOzK9QyoD0ogCnFro1m2mz41+Ib0oOhfJnBp5MR4K4=
34+
golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
3335
mvdan.cc/sh v2.6.4+incompatible h1:eD6tDeh0pw+/TOTI1BBEryZ02rD2nMcFsgcvde7jffM=
3436
mvdan.cc/sh v2.6.4+incompatible/go.mod h1:IeeQbZq+x2SUGBensq/jge5lLQbS3XT2ktyp3wrt4x8=

statuscode.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ type CloseError struct {
4949
}
5050

5151
func (ce CloseError) Error() string {
52-
return fmt.Sprintf("websocket closed with status = %v and reason = %q", ce.Code, ce.Reason)
52+
return fmt.Sprintf("status = %v and reason = %q", ce.Code, ce.Reason)
5353
}
5454

5555
func parseClosePayload(p []byte) (CloseError, error) {

websocket.go

+53-48
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ import (
1414
)
1515

1616
// Conn represents a WebSocket connection.
17-
// All methods except Reader can be used concurrently.
17+
// All methods may be called concurrently.
18+
//
1819
// Please be sure to call Close on the connection when you
1920
// are finished with it to release resources.
2021
type Conn struct {
@@ -31,8 +32,10 @@ type Conn struct {
3132
writeDataLock chan struct{}
3233
writeFrameLock chan struct{}
3334

34-
readData chan header
35-
readDone chan struct{}
35+
readDataLock chan struct{}
36+
readData chan header
37+
readDone chan struct{}
38+
readLoopDone chan struct{}
3639

3740
setReadTimeout chan context.Context
3841
setWriteTimeout chan context.Context
@@ -44,7 +47,7 @@ type Conn struct {
4447
// when the connection is closed.
4548
// If the parent context is cancelled, the connection will be closed.
4649
//
47-
// This is an experimental API meaning it may be remove in the future.
50+
// This is an experimental API that may be remove in the future.
4851
// Please let me know how you feel about it.
4952
func (c *Conn) Context(parent context.Context) context.Context {
5053
select {
@@ -77,6 +80,18 @@ func (c *Conn) close(err error) {
7780
c.closeErr = xerrors.Errorf("websocket closed: %w", cerr)
7881

7982
close(c.closed)
83+
84+
// See comment in dial.go
85+
if c.client {
86+
go func() {
87+
<-c.readLoopDone
88+
c.readDataLock <- struct{}{}
89+
c.writeFrameLock <- struct{}{}
90+
91+
returnBufioReader(c.br)
92+
returnBufioWriter(c.bw)
93+
}()
94+
}
8095
})
8196
}
8297

@@ -94,6 +109,8 @@ func (c *Conn) init() {
94109

95110
c.readData = make(chan header)
96111
c.readDone = make(chan struct{})
112+
c.readDataLock = make(chan struct{}, 1)
113+
c.readLoopDone = make(chan struct{})
97114

98115
c.setReadTimeout = make(chan context.Context)
99116
c.setWriteTimeout = make(chan context.Context)
@@ -174,8 +191,8 @@ func (c *Conn) timeoutLoop() {
174191
select {
175192
case <-c.closed:
176193
return
177-
case readCtx = <-c.setWriteTimeout:
178-
case writeCtx = <-c.setReadTimeout:
194+
case writeCtx = <-c.setWriteTimeout:
195+
case readCtx = <-c.setReadTimeout:
179196
case <-readCtx.Done():
180197
c.close(xerrors.Errorf("data read timed out: %w", readCtx.Err()))
181198
case <-writeCtx.Done():
@@ -276,6 +293,8 @@ func (c *Conn) readTillData() (header, error) {
276293
}
277294

278295
func (c *Conn) readLoop() {
296+
defer close(c.readLoopDone)
297+
279298
for {
280299
h, err := c.readTillData()
281300
if err != nil {
@@ -487,8 +506,7 @@ func (w *messageWriter) close() error {
487506
//
488507
// Your application must keep reading messages for the Conn to automatically respond to ping
489508
// and close frames and not become stuck waiting for a data message to be read.
490-
// Please ensure to read the full message from io.Reader. If you do not read till
491-
// io.EOF, the connection will break unless the next read would have yielded io.EOF.
509+
// Please ensure to read the full message from io.Reader.
492510
//
493511
// You can only read a single message at a time so do not call this method
494512
// concurrently.
@@ -500,30 +518,10 @@ func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) {
500518
return typ, r, nil
501519
}
502520

503-
func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) {
504-
// if !atomic.CompareAndSwapInt64(&c.activeReader, 0, 1) {
505-
// // If the next read yields io.EOF we are good to go.
506-
// r := messageReader{
507-
// ctx: ctx,
508-
// c: c,
509-
// }
510-
// _, err := r.Read(nil)
511-
// if err == nil {
512-
// return 0, nil, xerrors.New("previous message not fully read")
513-
// }
514-
// if !xerrors.Is(err, io.EOF) {
515-
// return 0, nil, xerrors.Errorf("failed to check if last message at io.EOF: %w", err)
516-
// }
517-
//
518-
// atomic.StoreInt64(&c.activeReader, 1)
519-
// }
520-
521-
select {
522-
case <-c.closed:
523-
return 0, nil, c.closeErr
524-
case <-ctx.Done():
525-
return 0, nil, ctx.Err()
526-
case c.setReadTimeout <- ctx:
521+
func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err error) {
522+
err = c.acquireLock(ctx, c.readDataLock)
523+
if err != nil {
524+
return 0, nil, err
527525
}
528526

529527
select {
@@ -533,25 +531,24 @@ func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) {
533531
return 0, nil, ctx.Err()
534532
case h := <-c.readData:
535533
if h.opcode == opContinuation {
536-
if h.fin && h.payloadLength == 0 {
537-
select {
538-
case <-c.closed:
539-
return 0, nil, c.closeErr
540-
case c.readDone <- struct{}{}:
541-
return c.reader(ctx)
542-
}
534+
ce := CloseError{
535+
Code: StatusProtocolError,
536+
Reason: "continuation frame not after data or text frame",
543537
}
544-
return 0, nil, xerrors.Errorf("previous reader was not read to EOF")
538+
c.Close(ce.Code, ce.Reason)
539+
return 0, nil, ce
545540
}
546541
return MessageType(h.opcode), &messageReader{
547-
h: &h,
548-
c: c,
542+
ctx: ctx,
543+
h: &h,
544+
c: c,
549545
}, nil
550546
}
551547
}
552548

553549
// messageReader enables reading a data frame from the WebSocket connection.
554550
type messageReader struct {
551+
ctx context.Context
555552
maskPos int
556553
h *header
557554
c *Conn
@@ -598,8 +595,20 @@ func (r *messageReader) read(p []byte) (int, error) {
598595
p = p[:r.h.payloadLength]
599596
}
600597

598+
select {
599+
case <-r.c.closed:
600+
return 0, r.c.closeErr
601+
case r.c.setReadTimeout <- r.ctx:
602+
}
603+
601604
n, err := io.ReadFull(r.c.br, p)
602605

606+
select {
607+
case <-r.c.closed:
608+
return 0, r.c.closeErr
609+
case r.c.setReadTimeout <- context.Background():
610+
}
611+
603612
r.h.payloadLength -= int64(n)
604613
if r.h.masked {
605614
r.maskPos = fastXOR(r.h.maskKey, r.maskPos, p)
@@ -618,12 +627,8 @@ func (r *messageReader) read(p []byte) (int, error) {
618627
}
619628
if r.h.fin {
620629
r.eofed = true
621-
select {
622-
case <-r.c.closed:
623-
return n, r.c.closeErr
624-
case r.c.setReadTimeout <- context.Background():
625-
return n, io.EOF
626-
}
630+
r.c.releaseLock(r.c.readDataLock)
631+
return n, io.EOF
627632
}
628633
r.maskPos = 0
629634
r.h = nil

0 commit comments

Comments
 (0)