Skip to content

Commit

Permalink
caddyhttp: Refactor header matching
Browse files Browse the repository at this point in the history
This allows response matchers to benefit from the same matching logic
as the request header matchers (mainly prefix/suffix wildcards).
  • Loading branch information
mholt committed May 26, 2020
1 parent 294910c commit e5bbed1
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 45 deletions.
73 changes: 28 additions & 45 deletions modules/caddyhttp/matchers.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ package caddyhttp
import (
"encoding/json"
"fmt"
"log"
"net"
"net/http"
"net/textproto"
Expand All @@ -28,6 +27,7 @@ import (

"github.com/caddyserver/caddy/v2"
"github.com/caddyserver/caddy/v2/caddyconfig/caddyfile"
"go.uber.org/zap"
)

type (
Expand Down Expand Up @@ -105,7 +105,8 @@ type (
MatchRemoteIP struct {
Ranges []string `json:"ranges,omitempty"`

cidrs []*net.IPNet
cidrs []*net.IPNet
logger *zap.Logger
}

// MatchNot matches requests by negating the results of its matcher
Expand Down Expand Up @@ -410,23 +411,28 @@ func (m *MatchHeader) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
return nil
}

// Like req.Header.Get(), but that works with Host header.
// go's http module swallows "Host" header.
func getHeader(r *http.Request, field string) []string {
field = textproto.CanonicalMIMEHeaderKey(field)
// Match returns true if r matches m.
func (m MatchHeader) Match(r *http.Request) bool {
return matchHeaders(r.Header, http.Header(m), r.Host)
}

if field == "Host" {
return []string{r.Host}
// getHeaderFieldVals returns the field values for the given fieldName from input.
// The host parameter should be obtained from the http.Request.Host field since
// net/http removes it from the header map.
func getHeaderFieldVals(input http.Header, fieldName, host string) []string {
fieldName = textproto.CanonicalMIMEHeaderKey(fieldName)
if fieldName == "Host" && host != "" {
return []string{host}
}

return r.Header[field]
return input[fieldName]
}

// Match returns true if r matches m.
func (m MatchHeader) Match(r *http.Request) bool {
for field, allowedFieldVals := range m {
actualFieldVals := getHeader(r, field)

// matchHeaders returns true if input matches the criteria in against without regex.
// The host parameter should be obtained from the http.Request.Host field since
// net/http removes it from the header map.
func matchHeaders(input, against http.Header, host string) bool {
for field, allowedFieldVals := range against {
actualFieldVals := getHeaderFieldVals(input, field, host)
if allowedFieldVals != nil && len(allowedFieldVals) == 0 && actualFieldVals != nil {
// a non-nil but empty list of allowed values means
// match if the header field exists at all
Expand Down Expand Up @@ -501,8 +507,7 @@ func (m *MatchHeaderRE) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
// Match returns true if r matches m.
func (m MatchHeaderRE) Match(r *http.Request) bool {
for field, rm := range m {
actualFieldVals := getHeader(r, field)

actualFieldVals := getHeaderFieldVals(r.Header, field, r.Host)
match := false
fieldVal:
for _, actualFieldVal := range actualFieldVals {
Expand Down Expand Up @@ -700,6 +705,7 @@ func (m *MatchRemoteIP) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {

// Provision parses m's IP ranges, either from IP or CIDR expressions.
func (m *MatchRemoteIP) Provision(ctx caddy.Context) error {
m.logger = ctx.Logger(m)
for _, str := range m.Ranges {
if strings.Contains(str, "/") {
_, ipNet, err := net.ParseCIDR(str)
Expand Down Expand Up @@ -748,7 +754,7 @@ func (m MatchRemoteIP) getClientIP(r *http.Request) (net.IP, error) {
func (m MatchRemoteIP) Match(r *http.Request) bool {
clientIP, err := m.getClientIP(r)
if err != nil {
log.Printf("[ERROR] remote_ip matcher: %v", err)
m.logger.Error("getting client IP", zap.Error(err))
return false
}
for _, ipRange := range m.cidrs {
Expand Down Expand Up @@ -859,7 +865,9 @@ type ResponseMatcher struct {
// in that class (e.g. 3 for all 3xx codes).
StatusCode []int `json:"status_code,omitempty"`

// If set, each header specified must be one of the specified values.
// If set, each header specified must be one of the
// specified values, with the same logic used by the
// request header matcher.
Headers http.Header `json:"headers,omitempty"`
}

Expand All @@ -868,7 +876,7 @@ func (rm ResponseMatcher) Match(statusCode int, hdr http.Header) bool {
if !rm.matchStatusCode(statusCode) {
return false
}
return rm.matchHeaders(hdr)
return matchHeaders(hdr, rm.Headers, "")
}

func (rm ResponseMatcher) matchStatusCode(statusCode int) bool {
Expand All @@ -883,31 +891,6 @@ func (rm ResponseMatcher) matchStatusCode(statusCode int) bool {
return false
}

func (rm ResponseMatcher) matchHeaders(hdr http.Header) bool {
for field, allowedFieldVals := range rm.Headers {
actualFieldVals, fieldExists := hdr[textproto.CanonicalMIMEHeaderKey(field)]
if allowedFieldVals != nil && len(allowedFieldVals) == 0 && fieldExists {
// a non-nil but empty list of allowed values means
// match if the header field exists at all
continue
}
var match bool
fieldVals:
for _, actualFieldVal := range actualFieldVals {
for _, allowedFieldVal := range allowedFieldVals {
if actualFieldVal == allowedFieldVal {
match = true
break fieldVals
}
}
}
if !match {
return false
}
}
return true
}

var wordRE = regexp.MustCompile(`\w+`)

const regexpPlaceholderPrefix = "http.regexp"
Expand Down
33 changes: 33 additions & 0 deletions modules/caddyhttp/matchers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,21 @@ func TestHeaderMatcher(t *testing.T) {
input: http.Header{"Field2": []string{"foo"}},
expect: false,
},
{
match: MatchHeader{"Field1": []string{"foo*"}},
input: http.Header{"Field1": []string{"foo"}},
expect: true,
},
{
match: MatchHeader{"Field1": []string{"foo*"}},
input: http.Header{"Field1": []string{"asdf", "foobar"}},
expect: true,
},
{
match: MatchHeader{"Field1": []string{"*bar"}},
input: http.Header{"Field1": []string{"asdf", "foobar"}},
expect: true,
},
{
match: MatchHeader{"host": []string{"localhost"}},
input: http.Header{},
Expand Down Expand Up @@ -814,6 +829,24 @@ func TestResponseMatcher(t *testing.T) {
hdr: http.Header{"Foo": []string{"bar"}, "Foo2": []string{"baz"}},
expect: true,
},
{
require: ResponseMatcher{
Headers: http.Header{
"Foo": []string{"foo*"},
},
},
hdr: http.Header{"Foo": []string{"foobar"}},
expect: true,
},
{
require: ResponseMatcher{
Headers: http.Header{
"Foo": []string{"foo*"},
},
},
hdr: http.Header{"Foo": []string{"foobar"}},
expect: true,
},
} {
actual := tc.require.Match(tc.status, tc.hdr)
if actual != tc.expect {
Expand Down

0 comments on commit e5bbed1

Please sign in to comment.