diff --git a/header.go b/header.go index e13f6852db..8e8f7d009f 100644 --- a/header.go +++ b/header.go @@ -11,6 +11,11 @@ import ( "time" ) +const ( + rChar = byte('\r') + nChar = byte('\n') +) + // ResponseHeader represents HTTP response header. // // It is forbidden copying ResponseHeader instances. @@ -1419,7 +1424,7 @@ func bufferSnippet(b []byte) string { func isOnlyCRLF(b []byte) bool { for _, ch := range b { - if ch != '\r' && ch != '\n' { + if ch != rChar && ch != nChar { return false } } @@ -1731,7 +1736,7 @@ func peekRawHeader(buf, key []byte) []byte { if n < 0 { return nil } - if n > 0 && buf[n-1] != '\n' { + if n > 0 && buf[n-1] != nChar { return nil } n += len(key) @@ -1747,22 +1752,22 @@ func peekRawHeader(buf, key []byte) []byte { } n++ buf = buf[n:] - n = bytes.IndexByte(buf, '\n') + n = bytes.IndexByte(buf, nChar) if n < 0 { return nil } - if n > 0 && buf[n-1] == '\r' { + if n > 0 && buf[n-1] == rChar { n-- } return buf[:n] } func readRawHeaders(dst, buf []byte) ([]byte, int, error) { - n := bytes.IndexByte(buf, '\n') + n := bytes.IndexByte(buf, nChar) if n < 0 { return dst[:0], 0, errNeedMore } - if (n == 1 && buf[0] == '\r') || n == 0 { + if (n == 1 && buf[0] == rChar) || n == 0 { // empty headers return dst, n + 1, nil } @@ -1772,13 +1777,13 @@ func readRawHeaders(dst, buf []byte) ([]byte, int, error) { m := n for { b = b[m:] - m = bytes.IndexByte(b, '\n') + m = bytes.IndexByte(b, nChar) if m < 0 { return dst, 0, errNeedMore } m++ n += m - if (m == 2 && b[0] == '\r') || m == 1 { + if (m == 2 && b[0] == rChar) || m == 1 { dst = append(dst, buf[:n]...) return dst, n, nil } @@ -2011,12 +2016,12 @@ func (s *headerScanner) next() bool { s.initialized = true } bLen := len(s.b) - if bLen >= 2 && s.b[0] == '\r' && s.b[1] == '\n' { + if bLen >= 2 && s.b[0] == rChar && s.b[1] == nChar { s.b = s.b[2:] s.hLen += 2 return false } - if bLen >= 1 && s.b[0] == '\n' { + if bLen >= 1 && s.b[0] == nChar { s.b = s.b[1:] s.hLen++ return false @@ -2029,7 +2034,7 @@ func (s *headerScanner) next() bool { n = bytes.IndexByte(s.b, ':') // There can't be a \n inside the header name, check for this. - x := bytes.IndexByte(s.b, '\n') + x := bytes.IndexByte(s.b, nChar) if x < 0 { // A header name should always at some point be followed by a \n // even if it's the one that terminates the header block. @@ -2062,7 +2067,7 @@ func (s *headerScanner) next() bool { n = s.nextNewLine s.nextNewLine = -1 } else { - n = bytes.IndexByte(s.b, '\n') + n = bytes.IndexByte(s.b, nChar) } if n < 0 { s.err = errNeedMore @@ -2076,10 +2081,10 @@ func (s *headerScanner) next() bool { if s.b[n+1] != ' ' && s.b[n+1] != '\t' { break } - d := bytes.IndexByte(s.b[n+1:], '\n') + d := bytes.IndexByte(s.b[n+1:], nChar) if d <= 0 { break - } else if d == 1 && s.b[n+1] == '\r' { + } else if d == 1 && s.b[n+1] == rChar { break } e := n + d + 1 @@ -2100,7 +2105,7 @@ func (s *headerScanner) next() bool { s.hLen += n + 1 s.b = s.b[n+1:] - if n > 0 && s.value[n-1] == '\r' { + if n > 0 && s.value[n-1] == rChar { n-- } for n > 0 && s.value[n-1] == ' ' { @@ -2156,12 +2161,12 @@ func hasHeaderValue(s, value []byte) bool { } func nextLine(b []byte) ([]byte, []byte, error) { - nNext := bytes.IndexByte(b, '\n') + nNext := bytes.IndexByte(b, nChar) if nNext < 0 { return nil, nil, errNeedMore } n := nNext - if n > 0 && b[n-1] == '\r' { + if n > 0 && b[n-1] == rChar { n-- } return b[:n], b[nNext+1:], nil @@ -2169,7 +2174,9 @@ func nextLine(b []byte) ([]byte, []byte, error) { func initHeaderKV(kv *argsKV, key, value string, disableNormalizing bool) { kv.key = getHeaderKeyBytes(kv, key, disableNormalizing) + // https://tools.ietf.org/html/rfc7230#section-3.2.4 kv.value = append(kv.value[:0], value...) + kv.value = removeNewLines(kv.value) } func getHeaderKeyBytes(kv *argsKV, key string, disableNormalizing bool) []byte { @@ -2189,9 +2196,9 @@ func normalizeHeaderValue(ov, ob []byte, headerLength int) (nv, nb []byte, nhl i lineStart := false for read := 0; read < length; read++ { c := ov[read] - if c == '\r' || c == '\n' { + if c == rChar || c == nChar { shrunk++ - if c == '\n' { + if c == nChar { lineStart = true } continue @@ -2209,13 +2216,13 @@ func normalizeHeaderValue(ov, ob []byte, headerLength int) (nv, nb []byte, nhl i // Check if we need to skip \r\n or just \n skip := 0 - if ob[write] == '\r' { - if ob[write+1] == '\n' { + if ob[write] == rChar { + if ob[write+1] == nChar { skip += 2 } else { skip++ } - } else if ob[write] == '\n' { + } else if ob[write] == nChar { skip++ } @@ -2248,6 +2255,37 @@ func normalizeHeaderKey(b []byte, disableNormalizing bool) { } } +// removeNewLines will replace `\r` and `\n` with an empty space +func removeNewLines(raw []byte) []byte { + // check if a `\r` is present and save the position. + // if no `\r` is found, check if a `\n` is present. + foundR := bytes.IndexByte(raw, rChar) + foundN := bytes.IndexByte(raw, nChar) + start := 0 + + if foundN != -1 { + if foundR > foundN { + start = foundN + } else if foundR != -1 { + start = foundR + } + } else if foundR != -1 { + start = foundR + } else { + return raw + } + + for i := start; i < len(raw); i++ { + switch raw[i] { + case rChar, nChar: + raw[i] = ' ' + default: + continue + } + } + return raw +} + // AppendNormalizedHeaderKey appends normalized header key (name) to dst // and returns the resulting dst. // diff --git a/header_timing_test.go b/header_timing_test.go index 55f5df9587..8c2f102aff 100644 --- a/header_timing_test.go +++ b/header_timing_test.go @@ -4,6 +4,7 @@ import ( "bufio" "bytes" "io" + "strconv" "testing" "github.com/valyala/bytebufferpool" @@ -146,3 +147,34 @@ func benchmarkNormalizeHeaderKey(b *testing.B, src []byte) { } }) } + +func BenchmarkRemoveNewLines(b *testing.B) { + type testcase struct { + value string + expectedValue string + } + + var testcases = []testcase{ + {value: "MaliciousValue", expectedValue: "MaliciousValue"}, + {value: "MaliciousValue\r\n", expectedValue: "MaliciousValue "}, + {value: "Malicious\nValue", expectedValue: "Malicious Value"}, + {value: "Malicious\rValue", expectedValue: "Malicious Value"}, + } + + for i, tcase := range testcases { + caseName := strconv.FormatInt(int64(i), 10) + b.Run(caseName, func(subB *testing.B) { + subB.ReportAllocs() + var h RequestHeader + for i := 0; i < subB.N; i++ { + h.Set("Test", tcase.value) + } + subB.StopTimer() + actualValue := string(h.Peek("Test")) + + if actualValue != tcase.expectedValue { + subB.Errorf("unexpected value, got: %+v", actualValue) + } + }) + } +} diff --git a/http_test.go b/http_test.go index c67f52092b..2f91e345ce 100644 --- a/http_test.go +++ b/http_test.go @@ -8,6 +8,7 @@ import ( "io/ioutil" "mime/multipart" "reflect" + "strconv" "strings" "testing" "time" @@ -30,6 +31,53 @@ func TestFragmentInURIRequest(t *testing.T) { } } +func TestIssue875(t *testing.T) { + type testcase struct { + uri string + expectedRedirect string + expectedLocation string + } + + var testcases = []testcase{ + { + uri: `http://localhost:3000/?redirect=foo%0d%0aSet-Cookie:%20SESSIONID=MaliciousValue%0d%0a`, + expectedRedirect: "foo\r\nSet-Cookie: SESSIONID=MaliciousValue\r\n", + expectedLocation: "Location: foo Set-Cookie: SESSIONID=MaliciousValue", + }, + { + uri: `http://localhost:3000/?redirect=foo%0dSet-Cookie:%20SESSIONID=MaliciousValue%0d%0a`, + expectedRedirect: "foo\rSet-Cookie: SESSIONID=MaliciousValue\r\n", + expectedLocation: "Location: foo Set-Cookie: SESSIONID=MaliciousValue", + }, + { + uri: `http://localhost:3000/?redirect=foo%0aSet-Cookie:%20SESSIONID=MaliciousValue%0d%0a`, + expectedRedirect: "foo\nSet-Cookie: SESSIONID=MaliciousValue\r\n", + expectedLocation: "Location: foo Set-Cookie: SESSIONID=MaliciousValue", + }, + } + + for i, tcase := range testcases { + caseName := strconv.FormatInt(int64(i), 10) + t.Run(caseName, func(subT *testing.T) { + ctx := &RequestCtx{ + Request: Request{}, + Response: Response{}, + } + ctx.Request.SetRequestURI(tcase.uri) + + q := string(ctx.QueryArgs().Peek("redirect")) + if q != tcase.expectedRedirect { + subT.Errorf("unexpected redirect query value, got: %+v", q) + } + ctx.Response.Header.Set("Location", q) + + if !strings.Contains(ctx.Response.String(), tcase.expectedLocation) { + subT.Errorf("invalid escaping, got\n%s", ctx.Response.String()) + } + }) + } +} + func TestRequestCopyTo(t *testing.T) { t.Parallel()