Skip to content

Commit 249edb2

Browse files
committed
dial_test: Add TestDialViaProxy
For #395 Somehow currently reproduces #391... Debugging still.
1 parent e314da6 commit 249edb2

File tree

3 files changed

+128
-15
lines changed

3 files changed

+128
-15
lines changed

Diff for: conn_test.go

+26
Original file line numberDiff line numberDiff line change
@@ -526,3 +526,29 @@ func echoServer(w http.ResponseWriter, r *http.Request, opts *websocket.AcceptOp
526526
err = wstest.EchoLoop(r.Context(), c)
527527
return assertCloseStatus(websocket.StatusNormalClosure, err)
528528
}
529+
530+
func assertEcho(tb testing.TB, ctx context.Context, c *websocket.Conn) {
531+
exp := xrand.String(xrand.Int(131072))
532+
533+
werr := xsync.Go(func() error {
534+
return wsjson.Write(ctx, c, exp)
535+
})
536+
537+
var act interface{}
538+
err := wsjson.Read(ctx, c, &act)
539+
assert.Success(tb, err)
540+
assert.Equal(tb, "read msg", exp, act)
541+
542+
select {
543+
case err := <-werr:
544+
assert.Success(tb, err)
545+
case <-ctx.Done():
546+
tb.Fatal(ctx.Err())
547+
}
548+
}
549+
550+
func assertClose(tb testing.TB, c *websocket.Conn) {
551+
tb.Helper()
552+
err := c.Close(websocket.StatusNormalClosure, "")
553+
assert.Success(tb, err)
554+
}

Diff for: dial_test.go

+95-15
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
//go:build !js
22
// +build !js
33

4-
package websocket
4+
package websocket_test
55

