Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
nhooyr committed Oct 26, 2023
1 parent f4e61e5 commit 53c53c2
Show file tree
Hide file tree
Showing 5 changed files with 166 additions and 135 deletions.
125 changes: 0 additions & 125 deletions frame.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"fmt"
"io"
"math"
"math/bits"

"nhooyr.io/websocket/internal/errd"
)
Expand Down Expand Up @@ -172,127 +171,3 @@ func writeFrameHeader(h header, w *bufio.Writer, buf []byte) (err error) {

return nil
}

// maskGo applies the WebSocket masking algorithm to p
// with the given key.
// See https://tools.ietf.org/html/rfc6455#section-5.3
//
// The returned value is the correctly rotated key to
// to continue to mask/unmask the message.
//
// It is optimized for LittleEndian and expects the key
// to be in little endian.
//
// See https://github.com/golang/go/issues/31586
//
//lint:ignore U1000 mask.go
func maskGo(b []byte, key uint32) uint32 {
if len(b) >= 8 {
key64 := uint64(key)<<32 | uint64(key)

// At some point in the future we can clean these unrolled loops up.
// See https://github.com/golang/go/issues/31586#issuecomment-487436401

// Then we xor until b is less than 128 bytes.
for len(b) >= 128 {
v := binary.LittleEndian.Uint64(b)
binary.LittleEndian.PutUint64(b, v^key64)
v = binary.LittleEndian.Uint64(b[8:16])
binary.LittleEndian.PutUint64(b[8:16], v^key64)
v = binary.LittleEndian.Uint64(b[16:24])
binary.LittleEndian.PutUint64(b[16:24], v^key64)
v = binary.LittleEndian.Uint64(b[24:32])
binary.LittleEndian.PutUint64(b[24:32], v^key64)
v = binary.LittleEndian.Uint64(b[32:40])
binary.LittleEndian.PutUint64(b[32:40], v^key64)
v = binary.LittleEndian.Uint64(b[40:48])
binary.LittleEndian.PutUint64(b[40:48], v^key64)
v = binary.LittleEndian.Uint64(b[48:56])
binary.LittleEndian.PutUint64(b[48:56], v^key64)
v = binary.LittleEndian.Uint64(b[56:64])
binary.LittleEndian.PutUint64(b[56:64], v^key64)
v = binary.LittleEndian.Uint64(b[64:72])
binary.LittleEndian.PutUint64(b[64:72], v^key64)
v = binary.LittleEndian.Uint64(b[72:80])
binary.LittleEndian.PutUint64(b[72:80], v^key64)
v = binary.LittleEndian.Uint64(b[80:88])
binary.LittleEndian.PutUint64(b[80:88], v^key64)
v = binary.LittleEndian.Uint64(b[88:96])
binary.LittleEndian.PutUint64(b[88:96], v^key64)
v = binary.LittleEndian.Uint64(b[96:104])
binary.LittleEndian.PutUint64(b[96:104], v^key64)
v = binary.LittleEndian.Uint64(b[104:112])
binary.LittleEndian.PutUint64(b[104:112], v^key64)
v = binary.LittleEndian.Uint64(b[112:120])
binary.LittleEndian.PutUint64(b[112:120], v^key64)
v = binary.LittleEndian.Uint64(b[120:128])
binary.LittleEndian.PutUint64(b[120:128], v^key64)
b = b[128:]
}

// Then we xor until b is less than 64 bytes.
for len(b) >= 64 {
v := binary.LittleEndian.Uint64(b)
binary.LittleEndian.PutUint64(b, v^key64)
v = binary.LittleEndian.Uint64(b[8:16])
binary.LittleEndian.PutUint64(b[8:16], v^key64)
v = binary.LittleEndian.Uint64(b[16:24])
binary.LittleEndian.PutUint64(b[16:24], v^key64)
v = binary.LittleEndian.Uint64(b[24:32])
binary.LittleEndian.PutUint64(b[24:32], v^key64)
v = binary.LittleEndian.Uint64(b[32:40])
binary.LittleEndian.PutUint64(b[32:40], v^key64)
v = binary.LittleEndian.Uint64(b[40:48])
binary.LittleEndian.PutUint64(b[40:48], v^key64)
v = binary.LittleEndian.Uint64(b[48:56])
binary.LittleEndian.PutUint64(b[48:56], v^key64)
v = binary.LittleEndian.Uint64(b[56:64])
binary.LittleEndian.PutUint64(b[56:64], v^key64)
b = b[64:]
}

// Then we xor until b is less than 32 bytes.
for len(b) >= 32 {
v := binary.LittleEndian.Uint64(b)
binary.LittleEndian.PutUint64(b, v^key64)
v = binary.LittleEndian.Uint64(b[8:16])
binary.LittleEndian.PutUint64(b[8:16], v^key64)
v = binary.LittleEndian.Uint64(b[16:24])
binary.LittleEndian.PutUint64(b[16:24], v^key64)
v = binary.LittleEndian.Uint64(b[24:32])
binary.LittleEndian.PutUint64(b[24:32], v^key64)
b = b[32:]
}

// Then we xor until b is less than 16 bytes.
for len(b) >= 16 {
v := binary.LittleEndian.Uint64(b)
binary.LittleEndian.PutUint64(b, v^key64)
v = binary.LittleEndian.Uint64(b[8:16])
binary.LittleEndian.PutUint64(b[8:16], v^key64)
b = b[16:]
}

// Then we xor until b is less than 8 bytes.
for len(b) >= 8 {
v := binary.LittleEndian.Uint64(b)
binary.LittleEndian.PutUint64(b, v^key64)
b = b[8:]
}
}

// Then we xor until b is less than 4 bytes.
for len(b) >= 4 {
v := binary.LittleEndian.Uint32(b)
binary.LittleEndian.PutUint32(b, v^key)
b = b[4:]
}

// xor remaining bytes.
for i := range b {
b[i] ^= byte(key)
key = bits.RotateLeft32(key, -8)
}

return key
}
131 changes: 127 additions & 4 deletions mask.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,130 @@
//go:build !amd64 && !arm64 && !js

