Skip to content

Commit

Permalink
Merge pull request #47 from xmidt-org/url-parsing
Browse files Browse the repository at this point in the history
Added a customizable function for parsing the request URL
  • Loading branch information
johnabass authored Nov 21, 2019
2 parents 842d3d1 + 06b0783 commit 0b3d3fb
Show file tree
Hide file tree
Showing 8 changed files with 97 additions and 7 deletions.
47 changes: 46 additions & 1 deletion basculehttp/constructor.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"net/http"
"net/textproto"
"net/url"
"strings"

"github.com/go-kit/kit/log"
Expand All @@ -24,11 +25,33 @@ const (
DefaultHeaderDelimiter = " "
)

// ParseURL is a function that modifies the url given then returns it.
type ParseURL func(*url.URL) (*url.URL, error)

// DefaultParseURLFunc does nothing. It returns the same url it received.
func DefaultParseURLFunc(u *url.URL) (*url.URL, error) {
return u, nil
}

// CreateRemovePrefixURLFunc parses the URL by removing the prefix specified.
func CreateRemovePrefixURLFunc(prefix string, next ParseURL) ParseURL {
return func(u *url.URL) (*url.URL, error) {
escapedPath := u.EscapedPath()
if !strings.HasPrefix(escapedPath, prefix) {
return nil, errors.New("unexpected URL, did not start with expected prefix")
}
u.Path = escapedPath[len(prefix):]
u.RawPath = escapedPath[len(prefix):]
return next(u)
}
}

type constructor struct {
headerName string
headerDelimiter string
authorizations map[bascule.Authorization]TokenFactory
getLogger func(context.Context) bascule.Logger
parseURL ParseURL
onErrorResponse OnErrorResponse
}

