Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge master into dev #277

Merged
merged 9 commits into from
Jan 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 4 additions & 7 deletions accept.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,13 +163,13 @@ func verifyClientRequest(w http.ResponseWriter, r *http.Request) (errCode int, _
return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: handshake request must be at least HTTP/1.1: %q", r.Proto)
}

if !headerContainsToken(r.Header, "Connection", "Upgrade") {
if !headerContainsTokenIgnoreCase(r.Header, "Connection", "Upgrade") {
w.Header().Set("Connection", "Upgrade")
w.Header().Set("Upgrade", "websocket")
return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", r.Header.Get("Connection"))
}

if !headerContainsToken(r.Header, "Upgrade", "websocket") {
if !headerContainsTokenIgnoreCase(r.Header, "Upgrade", "websocket") {
w.Header().Set("Connection", "Upgrade")
w.Header().Set("Upgrade", "websocket")
return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", r.Header.Get("Upgrade"))
Expand Down Expand Up @@ -313,11 +313,9 @@ func acceptWebkitDeflate(w http.ResponseWriter, ext websocketExtension, mode Com
return copts, nil
}

func headerContainsToken(h http.Header, key, token string) bool {
token = strings.ToLower(token)

func headerContainsTokenIgnoreCase(h http.Header, key, token string) bool {
for _, t := range headerTokens(h, key) {
if t == token {
if strings.EqualFold(t, token) {
return true
}
}
Expand Down Expand Up @@ -358,7 +356,6 @@ func headerTokens(h http.Header, key string) []string {
for _, v := range h[key] {
v = strings.TrimSpace(v)
for _, t := range strings.Split(v, ",") {
t = strings.ToLower(t)
t = strings.TrimSpace(t)
tokens = append(tokens, t)
}
Expand Down
6 changes: 6 additions & 0 deletions accept_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,12 @@ func Test_selectSubprotocol(t *testing.T) {
serverProtocols: []string{"echo2", "echo3"},
negotiated: "echo3",
},
{
name: "clientCasePresered",
clientProtocols: []string{"Echo1"},
serverProtocols: []string{"echo1"},
negotiated: "Echo1",
},
}

for _, tc := range testCases {
Expand Down
6 changes: 3 additions & 3 deletions ci/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ main() {
cd "$(dirname "$0")/.."

go test -timeout=30m -covermode=atomic -coverprofile=ci/out/coverage.prof -coverpkg=./... "$@" ./...
sed -i '/stringer\.go/d' ci/out/coverage.prof
sed -i '/nhooyr.io\/websocket\/internal\/test/d' ci/out/coverage.prof
sed -i '/examples/d' ci/out/coverage.prof
sed -i.bak '/stringer\.go/d' ci/out/coverage.prof
sed -i.bak '/nhooyr.io\/websocket\/internal\/test/d' ci/out/coverage.prof
sed -i.bak '/examples/d' ci/out/coverage.prof

# Last line is the total coverage.
go tool cover -func ci/out/coverage.prof | tail -n1
Expand Down
31 changes: 20 additions & 11 deletions dial.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"context"
"crypto/rand"
"encoding/base64"
"errors"
"fmt"
"io"
"io/ioutil"
Expand Down Expand Up @@ -47,18 +46,27 @@ type DialOptions struct {
CompressionThreshold int
}

func (opts *DialOptions) cloneWithDefaults() *DialOptions {
func (opts *DialOptions) cloneWithDefaults(ctx context.Context) (context.Context, context.CancelFunc, *DialOptions) {
var cancel context.CancelFunc

var o DialOptions
if opts != nil {
o = *opts
}
if o.HTTPClient == nil {
o.HTTPClient = http.DefaultClient
} else if opts.HTTPClient.Timeout > 0 {
ctx, cancel = context.WithTimeout(ctx, opts.HTTPClient.Timeout)

newClient := *opts.HTTPClient
newClient.Timeout = 0
opts.HTTPClient = &newClient
}
if o.HTTPHeader == nil {
o.HTTPHeader = http.Header{}
}
return &o

return ctx, cancel, &o
}

// Dial performs a WebSocket handshake on url.
Expand All @@ -81,7 +89,11 @@ func Dial(ctx context.Context, u string, opts *DialOptions) (*Conn, *http.Respon
func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (_ *Conn, _ *http.Response, err error) {
defer errd.Wrap(&err, "failed to WebSocket dial")

opts = opts.cloneWithDefaults()
var cancel context.CancelFunc
ctx, cancel, opts = opts.cloneWithDefaults(ctx)
if cancel != nil {
defer cancel()
}

secWebSocketKey, err := secWebSocketKey(rand)
if err != nil {
Expand Down Expand Up @@ -137,10 +149,6 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (
}

func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, copts *compressionOptions, secWebSocketKey string) (*http.Response, error) {
if opts.HTTPClient.Timeout > 0 {
return nil, errors.New("use context for cancellation instead of http.Client.Timeout; see https://github.com/nhooyr/websocket/issues/67")
}

u, err := url.Parse(urls)
if err != nil {
return nil, fmt.Errorf("failed to parse url: %w", err)
Expand Down Expand Up @@ -193,11 +201,11 @@ func verifyServerResponse(opts *DialOptions, copts *compressionOptions, secWebSo
return nil, fmt.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode)
}

if !headerContainsToken(resp.Header, "Connection", "Upgrade") {
if !headerContainsTokenIgnoreCase(resp.Header, "Connection", "Upgrade") {
return nil, fmt.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", resp.Header.Get("Connection"))
}

if !headerContainsToken(resp.Header, "Upgrade", "WebSocket") {
if !headerContainsTokenIgnoreCase(resp.Header, "Upgrade", "WebSocket") {
return nil, fmt.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", resp.Header.Get("Upgrade"))
}

Expand Down Expand Up @@ -242,7 +250,8 @@ func verifyServerExtensions(copts *compressionOptions, h http.Header) (*compress
return nil, fmt.Errorf("WebSocket protcol violation: unsupported extensions from server: %+v", exts[1:])
}

copts = &*copts
_copts := *copts
copts = &_copts

for _, p := range ext.params {
switch p {
Expand Down
9 changes: 0 additions & 9 deletions dial_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,6 @@ func TestBadDials(t *testing.T) {
name: "badURLScheme",
url: "ftp://nhooyr.io",
},
{
name: "badHTTPClient",
url: "ws://nhooyr.io",
opts: &DialOptions{
HTTPClient: &http.Client{
Timeout: time.Minute,
},
},
},
{
name: "badTLS",
url: "wss://totallyfake.nhooyr.io",
Expand Down
1 change: 0 additions & 1 deletion examples/echo/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ import (
// It ensures the client speaks the echo subprotocol and
// only allows one message every 100ms with a 10 message burst.
type echoServer struct {

// logf controls where logs are sent.
logf func(f string, v ...interface{})
}
Expand Down
3 changes: 2 additions & 1 deletion internal/test/wstest/pipe.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ func Pipe(dialOpts *websocket.DialOptions, acceptOpts *websocket.AcceptOptions)
if dialOpts == nil {
dialOpts = &websocket.DialOptions{}
}
dialOpts = &*dialOpts
_dialOpts := *dialOpts
dialOpts = &_dialOpts
dialOpts.HTTPClient = &http.Client{
Transport: tt,
}
Expand Down