diff --git a/dial.go b/dial.go index 4b2b7b62..510b94b1 100644 --- a/dial.go +++ b/dial.go @@ -30,6 +30,10 @@ type DialOptions struct { // HTTPHeader specifies the HTTP headers included in the handshake request. HTTPHeader http.Header + // Host optionally overrides the Host HTTP header to send. If empty, the value + // of URL.Host will be used. + Host string + // Subprotocols lists the WebSocket subprotocols to negotiate with the server. Subprotocols []string @@ -168,6 +172,9 @@ func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, copts if err != nil { return nil, fmt.Errorf("failed to create new http request: %w", err) } + if len(opts.Host) > 0 { + req.Host = opts.Host + } req.Header = opts.HTTPHeader.Clone() req.Header.Set("Connection", "Upgrade") req.Header.Set("Upgrade", "websocket") diff --git a/dial_test.go b/dial_test.go index 75d59540..8680147e 100644 --- a/dial_test.go +++ b/dial_test.go @@ -4,6 +4,7 @@ package websocket import ( + "bytes" "context" "crypto/rand" "io" @@ -118,6 +119,65 @@ func TestBadDials(t *testing.T) { }) } +func Test_verifyHostOverride(t *testing.T) { + testCases := []struct { + name string + host string + exp string + }{ + { + name: "noOverride", + host: "", + exp: "example.com", + }, + { + name: "hostOverride", + host: "example.net", + exp: "example.net", + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + rt := func(r *http.Request) (*http.Response, error) { + assert.Equal(t, "Host", tc.exp, r.Host) + + h := http.Header{} + h.Set("Connection", "Upgrade") + h.Set("Upgrade", "websocket") + h.Set("Sec-WebSocket-Accept", secWebSocketAccept(r.Header.Get("Sec-WebSocket-Key"))) + + return &http.Response{ + StatusCode: http.StatusSwitchingProtocols, + Header: h, + Body: mockBody{bytes.NewBufferString("hi")}, + }, nil + } + + _, _, err := Dial(ctx, "ws://example.com", &DialOptions{ + HTTPClient: mockHTTPClient(rt), + Host: tc.host, + }) + assert.Success(t, err) + }) + } + +} + +type mockBody struct { + *bytes.Buffer +} + +func (mb mockBody) Close() error { + return nil +} + func Test_verifyServerHandshake(t *testing.T) { t.Parallel()