Skip to content

Commit

Permalink
Allow ResponseWriters to unwrap writers when flushing/hijacking (#2595)
Browse files Browse the repository at this point in the history
* Allow ResponseWriters to unwrap writers when flushing/hijacking
  • Loading branch information
aldas authored Mar 9, 2024
1 parent 3e04e3e commit bc1e190
Show file tree
Hide file tree
Showing 11 changed files with 289 additions and 8 deletions.
12 changes: 10 additions & 2 deletions middleware/body_dump.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package middleware
import (
"bufio"
"bytes"
"errors"
"io"
"net"
"net/http"
Expand Down Expand Up @@ -98,9 +99,16 @@ func (w *bodyDumpResponseWriter) Write(b []byte) (int, error) {
}

func (w *bodyDumpResponseWriter) Flush() {
w.ResponseWriter.(http.Flusher).Flush()
err := responseControllerFlush(w.ResponseWriter)
if err != nil && errors.Is(err, http.ErrNotSupported) {
panic(errors.New("response writer flushing is not supported"))
}
}

func (w *bodyDumpResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return w.ResponseWriter.(http.Hijacker).Hijack()
return responseControllerHijack(w.ResponseWriter)
}

func (w *bodyDumpResponseWriter) Unwrap() http.ResponseWriter {
return w.ResponseWriter
}
50 changes: 50 additions & 0 deletions middleware/body_dump_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,53 @@ func TestBodyDumpFails(t *testing.T) {
}
})
}

func TestBodyDumpResponseWriter_CanNotFlush(t *testing.T) {
bdrw := bodyDumpResponseWriter{
ResponseWriter: new(testResponseWriterNoFlushHijack), // this RW does not support flush
}

assert.PanicsWithError(t, "response writer flushing is not supported", func() {
bdrw.Flush()
})
}

func TestBodyDumpResponseWriter_CanFlush(t *testing.T) {
trwu := testResponseWriterUnwrapperHijack{testResponseWriterUnwrapper: testResponseWriterUnwrapper{rw: httptest.NewRecorder()}}
bdrw := bodyDumpResponseWriter{
ResponseWriter: &trwu,
}

bdrw.Flush()
assert.Equal(t, 1, trwu.unwrapCalled)
}

func TestBodyDumpResponseWriter_CanUnwrap(t *testing.T) {
trwu := &testResponseWriterUnwrapper{rw: httptest.NewRecorder()}
bdrw := bodyDumpResponseWriter{
ResponseWriter: trwu,
}

result := bdrw.Unwrap()
assert.Equal(t, trwu, result)
}

func TestBodyDumpResponseWriter_CanHijack(t *testing.T) {
trwu := testResponseWriterUnwrapperHijack{testResponseWriterUnwrapper: testResponseWriterUnwrapper{rw: httptest.NewRecorder()}}
bdrw := bodyDumpResponseWriter{
ResponseWriter: &trwu, // this RW supports hijacking through unwrapping
}

_, _, err := bdrw.Hijack()
assert.EqualError(t, err, "can hijack")
}

func TestBodyDumpResponseWriter_CanNotHijack(t *testing.T) {
trwu := testResponseWriterUnwrapper{rw: httptest.NewRecorder()}
bdrw := bodyDumpResponseWriter{
ResponseWriter: &trwu, // this RW supports hijacking through unwrapping
}

_, _, err := bdrw.Hijack()
assert.EqualError(t, err, "feature not supported")
}
10 changes: 6 additions & 4 deletions middleware/compress.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,13 +191,15 @@ func (w *gzipResponseWriter) Flush() {
}

w.Writer.(*gzip.Writer).Flush()
if flusher, ok := w.ResponseWriter.(http.Flusher); ok {
flusher.Flush()
}
_ = responseControllerFlush(w.ResponseWriter)
}

func (w *gzipResponseWriter) Unwrap() http.ResponseWriter {
return w.ResponseWriter
}

func (w *gzipResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return w.ResponseWriter.(http.Hijacker).Hijack()
return responseControllerHijack(w.ResponseWriter)
}

