diff --git a/README.md b/README.md index da0e1b2..f2aa4a0 100644 --- a/README.md +++ b/README.md @@ -195,6 +195,18 @@ defer r.Stop() // Make sure recorder is stopped once done with it ... ``` +## Server Side + +VCR testing can also be used for creating server-side tests. Use the +`recorder.HTTPMiddleware` with an HTTP handler in order to create fixtures from +incoming requests and the handler's responses. Then, these requests can be +replayed and compared against the recorded responses to create a regression test. + +Rather than mocking/recording external HTTP interactions, this will record and +replay _incoming_ interactions with your application's HTTP server. + +See [an example here](./examples/middleware_test.go). + ## License `go-vcr` is Open Source and licensed under the [BSD diff --git a/examples/fixtures/middleware.yaml b/examples/fixtures/middleware.yaml new file mode 100644 index 0000000..29ef1ef --- /dev/null +++ b/examples/fixtures/middleware.yaml @@ -0,0 +1,163 @@ +--- +version: 2 +interactions: + - id: 0 + request: + proto: HTTP/1.1 + proto_major: 1 + proto_minor: 1 + content_length: 0 + transfer_encoding: [] + trailer: {} + host: "" + remote_addr: "" + request_uri: /request1 + body: "" + form: {} + headers: + Accept-Encoding: + - gzip + User-Agent: + - Go-http-client/1.1 + url: http://go-vcr/request1 + method: GET + response: + proto: HTTP/1.1 + proto_major: 1 + proto_minor: 1 + transfer_encoding: [] + trailer: {} + content_length: -1 + uncompressed: false + body: OK + headers: + Content-Type: + - text/plain; charset=utf-8 + Key: + - VALUE + status: 200 OK + code: 200 + duration: 0s + - id: 1 + request: + proto: HTTP/1.1 + proto_major: 1 + proto_minor: 1 + content_length: 0 + transfer_encoding: [] + trailer: {} + host: "" + remote_addr: "" + request_uri: /request2?query=example + body: "" + form: {} + headers: + Accept-Encoding: + - gzip + User-Agent: + - Go-http-client/1.1 + url: http://go-vcr/request2?query=example + method: GET + response: + proto: HTTP/1.1 + proto_major: 1 + proto_minor: 1 + transfer_encoding: [] + trailer: {} + content_length: -1 + uncompressed: false + body: |- + query=example + OK + headers: + Content-Type: + - text/plain; charset=utf-8 + Key: + - VALUE + status: 200 OK + code: 200 + duration: 0s + - id: 2 + request: + proto: HTTP/1.1 + proto_major: 1 + proto_minor: 1 + content_length: 9 + transfer_encoding: [] + trailer: {} + host: "" + remote_addr: "" + request_uri: /postform + body: key=value + form: + key: + - value + headers: + Accept-Encoding: + - gzip + Content-Length: + - "9" + Content-Type: + - application/x-www-form-urlencoded + User-Agent: + - Go-http-client/1.1 + url: http://go-vcr/postform + method: POST + response: + proto: HTTP/1.1 + proto_major: 1 + proto_minor: 1 + transfer_encoding: [] + trailer: {} + content_length: -1 + uncompressed: false + body: key=value + headers: + Content-Type: + - text/plain; charset=utf-8 + Key: + - VALUE + status: 200 OK + code: 200 + duration: 0s + - id: 3 + request: + proto: HTTP/1.1 + proto_major: 1 + proto_minor: 1 + content_length: 15 + transfer_encoding: [] + trailer: {} + host: "" + remote_addr: "" + request_uri: /postdata + body: '{"key":"value"}' + form: {} + headers: + Accept-Encoding: + - gzip + Content-Length: + - "15" + Content-Type: + - application/json + User-Agent: + - Go-http-client/1.1 + url: http://go-vcr/postdata + method: POST + response: + proto: HTTP/1.1 + proto_major: 1 + proto_minor: 1 + transfer_encoding: [] + trailer: {} + content_length: -1 + uncompressed: false + body: '{"key":"value"}' + headers: + Content-Type: + - text/plain; charset=utf-8 + Key: + - VALUE + status: 200 OK + code: 200 + duration: 0s diff --git a/examples/middleware_test.go b/examples/middleware_test.go new file mode 100644 index 0000000..f3f85d8 --- /dev/null +++ b/examples/middleware_test.go @@ -0,0 +1,96 @@ +package vcr_test + +import ( + "bytes" + "io" + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "gopkg.in/dnaeon/go-vcr.v4/pkg/cassette" + "gopkg.in/dnaeon/go-vcr.v4/pkg/recorder" +) + +func TestMiddleware(t *testing.T) { + cassetteName := "fixtures/middleware" + + // In a real-world scenario, the recorder will run outside of unit tests + // since you want to be able to record real application behavior + t.Run("RecordRealInteractionsWithMiddleware", func(t *testing.T) { + rec, err := recorder.New( + recorder.WithCassette(cassetteName), + recorder.WithMode(recorder.ModeRecordOnly), + // Use a BeforeSaveHook to remove host, remote_addr, and duration + // since they change whenever the test runs + recorder.WithHook(func(i *cassette.Interaction) error { + i.Request.Host = "" + i.Request.RemoteAddr = "" + i.Response.Duration = 0 + return nil + }, recorder.BeforeSaveHook), + ) + if err != nil { + t.Errorf("error creating recorder: %v", err) + } + + // Create the server handler with recorder middleware + handler := createHandler(rec.HTTPMiddleware) + defer rec.Stop() + + server := httptest.NewServer(handler) + defer server.Close() + + _, err = http.Get(server.URL + "/request1") + if err != nil { + t.Errorf("error making request: %v", err) + } + + _, err = http.Get(server.URL + "/request2?query=example") + if err != nil { + t.Errorf("error making request: %v", err) + } + + _, err = http.PostForm(server.URL+"/postform", url.Values{"key": []string{"value"}}) + if err != nil { + t.Errorf("error making request: %v", err) + } + + _, err = http.Post(server.URL+"/postdata", "application/json", bytes.NewBufferString(`{"key":"value"}`)) + if err != nil { + t.Errorf("error making request: %v", err) + } + }) + + t.Run("ReplayCassetteAndCompare", func(t *testing.T) { + cassette.TestServerReplay(t, cassetteName, createHandler(nil)) + }) +} + +// createHandler will return an HTTP handler with optional middleware. It will respond to +// simple requests for testing +func createHandler(middleware func(http.Handler) http.Handler) http.Handler { + mux := http.NewServeMux() + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Add("KEY", "VALUE") + + query := r.URL.Query().Encode() + if query != "" { + w.Write([]byte(query + "\n")) + } + + body, _ := io.ReadAll(r.Body) + if len(body) > 0 { + w.Write(body) + } else { + w.Write([]byte("OK")) + } + }) + + if middleware != nil { + handler = middleware(handler).ServeHTTP + } + + mux.Handle("/", handler) + return mux +} diff --git a/pkg/cassette/cassette.go b/pkg/cassette/cassette.go index f04dca1..bd3abfa 100644 --- a/pkg/cassette/cassette.go +++ b/pkg/cassette/cassette.go @@ -324,6 +324,13 @@ func (m *defaultMatcher) matcher(r *http.Request, i Request) bool { return false } + // Only ParseForm for non-GET requests since that would use query params + if r.Method == http.MethodPost || r.Method == http.MethodPut || r.Method == http.MethodPatch { + err := r.ParseForm() + if err != nil { + return false + } + } if !m.deepEqualContents(r.Form, i.Form) { return false } diff --git a/pkg/cassette/server_replay.go b/pkg/cassette/server_replay.go new file mode 100644 index 0000000..a18908d --- /dev/null +++ b/pkg/cassette/server_replay.go @@ -0,0 +1,87 @@ +package cassette + +import ( + "fmt" + "io" + "maps" + "net/http" + "net/http/httptest" + "slices" + "strings" + "testing" +) + +// ReplayAssertFunc is used to assert the results of replaying a recorded request against a handler. +// It receives the current Interaction and the httptest.ResponseRecorder. +type ReplayAssertFunc func(t *testing.T, expected *Interaction, actual *httptest.ResponseRecorder) + +// DefaultReplayAssertFunc compares the response status code, body, and headers. +// It can be overridden for more specific tests or to use your preferred assertion libraries +var DefaultReplayAssertFunc ReplayAssertFunc = func(t *testing.T, expected *Interaction, actual *httptest.ResponseRecorder) { + if expected.Response.Code != actual.Result().StatusCode { + t.Errorf("status code does not match: expected=%d actual=%d", expected.Response.Code, actual.Result().StatusCode) + } + + if expected.Response.Body != actual.Body.String() { + t.Errorf("body does not match: expected=%s actual=%s", expected.Response.Body, actual.Body.String()) + } + + if !headersEqual(expected.Response.Headers, actual.Header()) { + t.Errorf("header values do not match. expected=%v actual=%v", expected.Response.Headers, actual.Header()) + } +} + +// TestServerReplay loads a Cassette and replays each Interaction with the provided Handler, then compares the response +func TestServerReplay(t *testing.T, cassetteName string, handler http.Handler) { + t.Helper() + + c, err := Load(cassetteName) + if err != nil { + t.Errorf("unexpected error loading Cassette: %v", err) + } + + if len(c.Interactions) == 0 { + t.Error("no interactions in Cassette") + } + + for _, interaction := range c.Interactions { + t.Run(fmt.Sprintf("Interaction_%d", interaction.ID), func(t *testing.T) { + TestInteractionReplay(t, handler, interaction) + }) + } +} + +// TestInteractionReplay replays an Interaction with the provided Handler and compares the response +func TestInteractionReplay(t *testing.T, handler http.Handler, interaction *Interaction) { + t.Helper() + + req, err := interaction.GetHTTPRequest() + if err != nil { + t.Errorf("unexpected error getting interaction request: %v", err) + } + + if len(req.Form) > 0 { + req.Body = io.NopCloser(strings.NewReader(req.Form.Encode())) + } + + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + + DefaultReplayAssertFunc(t, interaction, w) +} + +func headersEqual(expected, actual http.Header) bool { + return maps.EqualFunc( + expected, actual, + func(v1, v2 []string) bool { + slices.Sort(v1) + slices.Sort(v2) + + if !slices.Equal(v1, v2) { + return false + } + + return true + }, + ) +} diff --git a/pkg/recorder/middleware.go b/pkg/recorder/middleware.go new file mode 100644 index 0000000..f5ed5c3 --- /dev/null +++ b/pkg/recorder/middleware.go @@ -0,0 +1,63 @@ +package recorder + +import ( + "bytes" + "io" + "net/http" + "net/http/httptest" +) + +// HTTPMiddleware intercepts and records all incoming requests and the server's response +func (rec *Recorder) HTTPMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ww := newPassthrough(w) + + // Tee the body so it can be read by the next handler and by the recorder + body := &bytes.Buffer{} + r.Body = io.NopCloser(io.TeeReader(r.Body, body)) + + next.ServeHTTP(ww, r) + + r.Body = io.NopCloser(body) + + // On the server side, requests do not have Host and Scheme so it must be set + r.URL.Host = "go-vcr" + r.URL.Scheme = "http" + + // copy headers from real response + for k, vv := range ww.real.Header() { + for _, v := range vv { + ww.recorder.Result().Header.Add(k, v) + } + } + + _, _ = rec.executeAndRecord(r, ww.recorder.Result()) + }) +} + +var _ http.ResponseWriter = &passthroughWriter{} + +// passthroughWriter uses the original ResponseWriter and an httptest.ResponseRecorder +// so the middleware can capture response details and passthrough to the client +type passthroughWriter struct { + recorder *httptest.ResponseRecorder + real http.ResponseWriter +} + +func newPassthrough(real http.ResponseWriter) passthroughWriter { + return passthroughWriter{recorder: httptest.NewRecorder(), real: real} +} + +func (p passthroughWriter) Header() http.Header { + return p.real.Header() +} + +func (p passthroughWriter) Write(in []byte) (int, error) { + _, _ = p.recorder.Write(in) + return p.real.Write(in) +} + +func (p passthroughWriter) WriteHeader(statusCode int) { + p.recorder.WriteHeader(statusCode) + p.real.WriteHeader(statusCode) +} diff --git a/pkg/recorder/recorder.go b/pkg/recorder/recorder.go index 9a74395..ec5ff68 100644 --- a/pkg/recorder/recorder.go +++ b/pkg/recorder/recorder.go @@ -385,7 +385,8 @@ func (rec *Recorder) getRoundTripper() http.RoundTripper { } // requestHandler proxies requests to their original destination -func (rec *Recorder) requestHandler(r *http.Request) (*cassette.Interaction, error) { +// If serverResponse is provided, this is used for the recording instead of using RoundTrip +func (rec *Recorder) requestHandler(r *http.Request, serverResponse *http.Response) (*cassette.Interaction, error) { if err := r.Context().Err(); err != nil { return nil, err } @@ -452,14 +453,23 @@ func (rec *Recorder) requestHandler(r *http.Request) (*cassette.Interaction, err if r.Body != nil && r.Body != http.NoBody { // Record the request body so we can add it to the cassette r.Body = io.NopCloser(io.TeeReader(r.Body, reqBody)) + if serverResponse != nil { + // when serverResponse is provided by middleware, it has to be read in order + // for reqBody buffer to be populated + _, _ = io.ReadAll(r.Body) + } } // Perform request to it's original destination and record the interactions + // If serverResponse is provided, use it instead var start time.Time start = time.Now() - resp, err := rec.getRoundTripper().RoundTrip(r) - if err != nil { - return nil, err + resp := serverResponse + if resp == nil { + resp, err = rec.getRoundTripper().RoundTrip(r) + if err != nil { + return nil, err + } } requestDuration := time.Since(start) defer resp.Body.Close() @@ -573,6 +583,11 @@ func (rec *Recorder) applyHooks(i *cassette.Interaction, kind HookKind) error { // RoundTrip implements the [http.RoundTripper] interface func (rec *Recorder) RoundTrip(req *http.Request) (*http.Response, error) { + return rec.executeAndRecord(req, nil) +} + +// executeAndRecord is used internally by the HTTPMiddleware to allow recording a response on the server side +func (rec *Recorder) executeAndRecord(req *http.Request, serverResponse *http.Response) (*http.Response, error) { // Passthrough mode, use real transport if rec.mode == ModePassthrough { return rec.getRoundTripper().RoundTrip(req) @@ -585,7 +600,7 @@ func (rec *Recorder) RoundTrip(req *http.Request) (*http.Response, error) { } } - interaction, err := rec.requestHandler(req) + interaction, err := rec.requestHandler(req, serverResponse) if err != nil { return nil, err }