package websocket

func mask(b []byte, key uint32) uint32 {
return maskGo(b, key)
import (
"encoding/binary"
"math/bits"
)

// maskGo applies the WebSocket masking algorithm to p
// with the given key.
// See https://tools.ietf.org/html/rfc6455#section-5.3
//
// The returned value is the correctly rotated key to
// to continue to mask/unmask the message.
//
// It is optimized for LittleEndian and expects the key
// to be in little endian.
//
// See https://github.com/golang/go/issues/31586
//
//lint:ignore U1000 mask.go
func maskGo(b []byte, key uint32) uint32 {
if len(b) >= 8 {
key64 := uint64(key)<<32 | uint64(key)

// At some point in the future we can clean these unrolled loops up.
// See https://github.com/golang/go/issues/31586#issuecomment-487436401

// Then we xor until b is less than 128 bytes.
for len(b) >= 128 {
v := binary.LittleEndian.Uint64(b)
binary.LittleEndian.PutUint64(b, v^key64)
v = binary.LittleEndian.Uint64(b[8:16])
binary.LittleEndian.PutUint64(b[8:16], v^key64)
v = binary.LittleEndian.Uint64(b[16:24])
binary.LittleEndian.PutUint64(b[16:24], v^key64)
v = binary.LittleEndian.Uint64(b[24:32])
binary.LittleEndian.PutUint64(b[24:32], v^key64)
v = binary.LittleEndian.Uint64(b[32:40])
binary.LittleEndian.PutUint64(b[32:40], v^key64)
v = binary.LittleEndian.Uint64(b[40:48])
binary.LittleEndian.PutUint64(b[40:48], v^key64)
v = binary.LittleEndian.Uint64(b[48:56])
binary.LittleEndian.PutUint64(b[48:56], v^key64)
v = binary.LittleEndian.Uint64(b[56:64])
binary.LittleEndian.PutUint64(b[56:64], v^key64)
v = binary.LittleEndian.Uint64(b[64:72])
binary.LittleEndian.PutUint64(b[64:72], v^key64)
v = binary.LittleEndian.Uint64(b[72:80])
binary.LittleEndian.PutUint64(b[72:80], v^key64)
v = binary.LittleEndian.Uint64(b[80:88])
binary.LittleEndian.PutUint64(b[80:88], v^key64)
v = binary.LittleEndian.Uint64(b[88:96])
binary.LittleEndian.PutUint64(b[88:96], v^key64)
v = binary.LittleEndian.Uint64(b[96:104])
binary.LittleEndian.PutUint64(b[96:104], v^key64)
v = binary.LittleEndian.Uint64(b[104:112])
binary.LittleEndian.PutUint64(b[104:112], v^key64)
v = binary.LittleEndian.Uint64(b[112:120])
binary.LittleEndian.PutUint64(b[112:120], v^key64)
v = binary.LittleEndian.Uint64(b[120:128])
binary.LittleEndian.PutUint64(b[120:128], v^key64)
b = b[128:]
}

// Then we xor until b is less than 64 bytes.
for len(b) >= 64 {
v := binary.LittleEndian.Uint64(b)
binary.LittleEndian.PutUint64(b, v^key64)
v = binary.LittleEndian.Uint64(b[8:16])
binary.LittleEndian.PutUint64(b[8:16], v^key64)
v = binary.LittleEndian.Uint64(b[16:24])
binary.LittleEndian.PutUint64(b[16:24], v^key64)
v = binary.LittleEndian.Uint64(b[24:32])
binary.LittleEndian.PutUint64(b[24:32], v^key64)
v = binary.LittleEndian.Uint64(b[32:40])
binary.LittleEndian.PutUint64(b[32:40], v^key64)
v = binary.LittleEndian.Uint64(b[40:48])
binary.LittleEndian.PutUint64(b[40:48], v^key64)
v = binary.LittleEndian.Uint64(b[48:56])
binary.LittleEndian.PutUint64(b[48:56], v^key64)
v = binary.LittleEndian.Uint64(b[56:64])
binary.LittleEndian.PutUint64(b[56:64], v^key64)
b = b[64:]
}

// Then we xor until b is less than 32 bytes.
for len(b) >= 32 {
v := binary.LittleEndian.Uint64(b)
binary.LittleEndian.PutUint64(b, v^key64)
v = binary.LittleEndian.Uint64(b[8:16])
binary.LittleEndian.PutUint64(b[8:16], v^key64)
v = binary.LittleEndian.Uint64(b[16:24])
binary.LittleEndian.PutUint64(b[16:24], v^key64)
v = binary.LittleEndian.Uint64(b[24:32])
binary.LittleEndian.PutUint64(b[24:32], v^key64)
b = b[32:]
}

// Then we xor until b is less than 16 bytes.
for len(b) >= 16 {
v := binary.LittleEndian.Uint64(b)
binary.LittleEndian.PutUint64(b, v^key64)
v = binary.LittleEndian.Uint64(b[8:16])
binary.LittleEndian.PutUint64(b[8:16], v^key64)
b = b[16:]
}

// Then we xor until b is less than 8 bytes.
for len(b) >= 8 {
v := binary.LittleEndian.Uint64(b)
binary.LittleEndian.PutUint64(b, v^key64)
b = b[8:]
}
}

// Then we xor until b is less than 4 bytes.
for len(b) >= 4 {
v := binary.LittleEndian.Uint32(b)
binary.LittleEndian.PutUint32(b, v^key)
b = b[4:]
}

// xor remaining bytes.
for i := range b {
b[i] ^= byte(key)
key = bits.RotateLeft32(key, -8)
}

return key
}
36 changes: 30 additions & 6 deletions mask_amd64.s
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,16 @@ TEXT ·maskAsm(SB), NOSPLIT, $0-28

CMPQ CX, $8
JL less_than_8
CMPQ CX, $512
CMPQ CX, $128
JLE sse
TESTQ $31, AX
JNZ unaligned

aligned:
CMPB ·useAVX2(SB), $1
JE avx2
JMP sse

unaligned_loop_1byte:
XORB SI, (AX)
INCQ AX
Expand All @@ -40,7 +45,7 @@ unaligned_loop_1byte:
ORQ DX, DI

TESTQ $31, AX
JZ sse
JZ aligned

unaligned:
// $7 & len, if not zero jump to loop_1b.
Expand All @@ -54,17 +59,36 @@ unaligned_loop:
SUBQ $8, CX
TESTQ $31, AX
JNZ unaligned_loop
JMP sse

JMP aligned

avx2:
CMPQ CX, $128
JL sse
VMOVQ DI, X0
VPBROADCASTQ X0, Y0

// TODO: shouldn't these be aligned movs now?
// TODO: should be 256?
avx2_loop:
VMOVDQU (AX), Y1
VPXOR Y0, Y1, Y2
VMOVDQU Y2, (AX)
ADDQ $128, AX
SUBQ $128, CX
CMPQ CX, $128
// Loop if CX >= 128.
JAE avx2_loop

// TODO: should be 128?
sse:
CMPQ CX, $64
JL less_than_64
MOVQ DI, X0
PUNPCKLQDQ X0, X0

sse_loop:
MOVOU 0*16(AX), X1
MOVOU 1*16(AX), X2
MOVOU (AX), X1
MOVOU 16(AX), X2
MOVOU 2*16(AX), X3
MOVOU 3*16(AX), X4
PXOR X0, X1
Expand Down
2 changes: 2 additions & 0 deletions mask_asm.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,7 @@ func mask(b []byte, key uint32) uint32 {
return key
}

var useAVX2 = true

Check failure on line 12 in mask_asm.go

View workflow job for this annotation

GitHub Actions / lint

var useAVX2 is unused (U1000)

//go:noescape
func maskAsm(b *byte, len int, key uint32) uint32
7 changes: 7 additions & 0 deletions mask_go.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
//go:build !amd64 && !arm64 && !js

package websocket

func mask(b []byte, key uint32) uint32 {
return maskGo(b, key)
}

0 comments on commit 53c53c2

Please sign in to comment.