func (w *gzipResponseWriter) Push(target string, opts *http.PushOptions) error {
Expand Down
30 changes: 30 additions & 0 deletions middleware/compress_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,36 @@ func TestGzipWithStatic(t *testing.T) {
}
}

func TestGzipResponseWriter_CanUnwrap(t *testing.T) {
trwu := &testResponseWriterUnwrapper{rw: httptest.NewRecorder()}
bdrw := gzipResponseWriter{
ResponseWriter: trwu,
}

result := bdrw.Unwrap()
assert.Equal(t, trwu, result)
}

func TestGzipResponseWriter_CanHijack(t *testing.T) {
trwu := testResponseWriterUnwrapperHijack{testResponseWriterUnwrapper: testResponseWriterUnwrapper{rw: httptest.NewRecorder()}}
bdrw := gzipResponseWriter{
ResponseWriter: &trwu, // this RW supports hijacking through unwrapping
}

_, _, err := bdrw.Hijack()
assert.EqualError(t, err, "can hijack")
}

func TestGzipResponseWriter_CanNotHijack(t *testing.T) {
trwu := testResponseWriterUnwrapper{rw: httptest.NewRecorder()}
bdrw := gzipResponseWriter{
ResponseWriter: &trwu, // this RW supports hijacking through unwrapping
}

_, _, err := bdrw.Hijack()
assert.EqualError(t, err, "feature not supported")
}

func BenchmarkGzip(b *testing.B) {
e := echo.New()

Expand Down
46 changes: 46 additions & 0 deletions middleware/middleware_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
package middleware

import (
"bufio"
"errors"
"github.com/stretchr/testify/assert"
"net"
"net/http"
"net/http/httptest"
"regexp"
Expand Down Expand Up @@ -90,3 +93,46 @@ func TestRewriteURL(t *testing.T) {
})
}
}

type testResponseWriterNoFlushHijack struct {
}

func (w *testResponseWriterNoFlushHijack) WriteHeader(statusCode int) {
}

func (w *testResponseWriterNoFlushHijack) Write([]byte) (int, error) {
return 0, nil
}

func (w *testResponseWriterNoFlushHijack) Header() http.Header {
return nil
}

type testResponseWriterUnwrapper struct {
unwrapCalled int
rw http.ResponseWriter
}

func (w *testResponseWriterUnwrapper) WriteHeader(statusCode int) {
}

func (w *testResponseWriterUnwrapper) Write([]byte) (int, error) {
return 0, nil
}

func (w *testResponseWriterUnwrapper) Header() http.Header {
return nil
}

func (w *testResponseWriterUnwrapper) Unwrap() http.ResponseWriter {
w.unwrapCalled++
return w.rw
}

type testResponseWriterUnwrapperHijack struct {
testResponseWriterUnwrapper
}

func (w *testResponseWriterUnwrapperHijack) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return nil, nil, errors.New("can hijack")
}
41 changes: 41 additions & 0 deletions middleware/responsecontroller_1.19.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
//go:build !go1.20

package middleware

import (
"bufio"
"fmt"
"net"
"net/http"
)

// TODO: remove when Go 1.23 is released and we do not support 1.19 anymore
func responseControllerFlush(rw http.ResponseWriter) error {
for {
switch t := rw.(type) {
case interface{ FlushError() error }:
return t.FlushError()
case http.Flusher:
t.Flush()
return nil
case interface{ Unwrap() http.ResponseWriter }:
rw = t.Unwrap()
default:
return fmt.Errorf("%w", http.ErrNotSupported)
}
}
}

// TODO: remove when Go 1.23 is released and we do not support 1.19 anymore
func responseControllerHijack(rw http.ResponseWriter) (net.Conn, *bufio.ReadWriter, error) {
for {
switch t := rw.(type) {
case http.Hijacker:
return t.Hijack()
case interface{ Unwrap() http.ResponseWriter }:
rw = t.Unwrap()
default:
return nil, nil, fmt.Errorf("%w", http.ErrNotSupported)
}
}
}
17 changes: 17 additions & 0 deletions middleware/responsecontroller_1.20.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
//go:build go1.20

