diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index d9a3ef3..558dbd5 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -28,6 +28,9 @@ jobs: - name: install golint run: go install golang.org/x/lint/golint@latest + + - name: install goasmfmt + run: go install github.com/klauspost/asmfmt/cmd/asmfmt@latest - run: goimports -local github.com/fortytw2/websocket -w . - run: golint -set_exit_status ./... diff --git a/accept.go b/accept.go index 1a48c20..ec878e4 100644 --- a/accept.go +++ b/accept.go @@ -208,7 +208,12 @@ func authenticateOrigin(r *http.Request, originHosts []string) error { return nil } } - return fmt.Errorf("request Origin %q is not authorized for Host %q", origin, r.Host) + + if u.Host == "" { + return fmt.Errorf("request Origin %q is not a valid URL with a host", origin) + } + + return fmt.Errorf("request Origin %q is not authorized for Host %q", u.Host, r.Host) } func match(pattern, s string) (bool, error) { diff --git a/accept_test.go b/accept_test.go index a5e35fc..7acfb9d 100644 --- a/accept_test.go +++ b/accept_test.go @@ -40,7 +40,23 @@ func TestAccept(t *testing.T) { r.Header.Set("Origin", "harhar.com") _, err := Accept(w, r, nil) - assert.Contains(t, err, `request Origin "harhar.com" is not authorized for Host`) + assert.Contains(t, err, `request Origin "harhar.com" is not a valid URL with a host`) + }) + + // nhooyr.io/websocket#247 + t.Run("unauthorizedOriginErrorMessage", func(t *testing.T) { + t.Parallel() + + w := httptest.NewRecorder() + r := httptest.NewRequest("GET", "/", nil) + r.Header.Set("Connection", "Upgrade") + r.Header.Set("Upgrade", "websocket") + r.Header.Set("Sec-WebSocket-Version", "13") + r.Header.Set("Sec-WebSocket-Key", "meow123") + r.Header.Set("Origin", "https://harhar.com") + + _, err := Accept(w, r, nil) + assert.Contains(t, err, `request Origin "harhar.com" is not authorized for Host "example.com"`) }) t.Run("badCompression", func(t *testing.T) { diff --git a/close.go b/close.go index 46c4279..e6796cd 100644 --- a/close.go +++ b/close.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "log" + "net" "time" "github.com/fortytw2/websocket/internal/errd" @@ -116,15 +117,13 @@ func (c *Conn) closeHandshake(code StatusCode, reason string) (err error) { return nil } -var errAlreadyWroteClose = errors.New("already wrote close") - func (c *Conn) writeClose(code StatusCode, reason string) error { c.closeMu.Lock() wroteClose := c.wroteClose c.wroteClose = true c.closeMu.Unlock() if wroteClose { - return errAlreadyWroteClose + return net.ErrClosed } ce := CloseError{ diff --git a/conn_test.go b/conn_test.go index e2f403d..5b69537 100644 --- a/conn_test.go +++ b/conn_test.go @@ -207,6 +207,37 @@ func TestConn(t *testing.T) { } }) + t.Run("netConn/readLimit", func(t *testing.T) { + tt, c1, c2 := newConnTest(t, nil, nil) + + n1 := websocket.NetConn(tt.ctx, c1, websocket.MessageBinary) + n2 := websocket.NetConn(tt.ctx, c2, websocket.MessageBinary) + + s := strings.Repeat("papa", 1<<20) + errs := xsync.Go(func() error { + _, err := n2.Write([]byte(s)) + if err != nil { + return err + } + return n2.Close() + }) + + b, err := ioutil.ReadAll(n1) + assert.Success(t, err) + + _, err = n1.Read(nil) + assert.Equal(t, "read error", err, io.EOF) + + select { + case err := <-errs: + assert.Success(t, err) + case <-tt.ctx.Done(): + t.Fatal(tt.ctx.Err()) + } + + assert.Equal(t, "read msg", s, string(b)) + }) + t.Run("wsjson", func(t *testing.T) { tt, c1, c2 := newConnTest(t, nil, nil) defer tt.cleanup() diff --git a/frame.go b/frame.go index c1528e9..a6e5546 100644 --- a/frame.go +++ b/frame.go @@ -6,7 +6,6 @@ import ( "fmt" "io" "math" - "math/bits" "github.com/fortytw2/websocket/internal/errd" ) @@ -170,125 +169,3 @@ func writeFrameHeader(h header, w *bufio.Writer, buf []byte) (err error) { return nil } - -// mask applies the WebSocket masking algorithm to p -// with the given key. -// See https://tools.ietf.org/html/rfc6455#section-5.3 -// -// The returned value is the correctly rotated key to -// to continue to mask/unmask the message. -// -// It is optimized for LittleEndian and expects the key -// to be in little endian. -// -// See https://github.com/golang/go/issues/31586 -func mask(key uint32, b []byte) uint32 { - if len(b) >= 8 { - key64 := uint64(key)<<32 | uint64(key) - - // At some point in the future we can clean these unrolled loops up. - // See https://github.com/golang/go/issues/31586#issuecomment-487436401 - - // Then we xor until b is less than 128 bytes. - for len(b) >= 128 { - v := binary.LittleEndian.Uint64(b) - binary.LittleEndian.PutUint64(b, v^key64) - v = binary.LittleEndian.Uint64(b[8:16]) - binary.LittleEndian.PutUint64(b[8:16], v^key64) - v = binary.LittleEndian.Uint64(b[16:24]) - binary.LittleEndian.PutUint64(b[16:24], v^key64) - v = binary.LittleEndian.Uint64(b[24:32]) - binary.LittleEndian.PutUint64(b[24:32], v^key64) - v = binary.LittleEndian.Uint64(b[32:40]) - binary.LittleEndian.PutUint64(b[32:40], v^key64) - v = binary.LittleEndian.Uint64(b[40:48]) - binary.LittleEndian.PutUint64(b[40:48], v^key64) - v = binary.LittleEndian.Uint64(b[48:56]) - binary.LittleEndian.PutUint64(b[48:56], v^key64) - v = binary.LittleEndian.Uint64(b[56:64]) - binary.LittleEndian.PutUint64(b[56:64], v^key64) - v = binary.LittleEndian.Uint64(b[64:72]) - binary.LittleEndian.PutUint64(b[64:72], v^key64) - v = binary.LittleEndian.Uint64(b[72:80]) - binary.LittleEndian.PutUint64(b[72:80], v^key64) - v = binary.LittleEndian.Uint64(b[80:88]) - binary.LittleEndian.PutUint64(b[80:88], v^key64) - v = binary.LittleEndian.Uint64(b[88:96]) - binary.LittleEndian.PutUint64(b[88:96], v^key64) - v = binary.LittleEndian.Uint64(b[96:104]) - binary.LittleEndian.PutUint64(b[96:104], v^key64) - v = binary.LittleEndian.Uint64(b[104:112]) - binary.LittleEndian.PutUint64(b[104:112], v^key64) - v = binary.LittleEndian.Uint64(b[112:120]) - binary.LittleEndian.PutUint64(b[112:120], v^key64) - v = binary.LittleEndian.Uint64(b[120:128]) - binary.LittleEndian.PutUint64(b[120:128], v^key64) - b = b[128:] - } - - // Then we xor until b is less than 64 bytes. - for len(b) >= 64 { - v := binary.LittleEndian.Uint64(b) - binary.LittleEndian.PutUint64(b, v^key64) - v = binary.LittleEndian.Uint64(b[8:16]) - binary.LittleEndian.PutUint64(b[8:16], v^key64) - v = binary.LittleEndian.Uint64(b[16:24]) - binary.LittleEndian.PutUint64(b[16:24], v^key64) - v = binary.LittleEndian.Uint64(b[24:32]) - binary.LittleEndian.PutUint64(b[24:32], v^key64) - v = binary.LittleEndian.Uint64(b[32:40]) - binary.LittleEndian.PutUint64(b[32:40], v^key64) - v = binary.LittleEndian.Uint64(b[40:48]) - binary.LittleEndian.PutUint64(b[40:48], v^key64) - v = binary.LittleEndian.Uint64(b[48:56]) - binary.LittleEndian.PutUint64(b[48:56], v^key64) - v = binary.LittleEndian.Uint64(b[56:64]) - binary.LittleEndian.PutUint64(b[56:64], v^key64) - b = b[64:] - } - - // Then we xor until b is less than 32 bytes. - for len(b) >= 32 { - v := binary.LittleEndian.Uint64(b) - binary.LittleEndian.PutUint64(b, v^key64) - v = binary.LittleEndian.Uint64(b[8:16]) - binary.LittleEndian.PutUint64(b[8:16], v^key64) - v = binary.LittleEndian.Uint64(b[16:24]) - binary.LittleEndian.PutUint64(b[16:24], v^key64) - v = binary.LittleEndian.Uint64(b[24:32]) - binary.LittleEndian.PutUint64(b[24:32], v^key64) - b = b[32:] - } - - // Then we xor until b is less than 16 bytes. - for len(b) >= 16 { - v := binary.LittleEndian.Uint64(b) - binary.LittleEndian.PutUint64(b, v^key64) - v = binary.LittleEndian.Uint64(b[8:16]) - binary.LittleEndian.PutUint64(b[8:16], v^key64) - b = b[16:] - } - - // Then we xor until b is less than 8 bytes. - for len(b) >= 8 { - v := binary.LittleEndian.Uint64(b) - binary.LittleEndian.PutUint64(b, v^key64) - b = b[8:] - } - } - - // Then we xor until b is less than 4 bytes. - for len(b) >= 4 { - v := binary.LittleEndian.Uint32(b) - binary.LittleEndian.PutUint32(b, v^key) - b = b[4:] - } - - // xor remaining bytes. - for i := range b { - b[i] ^= byte(key) - key = bits.RotateLeft32(key, -8) - } - - return key -} diff --git a/frame_test.go b/frame_test.go index 7b2524b..a815f09 100644 --- a/frame_test.go +++ b/frame_test.go @@ -6,8 +6,6 @@ package websocket import ( "bufio" "bytes" - "encoding/binary" - "math/bits" "math/rand" "strconv" "testing" @@ -102,83 +100,3 @@ func testHeader(t *testing.T, h header) { assert.Equal(t, "read header", h.payloadLength, h2.payloadLength) } - -func Test_mask(t *testing.T) { - t.Parallel() - - key := []byte{0xa, 0xb, 0xc, 0xff} - key32 := binary.LittleEndian.Uint32(key) - p := []byte{0xa, 0xb, 0xc, 0xf2, 0xc} - gotKey32 := mask(key32, p) - - expP := []byte{0, 0, 0, 0x0d, 0x6} - assert.Equal(t, "p", expP, p) - - expKey32 := bits.RotateLeft32(key32, -8) - assert.Equal(t, "key32", expKey32, gotKey32) -} - -func basicMask(maskKey [4]byte, pos int, b []byte) int { - for i := range b { - b[i] ^= maskKey[pos&3] - pos++ - } - return pos & 3 -} - -func Benchmark_mask(b *testing.B) { - sizes := []int{ - 2, - 3, - 4, - 8, - 16, - 32, - 128, - 512, - 4096, - 16384, - } - - fns := []struct { - name string - fn func(b *testing.B, key [4]byte, p []byte) - }{ - { - name: "basic", - fn: func(b *testing.B, key [4]byte, p []byte) { - for i := 0; i < b.N; i++ { - basicMask(key, 0, p) - } - }, - }, - - { - name: "nhooyr", - fn: func(b *testing.B, key [4]byte, p []byte) { - key32 := binary.LittleEndian.Uint32(key[:]) - b.ResetTimer() - - for i := 0; i < b.N; i++ { - mask(key32, p) - } - }, - }, - } - - key := [4]byte{1, 2, 3, 4} - - for _, size := range sizes { - p := make([]byte, size) - - b.Run(strconv.Itoa(size), func(b *testing.B) { - for _, fn := range fns { - b.Run(fn.name, func(b *testing.B) { - b.SetBytes(int64(size)) - - fn.fn(b, key, p) - }) - } - }) - } -} diff --git a/go.mod b/go.mod index 7be9d82..19a4f3e 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.18 require ( github.com/google/go-cmp v0.4.0 github.com/klauspost/compress v1.10.3 + golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e golang.org/x/time v0.0.0-20191024005414-555d28b269f0 ) diff --git a/go.sum b/go.sum index cb290e8..d6270cc 100644 --- a/go.sum +++ b/go.sum @@ -2,6 +2,8 @@ github.com/google/go-cmp v0.4.0 h1:xsAVV57WRhGj6kEIi8ReJzQlHHqcBYCElAvkovg3B/4= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/klauspost/compress v1.10.3 h1:OP96hzwJVBIHYU52pVTI6CczrxPvrGfgqF9N5eTO0Q8= github.com/klauspost/compress v1.10.3/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs= +golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e h1:CsOuNlbOuf0mzxJIefr6Q4uAUetRUwZE4qt7VfzP+xo= +golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= diff --git a/internal/wsmask/mask_amd64.s b/internal/wsmask/mask_amd64.s new file mode 100644 index 0000000..caca53e --- /dev/null +++ b/internal/wsmask/mask_amd64.s @@ -0,0 +1,152 @@ +#include "textflag.h" + +// func maskAsm(b *byte, len int, key uint32) +TEXT ·maskAsm(SB), NOSPLIT, $0-28 + // AX = b + // CX = len (left length) + // SI = key (uint32) + // DI = uint64(SI) | uint64(SI)<<32 + MOVQ b+0(FP), AX + MOVQ len+8(FP), CX + MOVL key+16(FP), SI + + // calculate the DI + // DI = SI<<32 | SI + MOVL SI, DI + MOVQ DI, DX + SHLQ $32, DI + ORQ DX, DI + + CMPQ CX, $15 + JLE less_than_16 + CMPQ CX, $63 + JLE less_than_64 + CMPQ CX, $128 + JLE sse + TESTQ $31, AX + JNZ unaligned + +aligned: + CMPB ·useAVX2(SB), $1 + JE avx2 + JMP sse + +unaligned_loop_1byte: + XORB SI, (AX) + INCQ AX + DECQ CX + ROLL $24, SI + TESTQ $7, AX + JNZ unaligned_loop_1byte + + // calculate DI again since SI was modified + // DI = SI<<32 | SI + MOVL SI, DI + MOVQ DI, DX + SHLQ $32, DI + ORQ DX, DI + + TESTQ $31, AX + JZ aligned + +unaligned: + TESTQ $7, AX // AND $7 & len, if not zero jump to loop_1b. + JNZ unaligned_loop_1byte + +unaligned_loop: + // we don't need to check the CX since we know it's above 128 + XORQ DI, (AX) + ADDQ $8, AX + SUBQ $8, CX + TESTQ $31, AX + JNZ unaligned_loop + JMP aligned + +avx2: + CMPQ CX, $0x80 + JL sse + VMOVQ DI, X0 + VPBROADCASTQ X0, Y0 + +avx2_loop: + VPXOR (AX), Y0, Y1 + VPXOR 32(AX), Y0, Y2 + VPXOR 64(AX), Y0, Y3 + VPXOR 96(AX), Y0, Y4 + VMOVDQU Y1, (AX) + VMOVDQU Y2, 32(AX) + VMOVDQU Y3, 64(AX) + VMOVDQU Y4, 96(AX) + ADDQ $0x80, AX + SUBQ $0x80, CX + CMPQ CX, $0x80 + JAE avx2_loop // loop if CX >= 0x80 + +sse: + CMPQ CX, $0x40 + JL less_than_64 + MOVQ DI, X0 + PUNPCKLQDQ X0, X0 + +sse_loop: + MOVOU 0*16(AX), X1 + MOVOU 1*16(AX), X2 + MOVOU 2*16(AX), X3 + MOVOU 3*16(AX), X4 + PXOR X0, X1 + PXOR X0, X2 + PXOR X0, X3 + PXOR X0, X4 + MOVOU X1, 0*16(AX) + MOVOU X2, 1*16(AX) + MOVOU X3, 2*16(AX) + MOVOU X4, 3*16(AX) + ADDQ $0x40, AX + SUBQ $0x40, CX + CMPQ CX, $0x40 + JAE sse_loop + +less_than_64: + TESTQ $32, CX + JZ less_than_32 + XORQ DI, (AX) + XORQ DI, 8(AX) + XORQ DI, 16(AX) + XORQ DI, 24(AX) + ADDQ $32, AX + +less_than_32: + TESTQ $16, CX + JZ less_than_16 + XORQ DI, (AX) + XORQ DI, 8(AX) + ADDQ $16, AX + +less_than_16: + TESTQ $8, CX + JZ less_than_8 + XORQ DI, (AX) + ADDQ $8, AX + +less_than_8: + TESTQ $4, CX + JZ less_than_4 + XORL SI, (AX) + ADDQ $4, AX + +less_than_4: + TESTQ $2, CX + JZ less_than_2 + XORW SI, (AX) + ROLL $16, SI + ADDQ $2, AX + +less_than_2: + TESTQ $1, CX + JZ done + XORB SI, (AX) + ROLL $24, SI + +done: + MOVL SI, ret+24(FP) + RET diff --git a/internal/wsmask/mask_arm64.s b/internal/wsmask/mask_arm64.s new file mode 100644 index 0000000..624cb72 --- /dev/null +++ b/internal/wsmask/mask_arm64.s @@ -0,0 +1,74 @@ +#include "textflag.h" + +// func maskAsm(b *byte,len, int, key uint32) +TEXT ·maskAsm(SB), NOSPLIT, $0-28 + // R0 = b + // R1 = len + // R2 = uint64(key)<<32 | uint64(key) + // R3 = key (uint32) + MOVD b_ptr+0(FP), R0 + MOVD b_len+8(FP), R1 + MOVWU key+16(FP), R3 + MOVD R3, R2 + ORR R2<<32, R2, R2 + VDUP R2, V0.D2 + CMP $64, R1 + BLT less_than_64 + + // todo: optimize unaligned case +loop_64: + VLD1 (R0), [V1.B16, V2.B16, V3.B16, V4.B16] + VEOR V1.B16, V0.B16, V1.B16 + VEOR V2.B16, V0.B16, V2.B16 + VEOR V3.B16, V0.B16, V3.B16 + VEOR V4.B16, V0.B16, V4.B16 + VST1.P [V1.B16, V2.B16, V3.B16, V4.B16], 64(R0) + SUBS $64, R1 + CMP $64, R1 + BGE loop_64 + +less_than_64: + // quick end + CBZ R1, end + TBZ $5, R1, less_than32 + VLD1 (R0), [V1.B16, V2.B16] + VEOR V1.B16, V0.B16, V1.B16 + VEOR V2.B16, V0.B16, V2.B16 + VST1.P [V1.B16, V2.B16], 32(R0) + +less_than32: + TBZ $4, R1, less_than16 + LDP (R0), (R11, R12) + EOR R11, R2, R11 + EOR R12, R2, R12 + STP.P (R11, R12), 16(R0) + +less_than16: + TBZ $3, R1, less_than8 + MOVD (R0), R11 + EOR R2, R11, R11 + MOVD.P R11, 8(R0) + +less_than8: + TBZ $2, R1, less_than4 + MOVWU (R0), R11 + EORW R2, R11, R11 + MOVWU.P R11, 4(R0) + +less_than4: + TBZ $1, R1, less_than2 + MOVHU (R0), R11 + EORW R3, R11, R11 + MOVHU.P R11, 2(R0) + RORW $16, R3 + +less_than2: + TBZ $0, R1, end + MOVBU (R0), R11 + EORW R3, R11, R11 + MOVBU.P R11, 1(R0) + RORW $8, R3 + +end: + MOVWU R3, ret+24(FP) + RET diff --git a/internal/wsmask/mask_asm.go b/internal/wsmask/mask_asm.go new file mode 100644 index 0000000..c2bc0f1 --- /dev/null +++ b/internal/wsmask/mask_asm.go @@ -0,0 +1,21 @@ +//go:build !appengine && (amd64 || arm64) + +package wsmask + +import "golang.org/x/sys/cpu" + +// Mask applies the WebSocket masking algorithm to b +// with the given key. +// See https://tools.ietf.org/html/rfc6455#section-5.3 +func Mask(key uint32, b []byte) uint32 { + if len(b) > 0 { + return maskAsm(&b[0], len(b), key) + } + return key +} + +//lint:ignore U1000 used in asm +var useAVX2 = cpu.X86.HasAVX2 + +//go:noescape +func maskAsm(b *byte, len int, key uint32) uint32 diff --git a/internal/wsmask/mask_generic.go b/internal/wsmask/mask_generic.go new file mode 100644 index 0000000..491c7ed --- /dev/null +++ b/internal/wsmask/mask_generic.go @@ -0,0 +1,133 @@ +//go:build appengine || (!amd64 && !arm64) + +package wsmask + +import ( + "encoding/binary" + "math/bits" +) + +// Mask applies the WebSocket masking algorithm to b +// with the given key. +// See https://tools.ietf.org/html/rfc6455#section-5.3 +func Mask(key uint32, b []byte) uint32 { + return maskGo(key, b) +} + +// The returned value is the correctly rotated key to +// to continue to mask/unmask the message. +// +// It is optimized for LittleEndian and expects the key +// to be in little endian. +// +// See https://github.com/golang/go/issues/31586 +func maskGo(key uint32, b []byte) uint32 { + if len(b) >= 8 { + key64 := uint64(key)<<32 | uint64(key) + + // At some point in the future we can clean these unrolled loops up. + // See https://github.com/golang/go/issues/31586#issuecomment-487436401 + + // Then we xor until b is less than 128 bytes. + for len(b) >= 128 { + v := binary.LittleEndian.Uint64(b) + binary.LittleEndian.PutUint64(b, v^key64) + v = binary.LittleEndian.Uint64(b[8:16]) + binary.LittleEndian.PutUint64(b[8:16], v^key64) + v = binary.LittleEndian.Uint64(b[16:24]) + binary.LittleEndian.PutUint64(b[16:24], v^key64) + v = binary.LittleEndian.Uint64(b[24:32]) + binary.LittleEndian.PutUint64(b[24:32], v^key64) + v = binary.LittleEndian.Uint64(b[32:40]) + binary.LittleEndian.PutUint64(b[32:40], v^key64) + v = binary.LittleEndian.Uint64(b[40:48]) + binary.LittleEndian.PutUint64(b[40:48], v^key64) + v = binary.LittleEndian.Uint64(b[48:56]) + binary.LittleEndian.PutUint64(b[48:56], v^key64) + v = binary.LittleEndian.Uint64(b[56:64]) + binary.LittleEndian.PutUint64(b[56:64], v^key64) + v = binary.LittleEndian.Uint64(b[64:72]) + binary.LittleEndian.PutUint64(b[64:72], v^key64) + v = binary.LittleEndian.Uint64(b[72:80]) + binary.LittleEndian.PutUint64(b[72:80], v^key64) + v = binary.LittleEndian.Uint64(b[80:88]) + binary.LittleEndian.PutUint64(b[80:88], v^key64) + v = binary.LittleEndian.Uint64(b[88:96]) + binary.LittleEndian.PutUint64(b[88:96], v^key64) + v = binary.LittleEndian.Uint64(b[96:104]) + binary.LittleEndian.PutUint64(b[96:104], v^key64) + v = binary.LittleEndian.Uint64(b[104:112]) + binary.LittleEndian.PutUint64(b[104:112], v^key64) + v = binary.LittleEndian.Uint64(b[112:120]) + binary.LittleEndian.PutUint64(b[112:120], v^key64) + v = binary.LittleEndian.Uint64(b[120:128]) + binary.LittleEndian.PutUint64(b[120:128], v^key64) + b = b[128:] + } + + // Then we xor until b is less than 64 bytes. + for len(b) >= 64 { + v := binary.LittleEndian.Uint64(b) + binary.LittleEndian.PutUint64(b, v^key64) + v = binary.LittleEndian.Uint64(b[8:16]) + binary.LittleEndian.PutUint64(b[8:16], v^key64) + v = binary.LittleEndian.Uint64(b[16:24]) + binary.LittleEndian.PutUint64(b[16:24], v^key64) + v = binary.LittleEndian.Uint64(b[24:32]) + binary.LittleEndian.PutUint64(b[24:32], v^key64) + v = binary.LittleEndian.Uint64(b[32:40]) + binary.LittleEndian.PutUint64(b[32:40], v^key64) + v = binary.LittleEndian.Uint64(b[40:48]) + binary.LittleEndian.PutUint64(b[40:48], v^key64) + v = binary.LittleEndian.Uint64(b[48:56]) + binary.LittleEndian.PutUint64(b[48:56], v^key64) + v = binary.LittleEndian.Uint64(b[56:64]) + binary.LittleEndian.PutUint64(b[56:64], v^key64) + b = b[64:] + } + + // Then we xor until b is less than 32 bytes. + for len(b) >= 32 { + v := binary.LittleEndian.Uint64(b) + binary.LittleEndian.PutUint64(b, v^key64) + v = binary.LittleEndian.Uint64(b[8:16]) + binary.LittleEndian.PutUint64(b[8:16], v^key64) + v = binary.LittleEndian.Uint64(b[16:24]) + binary.LittleEndian.PutUint64(b[16:24], v^key64) + v = binary.LittleEndian.Uint64(b[24:32]) + binary.LittleEndian.PutUint64(b[24:32], v^key64) + b = b[32:] + } + + // Then we xor until b is less than 16 bytes. + for len(b) >= 16 { + v := binary.LittleEndian.Uint64(b) + binary.LittleEndian.PutUint64(b, v^key64) + v = binary.LittleEndian.Uint64(b[8:16]) + binary.LittleEndian.PutUint64(b[8:16], v^key64) + b = b[16:] + } + + // Then we xor until b is less than 8 bytes. + for len(b) >= 8 { + v := binary.LittleEndian.Uint64(b) + binary.LittleEndian.PutUint64(b, v^key64) + b = b[8:] + } + } + + // Then we xor until b is less than 4 bytes. + for len(b) >= 4 { + v := binary.LittleEndian.Uint32(b) + binary.LittleEndian.PutUint32(b, v^key) + b = b[4:] + } + + // xor remaining bytes. + for i := range b { + b[i] ^= byte(key) + key = bits.RotateLeft32(key, -8) + } + + return key +} diff --git a/internal/wsmask/mask_test.go b/internal/wsmask/mask_test.go new file mode 100644 index 0000000..641d51e --- /dev/null +++ b/internal/wsmask/mask_test.go @@ -0,0 +1,90 @@ +package wsmask + +import ( + "encoding/binary" + "math/bits" + "strconv" + "testing" + + "github.com/fortytw2/websocket/internal/test/assert" +) + +func TestMask(t *testing.T) { + t.Parallel() + + key := []byte{0xa, 0xb, 0xc, 0xff} + key32 := binary.LittleEndian.Uint32(key) + p := []byte{0xa, 0xb, 0xc, 0xf2, 0xc} + gotKey32 := Mask(key32, p) + + expP := []byte{0, 0, 0, 0x0d, 0x6} + assert.Equal(t, "p", expP, p) + + expKey32 := bits.RotateLeft32(key32, -8) + assert.Equal(t, "key32", expKey32, gotKey32) +} + +func BenchmarkMask(b *testing.B) { + sizes := []int{ + 2, + 3, + 4, + 8, + 16, + 32, + 128, + 512, + 4096, + 16384, + } + + fns := []struct { + name string + fn func(b *testing.B, key [4]byte, p []byte) + }{ + { + name: "basic", + fn: func(b *testing.B, key [4]byte, p []byte) { + for i := 0; i < b.N; i++ { + basicMask(key, 0, p) + } + }, + }, + + { + name: "nhooyr", + fn: func(b *testing.B, key [4]byte, p []byte) { + key32 := binary.LittleEndian.Uint32(key[:]) + b.ResetTimer() + + for i := 0; i < b.N; i++ { + Mask(key32, p) + } + }, + }, + } + + key := [4]byte{1, 2, 3, 4} + + for _, size := range sizes { + p := make([]byte, size) + + b.Run(strconv.Itoa(size), func(b *testing.B) { + for _, fn := range fns { + b.Run(fn.name, func(b *testing.B) { + b.SetBytes(int64(size)) + + fn.fn(b, key, p) + }) + } + }) + } +} + +func basicMask(maskKey [4]byte, pos int, b []byte) int { + for i := range b { + b[i] ^= maskKey[pos&3] + pos++ + } + return pos & 3 +} diff --git a/netconn.go b/netconn.go index 76631ba..146aa26 100644 --- a/netconn.go +++ b/netconn.go @@ -37,7 +37,11 @@ import ( // // A received StatusNormalClosure or StatusGoingAway close frame will be translated to // io.EOF when reading. +// +// Furthermore, the ReadLimit is set to -1 to disable it. func NetConn(ctx context.Context, c *Conn, msgType MessageType) net.Conn { + c.SetReadLimit(-1) + nc := &netConn{ c: c, msgType: msgType, @@ -77,6 +81,8 @@ type netConn struct { var _ net.Conn = &netConn{} func (c *netConn) Close() error { + c.writeTimer.Stop() + c.readTimer.Stop() return c.c.Close(StatusNormalClosure, "") } diff --git a/read.go b/read.go index d93ca4b..474f093 100644 --- a/read.go +++ b/read.go @@ -14,6 +14,7 @@ import ( "time" "github.com/fortytw2/websocket/internal/errd" + "github.com/fortytw2/websocket/internal/wsmask" "github.com/fortytw2/websocket/internal/xsync" ) @@ -70,10 +71,16 @@ func (c *Conn) CloseRead(ctx context.Context) context.Context { // By default, the connection has a message read limit of 32768 bytes. // // When the limit is hit, the connection will be closed with StatusMessageTooBig. +// +// Set to -1 to disable. func (c *Conn) SetReadLimit(n int64) { - // We add read one more byte than the limit in case - // there is a fin frame that needs to be read. - c.msgReader.limitReader.limit.Store(n + 1) + if n >= 0 { + // We read one more byte than the limit in case + // there is a fin frame that needs to be read. + n++ + } + + c.msgReader.limitReader.limit.Store(n) } const defaultReadLimit = 32768 @@ -261,7 +268,7 @@ func (c *Conn) handleControl(ctx context.Context, h header) (err error) { } if h.masked { - mask(h.maskKey, b) + wsmask.Mask(h.maskKey, b) } switch h.opcode { @@ -425,7 +432,7 @@ func (mr *msgReader) read(p []byte) (int, error) { mr.payloadLength -= int64(n) if !mr.c.client { - mr.maskKey = mask(mr.maskKey, p) + mr.maskKey = wsmask.Mask(mr.maskKey, p) } return n, nil @@ -454,7 +461,11 @@ func (lr *limitReader) reset(r io.Reader) { } func (lr *limitReader) Read(p []byte) (int, error) { - if lr.n <= 0 { + if lr.n < 0 { + return lr.r.Read(p) + } + + if lr.n == 0 { err := fmt.Errorf("read limited at %v bytes", lr.limit.Load()) lr.c.writeError(StatusMessageTooBig, err) return 0, err @@ -465,6 +476,9 @@ func (lr *limitReader) Read(p []byte) (int, error) { } n, err := lr.r.Read(p) lr.n -= int64(n) + if lr.n < 0 { + lr.n = 0 + } return n, err } diff --git a/write.go b/write.go index 8f7e240..f864263 100644 --- a/write.go +++ b/write.go @@ -8,14 +8,15 @@ import ( "context" "crypto/rand" "encoding/binary" - "errors" "fmt" "io" + "net" "time" "github.com/klauspost/compress/flate" "github.com/fortytw2/websocket/internal/errd" + "github.com/fortytw2/websocket/internal/wsmask" ) // Writer returns a writer bounded by the context that will write @@ -54,14 +55,14 @@ type msgWriter struct { func (mw *msgWriter) Write(p []byte) (int, error) { if mw.closed { - return 0, errors.New("cannot use closed writer") + return 0, net.ErrClosed } return mw.mw.Write(p) } func (mw *msgWriter) Close() error { if mw.closed { - return errors.New("cannot use closed writer") + return net.ErrClosed } mw.closed = true return mw.mw.Close() @@ -247,7 +248,6 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco if err != nil { return 0, err } - defer c.writeFrameMu.unlock() // If the state says a close has already been written, we wait until // the connection is closed and return that error. @@ -258,6 +258,7 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco wroteClose := c.wroteClose c.closeMu.Unlock() if wroteClose && opcode != opClose { + c.writeFrameMu.unlock() select { case <-ctx.Done(): return 0, ctx.Err() @@ -265,6 +266,7 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco return 0, c.closeErr } } + defer c.writeFrameMu.unlock() select { case <-c.closed: @@ -359,7 +361,7 @@ func (c *Conn) writeFramePayload(p []byte) (n int, err error) { return n, err } - maskKey = mask(maskKey, c.writeBuf[i:c.bw.Buffered()]) + maskKey = wsmask.Mask(maskKey, c.writeBuf[i:c.bw.Buffered()]) p = p[j:] n += j