Skip to content

Commit

Permalink
Add RoundTripper method to ghttp.Server
Browse files Browse the repository at this point in the history
  • Loading branch information
Smirl authored and onsi committed Jul 14, 2024
1 parent 0e69083 commit c549e0d
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 29 deletions.
27 changes: 27 additions & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -2424,6 +2424,33 @@ To bring it all together: there are three ways to instruct a `ghttp` server to h

When a `ghttp` server receives a request it first checks against the set of handlers registered via `RouteToHandler` if there is no such handler it proceeds to pop an `AppendHandlers` handler off the stack, if the stack of ordered handlers is empty, it will check whether `GetAllowUnhandledRequests` returns `true` or `false`. If `false` the test fails. If `true`, a response is sent with whatever `GetUnhandledRequestStatusCode` returns.

### Using a RoundTripper to route requests to the test Server

So far you have seen examples of using `server.URL()` to get the string URL of the test server. This is ok if you are testing code where you can pass the URL. In some cases you might need to pass a `http.Client` or similar.

You can use `server.RounderTripper(nil)` to create a `http.RounderTripper` which will redirect requests to the test server.

The method takes another `http.RounderTripper` to make the request to the test server, this allows chaining `http.Transports` or otherwise.

If passed `nil`, then `http.DefaultTransport` is used to make the request.

```go
Describe("The http client", func() {
var server *ghttp.Server
var httpClient *http.Client

BeforeEach(func() {
server = ghttp.NewServer()
httpClient = &http.Client{Transport: server.RounderTripper(nil)}
})

AfterEach(func() {
//shut down the server between tests
server.Close()
})
})
```

## `gbytes`: Testing Streaming Buffers

`gbytes` implements `gbytes.Buffer` - an `io.WriteCloser` that captures all input to an in-memory buffer.
Expand Down
79 changes: 50 additions & 29 deletions ghttp/test_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,26 +186,26 @@ type Server struct {
calls int
}

//Start() starts an unstarted ghttp server. It is a catastrophic error to call Start more than once (thanks, httptest).
// Start() starts an unstarted ghttp server. It is a catastrophic error to call Start more than once (thanks, httptest).
func (s *Server) Start() {
s.HTTPTestServer.Start()
}

//URL() returns a url that will hit the server
// URL() returns a url that will hit the server
func (s *Server) URL() string {
s.rwMutex.RLock()
defer s.rwMutex.RUnlock()
return s.HTTPTestServer.URL
}

//Addr() returns the address on which the server is listening.
// Addr() returns the address on which the server is listening.
func (s *Server) Addr() string {
s.rwMutex.RLock()
defer s.rwMutex.RUnlock()
return s.HTTPTestServer.Listener.Addr().String()
}

//Close() should be called at the end of each test. It spins down and cleans up the test server.
// Close() should be called at the end of each test. It spins down and cleans up the test server.
func (s *Server) Close() {
s.rwMutex.Lock()
server := s.HTTPTestServer
Expand All @@ -217,14 +217,14 @@ func (s *Server) Close() {
}
}

//ServeHTTP() makes Server an http.Handler
//When the server receives a request it handles the request in the following order:
// ServeHTTP() makes Server an http.Handler
// When the server receives a request it handles the request in the following order:
//
//1. If the request matches a handler registered with RouteToHandler, that handler is called.
//2. Otherwise, if there are handlers registered via AppendHandlers, those handlers are called in order.
//3. If all registered handlers have been called then:
// a) If AllowUnhandledRequests is set to true, the request will be handled with response code of UnhandledRequestStatusCode
// b) If AllowUnhandledRequests is false, the request will not be handled and the current test will be marked as failed.
// 1. If the request matches a handler registered with RouteToHandler, that handler is called.
// 2. Otherwise, if there are handlers registered via AppendHandlers, those handlers are called in order.
// 3. If all registered handlers have been called then:
// a) If AllowUnhandledRequests is set to true, the request will be handled with response code of UnhandledRequestStatusCode
// b) If AllowUnhandledRequests is false, the request will not be handled and the current test will be marked as failed.
func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
s.rwMutex.Lock()
defer func() {
Expand Down Expand Up @@ -280,18 +280,18 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
}
}