package middleware

import (
"bufio"
"net"
"net/http"
)

func responseControllerFlush(rw http.ResponseWriter) error {
return http.NewResponseController(rw).Flush()
}

func responseControllerHijack(rw http.ResponseWriter) (net.Conn, *bufio.ReadWriter, error) {
return http.NewResponseController(rw).Hijack()
}
8 changes: 6 additions & 2 deletions response.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package echo

import (
"bufio"
"errors"
"net"
"net/http"
)
Expand Down Expand Up @@ -84,14 +85,17 @@ func (r *Response) Write(b []byte) (n int, err error) {
// buffered data to the client.
// See [http.Flusher](https://golang.org/pkg/net/http/#Flusher)
func (r *Response) Flush() {
r.Writer.(http.Flusher).Flush()
err := responseControllerFlush(r.Writer)
if err != nil && errors.Is(err, http.ErrNotSupported) {
panic(errors.New("response writer flushing is not supported"))
}
}

// Hijack implements the http.Hijacker interface to allow an HTTP handler to
// take over the connection.
// See [http.Hijacker](https://golang.org/pkg/net/http/#Hijacker)
func (r *Response) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return r.Writer.(http.Hijacker).Hijack()
return responseControllerHijack(r.Writer)
}

// Unwrap returns the original http.ResponseWriter.
Expand Down
25 changes: 25 additions & 0 deletions response_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,31 @@ func TestResponse_Flush(t *testing.T) {
assert.True(t, rec.Flushed)
}

type testResponseWriter struct {
}

func (w *testResponseWriter) WriteHeader(statusCode int) {
}

func (w *testResponseWriter) Write([]byte) (int, error) {
return 0, nil
}

func (w *testResponseWriter) Header() http.Header {
return nil
}

func TestResponse_FlushPanics(t *testing.T) {
e := New()
rw := new(testResponseWriter)
res := &Response{echo: e, Writer: rw}

// we test that we behave as before unwrapping flushers - flushing writer that does not support it causes panic
assert.PanicsWithError(t, "response writer flushing is not supported", func() {
res.Flush()
})
}

func TestResponse_ChangeStatusCodeBeforeWrite(t *testing.T) {
e := New()
rec := httptest.NewRecorder()
Expand Down
41 changes: 41 additions & 0 deletions responsecontroller_1.19.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
//go:build !go1.20

package echo

import (
"bufio"
"fmt"
"net"
"net/http"
)

// TODO: remove when Go 1.23 is released and we do not support 1.19 anymore
func responseControllerFlush(rw http.ResponseWriter) error {
for {
switch t := rw.(type) {
case interface{ FlushError() error }:
return t.FlushError()
case http.Flusher:
t.Flush()
return nil
case interface{ Unwrap() http.ResponseWriter }:
rw = t.Unwrap()
default:
return fmt.Errorf("%w", http.ErrNotSupported)
}
}
}

// TODO: remove when Go 1.23 is released and we do not support 1.19 anymore
func responseControllerHijack(rw http.ResponseWriter) (net.Conn, *bufio.ReadWriter, error) {
for {
switch t := rw.(type) {
case http.Hijacker:
return t.Hijack()
case interface{ Unwrap() http.ResponseWriter }:
rw = t.Unwrap()
default:
return nil, nil, fmt.Errorf("%w", http.ErrNotSupported)
}
}
}
17 changes: 17 additions & 0 deletions responsecontroller_1.20.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
//go:build go1.20

package echo

import (
"bufio"
"net"
"net/http"
)

func responseControllerFlush(rw http.ResponseWriter) error {
return http.NewResponseController(rw).Flush()
}

func responseControllerHijack(rw http.ResponseWriter) (net.Conn, *bufio.ReadWriter, error) {
return http.NewResponseController(rw).Hijack()
}

0 comments on commit bc1e190

Please sign in to comment.