From 745241168fa0995b14c7df0c73e1b0cd7029d7e5 Mon Sep 17 00:00:00 2001 From: Bo-Yi Wu Date: Sat, 25 Nov 2023 10:06:21 +0800 Subject: [PATCH] feat: add Status method to Writer struct and update test (#63) - Add a `Status` method to the `Writer` struct - Modify the `TestWriter_Status` test function to include the new `Status` method fixed by @zhyee fixed https://github.com/gin-contrib/timeout/pull/52 fixed https://github.com/gin-contrib/timeout/pull/51 Signed-off-by: Bo-Yi Wu --- writer.go | 10 ++++++++++ writer_test.go | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+) diff --git a/writer.go b/writer.go index d0cb79b..eba95ea 100644 --- a/writer.go +++ b/writer.go @@ -72,6 +72,16 @@ func (w *Writer) FreeBuffer() { w.body = nil } +// Status we must override Status func here, +// or the http status code returned by gin.Context.Writer.Status() +// will always be 200 in other custom gin middlewares. +func (w *Writer) Status() int { + if w.code == 0 || w.timeout { + return w.ResponseWriter.Status() + } + return w.code +} + func checkWriteHeaderCode(code int) { if code < 100 || code > 999 { panic(fmt.Sprintf("invalid http status code: %d", code)) diff --git a/writer_test.go b/writer_test.go index 4b61738..614c151 100644 --- a/writer_test.go +++ b/writer_test.go @@ -2,8 +2,13 @@ package timeout import ( "fmt" + "net/http" + "net/http/httptest" + "strconv" "testing" + "time" + "github.com/gin-gonic/gin" "github.com/stretchr/testify/assert" ) @@ -21,3 +26,34 @@ func TestWriteHeader(t *testing.T) { writer.WriteHeader(code2) }) } + +func TestWriter_Status(t *testing.T) { + r := gin.New() + + r.Use(New( + WithTimeout(1*time.Second), + WithHandler(func(c *gin.Context) { + c.Next() + }), + WithResponse(testResponse), + )) + + r.Use(func(c *gin.Context) { + c.Next() + statusInMW := c.Writer.Status() + c.Request.Header.Set("X-Status-Code-MW-Set", strconv.Itoa(statusInMW)) + t.Logf("[%s] %s %s %d\n", time.Now().Format(time.RFC3339), c.Request.Method, c.Request.URL, statusInMW) + }) + + r.GET("/test", func(c *gin.Context) { + c.Writer.WriteHeader(http.StatusInternalServerError) + }) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/test", nil) + + r.ServeHTTP(w, req) + + assert.Equal(t, http.StatusInternalServerError, w.Code) + assert.Equal(t, strconv.Itoa(http.StatusInternalServerError), req.Header.Get("X-Status-Code-MW-Set")) +}