//ReceivedRequests is an array containing all requests received by the server (both handled and unhandled requests)
// ReceivedRequests is an array containing all requests received by the server (both handled and unhandled requests)
func (s *Server) ReceivedRequests() []*http.Request {
s.rwMutex.RLock()
defer s.rwMutex.RUnlock()

return s.receivedRequests
}

//RouteToHandler can be used to register handlers that will always handle requests that match
//the passed in method and path.
// RouteToHandler can be used to register handlers that will always handle requests that match
// the passed in method and path.
//
//The path may be either a string object or a *regexp.Regexp.
// The path may be either a string object or a *regexp.Regexp.
func (s *Server) RouteToHandler(method string, path interface{}, handler http.HandlerFunc) {
s.rwMutex.Lock()
defer s.rwMutex.Unlock()
Expand Down Expand Up @@ -337,25 +337,25 @@ func (s *Server) handlerForRoute(method string, path string) (http.HandlerFunc,
return nil, false
}

//AppendHandlers will appends http.HandlerFuncs to the server's list of registered handlers. The first incoming request is handled by the first handler, the second by the second, etc...
// AppendHandlers will appends http.HandlerFuncs to the server's list of registered handlers. The first incoming request is handled by the first handler, the second by the second, etc...
func (s *Server) AppendHandlers(handlers ...http.HandlerFunc) {
s.rwMutex.Lock()
defer s.rwMutex.Unlock()

s.requestHandlers = append(s.requestHandlers, handlers...)
}

//SetHandler overrides the registered handler at the passed in index with the passed in handler
//This is useful, for example, when a server has been set up in a shared context, but must be tweaked
//for a particular test.
// SetHandler overrides the registered handler at the passed in index with the passed in handler
// This is useful, for example, when a server has been set up in a shared context, but must be tweaked
// for a particular test.
func (s *Server) SetHandler(index int, handler http.HandlerFunc) {
s.rwMutex.Lock()
defer s.rwMutex.Unlock()

s.requestHandlers[index] = handler
}

//GetHandler returns the handler registered at the passed in index.
// GetHandler returns the handler registered at the passed in index.
func (s *Server) GetHandler(index int) http.HandlerFunc {
s.rwMutex.RLock()
defer s.rwMutex.RUnlock()
Expand All @@ -374,12 +374,12 @@ func (s *Server) Reset() {
s.routedHandlers = nil
}

//WrapHandler combines the passed in handler with the handler registered at the passed in index.
//This is useful, for example, when a server has been set up in a shared context but must be tweaked
//for a particular test.
// WrapHandler combines the passed in handler with the handler registered at the passed in index.
// This is useful, for example, when a server has been set up in a shared context but must be tweaked
// for a particular test.
//
//If the currently registered handler is A, and the new passed in handler is B then
//WrapHandler will generate a new handler that first calls A, then calls B, and assign it to index
// If the currently registered handler is A, and the new passed in handler is B then
// WrapHandler will generate a new handler that first calls A, then calls B, and assign it to index
func (s *Server) WrapHandler(index int, handler http.HandlerFunc) {
existingHandler := s.GetHandler(index)
s.SetHandler(index, CombineHandlers(existingHandler, handler))
Expand All @@ -392,34 +392,55 @@ func (s *Server) CloseClientConnections() {
s.HTTPTestServer.CloseClientConnections()
}

//SetAllowUnhandledRequests enables the server to accept unhandled requests.
// SetAllowUnhandledRequests enables the server to accept unhandled requests.
func (s *Server) SetAllowUnhandledRequests(allowUnhandledRequests bool) {
s.rwMutex.Lock()
defer s.rwMutex.Unlock()

s.AllowUnhandledRequests = allowUnhandledRequests
}

//GetAllowUnhandledRequests returns true if the server accepts unhandled requests.
// GetAllowUnhandledRequests returns true if the server accepts unhandled requests.
func (s *Server) GetAllowUnhandledRequests() bool {
s.rwMutex.RLock()
defer s.rwMutex.RUnlock()

return s.AllowUnhandledRequests
}

//SetUnhandledRequestStatusCode status code to be returned when the server receives unhandled requests
// SetUnhandledRequestStatusCode status code to be returned when the server receives unhandled requests
func (s *Server) SetUnhandledRequestStatusCode(statusCode int) {
s.rwMutex.Lock()
defer s.rwMutex.Unlock()

s.UnhandledRequestStatusCode = statusCode
}

//GetUnhandledRequestStatusCode returns the current status code being returned for unhandled requests
// GetUnhandledRequestStatusCode returns the current status code being returned for unhandled requests
func (s *Server) GetUnhandledRequestStatusCode() int {
s.rwMutex.RLock()
defer s.rwMutex.RUnlock()

return s.UnhandledRequestStatusCode
}

// RoundTripper returns a RoundTripper which updates requests to point to the server.
// This is useful when you want to use the server as a RoundTripper in an http.Client.
// If rt is nil, http.DefaultTransport is used.
func (s *Server) RoundTripper(rt http.RoundTripper) http.RoundTripper {
if rt == nil {
rt = http.DefaultTransport
}
return RoundTripperFunc(func(r *http.Request) (*http.Response, error) {
r.URL.Scheme = "http"
r.URL.Host = s.Addr()
return rt.RoundTrip(r)
})
}

// Helper type for creating a RoundTripper from a function
type RoundTripperFunc func(*http.Request) (*http.Response, error)

func (fn RoundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) {
return fn(r)
}
59 changes: 59 additions & 0 deletions ghttp/test_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1190,4 +1190,63 @@ var _ = Describe("TestServer", func() {
})
})
})

