Skip to content

Commit

Permalink
[CLOUDTRUST-4180] Use Forwarded header instead of X-Forwarded-Proto (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
basbeu authored Jul 14, 2022
1 parent 5342d72 commit bc443a4
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 12 deletions.
16 changes: 12 additions & 4 deletions http/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"encoding/json"
"net/http"
"regexp"
"strings"

errorhandler "github.com/cloudtrust/common-service/v2/errors"
"github.com/cloudtrust/common-service/v2/log"
Expand All @@ -30,6 +31,8 @@ type GenericResponse struct {
JSONableResponse interface{}
}

var protoRegex = regexp.MustCompile(`(?i)(?:proto=)(https|http)`)

// WriteResponse writes a response for a mime content type
func (r *GenericResponse) WriteResponse(w http.ResponseWriter) {
if r.Headers == nil {
Expand Down Expand Up @@ -121,10 +124,15 @@ func DecodeRequestWithHeaders(_ context.Context, req *http.Request, pathParams m
}

func getScheme(req *http.Request) string {
var xForwardedProtoHeader = req.Header.Get("X-Forwarded-Proto")

if xForwardedProtoHeader != "" {
return xForwardedProtoHeader
var forwardedHeader = req.Header.Get("Forwarded")

if forwardedHeader != "" {
// match should contain at least two elements if the protocol was specified in the Forwarded header.
// The first match (match[0]) will always be the 'proto=' capture, which we ignore.
// In the case of multiple proto parameters we only extract the first.
if match := protoRegex.FindStringSubmatch(forwardedHeader); len(match) > 1 {
return strings.ToLower(match[1])
}
}

if req.TLS == nil {
Expand Down
16 changes: 8 additions & 8 deletions http/handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,20 +179,20 @@ func checkInvalidPathParameter(t *testing.T, url string) {
assert.Equal(t, http.StatusBadRequest, res.StatusCode)
}

func genericTestDecodeRequest(ctx context.Context, tls *tls.ConnectionState, xFwdProto *string, rawQuery string) (map[string]string, error) {
return genericTestDecodeRequestWithHeader(ctx, tls, xFwdProto, rawQuery, nil)
func genericTestDecodeRequest(ctx context.Context, tls *tls.ConnectionState, forwarded *string, rawQuery string) (map[string]string, error) {
return genericTestDecodeRequestWithHeader(ctx, tls, forwarded, rawQuery, nil)
}

func genericTestDecodeRequestWithHeader(ctx context.Context, tls *tls.ConnectionState, xFwdProto *string, rawQuery string, headers map[string]string) (map[string]string, error) {
func genericTestDecodeRequestWithHeader(ctx context.Context, tls *tls.ConnectionState, forwarded *string, rawQuery string, headers map[string]string) (map[string]string, error) {
input := "the body"
var req http.Request
var url url.URL
url.RawQuery = rawQuery
req.Host = "localhost"
req.TLS = tls
req.Header = make(http.Header)
if xFwdProto != nil {
req.Header.Set("X-Forwarded-Proto", *xFwdProto)
if forwarded != nil {
req.Header.Set("Forwarded", *forwarded)
}
req.Body = ioutil.NopCloser(bytes.NewBufferString(input))
req.URL = &url
Expand Down Expand Up @@ -247,14 +247,14 @@ func TestDecodeRequestHTTPS(t *testing.T) {
}

func TestDecodeRequestForwardProto(t *testing.T) {
proto := "ftp"
forwardedHeader := "for=192.0.2.60;proto=http;by=203.0.113.43"

request, _ := genericTestDecodeRequest(context.Background(), nil, &proto, "")
request, _ := genericTestDecodeRequest(context.Background(), nil, &forwardedHeader, "")

// Minimum parameters are scheme, host and body
assert.Equal(t, 3, len(request))
assert.Equal(t, "localhost", request["host"])
assert.Equal(t, proto, request["scheme"])
assert.Equal(t, "http", request["scheme"])
assert.Equal(t, "the body", request["body"])
}

Expand Down

0 comments on commit bc443a4

Please sign in to comment.