@@ -5,9 +5,13 @@ import (
5
5
"context"
6
6
"fmt"
7
7
"io"
8
+ "io/ioutil"
9
+ "math/rand"
8
10
"os"
9
11
"runtime"
12
+ "strconv"
10
13
"sync"
14
+ "sync/atomic"
11
15
"time"
12
16
13
17
"golang.org/x/xerrors"
@@ -25,6 +29,8 @@ type Conn struct {
25
29
closer io.Closer
26
30
client bool
27
31
32
+ msgReadLimit int64
33
+
28
34
closeOnce sync.Once
29
35
closeErr error
30
36
closed chan struct {}
@@ -41,14 +47,16 @@ type Conn struct {
41
47
setWriteTimeout chan context.Context
42
48
setConnContext chan context.Context
43
49
getConnContext chan context.Context
50
+
51
+ pingListener map [string ]chan <- struct {}
44
52
}
45
53
46
54
// Context returns a context derived from parent that will be cancelled
47
- // when the connection is closed.
55
+ // when the connection is closed or broken .
48
56
// If the parent context is cancelled, the connection will be closed.
49
57
//
50
- // This is an experimental API that may be remove in the future.
51
- // Please let me know how you feel about it.
58
+ // This is an experimental API that may be removed in the future.
59
+ // Please let me know how you feel about it in https://github.com/nhooyr/websocket/issues/79
52
60
func (c * Conn ) Context (parent context.Context ) context.Context {
53
61
select {
54
62
case <- c .closed :
@@ -105,6 +113,8 @@ func (c *Conn) Subprotocol() string {
105
113
func (c * Conn ) init () {
106
114
c .closed = make (chan struct {})
107
115
116
+ c .msgReadLimit = 32768
117
+
108
118
c .writeDataLock = make (chan struct {}, 1 )
109
119
c .writeFrameLock = make (chan struct {}, 1 )
110
120
@@ -118,6 +128,8 @@ func (c *Conn) init() {
118
128
c .setConnContext = make (chan context.Context )
119
129
c .getConnContext = make (chan context.Context )
120
130
131
+ c .pingListener = make (map [string ]chan struct {})
132
+
121
133
runtime .SetFinalizer (c , func (c * Conn ) {
122
134
c .close (xerrors .New ("connection garbage collected" ))
123
135
})
@@ -242,6 +254,10 @@ func (c *Conn) handleControl(h header) {
242
254
case opPing :
243
255
c .writePong (b )
244
256
case opPong :
257
+ listener , ok := c .pingListener [string (b )]
258
+ if ok {
259
+ close (listener )
260
+ }
245
261
case opClose :
246
262
ce , err := parseClosePayload (b )
247
263
if err != nil {
@@ -321,7 +337,7 @@ func (c *Conn) writePong(p []byte) error {
321
337
ctx , cancel := context .WithTimeout (context .Background (), time .Second * 5 )
322
338
defer cancel ()
323
339
324
- err := c .writeCompleteMessage (ctx , opPong , p )
340
+ err := c .writeMessage (ctx , opPong , p )
325
341
return err
326
342
}
327
343
@@ -369,7 +385,7 @@ func (c *Conn) writeClose(p []byte, cerr CloseError) error {
369
385
ctx , cancel := context .WithTimeout (context .Background (), time .Second * 5 )
370
386
defer cancel ()
371
387
372
- err := c .writeCompleteMessage (ctx , opClose , p )
388
+ err := c .writeMessage (ctx , opClose , p )
373
389
374
390
c .close (cerr )
375
391
@@ -399,7 +415,7 @@ func (c *Conn) releaseLock(lock chan struct{}) {
399
415
<- lock
400
416
}
401
417
402
- func (c * Conn ) writeCompleteMessage (ctx context.Context , opcode opcode , p []byte ) error {
418
+ func (c * Conn ) writeMessage (ctx context.Context , opcode opcode , p []byte ) error {
403
419
if ! opcode .controlOp () {
404
420
err := c .acquireLock (ctx , c .writeDataLock )
405
421
if err != nil {
@@ -445,6 +461,30 @@ func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, err
445
461
}, nil
446
462
}
447
463
464
+ // Read is a convenience method to read a single message from the connection.
465
+ //
466
+ // See the Reader method if you want to be able to reuse buffers or want to stream a message.
467
+ func (c * Conn ) Read (ctx context.Context ) (MessageType , []byte , error ) {
468
+ typ , r , err := c .Reader (ctx )
469
+ if err != nil {
470
+ return 0 , nil , err
471
+ }
472
+
473
+ b , err := ioutil .ReadAll (r )
474
+ if err != nil {
475
+ return typ , b , err
476
+ }
477
+
478
+ return typ , b , nil
479
+ }
480
+
481
+ // Write is a convenience method to write a message to the connection.
482
+ //
483
+ // See the Writer method if you want to stream a message.
484
+ func (c * Conn ) Write (ctx context.Context , typ MessageType , p []byte ) error {
485
+ return c .writeMessage (ctx , opcode (typ ), p )
486
+ }
487
+
448
488
// messageWriter enables writing to a WebSocket connection.
449
489
type messageWriter struct {
450
490
ctx context.Context
@@ -519,7 +559,7 @@ func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) {
519
559
if err != nil {
520
560
return 0 , nil , xerrors .Errorf ("failed to get reader: %w" , err )
521
561
}
522
- return typ , r , nil
562
+ return typ , io . LimitReader ( r , c . msgReadLimit ) , nil
523
563
}
524
564
525
565
func (c * Conn ) reader (ctx context.Context ) (_ MessageType , _ io.Reader , err error ) {
@@ -640,3 +680,48 @@ func (r *messageReader) read(p []byte) (int, error) {
640
680
641
681
return n , nil
642
682
}
683
+
684
+ // SetReadLimit sets the max number of bytes to read for a single message.
685
+ // It applies to the Reader and Read methods.
686
+ //
687
+ // By default, the connection has a message read limit of 32768 bytes.
688
+ func (c * Conn ) SetReadLimit (n int64 ) {
689
+ atomic .StoreInt64 (& c .msgReadLimit , n )
690
+ }
691
+
692
+ func init () {
693
+ rand .Seed (time .Now ().UnixNano ())
694
+ }
695
+
696
+ // Ping sends a ping to the peer and waits for a pong.
697
+ // Use this to measure latency or ensure the peer is responsive.
698
+ //
699
+ // This API is experimental and subject to change.
700
+ // Please provide feedback in https://github.com/nhooyr/websocket/issues/1.
701
+ func (c * Conn ) Ping (ctx context.Context ) error {
702
+ err := c .ping (ctx )
703
+ if err != nil {
704
+ return xerrors .Errorf ("failed to ping: %w" , err )
705
+ }
706
+ return nil
707
+ }
708
+
709
+ func (c * Conn ) ping (ctx context.Context ) error {
710
+ id := rand .Uint64 ()
711
+ p := strconv .FormatUint (id , 10 )
712
+
713
+ pong := make (chan struct {})
714
+ c .pingListener [p ] = pong
715
+
716
+ err := c .writeMessage (ctx , opPing , []byte (p ))
717
+ if err != nil {
718
+ return err
719
+ }
720
+
721
+ select {
722
+ case <- ctx .Done ():
723
+ return ctx .Err ()
724
+ case <- pong :
725
+ return nil
726
+ }
727
+ }
0 commit comments