Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix issue #875 #909

Merged
merged 9 commits into from
Dec 9, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 60 additions & 22 deletions header.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ import (
"time"
)

const (
rChar = byte('\r')
nChar = byte('\n')
)

// ResponseHeader represents HTTP response header.
//
// It is forbidden copying ResponseHeader instances.
Expand Down Expand Up @@ -1419,7 +1424,7 @@ func bufferSnippet(b []byte) string {

func isOnlyCRLF(b []byte) bool {
for _, ch := range b {
if ch != '\r' && ch != '\n' {
if ch != rChar && ch != nChar {
return false
}
}
Expand Down Expand Up @@ -1731,7 +1736,7 @@ func peekRawHeader(buf, key []byte) []byte {
if n < 0 {
return nil
}
if n > 0 && buf[n-1] != '\n' {
if n > 0 && buf[n-1] != nChar {
return nil
}
n += len(key)
Expand All @@ -1747,22 +1752,22 @@ func peekRawHeader(buf, key []byte) []byte {
}
n++
buf = buf[n:]
n = bytes.IndexByte(buf, '\n')
n = bytes.IndexByte(buf, nChar)
if n < 0 {
return nil
}
if n > 0 && buf[n-1] == '\r' {
if n > 0 && buf[n-1] == rChar {
n--
}
return buf[:n]
}

func readRawHeaders(dst, buf []byte) ([]byte, int, error) {
n := bytes.IndexByte(buf, '\n')
n := bytes.IndexByte(buf, nChar)
if n < 0 {
return dst[:0], 0, errNeedMore
}
if (n == 1 && buf[0] == '\r') || n == 0 {
if (n == 1 && buf[0] == rChar) || n == 0 {
// empty headers
return dst, n + 1, nil
}
Expand All @@ -1772,13 +1777,13 @@ func readRawHeaders(dst, buf []byte) ([]byte, int, error) {
m := n
for {
b = b[m:]
m = bytes.IndexByte(b, '\n')
m = bytes.IndexByte(b, nChar)
if m < 0 {
return dst, 0, errNeedMore
}
m++
n += m
if (m == 2 && b[0] == '\r') || m == 1 {
if (m == 2 && b[0] == rChar) || m == 1 {
dst = append(dst, buf[:n]...)
return dst, n, nil
}
Expand Down Expand Up @@ -2011,12 +2016,12 @@ func (s *headerScanner) next() bool {
s.initialized = true
}
bLen := len(s.b)
if bLen >= 2 && s.b[0] == '\r' && s.b[1] == '\n' {
if bLen >= 2 && s.b[0] == rChar && s.b[1] == nChar {
s.b = s.b[2:]
s.hLen += 2
return false
}
if bLen >= 1 && s.b[0] == '\n' {
if bLen >= 1 && s.b[0] == nChar {
s.b = s.b[1:]
s.hLen++
return false
Expand All @@ -2029,7 +2034,7 @@ func (s *headerScanner) next() bool {
n = bytes.IndexByte(s.b, ':')

// There can't be a \n inside the header name, check for this.
x := bytes.IndexByte(s.b, '\n')
x := bytes.IndexByte(s.b, nChar)
if x < 0 {
// A header name should always at some point be followed by a \n
// even if it's the one that terminates the header block.
Expand Down Expand Up @@ -2062,7 +2067,7 @@ func (s *headerScanner) next() bool {
n = s.nextNewLine
s.nextNewLine = -1
} else {
n = bytes.IndexByte(s.b, '\n')
n = bytes.IndexByte(s.b, nChar)
}
if n < 0 {
s.err = errNeedMore
Expand All @@ -2076,10 +2081,10 @@ func (s *headerScanner) next() bool {
if s.b[n+1] != ' ' && s.b[n+1] != '\t' {
break
}
d := bytes.IndexByte(s.b[n+1:], '\n')
d := bytes.IndexByte(s.b[n+1:], nChar)
if d <= 0 {
break
} else if d == 1 && s.b[n+1] == '\r' {
} else if d == 1 && s.b[n+1] == rChar {
break
}
e := n + d + 1
Expand All @@ -2100,7 +2105,7 @@ func (s *headerScanner) next() bool {
s.hLen += n + 1
s.b = s.b[n+1:]

if n > 0 && s.value[n-1] == '\r' {
if n > 0 && s.value[n-1] == rChar {
n--
}
for n > 0 && s.value[n-1] == ' ' {
Expand Down Expand Up @@ -2156,20 +2161,22 @@ func hasHeaderValue(s, value []byte) bool {
}

func nextLine(b []byte) ([]byte, []byte, error) {
nNext := bytes.IndexByte(b, '\n')
nNext := bytes.IndexByte(b, nChar)
if nNext < 0 {
return nil, nil, errNeedMore
}
n := nNext
if n > 0 && b[n-1] == '\r' {
if n > 0 && b[n-1] == rChar {
n--
}
return b[:n], b[nNext+1:], nil
}

func initHeaderKV(kv *argsKV, key, value string, disableNormalizing bool) {
kv.key = getHeaderKeyBytes(kv, key, disableNormalizing)
// https://tools.ietf.org/html/rfc7230#section-3.2.4
kv.value = append(kv.value[:0], value...)
kv.value = removeNewLines(kv.value)
}

func getHeaderKeyBytes(kv *argsKV, key string, disableNormalizing bool) []byte {
Expand All @@ -2189,9 +2196,9 @@ func normalizeHeaderValue(ov, ob []byte, headerLength int) (nv, nb []byte, nhl i
lineStart := false
for read := 0; read < length; read++ {
c := ov[read]
if c == '\r' || c == '\n' {
if c == rChar || c == nChar {
shrunk++
if c == '\n' {
if c == nChar {
lineStart = true
}
continue
Expand All @@ -2209,13 +2216,13 @@ func normalizeHeaderValue(ov, ob []byte, headerLength int) (nv, nb []byte, nhl i

// Check if we need to skip \r\n or just \n
skip := 0
if ob[write] == '\r' {
if ob[write+1] == '\n' {
if ob[write] == rChar {
if ob[write+1] == nChar {
skip += 2
} else {
skip++
}
} else if ob[write] == '\n' {
} else if ob[write] == nChar {
skip++
}

Expand Down Expand Up @@ -2248,6 +2255,37 @@ func normalizeHeaderKey(b []byte, disableNormalizing bool) {
}
}

// removeNewLines will replace `\r` and `\n` with an empty space
func removeNewLines(raw []byte) []byte {
// check if a `\r` is present and save the position.
// if no `\r` is found, check if a `\n` is present.
foundR := bytes.IndexByte(raw, rChar)
foundN := bytes.IndexByte(raw, nChar)
start := 0

if foundN != -1 {
if foundR > foundN {
start = foundN
} else if foundR != -1 {
start = foundR
}
} else if foundR != -1 {
start = foundR
} else {
return raw
}

for i := start; i < len(raw); i++ {
switch raw[i] {
case rChar, nChar:
raw[i] = ' '
default:
continue
}
}
return raw
}

// AppendNormalizedHeaderKey appends normalized header key (name) to dst
// and returns the resulting dst.
//
Expand Down
32 changes: 32 additions & 0 deletions header_timing_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bufio"
"bytes"
"io"
"strconv"
"testing"

"github.com/valyala/bytebufferpool"
Expand Down Expand Up @@ -146,3 +147,34 @@ func benchmarkNormalizeHeaderKey(b *testing.B, src []byte) {
}
})
}

func BenchmarkRemoveNewLines(b *testing.B) {
type testcase struct {
value string
expectedValue string
}

var testcases = []testcase{
{value: "MaliciousValue", expectedValue: "MaliciousValue"},
{value: "MaliciousValue\r\n", expectedValue: "MaliciousValue "},
{value: "Malicious\nValue", expectedValue: "Malicious Value"},
{value: "Malicious\rValue", expectedValue: "Malicious Value"},
}

for i, tcase := range testcases {
caseName := strconv.FormatInt(int64(i), 10)
b.Run(caseName, func(subB *testing.B) {
subB.ReportAllocs()
var h RequestHeader
for i := 0; i < subB.N; i++ {
h.Set("Test", tcase.value)
}
subB.StopTimer()
actualValue := string(h.Peek("Test"))

if actualValue != tcase.expectedValue {
subB.Errorf("unexpected value, got: %+v", actualValue)
}
})
}
}
48 changes: 48 additions & 0 deletions http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"io/ioutil"
"mime/multipart"
"reflect"
"strconv"
"strings"
"testing"
"time"
Expand All @@ -30,6 +31,53 @@ func TestFragmentInURIRequest(t *testing.T) {
}
}

func TestIssue875(t *testing.T) {
type testcase struct {
uri string
expectedRedirect string
expectedLocation string
}

var testcases = []testcase{
{
uri: `http://localhost:3000/?redirect=foo%0d%0aSet-Cookie:%20SESSIONID=MaliciousValue%0d%0a`,
expectedRedirect: "foo\r\nSet-Cookie: SESSIONID=MaliciousValue\r\n",
expectedLocation: "Location: foo Set-Cookie: SESSIONID=MaliciousValue",
},
{
uri: `http://localhost:3000/?redirect=foo%0dSet-Cookie:%20SESSIONID=MaliciousValue%0d%0a`,
expectedRedirect: "foo\rSet-Cookie: SESSIONID=MaliciousValue\r\n",
expectedLocation: "Location: foo Set-Cookie: SESSIONID=MaliciousValue",
},
{
uri: `http://localhost:3000/?redirect=foo%0aSet-Cookie:%20SESSIONID=MaliciousValue%0d%0a`,
expectedRedirect: "foo\nSet-Cookie: SESSIONID=MaliciousValue\r\n",
expectedLocation: "Location: foo Set-Cookie: SESSIONID=MaliciousValue",
},
}

for i, tcase := range testcases {
caseName := strconv.FormatInt(int64(i), 10)
t.Run(caseName, func(subT *testing.T) {
ctx := &RequestCtx{
Request: Request{},
Response: Response{},
}
ctx.Request.SetRequestURI(tcase.uri)

q := string(ctx.QueryArgs().Peek("redirect"))
if q != tcase.expectedRedirect {
subT.Errorf("unexpected redirect query value, got: %+v", q)
}
ctx.Response.Header.Set("Location", q)

if !strings.Contains(ctx.Response.String(), tcase.expectedLocation) {
subT.Errorf("invalid escaping, got\n%s", ctx.Response.String())
}
})
}
}

func TestRequestCopyTo(t *testing.T) {
t.Parallel()

Expand Down