diff --git a/routing/http/client/client.go b/routing/http/client/client.go index 173c20909b..9ad3e69870 100644 --- a/routing/http/client/client.go +++ b/routing/http/client/client.go @@ -69,9 +69,15 @@ func WithProviderInfo(peerID peer.ID, addrs []multiaddr.Multiaddr) option { // New creates a content routing API client. // The Provider and identity parameters are option. If they are nil, the `Provide` method will not function. func New(baseURL string, opts ...option) (*client, error) { + defaultHTTPClient := &http.Client{ + Transport: &ResponseBodyLimitedTransport{ + RoundTripper: http.DefaultTransport, + LimitBytes: 1 << 20, + }, + } client := &client{ baseURL: baseURL, - httpClient: http.DefaultClient, + httpClient: defaultHTTPClient, validator: ipns.Validator{}, clock: clock.New(), } diff --git a/routing/http/client/transport.go b/routing/http/client/transport.go new file mode 100644 index 0000000000..ea99204632 --- /dev/null +++ b/routing/http/client/transport.go @@ -0,0 +1,38 @@ +package client + +import ( + "fmt" + "io" + "net/http" +) + +type ResponseBodyLimitedTransport struct { + http.RoundTripper + LimitBytes int64 +} + +func (r *ResponseBodyLimitedTransport) RoundTrip(req *http.Request) (*http.Response, error) { + resp, err := r.RoundTripper.RoundTrip(req) + if resp != nil && resp.Body != nil { + resp.Body = &limitReadCloser{ + limit: r.LimitBytes, + ReadCloser: resp.Body, + } + } + return resp, err +} + +type limitReadCloser struct { + limit int64 + bytesRead int64 + io.ReadCloser +} + +func (l *limitReadCloser) Read(p []byte) (int, error) { + n, err := l.ReadCloser.Read(p) + l.bytesRead += int64(n) + if l.bytesRead > l.limit { + return 0, fmt.Errorf("reached read limit of %d bytes after reading %d bytes", l.limit, l.bytesRead) + } + return n, err +} diff --git a/routing/http/client/transport_test.go b/routing/http/client/transport_test.go new file mode 100644 index 0000000000..4b66e8bbaa --- /dev/null +++ b/routing/http/client/transport_test.go @@ -0,0 +1,82 @@ +package client + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type testServer struct { + bytesToWrite int +} + +func (s *testServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { + bytes := make([]byte, s.bytesToWrite) + for i := 0; i < s.bytesToWrite; i++ { + bytes[i] = 'a' + } + _, err := w.Write(bytes) + if err != nil { + panic(err) + } +} + +func TestResponseBodyLimitedTransport(t *testing.T) { + for _, c := range []struct { + name string + limit int64 + serverSend int + + expErr string + }{ + { + name: "under the limit should succeed", + limit: 1 << 20, + serverSend: 1 << 19, + }, + { + name: "over the limit should fail", + limit: 1 << 20, + serverSend: 1 << 21, + expErr: "reached read limit of 1048576 bytes after reading", + }, + { + name: "exactly on the limit should succeed", + limit: 1 << 20, + serverSend: 1 << 20, + }, + } { + t.Run(c.name, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + server := httptest.NewServer(&testServer{bytesToWrite: c.serverSend}) + t.Cleanup(server.Close) + + client := server.Client() + client.Transport = &ResponseBodyLimitedTransport{ + LimitBytes: c.limit, + RoundTripper: client.Transport, + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, server.URL, nil) + require.NoError(t, err) + resp, err := client.Do(req) + require.NoError(t, err) + + _, err = io.ReadAll(resp.Body) + + if c.expErr == "" { + assert.NoError(t, err) + } else { + assert.Contains(t, err.Error(), c.expErr) + } + + }) + } +}