From 73388f411878770b8abe7ab3a53c14d9b248d3ce Mon Sep 17 00:00:00 2001 From: kinggo Date: Tue, 31 Jan 2023 12:29:47 +0800 Subject: [PATCH] feat: add peekAll --- pkg/protocol/args.go | 10 ++++ pkg/protocol/header.go | 114 ++++++++++++++++++++++++++++++++++++ pkg/protocol/header_test.go | 69 ++++++++++++++++++++++ 3 files changed, 193 insertions(+) diff --git a/pkg/protocol/args.go b/pkg/protocol/args.go index 5731db570..ebf9ae56c 100644 --- a/pkg/protocol/args.go +++ b/pkg/protocol/args.go @@ -224,6 +224,16 @@ func peekArgBytes(h []argsKV, k []byte) []byte { return nil } +func peekAllArgBytesToDst(dst [][]byte, h []argsKV, k []byte) [][]byte { + for i, n := 0, len(h); i < n; i++ { + kv := &h[i] + if bytes.Equal(kv.key, k) { + dst = append(dst, kv.value) + } + } + return dst +} + func delAllArgsBytes(args []argsKV, key []byte) []argsKV { return delAllArgs(args, bytesconv.B2s(key)) } diff --git a/pkg/protocol/header.go b/pkg/protocol/header.go index b447ccef1..dfc5c8d81 100644 --- a/pkg/protocol/header.go +++ b/pkg/protocol/header.go @@ -83,6 +83,7 @@ type RequestHeader struct { host []byte contentType []byte userAgent []byte + mulHeader [][]byte h []argsKV bufKV argsKV @@ -121,6 +122,7 @@ type ResponseHeader struct { contentType []byte server []byte + mulHeader [][]byte h []argsKV bufKV argsKV @@ -486,6 +488,7 @@ func (h *ResponseHeader) ResetSkipNormalize() { h.h = h.h[:0] h.cookies = h.cookies[:0] + h.mulHeader = h.mulHeader[:0] } // ContentLength returns Content-Length header value. @@ -687,6 +690,94 @@ func (h *ResponseHeader) peek(key []byte) []byte { } } +// PeekAll returns all header value for the given key. +// +// The returned value is valid until the request is released, +// either though ReleaseResponse or your request handler returning. +// Any future calls to the Peek* will modify the returned value. +// Do not store references to returned value. Use ResponseHeader.GetAll(key) instead. +func (h *ResponseHeader) PeekAll(key string) [][]byte { + k := getHeaderKeyBytes(&h.bufKV, key, h.disableNormalizing) + return h.peekAll(k) +} + +func (h *ResponseHeader) peekAll(key []byte) [][]byte { + h.mulHeader = h.mulHeader[:0] + switch string(key) { + case consts.HeaderContentType: + if contentType := h.ContentType(); len(contentType) > 0 { + h.mulHeader = append(h.mulHeader, contentType) + } + case consts.HeaderContentEncoding: + if contentEncoding := h.ContentEncoding(); len(contentEncoding) > 0 { + h.mulHeader = append(h.mulHeader, contentEncoding) + } + case consts.HeaderServer: + if server := h.Server(); len(server) > 0 { + h.mulHeader = append(h.mulHeader, server) + } + case consts.HeaderConnection: + if h.ConnectionClose() { + h.mulHeader = append(h.mulHeader, bytestr.StrClose) + } else { + h.mulHeader = peekAllArgBytesToDst(h.mulHeader, h.h, key) + } + case consts.HeaderContentLength: + h.mulHeader = append(h.mulHeader, h.contentLengthBytes) + case consts.HeaderSetCookie: + h.mulHeader = append(h.mulHeader, appendResponseCookieBytes(nil, h.cookies)) + default: + h.mulHeader = peekAllArgBytesToDst(h.mulHeader, h.h, key) + } + return h.mulHeader +} + +// PeekAll returns all header value for the given key. +// +// The returned value is valid until the request is released, +// either though ReleaseRequest or your request handler returning. +// Any future calls to the Peek* will modify the returned value. +// Do not store references to returned value. Use RequestHeader.GetAll(key) instead. +func (h *RequestHeader) PeekAll(key string) [][]byte { + k := getHeaderKeyBytes(&h.bufKV, key, h.disableNormalizing) + return h.peekAll(k) +} + +func (h *RequestHeader) peekAll(key []byte) [][]byte { + h.mulHeader = h.mulHeader[:0] + switch string(key) { + case consts.HeaderHost: + if host := h.Host(); len(host) > 0 { + h.mulHeader = append(h.mulHeader, host) + } + case consts.HeaderContentType: + if contentType := h.ContentType(); len(contentType) > 0 { + h.mulHeader = append(h.mulHeader, contentType) + } + case consts.HeaderUserAgent: + if ua := h.UserAgent(); len(ua) > 0 { + h.mulHeader = append(h.mulHeader, ua) + } + case consts.HeaderConnection: + if h.ConnectionClose() { + h.mulHeader = append(h.mulHeader, bytestr.StrClose) + } else { + h.mulHeader = peekAllArgBytesToDst(h.mulHeader, h.h, key) + } + case consts.HeaderContentLength: + h.mulHeader = append(h.mulHeader, h.contentLengthBytes) + case consts.HeaderCookie: + if h.cookiesCollected { + h.mulHeader = append(h.mulHeader, appendRequestCookieBytes(nil, h.cookies)) + } else { + h.mulHeader = peekAllArgBytesToDst(h.mulHeader, h.h, key) + } + default: + h.mulHeader = peekAllArgBytesToDst(h.mulHeader, h.h, key) + } + return h.mulHeader +} + // SetContentTypeBytes sets Content-Type header value. func (h *ResponseHeader) SetContentTypeBytes(contentType []byte) { h.contentType = append(h.contentType[:0], contentType...) @@ -1299,6 +1390,7 @@ func (h *RequestHeader) ResetSkipNormalize() { h.cookiesCollected = false h.rawHeaders = h.rawHeaders[:0] + h.mulHeader = h.mulHeader[:0] } func peekRawHeader(buf, key []byte) []byte { @@ -1482,6 +1574,28 @@ func (h *ResponseHeader) Get(key string) string { return string(h.Peek(key)) } +// GetAll returns all header value for the given key +// it is concurrent safety and long lifetime. +func (h *RequestHeader) GetAll(key string) []string { + res := make([]string, 0) + headers := h.PeekAll(key) + for _, header := range headers { + res = append(res, string(header)) + } + return res +} + +// GetAll returns all header value for the given key and is concurrent safety. +// it is concurrent safety and long lifetime. +func (h *ResponseHeader) GetAll(key string) []string { + res := make([]string, 0) + headers := h.PeekAll(key) + for _, header := range headers { + res = append(res, string(header)) + } + return res +} + func appendHeaderLine(dst, key, value []byte) []byte { dst = append(dst, key...) dst = append(dst, bytestr.StrColonSpace...) diff --git a/pkg/protocol/header_test.go b/pkg/protocol/header_test.go index 0a7baf2fd..fc58a29c0 100644 --- a/pkg/protocol/header_test.go +++ b/pkg/protocol/header_test.go @@ -572,3 +572,72 @@ func TestRequestHeaderSetNoDefaultContentType(t *testing.T) { b = h.AppendBytes(nil) assert.DeepEqual(t, b, []byte("POST / HTTP/1.1\r\n\r\n")) } + +func TestRequestHeader_PeekAll(t *testing.T) { + t.Parallel() + h := &RequestHeader{} + h.Add(consts.HeaderConnection, "keep-alive") + h.Add("Content-Type", "aaa") + h.Add(consts.HeaderHost, "aaabbb") + h.Add("User-Agent", "asdfas") + h.Add("Content-Length", "1123") + h.Add("Cookie", "foobar=baz") + h.Add("aaa", "aaa") + h.Add("aaa", "bbb") + + expectRequestHeaderAll(t, h, consts.HeaderConnection, [][]byte{[]byte("keep-alive")}) + expectRequestHeaderAll(t, h, "Content-Type", [][]byte{[]byte("aaa")}) + expectRequestHeaderAll(t, h, consts.HeaderHost, [][]byte{[]byte("aaabbb")}) + expectRequestHeaderAll(t, h, "User-Agent", [][]byte{[]byte("asdfas")}) + expectRequestHeaderAll(t, h, "Content-Length", [][]byte{[]byte("1123")}) + expectRequestHeaderAll(t, h, "Cookie", [][]byte{[]byte("foobar=baz")}) + expectRequestHeaderAll(t, h, "aaa", [][]byte{[]byte("aaa"), []byte("bbb")}) + + h.DelBytes([]byte("Content-Type")) + h.DelBytes([]byte((consts.HeaderHost))) + h.DelBytes([]byte("aaa")) + expectRequestHeaderAll(t, h, "Content-Type", [][]byte{}) + expectRequestHeaderAll(t, h, consts.HeaderHost, [][]byte{}) + expectRequestHeaderAll(t, h, "aaa", [][]byte{}) +} + +func expectRequestHeaderAll(t *testing.T, h *RequestHeader, key string, expectedValue [][]byte) { + if len(h.PeekAll(key)) != len(expectedValue) { + t.Fatalf("Unexpected size for key %q: %d. Expected %d", key, len(h.PeekAll(key)), len(expectedValue)) + } + assert.DeepEqual(t, h.PeekAll(key), expectedValue) +} + +func TestResponseHeader_PeekAll(t *testing.T) { + t.Parallel() + + h := &ResponseHeader{} + h.Add(consts.HeaderContentType, "aaa/bbb") + h.Add(consts.HeaderContentEncoding, "gzip") + h.Add(consts.HeaderConnection, "close") + h.Add(consts.HeaderContentLength, "1234") + h.Add(consts.HeaderServer, "aaaa") + h.Add(consts.HeaderSetCookie, "cccc") + h.Add("aaa", "aaa") + h.Add("aaa", "bbb") + + expectResponseHeaderAll(t, h, consts.HeaderContentType, [][]byte{[]byte("aaa/bbb")}) + expectResponseHeaderAll(t, h, consts.HeaderContentEncoding, [][]byte{[]byte("gzip")}) + expectResponseHeaderAll(t, h, consts.HeaderConnection, [][]byte{[]byte("close")}) + expectResponseHeaderAll(t, h, consts.HeaderContentLength, [][]byte{[]byte("1234")}) + expectResponseHeaderAll(t, h, consts.HeaderServer, [][]byte{[]byte("aaaa")}) + expectResponseHeaderAll(t, h, consts.HeaderSetCookie, [][]byte{[]byte("cccc")}) + expectResponseHeaderAll(t, h, "aaa", [][]byte{[]byte("aaa"), []byte("bbb")}) + + h.Del(consts.HeaderContentType) + h.Del(consts.HeaderContentEncoding) + expectResponseHeaderAll(t, h, consts.HeaderContentType, [][]byte{bytestr.DefaultContentType}) + expectResponseHeaderAll(t, h, consts.HeaderContentEncoding, [][]byte{}) +} + +func expectResponseHeaderAll(t *testing.T, h *ResponseHeader, key string, expectedValue [][]byte) { + if len(h.PeekAll(key)) != len(expectedValue) { + t.Fatalf("Unexpected size for key %q: %d. Expected %d", key, len(h.PeekAll(key)), len(expectedValue)) + } + assert.DeepEqual(t, h.PeekAll(key), expectedValue) +}