@@ -14,7 +14,8 @@ import (
14
14
)
15
15
16
16
// Conn represents a WebSocket connection.
17
- // All methods except Reader can be used concurrently.
17
+ // All methods may be called concurrently.
18
+ //
18
19
// Please be sure to call Close on the connection when you
19
20
// are finished with it to release resources.
20
21
type Conn struct {
@@ -31,8 +32,10 @@ type Conn struct {
31
32
writeDataLock chan struct {}
32
33
writeFrameLock chan struct {}
33
34
34
- readData chan header
35
- readDone chan struct {}
35
+ readDataLock chan struct {}
36
+ readData chan header
37
+ readDone chan struct {}
38
+ readLoopDone chan struct {}
36
39
37
40
setReadTimeout chan context.Context
38
41
setWriteTimeout chan context.Context
@@ -44,7 +47,7 @@ type Conn struct {
44
47
// when the connection is closed.
45
48
// If the parent context is cancelled, the connection will be closed.
46
49
//
47
- // This is an experimental API meaning it may be remove in the future.
50
+ // This is an experimental API that may be remove in the future.
48
51
// Please let me know how you feel about it.
49
52
func (c * Conn ) Context (parent context.Context ) context.Context {
50
53
select {
@@ -77,6 +80,18 @@ func (c *Conn) close(err error) {
77
80
c .closeErr = xerrors .Errorf ("websocket closed: %w" , cerr )
78
81
79
82
close (c .closed )
83
+
84
+ // See comment in dial.go
85
+ if c .client {
86
+ go func () {
87
+ <- c .readLoopDone
88
+ c .readDataLock <- struct {}{}
89
+ c .writeFrameLock <- struct {}{}
90
+
91
+ returnBufioReader (c .br )
92
+ returnBufioWriter (c .bw )
93
+ }()
94
+ }
80
95
})
81
96
}
82
97
@@ -94,6 +109,8 @@ func (c *Conn) init() {
94
109
95
110
c .readData = make (chan header )
96
111
c .readDone = make (chan struct {})
112
+ c .readDataLock = make (chan struct {}, 1 )
113
+ c .readLoopDone = make (chan struct {})
97
114
98
115
c .setReadTimeout = make (chan context.Context )
99
116
c .setWriteTimeout = make (chan context.Context )
@@ -174,8 +191,8 @@ func (c *Conn) timeoutLoop() {
174
191
select {
175
192
case <- c .closed :
176
193
return
177
- case readCtx = <- c .setWriteTimeout :
178
- case writeCtx = <- c .setReadTimeout :
194
+ case writeCtx = <- c .setWriteTimeout :
195
+ case readCtx = <- c .setReadTimeout :
179
196
case <- readCtx .Done ():
180
197
c .close (xerrors .Errorf ("data read timed out: %w" , readCtx .Err ()))
181
198
case <- writeCtx .Done ():
@@ -276,6 +293,8 @@ func (c *Conn) readTillData() (header, error) {
276
293
}
277
294
278
295
func (c * Conn ) readLoop () {
296
+ defer close (c .readLoopDone )
297
+
279
298
for {
280
299
h , err := c .readTillData ()
281
300
if err != nil {
@@ -487,8 +506,7 @@ func (w *messageWriter) close() error {
487
506
//
488
507
// Your application must keep reading messages for the Conn to automatically respond to ping
489
508
// and close frames and not become stuck waiting for a data message to be read.
490
- // Please ensure to read the full message from io.Reader. If you do not read till
491
- // io.EOF, the connection will break unless the next read would have yielded io.EOF.
509
+ // Please ensure to read the full message from io.Reader.
492
510
//
493
511
// You can only read a single message at a time so do not call this method
494
512
// concurrently.
@@ -500,30 +518,10 @@ func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) {
500
518
return typ , r , nil
501
519
}
502
520
503
- func (c * Conn ) reader (ctx context.Context ) (MessageType , io.Reader , error ) {
504
- // if !atomic.CompareAndSwapInt64(&c.activeReader, 0, 1) {
505
- // // If the next read yields io.EOF we are good to go.
506
- // r := messageReader{
507
- // ctx: ctx,
508
- // c: c,
509
- // }
510
- // _, err := r.Read(nil)
511
- // if err == nil {
512
- // return 0, nil, xerrors.New("previous message not fully read")
513
- // }
514
- // if !xerrors.Is(err, io.EOF) {
515
- // return 0, nil, xerrors.Errorf("failed to check if last message at io.EOF: %w", err)
516
- // }
517
- //
518
- // atomic.StoreInt64(&c.activeReader, 1)
519
- // }
520
-
521
- select {
522
- case <- c .closed :
523
- return 0 , nil , c .closeErr
524
- case <- ctx .Done ():
525
- return 0 , nil , ctx .Err ()
526
- case c .setReadTimeout <- ctx :
521
+ func (c * Conn ) reader (ctx context.Context ) (_ MessageType , _ io.Reader , err error ) {
522
+ err = c .acquireLock (ctx , c .readDataLock )
523
+ if err != nil {
524
+ return 0 , nil , err
527
525
}
528
526
529
527
select {
@@ -533,25 +531,24 @@ func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) {
533
531
return 0 , nil , ctx .Err ()
534
532
case h := <- c .readData :
535
533
if h .opcode == opContinuation {
536
- if h .fin && h .payloadLength == 0 {
537
- select {
538
- case <- c .closed :
539
- return 0 , nil , c .closeErr
540
- case c .readDone <- struct {}{}:
541
- return c .reader (ctx )
542
- }
534
+ ce := CloseError {
535
+ Code : StatusProtocolError ,
536
+ Reason : "continuation frame not after data or text frame" ,
543
537
}
544
- return 0 , nil , xerrors .Errorf ("previous reader was not read to EOF" )
538
+ c .Close (ce .Code , ce .Reason )
539
+ return 0 , nil , ce
545
540
}
546
541
return MessageType (h .opcode ), & messageReader {
547
- h : & h ,
548
- c : c ,
542
+ ctx : ctx ,
543
+ h : & h ,
544
+ c : c ,
549
545
}, nil
550
546
}
551
547
}
552
548
553
549
// messageReader enables reading a data frame from the WebSocket connection.
554
550
type messageReader struct {
551
+ ctx context.Context
555
552
maskPos int
556
553
h * header
557
554
c * Conn
@@ -598,8 +595,20 @@ func (r *messageReader) read(p []byte) (int, error) {
598
595
p = p [:r .h .payloadLength ]
599
596
}
600
597
598
+ select {
599
+ case <- r .c .closed :
600
+ return 0 , r .c .closeErr
601
+ case r .c .setReadTimeout <- r .ctx :
602
+ }
603
+
601
604
n , err := io .ReadFull (r .c .br , p )
602
605
606
+ select {
607
+ case <- r .c .closed :
608
+ return 0 , r .c .closeErr
609
+ case r .c .setReadTimeout <- context .Background ():
610
+ }
611
+
603
612
r .h .payloadLength -= int64 (n )
604
613
if r .h .masked {
605
614
r .maskPos = fastXOR (r .h .maskKey , r .maskPos , p )
@@ -618,12 +627,8 @@ func (r *messageReader) read(p []byte) (int, error) {
618
627
}
619
628
if r .h .fin {
620
629
r .eofed = true
621
- select {
622
- case <- r .c .closed :
623
- return n , r .c .closeErr
624
- case r .c .setReadTimeout <- context .Background ():
625
- return n , io .EOF
626
- }
630
+ r .c .releaseLock (r .c .readDataLock )
631
+ return n , io .EOF
627
632
}
628
633
r .maskPos = 0
629
634
r .h = nil
0 commit comments