diff --git a/services/httpd/config.go b/services/httpd/config.go index 3f42fbeffdb..cb5e0c97ea9 100644 --- a/services/httpd/config.go +++ b/services/httpd/config.go @@ -11,6 +11,9 @@ const ( // DefaultBindSocket is the default unix socket to bind to. DefaultBindSocket = "/var/run/influxdb.sock" + + // DefaultMaxBodySize is the default maximum size of a client request body, in bytes. Specify 0 for no limit. + DefaultMaxBodySize = 25e6 ) // Config represents a configuration for a HTTP service. @@ -30,6 +33,7 @@ type Config struct { Realm string `toml:"realm"` UnixSocketEnabled bool `toml:"unix-socket-enabled"` BindSocket string `toml:"bind-socket"` + MaxBodySize int `toml:"max-body-size"` } // NewConfig returns a new Config with default settings. @@ -45,6 +49,7 @@ func NewConfig() Config { Realm: DefaultRealm, UnixSocketEnabled: false, BindSocket: DefaultBindSocket, + MaxBodySize: DefaultMaxBodySize, } } diff --git a/services/httpd/handler.go b/services/httpd/handler.go index 2eb04fbe6bd..c686ea64eb4 100644 --- a/services/httpd/handler.go +++ b/services/httpd/handler.go @@ -609,8 +609,12 @@ func (h *Handler) serveWrite(w http.ResponseWriter, r *http.Request, user meta.U } } - // Handle gzip decoding of the body body := r.Body + if h.Config.MaxBodySize > 0 { + body = truncateReader(body, int64(h.Config.MaxBodySize)) + } + + // Handle gzip decoding of the body if r.Header.Get("Content-Encoding") == "gzip" { b, err := gzip.NewReader(r.Body) if err != nil { @@ -622,17 +626,25 @@ func (h *Handler) serveWrite(w http.ResponseWriter, r *http.Request, user meta.U } var bs []byte - if clStr := r.Header.Get("Content-Length"); clStr != "" { - if length, err := strconv.Atoi(clStr); err == nil { - // This will just be an initial hint for the gzip reader, as the - // bytes.Buffer will grow as needed when ReadFrom is called - bs = make([]byte, 0, length) + if r.ContentLength > 0 { + if h.Config.MaxBodySize > 0 && r.ContentLength > int64(h.Config.MaxBodySize) { + h.httpError(w, http.StatusText(http.StatusRequestEntityTooLarge), http.StatusRequestEntityTooLarge) + return } + + // This will just be an initial hint for the gzip reader, as the + // bytes.Buffer will grow as needed when ReadFrom is called + bs = make([]byte, 0, r.ContentLength) } buf := bytes.NewBuffer(bs) _, err := buf.ReadFrom(body) if err != nil { + if err == errTruncated { + h.httpError(w, http.StatusText(http.StatusRequestEntityTooLarge), http.StatusRequestEntityTooLarge) + return + } + if h.Config.WriteTracing { h.Logger.Info("Write handler unable to read bytes from request body") } diff --git a/services/httpd/handler_test.go b/services/httpd/handler_test.go index af411d422fc..5faa6de26ff 100644 --- a/services/httpd/handler_test.go +++ b/services/httpd/handler_test.go @@ -610,6 +610,45 @@ func TestHandler_HandleBadRequestBody(t *testing.T) { } } +func TestHandler_Write_EntityTooLarge_ContentLength(t *testing.T) { + b := bytes.NewReader(make([]byte, 100)) + h := NewHandler(false) + h.Config.MaxBodySize = 5 + h.MetaClient.DatabaseFn = func(name string) *meta.DatabaseInfo { + return &meta.DatabaseInfo{} + } + + w := httptest.NewRecorder() + h.ServeHTTP(w, MustNewRequest("POST", "/write?db=foo", b)) + if w.Code != http.StatusRequestEntityTooLarge { + t.Fatalf("unexpected status: %d", w.Code) + } +} + +// onlyReader implements io.Reader only to ensure Request.ContentLength is not set +type onlyReader struct { + r io.Reader +} + +func (o onlyReader) Read(p []byte) (n int, err error) { + return o.r.Read(p) +} + +func TestHandler_Write_EntityTooLarge_NoContentLength(t *testing.T) { + b := onlyReader{bytes.NewReader(make([]byte, 100))} + h := NewHandler(false) + h.Config.MaxBodySize = 5 + h.MetaClient.DatabaseFn = func(name string) *meta.DatabaseInfo { + return &meta.DatabaseInfo{} + } + + w := httptest.NewRecorder() + h.ServeHTTP(w, MustNewRequest("POST", "/write?db=foo", b)) + if w.Code != http.StatusRequestEntityTooLarge { + t.Fatalf("unexpected status: %d", w.Code) + } +} + // Ensure X-Forwarded-For header writes the correct log message. func TestHandler_XForwardedFor(t *testing.T) { var buf bytes.Buffer diff --git a/services/httpd/io.go b/services/httpd/io.go new file mode 100644 index 00000000000..cc48444b84f --- /dev/null +++ b/services/httpd/io.go @@ -0,0 +1,45 @@ +package httpd + +import ( + "errors" + "io" +) + +var ( + errTruncated = errors.New("Read: truncated") +) + +// truncateReader returns a Reader that reads from r +// but stops with ErrTruncated after n bytes. +func truncateReader(r io.Reader, n int64) io.ReadCloser { + tr := &truncatedReader{r: &io.LimitedReader{R: r, N: n + 1}} + + if rc, ok := r.(io.Closer); ok { + tr.Closer = rc + } + + return tr +} + +// A truncatedReader reads limits the amount of +// data returned to a maximum of r.N bytes. +type truncatedReader struct { + r *io.LimitedReader + io.Closer +} + +func (r *truncatedReader) Read(p []byte) (n int, err error) { + n, err = r.r.Read(p) + if r.r.N <= 0 { + return n, errTruncated + } + + return n, err +} + +func (r *truncatedReader) Close() error { + if r.Closer != nil { + return r.Closer.Close() + } + return nil +} diff --git a/services/httpd/io_test.go b/services/httpd/io_test.go new file mode 100644 index 00000000000..10876e6df65 --- /dev/null +++ b/services/httpd/io_test.go @@ -0,0 +1,31 @@ +package httpd + +import ( + "bytes" + "io/ioutil" + "testing" +) + +func TestTruncatedReader_Read(t *testing.T) { + tests := []struct { + name string + in []byte + n int64 + err error + }{ + {"in(1000)-max(1000)", make([]byte, 1000), 1000, nil}, + {"in(1000)-max(1001)", make([]byte, 1000), 1001, nil}, + {"in(1001)-max(1000)", make([]byte, 1001), 1000, errTruncated}, + {"in(10000)-max(1000)", make([]byte, 1e5), 1e3, errTruncated}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + b := truncateReader(bytes.NewReader(tc.in), tc.n) + _, err := ioutil.ReadAll(b) + if err != tc.err { + t.Errorf("unexpected error; got=%v, exp=%v", err, tc.err) + } + }) + } +}