66
import (
77
"bytes"
@@ -10,12 +10,15 @@ import (
1010
"io"
1111
"net/http"
1212
"net/http/httptest"
13+
"net/url"
1314
"strings"
1415
"testing"
1516
"time"
1617

18+
"nhooyr.io/websocket"
1719
"nhooyr.io/websocket/internal/test/assert"
1820
"nhooyr.io/websocket/internal/util"
21+
"nhooyr.io/websocket/internal/xsync"
1922
)
2023

2124
func TestBadDials(t *testing.T) {
@@ -27,7 +30,7 @@ func TestBadDials(t *testing.T) {
2730
testCases := []struct {
2831
name string
2932
url string
30-
opts *DialOptions
33+
opts *websocket.DialOptions
3134
rand util.ReaderFunc
3235
nilCtx bool
3336
}{
@@ -72,7 +75,7 @@ func TestBadDials(t *testing.T) {
7275
tc.rand = rand.Reader.Read
7376
}
7477

75-
_, _, err := dial(ctx, tc.url, tc.opts, tc.rand)
78+
_, _, err := websocket.ExportedDial(ctx, tc.url, tc.opts, tc.rand)
7679
assert.Error(t, err)
7780
})
7881
}
@@ -84,7 +87,7 @@ func TestBadDials(t *testing.T) {
8487
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
8588
defer cancel()
8689

87-
_, _, err := Dial(ctx, "ws://example.com", &DialOptions{
90+
_, _, err := websocket.Dial(ctx, "ws://example.com", &websocket.DialOptions{
8891
HTTPClient: mockHTTPClient(func(*http.Request) (*http.Response, error) {
8992
return &http.Response{
9093
Body: io.NopCloser(strings.NewReader("hi")),
@@ -104,7 +107,7 @@ func TestBadDials(t *testing.T) {
104107
h := http.Header{}
105108
h.Set("Connection", "Upgrade")
106109
h.Set("Upgrade", "websocket")
107-
h.Set("Sec-WebSocket-Accept", secWebSocketAccept(r.Header.Get("Sec-WebSocket-Key")))
110+
h.Set("Sec-WebSocket-Accept", websocket.SecWebSocketAccept(r.Header.Get("Sec-WebSocket-Key")))
108111

109112
return &http.Response{
110113
StatusCode: http.StatusSwitchingProtocols,
@@ -113,7 +116,7 @@ func TestBadDials(t *testing.T) {
113116
}, nil
114117
}
115118

116-
_, _, err := Dial(ctx, "ws://example.com", &DialOptions{
119+
_, _, err := websocket.Dial(ctx, "ws://example.com", &websocket.DialOptions{
117120
HTTPClient: mockHTTPClient(rt),
118121
})
119122
assert.Contains(t, err, "response body is not a io.ReadWriteCloser")
@@ -152,7 +155,7 @@ func Test_verifyHostOverride(t *testing.T) {
152155
h := http.Header{}
153156
h.Set("Connection", "Upgrade")
154157
h.Set("Upgrade", "websocket")
155-
h.Set("Sec-WebSocket-Accept", secWebSocketAccept(r.Header.Get("Sec-WebSocket-Key")))
158+
h.Set("Sec-WebSocket-Accept", websocket.SecWebSocketAccept(r.Header.Get("Sec-WebSocket-Key")))
156159

157160
return &http.Response{
158161
StatusCode: http.StatusSwitchingProtocols,
@@ -161,7 +164,7 @@ func Test_verifyHostOverride(t *testing.T) {
161164
}, nil
162165
}
163166

164-
_, _, err := Dial(ctx, "ws://example.com", &DialOptions{
167+
_, _, err := websocket.Dial(ctx, "ws://example.com", &websocket.DialOptions{
165168
HTTPClient: mockHTTPClient(rt),
166169
Host: tc.host,
167170
})
@@ -272,18 +275,18 @@ func Test_verifyServerHandshake(t *testing.T) {
272275
resp := w.Result()
273276

274277
r := httptest.NewRequest("GET", "/", nil)
275-
key, err := secWebSocketKey(rand.Reader)
278+
key, err := websocket.SecWebSocketKey(rand.Reader)
276279
assert.Success(t, err)
277280
r.Header.Set("Sec-WebSocket-Key", key)
278281

279282
if resp.Header.Get("Sec-WebSocket-Accept") == "" {
280-
resp.Header.Set("Sec-WebSocket-Accept", secWebSocketAccept(key))
283+
resp.Header.Set("Sec-WebSocket-Accept", websocket.SecWebSocketAccept(key))
281284
}
282285

283-
opts := &DialOptions{
286+
opts := &websocket.DialOptions{
284287
Subprotocols: strings.Split(r.Header.Get("Sec-WebSocket-Protocol"), ","),
285288
}
286-
_, err = verifyServerResponse(opts, opts.CompressionMode.opts(), key, resp)
289+
_, err = websocket.VerifyServerResponse(opts, websocket.CompressionModeOpts(opts.CompressionMode), key, resp)
287290
if tc.success {
288291
assert.Success(t, err)
289292
} else {
@@ -311,7 +314,7 @@ func TestDialRedirect(t *testing.T) {
311314
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
312315
defer cancel()
313316

314-
_, _, err := Dial(ctx, "ws://example.com", &DialOptions{
317+
_, _, err := websocket.Dial(ctx, "ws://example.com", &websocket.DialOptions{
315318
HTTPClient: mockHTTPClient(func(r *http.Request) (*http.Response, error) {
316319
resp := &http.Response{
317320
Header: http.Header{},
@@ -321,11 +324,88 @@ func TestDialRedirect(t *testing.T) {
321324
resp.StatusCode = http.StatusFound
322325
return resp, nil
323326
}
324-
resp.Header.Set("Connection", "Upgrade")
325-
resp.Header.Set("Upgrade", "meow")
327+
resp.Header.Set("Connection", "Upgrade")
328+
resp.Header.Set("Upgrade", "meow")
326329
resp.StatusCode = http.StatusSwitchingProtocols
327330
return resp, nil
328331
}),
329332
})
330333
assert.Contains(t, err, "failed to WebSocket dial: WebSocket protocol violation: Upgrade header \"meow\" does not contain websocket")
331334
}
335+
336+
type forwardProxy struct {
337+
hc *http.Client
338+
}
339+
340+
func newForwardProxy() *forwardProxy {
341+
return &forwardProxy{
342+
hc: &http.Client{},
343+
}
344+
}
345+
346+
func (fc *forwardProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
347+
ctx, cancel := context.WithTimeout(r.Context(), time.Second*10)
348+
defer cancel()
349+
350+
r = r.WithContext(ctx)
351+
r.RequestURI = ""
352+
resp, err := fc.hc.Do(r)
353+
if err != nil {
354+
http.Error(w, err.Error(), http.StatusBadRequest)
355+
return
356+
}
357+
defer resp.Body.Close()
358+
359+
for k, v := range resp.Header {
360+
w.Header()[k] = v
361+
}
362+
w.Header().Set("PROXIED", "true")
363+
w.WriteHeader(resp.StatusCode)
364+
errc1 := xsync.Go(func() error {
365+
_, err := io.Copy(w, resp.Body)
366+
return err
367+
})
368+
var errc2 <-chan error
369+
if bodyw, ok := resp.Body.(io.Writer); ok {
370+
errc2 = xsync.Go(func() error {
371+
_, err := io.Copy(bodyw, r.Body)
372+
return err
373+
})
374+
}
375+
select {
376+
case <-errc1:
377+
case <-errc2:
378+
case <-r.Context().Done():
379+
}
380+
}
381+
382+
func TestDialViaProxy(t *testing.T) {
383+
t.Parallel()
384+
385+
ps := httptest.NewServer(newForwardProxy())
386+
defer ps.Close()
387+
388+
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
389+
err := echoServer(w, r, nil)
390+
assert.Success(t, err)
391+
}))
392+
defer s.Close()
393+
394+
psu, err := url.Parse(ps.URL)
395+
assert.Success(t, err)
396+
proxyTransport := http.DefaultTransport.(*http.Transport).Clone()
397+
proxyTransport.Proxy = http.ProxyURL(psu)
398+
399+
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
400+
defer cancel()
401+
c, resp, err := websocket.Dial(ctx, s.URL, &websocket.DialOptions{
402+
HTTPClient: &http.Client{
403+
Transport: proxyTransport,
404+
},
405+
})
406+
assert.Success(t, err)
407+
assert.Equal(t, "", "true", resp.Header.Get("PROXIED"))
408+
409+
assertEcho(t, ctx, c)
410+
assertClose(t, c)
411+
}

Diff for: export_test.go

+7
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,10 @@ func (c *Conn) RecordBytesRead() *int {
2525
}
2626

2727
var ErrClosed = errClosed
28+
29+
var ExportedDial = dial
30+
var SecWebSocketAccept = secWebSocketAccept
31+
var SecWebSocketKey = secWebSocketKey
32+
var VerifyServerResponse = verifyServerResponse
33+
34+
var CompressionModeOpts = CompressionMode.opts

0 commit comments

Comments
 (0)