diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000000..f34f8db93e --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,10 @@ +version: 2 +updates: + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "daily" + - package-ecosystem: "gomod" + directory: "/" + schedule: + interval: "daily" \ No newline at end of file diff --git a/.github/workflows/cifuzz.yml b/.github/workflows/cifuzz.yml index 5f5fa747e6..1d0603d69a 100644 --- a/.github/workflows/cifuzz.yml +++ b/.github/workflows/cifuzz.yml @@ -19,7 +19,7 @@ jobs: dry-run: false language: go - name: Upload Crash - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 if: failure() && steps.build.outcome == 'success' with: name: artifacts diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index cd41c03a82..c250bf70e5 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -4,17 +4,26 @@ on: branches: - master pull_request: + +permissions: + # Required: allow read access to the content for analysis. + contents: read + # Optional: allow read access to pull request. Use with `only-new-issues` option. + pull-requests: read + # Optional: Allow write access to checks to allow the action to annotate code in the PR. + checks: write + jobs: lint: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 - - uses: actions/setup-go@v4 + - uses: actions/checkout@v4 + - uses: actions/setup-go@v5 with: go-version: 1.20.x - run: go version - name: Run golangci-lint - uses: golangci/golangci-lint-action@v3 + uses: golangci/golangci-lint-action@v4 with: - version: v1.55.2 + version: v1.56.2 args: --verbose diff --git a/.github/workflows/security.yml b/.github/workflows/security.yml index 5c592adb6c..997992bf49 100644 --- a/.github/workflows/security.yml +++ b/.github/workflows/security.yml @@ -14,8 +14,8 @@ jobs: env: GO111MODULE: on steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Run Gosec Security Scanner - uses: securego/gosec@v2.17.0 + uses: securego/gosec@v2.19.0 with: - args: '-exclude=G104,G304,G402 ./...' + args: '-exclude=G103,G104,G304,G402 ./...' diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 8f5c3d8798..5cfc4e5edb 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -9,12 +9,12 @@ jobs: strategy: fail-fast: false matrix: - go-version: [1.18.x, 1.19.x, 1.20.x, 1.21.x] + go-version: [1.19.x, 1.20.x, 1.21.x, 1.22.x] os: [ubuntu-latest, macos-latest, windows-latest] runs-on: ${{ matrix.os }} steps: - - uses: actions/checkout@v3 - - uses: actions/setup-go@v4 + - uses: actions/checkout@v4 + - uses: actions/setup-go@v5 with: go-version: ${{ matrix.go-version }} diff --git a/.golangci.yml b/.golangci.yml index 6d91345124..dded8fc422 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -33,7 +33,6 @@ linters: - noctx - nonamedreturns - paralleltest - - perfsprint - testableexamples - testpackage - thelper @@ -69,8 +68,18 @@ linters-settings: "-ST1000", # at least one file in a package should have a package comment ] gocritic: - enabled-checks: - - emptyStringTest + enabled-tags: + - diagnostic + - experimental + - opinionated + - performance + - style + disabled-checks: + - deferInLoop + - importShadow + - sloppyReassign + - unnamedResult + - whyNoLint issues: # Show all issues from a linter. diff --git a/args_test.go b/args_test.go index 68c6b7330d..9dff7115df 100644 --- a/args_test.go +++ b/args_test.go @@ -336,8 +336,8 @@ func testCopyTo(t *testing.T, a *Args) { var b Args a.CopyTo(&b) - if !reflect.DeepEqual(*a, b) { //nolint:govet - t.Fatalf("ArgsCopyTo fail, a: \n%+v\nb: \n%+v\n", *a, b) //nolint:govet + if !reflect.DeepEqual(a, &b) { + t.Fatalf("ArgsCopyTo fail, a: \n%+v\nb: \n%+v\n", a, &b) } b.VisitAll(func(k, _ []byte) { @@ -443,13 +443,13 @@ func TestArgsSetGetDel(t *testing.T) { t.Fatalf("Unexpected value: %q. Expected %q", a.Peek(k), v) } a.Del(k) - if string(a.Peek(k)) != "" { + if len(a.Peek(k)) != 0 { t.Fatalf("Unexpected value: %q. Expected %q", a.Peek(k), "") } } a.Parse("aaa=xxx&bb=aa") - if string(a.Peek("foo0")) != "" { + if len(a.Peek("foo0")) != 0 { t.Fatalf("Unexpected value %q", a.Peek("foo0")) } if string(a.Peek("aaa")) != "xxx" { @@ -474,7 +474,7 @@ func TestArgsSetGetDel(t *testing.T) { t.Fatalf("Unexpected value: %q. Expected %q", a.Peek(k), v) } a.Del(k) - if string(a.Peek(k)) != "" { + if len(a.Peek(k)) != 0 { t.Fatalf("Unexpected value: %q. Expected %q", a.Peek(k), "") } } diff --git a/brotli_test.go b/brotli_test.go index 7833c41d8c..794ab45cd9 100644 --- a/brotli_test.go +++ b/brotli_test.go @@ -105,9 +105,9 @@ func testBrotliCompressSingleCase(s string) error { func TestCompressHandlerBrotliLevel(t *testing.T) { t.Parallel() - expectedBody := string(createFixedBody(2e4)) + expectedBody := createFixedBody(2e4) h := CompressHandlerBrotliLevel(func(ctx *RequestCtx) { - ctx.WriteString(expectedBody) //nolint:errcheck + ctx.Write(expectedBody) //nolint:errcheck }, CompressBrotliDefaultCompression, CompressDefaultCompression) var ctx RequestCtx @@ -121,11 +121,11 @@ func TestCompressHandlerBrotliLevel(t *testing.T) { t.Fatalf("unexpected error: %v", err) } ce := resp.Header.ContentEncoding() - if string(ce) != "" { + if len(ce) != 0 { t.Fatalf("unexpected Content-Encoding: %q. Expecting %q", ce, "") } body := resp.Body() - if string(body) != expectedBody { + if !bytes.Equal(body, expectedBody) { t.Fatalf("unexpected body %q. Expecting %q", body, expectedBody) } @@ -148,7 +148,7 @@ func TestCompressHandlerBrotliLevel(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } - if string(body) != expectedBody { + if !bytes.Equal(body, expectedBody) { t.Fatalf("unexpected body %q. Expecting %q", body, expectedBody) } @@ -171,7 +171,7 @@ func TestCompressHandlerBrotliLevel(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } - if string(body) != expectedBody { + if !bytes.Equal(body, expectedBody) { t.Fatalf("unexpected body %q. Expecting %q", body, expectedBody) } } diff --git a/bytesconv.go b/bytesconv.go index b3cf29e3ce..4e36d1d7b3 100644 --- a/bytesconv.go +++ b/bytesconv.go @@ -35,7 +35,7 @@ func AppendHTMLEscape(dst []byte, s string) []byte { case '\'': sub = "'" // "'" is shorter than "'" and apos was not in HTML until HTML5. } - if len(sub) > 0 { + if sub != "" { dst = append(dst, s[prev:i]...) dst = append(dst, sub...) prev = i + 1 diff --git a/bytesconv_test.go b/bytesconv_test.go index 444981061a..54e5e20022 100644 --- a/bytesconv_test.go +++ b/bytesconv_test.go @@ -3,10 +3,10 @@ package fasthttp import ( "bufio" "bytes" - "fmt" "html" "net" "net/url" + "strconv" "testing" "time" @@ -118,7 +118,7 @@ func testAppendIPv4(t *testing.T, ipStr string, isValid bool) { } func testAppendUint(t *testing.T, n int) { - expectedS := fmt.Sprintf("%d", n) + expectedS := strconv.Itoa(n) s := AppendUint(nil, n) if string(s) != expectedS { t.Fatalf("unexpected uint %q. Expecting %q. n=%d", s, expectedS, n) diff --git a/client_test.go b/client_test.go index e5b589245f..48429657cc 100644 --- a/client_test.go +++ b/client_test.go @@ -456,7 +456,7 @@ func TestClientParseConn(t *testing.T) { t.Fatalf("req RemoteAddr parse addr fail: %q, hope: %q", res.RemoteAddr().String(), host) } - if !regexp.MustCompile(`^127\.0\.0\.1:[0-9]{4,5}$`).MatchString(res.LocalAddr().String()) { + if !regexp.MustCompile(`^127\.0\.0\.1:\d{4,5}$`).MatchString(res.LocalAddr().String()) { t.Fatalf("res LocalAddr addr match fail: %q, hope match: %q", res.LocalAddr().String(), "^127.0.0.1:[0-9]{4,5}$") } } @@ -2258,7 +2258,7 @@ type writeErrorConn struct { } func (w *writeErrorConn) Write(p []byte) (int, error) { - return 1, fmt.Errorf("error") + return 1, errors.New("error") } func (w *writeErrorConn) Close() error { @@ -2286,7 +2286,7 @@ type readErrorConn struct { } func (r *readErrorConn) Read(p []byte) (int, error) { - return 0, fmt.Errorf("error") + return 0, errors.New("error") } func (r *readErrorConn) Write(p []byte) (int, error) { @@ -2323,7 +2323,7 @@ func (r *singleReadConn) Read(p []byte) (int, error) { if len(r.s) == r.n { return 0, io.EOF } - n := copy(p, []byte(r.s[r.n:])) + n := copy(p, r.s[r.n:]) r.n += n return n, nil } @@ -2849,7 +2849,7 @@ func TestClientConfigureClientFailed(t *testing.T) { c := &Client{ ConfigureClient: func(hc *HostClient) error { - return fmt.Errorf("failed to configure") + return errors.New("failed to configure") }, } diff --git a/client_timing_test.go b/client_timing_test.go index 416ad53811..5e9a2f48d3 100644 --- a/client_timing_test.go +++ b/client_timing_test.go @@ -165,7 +165,7 @@ func BenchmarkNetHTTPClientDoFastServer(b *testing.B) { nn := uint32(0) b.RunParallel(func(pb *testing.PB) { - req, err := http.NewRequest(MethodGet, fmt.Sprintf("http://foobar%d.com/aaa/bbb", atomic.AddUint32(&nn, 1)), nil) + req, err := http.NewRequest(MethodGet, fmt.Sprintf("http://foobar%d.com/aaa/bbb", atomic.AddUint32(&nn, 1)), http.NoBody) if err != nil { b.Fatalf("unexpected error: %v", err) } @@ -550,7 +550,7 @@ func benchmarkNetHTTPClientEndToEndBigResponseInmemory(b *testing.B, parallelism url := "http://unused.host" + requestURI b.SetParallelism(parallelism) b.RunParallel(func(pb *testing.PB) { - req, err := http.NewRequest(MethodGet, url, nil) + req, err := http.NewRequest(MethodGet, url, http.NoBody) if err != nil { b.Fatalf("unexpected error: %v", err) } diff --git a/compress_test.go b/compress_test.go index 329f0f5953..e597267a77 100644 --- a/compress_test.go +++ b/compress_test.go @@ -2,6 +2,7 @@ package fasthttp import ( "bytes" + "errors" "fmt" "io" "testing" @@ -225,7 +226,7 @@ func testConcurrent(concurrency int, f func() error) error { return err } case <-time.After(time.Second): - return fmt.Errorf("timeout") + return errors.New("timeout") } } return nil diff --git a/fasthttpadaptor/adaptor.go b/fasthttpadaptor/adaptor.go index dcd43e4431..5e856fb5b9 100644 --- a/fasthttpadaptor/adaptor.go +++ b/fasthttpadaptor/adaptor.go @@ -3,8 +3,11 @@ package fasthttpadaptor import ( + "bufio" "io" + "net" "net/http" + "sync" "github.com/valyala/fasthttp" ) @@ -53,8 +56,10 @@ func NewFastHTTPHandler(h http.Handler) fasthttp.RequestHandler { ctx.Error("Internal Server Error", fasthttp.StatusInternalServerError) return } - - w := netHTTPResponseWriter{w: ctx.Response.BodyWriter()} + w := netHTTPResponseWriter{ + w: ctx.Response.BodyWriter(), + ctx: ctx, + } h.ServeHTTP(&w, r.WithContext(ctx)) ctx.SetStatusCode(w.StatusCode()) @@ -86,6 +91,7 @@ type netHTTPResponseWriter struct { statusCode int h http.Header w io.Writer + ctx *fasthttp.RequestCtx } func (w *netHTTPResponseWriter) StatusCode() int { @@ -111,3 +117,43 @@ func (w *netHTTPResponseWriter) Write(p []byte) (int, error) { } func (w *netHTTPResponseWriter) Flush() {} + +type wrappedConn struct { + net.Conn + + wg sync.WaitGroup + once sync.Once +} + +func (c *wrappedConn) Close() (err error) { + c.once.Do(func() { + err = c.Conn.Close() + c.wg.Done() + }) + return +} + +func (w *netHTTPResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + // Hijack assumes control of the connection, so we need to prevent fasthttp from closing it or + // doing anything else with it. + w.ctx.HijackSetNoResponse(true) + + conn := &wrappedConn{Conn: w.ctx.Conn()} + conn.wg.Add(1) + w.ctx.Hijack(func(net.Conn) { + conn.wg.Wait() + }) + + bufW := bufio.NewWriter(conn) + + // Write any unflushed body to the hijacked connection buffer. + unflushedBody := w.ctx.Response.Body() + if len(unflushedBody) > 0 { + if _, err := bufW.Write(unflushedBody); err != nil { + conn.Close() + return nil, nil, err + } + } + + return conn, &bufio.ReadWriter{Reader: bufio.NewReader(conn), Writer: bufW}, nil +} diff --git a/fasthttpadaptor/adaptor_test.go b/fasthttpadaptor/adaptor_test.go index 9f03858dc1..229e368550 100644 --- a/fasthttpadaptor/adaptor_test.go +++ b/fasthttpadaptor/adaptor_test.go @@ -7,8 +7,10 @@ import ( "net/url" "reflect" "testing" + "time" "github.com/valyala/fasthttp" + "github.com/valyala/fasthttp/fasthttputil" ) func TestNewFastHTTPHandler(t *testing.T) { @@ -143,3 +145,74 @@ func setContextValueMiddleware(next fasthttp.RequestHandler, key string, value a next(ctx) } } + +func TestHijack(t *testing.T) { + t.Parallel() + + nethttpH := func(w http.ResponseWriter, r *http.Request) { + if f, ok := w.(http.Hijacker); !ok { + t.Errorf("expected http.ResponseWriter to implement http.Hijacker") + } else { + if _, err := w.Write([]byte("foo")); err != nil { + t.Error(err) + } + + if c, rw, err := f.Hijack(); err != nil { + t.Error(err) + } else { + if _, err := rw.WriteString("bar"); err != nil { + t.Error(err) + } + + if err := rw.Flush(); err != nil { + t.Error(err) + } + + if err := c.Close(); err != nil { + t.Error(err) + } + } + } + } + + s := &fasthttp.Server{ + Handler: NewFastHTTPHandler(http.HandlerFunc(nethttpH)), + } + + ln := fasthttputil.NewInmemoryListener() + + go func() { + if err := s.Serve(ln); err != nil { + t.Errorf("unexpected error: %v", err) + } + }() + + clientCh := make(chan struct{}) + go func() { + c, err := ln.Dial() + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + if _, err = c.Write([]byte("GET / HTTP/1.1\r\nHost: aa\r\n\r\n")); err != nil { + t.Errorf("unexpected error: %v", err) + } + + buf, err := io.ReadAll(c) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + if string(buf) != "foobar" { + t.Errorf("unexpected response: %q. Expecting %q", buf, "foobar") + } + + close(clientCh) + }() + + select { + case <-clientCh: + case <-time.After(time.Second): + t.Fatal("timeout") + } +} diff --git a/fs.go b/fs.go index f35f22c91b..6793713f76 100644 --- a/fs.go +++ b/fs.go @@ -18,6 +18,7 @@ import ( "github.com/andybalholm/brotli" "github.com/klauspost/compress/gzip" + "github.com/klauspost/compress/zstd" "github.com/valyala/bytebufferpool" ) @@ -102,7 +103,7 @@ func ServeFile(ctx *RequestCtx, path string) { if path == "" || !filepath.IsAbs(path) { // extend relative path to absolute path - hasTrailingSlash := len(path) > 0 && (path[len(path)-1] == '/' || path[len(path)-1] == '\\') + hasTrailingSlash := path != "" && (path[len(path)-1] == '/' || path[len(path)-1] == '\\') var err error path = filepath.FromSlash(path) @@ -370,6 +371,7 @@ const FSCompressedFileSuffix = ".fasthttp.gz" var FSCompressedFileSuffixes = map[string]string{ "gzip": ".fasthttp.gz", "br": ".fasthttp.br", + "zstd": ".fasthttp.zst", } // FSHandlerCacheDuration is the default expiration duration for inactive @@ -442,7 +444,7 @@ func (fs *FS) normalizeRoot(root string) string { } // strip trailing slashes from the root path - for len(root) > 0 && root[len(root)-1] == os.PathSeparator { + for root != "" && root[len(root)-1] == os.PathSeparator { root = root[:len(root)-1] } return root @@ -460,7 +462,9 @@ func (fs *FS) initRequestHandler() { compressedFileSuffixes := fs.CompressedFileSuffixes if compressedFileSuffixes["br"] == "" || compressedFileSuffixes["gzip"] == "" || - compressedFileSuffixes["br"] == compressedFileSuffixes["gzip"] { + compressedFileSuffixes["zstd"] == "" || compressedFileSuffixes["br"] == compressedFileSuffixes["gzip"] || + compressedFileSuffixes["br"] == compressedFileSuffixes["zstd"] || + compressedFileSuffixes["gzip"] == compressedFileSuffixes["zstd"] { // Copy global map compressedFileSuffixes = make(map[string]string, len(FSCompressedFileSuffixes)) for k, v := range FSCompressedFileSuffixes { @@ -468,9 +472,10 @@ func (fs *FS) initRequestHandler() { } } - if len(fs.CompressedFileSuffix) > 0 { + if fs.CompressedFileSuffix != "" { compressedFileSuffixes["gzip"] = fs.CompressedFileSuffix compressedFileSuffixes["br"] = FSCompressedFileSuffixes["br"] + compressedFileSuffixes["zstd"] = FSCompressedFileSuffixes["zstd"] } h := &fsHandler{ @@ -794,6 +799,7 @@ const ( defaultCacheKind CacheKind = iota brotliCacheKind gzipCacheKind + zstdCacheKind ) func newCacheManager(fs *FS) cacheManager { @@ -1032,14 +1038,19 @@ func (h *fsHandler) handleRequest(ctx *RequestCtx) { fileEncoding := "" byteRange := ctx.Request.Header.peek(strRange) if len(byteRange) == 0 && h.compress { - if h.compressBrotli && ctx.Request.Header.HasAcceptEncodingBytes(strBr) { + switch { + case h.compressBrotli && ctx.Request.Header.HasAcceptEncodingBytes(strBr): mustCompress = true fileCacheKind = brotliCacheKind fileEncoding = "br" - } else if ctx.Request.Header.HasAcceptEncodingBytes(strGzip) { + case ctx.Request.Header.HasAcceptEncodingBytes(strGzip): mustCompress = true fileCacheKind = gzipCacheKind fileEncoding = "gzip" + case ctx.Request.Header.HasAcceptEncodingBytes(strZstd): + mustCompress = true + fileCacheKind = zstdCacheKind + fileEncoding = "zstd" } } @@ -1097,10 +1108,13 @@ func (h *fsHandler) handleRequest(ctx *RequestCtx) { hdr := &ctx.Response.Header if ff.compressed { - if fileEncoding == "br" { + switch fileEncoding { + case "br": hdr.SetContentEncodingBytes(strBr) - } else if fileEncoding == "gzip" { + case "gzip": hdr.SetContentEncodingBytes(strGzip) + case "zstd": + hdr.SetContentEncodingBytes(strZstd) } } @@ -1304,10 +1318,13 @@ nestedContinue: if mustCompress { var zbuf bytebufferpool.ByteBuffer - if fileEncoding == "br" { + switch fileEncoding { + case "br": zbuf.B = AppendBrotliBytesLevel(zbuf.B, w.B, CompressDefaultCompression) - } else if fileEncoding == "gzip" { + case "gzip": zbuf.B = AppendGzipBytesLevel(zbuf.B, w.B, CompressDefaultCompression) + case "zstd": + zbuf.B = AppendZstdBytesLevel(zbuf.B, w.B, CompressZstdDefault) } w = &zbuf } @@ -1333,7 +1350,7 @@ const ( fsMaxCompressibleFileSize = 8 * 1024 * 1024 ) -func (h *fsHandler) compressAndOpenFSFile(filePath string, fileEncoding string) (*fsFile, error) { +func (h *fsHandler) compressAndOpenFSFile(filePath, fileEncoding string) (*fsFile, error) { f, err := h.filesystem.Open(filePath) if err != nil { return nil, err @@ -1406,20 +1423,28 @@ func (h *fsHandler) compressFileNolock( } return nil, errNoCreatePermission } - if fileEncoding == "br" { + switch fileEncoding { + case "br": zw := acquireStacklessBrotliWriter(zf, CompressDefaultCompression) _, err = copyZeroAlloc(zw, f) if err1 := zw.Flush(); err == nil { err = err1 } releaseStacklessBrotliWriter(zw, CompressDefaultCompression) - } else if fileEncoding == "gzip" { + case "gzip": zw := acquireStacklessGzipWriter(zf, CompressDefaultCompression) _, err = copyZeroAlloc(zw, f) if err1 := zw.Flush(); err == nil { err = err1 } releaseStacklessGzipWriter(zw, CompressDefaultCompression) + case "zstd": + zw := acquireStacklessZstdWriter(zf, CompressZstdDefault) + _, err = copyZeroAlloc(zw, f) + if err1 := zw.Flush(); err == nil { + err = err1 + } + releaseStacklessZstdWriter(zw, CompressZstdDefault) } _ = zf.Close() _ = f.Close() @@ -1443,20 +1468,28 @@ func (h *fsHandler) newCompressedFSFileCache(f fs.File, fileInfo fs.FileInfo, fi err error ) - if fileEncoding == "br" { + switch fileEncoding { + case "br": zw := acquireStacklessBrotliWriter(w, CompressDefaultCompression) _, err = copyZeroAlloc(zw, f) if err1 := zw.Flush(); err == nil { err = err1 } releaseStacklessBrotliWriter(zw, CompressDefaultCompression) - } else if fileEncoding == "gzip" { + case "gzip": zw := acquireStacklessGzipWriter(w, CompressDefaultCompression) _, err = copyZeroAlloc(zw, f) if err1 := zw.Flush(); err == nil { err = err1 } releaseStacklessGzipWriter(zw, CompressDefaultCompression) + case "zstd": + zw := acquireStacklessZstdWriter(w, CompressZstdDefault) + _, err = copyZeroAlloc(zw, f) + if err1 := zw.Flush(); err == nil { + err = err1 + } + releaseStacklessZstdWriter(zw, CompressZstdDefault) } defer func() { _ = f.Close() }() @@ -1499,7 +1532,7 @@ func (h *fsHandler) newCompressedFSFileCache(f fs.File, fileInfo fs.FileInfo, fi return ff, nil } -func (h *fsHandler) newCompressedFSFile(filePath string, fileEncoding string) (*fsFile, error) { +func (h *fsHandler) newCompressedFSFile(filePath, fileEncoding string) (*fsFile, error) { f, err := h.filesystem.Open(filePath) if err != nil { return nil, fmt.Errorf("cannot open compressed file %q: %w", filePath, err) @@ -1600,21 +1633,28 @@ func (h *fsHandler) newFSFile(f fs.File, fileInfo fs.FileInfo, compressed bool, func readFileHeader(f io.Reader, compressed bool, fileEncoding string) ([]byte, error) { r := f var ( - br *brotli.Reader - zr *gzip.Reader + br *brotli.Reader + zr *gzip.Reader + zsr *zstd.Decoder ) if compressed { var err error - if fileEncoding == "br" { + switch fileEncoding { + case "br": if br, err = acquireBrotliReader(f); err != nil { return nil, err } r = br - } else if fileEncoding == "gzip" { + case "gzip": if zr, err = acquireGzipReader(f); err != nil { return nil, err } r = zr + case "zstd": + if zsr, err = acquireZstdReader(f); err != nil { + return nil, err + } + r = zsr } } @@ -1639,6 +1679,10 @@ func readFileHeader(f io.Reader, compressed bool, fileEncoding string) ([]byte, releaseGzipReader(zr) } + if zsr != nil { + releaseZstdReader(zsr) + } + return data, err } diff --git a/fs_fs_test.go b/fs_fs_test.go index 3ddc01ae78..437e399902 100644 --- a/fs_fs_test.go +++ b/fs_fs_test.go @@ -254,7 +254,7 @@ func testFSFSCompress(t *testing.T, h RequestHandler, filePath string) { t.Errorf("unexpected status code: %d. Expecting %d. filePath=%q", resp.StatusCode(), StatusOK, filePath) } ce := resp.Header.ContentEncoding() - if string(ce) != "" { + if len(ce) != 0 { t.Errorf("unexpected content-encoding %q. Expecting empty string. filePath=%q", ce, filePath) } body := string(resp.Body()) diff --git a/fs_test.go b/fs_test.go index 572fafe161..9704199b8e 100644 --- a/fs_test.go +++ b/fs_test.go @@ -652,7 +652,7 @@ func testFSCompress(t *testing.T, h RequestHandler, filePath string) { t.Errorf("unexpected status code: %d. Expecting %d. filePath=%q", resp.StatusCode(), StatusOK, filePath) } ce := resp.Header.ContentEncoding() - if string(ce) != "" { + if len(ce) != 0 { t.Errorf("unexpected content-encoding %q. Expecting empty string. filePath=%q", ce, filePath) } body := string(resp.Body()) diff --git a/fuzz_test.go b/fuzz_test.go index 827e776682..532c052f99 100644 --- a/fuzz_test.go +++ b/fuzz_test.go @@ -7,17 +7,13 @@ import ( ) func FuzzCookieParse(f *testing.F) { - inputs := []string{ - `xxx=yyy`, - `xxx=yyy; expires=Tue, 10 Nov 2009 23:00:00 GMT; domain=foobar.com; path=/a/b`, - " \n\t\"", - } - for _, input := range inputs { - f.Add([]byte(input)) - } - c := AcquireCookie() - defer ReleaseCookie(c) + f.Add([]byte(`xxx=yyy`)) + f.Add([]byte(`xxx=yyy; expires=Tue, 10 Nov 2009 23:00:00 GMT; domain=foobar.com; path=/a/b`)) + f.Add([]byte(" \n\t\"")) + f.Fuzz(func(t *testing.T, cookie []byte) { + var c Cookie + _ = c.ParseBytes(cookie) w := bytes.Buffer{} @@ -28,15 +24,11 @@ func FuzzCookieParse(f *testing.F) { } func FuzzVisitHeaderParams(f *testing.F) { - inputs := []string{ - `application/json; v=1; foo=bar; q=0.938; param=param; param="big fox"; q=0.43`, - `*/*`, - `\\`, - `text/plain; foo="\\\"\'\\''\'"`, - } - for _, input := range inputs { - f.Add([]byte(input)) - } + f.Add([]byte(`application/json; v=1; foo=bar; q=0.938; param=param; param="big fox"; q=0.43`)) + f.Add([]byte(`*/*`)) + f.Add([]byte(`\\`)) + f.Add([]byte(`text/plain; foo="\\\"\'\\''\'"`)) + f.Fuzz(func(t *testing.T, header []byte) { VisitHeaderParams(header, func(key, value []byte) bool { if len(key) == 0 { @@ -48,16 +40,15 @@ func FuzzVisitHeaderParams(f *testing.F) { } func FuzzResponseReadLimitBody(f *testing.F) { - res := AcquireResponse() - defer ReleaseResponse(res) - - f.Add([]byte("HTTP/1.1 200 OK\r\nContent-Type: aa\r\nContent-Length: 10\r\n\r\n9876543210"), 1024*1024) + f.Add([]byte("HTTP/1.1 200 OK\r\nContent-Type: aa\r\nContent-Length: 10\r\n\r\n9876543210"), 1024) f.Fuzz(func(t *testing.T, body []byte, max int) { - if max > 10*1024*1024 { // Skip limits higher than 10MB. + if len(body) > 10*1024 || max > 10*1024 { return } + var res Response + _ = res.ReadLimitBody(bufio.NewReader(bytes.NewReader(body)), max) w := bytes.Buffer{} _, _ = res.WriteTo(&w) @@ -65,16 +56,15 @@ func FuzzResponseReadLimitBody(f *testing.F) { } func FuzzRequestReadLimitBody(f *testing.F) { - req := AcquireRequest() - defer ReleaseRequest(req) - - f.Add([]byte("POST /a HTTP/1.1\r\nHost: a.com\r\nTransfer-Encoding: chunked\r\nContent-Type: aa\r\n\r\n6\r\nfoobar\r\n3\r\nbaz\r\n0\r\nfoobar\r\n\r\n"), 1024*1024) + f.Add([]byte("POST /a HTTP/1.1\r\nHost: a.com\r\nTransfer-Encoding: chunked\r\nContent-Type: aa\r\n\r\n6\r\nfoobar\r\n3\r\nbaz\r\n0\r\nfoobar\r\n\r\n"), 1024) f.Fuzz(func(t *testing.T, body []byte, max int) { - if max > 10*1024*1024 { // Skip limits higher than 10MB. + if len(body) > 10*1024 || max > 10*1024 { return } + var req Request + _ = req.ReadLimitBody(bufio.NewReader(bytes.NewReader(body)), max) w := bytes.Buffer{} _, _ = req.WriteTo(&w) @@ -82,15 +72,14 @@ func FuzzRequestReadLimitBody(f *testing.F) { } func FuzzURIUpdateBytes(f *testing.F) { - u := AcquireURI() - defer ReleaseURI(u) - f.Add([]byte(`http://foobar.com/aaa/bb?cc`)) f.Add([]byte(`//foobar.com/aaa/bb?cc`)) f.Add([]byte(`/aaa/bb?cc`)) f.Add([]byte(`xx?yy=abc`)) f.Fuzz(func(t *testing.T, uri []byte) { + var u URI + u.UpdateBytes(uri) w := bytes.Buffer{} diff --git a/go.mod b/go.mod index d4c4c8c301..f39397e900 100644 --- a/go.mod +++ b/go.mod @@ -3,13 +3,13 @@ module github.com/valyala/fasthttp go 1.20 require ( - github.com/andybalholm/brotli v1.0.5 - github.com/klauspost/compress v1.17.0 + github.com/andybalholm/brotli v1.1.0 + github.com/klauspost/compress v1.17.7 github.com/valyala/bytebufferpool v1.0.0 github.com/valyala/tcplisten v1.0.0 - golang.org/x/crypto v0.17.0 - golang.org/x/net v0.17.0 - golang.org/x/sys v0.15.0 + golang.org/x/crypto v0.21.0 + golang.org/x/net v0.22.0 + golang.org/x/sys v0.18.0 ) require golang.org/x/text v0.14.0 // indirect diff --git a/go.sum b/go.sum index 04460ea96a..a48ab9aa3c 100644 --- a/go.sum +++ b/go.sum @@ -1,16 +1,16 @@ -github.com/andybalholm/brotli v1.0.5 h1:8uQZIdzKmjc/iuPu7O2ioW48L81FgatrcpfFmiq/cCs= -github.com/andybalholm/brotli v1.0.5/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= -github.com/klauspost/compress v1.17.0 h1:Rnbp4K9EjcDuVuHtd0dgA4qNuv9yKDYKK1ulpJwgrqM= -github.com/klauspost/compress v1.17.0/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE= +github.com/andybalholm/brotli v1.1.0 h1:eLKJA0d02Lf0mVpIDgYnqXcUn0GqVmEFny3VuID1U3M= +github.com/andybalholm/brotli v1.1.0/go.mod h1:sms7XGricyQI9K10gOSf56VKKWS4oLer58Q+mhRPtnY= +github.com/klauspost/compress v1.17.7 h1:ehO88t2UGzQK66LMdE8tibEd1ErmzZjNEqWkjLAKQQg= +github.com/klauspost/compress v1.17.7/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/tcplisten v1.0.0 h1:rBHj/Xf+E1tRGZyWIWwJDiRY0zc1Js+CV5DqwacVSA8= github.com/valyala/tcplisten v1.0.0/go.mod h1:T0xQ8SeCZGxckz9qRXTfG43PvQ/mcWh7FwZEA7Ioqkc= -golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k= -golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= -golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= -golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= -golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= -golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA= +golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= +golang.org/x/net v0.22.0 h1:9sGLhx7iRIHEiX0oAJ3MRZMUCElJgy7Br1nO+AMN3Tc= +golang.org/x/net v0.22.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= +golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= +golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= diff --git a/header.go b/header.go index ac279d7261..a5b42e9ed0 100644 --- a/header.go +++ b/header.go @@ -1430,7 +1430,7 @@ func (h *ResponseHeader) setSpecialHeader(key, value []byte) bool { } // setNonSpecial directly put into map i.e. not a basic header. -func (h *ResponseHeader) setNonSpecial(key []byte, value []byte) { +func (h *ResponseHeader) setNonSpecial(key, value []byte) { h.h = setArgBytes(h.h, key, value, argsHasValue) } @@ -1489,7 +1489,7 @@ func (h *RequestHeader) setSpecialHeader(key, value []byte) bool { } // setNonSpecial directly put into map i.e. not a basic header. -func (h *RequestHeader) setNonSpecial(key []byte, value []byte) { +func (h *RequestHeader) setNonSpecial(key, value []byte) { h.h = setArgBytes(h.h, key, value, argsHasValue) } @@ -2118,7 +2118,7 @@ func (h *ResponseHeader) tryRead(r *bufio.Reader, n int) error { if err == bufio.ErrBufferFull { if h.secureErrorLogMessage { return &ErrSmallBuffer{ - error: fmt.Errorf("error when reading response headers"), + error: errors.New("error when reading response headers"), } } return &ErrSmallBuffer{ @@ -2170,7 +2170,7 @@ func (h *ResponseHeader) tryReadTrailer(r *bufio.Reader, n int) error { if err == bufio.ErrBufferFull { if h.secureErrorLogMessage { return &ErrSmallBuffer{ - error: fmt.Errorf("error when reading response trailer"), + error: errors.New("error when reading response trailer"), } } return &ErrSmallBuffer{ @@ -2279,7 +2279,7 @@ func (h *RequestHeader) tryReadTrailer(r *bufio.Reader, n int) error { if err == bufio.ErrBufferFull { if h.secureErrorLogMessage { return &ErrSmallBuffer{ - error: fmt.Errorf("error when reading request trailer"), + error: errors.New("error when reading request trailer"), } } return &ErrSmallBuffer{ @@ -2821,7 +2821,7 @@ func (h *ResponseHeader) parseFirstLine(buf []byte) (int, error) { n := bytes.IndexByte(b, ' ') if n < 0 { if h.secureErrorLogMessage { - return 0, fmt.Errorf("cannot find whitespace in the first line of response") + return 0, errors.New("cannot find whitespace in the first line of response") } return 0, fmt.Errorf("cannot find whitespace in the first line of response %q", buf) } @@ -2838,7 +2838,7 @@ func (h *ResponseHeader) parseFirstLine(buf []byte) (int, error) { } if len(b) > n && b[n] != ' ' { if h.secureErrorLogMessage { - return 0, fmt.Errorf("unexpected char at the end of status code") + return 0, errors.New("unexpected char at the end of status code") } return 0, fmt.Errorf("unexpected char at the end of status code. Response %q", buf) } @@ -2863,31 +2863,47 @@ func (h *RequestHeader) parseFirstLine(buf []byte) (int, error) { n := bytes.IndexByte(b, ' ') if n <= 0 { if h.secureErrorLogMessage { - return 0, fmt.Errorf("cannot find http request method") + return 0, errors.New("cannot find http request method") } return 0, fmt.Errorf("cannot find http request method in %q", buf) } h.method = append(h.method[:0], b[:n]...) b = b[n+1:] - protoStr := strHTTP11 // parse requestURI n = bytes.LastIndexByte(b, ' ') - switch { - case n < 0: - h.noHTTP11 = true - n = len(b) - protoStr = strHTTP10 - case n == 0: + if n < 0 { + return 0, fmt.Errorf("cannot find whitespace in the first line of request %q", buf) + } else if n == 0 { if h.secureErrorLogMessage { - return 0, fmt.Errorf("requestURI cannot be empty") + return 0, errors.New("requestURI cannot be empty") } return 0, fmt.Errorf("requestURI cannot be empty in %q", buf) - case !bytes.Equal(b[n+1:], strHTTP11): - h.noHTTP11 = true - protoStr = b[n+1:] } + protoStr := b[n+1:] + + // Follow RFCs 7230 and 9112 and require that HTTP versions match the following pattern: HTTP/[0-9]\.[0-9] + if len(protoStr) != len(strHTTP11) { + if h.secureErrorLogMessage { + return 0, fmt.Errorf("unsupported HTTP version %q", protoStr) + } + return 0, fmt.Errorf("unsupported HTTP version %q in %q", protoStr, buf) + } + if !bytes.HasPrefix(protoStr, strHTTP11[:5]) { + if h.secureErrorLogMessage { + return 0, fmt.Errorf("unsupported HTTP version %q", protoStr) + } + return 0, fmt.Errorf("unsupported HTTP version %q in %q", protoStr, buf) + } + if protoStr[5] < '0' || protoStr[5] > '9' || protoStr[7] < '0' || protoStr[7] > '9' { + if h.secureErrorLogMessage { + return 0, fmt.Errorf("unsupported HTTP version %q", protoStr) + } + return 0, fmt.Errorf("unsupported HTTP version %q in %q", protoStr, buf) + } + + h.noHTTP11 = !bytes.Equal(protoStr, strHTTP11) h.proto = append(h.proto[:0], protoStr...) h.requestURI = append(h.requestURI[:0], b[:n]...) @@ -3013,6 +3029,8 @@ func (h *ResponseHeader) parseHeaders(buf []byte) (int, error) { func (h *RequestHeader) parseHeaders(buf []byte) (int, error) { h.contentLength = -2 + contentLengthSeen := false + var s headerScanner s.b = buf s.disableNormalizing = h.disableNormalizing @@ -3048,6 +3066,11 @@ func (h *RequestHeader) parseHeaders(buf []byte) (int, error) { continue } if caseInsensitiveCompare(s.key, strContentLength) { + if contentLengthSeen { + return 0, errors.New("duplicate Content-Length header") + } + contentLengthSeen = true + if h.contentLength != -1 { var nerr error if h.contentLength, nerr = parseContentLength(s.value); nerr != nil { @@ -3072,7 +3095,17 @@ func (h *RequestHeader) parseHeaders(buf []byte) (int, error) { } case 't': if caseInsensitiveCompare(s.key, strTransferEncoding) { - if !bytes.Equal(s.value, strIdentity) { + isIdentity := caseInsensitiveCompare(s.value, strIdentity) + isChunked := caseInsensitiveCompare(s.value, strChunked) + + if !isIdentity && !isChunked { + if h.secureErrorLogMessage { + return 0, errors.New("unsupported Transfer-Encoding") + } + return 0, fmt.Errorf("unsupported Transfer-Encoding: %q", s.value) + } + + if isChunked { h.contentLength = -1 h.h = setArgBytes(h.h, strTransferEncoding, strChunked, argsHasValue) } diff --git a/header_test.go b/header_test.go index 163fdd7b75..7e30162da2 100644 --- a/header_test.go +++ b/header_test.go @@ -9,6 +9,7 @@ import ( "io" "net/http" "reflect" + "strconv" "strings" "testing" ) @@ -167,7 +168,7 @@ func TestResponseHeaderEmptyValueFromHeader(t *testing.T) { if err := h.Read(br); err != nil { t.Fatalf("unexpected error: %v", err) } - if string(h.ContentType()) != string(h1.ContentType()) { + if !bytes.Equal(h.ContentType(), h1.ContentType()) { t.Fatalf("unexpected content-type: %q. Expecting %q", h.ContentType(), h1.ContentType()) } v1 := h.Peek("EmptyValue1") @@ -222,7 +223,7 @@ func TestRequestHeaderEmptyValueFromHeader(t *testing.T) { if err := h.Read(br); err != nil { t.Fatalf("unexpected error: %v", err) } - if string(h.Host()) != string(h1.Host()) { + if !bytes.Equal(h.Host(), h1.Host()) { t.Fatalf("unexpected host: %q. Expecting %q", h.Host(), h1.Host()) } v1 := h.Peek("EmptyValue1") @@ -341,7 +342,7 @@ func TestRequestRawHeaders(t *testing.T) { if err := h.Read(br); err != nil { t.Fatalf("unexpected error: %v", err) } - if string(h.Host()) != "" { + if len(h.Host()) != 0 { t.Fatalf("unexpected host: %q. Expecting %q", h.Host(), "") } v1 := h.Peek("NoKey") @@ -454,7 +455,7 @@ func TestResponseHeaderAdd(t *testing.T) { m["bbb"] = struct{}{} m["xxx"] = struct{}{} for i := 0; i < 10; i++ { - v := fmt.Sprintf("%d", i) + v := strconv.Itoa(i) h.Add("Foo-Bar", v) m[v] = struct{}{} } @@ -507,7 +508,7 @@ func TestRequestHeaderAdd(t *testing.T) { m["bbb"] = struct{}{} m["xxx"] = struct{}{} for i := 0; i < 10; i++ { - v := fmt.Sprintf("%d", i) + v := strconv.Itoa(i) h.Add("Foo-Bar", v) m[v] = struct{}{} } @@ -687,11 +688,11 @@ func TestResponseHeaderDel(t *testing.T) { t.Fatalf("non-zero value: %q", hv) } hv = h.Peek(HeaderContentType) - if string(hv) != string(defaultContentType) { + if !bytes.Equal(hv, defaultContentType) { t.Fatalf("unexpected content-type: %q. Expecting %q", hv, defaultContentType) } hv = h.Peek(HeaderContentEncoding) - if string(hv) != ("gzip") { + if string(hv) != "gzip" { t.Fatalf("unexpected content-encoding: %q. Expecting %q", hv, "gzip") } hv = h.Peek(HeaderServer) @@ -1294,7 +1295,7 @@ func TestResponseHeaderFirstByteReadEOF(t *testing.T) { var h ResponseHeader - r := &errorReader{fmt.Errorf("non-eof error")} + r := &errorReader{errors.New("non-eof error")} br := bufio.NewReader(r) err := h.Read(br) if err == nil { @@ -1341,11 +1342,6 @@ func TestRequestHeaderHTTPVer(t *testing.T) { // non-http/1.1 testRequestHeaderHTTPVer(t, "GET / HTTP/1.0\r\nHost: aa.com\r\n\r\n", true) testRequestHeaderHTTPVer(t, "GET / HTTP/0.9\r\nHost: aa.com\r\n\r\n", true) - testRequestHeaderHTTPVer(t, "GET / foobar\r\nHost: aa.com\r\n\r\n", true) - - // empty http version - testRequestHeaderHTTPVer(t, "GET /\r\nHost: aaa.com\r\n\r\n", true) - testRequestHeaderHTTPVer(t, "GET / \r\nHost: aaa.com\r\n\r\n", true) // http/1.1 testRequestHeaderHTTPVer(t, "GET / HTTP/1.1\r\nHost: a.com\r\n\r\n", false) @@ -1365,6 +1361,8 @@ func testResponseHeaderHTTPVer(t *testing.T, s string, connectionClose bool) { } func testRequestHeaderHTTPVer(t *testing.T, s string, connectionClose bool) { + t.Helper() + var h RequestHeader r := bytes.NewBufferString(s) @@ -1410,8 +1408,8 @@ func TestResponseHeaderCopyTo(t *testing.T) { h.bufKV = argsKV{} h1.bufKV = argsKV{} - if !reflect.DeepEqual(h, h1) { //nolint:govet - t.Fatalf("ResponseHeaderCopyTo fail, src: \n%+v\ndst: \n%+v\n", h, h1) //nolint:govet + if !reflect.DeepEqual(&h, &h1) { + t.Fatalf("ResponseHeaderCopyTo fail, src: \n%+v\ndst: \n%+v\n", &h, &h1) } } @@ -1453,8 +1451,8 @@ func TestRequestHeaderCopyTo(t *testing.T) { h.bufKV = argsKV{} h1.bufKV = argsKV{} - if !reflect.DeepEqual(h, h1) { //nolint:govet - t.Fatalf("RequestHeaderCopyTo fail, src: \n%+v\ndst: \n%+v\n", h, h1) //nolint:govet + if !reflect.DeepEqual(&h, &h1) { + t.Fatalf("RequestHeaderCopyTo fail, src: \n%+v\ndst: \n%+v\n", &h, &h1) } } @@ -1469,7 +1467,7 @@ func TestResponseContentTypeNoDefaultNotEmpty(t *testing.T) { headers := h.String() if strings.Contains(headers, "Content-Type: \r\n") { - t.Fatalf("ResponseContentTypeNoDefaultNotEmpty fail, response: \n%+v\noutcome: \n%q\n", h, headers) //nolint:govet + t.Fatalf("ResponseContentTypeNoDefaultNotEmpty fail, response: \n%+v\noutcome: \n%q\n", &h, headers) } } @@ -1522,7 +1520,7 @@ func TestRequestContentTypeNoDefault(t *testing.T) { t.Fatalf("Unexpected error: %v", err) } - if string(h1.contentType) != "" { + if len(h1.contentType) != 0 { t.Fatalf("unexpected Content-Type %q. Expecting %q", h1.contentType, "") } } @@ -1537,7 +1535,7 @@ func TestResponseDateNoDefaultNotEmpty(t *testing.T) { headers := h.String() if strings.Contains(headers, "\r\nDate: ") { - t.Fatalf("ResponseDateNoDefaultNotEmpty fail, response: \n%+v\noutcome: \n%q\n", h, headers) //nolint:govet + t.Fatalf("ResponseDateNoDefaultNotEmpty fail, response: \n%+v\noutcome: \n%q\n", &h, headers) } } @@ -2124,7 +2122,7 @@ func testRequestHeaderMethod(t *testing.T, expectedMethod string) { t.Fatalf("unexpected error: %v", err) } m1 := h1.Method() - if string(m) != string(m1) { + if !bytes.Equal(m, m1) { t.Fatalf("unexpected method: %q. Expecting %q", m, m1) } } @@ -2621,10 +2619,6 @@ func TestRequestHeaderReadSuccess(t *testing.T) { testRequestHeaderReadSuccess(t, h, "POST /a HTTP/1.1\r\nHost: aa\r\nContent-Type: ab\r\nContent-Length: 123\r\nContent-Type: xx\r\n\r\n", 123, "/a", "aa", "", "xx", nil) - // post with duplicate content-length - testRequestHeaderReadSuccess(t, h, "POST /xx HTTP/1.1\r\nHost: aa\r\nContent-Type: s\r\nContent-Length: 13\r\nContent-Length: 1\r\n\r\n", - 1, "/xx", "aa", "", "s", nil) - // non-post with content-type testRequestHeaderReadSuccess(t, h, "GET /aaa HTTP/1.1\r\nHost: bbb.com\r\nContent-Type: aaab\r\n\r\n", -2, "/aaa", "bbb.com", "", "aaab", nil) @@ -2641,10 +2635,6 @@ func TestRequestHeaderReadSuccess(t *testing.T) { testRequestHeaderReadSuccess(t, h, "GET http://gooGle.com/foO/%20bar?xxx#aaa HTTP/1.1\r\nHost: aa.cOM\r\n\r\ntrail", -2, "http://gooGle.com/foO/%20bar?xxx#aaa", "aa.cOM", "", "", nil) - // no protocol in the first line - testRequestHeaderReadSuccess(t, h, "GET /foo/bar\r\nHost: google.com\r\n\r\nisdD", - -2, "/foo/bar", "google.com", "", "", nil) - // blank lines before the first line testRequestHeaderReadSuccess(t, h, "\r\n\n\r\nGET /aaa HTTP/1.1\r\nHost: aaa.com\r\n\r\nsss", -2, "/aaa", "aaa.com", "", "", nil) @@ -2713,6 +2703,9 @@ func TestResponseHeaderReadError(t *testing.T) { // forbidden trailer testResponseHeaderReadError(t, h, "HTTP/1.1 200 OK\r\nContent-Length: -1\r\nTrailer: Foo, Content-Length\r\n\r\n") + + // no protocol in the first line + testResponseHeaderReadError(t, h, "GET /foo/bar\r\nHost: google.com\r\n\r\nisdD") } func TestResponseHeaderReadErrorSecureLog(t *testing.T) { @@ -2760,6 +2753,9 @@ func TestRequestHeaderReadError(t *testing.T) { // forbidden trailer testRequestHeaderReadError(t, h, "POST /a HTTP/1.1\r\nContent-Length: -1\r\nTrailer: Foo, Content-Length\r\n\r\n") + + // post with duplicate content-length + testRequestHeaderReadError(t, h, "POST /xx HTTP/1.1\r\nHost: aa\r\nContent-Type: s\r\nContent-Length: 13\r\nContent-Length: 1\r\n\r\n") } func TestRequestHeaderReadSecuredError(t *testing.T) { @@ -2809,6 +2805,8 @@ func testResponseHeaderReadSecuredError(t *testing.T, h *ResponseHeader, headers } func testRequestHeaderReadError(t *testing.T, h *RequestHeader, headers string) { + t.Helper() + r := bytes.NewBufferString(headers) br := bufio.NewReader(r) err := h.Read(br) @@ -2839,6 +2837,8 @@ func testRequestHeaderReadSecuredError(t *testing.T, h *RequestHeader, headers s func testResponseHeaderReadSuccess(t *testing.T, h *ResponseHeader, headers string, expectedStatusCode, expectedContentLength int, expectedContentType string, ) { + t.Helper() + r := bytes.NewBufferString(headers) br := bufio.NewReader(r) err := h.Read(br) @@ -2851,6 +2851,8 @@ func testResponseHeaderReadSuccess(t *testing.T, h *ResponseHeader, headers stri func testRequestHeaderReadSuccess(t *testing.T, h *RequestHeader, headers string, expectedContentLength int, expectedRequestURI, expectedHost, expectedReferer, expectedContentType string, expectedTrailer map[string]string, ) { + t.Helper() + r := bytes.NewBufferString(headers) br := bufio.NewReader(r) err := h.Read(br) diff --git a/http.go b/http.go index 56723016f6..5dd4e645f3 100644 --- a/http.go +++ b/http.go @@ -528,6 +528,23 @@ func (ctx *RequestCtx) RequestBodyStream() io.Reader { return ctx.Request.bodyStream } +func (req *Request) BodyUnzstd() ([]byte, error) { + return unzstdData(req.Body()) +} + +func (resp *Response) BodyUnzstd() ([]byte, error) { + return unzstdData(resp.Body()) +} + +func unzstdData(p []byte) ([]byte, error) { + var bb bytebufferpool.ByteBuffer + _, err := WriteUnzstd(&bb, p) + if err != nil { + return nil, err + } + return bb.B, nil +} + func inflateData(p []byte) ([]byte, error) { var bb bytebufferpool.ByteBuffer _, err := WriteInflate(&bb, p) @@ -554,6 +571,8 @@ func (req *Request) BodyUncompressed() ([]byte, error) { return req.BodyGunzip() case "br": return req.BodyUnbrotli() + case "zstd": + return req.BodyUnzstd() default: return nil, ErrContentEncodingUnsupported } @@ -574,6 +593,8 @@ func (resp *Response) BodyUncompressed() ([]byte, error) { return resp.BodyGunzip() case "br": return resp.BodyUnbrotli() + case "zstd": + return resp.BodyUnzstd() default: return nil, ErrContentEncodingUnsupported } @@ -1180,7 +1201,7 @@ func (req *Request) ReadLimitBody(r *bufio.Reader, maxBodySize int) error { return req.readLimitBody(r, maxBodySize, false, true) } -func (req *Request) readLimitBody(r *bufio.Reader, maxBodySize int, getOnly bool, preParseMultipartForm bool) error { +func (req *Request) readLimitBody(r *bufio.Reader, maxBodySize int, getOnly, preParseMultipartForm bool) error { // Do not reset the request here - the caller must reset it before // calling this method. @@ -1198,7 +1219,7 @@ func (req *Request) readLimitBody(r *bufio.Reader, maxBodySize int, getOnly bool return req.ContinueReadBody(r, maxBodySize, preParseMultipartForm) } -func (req *Request) readBodyStream(r *bufio.Reader, maxBodySize int, getOnly bool, preParseMultipartForm bool) error { +func (req *Request) readBodyStream(r *bufio.Reader, maxBodySize int, getOnly, preParseMultipartForm bool) error { // Do not reset the request here - the caller must reset it before // calling this method. @@ -1250,7 +1271,7 @@ func (req *Request) ContinueReadBody(r *bufio.Reader, maxBodySize int, preParseM // This way we limit memory usage for large file uploads, since their contents // is streamed into temporary files if file size exceeds defaultMaxInMemoryFileSize. req.multipartFormBoundary = string(req.Header.MultipartFormBoundary()) - if len(req.multipartFormBoundary) > 0 && len(req.Header.peek(strContentEncoding)) == 0 { + if req.multipartFormBoundary != "" && len(req.Header.peek(strContentEncoding)) == 0 { req.multipartForm, err = readMultipartForm(r, req.multipartFormBoundary, contentLength, defaultMaxInMemoryFileSize) if err != nil { req.Reset() @@ -1289,7 +1310,7 @@ func (req *Request) ContinueReadBody(r *bufio.Reader, maxBodySize int, preParseM // // If maxBodySize > 0 and the body size exceeds maxBodySize, // then ErrBodyTooLarge is returned. -func (req *Request) ReadBody(r *bufio.Reader, contentLength int, maxBodySize int) (err error) { +func (req *Request) ReadBody(r *bufio.Reader, contentLength, maxBodySize int) (err error) { bodyBuf := req.bodyBuffer() bodyBuf.Reset() @@ -1329,7 +1350,7 @@ func (req *Request) ContinueReadBodyStream(r *bufio.Reader, maxBodySize int, pre // This way we limit memory usage for large file uploads, since their contents // is streamed into temporary files if file size exceeds defaultMaxInMemoryFileSize. req.multipartFormBoundary = b2s(req.Header.MultipartFormBoundary()) - if len(req.multipartFormBoundary) > 0 && len(req.Header.peek(strContentEncoding)) == 0 { + if req.multipartFormBoundary != "" && len(req.Header.peek(strContentEncoding)) == 0 { req.multipartForm, err = readMultipartForm(r, req.multipartFormBoundary, contentLength, defaultMaxInMemoryFileSize) if err != nil { req.Reset() @@ -1612,7 +1633,7 @@ func (req *Request) Write(w *bufio.Writer) error { _, err = w.Write(body) } else if len(body) > 0 { if req.secureErrorLogMessage { - return fmt.Errorf("non-zero body for non-POST request") + return errors.New("non-zero body for non-POST request") } return fmt.Errorf("non-zero body for non-POST request. body=%q", body) } @@ -1849,6 +1870,55 @@ func (resp *Response) deflateBody(level int) error { return nil } +func (resp *Response) zstdBody(level int) error { + if len(resp.Header.ContentEncoding()) > 0 { + return nil + } + + if !resp.Header.isCompressibleContentType() { + return nil + } + + if resp.bodyStream != nil { + // Reset Content-Length to -1, since it is impossible + // to determine body size beforehand of streamed compression. + // For + resp.Header.SetContentLength(-1) + + // Do not care about memory allocations here, since flate is slow + // and allocates a lot of memory by itself. + bs := resp.bodyStream + resp.bodyStream = NewStreamReader(func(sw *bufio.Writer) { + zw := acquireStacklessZstdWriter(sw, level) + fw := &flushWriter{ + wf: zw, + bw: sw, + } + copyZeroAlloc(fw, bs) //nolint:errcheck + releaseStacklessZstdWriter(zw, level) + if bsc, ok := bs.(io.Closer); ok { + bsc.Close() + } + }) + } else { + bodyBytes := resp.bodyBytes() + if len(bodyBytes) < minCompressLen { + return nil + } + w := responseBodyPool.Get() + w.B = AppendZstdBytesLevel(w.B, bodyBytes, level) + + if resp.body != nil { + responseBodyPool.Put(resp.body) + } + resp.body = w + resp.bodyRaw = nil + } + resp.Header.SetContentEncodingBytes(strZstd) + resp.Header.addVaryBytes(strAcceptEncoding) + return nil +} + // Bodies with sizes smaller than minCompressLen aren't compressed at all. const minCompressLen = 200 @@ -2172,7 +2242,7 @@ func writeChunk(w *bufio.Writer, b []byte) error { // the given limit. var ErrBodyTooLarge = errors.New("body size exceeds the given limit") -func readBody(r *bufio.Reader, contentLength int, maxBodySize int, dst []byte) ([]byte, error) { +func readBody(r *bufio.Reader, contentLength, maxBodySize int, dst []byte) ([]byte, error) { if maxBodySize > 0 && contentLength > maxBodySize { return dst, ErrBodyTooLarge } @@ -2181,7 +2251,7 @@ func readBody(r *bufio.Reader, contentLength int, maxBodySize int, dst []byte) ( var errChunkedStream = errors.New("chunked stream") -func readBodyWithStreaming(r *bufio.Reader, contentLength int, maxBodySize int, dst []byte) (b []byte, err error) { +func readBodyWithStreaming(r *bufio.Reader, contentLength, maxBodySize int, dst []byte) (b []byte, err error) { if contentLength == -1 { // handled in requestStream.Read() return b, errChunkedStream @@ -2309,7 +2379,7 @@ func readBodyChunked(r *bufio.Reader, maxBodySize int, dst []byte) ([]byte, erro } if !bytes.Equal(dst[len(dst)-strCRLFLen:], strCRLF) { return dst, ErrBrokenChunk{ - error: fmt.Errorf("cannot find crlf at the end of chunk"), + error: errors.New("cannot find crlf at the end of chunk"), } } dst = dst[:len(dst)-strCRLFLen] diff --git a/http_test.go b/http_test.go index 098990a678..b83a487455 100644 --- a/http_test.go +++ b/http_test.go @@ -3,7 +3,6 @@ package fasthttp import ( "bufio" "bytes" - "encoding/base64" "errors" "fmt" "io" @@ -22,23 +21,15 @@ import ( func TestInvalidTrailers(t *testing.T) { t.Parallel() - if err := (&Response{}).Read(bufio.NewReader(bytes.NewReader([]byte{0x20, 0x30, 0x0a, 0x54, 0x72, 0x61, 0x6e, 0x73, 0x66, 0x65, 0x72, 0x2d, 0x45, 0x6e, 0x63, 0x6f, 0x64, 0x69, 0x6e, 0x67, 0x3a, 0xff, 0x0a, 0x0a, 0x30, 0x0d, 0x0a, 0x30}))); !errors.Is(err, io.EOF) { + if err := (&Response{}).Read(bufio.NewReader(strings.NewReader(" 0\nTransfer-Encoding:\xff\n\n0\r\n0"))); !errors.Is(err, io.EOF) { t.Fatalf("%#v", err) } - if err := (&Response{}).Read(bufio.NewReader(bytes.NewReader([]byte{0xff, 0x20, 0x0a, 0x54, 0x52, 0x61, 0x49, 0x4c, 0x65, 0x52, 0x3a, 0x2c, 0x0a, 0x0a}))); !errors.Is(err, errEmptyInt) { + if err := (&Response{}).Read(bufio.NewReader(strings.NewReader("\xff \nTRaILeR:,\n\n"))); !errors.Is(err, errEmptyInt) { t.Fatal(err) } - if err := (&Response{}).Read(bufio.NewReader(bytes.NewReader([]byte{0x54, 0x52, 0x61, 0x49, 0x4c, 0x65, 0x52, 0x3a, 0x2c, 0x0a, 0x0a}))); !strings.Contains(err.Error(), "cannot find whitespace in the first line of response") { + if err := (&Response{}).Read(bufio.NewReader(strings.NewReader("TRaILeR:,\n\n"))); !strings.Contains(err.Error(), "cannot find whitespace in the first line of response") { t.Fatal(err) } - if err := (&Request{}).Read(bufio.NewReader(bytes.NewReader([]byte{0xff, 0x20, 0x0a, 0x54, 0x52, 0x61, 0x49, 0x4c, 0x65, 0x52, 0x3a, 0x2c, 0x0a, 0x0a}))); !strings.Contains(err.Error(), "contain forbidden trailer") { - t.Fatal(err) - } - - b, _ := base64.StdEncoding.DecodeString("tCAKIDoKCToKICAKCToKICAKCToKIAogOgoJOgogIAoJOgovIC8vOi4KOh0KVFJhSUxlUjo9HT09HQpUUmFJTGVSOicQAApUUmFJTGVSOj0gHSAKCT09HQoKOgoKCgo=") - if err := (&Request{}).Read(bufio.NewReader(bytes.NewReader(b))); !strings.Contains(err.Error(), "error when reading request headers: invalid header key") { - t.Fatalf("%#v", err) - } } func TestResponseEmptyTransferEncoding(t *testing.T) { @@ -164,8 +155,8 @@ func testRequestCopyTo(t *testing.T, src *Request) { var dst Request src.CopyTo(&dst) - if !reflect.DeepEqual(*src, dst) { //nolint:govet - t.Fatalf("RequestCopyTo fail, src: \n%+v\ndst: \n%+v\n", *src, dst) //nolint:govet + if !reflect.DeepEqual(src, &dst) { + t.Fatalf("RequestCopyTo fail, src: \n%+v\ndst: \n%+v\n", src, &dst) } } @@ -173,8 +164,8 @@ func testResponseCopyTo(t *testing.T, src *Response) { var dst Response src.CopyTo(&dst) - if !reflect.DeepEqual(*src, dst) { //nolint:govet - t.Fatalf("ResponseCopyTo fail, src: \n%+v\ndst: \n%+v\n", *src, dst) //nolint:govet + if !reflect.DeepEqual(src, &dst) { + t.Fatalf("ResponseCopyTo fail, src: \n%+v\ndst: \n%+v\n", src, &dst) } } @@ -1219,7 +1210,7 @@ func TestRequestReadGzippedBody(t *testing.T) { if r.Header.ContentLength() != len(body) { t.Fatalf("unexpected content-length: %d. Expecting %d", r.Header.ContentLength(), len(body)) } - if string(r.Body()) != string(body) { + if !bytes.Equal(r.Body(), body) { t.Fatalf("unexpected body: %q. Expecting %q", r.Body(), body) } @@ -1332,7 +1323,7 @@ func TestRequestContinueReadBodyDisablePrereadMultipartForm(t *testing.T) { t.Fatalf("The multipartForm of the Request must be nil") } - if string(formData) != string(r.Body()) { + if !bytes.Equal(formData, r.Body()) { t.Fatalf("The body given must equal the body in the Request") } } @@ -2667,8 +2658,8 @@ func TestRequestRawBodyCopyTo(t *testing.T) { } type testReader struct { - read chan (int) - cb chan (struct{}) + read chan int + cb chan struct{} onClose func() error } diff --git a/prefork/prefork.go b/prefork/prefork.go index fa45348c91..d40bded1a7 100644 --- a/prefork/prefork.go +++ b/prefork/prefork.go @@ -152,7 +152,8 @@ func (p *Prefork) doCommand() (*exec.Cmd, error) { cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr cmd.ExtraFiles = p.files - return cmd, cmd.Start() + err := cmd.Start() + return cmd, err } func (p *Prefork) prefork(addr string) (err error) { @@ -209,7 +210,8 @@ func (p *Prefork) prefork(addr string) (err error) { p.logger().Printf("one of the child prefork processes exited with "+ "error: %v", sig.err) - if exitedProcs++; exitedProcs > p.RecoverThreshold { + exitedProcs++ + if exitedProcs > p.RecoverThreshold { p.logger().Printf("child prefork processes exit too many times, "+ "which exceeds the value of RecoverThreshold(%d), "+ "exiting the master process.\n", exitedProcs) diff --git a/request_body.zst b/request_body.zst new file mode 100644 index 0000000000..ea95e7301d Binary files /dev/null and b/request_body.zst differ diff --git a/server.go b/server.go index dbee2a4f0d..03b49f7efe 100644 --- a/server.go +++ b/server.go @@ -431,6 +431,8 @@ type Server struct { open int32 stop int32 done chan struct{} + + rejectedRequestsCount uint32 } // TimeoutHandler creates RequestHandler, which returns StatusRequestTimeout @@ -521,10 +523,13 @@ func CompressHandler(h RequestHandler) RequestHandler { func CompressHandlerLevel(h RequestHandler, level int) RequestHandler { return func(ctx *RequestCtx) { h(ctx) - if ctx.Request.Header.HasAcceptEncodingBytes(strGzip) { + switch { + case ctx.Request.Header.HasAcceptEncodingBytes(strGzip): ctx.Response.gzipBody(level) //nolint:errcheck - } else if ctx.Request.Header.HasAcceptEncodingBytes(strDeflate) { + case ctx.Request.Header.HasAcceptEncodingBytes(strDeflate): ctx.Response.deflateBody(level) //nolint:errcheck + case ctx.Request.Header.HasAcceptEncodingBytes(strZstd): + ctx.Response.zstdBody(level) //nolint:errcheck } } } @@ -557,6 +562,8 @@ func CompressHandlerBrotliLevel(h RequestHandler, brotliLevel, otherLevel int) R ctx.Response.gzipBody(otherLevel) //nolint:errcheck case ctx.Request.Header.HasAcceptEncodingBytes(strDeflate): ctx.Response.deflateBody(otherLevel) //nolint:errcheck + case ctx.Request.Header.HasAcceptEncodingBytes(strZstd): + ctx.Response.zstdBody(otherLevel) //nolint:errcheck } } } @@ -675,7 +682,7 @@ func (ctx *RequestCtx) Hijacked() bool { // All the values are removed from ctx after returning from the top // RequestHandler. Additionally, Close method is called on each value // implementing io.Closer before removing the value from ctx. -func (ctx *RequestCtx) SetUserValue(key any, value any) { +func (ctx *RequestCtx) SetUserValue(key, value any) { ctx.userValues.Set(key, value) } @@ -1828,6 +1835,7 @@ func (s *Server) Serve(ln net.Listener) error { atomic.AddInt32(&s.open, 1) if !wp.Serve(c) { atomic.AddInt32(&s.open, -1) + atomic.AddUint32(&s.rejectedRequestsCount, 1) s.writeFastError(c, StatusServiceUnavailable, "The connection cannot be served because Server.Concurrency limit exceeded") c.Close() @@ -2073,6 +2081,13 @@ func (s *Server) GetOpenConnectionsCount() int32 { return atomic.LoadInt32(&s.open) } +// GetRejectedConnectionsCount returns a number of rejected connections. +// +// This function is intended be used by monitoring systems. +func (s *Server) GetRejectedConnectionsCount() uint32 { + return atomic.LoadUint32(&s.rejectedRequestsCount) +} + func (s *Server) getConcurrency() int { n := s.Concurrency if n <= 0 { diff --git a/server_test.go b/server_test.go index efc126f653..2153c8987f 100644 --- a/server_test.go +++ b/server_test.go @@ -1016,6 +1016,62 @@ func TestServerConcurrencyLimit(t *testing.T) { } } +func TestRejectedRequestsCount(t *testing.T) { + t.Parallel() + + s := &Server{ + Handler: func(ctx *RequestCtx) { + ctx.WriteString("OK") //nolint:errcheck + }, + Concurrency: 1, + Logger: &testLogger{}, + } + + ln := fasthttputil.NewInmemoryListener() + + serverCh := make(chan struct{}) + go func() { + if err := s.Serve(ln); err != nil { + t.Errorf("unexpected error: %v", err) + } + close(serverCh) + }() + + clientCh := make(chan struct{}) + expectedCount := 5 + go func() { + for i := 0; i < expectedCount+1; i++ { + _, err := ln.Dial() + if err != nil { + t.Errorf("unexpected error: %v", err) + } + } + + if cnt := s.GetRejectedConnectionsCount(); cnt != uint32(expectedCount) { + t.Errorf("unexpected rejected connections count: %d. Expecting %d", + cnt, expectedCount) + } + + close(clientCh) + }() + + select { + case <-clientCh: + case <-time.After(time.Second): + t.Fatal("timeout") + } + + if err := ln.Close(); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + select { + case <-serverCh: + case <-time.After(time.Second): + t.Fatal("timeout") + } +} + func TestServerWriteFastError(t *testing.T) { t.Parallel() @@ -1956,7 +2012,7 @@ func TestCompressHandler(t *testing.T) { t.Fatalf("unexpected error: %v", err) } ce := resp.Header.ContentEncoding() - if string(ce) != "" { + if len(ce) != 0 { t.Fatalf("unexpected Content-Encoding: %q. Expecting %q", ce, "") } body := resp.Body() @@ -2054,11 +2110,11 @@ func TestCompressHandlerVary(t *testing.T) { t.Fatalf("unexpected error: %v", err) } ce := resp.Header.ContentEncoding() - if string(ce) != "" { + if len(ce) != 0 { t.Fatalf("unexpected Content-Encoding: %q. Expecting %q", ce, "") } vary := resp.Header.Peek("Vary") - if string(vary) != "" { + if len(vary) != 0 { t.Fatalf("unexpected Vary: %q. Expecting %q", vary, "") } body := resp.Body() diff --git a/stackless/func_test.go b/stackless/func_test.go index 719d10be2f..0d8f1d5c5e 100644 --- a/stackless/func_test.go +++ b/stackless/func_test.go @@ -1,7 +1,7 @@ package stackless import ( - "fmt" + "errors" "sync/atomic" "testing" "time" @@ -44,7 +44,7 @@ func TestNewFuncMulti(t *testing.T) { var err error for i := 0; i < iterations; i++ { if !f1(3) { - err = fmt.Errorf("f1 mustn't return false") + err = errors.New("f1 mustn't return false") break } } @@ -56,7 +56,7 @@ func TestNewFuncMulti(t *testing.T) { var err error for i := 0; i < iterations; i++ { if !f2(5) { - err = fmt.Errorf("f2 mustn't return false") + err = errors.New("f2 mustn't return false") break } } diff --git a/status.go b/status.go index c88ba11ee7..f92727c628 100644 --- a/status.go +++ b/status.go @@ -163,7 +163,7 @@ func StatusMessage(statusCode int) string { return unknownStatusCode } -func formatStatusLine(dst []byte, protocol []byte, statusCode int, statusText []byte) []byte { +func formatStatusLine(dst, protocol []byte, statusCode int, statusText []byte) []byte { dst = append(dst, protocol...) dst = append(dst, ' ') dst = strconv.AppendInt(dst, int64(statusCode), 10) diff --git a/stream_test.go b/stream_test.go index f4dcb5a62d..fe35f3216f 100644 --- a/stream_test.go +++ b/stream_test.go @@ -2,6 +2,7 @@ package fasthttp import ( "bufio" + "errors" "fmt" "io" "testing" @@ -55,7 +56,7 @@ func TestStreamReaderClose(t *testing.T) { w.Write(data) //nolint:errcheck } if err := w.Flush(); err == nil { - ch <- fmt.Errorf("expecting error on the second flush") + ch <- errors.New("expecting error on the second flush") } ch <- nil }) diff --git a/strings.go b/strings.go index 3cec8ed0e1..a9e4072d4d 100644 --- a/strings.go +++ b/strings.go @@ -19,7 +19,6 @@ var ( strCRLF = []byte("\r\n") strHTTP = []byte("http") strHTTPS = []byte("https") - strHTTP10 = []byte("HTTP/1.0") strHTTP11 = []byte("HTTP/1.1") strColon = []byte(":") strColonSlashSlash = []byte("://") @@ -73,6 +72,7 @@ var ( strClose = []byte("close") strGzip = []byte("gzip") strBr = []byte("br") + strZstd = []byte("zstd") strDeflate = []byte("deflate") strKeepAlive = []byte("keep-alive") strUpgrade = []byte("Upgrade") diff --git a/tcpdialer.go b/tcpdialer.go index 7edd79fdb1..e8430cb9c8 100644 --- a/tcpdialer.go +++ b/tcpdialer.go @@ -151,6 +151,8 @@ type TCPDialer struct { // } Resolver Resolver + // DisableDNSResolution may be used to disable DNS resolution + DisableDNSResolution bool // DNSCacheDuration may be used to override the default DNS cache duration (DefaultDNSCacheDuration) DNSCacheDuration time.Duration @@ -277,23 +279,26 @@ func (d *TCPDialer) dial(addr string, dualStack bool, timeout time.Duration) (ne d.DNSCacheDuration = DefaultDNSCacheDuration } - go d.tcpAddrsClean() + if !d.DisableDNSResolution { + go d.tcpAddrsClean() + } }) - deadline := time.Now().Add(timeout) - addrs, idx, err := d.getTCPAddrs(addr, dualStack, deadline) - if err != nil { - return nil, err - } network := "tcp4" if dualStack { network = "tcp" } - + if d.DisableDNSResolution { + return d.tryDial(network, addr, deadline, d.concurrencyCh) + } + addrs, idx, err := d.getTCPAddrs(addr, dualStack, deadline) + if err != nil { + return nil, err + } var conn net.Conn n := uint32(len(addrs)) for n > 0 { - conn, err = d.tryDial(network, &addrs[idx%n], deadline, d.concurrencyCh) + conn, err = d.tryDial(network, addrs[idx%n].String(), deadline, d.concurrencyCh) if err == nil { return conn, nil } @@ -307,7 +312,7 @@ func (d *TCPDialer) dial(addr string, dualStack bool, timeout time.Duration) (ne } func (d *TCPDialer) tryDial( - network string, addr *net.TCPAddr, deadline time.Time, concurrencyCh chan struct{}, + network string, addr string, deadline time.Time, concurrencyCh chan struct{}, ) (net.Conn, error) { timeout := time.Until(deadline) if timeout <= 0 { @@ -340,7 +345,7 @@ func (d *TCPDialer) tryDial( ctx, cancelCtx := context.WithDeadline(context.Background(), deadline) defer cancelCtx() - conn, err := dialer.DialContext(ctx, network, addr.String()) + conn, err := dialer.DialContext(ctx, network, addr) if err != nil && ctx.Err() == context.DeadlineExceeded { return nil, ErrDialTimeout } diff --git a/uri_test.go b/uri_test.go index 5996bcb19c..6ffea4a9bb 100644 --- a/uri_test.go +++ b/uri_test.go @@ -230,14 +230,14 @@ func TestURICopyTo(t *testing.T) { var u URI var copyU URI u.CopyTo(©U) - if !reflect.DeepEqual(u, copyU) { //nolint:govet - t.Fatalf("URICopyTo fail, u: \n%+v\ncopyu: \n%+v\n", u, copyU) //nolint:govet + if !reflect.DeepEqual(&u, ©U) { + t.Fatalf("URICopyTo fail, u: \n%+v\ncopyu: \n%+v\n", &u, ©U) } u.UpdateBytes([]byte("https://example.com/foo?bar=baz&baraz#qqqq")) u.CopyTo(©U) - if !reflect.DeepEqual(u, copyU) { //nolint:govet - t.Fatalf("URICopyTo fail, u: \n%+v\ncopyu: \n%+v\n", u, copyU) //nolint:govet + if !reflect.DeepEqual(&u, ©U) { + t.Fatalf("URICopyTo fail, u: \n%+v\ncopyu: \n%+v\n", &u, ©U) } } diff --git a/userdata.go b/userdata.go index a9afbf564c..38cca864f8 100644 --- a/userdata.go +++ b/userdata.go @@ -11,7 +11,7 @@ type userDataKV struct { type userData []userDataKV -func (d *userData) Set(key any, value any) { +func (d *userData) Set(key, value any) { if b, ok := key.([]byte); ok { key = string(b) } diff --git a/zstd.go b/zstd.go new file mode 100644 index 0000000000..226a126326 --- /dev/null +++ b/zstd.go @@ -0,0 +1,186 @@ +package fasthttp + +import ( + "bytes" + "fmt" + "io" + "sync" + + "github.com/klauspost/compress/zstd" + "github.com/valyala/bytebufferpool" + "github.com/valyala/fasthttp/stackless" +) + +const ( + CompressZstdSpeedNotSet = iota + CompressZstdBestSpeed + CompressZstdDefault + CompressZstdSpeedBetter + CompressZstdBestCompression +) + +var ( + zstdDecoderPool sync.Pool + zstdEncoderPool sync.Pool + realZstdWriterPoolMap = newCompressWriterPoolMap() + stacklessZstdWriterPoolMap = newCompressWriterPoolMap() +) + +func acquireZstdReader(r io.Reader) (*zstd.Decoder, error) { + v := zstdDecoderPool.Get() + if v == nil { + return zstd.NewReader(r) + } + zr := v.(*zstd.Decoder) + if err := zr.Reset(r); err != nil { + return nil, err + } + return zr, nil +} + +func releaseZstdReader(zr *zstd.Decoder) { + zstdDecoderPool.Put(zr) +} + +func acquireZstdWriter(w io.Writer, level int) (*zstd.Encoder, error) { + v := zstdEncoderPool.Get() + if v == nil { + return zstd.NewWriter(w, zstd.WithEncoderLevel(zstd.EncoderLevel(level))) + } + zw := v.(*zstd.Encoder) + zw.Reset(w) + return zw, nil +} + +func releaseZstdWriter(zw *zstd.Encoder) { //nolint:unused + zw.Close() + zstdEncoderPool.Put(zw) +} + +func acquireStacklessZstdWriter(w io.Writer, compressLevel int) stackless.Writer { + nLevel := normalizeZstdCompressLevel(compressLevel) + p := stacklessZstdWriterPoolMap[nLevel] + v := p.Get() + if v == nil { + return stackless.NewWriter(w, func(w io.Writer) stackless.Writer { + return acquireRealZstdWriter(w, compressLevel) + }) + } + sw := v.(stackless.Writer) + sw.Reset(w) + return sw +} + +func releaseStacklessZstdWriter(zf stackless.Writer, zstdDefault int) { + zf.Close() + nLevel := normalizeZstdCompressLevel(zstdDefault) + p := stacklessZstdWriterPoolMap[nLevel] + p.Put(zf) +} + +func acquireRealZstdWriter(w io.Writer, level int) *zstd.Encoder { + nLevel := normalizeZstdCompressLevel(level) + p := realZstdWriterPoolMap[nLevel] + v := p.Get() + if v == nil { + zw, err := acquireZstdWriter(w, level) + if err != nil { + panic(err) + } + return zw + } + zw := v.(*zstd.Encoder) + zw.Reset(w) + return zw +} + +func releaseRealZstdWrter(zw *zstd.Encoder, level int) { + zw.Close() + nLevel := normalizeZstdCompressLevel(level) + p := realZstdWriterPoolMap[nLevel] + p.Put(zw) +} + +func AppendZstdBytesLevel(dst, src []byte, level int) []byte { + w := &byteSliceWriter{dst} + WriteZstdLevel(w, src, level) //nolint:errcheck + return w.b +} + +func WriteZstdLevel(w io.Writer, p []byte, level int) (int, error) { + level = normalizeZstdCompressLevel(level) + switch w.(type) { + case *byteSliceWriter, + *bytes.Buffer, + *bytebufferpool.ByteBuffer: + ctx := &compressCtx{ + w: w, + p: p, + level: level, + } + stacklessWriteZstd(ctx) + return len(p), nil + default: + zw := acquireStacklessZstdWriter(w, level) + n, err := zw.Write(p) + releaseStacklessZstdWriter(zw, level) + return n, err + } +} + +var ( + stacklessWriteZstdOnce sync.Once + stacklessWriteZstdFunc func(ctx any) bool +) + +func stacklessWriteZstd(ctx any) { + stacklessWriteZstdOnce.Do(func() { + stacklessWriteZstdFunc = stackless.NewFunc(nonblockingWriteZstd) + }) + stacklessWriteZstdFunc(ctx) +} + +func nonblockingWriteZstd(ctxv any) { + ctx := ctxv.(*compressCtx) + zw := acquireRealZstdWriter(ctx.w, ctx.level) + zw.Write(ctx.p) //nolint:errcheck + releaseRealZstdWrter(zw, ctx.level) +} + +// AppendZstdBytes appends zstd src to dst and returns the resulting dst. +func AppendZstdBytes(dst, src []byte) []byte { + return AppendZstdBytesLevel(dst, src, CompressZstdDefault) +} + +// WriteUnzstd writes unzstd p to w and returns the number of uncompressed +// bytes written to w. +func WriteUnzstd(w io.Writer, p []byte) (int, error) { + r := &byteSliceReader{p} + zr, err := acquireZstdReader(r) + if err != nil { + return 0, err + } + n, err := copyZeroAlloc(w, zr) + releaseZstdReader(zr) + nn := int(n) + if int64(nn) != n { + return 0, fmt.Errorf("too much data unzstd: %d", n) + } + return nn, err +} + +// AppendUnzstdBytes appends unzstd src to dst and returns the resulting dst. +func AppendUnzstdBytes(dst, src []byte) ([]byte, error) { + w := &byteSliceWriter{dst} + _, err := WriteUnzstd(w, src) + return w.b, err +} + +// normalizes compression level into [0..7], so it could be used as an index +// in *PoolMap. +func normalizeZstdCompressLevel(level int) int { + if level < CompressZstdSpeedNotSet || level > CompressZstdBestCompression { + level = CompressZstdDefault + } + return level +} diff --git a/zstd_test.go b/zstd_test.go new file mode 100644 index 0000000000..dc0c45f339 --- /dev/null +++ b/zstd_test.go @@ -0,0 +1,102 @@ +package fasthttp + +import ( + "bytes" + "fmt" + "io" + "testing" +) + +func TestZstdBytesSerial(t *testing.T) { + t.Parallel() + + if err := testZstdBytes(); err != nil { + t.Fatal(err) + } +} + +func TestZstdBytesConcurrent(t *testing.T) { + t.Parallel() + + if err := testConcurrent(10, testZstdBytes); err != nil { + t.Fatal(err) + } +} + +func testZstdBytes() error { + for _, s := range compressTestcases { + if err := testZstdBytesSingleCase(s); err != nil { + return err + } + } + return nil +} + +func testZstdBytesSingleCase(s string) error { + prefix := []byte("foobar") + ZstdpedS := AppendZstdBytes(prefix, []byte(s)) + if !bytes.Equal(ZstdpedS[:len(prefix)], prefix) { + return fmt.Errorf("unexpected prefix when compressing %q: %q. Expecting %q", s, ZstdpedS[:len(prefix)], prefix) + } + + unZstdedS, err := AppendUnzstdBytes(prefix, ZstdpedS[len(prefix):]) + if err != nil { + return fmt.Errorf("unexpected error when uncompressing %q: %w", s, err) + } + if !bytes.Equal(unZstdedS[:len(prefix)], prefix) { + return fmt.Errorf("unexpected prefix when uncompressing %q: %q. Expecting %q", s, unZstdedS[:len(prefix)], prefix) + } + unZstdedS = unZstdedS[len(prefix):] + if string(unZstdedS) != s { + return fmt.Errorf("unexpected uncompressed string %q. Expecting %q", unZstdedS, s) + } + return nil +} + +func TestZstdCompressSerial(t *testing.T) { + t.Parallel() + + if err := testZstdCompress(); err != nil { + t.Fatal(err) + } +} + +func TestZstdCompressConcurrent(t *testing.T) { + t.Parallel() + + if err := testConcurrent(10, testZstdCompress); err != nil { + t.Fatal(err) + } +} + +func testZstdCompress() error { + for _, s := range compressTestcases { + if err := testZstdCompressSingleCase(s); err != nil { + return err + } + } + return nil +} + +func testZstdCompressSingleCase(s string) error { + var buf bytes.Buffer + zw := acquireStacklessZstdWriter(&buf, CompressZstdDefault) + if _, err := zw.Write([]byte(s)); err != nil { + return fmt.Errorf("unexpected error: %w. s=%q", err, s) + } + releaseStacklessZstdWriter(zw, CompressZstdDefault) + + zr, err := acquireZstdReader(&buf) + if err != nil { + return fmt.Errorf("unexpected error: %w. s=%q", err, s) + } + body, err := io.ReadAll(zr) + if err != nil { + return fmt.Errorf("unexpected error: %w. s=%q", err, s) + } + if string(body) != s { + return fmt.Errorf("unexpected string after decompression: %q. Expecting %q", body, s) + } + releaseZstdReader(zr) + return nil +}