From 29a06c764d8e59405fc9191993a6e2d2e6e43730 Mon Sep 17 00:00:00 2001 From: Edward McFarlane <3036610+emcfarlane@users.noreply.github.com> Date: Fri, 16 Jun 2023 13:51:21 -0400 Subject: [PATCH] Allow clients to set Host in headers (#522) Allows Host to set the request.Host on the client and promotes request.Host back to a Header in the handler. Fixes #513 --- connect_ext_test.go | 58 +++++++++++++++++++++++++++++++++++++++++++++ duplex_http_call.go | 5 ++++ handler.go | 3 ++- protocol.go | 1 + 4 files changed, 66 insertions(+), 1 deletion(-) diff --git a/connect_ext_test.go b/connect_ext_test.go index f5e7a1a9..f523b7cb 100644 --- a/connect_ext_test.go +++ b/connect_ext_test.go @@ -521,6 +521,64 @@ func TestHeaderBasic(t *testing.T) { assert.Equal(t, response.Header().Get(key), hval) } +func TestHeaderHost(t *testing.T) { + t.Parallel() + const ( + key = "Host" + cval = "buf.build" + ) + + pingServer := &pluggablePingServer{ + ping: func(_ context.Context, request *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error) { + assert.Equal(t, request.Header().Get(key), cval) + response := connect.NewResponse(&pingv1.PingResponse{}) + return response, nil + }, + } + + newHTTP2Server := func(t *testing.T) *httptest.Server { + t.Helper() + mux := http.NewServeMux() + mux.Handle(pingv1connect.NewPingServiceHandler(pingServer)) + server := httptest.NewUnstartedServer(mux) + server.EnableHTTP2 = true + server.StartTLS() + t.Cleanup(server.Close) + return server + } + + callWithHost := func(t *testing.T, client pingv1connect.PingServiceClient) { + t.Helper() + + request := connect.NewRequest(&pingv1.PingRequest{}) + request.Header().Set(key, cval) + response, err := client.Ping(context.Background(), request) + assert.Nil(t, err) + assert.Equal(t, response.Header().Get(key), "") + } + + t.Run("connect", func(t *testing.T) { + t.Parallel() + server := newHTTP2Server(t) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL) + callWithHost(t, client) + }) + + t.Run("grpc", func(t *testing.T) { + t.Parallel() + server := newHTTP2Server(t) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithGRPC()) + callWithHost(t, client) + }) + + t.Run("grpc-web", func(t *testing.T) { + t.Parallel() + server := newHTTP2Server(t) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithGRPCWeb()) + callWithHost(t, client) + }) +} + func TestTimeoutParsing(t *testing.T) { t.Parallel() const timeout = 10 * time.Minute diff --git a/duplex_http_call.go b/duplex_http_call.go index 03d4c912..4dac0092 100644 --- a/duplex_http_call.go +++ b/duplex_http_call.go @@ -257,6 +257,11 @@ func (d *duplexHTTPCall) makeRequest() { // on d.responseReady, so we can't race with them. defer close(d.responseReady) + // Promote the header Host to the request object. + if host := d.request.Header.Get(headerHost); len(host) > 0 { + d.request.Host = host + } + if d.onRequestSend != nil { d.onRequestSend(d.request) } diff --git a/handler.go b/handler.go index af3d215b..8d63eda7 100644 --- a/handler.go +++ b/handler.go @@ -176,7 +176,7 @@ func NewBidiStreamHandler[Req, Res any]( // ServeHTTP implements [http.Handler]. func (h *Handler) ServeHTTP(responseWriter http.ResponseWriter, request *http.Request) { - // We don't need to defer functions to close the request body or read to + // We don't need to defer functions to close the request body or read to // EOF: the stream we construct later on already does that, and we only // return early when dealing with misbehaving clients. In those cases, it's // okay if we can't re-use the connection. @@ -221,6 +221,7 @@ func (h *Handler) ServeHTTP(responseWriter http.ResponseWriter, request *http.Re // Establish a stream and serve the RPC. setHeaderCanonical(request.Header, headerContentType, contentType) + setHeaderCanonical(request.Header, headerHost, request.Host) ctx, cancel, timeoutErr := protocolHandler.SetTimeout(request) //nolint: contextcheck if timeoutErr != nil { ctx = request.Context() diff --git a/protocol.go b/protocol.go index b5954290..ee1a141d 100644 --- a/protocol.go +++ b/protocol.go @@ -36,6 +36,7 @@ const ( const ( headerContentType = "Content-Type" + headerHost = "Host" headerUserAgent = "User-Agent" headerTrailer = "Trailer"