1
1
//go:build !js
2
2
// +build !js
3
3
4
- package websocket
4
+ package websocket_test
5
5
6
6
import (
7
7
"bytes"
@@ -10,12 +10,15 @@ import (
10
10
"io"
11
11
"net/http"
12
12
"net/http/httptest"
13
+ "net/url"
13
14
"strings"
14
15
"testing"
15
16
"time"
16
17
18
+ "nhooyr.io/websocket"
17
19
"nhooyr.io/websocket/internal/test/assert"
18
20
"nhooyr.io/websocket/internal/util"
21
+ "nhooyr.io/websocket/internal/xsync"
19
22
)
20
23
21
24
func TestBadDials (t * testing.T ) {
@@ -27,7 +30,7 @@ func TestBadDials(t *testing.T) {
27
30
testCases := []struct {
28
31
name string
29
32
url string
30
- opts * DialOptions
33
+ opts * websocket. DialOptions
31
34
rand util.ReaderFunc
32
35
nilCtx bool
33
36
}{
@@ -72,7 +75,7 @@ func TestBadDials(t *testing.T) {
72
75
tc .rand = rand .Reader .Read
73
76
}
74
77
75
- _ , _ , err := dial (ctx , tc .url , tc .opts , tc .rand )
78
+ _ , _ , err := websocket . ExportedDial (ctx , tc .url , tc .opts , tc .rand )
76
79
assert .Error (t , err )
77
80
})
78
81
}
@@ -84,7 +87,7 @@ func TestBadDials(t *testing.T) {
84
87
ctx , cancel := context .WithTimeout (context .Background (), time .Second * 5 )
85
88
defer cancel ()
86
89
87
- _ , _ , err := Dial (ctx , "ws://example.com" , & DialOptions {
90
+ _ , _ , err := websocket . Dial (ctx , "ws://example.com" , & websocket. DialOptions {
88
91
HTTPClient : mockHTTPClient (func (* http.Request ) (* http.Response , error ) {
89
92
return & http.Response {
90
93
Body : io .NopCloser (strings .NewReader ("hi" )),
@@ -104,7 +107,7 @@ func TestBadDials(t *testing.T) {
104
107
h := http.Header {}
105
108
h .Set ("Connection" , "Upgrade" )
106
109
h .Set ("Upgrade" , "websocket" )
107
- h .Set ("Sec-WebSocket-Accept" , secWebSocketAccept (r .Header .Get ("Sec-WebSocket-Key" )))
110
+ h .Set ("Sec-WebSocket-Accept" , websocket . SecWebSocketAccept (r .Header .Get ("Sec-WebSocket-Key" )))
108
111
109
112
return & http.Response {
110
113
StatusCode : http .StatusSwitchingProtocols ,
@@ -113,7 +116,7 @@ func TestBadDials(t *testing.T) {
113
116
}, nil
114
117
}
115
118
116
- _ , _ , err := Dial (ctx , "ws://example.com" , & DialOptions {
119
+ _ , _ , err := websocket . Dial (ctx , "ws://example.com" , & websocket. DialOptions {
117
120
HTTPClient : mockHTTPClient (rt ),
118
121
})
119
122
assert .Contains (t , err , "response body is not a io.ReadWriteCloser" )
@@ -152,7 +155,7 @@ func Test_verifyHostOverride(t *testing.T) {
152
155
h := http.Header {}
153
156
h .Set ("Connection" , "Upgrade" )
154
157
h .Set ("Upgrade" , "websocket" )
155
- h .Set ("Sec-WebSocket-Accept" , secWebSocketAccept (r .Header .Get ("Sec-WebSocket-Key" )))
158
+ h .Set ("Sec-WebSocket-Accept" , websocket . SecWebSocketAccept (r .Header .Get ("Sec-WebSocket-Key" )))
156
159
157
160
return & http.Response {
158
161
StatusCode : http .StatusSwitchingProtocols ,
@@ -161,7 +164,7 @@ func Test_verifyHostOverride(t *testing.T) {
161
164
}, nil
162
165
}
163
166
164
- _ , _ , err := Dial (ctx , "ws://example.com" , & DialOptions {
167
+ _ , _ , err := websocket . Dial (ctx , "ws://example.com" , & websocket. DialOptions {
165
168
HTTPClient : mockHTTPClient (rt ),
166
169
Host : tc .host ,
167
170
})
@@ -272,18 +275,18 @@ func Test_verifyServerHandshake(t *testing.T) {
272
275
resp := w .Result ()
273
276
274
277
r := httptest .NewRequest ("GET" , "/" , nil )
275
- key , err := secWebSocketKey (rand .Reader )
278
+ key , err := websocket . SecWebSocketKey (rand .Reader )
276
279
assert .Success (t , err )
277
280
r .Header .Set ("Sec-WebSocket-Key" , key )
278
281
279
282
if resp .Header .Get ("Sec-WebSocket-Accept" ) == "" {
280
- resp .Header .Set ("Sec-WebSocket-Accept" , secWebSocketAccept (key ))
283
+ resp .Header .Set ("Sec-WebSocket-Accept" , websocket . SecWebSocketAccept (key ))
281
284
}
282
285
283
- opts := & DialOptions {
286
+ opts := & websocket. DialOptions {
284
287
Subprotocols : strings .Split (r .Header .Get ("Sec-WebSocket-Protocol" ), "," ),
285
288
}
286
- _ , err = verifyServerResponse (opts , opts .CompressionMode . opts ( ), key , resp )
289
+ _ , err = websocket . VerifyServerResponse (opts , websocket . CompressionModeOpts ( opts .CompressionMode ), key , resp )
287
290
if tc .success {
288
291
assert .Success (t , err )
289
292
} else {
@@ -311,7 +314,7 @@ func TestDialRedirect(t *testing.T) {
311
314
ctx , cancel := context .WithTimeout (context .Background (), time .Second * 5 )
312
315
defer cancel ()
313
316
314
- _ , _ , err := Dial (ctx , "ws://example.com" , & DialOptions {
317
+ _ , _ , err := websocket . Dial (ctx , "ws://example.com" , & websocket. DialOptions {
315
318
HTTPClient : mockHTTPClient (func (r * http.Request ) (* http.Response , error ) {
316
319
resp := & http.Response {
317
320
Header : http.Header {},
@@ -321,11 +324,88 @@ func TestDialRedirect(t *testing.T) {
321
324
resp .StatusCode = http .StatusFound
322
325
return resp , nil
323
326
}
324
- resp .Header .Set ("Connection" , "Upgrade" )
325
- resp .Header .Set ("Upgrade" , "meow" )
327
+ resp .Header .Set ("Connection" , "Upgrade" )
328
+ resp .Header .Set ("Upgrade" , "meow" )
326
329
resp .StatusCode = http .StatusSwitchingProtocols
327
330
return resp , nil
328
331
}),
329
332
})
330
333
assert .Contains (t , err , "failed to WebSocket dial: WebSocket protocol violation: Upgrade header \" meow\" does not contain websocket" )
331
334
}
335
+
336
+ type forwardProxy struct {
337
+ hc * http.Client
338
+ }
339
+
340
+ func newForwardProxy () * forwardProxy {
341
+ return & forwardProxy {
342
+ hc : & http.Client {},
343
+ }
344
+ }
345
+
346
+ func (fc * forwardProxy ) ServeHTTP (w http.ResponseWriter , r * http.Request ) {
347
+ ctx , cancel := context .WithTimeout (r .Context (), time .Second * 10 )
348
+ defer cancel ()
349
+
350
+ r = r .WithContext (ctx )
351
+ r .RequestURI = ""
352
+ resp , err := fc .hc .Do (r )
353
+ if err != nil {
354
+ http .Error (w , err .Error (), http .StatusBadRequest )
355
+ return
356
+ }
357
+ defer resp .Body .Close ()
358
+
359
+ for k , v := range resp .Header {
360
+ w .Header ()[k ] = v
361
+ }
362
+ w .Header ().Set ("PROXIED" , "true" )
363
+ w .WriteHeader (resp .StatusCode )
364
+ errc1 := xsync .Go (func () error {
365
+ _ , err := io .Copy (w , resp .Body )
366
+ return err
367
+ })
368
+ var errc2 <- chan error
369
+ if bodyw , ok := resp .Body .(io.Writer ); ok {
370
+ errc2 = xsync .Go (func () error {
371
+ _ , err := io .Copy (bodyw , r .Body )
372
+ return err
373
+ })
374
+ }
375
+ select {
376
+ case <- errc1 :
377
+ case <- errc2 :
378
+ case <- r .Context ().Done ():
379
+ }
380
+ }
381
+
382
+ func TestDialViaProxy (t * testing.T ) {
383
+ t .Parallel ()
384
+
385
+ ps := httptest .NewServer (newForwardProxy ())
386
+ defer ps .Close ()
387
+
388
+ s := httptest .NewServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
389
+ err := echoServer (w , r , nil )
390
+ assert .Success (t , err )
391
+ }))
392
+ defer s .Close ()
393
+
394
+ psu , err := url .Parse (ps .URL )
395
+ assert .Success (t , err )
396
+ proxyTransport := http .DefaultTransport .(* http.Transport ).Clone ()
397
+ proxyTransport .Proxy = http .ProxyURL (psu )
398
+
399
+ ctx , cancel := context .WithTimeout (context .Background (), time .Second * 10 )
400
+ defer cancel ()
401
+ c , resp , err := websocket .Dial (ctx , s .URL , & websocket.DialOptions {
402
+ HTTPClient : & http.Client {
403
+ Transport : proxyTransport ,
404
+ },
405
+ })
406
+ assert .Success (t , err )
407
+ assert .Equal (t , "" , "true" , resp .Header .Get ("PROXIED" ))
408
+
409
+ assertEcho (t , ctx , c )
410
+ assertClose (t , c )
411
+ }
0 commit comments