diff --git a/header.go b/header.go index b3434c3556..99e1dcc42d 100644 --- a/header.go +++ b/header.go @@ -681,8 +681,6 @@ func (h *RequestHeader) RequestURI() []byte { requestURI := h.requestURI if len(requestURI) == 0 { requestURI = strSlash - } else if requestURI[0] == '?' { - requestURI = append(strSlash, requestURI...) } return requestURI } @@ -691,6 +689,11 @@ func (h *RequestHeader) RequestURI() []byte { // RequestURI must be properly encoded. // Use URI.RequestURI for constructing proper RequestURI if unsure. func (h *RequestHeader) SetRequestURI(requestURI string) { + if len(requestURI) > 0 && requestURI[0] == '?' { + h.requestURI = append(h.requestURI[:0], strSlash...) + h.requestURI = append(h.requestURI, requestURI...) + return + } h.requestURI = append(h.requestURI[:0], requestURI...) } @@ -698,6 +701,11 @@ func (h *RequestHeader) SetRequestURI(requestURI string) { // RequestURI must be properly encoded. // Use URI.RequestURI for constructing proper RequestURI if unsure. func (h *RequestHeader) SetRequestURIBytes(requestURI []byte) { + if len(requestURI) > 0 && requestURI[0] == '?' { + h.requestURI = append(h.requestURI[:0], strSlash...) + h.requestURI = append(h.requestURI, requestURI...) + return + } h.requestURI = append(h.requestURI[:0], requestURI...) } diff --git a/header_test.go b/header_test.go index 8d4f7cdb70..0c7fdc1cab 100644 --- a/header_test.go +++ b/header_test.go @@ -1261,6 +1261,26 @@ func TestRequestHeaderWithQueryParamsAndNoPath(t *testing.T) { if w.String() != expectedRequestHeader { t.Fatalf("unexpected request header: %q. Expecting %q", w, expectedRequestHeader) } + + h1.Reset() + h1.SetRequestURIBytes([]byte("?foo=bar")) + h1.SetHost("example.com") + h1.SetMethod("GET") + + w.Reset() + bw = bufio.NewWriter(w) + if err := h1.Write(bw); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if err := bw.Flush(); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + expectedRequestHeader = "GET /?foo=bar HTTP/1.1\r\nHost: example.com\r\n\r\n" + if w.String() != expectedRequestHeader { + t.Fatalf("unexpected request header: %q. Expecting %q", w, expectedRequestHeader) + } + } func TestResponseHeaderFirstByteReadEOF(t *testing.T) {