Skip to content

Commit

Permalink
Sec web socket protocol (#13)
Browse files Browse the repository at this point in the history
* Sec web socket protocol

* 更新

* 删除golangci-lint,有时间又再加回去

* 更新
  • Loading branch information
guonaihong authored Sep 4, 2023
1 parent 8faa921 commit b032a13
Show file tree
Hide file tree
Showing 8 changed files with 144 additions and 45 deletions.
38 changes: 0 additions & 38 deletions .github/workflows/golangci-lint.yml

This file was deleted.

4 changes: 2 additions & 2 deletions client_option_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ func Test_ClientOption(t *testing.T) {
}
})

t.Run("6.1 Dial: WithClientBindHTTPHeader", func(t *testing.T) {
t.Run("6.1 Dial: WithClientBindHTTPHeader and echo Sec-Websocket-Protocol", func(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := Upgrade(w, r)
if err != nil {
Expand All @@ -132,7 +132,7 @@ func Test_ClientOption(t *testing.T) {
}
})

t.Run("6.2 Dial: WithClientBindHTTPHeader", func(t *testing.T) {
t.Run("6.2 DialConf: WithClientBindHTTPHeader and echo Sec-Websocket-Protocol", func(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := Upgrade(w, r)
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion common_options_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1449,7 +1449,7 @@ func Test_CommonOption(t *testing.T) {
url := strings.ReplaceAll(ts.URL, "http", "ws")
con, err := Dial(url, WithClientDecompressAndCompress(),
WithClientDecompression(),
WithClientMaxDelayWriteDuration(10*time.Millisecond),
WithClientMaxDelayWriteDuration(30*time.Millisecond),
WithClientMaxDelayWriteNum(3),
WithClientWindowsParseMode(),
WithClientDelayWriteInitBufferSize(4096),
Expand Down
1 change: 1 addition & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ type Config struct {
maxDelayWriteNum int32 // 最大延迟包的个数, 默认值为10
delayWriteInitBufferSize int32 // 延迟写入的初始缓冲区大小, 默认值是8k
maxDelayWriteDuration time.Duration // 最大延迟时间, 默认值是10ms
subProtocols []string // 设置支持的子协议
}

func (c *Config) initPayloadSize() int {
Expand Down
13 changes: 13 additions & 0 deletions config_test.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,16 @@
// Copyright 2021-2023 antlabs. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package quickws

import "testing"
Expand Down
31 changes: 28 additions & 3 deletions server_handshake.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,12 @@ import (

var (
ErrNotFoundHijacker = errors.New("not found Hijacker")
bytesHeaderUpgrade = []byte("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept:")
bytesHeaderUpgrade = []byte("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: ")
bytesHeaderExtensions = []byte("Sec-WebSocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n")
bytesCRLF = []byte("\r\n")
strGetSecWebSocketProtocolKey = "Sec-WebSocket-Protocol"
bytesPutSecWebSocketProtocolKey = []byte("Sec-WebSocket-Protocol: ")
strGetSecWebSocketProtocolKey = "Sec-WebSocket-Protocol"
strWebSocketKey = "Sec-WebSocket-Key"
)

type ConnOption struct {
Expand All @@ -46,6 +47,29 @@ func writeHeaderVal(w io.Writer, val []byte) (err error) {
return
}

func subProtocol(subProtocol string, cnf *Config) string {
if subProtocol == "" {
return ""
}

subProtocols := strings.Split(subProtocol, ",")
// 如果配置了subProtocols, 则检查客户端的subProtocols是否在配置的subProtocols中
// 为什么要这么做,可以看下这个issue
// https://github.com/antlabs/quickws/issues/12
if len(cnf.subProtocols) > 0 {
for _, clientSubProtocols := range subProtocols {
clientSubProtocols = strings.TrimSpace(clientSubProtocols)
for _, serverSubProtocols := range cnf.subProtocols {
if clientSubProtocols == serverSubProtocols {
return clientSubProtocols
}
}
}
}
// echo Secf-WebSocket-Protocol 的值
return subProtocol
}

// https://datatracker.ietf.org/doc/html/rfc6455#section-4.2.2
// 第5小点
func prepareWriteResponse(r *http.Request, w io.Writer, cnf *Config) (err error) {
Expand All @@ -55,7 +79,7 @@ func prepareWriteResponse(r *http.Request, w io.Writer, cnf *Config) (err error)
return
}

v := secWebSocketAcceptVal(r.Header.Get("Sec-WebSocket-Key"))
v := secWebSocketAcceptVal(r.Header.Get(strWebSocketKey))
// 写入Sec-WebSocket-Accept vla
if err = writeHeaderVal(w, StringToBytes(v)); err != nil {
return err
Expand All @@ -69,6 +93,7 @@ func prepareWriteResponse(r *http.Request, w io.Writer, cnf *Config) (err error)
}

v = r.Header.Get(strGetSecWebSocketProtocolKey)
v = subProtocol(v, cnf)
if len(v) > 0 {
if _, err = w.Write(bytesPutSecWebSocketProtocolKey); err != nil {
return
Expand Down
91 changes: 91 additions & 0 deletions server_option_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,99 @@
// Copyright 2021-2023 antlabs. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package quickws

import (
"net/http"
"net/http/httptest"
"strings"
"testing"
)

func Test_ServerOption(t *testing.T) {
t.Run("2.1 Subprotocol", func(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := Upgrade(w, r, WithServerSubprotocols([]string{"crud", "im"}))
if err != nil {
t.Error(err)
}
}))

defer ts.Close()

url := strings.ReplaceAll(ts.URL, "http", "ws")
h := make(http.Header)
con, err := DialConf(url, ClientOptionToConf(WithClientBindHTTPHeader(&h), WithClientHTTPHeader(http.Header{
"Sec-WebSocket-Protocol": []string{"crud"},
})))
if err != nil {
t.Error(err)
}
defer con.Close()

if h["Sec-Websocket-Protocol"][0] != "crud" {
t.Error("header fail")
}
})

t.Run("2.2 Subprotocol", func(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := Upgrade(w, r, WithServerSubprotocols([]string{"crud", "im"}))
if err != nil {
t.Error(err)
}
}))

defer ts.Close()

url := strings.ReplaceAll(ts.URL, "http", "ws")
h := make(http.Header)
con, err := Dial(url, WithClientBindHTTPHeader(&h), WithClientHTTPHeader(http.Header{
"Sec-WebSocket-Protocol": []string{"crud"},
}))
if err != nil {
t.Error(err)
}
defer con.Close()

if h["Sec-Websocket-Protocol"][0] != "crud" {
t.Error("header fail")
}
})

t.Run("2.3 Subprotocol", func(t *testing.T) {
upgrade := NewUpgrade(WithServerSubprotocols([]string{"crud", "im"}))
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := upgrade.Upgrade(w, r)
if err != nil {
t.Error(err)
}
}))

defer ts.Close()

url := strings.ReplaceAll(ts.URL, "http", "ws")
h := make(http.Header)
con, err := Dial(url, WithClientBindHTTPHeader(&h), WithClientHTTPHeader(http.Header{
"Sec-WebSocket-Protocol": []string{"crud"},
}))
if err != nil {
t.Error(err)
}
defer con.Close()

if h["Sec-Websocket-Protocol"][0] != "crud" {
t.Error("header fail")
}
})
}
9 changes: 8 additions & 1 deletion server_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,17 @@ package quickws

type ServerOption func(*ConnOption)

// 2.配置压缩和解压缩
// 1.配置压缩和解压缩
func WithServerDecompressAndCompress() ServerOption {
return func(o *ConnOption) {
o.compression = true
o.decompression = true
}
}

// 2. 设置服务端支持的子协议
func WithServerSubprotocols(subprotocols []string) ServerOption {
return func(o *ConnOption) {
o.subProtocols = subprotocols
}
}

0 comments on commit b032a13

Please sign in to comment.