@@ -3,6 +3,7 @@ package websocket
3
3
import (
4
4
"bufio"
5
5
"context"
6
+ cryptorand "crypto/rand"
6
7
"fmt"
7
8
"io"
8
9
"io/ioutil"
@@ -26,8 +27,11 @@ type Conn struct {
26
27
subprotocol string
27
28
br * bufio.Reader
28
29
bw * bufio.Writer
29
- closer io.Closer
30
- client bool
30
+ // writeBuf is used for masking, its the buffer in bufio.Writer.
31
+ // Only used by the client.
32
+ writeBuf []byte
33
+ closer io.Closer
34
+ client bool
31
35
32
36
// read limit for a message in bytes.
33
37
msgReadLimit int64
@@ -581,22 +585,22 @@ func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, err
581
585
// See the Writer method if you want to stream a message. The docs on Writer
582
586
// regarding concurrency also apply to this method.
583
587
func (c * Conn ) Write (ctx context.Context , typ MessageType , p []byte ) error {
584
- err := c .write (ctx , typ , p )
588
+ _ , err := c .write (ctx , typ , p )
585
589
if err != nil {
586
590
return xerrors .Errorf ("failed to write msg: %w" , err )
587
591
}
588
592
return nil
589
593
}
590
594
591
- func (c * Conn ) write (ctx context.Context , typ MessageType , p []byte ) error {
595
+ func (c * Conn ) write (ctx context.Context , typ MessageType , p []byte ) ( int , error ) {
592
596
err := c .acquireLock (ctx , c .writeMsgLock )
593
597
if err != nil {
594
- return err
598
+ return 0 , err
595
599
}
596
600
defer c .releaseLock (c .writeMsgLock )
597
601
598
- err = c .writeFrame (ctx , true , opcode (typ ), p )
599
- return err
602
+ n , err : = c .writeFrame (ctx , true , opcode (typ ), p )
603
+ return n , err
600
604
}
601
605
602
606
// messageWriter enables writing to a WebSocket connection.
@@ -620,12 +624,12 @@ func (w *messageWriter) write(p []byte) (int, error) {
620
624
if w .closed {
621
625
return 0 , xerrors .Errorf ("cannot use closed writer" )
622
626
}
623
- err := w .c .writeFrame (w .ctx , false , w .opcode , p )
627
+ n , err := w .c .writeFrame (w .ctx , false , w .opcode , p )
624
628
if err != nil {
625
- return 0 , xerrors .Errorf ("failed to write data frame: %w" , err )
629
+ return n , xerrors .Errorf ("failed to write data frame: %w" , err )
626
630
}
627
631
w .opcode = opContinuation
628
- return len ( p ) , nil
632
+ return n , nil
629
633
}
630
634
631
635
// Close flushes the frame to the connection.
@@ -644,7 +648,7 @@ func (w *messageWriter) close() error {
644
648
}
645
649
w .closed = true
646
650
647
- err := w .c .writeFrame (w .ctx , true , w .opcode , nil )
651
+ _ , err := w .c .writeFrame (w .ctx , true , w .opcode , nil )
648
652
if err != nil {
649
653
return xerrors .Errorf ("failed to write fin frame: %w" , err )
650
654
}
@@ -654,34 +658,40 @@ func (w *messageWriter) close() error {
654
658
}
655
659
656
660
func (c * Conn ) writeControl (ctx context.Context , opcode opcode , p []byte ) error {
657
- err := c .writeFrame (ctx , true , opcode , p )
661
+ _ , err := c .writeFrame (ctx , true , opcode , p )
658
662
if err != nil {
659
663
return xerrors .Errorf ("failed to write control frame: %w" , err )
660
664
}
661
665
return nil
662
666
}
663
667
664
668
// writeFrame handles all writes to the connection.
665
- // We never mask inside here because our mask key is always 0,0,0,0.
666
- // See comment on secWebSocketKey for why.
667
- func (c * Conn ) writeFrame (ctx context.Context , fin bool , opcode opcode , p []byte ) error {
669
+ func (c * Conn ) writeFrame (ctx context.Context , fin bool , opcode opcode , p []byte ) (int , error ) {
668
670
h := header {
669
671
fin : fin ,
670
672
opcode : opcode ,
671
673
masked : c .client ,
672
674
payloadLength : int64 (len (p )),
673
675
}
676
+
677
+ if c .client {
678
+ _ , err := io .ReadFull (cryptorand .Reader , h .maskKey [:])
679
+ if err != nil {
680
+ return 0 , xerrors .Errorf ("failed to generate masking key: %w" , err )
681
+ }
682
+ }
683
+
674
684
b2 := marshalHeader (h )
675
685
676
686
err := c .acquireLock (ctx , c .writeFrameLock )
677
687
if err != nil {
678
- return err
688
+ return 0 , err
679
689
}
680
690
defer c .releaseLock (c .writeFrameLock )
681
691
682
692
select {
683
693
case <- c .closed :
684
- return c .closeErr
694
+ return 0 , c .closeErr
685
695
case c .setWriteTimeout <- ctx :
686
696
}
687
697
@@ -705,29 +715,61 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte
705
715
706
716
_ , err = c .bw .Write (b2 )
707
717
if err != nil {
708
- return writeErr (err )
709
- }
710
- _ , err = c .bw .Write (p )
711
- if err != nil {
712
- return writeErr (err )
718
+ return 0 , writeErr (err )
719
+ }
720
+
721
+ var n int
722
+ if c .client {
723
+ var keypos int
724
+ for len (p ) > 0 {
725
+ if c .bw .Available () == 0 {
726
+ err = c .bw .Flush ()
727
+ if err != nil {
728
+ return n , writeErr (err )
729
+ }
730
+ }
731
+
732
+ // Start of next write in the buffer.
733
+ i := c .bw .Buffered ()
734
+
735
+ p2 := p
736
+ if len (p ) > c .bw .Available () {
737
+ p2 = p [:c .bw .Available ()]
738
+ }
739
+
740
+ n2 , err := c .bw .Write (p2 )
741
+ if err != nil {
742
+ return n , writeErr (err )
743
+ }
744
+
745
+ keypos = fastXOR (h .maskKey , keypos , c .writeBuf [i :i + n2 ])
746
+
747
+ p = p [n2 :]
748
+ n += n2
749
+ }
750
+ } else {
751
+ n , err = c .bw .Write (p )
752
+ if err != nil {
753
+ return n , writeErr (err )
754
+ }
713
755
}
714
756
715
757
if fin {
716
758
err = c .bw .Flush ()
717
759
if err != nil {
718
- return writeErr (err )
760
+ return n , writeErr (err )
719
761
}
720
762
}
721
763
722
764
// We already finished writing, no need to potentially brick the connection if
723
765
// the context expires.
724
766
select {
725
767
case <- c .closed :
726
- return c .closeErr
768
+ return n , c .closeErr
727
769
case c .setWriteTimeout <- context .Background ():
728
770
}
729
771
730
- return nil
772
+ return n , nil
731
773
}
732
774
733
775
func (c * Conn ) writePong (p []byte ) error {
@@ -842,3 +884,23 @@ func (c *Conn) ping(ctx context.Context) error {
842
884
return nil
843
885
}
844
886
}
887
+
888
+ type writerFunc func (p []byte ) (int , error )
889
+
890
+ func (f writerFunc ) Write (p []byte ) (int , error ) {
891
+ return f (p )
892
+ }
893
+
894
+ // extractBufioWriterBuf grabs the []byte backing a *bufio.Writer
895
+ // and stores it in c.writeBuf.
896
+ func (c * Conn ) extractBufioWriterBuf (w io.Writer ) {
897
+ c .bw .Reset (writerFunc (func (p2 []byte ) (int , error ) {
898
+ c .writeBuf = p2 [:cap (p2 )]
899
+ return len (p2 ), nil
900
+ }))
901
+
902
+ c .bw .WriteByte (0 )
903
+ c .bw .Flush ()
904
+
905
+ c .bw .Reset (w )
906
+ }
0 commit comments