Skip to content

Commit

Permalink
net/http/httputil: use response controller in reverse proxy
Browse files Browse the repository at this point in the history
Previously, the reverse proxy is unable to detect
the support for hijack or flush if those things
are residing in the response writer in a wrapped
manner.

The reverse proxy now makes use of the new http
response controller as the means to discover
the underlying flusher and hijacker associated
with the response writer, allowing wrapped flusher
and hijacker become discoverable.

Change-Id: I53acbb12315c3897be068e8c00598ef42fc74649
Reviewed-on: https://go-review.googlesource.com/c/go/+/468755
Run-TryBot: Damien Neil <dneil@google.com>
Auto-Submit: Damien Neil <dneil@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Damien Neil <dneil@google.com>
Reviewed-by: Cherry Mui <cherryyz@google.com>
  • Loading branch information
Shang Ding authored and gopherbot committed Mar 17, 2023
1 parent 602e6aa commit 2449bbb
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 33 deletions.
61 changes: 28 additions & 33 deletions src/net/http/httputil/reverseproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -524,9 +524,7 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
// Force chunking if we saw a response trailer.
// This prevents net/http from calculating the length for short
// bodies and adding a Content-Length.
if fl, ok := rw.(http.Flusher); ok {
fl.Flush()
}
http.NewResponseController(rw).Flush()
}

if len(res.Trailer) == announcedTrailers {
Expand Down Expand Up @@ -601,29 +599,30 @@ func (p *ReverseProxy) flushInterval(res *http.Response) time.Duration {
return p.FlushInterval
}

func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader, flushInterval time.Duration) error {
func (p *ReverseProxy) copyResponse(dst http.ResponseWriter, src io.Reader, flushInterval time.Duration) error {
var w io.Writer = dst

if flushInterval != 0 {
if wf, ok := dst.(writeFlusher); ok {
mlw := &maxLatencyWriter{
dst: wf,
latency: flushInterval,
}
defer mlw.stop()
mlw := &maxLatencyWriter{
dst: dst,
flush: http.NewResponseController(dst).Flush,
latency: flushInterval,
}
defer mlw.stop()

// set up initial timer so headers get flushed even if body writes are delayed
mlw.flushPending = true
mlw.t = time.AfterFunc(flushInterval, mlw.delayedFlush)
// set up initial timer so headers get flushed even if body writes are delayed
mlw.flushPending = true
mlw.t = time.AfterFunc(flushInterval, mlw.delayedFlush)

dst = mlw
}
w = mlw
}

var buf []byte
if p.BufferPool != nil {
buf = p.BufferPool.Get()
defer p.BufferPool.Put(buf)
}
_, err := p.copyBuffer(dst, src, buf)
_, err := p.copyBuffer(w, src, buf)
return err
}

Expand Down Expand Up @@ -668,13 +667,9 @@ func (p *ReverseProxy) logf(format string, args ...any) {
}
}

type writeFlusher interface {
io.Writer
http.Flusher
}

type maxLatencyWriter struct {
dst writeFlusher
dst io.Writer
flush func() error
latency time.Duration // non-zero; negative means to flush immediately

mu sync.Mutex // protects t, flushPending, and dst.Flush
Expand All @@ -687,7 +682,7 @@ func (m *maxLatencyWriter) Write(p []byte) (n int, err error) {
defer m.mu.Unlock()
n, err = m.dst.Write(p)
if m.latency < 0 {
m.dst.Flush()
m.flush()
return
}
if m.flushPending {
Expand All @@ -708,7 +703,7 @@ func (m *maxLatencyWriter) delayedFlush() {
if !m.flushPending { // if stop was called but AfterFunc already started this goroutine
return
}
m.dst.Flush()
m.flush()
m.flushPending = false
}

Expand Down Expand Up @@ -739,17 +734,19 @@ func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.R
return
}

hj, ok := rw.(http.Hijacker)
if !ok {
p.getErrorHandler()(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw))
return
}
backConn, ok := res.Body.(io.ReadWriteCloser)
if !ok {
p.getErrorHandler()(rw, req, fmt.Errorf("internal error: 101 switching protocols response with non-writable body"))
return
}

rc := http.NewResponseController(rw)
conn, brw, hijackErr := rc.Hijack()
if errors.Is(hijackErr, http.ErrNotSupported) {
p.getErrorHandler()(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw))
return
}

backConnCloseCh := make(chan bool)
go func() {
// Ensure that the cancellation of a request closes the backend.
Expand All @@ -760,12 +757,10 @@ func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.R
}
backConn.Close()
}()

defer close(backConnCloseCh)

conn, brw, err := hj.Hijack()
if err != nil {
p.getErrorHandler()(rw, req, fmt.Errorf("Hijack failed on protocol switch: %v", err))
if hijackErr != nil {
p.getErrorHandler()(rw, req, fmt.Errorf("Hijack failed on protocol switch: %v", hijackErr))
return
}
defer conn.Close()
Expand Down
56 changes: 56 additions & 0 deletions src/net/http/httputil/reverseproxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,62 @@ func TestReverseProxyFlushInterval(t *testing.T) {
}
}

type mockFlusher struct {
http.ResponseWriter
flushed bool
}

func (m *mockFlusher) Flush() {
m.flushed = true
}

type wrappedRW struct {
http.ResponseWriter
}

func (w *wrappedRW) Unwrap() http.ResponseWriter {
return w.ResponseWriter
}

func TestReverseProxyResponseControllerFlushInterval(t *testing.T) {
const expected = "hi"
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(expected))
}))
defer backend.Close()

backendURL, err := url.Parse(backend.URL)
if err != nil {
t.Fatal(err)
}

mf := &mockFlusher{}
proxyHandler := NewSingleHostReverseProxy(backendURL)
proxyHandler.FlushInterval = -1 // flush immediately
proxyWithMiddleware := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mf.ResponseWriter = w
w = &wrappedRW{mf}
proxyHandler.ServeHTTP(w, r)
})

frontend := httptest.NewServer(proxyWithMiddleware)
defer frontend.Close()

req, _ := http.NewRequest("GET", frontend.URL, nil)
req.Close = true
res, err := frontend.Client().Do(req)
if err != nil {
t.Fatalf("Get: %v", err)
}
defer res.Body.Close()
if bodyBytes, _ := io.ReadAll(res.Body); string(bodyBytes) != expected {
t.Errorf("got body %q; expected %q", bodyBytes, expected)
}
if !mf.flushed {
t.Errorf("response writer was not flushed")
}
}

func TestReverseProxyFlushIntervalHeaders(t *testing.T) {
const expected = "hi"
stopCh := make(chan struct{})
Expand Down

0 comments on commit 2449bbb

Please sign in to comment.