diff --git a/apitest.go b/apitest.go index b075ed9..55223fd 100644 --- a/apitest.go +++ b/apitest.go @@ -46,6 +46,7 @@ type APITest struct { mocks []*Mock t TestingT httpClient *http.Client + httpRequest *http.Request transport *Transport meta map[string]interface{} started time.Time @@ -251,6 +252,12 @@ func (a *APITest) Method(method string) *Request { return a.request } +// HttpRequest defines the native `http.Request` +func (a *APITest) HttpRequest(req *http.Request) *Request { + a.httpRequest = req + return a.request +} + // Get is a convenience method for setting the request as http.MethodGet func (a *APITest) Get(url string) *Request { a.request.method = http.MethodGet @@ -878,6 +885,10 @@ func (a *APITest) serveHttp(res *httptest.ResponseRecorder, req *http.Request) { } func (a *APITest) buildRequest() *http.Request { + if a.httpRequest != nil { + return a.httpRequest + } + if len(a.request.formData) > 0 { form := url.Values{} for k := range a.request.formData { diff --git a/apitest_test.go b/apitest_test.go index 8b47b36..7981a3d 100644 --- a/apitest_test.go +++ b/apitest_test.go @@ -7,7 +7,9 @@ import ( "io/ioutil" "net/http" "net/http/cookiejar" + "net/http/httptest" "reflect" + "strings" "testing" "time" @@ -27,6 +29,30 @@ func TestApiTest_ResponseBody(t *testing.T) { End() } +func TestApiTest_HttpRequest(t *testing.T) { + handler := http.NewServeMux() + handler.HandleFunc("/hello", func(w http.ResponseWriter, r *http.Request) { + data, _ := ioutil.ReadAll(r.Body) + if string(data) != `hello` { + w.WriteHeader(http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusOK) + if r.Header.Get("key") != "val" { + t.Fatal("expected header key=val") + } + }) + + request := httptest.NewRequest(http.MethodGet, "/hello", strings.NewReader("hello")) + request.Header.Set("key", "val") + + apitest.Handler(handler). + HttpRequest(request). + Expect(t). + Status(http.StatusOK). + End() +} + func TestApiTest_AddsJSONBodyToRequest(t *testing.T) { handler := http.NewServeMux() handler.HandleFunc("/hello", func(w http.ResponseWriter, r *http.Request) {