Skip to content

Commit

Permalink
add RequestData in Context option for HTTP client (#747)
Browse files Browse the repository at this point in the history
Signed-off-by: Pablo Mercado <odacremolbap@gmail.com>
  • Loading branch information
odacremolbap authored Dec 9, 2021
1 parent 77f73c2 commit 4b69880
Show file tree
Hide file tree
Showing 4 changed files with 281 additions and 0 deletions.
48 changes: 48 additions & 0 deletions v2/protocol/http/context.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
Copyright 2021 The CloudEvents Authors
SPDX-License-Identifier: Apache-2.0
*/

package http

import (
"context"

nethttp "net/http"
"net/url"
)

type requestKey struct{}

// RequestData holds the http.Request information subset that can be
// used to retrieve HTTP information for an incoming CloudEvent.
type RequestData struct {
URL *url.URL
Header nethttp.Header
RemoteAddr string
Host string
}

// WithRequestDataAtContext uses the http.Request to add RequestData
// information to the Context.
func WithRequestDataAtContext(ctx context.Context, r *nethttp.Request) context.Context {
if r == nil {
return ctx
}

return context.WithValue(ctx, requestKey{}, &RequestData{
URL: r.URL,
Header: r.Header,
RemoteAddr: r.RemoteAddr,
Host: r.Host,
})
}

// RequestDataFromContext retrieves RequestData from the Context.
// If not set nil is returned.
func RequestDataFromContext(ctx context.Context) *RequestData {
if req := ctx.Value(requestKey{}); req != nil {
return req.(*RequestData)
}
return nil
}
126 changes: 126 additions & 0 deletions v2/protocol/http/context_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
/*
Copyright 2021 The CloudEvents Authors
SPDX-License-Identifier: Apache-2.0
*/

package http

import (
"context"
nethttp "net/http"
"net/url"
"testing"

"github.com/stretchr/testify/assert"
)

const (
tMethod = nethttp.MethodPost
)

func TestWithRequest(t *testing.T) {
testCases := map[string]struct {
request *nethttp.Request

expectedRequest *RequestData
}{
"request": {
request: newRequest("http://testhost:8080/test/path.json"),
expectedRequest: &RequestData{
Host: "testhost:8080",
URL: newURL("http://testhost:8080/test/path.json"),
Header: nethttp.Header{},
},
},
"request with headers": {
request: newRequest("http://testhost:8080/test/path.json",
requestOptionAddHeader("key1", "value1"),
requestOptionAddHeader("key2", "value2.1"),
requestOptionAddHeader("key2", "value2.2"),
),
expectedRequest: &RequestData{
Host: "testhost:8080",
URL: newURL("http://testhost:8080/test/path.json"),
Header: nethttp.Header{
"Key1": []string{"value1"},
"Key2": []string{"value2.1", "value2.2"},
},
},
},
"request with host header": {
request: newRequest("http://testhost:8080/test/path.json",
requestOptionHostHeader("alternative.host"),
),
expectedRequest: &RequestData{
Host: "alternative.host",
URL: newURL("http://testhost:8080/test/path.json"),
Header: nethttp.Header{},
},
},
"request with remote address": {
request: newRequest("http://testhost:8080/test/path.json",
requestOptionRemoteAddr("requester.address"),
),
expectedRequest: &RequestData{
Host: "testhost:8080",
URL: newURL("http://testhost:8080/test/path.json"),
Header: nethttp.Header{},
RemoteAddr: "requester.address",
},
},
"nil request": {
request: nil,
expectedRequest: nil,
},
}

for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
ctx := WithRequestDataAtContext(context.TODO(), tc.request)

req := RequestDataFromContext(ctx)
assert.Equal(t, req, tc.expectedRequest)
})
}
}

type requestOption func(*nethttp.Request)

func newRequest(url string, opts ...requestOption) *nethttp.Request {
r, err := nethttp.NewRequest(tMethod, url, nil)
if err != nil {
panic(err)
}

for _, opt := range opts {
opt(r)
}

return r
}

