Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add PeekAll function #569

Merged
merged 1 commit into from
Jan 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions pkg/protocol/args.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,16 @@ func peekArgBytes(h []argsKV, k []byte) []byte {
return nil
}

func peekAllArgBytesToDst(dst [][]byte, h []argsKV, k []byte) [][]byte {
for i, n := 0, len(h); i < n; i++ {
kv := &h[i]
if bytes.Equal(kv.key, k) {
dst = append(dst, kv.value)
}
}
return dst
}

func delAllArgsBytes(args []argsKV, key []byte) []argsKV {
return delAllArgs(args, bytesconv.B2s(key))
}
Expand Down
114 changes: 114 additions & 0 deletions pkg/protocol/header.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ type RequestHeader struct {
host []byte
contentType []byte
userAgent []byte
mulHeader [][]byte

h []argsKV
bufKV argsKV
Expand Down Expand Up @@ -121,6 +122,7 @@ type ResponseHeader struct {

contentType []byte
server []byte
mulHeader [][]byte

h []argsKV
bufKV argsKV
Expand Down Expand Up @@ -486,6 +488,7 @@ func (h *ResponseHeader) ResetSkipNormalize() {

h.h = h.h[:0]
h.cookies = h.cookies[:0]
h.mulHeader = h.mulHeader[:0]
}

// ContentLength returns Content-Length header value.
Expand Down Expand Up @@ -687,6 +690,94 @@ func (h *ResponseHeader) peek(key []byte) []byte {
}
}

// PeekAll returns all header value for the given key.
//
// The returned value is valid until the request is released,
// either though ReleaseResponse or your request handler returning.
// Any future calls to the Peek* will modify the returned value.
// Do not store references to returned value. Use ResponseHeader.GetAll(key) instead.
func (h *ResponseHeader) PeekAll(key string) [][]byte {
k := getHeaderKeyBytes(&h.bufKV, key, h.disableNormalizing)
return h.peekAll(k)
}

func (h *ResponseHeader) peekAll(key []byte) [][]byte {
h.mulHeader = h.mulHeader[:0]
switch string(key) {
case consts.HeaderContentType:
if contentType := h.ContentType(); len(contentType) > 0 {
h.mulHeader = append(h.mulHeader, contentType)
}
case consts.HeaderContentEncoding:
if contentEncoding := h.ContentEncoding(); len(contentEncoding) > 0 {
h.mulHeader = append(h.mulHeader, contentEncoding)
}
case consts.HeaderServer:
if server := h.Server(); len(server) > 0 {
h.mulHeader = append(h.mulHeader, server)
}
case consts.HeaderConnection:
if h.ConnectionClose() {
h.mulHeader = append(h.mulHeader, bytestr.StrClose)
} else {
h.mulHeader = peekAllArgBytesToDst(h.mulHeader, h.h, key)
}
case consts.HeaderContentLength:
h.mulHeader = append(h.mulHeader, h.contentLengthBytes)
case consts.HeaderSetCookie:
h.mulHeader = append(h.mulHeader, appendResponseCookieBytes(nil, h.cookies))
default:
h.mulHeader = peekAllArgBytesToDst(h.mulHeader, h.h, key)
}
return h.mulHeader
}

// PeekAll returns all header value for the given key.
//
// The returned value is valid until the request is released,
// either though ReleaseRequest or your request handler returning.
// Any future calls to the Peek* will modify the returned value.
// Do not store references to returned value. Use RequestHeader.GetAll(key) instead.
func (h *RequestHeader) PeekAll(key string) [][]byte {
k := getHeaderKeyBytes(&h.bufKV, key, h.disableNormalizing)
return h.peekAll(k)
}

func (h *RequestHeader) peekAll(key []byte) [][]byte {
h.mulHeader = h.mulHeader[:0]
switch string(key) {
case consts.HeaderHost:
if host := h.Host(); len(host) > 0 {
h.mulHeader = append(h.mulHeader, host)
}
case consts.HeaderContentType:
if contentType := h.ContentType(); len(contentType) > 0 {
h.mulHeader = append(h.mulHeader, contentType)
}
case consts.HeaderUserAgent:
if ua := h.UserAgent(); len(ua) > 0 {
h.mulHeader = append(h.mulHeader, ua)
}
case consts.HeaderConnection:
if h.ConnectionClose() {
h.mulHeader = append(h.mulHeader, bytestr.StrClose)
} else {
h.mulHeader = peekAllArgBytesToDst(h.mulHeader, h.h, key)
}
case consts.HeaderContentLength:
h.mulHeader = append(h.mulHeader, h.contentLengthBytes)
case consts.HeaderCookie:
if h.cookiesCollected {
h.mulHeader = append(h.mulHeader, appendRequestCookieBytes(nil, h.cookies))
} else {
h.mulHeader = peekAllArgBytesToDst(h.mulHeader, h.h, key)
}
default:
h.mulHeader = peekAllArgBytesToDst(h.mulHeader, h.h, key)
}
return h.mulHeader
}

// SetContentTypeBytes sets Content-Type header value.
func (h *ResponseHeader) SetContentTypeBytes(contentType []byte) {
h.contentType = append(h.contentType[:0], contentType...)
Expand Down Expand Up @@ -1299,6 +1390,7 @@ func (h *RequestHeader) ResetSkipNormalize() {
h.cookiesCollected = false

h.rawHeaders = h.rawHeaders[:0]
h.mulHeader = h.mulHeader[:0]
}

func peekRawHeader(buf, key []byte) []byte {
Expand Down Expand Up @@ -1482,6 +1574,28 @@ func (h *ResponseHeader) Get(key string) string {
return string(h.Peek(key))
}

// GetAll returns all header value for the given key
// it is concurrent safety and long lifetime.
func (h *RequestHeader) GetAll(key string) []string {
res := make([]string, 0)
headers := h.PeekAll(key)
for _, header := range headers {
res = append(res, string(header))
}
return res
}

// GetAll returns all header value for the given key and is concurrent safety.
// it is concurrent safety and long lifetime.
func (h *ResponseHeader) GetAll(key string) []string {
res := make([]string, 0)
headers := h.PeekAll(key)
for _, header := range headers {
res = append(res, string(header))
}
return res
}

func appendHeaderLine(dst, key, value []byte) []byte {
dst = append(dst, key...)
dst = append(dst, bytestr.StrColonSpace...)
Expand Down
69 changes: 69 additions & 0 deletions pkg/protocol/header_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -572,3 +572,72 @@ func TestRequestHeaderSetNoDefaultContentType(t *testing.T) {
b = h.AppendBytes(nil)
assert.DeepEqual(t, b, []byte("POST / HTTP/1.1\r\n\r\n"))
}

func TestRequestHeader_PeekAll(t *testing.T) {
t.Parallel()
h := &RequestHeader{}
h.Add(consts.HeaderConnection, "keep-alive")
h.Add("Content-Type", "aaa")
h.Add(consts.HeaderHost, "aaabbb")
h.Add("User-Agent", "asdfas")
h.Add("Content-Length", "1123")
h.Add("Cookie", "foobar=baz")
h.Add("aaa", "aaa")
h.Add("aaa", "bbb")

expectRequestHeaderAll(t, h, consts.HeaderConnection, [][]byte{[]byte("keep-alive")})
expectRequestHeaderAll(t, h, "Content-Type", [][]byte{[]byte("aaa")})
expectRequestHeaderAll(t, h, consts.HeaderHost, [][]byte{[]byte("aaabbb")})
expectRequestHeaderAll(t, h, "User-Agent", [][]byte{[]byte("asdfas")})
expectRequestHeaderAll(t, h, "Content-Length", [][]byte{[]byte("1123")})
expectRequestHeaderAll(t, h, "Cookie", [][]byte{[]byte("foobar=baz")})
expectRequestHeaderAll(t, h, "aaa", [][]byte{[]byte("aaa"), []byte("bbb")})

h.DelBytes([]byte("Content-Type"))
h.DelBytes([]byte((consts.HeaderHost)))
h.DelBytes([]byte("aaa"))
expectRequestHeaderAll(t, h, "Content-Type", [][]byte{})
expectRequestHeaderAll(t, h, consts.HeaderHost, [][]byte{})
expectRequestHeaderAll(t, h, "aaa", [][]byte{})
}

func expectRequestHeaderAll(t *testing.T, h *RequestHeader, key string, expectedValue [][]byte) {
if len(h.PeekAll(key)) != len(expectedValue) {
t.Fatalf("Unexpected size for key %q: %d. Expected %d", key, len(h.PeekAll(key)), len(expectedValue))
}
assert.DeepEqual(t, h.PeekAll(key), expectedValue)
}

func TestResponseHeader_PeekAll(t *testing.T) {
t.Parallel()

h := &ResponseHeader{}
h.Add(consts.HeaderContentType, "aaa/bbb")
h.Add(consts.HeaderContentEncoding, "gzip")
h.Add(consts.HeaderConnection, "close")
h.Add(consts.HeaderContentLength, "1234")
h.Add(consts.HeaderServer, "aaaa")
h.Add(consts.HeaderSetCookie, "cccc")
h.Add("aaa", "aaa")
h.Add("aaa", "bbb")

expectResponseHeaderAll(t, h, consts.HeaderContentType, [][]byte{[]byte("aaa/bbb")})
expectResponseHeaderAll(t, h, consts.HeaderContentEncoding, [][]byte{[]byte("gzip")})
expectResponseHeaderAll(t, h, consts.HeaderConnection, [][]byte{[]byte("close")})
expectResponseHeaderAll(t, h, consts.HeaderContentLength, [][]byte{[]byte("1234")})
expectResponseHeaderAll(t, h, consts.HeaderServer, [][]byte{[]byte("aaaa")})
expectResponseHeaderAll(t, h, consts.HeaderSetCookie, [][]byte{[]byte("cccc")})
expectResponseHeaderAll(t, h, "aaa", [][]byte{[]byte("aaa"), []byte("bbb")})

h.Del(consts.HeaderContentType)
h.Del(consts.HeaderContentEncoding)
expectResponseHeaderAll(t, h, consts.HeaderContentType, [][]byte{bytestr.DefaultContentType})
expectResponseHeaderAll(t, h, consts.HeaderContentEncoding, [][]byte{})
}

func expectResponseHeaderAll(t *testing.T, h *ResponseHeader, key string, expectedValue [][]byte) {
if len(h.PeekAll(key)) != len(expectedValue) {
t.Fatalf("Unexpected size for key %q: %d. Expected %d", key, len(h.PeekAll(key)), len(expectedValue))
}
assert.DeepEqual(t, h.PeekAll(key), expectedValue)
}