Skip to content

Commit 6ad088f

Browse files
committed
Expand API
- Closes #1 (Ping API) - Closes #62 (Read/Write convienence methods) - Closes #83 (SetReadLimit)
1 parent 027e6af commit 6ad088f

File tree

6 files changed

+95
-51
lines changed

6 files changed

+95
-51
lines changed

Diff for: example_echo_test.go

-2
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,6 @@ func echoServer(w http.ResponseWriter, r *http.Request) error {
9494
// echo reads from the websocket connection and then writes
9595
// the received message back to it.
9696
// The entire function has 10s to complete.
97-
// The received message is limited to 32768 bytes.
9897
func echo(ctx context.Context, c *websocket.Conn, l *rate.Limiter) error {
9998
ctx, cancel := context.WithTimeout(ctx, time.Second*10)
10099
defer cancel()
@@ -108,7 +107,6 @@ func echo(ctx context.Context, c *websocket.Conn, l *rate.Limiter) error {
108107
if err != nil {
109108
return err
110109
}
111-
r = io.LimitReader(r, 32768)
112110

113111
w, err := c.Writer(ctx, typ)
114112
if err != nil {

Diff for: export_test.go

-18
This file was deleted.

Diff for: websocket.go

+92-7
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,13 @@ import (
55
"context"
66
"fmt"
77
"io"
8+
"io/ioutil"
9+
"math/rand"
810
"os"
911
"runtime"
12+
"strconv"
1013
"sync"
14+
"sync/atomic"
1115
"time"
1216

1317
"golang.org/x/xerrors"
@@ -25,6 +29,8 @@ type Conn struct {
2529
closer io.Closer
2630
client bool
2731

32+
msgReadLimit int64
33+
2834
closeOnce sync.Once
2935
closeErr error
3036
closed chan struct{}
@@ -41,14 +47,16 @@ type Conn struct {
4147
setWriteTimeout chan context.Context
4248
setConnContext chan context.Context
4349
getConnContext chan context.Context
50+
51+
pingListener map[string]chan<- struct{}
4452
}
4553

4654
// Context returns a context derived from parent that will be cancelled
47-
// when the connection is closed.
55+
// when the connection is closed or broken.
4856
// If the parent context is cancelled, the connection will be closed.
4957
//
50-
// This is an experimental API that may be remove in the future.
51-
// Please let me know how you feel about it.
58+
// This is an experimental API that may be removed in the future.
59+
// Please let me know how you feel about it in https://github.com/nhooyr/websocket/issues/79
5260
func (c *Conn) Context(parent context.Context) context.Context {
5361
select {
5462
case <-c.closed:
@@ -105,6 +113,8 @@ func (c *Conn) Subprotocol() string {
105113
func (c *Conn) init() {
106114
c.closed = make(chan struct{})
107115

116+
c.msgReadLimit = 32768
117+
108118
c.writeDataLock = make(chan struct{}, 1)
109119
c.writeFrameLock = make(chan struct{}, 1)
110120

@@ -118,6 +128,8 @@ func (c *Conn) init() {
118128
c.setConnContext = make(chan context.Context)
119129
c.getConnContext = make(chan context.Context)
120130

131+
c.pingListener = make(map[string]chan struct{})
132+
121133
runtime.SetFinalizer(c, func(c *Conn) {
122134
c.close(xerrors.New("connection garbage collected"))
123135
})
@@ -242,6 +254,10 @@ func (c *Conn) handleControl(h header) {
242254
case opPing:
243255
c.writePong(b)
244256
case opPong:
257+
listener, ok := c.pingListener[string(b)]
258+
if ok {
259+
close(listener)
260+
}
245261
case opClose:
246262
ce, err := parseClosePayload(b)
247263
if err != nil {
@@ -321,7 +337,7 @@ func (c *Conn) writePong(p []byte) error {
321337
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
322338
defer cancel()
323339

324-
err := c.writeCompleteMessage(ctx, opPong, p)
340+
err := c.writeMessage(ctx, opPong, p)
325341
return err
326342
}
327343

@@ -369,7 +385,7 @@ func (c *Conn) writeClose(p []byte, cerr CloseError) error {
369385
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
370386
defer cancel()
371387

372-
err := c.writeCompleteMessage(ctx, opClose, p)
388+
err := c.writeMessage(ctx, opClose, p)
373389

374390
c.close(cerr)
375391

@@ -399,7 +415,7 @@ func (c *Conn) releaseLock(lock chan struct{}) {
399415
<-lock
400416
}
401417

402-
func (c *Conn) writeCompleteMessage(ctx context.Context, opcode opcode, p []byte) error {
418+
func (c *Conn) writeMessage(ctx context.Context, opcode opcode, p []byte) error {
403419
if !opcode.controlOp() {
404420
err := c.acquireLock(ctx, c.writeDataLock)
405421
if err != nil {
@@ -445,6 +461,30 @@ func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, err
445461
}, nil
446462
}
447463

464+
// Read is a convenience method to read a single message from the connection.
465+
//
466+
// See the Reader method if you want to be able to reuse buffers or want to stream a message.
467+
func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) {
468+
typ, r, err := c.Reader(ctx)
469+
if err != nil {
470+
return 0, nil, err
471+
}
472+
473+
b, err := ioutil.ReadAll(r)
474+
if err != nil {
475+
return typ, b, err
476+
}
477+
478+
return typ, b, nil
479+
}
480+
481+
// Write is a convenience method to write a message to the connection.
482+
//
483+
// See the Writer method if you want to stream a message.
484+
func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error {
485+
return c.writeMessage(ctx, opcode(typ), p)
486+
}
487+
448488
// messageWriter enables writing to a WebSocket connection.
449489
type messageWriter struct {
450490
ctx context.Context
@@ -519,7 +559,7 @@ func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) {
519559
if err != nil {
520560
return 0, nil, xerrors.Errorf("failed to get reader: %w", err)
521561
}
522-
return typ, r, nil
562+
return typ, io.LimitReader(r, c.msgReadLimit), nil
523563
}
524564

525565
func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err error) {
@@ -640,3 +680,48 @@ func (r *messageReader) read(p []byte) (int, error) {
640680

641681
return n, nil
642682
}
683+
684+
// SetReadLimit sets the max number of bytes to read for a single message.
685+
// It applies to the Reader and Read methods.
686+
//
687+
// By default, the connection has a message read limit of 32768 bytes.
688+
func (c *Conn) SetReadLimit(n int64) {
689+
atomic.StoreInt64(&c.msgReadLimit, n)
690+
}
691+
692+
func init() {
693+
rand.Seed(time.Now().UnixNano())
694+
}
695+
696+
// Ping sends a ping to the peer and waits for a pong.
697+
// Use this to measure latency or ensure the peer is responsive.
698+
//
699+
// This API is experimental and subject to change.
700+
// Please provide feedback in https://github.com/nhooyr/websocket/issues/1.
701+
func (c *Conn) Ping(ctx context.Context) error {
702+
err := c.ping(ctx)
703+
if err != nil {
704+
return xerrors.Errorf("failed to ping: %w", err)
705+
}
706+
return nil
707+
}
708+
709+
func (c *Conn) ping(ctx context.Context) error {
710+
id := rand.Uint64()
711+
p := strconv.FormatUint(id, 10)
712+
713+
pong := make(chan struct{})
714+
c.pingListener[p] = pong
715+
716+
err := c.writeMessage(ctx, opPing, []byte(p))
717+
if err != nil {
718+
return err
719+
}
720+
721+
select {
722+
case <-ctx.Done():
723+
return ctx.Err()
724+
case <-pong:
725+
return nil
726+
}
727+
}

Diff for: websocket_test.go

+2
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,8 @@ func TestAutobahnServer(t *testing.T) {
489489
func echoLoop(ctx context.Context, c *websocket.Conn) {
490490
defer c.Close(websocket.StatusInternalError, "")
491491

492+
c.SetReadLimit(1 << 30)
493+
492494
ctx, cancel := context.WithTimeout(ctx, time.Minute)
493495
defer cancel()
494496

Diff for: wsjson/wsjson.go

-4
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@ import (
1212
)
1313

1414
// Read reads a json message from c into v.
15-
// For security reasons, it will not read messages
16-
// larger than 32768 bytes.
1715
func Read(ctx context.Context, c *websocket.Conn, v interface{}) error {
1816
err := read(ctx, c, v)
1917
if err != nil {
@@ -33,8 +31,6 @@ func read(ctx context.Context, c *websocket.Conn, v interface{}) error {
3331
return xerrors.Errorf("unexpected frame type for json (expected %v): %v", websocket.MessageText, typ)
3432
}
3533

36-
r = io.LimitReader(r, 32768)
37-
3834
d := json.NewDecoder(r)
3935
err = d.Decode(v)
4036
if err != nil {

Diff for: wspb/wspb.go

+1-20
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ package wspb
33

44
import (
55
"context"
6-
"io"
76
"io/ioutil"
87

98
"github.com/golang/protobuf/proto"
@@ -13,8 +12,6 @@ import (
1312
)
1413

1514
// Read reads a protobuf message from c into v.
16-
// For security reasons, it will not read messages
17-
// larger than 32768 bytes.
1815
func Read(ctx context.Context, c *websocket.Conn, v proto.Message) error {
1916
err := read(ctx, c, v)
2017
if err != nil {
@@ -34,8 +31,6 @@ func read(ctx context.Context, c *websocket.Conn, v proto.Message) error {
3431
return xerrors.Errorf("unexpected frame type for protobuf (expected %v): %v", websocket.MessageBinary, typ)
3532
}
3633

37-
r = io.LimitReader(r, 32768)
38-
3934
b, err := ioutil.ReadAll(r)
4035
if err != nil {
4136
return xerrors.Errorf("failed to read message: %w", err)
@@ -64,19 +59,5 @@ func write(ctx context.Context, c *websocket.Conn, v proto.Message) error {
6459
return xerrors.Errorf("failed to marshal protobuf: %w", err)
6560
}
6661

67-
w, err := c.Writer(ctx, websocket.MessageBinary)
68-
if err != nil {
69-
return err
70-
}
71-
72-
_, err = w.Write(b)
73-
if err != nil {
74-
return err
75-
}
76-
77-
err = w.Close()
78-
if err != nil {
79-
return err
80-
}
81-
return nil
62+
return c.Write(ctx, websocket.MessageBinary, b)
8263
}

0 commit comments

Comments
 (0)