diff --git a/.golangci.yml b/.golangci.yml index d5a5831..a68548a 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -32,10 +32,8 @@ linters: - unparam linters-settings: gosimple: - go: "1.21" checks: ["all"] staticcheck: - go: "1.21" checks: ["all"] dupl: threshold: 125 \ No newline at end of file diff --git a/check.go b/check.go index a22791d..ce4e2d0 100644 --- a/check.go +++ b/check.go @@ -14,6 +14,7 @@ import ( "github.com/klauspost/compress/gzhttp" "github.com/samber/lo" "github.com/sourcegraph/conc/pool" + "go.goblog.app/app/pkgs/bodylimit" "go.goblog.app/app/pkgs/httpcachetransport" ) @@ -57,7 +58,7 @@ func (a *goBlog) checkLinks(posts ...*post) error { Timeout: 30 * time.Second, Transport: httpcachetransport.NewHttpCacheTransportNoBody(gzhttp.Transport(&http.Transport{ DisableKeepAlives: true, MaxConnsPerHost: 1, - }), cache, 60*time.Minute), + }), cache, 60*time.Minute, 5*bodylimit.MB), } // Process all links type checkresult struct { diff --git a/http.go b/http.go index c5e1914..c6b3bb3 100644 --- a/http.go +++ b/http.go @@ -15,6 +15,7 @@ import ( "github.com/go-chi/chi/v5/middleware" "github.com/justinas/alice" "github.com/samber/lo" + "go.goblog.app/app/pkgs/bodylimit" "go.goblog.app/app/pkgs/httpcompress" "go.goblog.app/app/pkgs/maprouter" "go.goblog.app/app/pkgs/plugintypes" @@ -36,6 +37,7 @@ func (a *goBlog) startServer() (err error) { a.reloadRouter() // Set basic middlewares h := alice.New() + h = h.Append(bodylimit.BodyLimit(100 * bodylimit.MB)) h = h.Append(middleware.Heartbeat("/ping")) if a.cfg.Server.Logging { h = h.Append(a.logMiddleware) diff --git a/microformats.go b/microformats.go index bd23944..011346b 100644 --- a/microformats.go +++ b/microformats.go @@ -11,6 +11,7 @@ import ( "github.com/PuerkitoBio/goquery" "github.com/carlmjohnson/requests" "github.com/dgraph-io/ristretto" + "go.goblog.app/app/pkgs/bodylimit" "go.goblog.app/app/pkgs/bufferpool" "go.goblog.app/app/pkgs/contenttype" "go.goblog.app/app/pkgs/httpcachetransport" @@ -43,7 +44,7 @@ func (a *goBlog) parseMicroformats(u string, cache bool) (*microformatsResult, e ToWriter(pw) if cache { a.initMicroformatsCache() - rb.Transport(httpcachetransport.NewHttpCacheTransport(a.httpClient.Transport, a.mfCache, 10*time.Minute)) + rb.Transport(httpcachetransport.NewHttpCacheTransport(a.httpClient.Transport, a.mfCache, time.Minute, 5*bodylimit.MB)) } go func() { _ = pw.CloseWithError(rb.Fetch(context.Background())) diff --git a/pkgs/httpcachetransport/httpCacheTransport.go b/pkgs/httpcachetransport/httpCacheTransport.go index 42d5aae..ec0dcdb 100644 --- a/pkgs/httpcachetransport/httpCacheTransport.go +++ b/pkgs/httpcachetransport/httpCacheTransport.go @@ -3,6 +3,7 @@ package httpcachetransport import ( "bufio" "bytes" + "io" "net/http" "net/http/httputil" "time" @@ -15,6 +16,7 @@ type httpCacheTransport struct { ristrettoCache *ristretto.Cache ttl time.Duration body bool + maxSize int64 } func (t *httpCacheTransport) RoundTrip(r *http.Request) (*http.Response, error) { @@ -26,9 +28,22 @@ func (t *httpCacheTransport) RoundTrip(r *http.Request) (*http.Response, error) } } } + resp, err := t.parent.RoundTrip(r) if err == nil && t.ristrettoCache != nil { - respBytes, err := httputil.DumpResponse(resp, t.body) + // Limit the response size + limitedResp := &http.Response{ + Status: resp.Status, + StatusCode: resp.StatusCode, + Proto: resp.Proto, + ProtoMajor: resp.ProtoMajor, + ProtoMinor: resp.ProtoMinor, + Header: resp.Header, + Body: io.NopCloser(io.LimitReader(resp.Body, t.maxSize)), + ContentLength: -1, + } + + respBytes, err := httputil.DumpResponse(limitedResp, t.body) if err != nil { return resp, err } @@ -41,11 +56,11 @@ func (t *httpCacheTransport) RoundTrip(r *http.Request) (*http.Response, error) // Creates a new http.RoundTripper that caches all // request responses (by the request URL) in ristretto. -func NewHttpCacheTransport(parent http.RoundTripper, ristrettoCache *ristretto.Cache, ttl time.Duration) http.RoundTripper { - return &httpCacheTransport{parent, ristrettoCache, ttl, true} +func NewHttpCacheTransport(parent http.RoundTripper, ristrettoCache *ristretto.Cache, ttl time.Duration, maxSize int64) http.RoundTripper { + return &httpCacheTransport{parent, ristrettoCache, ttl, true, maxSize} } // Like NewHttpCacheTransport but doesn't cache body -func NewHttpCacheTransportNoBody(parent http.RoundTripper, ristrettoCache *ristretto.Cache, ttl time.Duration) http.RoundTripper { - return &httpCacheTransport{parent, ristrettoCache, ttl, false} +func NewHttpCacheTransportNoBody(parent http.RoundTripper, ristrettoCache *ristretto.Cache, ttl time.Duration, maxSize int64) http.RoundTripper { + return &httpCacheTransport{parent, ristrettoCache, ttl, false, maxSize} } diff --git a/pkgs/httpcachetransport/httpCacheTransport_test.go b/pkgs/httpcachetransport/httpCacheTransport_test.go index 0e9b7c6..5c3df0c 100644 --- a/pkgs/httpcachetransport/httpCacheTransport_test.go +++ b/pkgs/httpcachetransport/httpCacheTransport_test.go @@ -11,6 +11,7 @@ import ( "github.com/carlmjohnson/requests" "github.com/dgraph-io/ristretto" "github.com/stretchr/testify/assert" + "go.goblog.app/app/pkgs/bodylimit" ) const fakeResponse = `HTTP/1.1 200 OK @@ -37,7 +38,7 @@ func TestHttpCacheTransport(t *testing.T) { }) client := &http.Client{ - Transport: NewHttpCacheTransport(orig, cache, time.Minute), + Transport: NewHttpCacheTransport(orig, cache, time.Minute, bodylimit.KB), } err := requests.URL("https://example.com/").Client(client).Fetch(context.Background())