Skip to content

Commit

Permalink
fix: handle X-Forwarded-* headers correctly (#4334)
Browse files Browse the repository at this point in the history
Signed-off-by: Andrew Haines <andrew@haines.org.nz>
  • Loading branch information
haines authored May 15, 2024
1 parent 9660e4a commit 2da4beb
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 15 deletions.
16 changes: 11 additions & 5 deletions runtime/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,12 @@ func annotateContext(ctx context.Context, mux *ServeMux, req *http.Request, rpcM
var pairs []string
for key, vals := range req.Header {
key = textproto.CanonicalMIMEHeaderKey(key)
switch key {
case xForwardedFor, xForwardedHost:
// Handled separately below
continue
}

for _, val := range vals {
// For backwards-compatibility, pass through 'authorization' header with no prefix.
if key == "Authorization" {
Expand Down Expand Up @@ -181,15 +187,15 @@ func annotateContext(ctx context.Context, mux *ServeMux, req *http.Request, rpcM
pairs = append(pairs, strings.ToLower(xForwardedHost), req.Host)
}

xff := req.Header.Values(xForwardedFor)
if addr := req.RemoteAddr; addr != "" {
if remoteIP, _, err := net.SplitHostPort(addr); err == nil {
if fwd := req.Header.Get(xForwardedFor); fwd == "" {
pairs = append(pairs, strings.ToLower(xForwardedFor), remoteIP)
} else {
pairs = append(pairs, strings.ToLower(xForwardedFor), fmt.Sprintf("%s, %s", fwd, remoteIP))
}
xff = append(xff, remoteIP)
}
}
if len(xff) > 0 {
pairs = append(pairs, strings.ToLower(xForwardedFor), strings.Join(xff, ", "))
}

if timeout != 0 {
ctx, _ = context.WithTimeout(ctx, timeout)
Expand Down
102 changes: 92 additions & 10 deletions runtime/context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,17 +113,20 @@ func TestAnnotateContext_ForwardGrpcBinaryMetadata(t *testing.T) {
}
}

func TestAnnotateContext_XForwardedFor(t *testing.T) {
func TestAnnotateContext_AddsXForwardedHeaders(t *testing.T) {
ctx := context.Background()
expectedRPCName := "/example.Example/Example"
request, err := http.NewRequestWithContext(ctx, "GET", "http://bar.foo.example.com", nil)
if err != nil {
t.Fatalf("http.NewRequestWithContext(ctx, %q, %q, nil) failed with %v; want success", "GET", "http://bar.foo.example.com", err)
}
request.Header.Add("X-Forwarded-For", "192.0.2.100") // client
request.RemoteAddr = "192.0.2.200:12345" // proxy
request.RemoteAddr = "192.0.2.100:12345" // client

annotated, err := runtime.AnnotateContext(ctx, runtime.NewServeMux(), request, expectedRPCName)
serveMux := runtime.NewServeMux(runtime.WithIncomingHeaderMatcher(func(key string) (string, bool) {
return key, true
}))

annotated, err := runtime.AnnotateContext(ctx, serveMux, request, expectedRPCName)
if err != nil {
t.Errorf("runtime.AnnotateContext(ctx, %#v) failed with %v; want success", request, err)
return
Expand All @@ -135,8 +138,46 @@ func TestAnnotateContext_XForwardedFor(t *testing.T) {
if got, want := md["x-forwarded-host"], []string{"bar.foo.example.com"}; !reflect.DeepEqual(got, want) {
t.Errorf(`md["host"] = %v; want %v`, got, want)
}
if got, want := md["x-forwarded-for"], []string{"192.0.2.100"}; !reflect.DeepEqual(got, want) {
t.Errorf(`md["x-forwarded-for"] = %v want %v`, got, want)
}
if m, ok := runtime.RPCMethod(annotated); !ok {
t.Errorf("runtime.RPCMethod(annotated) failed with no value; want %s", expectedRPCName)
} else if m != expectedRPCName {
t.Errorf("runtime.RPCMethod(annotated) failed with %s; want %s", m, expectedRPCName)
}
}

func TestAnnotateContext_AppendsToExistingXForwardedHeaders(t *testing.T) {
ctx := context.Background()
expectedRPCName := "/example.Example/Example"
request, err := http.NewRequestWithContext(ctx, "GET", "http://bar.foo.example.com", nil)
if err != nil {
t.Fatalf("http.NewRequestWithContext(ctx, %q, %q, nil) failed with %v; want success", "GET", "http://bar.foo.example.com", err)
}
request.Header.Add("X-Forwarded-Host", "qux.example.com")
request.Header.Add("X-Forwarded-For", "192.0.2.100") // client
request.Header.Add("X-Forwarded-For", "192.0.2.101, 192.0.2.102") // intermediate proxies
request.RemoteAddr = "192.0.2.200:12345" // final proxy

serveMux := runtime.NewServeMux(runtime.WithIncomingHeaderMatcher(func(key string) (string, bool) {
return key, true
}))

annotated, err := runtime.AnnotateContext(ctx, serveMux, request, expectedRPCName)
if err != nil {
t.Errorf("runtime.AnnotateContext(ctx, %#v) failed with %v; want success", request, err)
return
}
md, ok := metadata.FromOutgoingContext(annotated)
if !ok || len(md) != emptyForwardMetaCount+1 {
t.Errorf("Expected %d metadata items in context; got %v", emptyForwardMetaCount+1, md)
}
if got, want := md["x-forwarded-host"], []string{"qux.example.com"}; !reflect.DeepEqual(got, want) {
t.Errorf(`md["host"] = %v; want %v`, got, want)
}
// Note: it must be in order client, proxy1, proxy2
if got, want := md["x-forwarded-for"], []string{"192.0.2.100, 192.0.2.200"}; !reflect.DeepEqual(got, want) {
if got, want := md["x-forwarded-for"], []string{"192.0.2.100, 192.0.2.101, 192.0.2.102, 192.0.2.200"}; !reflect.DeepEqual(got, want) {
t.Errorf(`md["x-forwarded-for"] = %v want %v`, got, want)
}
if m, ok := runtime.RPCMethod(annotated); !ok {
Expand Down Expand Up @@ -356,17 +397,20 @@ func TestAnnotateIncomingContext_ForwardGrpcBinaryMetadata(t *testing.T) {
}
}

func TestAnnotateIncomingContext_XForwardedFor(t *testing.T) {
func TestAnnotateIncomingContext_AddsXForwardedHeaders(t *testing.T) {
ctx := context.Background()
expectedRPCName := "/example.Example/Example"
request, err := http.NewRequestWithContext(ctx, "GET", "http://bar.foo.example.com", nil)
if err != nil {
t.Fatalf("http.NewRequestWithContext(ctx, %q, %q, nil) failed with %v; want success", "GET", "http://bar.foo.example.com", err)
}
request.Header.Add("X-Forwarded-For", "192.0.2.100") // client
request.RemoteAddr = "192.0.2.200:12345" // proxy
request.RemoteAddr = "192.0.2.100:12345" // client

annotated, err := runtime.AnnotateIncomingContext(ctx, runtime.NewServeMux(), request, expectedRPCName)
serveMux := runtime.NewServeMux(runtime.WithIncomingHeaderMatcher(func(key string) (string, bool) {
return key, true
}))

annotated, err := runtime.AnnotateIncomingContext(ctx, serveMux, request, expectedRPCName)
if err != nil {
t.Errorf("runtime.AnnotateIncomingContext(ctx, %#v) failed with %v; want success", request, err)
return
Expand All @@ -378,8 +422,46 @@ func TestAnnotateIncomingContext_XForwardedFor(t *testing.T) {
if got, want := md["x-forwarded-host"], []string{"bar.foo.example.com"}; !reflect.DeepEqual(got, want) {
t.Errorf(`md["host"] = %v; want %v`, got, want)
}
if got, want := md["x-forwarded-for"], []string{"192.0.2.100"}; !reflect.DeepEqual(got, want) {
t.Errorf(`md["x-forwarded-for"] = %v want %v`, got, want)
}
if m, ok := runtime.RPCMethod(annotated); !ok {
t.Errorf("runtime.RPCMethod(annotated) failed with no value; want %s", expectedRPCName)
} else if m != expectedRPCName {
t.Errorf("runtime.RPCMethod(annotated) failed with %s; want %s", m, expectedRPCName)
}
}

func TestAnnotateIncomingContext_AppendsToExistingXForwardedHeaders(t *testing.T) {
ctx := context.Background()
expectedRPCName := "/example.Example/Example"
request, err := http.NewRequestWithContext(ctx, "GET", "http://bar.foo.example.com", nil)
if err != nil {
t.Fatalf("http.NewRequestWithContext(ctx, %q, %q, nil) failed with %v; want success", "GET", "http://bar.foo.example.com", err)
}
request.Header.Add("X-Forwarded-Host", "qux.example.com")
request.Header.Add("X-Forwarded-For", "192.0.2.100") // client
request.Header.Add("X-Forwarded-For", "192.0.2.101, 192.0.2.102") // intermediate proxies
request.RemoteAddr = "192.0.2.200:12345" // final proxy

serveMux := runtime.NewServeMux(runtime.WithIncomingHeaderMatcher(func(key string) (string, bool) {
return key, true
}))

annotated, err := runtime.AnnotateIncomingContext(ctx, serveMux, request, expectedRPCName)
if err != nil {
t.Errorf("runtime.AnnotateIncomingContext(ctx, %#v) failed with %v; want success", request, err)
return
}
md, ok := metadata.FromIncomingContext(annotated)
if !ok || len(md) != emptyForwardMetaCount+1 {
t.Errorf("Expected %d metadata items in context; got %v", emptyForwardMetaCount+1, md)
}
if got, want := md["x-forwarded-host"], []string{"qux.example.com"}; !reflect.DeepEqual(got, want) {
t.Errorf(`md["host"] = %v; want %v`, got, want)
}
// Note: it must be in order client, proxy1, proxy2
if got, want := md["x-forwarded-for"], []string{"192.0.2.100, 192.0.2.200"}; !reflect.DeepEqual(got, want) {
if got, want := md["x-forwarded-for"], []string{"192.0.2.100, 192.0.2.101, 192.0.2.102, 192.0.2.200"}; !reflect.DeepEqual(got, want) {
t.Errorf(`md["x-forwarded-for"] = %v want %v`, got, want)
}
if m, ok := runtime.RPCMethod(annotated); !ok {
Expand Down

0 comments on commit 2da4beb

Please sign in to comment.