Skip to content

Commit

Permalink
middleware: add Discard method to WrapResponseWriter (#926)
Browse files Browse the repository at this point in the history
* middleware: add Discard method to WrapResponseWriter

* resolve review comments

* use ioutil.Discard and deprecate the public interface

* move the Discard method back to the public interface

* discard calls to WriteHeader too
  • Loading branch information
patrislav authored Jun 28, 2024
1 parent 7957c0d commit 67be7d9
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 8 deletions.
35 changes: 27 additions & 8 deletions middleware/wrap_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package middleware
import (
"bufio"
"io"
"io/ioutil"
"net"
"net/http"
)
Expand Down Expand Up @@ -61,6 +62,11 @@ type WrapResponseWriter interface {
Tee(io.Writer)
// Unwrap returns the original proxied target.
Unwrap() http.ResponseWriter
// Discard causes all writes to the original ResponseWriter be discarded,
// instead writing only to the tee'd writer if it's set.
// The caller is responsible for calling WriteHeader and Write on the
// original ResponseWriter once the processing is done.
Discard()
}

// basicWriter wraps a http.ResponseWriter that implements the minimal
Expand All @@ -71,25 +77,34 @@ type basicWriter struct {
code int
bytes int
tee io.Writer
discard bool
}

func (b *basicWriter) WriteHeader(code int) {
if !b.wroteHeader {
b.code = code
b.wroteHeader = true
b.ResponseWriter.WriteHeader(code)
if !b.discard {
b.ResponseWriter.WriteHeader(code)
}
}
}

func (b *basicWriter) Write(buf []byte) (int, error) {
func (b *basicWriter) Write(buf []byte) (n int, err error) {
b.maybeWriteHeader()
n, err := b.ResponseWriter.Write(buf)
if b.tee != nil {
_, err2 := b.tee.Write(buf[:n])
// Prefer errors generated by the proxied writer.
if err == nil {
err = err2
if !b.discard {
n, err = b.ResponseWriter.Write(buf)
if b.tee != nil {
_, err2 := b.tee.Write(buf[:n])
// Prefer errors generated by the proxied writer.
if err == nil {
err = err2
}
}
} else if b.tee != nil {
n, err = b.tee.Write(buf)
} else {
n, err = ioutil.Discard.Write(buf)
}
b.bytes += n
return n, err
Expand Down Expand Up @@ -117,6 +132,10 @@ func (b *basicWriter) Unwrap() http.ResponseWriter {
return b.ResponseWriter
}

func (b *basicWriter) Discard() {
b.discard = true
}

// flushWriter ...
type flushWriter struct {
basicWriter
Expand Down
62 changes: 62 additions & 0 deletions middleware/wrap_writer_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package middleware

import (
"bytes"
"net/http"
"net/http/httptest"
"testing"
)
Expand All @@ -22,3 +24,63 @@ func TestHttp2FancyWriterRemembersWroteHeaderWhenFlushed(t *testing.T) {
t.Fatal("want Flush to have set wroteHeader=true")
}
}

func TestBasicWritesTeesWritesWithoutDiscard(t *testing.T) {
// explicitly create the struct instead of NewRecorder to control the value of Code
original := &httptest.ResponseRecorder{
HeaderMap: make(http.Header),
Body: new(bytes.Buffer),
}
wrap := &basicWriter{ResponseWriter: original}

var buf bytes.Buffer
wrap.Tee(&buf)

_, err := wrap.Write([]byte("hello world"))
assertNoError(t, err)

assertEqual(t, 200, original.Code)
assertEqual(t, []byte("hello world"), original.Body.Bytes())
assertEqual(t, []byte("hello world"), buf.Bytes())
assertEqual(t, 11, wrap.BytesWritten())
}

func TestBasicWriterDiscardsWritesToOriginalResponseWriter(t *testing.T) {
t.Run("With Tee", func(t *testing.T) {
// explicitly create the struct instead of NewRecorder to control the value of Code
original := &httptest.ResponseRecorder{
HeaderMap: make(http.Header),
Body: new(bytes.Buffer),
}
wrap := &basicWriter{ResponseWriter: original}

var buf bytes.Buffer
wrap.Tee(&buf)
wrap.Discard()

_, err := wrap.Write([]byte("hello world"))
assertNoError(t, err)

assertEqual(t, 0, original.Code) // wrapper shouldn't call WriteHeader implicitly
assertEqual(t, 0, original.Body.Len())
assertEqual(t, []byte("hello world"), buf.Bytes())
assertEqual(t, 11, wrap.BytesWritten())
})

t.Run("Without Tee", func(t *testing.T) {
// explicitly create the struct instead of NewRecorder to control the value of Code
original := &httptest.ResponseRecorder{
HeaderMap: make(http.Header),
Body: new(bytes.Buffer),
}
wrap := &basicWriter{ResponseWriter: original}
wrap.Discard()

_, err := wrap.Write([]byte("hello world"))
assertNoError(t, err)

assertEqual(t, 0, original.Code) // wrapper shouldn't call WriteHeader implicitly
assertEqual(t, 0, original.Body.Len())
assertEqual(t, 11, wrap.BytesWritten())
})
}

0 comments on commit 67be7d9

Please sign in to comment.