From f725ded92bac13e773f92ff478e1a461c160abd3 Mon Sep 17 00:00:00 2001 From: JIeJaitt <498938874@qq.com> Date: Sat, 16 Nov 2024 00:34:20 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=A5=20feat:=20Add=20Context=20Support?= =?UTF-8?q?=20to=20RequestID=20Middleware=20(#3200)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Rename UserContext() to Context(). Rename Context() to RequestCtx() * feat: add requestID in UserContext * Update Ctxt docs and What's new * Remove extra blank lines * ♻️ Refactor: merge issue #3186 * 🔥 Feature: improve FromContext func and test * 📚 Doc: improve requestid middleware * ♻️ Refactor: Rename interface to any * fix: Modify structure sorting to reduce memory usage --------- Co-authored-by: Juan Calderon-Perez Co-authored-by: Juan Calderon-Perez <835733+gaby@users.noreply.github.com> --- docs/middleware/requestid.md | 10 ++++ middleware/requestid/requestid.go | 25 ++++++++-- middleware/requestid/requestid_test.go | 67 +++++++++++++++++++------- 3 files changed, 82 insertions(+), 20 deletions(-) diff --git a/docs/middleware/requestid.md b/docs/middleware/requestid.md index 739a4a6190..01ec569e3c 100644 --- a/docs/middleware/requestid.md +++ b/docs/middleware/requestid.md @@ -49,6 +49,16 @@ func handler(c fiber.Ctx) error { } ``` +In version v3, Fiber will inject `requestID` into the built-in `Context` of Go. + +```go +func handler(c fiber.Ctx) error { + id := requestid.FromContext(c.Context()) + log.Printf("Request ID: %s", id) + return c.SendString("Hello, World!") +} +``` + ## Config | Property | Type | Description | Default | diff --git a/middleware/requestid/requestid.go b/middleware/requestid/requestid.go index 8e521dc650..ef67e6f21c 100644 --- a/middleware/requestid/requestid.go +++ b/middleware/requestid/requestid.go @@ -1,7 +1,10 @@ package requestid import ( + "context" + "github.com/gofiber/fiber/v3" + "github.com/gofiber/fiber/v3/log" ) // The contextKey type is unexported to prevent collisions with context keys defined in @@ -36,6 +39,10 @@ func New(config ...Config) fiber.Handler { // Add the request ID to locals c.Locals(requestIDKey, rid) + // Add the request ID to UserContext + ctx := context.WithValue(c.Context(), requestIDKey, rid) + c.SetContext(ctx) + // Continue stack return c.Next() } @@ -43,9 +50,21 @@ func New(config ...Config) fiber.Handler { // FromContext returns the request ID from context. // If there is no request ID, an empty string is returned. -func FromContext(c fiber.Ctx) string { - if rid, ok := c.Locals(requestIDKey).(string); ok { - return rid +// Supported context types: +// - fiber.Ctx: Retrieves request ID from Locals +// - context.Context: Retrieves request ID from context values +func FromContext(c any) string { + switch ctx := c.(type) { + case fiber.Ctx: + if rid, ok := ctx.Locals(requestIDKey).(string); ok { + return rid + } + case context.Context: + if rid, ok := ctx.Value(requestIDKey).(string); ok { + return rid + } + default: + log.Errorf("Unsupported context type: %T. Expected fiber.Ctx or context.Context", c) } return "" } diff --git a/middleware/requestid/requestid_test.go b/middleware/requestid/requestid_test.go index c739407be0..ad36884aca 100644 --- a/middleware/requestid/requestid_test.go +++ b/middleware/requestid/requestid_test.go @@ -51,26 +51,59 @@ func Test_RequestID_Next(t *testing.T) { require.Equal(t, fiber.StatusNotFound, resp.StatusCode) } -// go test -run Test_RequestID_Locals +// go test -run Test_RequestID_FromContext func Test_RequestID_FromContext(t *testing.T) { t.Parallel() + reqID := "ThisIsARequestId" - app := fiber.New() - app.Use(New(Config{ - Generator: func() string { - return reqID + type args struct { + inputFunc func(c fiber.Ctx) any + } + + tests := []struct { + args args + name string + }{ + { + name: "From fiber.Ctx", + args: args{ + inputFunc: func(c fiber.Ctx) any { + return c + }, + }, }, - })) - - var ctxVal string - - app.Use(func(c fiber.Ctx) error { - ctxVal = FromContext(c) - return c.Next() - }) - - _, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil)) - require.NoError(t, err) - require.Equal(t, reqID, ctxVal) + { + name: "From context.Context", + args: args{ + inputFunc: func(c fiber.Ctx) any { + return c.Context() + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + app := fiber.New() + app.Use(New(Config{ + Generator: func() string { + return reqID + }, + })) + + var ctxVal string + + app.Use(func(c fiber.Ctx) error { + ctxVal = FromContext(tt.args.inputFunc(c)) + return c.Next() + }) + + _, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil)) + require.NoError(t, err) + require.Equal(t, reqID, ctxVal) + }) + } }