Skip to content

Commit f7bed7c

Browse files
bendiscznhooyr
authored andcommitted
Extend DialOptions to allow Host header override
1 parent 3f26c9f commit f7bed7c

File tree

2 files changed

+67
-0
lines changed

2 files changed

+67
-0
lines changed

dial.go

+7
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ type DialOptions struct {
3030
// HTTPHeader specifies the HTTP headers included in the handshake request.
3131
HTTPHeader http.Header
3232

33+
// Host optionally overrides the Host HTTP header to send. If empty, the value
34+
// of URL.Host will be used.
35+
Host string
36+
3337
// Subprotocols lists the WebSocket subprotocols to negotiate with the server.
3438
Subprotocols []string
3539

@@ -168,6 +172,9 @@ func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, copts
168172
if err != nil {
169173
return nil, fmt.Errorf("failed to create new http request: %w", err)
170174
}
175+
if len(opts.Host) > 0 {
176+
req.Host = opts.Host
177+
}
171178
req.Header = opts.HTTPHeader.Clone()
172179
req.Header.Set("Connection", "Upgrade")
173180
req.Header.Set("Upgrade", "websocket")

dial_test.go

+60
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
package websocket
55

66
import (
7+
"bytes"
78
"context"
89
"crypto/rand"
910
"io"
@@ -118,6 +119,65 @@ func TestBadDials(t *testing.T) {
118119
})
119120
}
120121

122+
func Test_verifyHostOverride(t *testing.T) {
123+
testCases := []struct {
124+
name string
125+
host string
126+
exp string
127+
}{
128+
{
129+
name: "noOverride",
130+
host: "",
131+
exp: "example.com",
132+
},
133+
{
134+
name: "hostOverride",
135+
host: "example.net",
136+
exp: "example.net",
137+
},
138+
}
139+
140+
for _, tc := range testCases {
141+
tc := tc
142+
t.Run(tc.name, func(t *testing.T) {
143+
t.Parallel()
144+
145+
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
146+
defer cancel()
147+
148+
rt := func(r *http.Request) (*http.Response, error) {
149+
assert.Equal(t, "Host", tc.exp, r.Host)
150+
151+
h := http.Header{}
152+
h.Set("Connection", "Upgrade")
153+
h.Set("Upgrade", "websocket")
154+
h.Set("Sec-WebSocket-Accept", secWebSocketAccept(r.Header.Get("Sec-WebSocket-Key")))
155+
156+
return &http.Response{
157+
StatusCode: http.StatusSwitchingProtocols,
158+
Header: h,
159+
Body: mockBody{bytes.NewBufferString("hi")},
160+
}, nil
161+
}
162+
163+
_, _, err := Dial(ctx, "ws://example.com", &DialOptions{
164+
HTTPClient: mockHTTPClient(rt),
165+
Host: tc.host,
166+
})
167+
assert.Success(t, err)
168+
})
169+
}
170+
171+
}
172+
173+
type mockBody struct {
174+
*bytes.Buffer
175+
}
176+
177+
func (mb mockBody) Close() error {
178+
return nil
179+
}
180+
121181
func Test_verifyServerHandshake(t *testing.T) {
122182
t.Parallel()
123183

0 commit comments

Comments
 (0)