Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cherry pick dev branch and open PRs from upstream #5

Merged
merged 8 commits into from
Jul 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ jobs:

- name: install golint
run: go install golang.org/x/lint/golint@latest

- name: install goasmfmt
run: go install github.com/klauspost/asmfmt/cmd/asmfmt@latest

- run: goimports -local github.com/fortytw2/websocket -w .
- run: golint -set_exit_status ./...
Expand Down
7 changes: 6 additions & 1 deletion accept.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,12 @@ func authenticateOrigin(r *http.Request, originHosts []string) error {
return nil
}
}
return fmt.Errorf("request Origin %q is not authorized for Host %q", origin, r.Host)

if u.Host == "" {
return fmt.Errorf("request Origin %q is not a valid URL with a host", origin)
}

return fmt.Errorf("request Origin %q is not authorized for Host %q", u.Host, r.Host)
}

func match(pattern, s string) (bool, error) {
Expand Down
18 changes: 17 additions & 1 deletion accept_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,23 @@ func TestAccept(t *testing.T) {
r.Header.Set("Origin", "harhar.com")

_, err := Accept(w, r, nil)
assert.Contains(t, err, `request Origin "harhar.com" is not authorized for Host`)
assert.Contains(t, err, `request Origin "harhar.com" is not a valid URL with a host`)
})

// nhooyr.io/websocket#247
t.Run("unauthorizedOriginErrorMessage", func(t *testing.T) {
t.Parallel()

w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/", nil)
r.Header.Set("Connection", "Upgrade")
r.Header.Set("Upgrade", "websocket")
r.Header.Set("Sec-WebSocket-Version", "13")
r.Header.Set("Sec-WebSocket-Key", "meow123")
r.Header.Set("Origin", "https://harhar.com")

_, err := Accept(w, r, nil)
assert.Contains(t, err, `request Origin "harhar.com" is not authorized for Host "example.com"`)
})

