8
8
"context"
9
9
"crypto/rand"
10
10
"encoding/base64"
11
- "errors"
12
11
"fmt"
13
12
"io"
14
13
"io/ioutil"
@@ -47,18 +46,27 @@ type DialOptions struct {
47
46
CompressionThreshold int
48
47
}
49
48
50
- func (opts * DialOptions ) cloneWithDefaults () * DialOptions {
49
+ func (opts * DialOptions ) cloneWithDefaults (ctx context.Context ) (context.Context , context.CancelFunc , * DialOptions ) {
50
+ var cancel context.CancelFunc
51
+
51
52
var o DialOptions
52
53
if opts != nil {
53
54
o = * opts
54
55
}
55
56
if o .HTTPClient == nil {
56
57
o .HTTPClient = http .DefaultClient
58
+ } else if opts .HTTPClient .Timeout > 0 {
59
+ ctx , cancel = context .WithTimeout (ctx , opts .HTTPClient .Timeout )
60
+
61
+ newClient := * opts .HTTPClient
62
+ newClient .Timeout = 0
63
+ opts .HTTPClient = & newClient
57
64
}
58
65
if o .HTTPHeader == nil {
59
66
o .HTTPHeader = http.Header {}
60
67
}
61
- return & o
68
+
69
+ return ctx , cancel , & o
62
70
}
63
71
64
72
// Dial performs a WebSocket handshake on url.
@@ -81,7 +89,11 @@ func Dial(ctx context.Context, u string, opts *DialOptions) (*Conn, *http.Respon
81
89
func dial (ctx context.Context , urls string , opts * DialOptions , rand io.Reader ) (_ * Conn , _ * http.Response , err error ) {
82
90
defer errd .Wrap (& err , "failed to WebSocket dial" )
83
91
84
- opts = opts .cloneWithDefaults ()
92
+ var cancel context.CancelFunc
93
+ ctx , cancel , opts = opts .cloneWithDefaults (ctx )
94
+ if cancel != nil {
95
+ defer cancel ()
96
+ }
85
97
86
98
secWebSocketKey , err := secWebSocketKey (rand )
87
99
if err != nil {
@@ -137,10 +149,6 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (
137
149
}
138
150
139
151
func handshakeRequest (ctx context.Context , urls string , opts * DialOptions , copts * compressionOptions , secWebSocketKey string ) (* http.Response , error ) {
140
- if opts .HTTPClient .Timeout > 0 {
141
- return nil , errors .New ("use context for cancellation instead of http.Client.Timeout; see https://github.com/nhooyr/websocket/issues/67" )
142
- }
143
-
144
152
u , err := url .Parse (urls )
145
153
if err != nil {
146
154
return nil , fmt .Errorf ("failed to parse url: %w" , err )
@@ -193,11 +201,11 @@ func verifyServerResponse(opts *DialOptions, copts *compressionOptions, secWebSo
193
201
return nil , fmt .Errorf ("expected handshake response status code %v but got %v" , http .StatusSwitchingProtocols , resp .StatusCode )
194
202
}
195
203
196
- if ! headerContainsToken (resp .Header , "Connection" , "Upgrade" ) {
204
+ if ! headerContainsTokenIgnoreCase (resp .Header , "Connection" , "Upgrade" ) {
197
205
return nil , fmt .Errorf ("WebSocket protocol violation: Connection header %q does not contain Upgrade" , resp .Header .Get ("Connection" ))
198
206
}
199
207
200
- if ! headerContainsToken (resp .Header , "Upgrade" , "WebSocket" ) {
208
+ if ! headerContainsTokenIgnoreCase (resp .Header , "Upgrade" , "WebSocket" ) {
201
209
return nil , fmt .Errorf ("WebSocket protocol violation: Upgrade header %q does not contain websocket" , resp .Header .Get ("Upgrade" ))
202
210
}
203
211
0 commit comments