diff --git a/ctx.go b/ctx.go index 8e568e2576..880113d5f2 100644 --- a/ctx.go +++ b/ctx.go @@ -651,10 +651,11 @@ func (c *Ctx) GetRespHeader(key string, defaultValue ...string) string { // GetReqHeaders returns the HTTP request headers. // Returned value is only valid within the handler. Do not store any references. // Make copies or use the Immutable setting instead. -func (c *Ctx) GetReqHeaders() map[string]string { - headers := make(map[string]string) +func (c *Ctx) GetReqHeaders() map[string][]string { + headers := make(map[string][]string) c.Request().Header.VisitAll(func(k, v []byte) { - headers[c.app.getString(k)] = c.app.getString(v) + key := c.app.getString(k) + headers[key] = append(headers[key], c.app.getString(v)) }) return headers @@ -663,10 +664,11 @@ func (c *Ctx) GetReqHeaders() map[string]string { // GetRespHeaders returns the HTTP response headers. // Returned value is only valid within the handler. Do not store any references. // Make copies or use the Immutable setting instead. -func (c *Ctx) GetRespHeaders() map[string]string { - headers := make(map[string]string) +func (c *Ctx) GetRespHeaders() map[string][]string { + headers := make(map[string][]string) c.Response().Header.VisitAll(func(k, v []byte) { - headers[c.app.getString(k)] = c.app.getString(v) + key := c.app.getString(k) + headers[key] = append(headers[key], c.app.getString(v)) }) return headers diff --git a/ctx_test.go b/ctx_test.go index 7e89327ff8..ae15febef9 100644 --- a/ctx_test.go +++ b/ctx_test.go @@ -5117,12 +5117,39 @@ func Test_Ctx_GetRespHeaders(t *testing.T) { c.Set("test", "Hello, World 👋!") c.Set("foo", "bar") + c.Response().Header.Set("multi", "one") + c.Response().Header.Add("multi", "two") c.Response().Header.Set(HeaderContentType, "application/json") - utils.AssertEqual(t, c.GetRespHeaders(), map[string]string{ - "Content-Type": "application/json", - "Foo": "bar", - "Test": "Hello, World 👋!", + utils.AssertEqual(t, c.GetRespHeaders(), map[string][]string{ + "Content-Type": {"application/json"}, + "Foo": {"bar"}, + "Multi": {"one", "two"}, + "Test": {"Hello, World 👋!"}, + }) +} + +func Benchmark_Ctx_GetRespHeaders(b *testing.B) { + app := New() + c := app.AcquireCtx(&fasthttp.RequestCtx{}) + defer app.ReleaseCtx(c) + + c.Response().Header.Set("test", "Hello, World 👋!") + c.Response().Header.Set("foo", "bar") + c.Response().Header.Set(HeaderContentType, "application/json") + + b.ReportAllocs() + b.ResetTimer() + + var headers map[string][]string + for n := 0; n < b.N; n++ { + headers = c.GetRespHeaders() + } + + utils.AssertEqual(b, headers, map[string][]string{ + "Content-Type": {"application/json"}, + "Foo": {"bar"}, + "Test": {"Hello, World 👋!"}, }) } @@ -5133,14 +5160,41 @@ func Test_Ctx_GetReqHeaders(t *testing.T) { c := app.AcquireCtx(&fasthttp.RequestCtx{}) defer app.ReleaseCtx(c) + c.Request().Header.Set("test", "Hello, World 👋!") + c.Request().Header.Set("foo", "bar") + c.Request().Header.Set("multi", "one") + c.Request().Header.Add("multi", "two") + c.Request().Header.Set(HeaderContentType, "application/json") + + utils.AssertEqual(t, c.GetReqHeaders(), map[string][]string{ + "Content-Type": {"application/json"}, + "Foo": {"bar"}, + "Test": {"Hello, World 👋!"}, + "Multi": {"one", "two"}, + }) +} + +func Benchmark_Ctx_GetReqHeaders(b *testing.B) { + app := New() + c := app.AcquireCtx(&fasthttp.RequestCtx{}) + defer app.ReleaseCtx(c) + c.Request().Header.Set("test", "Hello, World 👋!") c.Request().Header.Set("foo", "bar") c.Request().Header.Set(HeaderContentType, "application/json") - utils.AssertEqual(t, c.GetReqHeaders(), map[string]string{ - "Content-Type": "application/json", - "Foo": "bar", - "Test": "Hello, World 👋!", + b.ReportAllocs() + b.ResetTimer() + + var headers map[string][]string + for n := 0; n < b.N; n++ { + headers = c.GetReqHeaders() + } + + utils.AssertEqual(b, headers, map[string][]string{ + "Content-Type": {"application/json"}, + "Foo": {"bar"}, + "Test": {"Hello, World 👋!"}, }) } diff --git a/docs/api/ctx.md b/docs/api/ctx.md index 1da3db9795..fdd473aacc 100644 --- a/docs/api/ctx.md +++ b/docs/api/ctx.md @@ -554,7 +554,7 @@ app.Get("/", func(c *fiber.Ctx) error { Returns the HTTP request headers. ```go title="Signature" -func (c *Ctx) GetReqHeaders() map[string]string +func (c *Ctx) GetReqHeaders() map[string][]string ``` > _Returned value is only valid within the handler. Do not store any references. @@ -589,7 +589,7 @@ app.Get("/", func(c *fiber.Ctx) error { Returns the HTTP response headers. ```go title="Signature" -func (c *Ctx) GetRespHeaders() map[string]string +func (c *Ctx) GetRespHeaders() map[string][]string ``` > _Returned value is only valid within the handler. Do not store any references. diff --git a/middleware/idempotency/idempotency.go b/middleware/idempotency/idempotency.go index ae4097ae9d..604f867c76 100644 --- a/middleware/idempotency/idempotency.go +++ b/middleware/idempotency/idempotency.go @@ -45,8 +45,10 @@ func New(config ...Config) fiber.Handler { _ = c.Status(res.StatusCode) - for header, val := range res.Headers { - c.Set(header, val) + for header, vals := range res.Headers { + for _, val := range vals { + c.Context().Response.Header.Add(header, val) + } } if len(res.Body) != 0 { @@ -122,7 +124,7 @@ func New(config ...Config) fiber.Handler { res.Headers = headers } else { // Filter - res.Headers = make(map[string]string) + res.Headers = make(map[string][]string) for h := range headers { if _, ok := keepResponseHeadersMap[utils.ToLower(h)]; ok { res.Headers[h] = headers[h] diff --git a/middleware/idempotency/response.go b/middleware/idempotency/response.go index ca06bcb452..f42d1a3311 100644 --- a/middleware/idempotency/response.go +++ b/middleware/idempotency/response.go @@ -4,7 +4,7 @@ package idempotency type response struct { StatusCode int `msg:"sc"` - Headers map[string]string `msg:"hs"` + Headers map[string][]string `msg:"hs"` Body []byte `msg:"b"` } diff --git a/middleware/idempotency/response_msgp.go b/middleware/idempotency/response_msgp.go index 4eb4d7fcb0..410d118ca0 100644 --- a/middleware/idempotency/response_msgp.go +++ b/middleware/idempotency/response_msgp.go @@ -18,7 +18,10 @@ func (z *response) MarshalMsg(b []byte) (o []byte, err error) { o = msgp.AppendMapHeader(o, uint32(len(z.Headers))) for za0001, za0002 := range z.Headers { o = msgp.AppendString(o, za0001) - o = msgp.AppendString(o, za0002) + o = msgp.AppendArrayHeader(o, uint32(len(za0002))) + for za0003 := range za0002 { + o = msgp.AppendString(o, za0002[za0003]) + } } // string "b" o = append(o, 0xa1, 0x62) @@ -58,7 +61,7 @@ func (z *response) UnmarshalMsg(bts []byte) (o []byte, err error) { return } if z.Headers == nil { - z.Headers = make(map[string]string, zb0002) + z.Headers = make(map[string][]string, zb0002) } else if len(z.Headers) > 0 { for key := range z.Headers { delete(z.Headers, key) @@ -66,18 +69,31 @@ func (z *response) UnmarshalMsg(bts []byte) (o []byte, err error) { } for zb0002 > 0 { var za0001 string - var za0002 string + var za0002 []string zb0002-- za0001, bts, err = msgp.ReadStringBytes(bts) if err != nil { err = msgp.WrapError(err, "Headers") return } - za0002, bts, err = msgp.ReadStringBytes(bts) + var zb0003 uint32 + zb0003, bts, err = msgp.ReadArrayHeaderBytes(bts) if err != nil { err = msgp.WrapError(err, "Headers", za0001) return } + if cap(za0002) >= int(zb0003) { + za0002 = (za0002)[:zb0003] + } else { + za0002 = make([]string, zb0003) + } + for za0003 := range za0002 { + za0002[za0003], bts, err = msgp.ReadStringBytes(bts) + if err != nil { + err = msgp.WrapError(err, "Headers", za0001, za0003) + return + } + } z.Headers[za0001] = za0002 } case "b": @@ -104,7 +120,10 @@ func (z *response) Msgsize() (s int) { if z.Headers != nil { for za0001, za0002 := range z.Headers { _ = za0002 - s += msgp.StringPrefixSize + len(za0001) + msgp.StringPrefixSize + len(za0002) + s += msgp.StringPrefixSize + len(za0001) + msgp.ArrayHeaderSize + for za0003 := range za0002 { + s += msgp.StringPrefixSize + len(za0002[za0003]) + } } } s += 2 + msgp.BytesPrefixSize + len(z.Body) diff --git a/middleware/logger/tags.go b/middleware/logger/tags.go index 87b9a9b228..67ccbb83a2 100644 --- a/middleware/logger/tags.go +++ b/middleware/logger/tags.go @@ -102,7 +102,7 @@ func createTagMap(cfg *Config) map[string]LogFunc { TagReqHeaders: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) { reqHeaders := make([]string, 0) for k, v := range c.GetReqHeaders() { - reqHeaders = append(reqHeaders, k+"="+v) + reqHeaders = append(reqHeaders, k+"="+strings.Join(v, ",")) } return output.Write([]byte(strings.Join(reqHeaders, "&"))) },