Expand All @@ -38,6 +61,16 @@ func (c *constructor) decorate(next http.Handler) http.Handler {
if logger == nil {
logger = bascule.GetDefaultLoggerFunc(request.Context())
}

// copy the URL before modifying it
urlVal := *request.URL
u, err := c.parseURL(&urlVal)
if err != nil {
c.error(logger, GetURLFailed, "", emperror.WrapWith(err, "failed to get URL", "URL", request.URL))
WriteResponse(response, http.StatusForbidden, err)
return
}

authorization := request.Header.Get(c.headerName)
if len(authorization) == 0 {
err := errors.New("no authorization header")
Expand Down Expand Up @@ -80,7 +113,7 @@ func (c *constructor) decorate(next http.Handler) http.Handler {
Authorization: key,
Token: token,
Request: bascule.Request{
URL: request.URL.EscapedPath(),
URL: u,
Method: request.Method,
},
},
Expand Down Expand Up @@ -111,6 +144,7 @@ func WithHeaderName(headerName string) COption {
}
}

// WithHeaderDelimiter sets the value expected between the authorization key and token.
func WithHeaderDelimiter(delimiter string) COption {
return func(c *constructor) {
if len(delimiter) > 0 {
Expand All @@ -134,6 +168,16 @@ func WithCLogger(getLogger func(context.Context) bascule.Logger) COption {
}
}

// WithParseURLFunc sets the function to use to make any changes to the URL
// before it is added to the context.
func WithParseURLFunc(parseURL ParseURL) COption {
return func(c *constructor) {
if parseURL != nil {
c.parseURL = parseURL
}
}
}

// WithCErrorResponseFunc sets the function that is called when an error occurs.
func WithCErrorResponseFunc(f OnErrorResponse) COption {
return func(c *constructor) {
Expand All @@ -150,6 +194,7 @@ func NewConstructor(options ...COption) func(http.Handler) http.Handler {
headerDelimiter: DefaultHeaderDelimiter,
authorizations: make(map[bascule.Authorization]TokenFactory),
getLogger: bascule.GetDefaultLoggerFunc,
parseURL: DefaultParseURLFunc,
onErrorResponse: DefaultOnErrorResponse,
}

Expand Down
16 changes: 15 additions & 1 deletion basculehttp/constructor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@ import (
func TestConstructor(t *testing.T) {
testHeader := "test header"
testDelimiter := "="

c := NewConstructor(
WithHeaderName(testHeader),
WithHeaderDelimiter(testDelimiter),
WithTokenFactory("Basic", BasicTokenFactory{"codex": "codex"}),
WithCLogger(func(_ context.Context) bascule.Logger {
return bascule.Logger(log.NewJSONLogger(log.NewSyncWriter(os.Stdout)))
}),
WithParseURLFunc(CreateRemovePrefixURLFunc("/test", DefaultParseURLFunc)),
WithCErrorResponseFunc(DefaultOnErrorResponse),
)
c2 := NewConstructor(
Expand All @@ -35,41 +37,53 @@ func TestConstructor(t *testing.T) {
requestHeaderKey string
requestHeaderValue string
expectedStatusCode int
endpoint string
}{
{
description: "Success",
constructor: c,
requestHeaderKey: testHeader,
requestHeaderValue: "Basic=Y29kZXg6Y29kZXg=",
expectedStatusCode: http.StatusOK,
endpoint: "/test",
},
{
description: "URL Parsing Error",
constructor: c,
endpoint: "/blah",
expectedStatusCode: http.StatusForbidden,
},
{
description: "No Authorization Header Error",
constructor: c2,
requestHeaderKey: DefaultHeaderName,
requestHeaderValue: "",
expectedStatusCode: http.StatusForbidden,
endpoint: "/",
},
{
description: "No Space in Auth Header Error",
constructor: c,
requestHeaderKey: testHeader,
requestHeaderValue: "abcd",
expectedStatusCode: http.StatusBadRequest,
endpoint: "/test",
},
{
description: "Key Not Supported Error",
constructor: c2,
requestHeaderKey: DefaultHeaderName,
requestHeaderValue: "abcd ",
expectedStatusCode: http.StatusForbidden,
endpoint: "/test",
},
{
description: "Parse and Validate Error",
constructor: c,
requestHeaderKey: testHeader,
requestHeaderValue: "Basic=AFJDK",
expectedStatusCode: http.StatusForbidden,
endpoint: "/test",
},
}
for _, tc := range tests {
Expand All @@ -78,7 +92,7 @@ func TestConstructor(t *testing.T) {
handler := tc.constructor(next)

writer := httptest.NewRecorder()
req := httptest.NewRequest("get", "/", nil)
req := httptest.NewRequest("get", tc.endpoint, nil)
req.Header.Add(tc.requestHeaderKey, tc.requestHeaderValue)
handler.ServeHTTP(writer, req)
assert.Equal(tc.expectedStatusCode, writer.Code)
Expand Down
1 change: 1 addition & 0 deletions basculehttp/errorResponseReason.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ const (
InvalidHeader
KeyNotSupported
ParseFailed
GetURLFailed
MissingAuthentication
ChecksNotFound
ChecksFailed
Expand Down
18 changes: 16 additions & 2 deletions basculehttp/errorresponsereason_string.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 5 additions & 1 deletion basculehttp/listener_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"net/http"
"net/http/httptest"
"net/url"
"testing"

"github.com/stretchr/testify/assert"
Expand All @@ -29,11 +30,14 @@ func TestListenerDecorator(t *testing.T) {
handler.ServeHTTP(writer, req)
assert.Equal(http.StatusForbidden, writer.Code)

u, err := url.ParseRequestURI("/")
assert.Nil(err)

ctx := bascule.WithAuthentication(context.Background(), bascule.Authentication{
Authorization: "jwt",
Token: bascule.NewToken("", "", bascule.Attributes{}),
Request: bascule.Request{
URL: "/",
URL: u,
Method: "get",
},
})
Expand Down
8 changes: 8 additions & 0 deletions basculehttp/notfoundbehavior_string.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion context.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package bascule

import (
"context"
"net/url"
)

// Authorization represents the authorization mechanism performed on the token,
Expand All @@ -18,7 +19,7 @@ type Authentication struct {
// Request holds request information that may be useful for validating the
// token.
type Request struct {
URL string
URL *url.URL
Method string
}

Expand Down
5 changes: 4 additions & 1 deletion context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@ package bascule

import (
"context"
"net/url"
"testing"

"github.com/stretchr/testify/assert"
)

func TestContext(t *testing.T) {
assert := assert.New(t)
u, err := url.ParseRequestURI("/a/b/c")
assert.Nil(err)
expectedAuth := Authentication{
Authorization: "authorization string",
Token: simpleToken{
Expand All @@ -17,7 +20,7 @@ func TestContext(t *testing.T) {
attributes: map[string]interface{}{"testkey": "testval", "attr": 5},
},
Request: Request{
URL: "/a/b/c",
URL: u,
Method: "GET",
},
}
Expand Down

0 comments on commit 0b3d3fb

Please sign in to comment.