Skip to content

Commit

Permalink
routing/http: feat: limit the resp body payload
Browse files Browse the repository at this point in the history
  • Loading branch information
guseggert committed Dec 6, 2022
1 parent 10595c9 commit f7313cf
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 1 deletion.
8 changes: 7 additions & 1 deletion routing/http/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,15 @@ func WithProviderInfo(peerID peer.ID, addrs []multiaddr.Multiaddr) option {
// New creates a content routing API client.
// The Provider and identity parameters are option. If they are nil, the `Provide` method will not function.
func New(baseURL string, opts ...option) (*client, error) {
defaultHTTPClient := &http.Client{
Transport: &ResponseBodyLimitedTransport{
RoundTripper: http.DefaultTransport,
LimitBytes: 1 << 20,
},
}
client := &client{
baseURL: baseURL,
httpClient: http.DefaultClient,
httpClient: defaultHTTPClient,
validator: ipns.Validator{},
clock: clock.New(),
}
Expand Down
38 changes: 38 additions & 0 deletions routing/http/client/transport.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package client

import (
"fmt"
"io"
"net/http"
)

type ResponseBodyLimitedTransport struct {
http.RoundTripper
LimitBytes int64
}

func (r *ResponseBodyLimitedTransport) RoundTrip(req *http.Request) (*http.Response, error) {
resp, err := r.RoundTripper.RoundTrip(req)
if resp != nil && resp.Body != nil {
resp.Body = &limitReadCloser{
limit: r.LimitBytes,
ReadCloser: resp.Body,
}
}
return resp, err
}

type limitReadCloser struct {
limit int64
bytesRead int64
io.ReadCloser
}

func (l *limitReadCloser) Read(p []byte) (int, error) {
n, err := l.ReadCloser.Read(p)
l.bytesRead += int64(n)
if l.bytesRead > l.limit {
return 0, fmt.Errorf("reached read limit of %d bytes after reading %d bytes", l.limit, l.bytesRead)
}
return n, err
}
82 changes: 82 additions & 0 deletions routing/http/client/transport_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package client

import (
"context"
"io"
"net/http"
"net/http/httptest"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

type testServer struct {
bytesToWrite int
}

func (s *testServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
bytes := make([]byte, s.bytesToWrite)
for i := 0; i < s.bytesToWrite; i++ {
bytes[i] = 'a'
}
_, err := w.Write(bytes)
if err != nil {
panic(err)
}
}

func TestResponseBodyLimitedTransport(t *testing.T) {
for _, c := range []struct {
name string
limit int64
serverSend int

expErr string
}{
{
name: "under the limit should succeed",
limit: 1 << 20,
serverSend: 1 << 19,
},
{
name: "over the limit should fail",
limit: 1 << 20,
serverSend: 1 << 21,
expErr: "reached read limit of 1048576 bytes after reading",
},
{
name: "exactly on the limit should succeed",
limit: 1 << 20,
serverSend: 1 << 20,
},
} {
t.Run(c.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

server := httptest.NewServer(&testServer{bytesToWrite: c.serverSend})
t.Cleanup(server.Close)

client := server.Client()
client.Transport = &ResponseBodyLimitedTransport{
LimitBytes: c.limit,
RoundTripper: client.Transport,
}

req, err := http.NewRequestWithContext(ctx, http.MethodGet, server.URL, nil)
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)

_, err = io.ReadAll(resp.Body)

if c.expErr == "" {
assert.NoError(t, err)
} else {
assert.Contains(t, err.Error(), c.expErr)
}

})
}
}

0 comments on commit f7313cf

Please sign in to comment.