t.Run("badCompression", func(t *testing.T) {
Expand Down
5 changes: 2 additions & 3 deletions close.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"errors"
"fmt"
"log"
"net"
"time"

"github.com/fortytw2/websocket/internal/errd"
Expand Down Expand Up @@ -116,15 +117,13 @@ func (c *Conn) closeHandshake(code StatusCode, reason string) (err error) {
return nil
}

var errAlreadyWroteClose = errors.New("already wrote close")

func (c *Conn) writeClose(code StatusCode, reason string) error {
c.closeMu.Lock()
wroteClose := c.wroteClose
c.wroteClose = true
c.closeMu.Unlock()
if wroteClose {
return errAlreadyWroteClose
return net.ErrClosed
}

ce := CloseError{
Expand Down
31 changes: 31 additions & 0 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,37 @@ func TestConn(t *testing.T) {
}
})

t.Run("netConn/readLimit", func(t *testing.T) {
tt, c1, c2 := newConnTest(t, nil, nil)

n1 := websocket.NetConn(tt.ctx, c1, websocket.MessageBinary)
n2 := websocket.NetConn(tt.ctx, c2, websocket.MessageBinary)

s := strings.Repeat("papa", 1<<20)
errs := xsync.Go(func() error {
_, err := n2.Write([]byte(s))
if err != nil {
return err
}
return n2.Close()
})

b, err := ioutil.ReadAll(n1)
assert.Success(t, err)

_, err = n1.Read(nil)
assert.Equal(t, "read error", err, io.EOF)

select {
case err := <-errs:
assert.Success(t, err)
case <-tt.ctx.Done():
t.Fatal(tt.ctx.Err())
}

assert.Equal(t, "read msg", s, string(b))
})

t.Run("wsjson", func(t *testing.T) {
tt, c1, c2 := newConnTest(t, nil, nil)
defer tt.cleanup()
Expand Down
123 changes: 0 additions & 123 deletions frame.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"fmt"
"io"
"math"
"math/bits"

"github.com/fortytw2/websocket/internal/errd"
)
Expand Down Expand Up @@ -170,125 +169,3 @@ func writeFrameHeader(h header, w *bufio.Writer, buf []byte) (err error) {

return nil
}

// mask applies the WebSocket masking algorithm to p
// with the given key.
// See https://tools.ietf.org/html/rfc6455#section-5.3
//
// The returned value is the correctly rotated key to
// to continue to mask/unmask the message.
//
// It is optimized for LittleEndian and expects the key
// to be in little endian.
//
// See https://github.com/golang/go/issues/31586
func mask(key uint32, b []byte) uint32 {
if len(b) >= 8 {
key64 := uint64(key)<<32 | uint64(key)

// At some point in the future we can clean these unrolled loops up.
// See https://github.com/golang/go/issues/31586#issuecomment-487436401

// Then we xor until b is less than 128 bytes.
for len(b) >= 128 {
v := binary.LittleEndian.Uint64(b)
binary.LittleEndian.PutUint64(b, v^key64)
v = binary.LittleEndian.Uint64(b[8:16])
binary.LittleEndian.PutUint64(b[8:16], v^key64)
v = binary.LittleEndian.Uint64(b[16:24])
binary.LittleEndian.PutUint64(b[16:24], v^key64)
v = binary.LittleEndian.Uint64(b[24:32])
binary.LittleEndian.PutUint64(b[24:32], v^key64)
v = binary.LittleEndian.Uint64(b[32:40])
binary.LittleEndian.PutUint64(b[32:40], v^key64)
v = binary.LittleEndian.Uint64(b[40:48])
binary.LittleEndian.PutUint64(b[40:48], v^key64)
v = binary.LittleEndian.Uint64(b[48:56])
binary.LittleEndian.PutUint64(b[48:56], v^key64)
v = binary.LittleEndian.Uint64(b[56:64])
binary.LittleEndian.PutUint64(b[56:64], v^key64)
v = binary.LittleEndian.Uint64(b[64:72])
binary.LittleEndian.PutUint64(b[64:72], v^key64)
v = binary.LittleEndian.Uint64(b[72:80])
binary.LittleEndian.PutUint64(b[72:80], v^key64)
v = binary.LittleEndian.Uint64(b[80:88])
binary.LittleEndian.PutUint64(b[80:88], v^key64)
v = binary.LittleEndian.Uint64(b[88:96])
binary.LittleEndian.PutUint64(b[88:96], v^key64)
v = binary.LittleEndian.Uint64(b[96:104])
binary.LittleEndian.PutUint64(b[96:104], v^key64)
v = binary.LittleEndian.Uint64(b[104:112])
binary.LittleEndian.PutUint64(b[104:112], v^key64)
v = binary.LittleEndian.Uint64(b[112:120])
binary.LittleEndian.PutUint64(b[112:120], v^key64)
v = binary.LittleEndian.Uint64(b[120:128])
binary.LittleEndian.PutUint64(b[120:128], v^key64)
b = b[128:]
}

// Then we xor until b is less than 64 bytes.
for len(b) >= 64 {
v := binary.LittleEndian.Uint64(b)
binary.LittleEndian.PutUint64(b, v^key64)
v = binary.LittleEndian.Uint64(b[8:16])
binary.LittleEndian.PutUint64(b[8:16], v^key64)
v = binary.LittleEndian.Uint64(b[16:24])
binary.LittleEndian.PutUint64(b[16:24], v^key64)
v = binary.LittleEndian.Uint64(b[24:32])
binary.LittleEndian.PutUint64(b[24:32], v^key64)
v = binary.LittleEndian.Uint64(b[32:40])
binary.LittleEndian.PutUint64(b[32:40], v^key64)
v = binary.LittleEndian.Uint64(b[40:48])
binary.LittleEndian.PutUint64(b[40:48], v^key64)
v = binary.LittleEndian.Uint64(b[48:56])
binary.LittleEndian.PutUint64(b[48:56], v^key64)
v = binary.LittleEndian.Uint64(b[56:64])
binary.LittleEndian.PutUint64(b[56:64], v^key64)
b = b[64:]
}

// Then we xor until b is less than 32 bytes.
for len(b) >= 32 {
v := binary.LittleEndian.Uint64(b)
binary.LittleEndian.PutUint64(b, v^key64)
v = binary.LittleEndian.Uint64(b[8:16])
binary.LittleEndian.PutUint64(b[8:16], v^key64)
v = binary.LittleEndian.Uint64(b[16:24])
binary.LittleEndian.PutUint64(b[16:24], v^key64)
v = binary.LittleEndian.Uint64(b[24:32])
binary.LittleEndian.PutUint64(b[24:32], v^key64)
b = b[32:]
}

// Then we xor until b is less than 16 bytes.
for len(b) >= 16 {
v := binary.LittleEndian.Uint64(b)
binary.LittleEndian.PutUint64(b, v^key64)
v = binary.LittleEndian.Uint64(b[8:16])
binary.LittleEndian.PutUint64(b[8:16], v^key64)
b = b[16:]
}

// Then we xor until b is less than 8 bytes.
for len(b) >= 8 {
v := binary.LittleEndian.Uint64(b)
binary.LittleEndian.PutUint64(b, v^key64)
b = b[8:]
}
}

// Then we xor until b is less than 4 bytes.
for len(b) >= 4 {
v := binary.LittleEndian.Uint32(b)
binary.LittleEndian.PutUint32(b, v^key)
b = b[4:]
}

// xor remaining bytes.
for i := range b {
b[i] ^= byte(key)
key = bits.RotateLeft32(key, -8)
}

return key
}
82 changes: 0 additions & 82 deletions frame_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ package websocket
import (
"bufio"
"bytes"
"encoding/binary"
"math/bits"
"math/rand"
"strconv"
"testing"
Expand Down Expand Up @@ -102,83 +100,3 @@ func testHeader(t *testing.T, h header) {

assert.Equal(t, "read header", h.payloadLength, h2.payloadLength)
}

func Test_mask(t *testing.T) {
t.Parallel()

key := []byte{0xa, 0xb, 0xc, 0xff}
key32 := binary.LittleEndian.Uint32(key)
p := []byte{0xa, 0xb, 0xc, 0xf2, 0xc}
gotKey32 := mask(key32, p)

expP := []byte{0, 0, 0, 0x0d, 0x6}
assert.Equal(t, "p", expP, p)

expKey32 := bits.RotateLeft32(key32, -8)
assert.Equal(t, "key32", expKey32, gotKey32)
}

func basicMask(maskKey [4]byte, pos int, b []byte) int {
for i := range b {
b[i] ^= maskKey[pos&3]
pos++
}
return pos & 3
}

func Benchmark_mask(b *testing.B) {
sizes := []int{
2,
3,
4,
8,
16,
32,
128,
512,
4096,
16384,
}

fns := []struct {
name string
fn func(b *testing.B, key [4]byte, p []byte)
}{
{
name: "basic",
fn: func(b *testing.B, key [4]byte, p []byte) {
for i := 0; i < b.N; i++ {
basicMask(key, 0, p)
}
},
},

{
name: "nhooyr",
fn: func(b *testing.B, key [4]byte, p []byte) {
key32 := binary.LittleEndian.Uint32(key[:])
b.ResetTimer()

for i := 0; i < b.N; i++ {
mask(key32, p)
}
},
},
}

key := [4]byte{1, 2, 3, 4}

for _, size := range sizes {
p := make([]byte, size)

b.Run(strconv.Itoa(size), func(b *testing.B) {
for _, fn := range fns {
b.Run(fn.name, func(b *testing.B) {
b.SetBytes(int64(size))

fn.fn(b, key, p)
})
}
})
}
}
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ go 1.18
require (
github.com/google/go-cmp v0.4.0
github.com/klauspost/compress v1.10.3
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e
golang.org/x/time v0.0.0-20191024005414-555d28b269f0
)

Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ github.com/google/go-cmp v0.4.0 h1:xsAVV57WRhGj6kEIi8ReJzQlHHqcBYCElAvkovg3B/4=
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/klauspost/compress v1.10.3 h1:OP96hzwJVBIHYU52pVTI6CczrxPvrGfgqF9N5eTO0Q8=
github.com/klauspost/compress v1.10.3/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs=
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e h1:CsOuNlbOuf0mzxJIefr6Q4uAUetRUwZE4qt7VfzP+xo=
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
Expand Down
Loading