func requestOptionAddHeader(key, value string) requestOption {
return func(r *nethttp.Request) {
r.Header.Add(key, value)
}
}

func requestOptionHostHeader(host string) requestOption {
return func(r *nethttp.Request) {
r.Host = host
}
}

func requestOptionRemoteAddr(addr string) requestOption {
return func(r *nethttp.Request) {
r.RemoteAddr = addr
}
}

func newURL(u string) *url.URL {
parsed, err := url.Parse(u)
if err != nil {
panic(err)
}
return parsed
}
12 changes: 12 additions & 0 deletions v2/protocol/http/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -287,3 +287,15 @@ func WithRateLimiter(rl RateLimiter) Option {
return nil
}
}

// WithRequestDataAtContextMiddleware adds to the Context RequestData.
// This enables a user's dispatch handler to inspect HTTP request information by
// retrieving it from the Context.
func WithRequestDataAtContextMiddleware() Option {
return WithMiddleware(func(next nethttp.Handler) nethttp.Handler {
return nethttp.HandlerFunc(func(w nethttp.ResponseWriter, r *nethttp.Request) {
ctx := WithRequestDataAtContext(r.Context(), r)
next.ServeHTTP(w, r.WithContext(ctx))
})
})
}
95 changes: 95 additions & 0 deletions v2/protocol/http/options_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -722,3 +722,98 @@ func TestWithIsRetriableFunc(t *testing.T) {
})
}
}

func TestWithRequestDataAtContextMiddleware(t *testing.T) {
const tURL = "https://testhost:8080/test/path"
const tRemoteAddr = "remote.address:1234"
const tHost = "testhost:8080"

tHeader := http.Header{
"Key": []string{"value"},
}

u, err := url.Parse(tURL)
if err != nil {
t.Fatal(err)
}

tRequestData := &RequestData{
Host: tHost,
URL: u,
RemoteAddr: tRemoteAddr,
Header: tHeader,
}

testCases := map[string]struct {
t *Protocol
options []Option
wantApplyOptionsErr string
expectMiddlewareCount int
expectRequestData *RequestData
}{
"nil protocol": {
wantApplyOptionsErr: "http middleware option can not set nil protocol",
options: []Option{WithRequestDataAtContextMiddleware()},
},
"protocol with RequestData middleware": {
t: &Protocol{},
options: []Option{WithRequestDataAtContextMiddleware()},
expectMiddlewareCount: 1,
expectRequestData: tRequestData,
},
"protocol without RequestData middleware": {
t: &Protocol{},
expectMiddlewareCount: 0,
expectRequestData: nil,
},
}
for n, tc := range testCases {
t.Run(n, func(t *testing.T) {
err := tc.t.applyOptions(tc.options...)

if tc.wantApplyOptionsErr != "" {
if err == nil || err.Error() != tc.wantApplyOptionsErr {
t.Fatalf("Expected error '%s'. Actual '%v'", tc.wantApplyOptionsErr, err)
}
return
}
if err != nil {
t.Fatalf("Unexpected error: %v", err)
return
}

if len(tc.t.middleware) != tc.expectMiddlewareCount {
t.Fatalf("Expected number of registered middleware %d. Actual '%v'", tc.expectMiddlewareCount, len(tc.t.middleware))
return
}

ms := mockOptionsServer{
handler: func(w http.ResponseWriter, r *http.Request) {
rd := RequestDataFromContext(r.Context())
require.Equal(t, rd, tc.expectRequestData)
},
}

handler := attachMiddleware(ms, tc.t.middleware)

req, err := http.NewRequest("POST", tURL, nil)
if err != nil {
t.Fatal(err)
}
req.RemoteAddr = tRemoteAddr
req.Header = tHeader

handler.ServeHTTP(nil, req)
})
}
}

// mockOptionsServer implements http.Handler passing
// unmodified calls to the internal handler.
type mockOptionsServer struct {
handler http.HandlerFunc
}

func (m mockOptionsServer) ServeHTTP(res http.ResponseWriter, req *http.Request) {
m.handler(res, req)
}

0 comments on commit 4b69880

Please sign in to comment.