Skip to content

Commit 06beb14

Browse files
committed
Merge branch 'master' into merge-master
There were a few PRs merged into the master branch that were then not merged into the dev branch. This branch merges those changes in cleanly. - #261 - #266 - #273
2 parents fdc4079 + e4c3b0f commit 06beb14

File tree

5 files changed

+28
-27
lines changed

5 files changed

+28
-27
lines changed

accept.go

+4-7
Original file line numberDiff line numberDiff line change
@@ -163,13 +163,13 @@ func verifyClientRequest(w http.ResponseWriter, r *http.Request) (errCode int, _
163163
return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: handshake request must be at least HTTP/1.1: %q", r.Proto)
164164
}
165165

166-
if !headerContainsToken(r.Header, "Connection", "Upgrade") {
166+
if !headerContainsTokenIgnoreCase(r.Header, "Connection", "Upgrade") {
167167
w.Header().Set("Connection", "Upgrade")
168168
w.Header().Set("Upgrade", "websocket")
169169
return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", r.Header.Get("Connection"))
170170
}
171171

172-
if !headerContainsToken(r.Header, "Upgrade", "websocket") {
172+
if !headerContainsTokenIgnoreCase(r.Header, "Upgrade", "websocket") {
173173
w.Header().Set("Connection", "Upgrade")
174174
w.Header().Set("Upgrade", "websocket")
175175
return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", r.Header.Get("Upgrade"))
@@ -313,11 +313,9 @@ func acceptWebkitDeflate(w http.ResponseWriter, ext websocketExtension, mode Com
313313
return copts, nil
314314
}
315315

316-
func headerContainsToken(h http.Header, key, token string) bool {
317-
token = strings.ToLower(token)
318-
316+
func headerContainsTokenIgnoreCase(h http.Header, key, token string) bool {
319317
for _, t := range headerTokens(h, key) {
320-
if t == token {
318+
if strings.EqualFold(t, token) {
321319
return true
322320
}
323321
}
@@ -358,7 +356,6 @@ func headerTokens(h http.Header, key string) []string {
358356
for _, v := range h[key] {
359357
v = strings.TrimSpace(v)
360358
for _, t := range strings.Split(v, ",") {
361-
t = strings.ToLower(t)
362359
t = strings.TrimSpace(t)
363360
tokens = append(tokens, t)
364361
}

accept_test.go

+6
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,12 @@ func Test_selectSubprotocol(t *testing.T) {
226226
serverProtocols: []string{"echo2", "echo3"},
227227
negotiated: "echo3",
228228
},
229+
{
230+
name: "clientCasePresered",
231+
clientProtocols: []string{"Echo1"},
232+
serverProtocols: []string{"echo1"},
233+
negotiated: "Echo1",
234+
},
229235
}
230236

231237
for _, tc := range testCases {

dial.go

+18-10
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ import (
88
"context"
99
"crypto/rand"
1010
"encoding/base64"
11-
"errors"
1211
"fmt"
1312
"io"
1413
"io/ioutil"
@@ -47,18 +46,27 @@ type DialOptions struct {
4746
CompressionThreshold int
4847
}
4948

50-
func (opts *DialOptions) cloneWithDefaults() *DialOptions {
49+
func (opts *DialOptions) cloneWithDefaults(ctx context.Context) (context.Context, context.CancelFunc, *DialOptions) {
50+
var cancel context.CancelFunc
51+
5152
var o DialOptions
5253
if opts != nil {
5354
o = *opts
5455
}
5556
if o.HTTPClient == nil {
5657
o.HTTPClient = http.DefaultClient
58+
} else if opts.HTTPClient.Timeout > 0 {
59+
ctx, cancel = context.WithTimeout(ctx, opts.HTTPClient.Timeout)
60+
61+
newClient := *opts.HTTPClient
62+
newClient.Timeout = 0
63+
opts.HTTPClient = &newClient
5764
}
5865
if o.HTTPHeader == nil {
5966
o.HTTPHeader = http.Header{}
6067
}
61-
return &o
68+
69+
return ctx, cancel, &o
6270
}
6371

6472
// Dial performs a WebSocket handshake on url.
@@ -81,7 +89,11 @@ func Dial(ctx context.Context, u string, opts *DialOptions) (*Conn, *http.Respon
8189
func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (_ *Conn, _ *http.Response, err error) {
8290
defer errd.Wrap(&err, "failed to WebSocket dial")
8391

84-
opts = opts.cloneWithDefaults()
92+
var cancel context.CancelFunc
93+
ctx, cancel, opts = opts.cloneWithDefaults(ctx)
94+
if cancel != nil {
95+
defer cancel()
96+
}
8597

8698
secWebSocketKey, err := secWebSocketKey(rand)
8799
if err != nil {
@@ -137,10 +149,6 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (
137149
}
138150

139151
func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, copts *compressionOptions, secWebSocketKey string) (*http.Response, error) {
140-
if opts.HTTPClient.Timeout > 0 {
141-
return nil, errors.New("use context for cancellation instead of http.Client.Timeout; see https://github.com/nhooyr/websocket/issues/67")
142-
}
143-
144152
u, err := url.Parse(urls)
145153
if err != nil {
146154
return nil, fmt.Errorf("failed to parse url: %w", err)
@@ -193,11 +201,11 @@ func verifyServerResponse(opts *DialOptions, copts *compressionOptions, secWebSo
193201
return nil, fmt.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode)
194202
}
195203

196-
if !headerContainsToken(resp.Header, "Connection", "Upgrade") {
204+
if !headerContainsTokenIgnoreCase(resp.Header, "Connection", "Upgrade") {
197205
return nil, fmt.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", resp.Header.Get("Connection"))
198206
}
199207

200-
if !headerContainsToken(resp.Header, "Upgrade", "WebSocket") {
208+
if !headerContainsTokenIgnoreCase(resp.Header, "Upgrade", "WebSocket") {
201209
return nil, fmt.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", resp.Header.Get("Upgrade"))
202210
}
203211

dial_test.go

-9
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,6 @@ func TestBadDials(t *testing.T) {
3636
name: "badURLScheme",
3737
url: "ftp://nhooyr.io",
3838
},
39-
{
40-
name: "badHTTPClient",
41-
url: "ws://nhooyr.io",
42-
opts: &DialOptions{
43-
HTTPClient: &http.Client{
44-
Timeout: time.Minute,
45-
},
46-
},
47-
},
4839
{
4940
name: "badTLS",
5041
url: "wss://totallyfake.nhooyr.io",

examples/echo/server.go

-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ import (
1616
// It ensures the client speaks the echo subprotocol and
1717
// only allows one message every 100ms with a 10 message burst.
1818
type echoServer struct {
19-
2019
// logf controls where logs are sent.
2120
logf func(f string, v ...interface{})
2221
}

0 commit comments

Comments
 (0)