Describe("RoundTripper", func() {
var called []string
BeforeEach(func() {
called = []string{}
s.RouteToHandler("GET", "/routed", func(w http.ResponseWriter, req *http.Request) {
called = append(called, "get")
})
s.RouteToHandler("POST", "/routed", func(w http.ResponseWriter, req *http.Request) {
called = append(called, "post")
})
})

It("should send http traffic to test server with default transport", func() {
client := http.Client{Transport: s.RoundTripper(nil)}
client.Get("http://example.com/routed")
client.Post("http://example.com/routed", "application/json", nil)
client.Get("http://foo.bar/routed")
client.Post("http://foo.bar/routed", "application/json", nil)
Expect(called).Should(Equal([]string{"get", "post", "get", "post"}))
})

It("should send https traffic to test server with default transport", func() {
client := http.Client{Transport: s.RoundTripper(nil)}
client.Get("https://example.com/routed")
client.Post("https://example.com/routed", "application/json", nil)
client.Get("https://foo.bar/routed")
client.Post("https://foo.bar/routed", "application/json", nil)
Expect(called).Should(Equal([]string{"get", "post", "get", "post"}))
})

It("should send http traffic to test server with default transport", func() {
transport := http.Transport{}
client := http.Client{Transport: s.RoundTripper(&transport)}
client.Get("http://example.com/routed")
client.Post("http://example.com/routed", "application/json", nil)
client.Get("http://foo.bar/routed")
client.Post("http://foo.bar/routed", "application/json", nil)
Expect(called).Should(Equal([]string{"get", "post", "get", "post"}))
})

It("should send http traffic to test server with default transport", func() {
transport := http.Transport{}
client := http.Client{Transport: s.RoundTripper(&transport)}
client.Get("https://example.com/routed")
client.Post("https://example.com/routed", "application/json", nil)
client.Get("https://foo.bar/routed")
client.Post("https://foo.bar/routed", "application/json", nil)
Expect(called).Should(Equal([]string{"get", "post", "get", "post"}))
})

It("should not change the path of the request", func() {
client := http.Client{Transport: s.RoundTripper(nil)}
client.Get("https://example.com/routed")
Expect(called).Should(Equal([]string{"get"}))
Expect(s.ReceivedRequests()).Should(HaveLen(1))
Expect(s.ReceivedRequests()[0].URL.Path).Should(Equal("/routed"))
})
})
})

0 comments on commit c549e0d

Please sign in to comment.