From d7f6baccf1c1f8434bb70439193868321a64701e Mon Sep 17 00:00:00 2001 From: Mark Sagi-Kazar Date: Wed, 29 Sep 2021 00:17:36 +0200 Subject: [PATCH 1/5] chore: bump Go version to 1.18 Signed-off-by: Mark Sagi-Kazar --- go.mod | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go.mod b/go.mod index bfd854d27..e9bbe6116 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/go-kit/kit -go 1.17 +go 1.18 require ( github.com/VividCortex/gohistogram v1.0.0 From d0b7cac6e42f1b36c1e7be497f66c47ca3d920d5 Mon Sep 17 00:00:00 2001 From: Mark Sagi-Kazar Date: Wed, 29 Sep 2021 00:17:58 +0200 Subject: [PATCH 2/5] feat: add generics to the endpoint layer Signed-off-by: Mark Sagi-Kazar --- endpoint/endpoint.go | 8 ++++---- endpoint/endpoint_example_test.go | 14 +++++++------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/endpoint/endpoint.go b/endpoint/endpoint.go index 6e9da3679..a9dc236f6 100644 --- a/endpoint/endpoint.go +++ b/endpoint/endpoint.go @@ -6,20 +6,20 @@ import ( // Endpoint is the fundamental building block of servers and clients. // It represents a single RPC method. -type Endpoint func(ctx context.Context, request interface{}) (response interface{}, err error) +type Endpoint[Req any, Resp any] func(ctx context.Context, request Req) (response Resp, err error) // Nop is an endpoint that does nothing and returns a nil error. // Useful for tests. func Nop(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil } // Middleware is a chainable behavior modifier for endpoints. -type Middleware func(Endpoint) Endpoint +type Middleware[Req any, Resp any] func(Endpoint[Req, Resp]) Endpoint[Req, Resp] // Chain is a helper function for composing middlewares. Requests will // traverse them in the order they're declared. That is, the first middleware // is treated as the outermost middleware. -func Chain(outer Middleware, others ...Middleware) Middleware { - return func(next Endpoint) Endpoint { +func Chain[Req any, Resp any](outer Middleware[Req, Resp], others ...Middleware[Req, Resp]) Middleware[Req, Resp] { + return func(next Endpoint[Req, Resp]) Endpoint[Req, Resp] { for i := len(others) - 1; i >= 0; i-- { // reverse next = others[i](next) } diff --git a/endpoint/endpoint_example_test.go b/endpoint/endpoint_example_test.go index e95062305..c3ba71e5b 100644 --- a/endpoint/endpoint_example_test.go +++ b/endpoint/endpoint_example_test.go @@ -9,9 +9,9 @@ import ( func ExampleChain() { e := endpoint.Chain( - annotate("first"), - annotate("second"), - annotate("third"), + annotate[any, any]("first"), + annotate[any, any]("second"), + annotate[any, any]("third"), )(myEndpoint) if _, err := e(ctx, req); err != nil { @@ -33,13 +33,13 @@ var ( req = struct{}{} ) -func annotate(s string) endpoint.Middleware { - return func(next endpoint.Endpoint) endpoint.Endpoint { - return func(ctx context.Context, request interface{}) (interface{}, error) { +func annotate[Req any, Resp any](s string) endpoint.Middleware[Req, Resp] { + return func(next endpoint.Endpoint[Req, Resp]) endpoint.Endpoint[Req, Resp] { + return endpoint.Endpoint[Req, Resp](func(ctx context.Context, request Req) (Resp, error) { fmt.Println(s, "pre") defer fmt.Println(s, "post") return next(ctx, request) - } + }) } } From b0e403baa98cdfa3e6e762934c19613c8866ed1f Mon Sep 17 00:00:00 2001 From: Mark Sagi-Kazar Date: Wed, 29 Sep 2021 00:29:03 +0200 Subject: [PATCH 3/5] feat: support generic endpoints in http transport Signed-off-by: Mark Sagi-Kazar --- transport/http2/doc.go | 2 + transport/http2/encode_decode.go | 36 ++ transport/http2/example_test.go | 36 ++ transport/http2/request_response_funcs.go | 133 ++++++ .../http2/request_response_funcs_test.go | 30 ++ transport/http2/server.go | 244 ++++++++++ transport/http2/server_test.go | 442 ++++++++++++++++++ 7 files changed, 923 insertions(+) create mode 100644 transport/http2/doc.go create mode 100644 transport/http2/encode_decode.go create mode 100644 transport/http2/example_test.go create mode 100644 transport/http2/request_response_funcs.go create mode 100644 transport/http2/request_response_funcs_test.go create mode 100644 transport/http2/server.go create mode 100644 transport/http2/server_test.go diff --git a/transport/http2/doc.go b/transport/http2/doc.go new file mode 100644 index 000000000..e64010358 --- /dev/null +++ b/transport/http2/doc.go @@ -0,0 +1,2 @@ +// Package http provides a general purpose HTTP binding for endpoints. +package http diff --git a/transport/http2/encode_decode.go b/transport/http2/encode_decode.go new file mode 100644 index 000000000..dc1423fe7 --- /dev/null +++ b/transport/http2/encode_decode.go @@ -0,0 +1,36 @@ +package http + +import ( + "context" + "net/http" +) + +// DecodeRequestFunc extracts a user-domain request object from an HTTP +// request object. It's designed to be used in HTTP servers, for server-side +// endpoints. One straightforward DecodeRequestFunc could be something that +// JSON decodes from the request body to the concrete request type. +type DecodeRequestFunc[Req any] func(context.Context, *http.Request) (request Req, err error) + +// EncodeRequestFunc encodes the passed request object into the HTTP request +// object. It's designed to be used in HTTP clients, for client-side +// endpoints. One straightforward EncodeRequestFunc could be something that JSON +// encodes the object directly to the request body. +type EncodeRequestFunc func(context.Context, *http.Request, interface{}) error + +// CreateRequestFunc creates an outgoing HTTP request based on the passed +// request object. It's designed to be used in HTTP clients, for client-side +// endpoints. It's a more powerful version of EncodeRequestFunc, and can be used +// if more fine-grained control of the HTTP request is required. +type CreateRequestFunc func(context.Context, interface{}) (*http.Request, error) + +// EncodeResponseFunc encodes the passed response object to the HTTP response +// writer. It's designed to be used in HTTP servers, for server-side +// endpoints. One straightforward EncodeResponseFunc could be something that +// JSON encodes the object directly to the response body. +type EncodeResponseFunc[Resp any] func(context.Context, http.ResponseWriter, Resp) error + +// DecodeResponseFunc extracts a user-domain response object from an HTTP +// response object. It's designed to be used in HTTP clients, for client-side +// endpoints. One straightforward DecodeResponseFunc could be something that +// JSON decodes from the response body to the concrete response type. +type DecodeResponseFunc func(context.Context, *http.Response) (response interface{}, err error) diff --git a/transport/http2/example_test.go b/transport/http2/example_test.go new file mode 100644 index 000000000..65397c3be --- /dev/null +++ b/transport/http2/example_test.go @@ -0,0 +1,36 @@ +package http + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" +) + +func ExamplePopulateRequestContext() { + handler := NewServer[any, any]( + func(ctx context.Context, request interface{}) (response interface{}, err error) { + fmt.Println("Method", ctx.Value(ContextKeyRequestMethod).(string)) + fmt.Println("RequestPath", ctx.Value(ContextKeyRequestPath).(string)) + fmt.Println("RequestURI", ctx.Value(ContextKeyRequestURI).(string)) + fmt.Println("X-Request-ID", ctx.Value(ContextKeyRequestXRequestID).(string)) + return struct{}{}, nil + }, + func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil }, + func(context.Context, http.ResponseWriter, interface{}) error { return nil }, + ServerBefore[any, any](PopulateRequestContext), + ) + + server := httptest.NewServer(handler) + defer server.Close() + + req, _ := http.NewRequest("PATCH", fmt.Sprintf("%s/search?q=sympatico", server.URL), nil) + req.Header.Set("X-Request-Id", "a1b2c3d4e5") + http.DefaultClient.Do(req) + + // Output: + // Method PATCH + // RequestPath /search + // RequestURI /search?q=sympatico + // X-Request-ID a1b2c3d4e5 +} diff --git a/transport/http2/request_response_funcs.go b/transport/http2/request_response_funcs.go new file mode 100644 index 000000000..8f92b3bc7 --- /dev/null +++ b/transport/http2/request_response_funcs.go @@ -0,0 +1,133 @@ +package http + +import ( + "context" + "net/http" +) + +// RequestFunc may take information from an HTTP request and put it into a +// request context. In Servers, RequestFuncs are executed prior to invoking the +// endpoint. In Clients, RequestFuncs are executed after creating the request +// but prior to invoking the HTTP client. +type RequestFunc func(context.Context, *http.Request) context.Context + +// ServerResponseFunc may take information from a request context and use it to +// manipulate a ResponseWriter. ServerResponseFuncs are only executed in +// servers, after invoking the endpoint but prior to writing a response. +type ServerResponseFunc func(context.Context, http.ResponseWriter) context.Context + +// ClientResponseFunc may take information from an HTTP request and make the +// response available for consumption. ClientResponseFuncs are only executed in +// clients, after a request has been made, but prior to it being decoded. +type ClientResponseFunc func(context.Context, *http.Response) context.Context + +// SetContentType returns a ServerResponseFunc that sets the Content-Type header +// to the provided value. +func SetContentType(contentType string) ServerResponseFunc { + return SetResponseHeader("Content-Type", contentType) +} + +// SetResponseHeader returns a ServerResponseFunc that sets the given header. +func SetResponseHeader(key, val string) ServerResponseFunc { + return func(ctx context.Context, w http.ResponseWriter) context.Context { + w.Header().Set(key, val) + return ctx + } +} + +// SetRequestHeader returns a RequestFunc that sets the given header. +func SetRequestHeader(key, val string) RequestFunc { + return func(ctx context.Context, r *http.Request) context.Context { + r.Header.Set(key, val) + return ctx + } +} + +// PopulateRequestContext is a RequestFunc that populates several values into +// the context from the HTTP request. Those values may be extracted using the +// corresponding ContextKey type in this package. +func PopulateRequestContext(ctx context.Context, r *http.Request) context.Context { + for k, v := range map[contextKey]string{ + ContextKeyRequestMethod: r.Method, + ContextKeyRequestURI: r.RequestURI, + ContextKeyRequestPath: r.URL.Path, + ContextKeyRequestProto: r.Proto, + ContextKeyRequestHost: r.Host, + ContextKeyRequestRemoteAddr: r.RemoteAddr, + ContextKeyRequestXForwardedFor: r.Header.Get("X-Forwarded-For"), + ContextKeyRequestXForwardedProto: r.Header.Get("X-Forwarded-Proto"), + ContextKeyRequestAuthorization: r.Header.Get("Authorization"), + ContextKeyRequestReferer: r.Header.Get("Referer"), + ContextKeyRequestUserAgent: r.Header.Get("User-Agent"), + ContextKeyRequestXRequestID: r.Header.Get("X-Request-Id"), + ContextKeyRequestAccept: r.Header.Get("Accept"), + } { + ctx = context.WithValue(ctx, k, v) + } + return ctx +} + +type contextKey int + +const ( + // ContextKeyRequestMethod is populated in the context by + // PopulateRequestContext. Its value is r.Method. + ContextKeyRequestMethod contextKey = iota + + // ContextKeyRequestURI is populated in the context by + // PopulateRequestContext. Its value is r.RequestURI. + ContextKeyRequestURI + + // ContextKeyRequestPath is populated in the context by + // PopulateRequestContext. Its value is r.URL.Path. + ContextKeyRequestPath + + // ContextKeyRequestProto is populated in the context by + // PopulateRequestContext. Its value is r.Proto. + ContextKeyRequestProto + + // ContextKeyRequestHost is populated in the context by + // PopulateRequestContext. Its value is r.Host. + ContextKeyRequestHost + + // ContextKeyRequestRemoteAddr is populated in the context by + // PopulateRequestContext. Its value is r.RemoteAddr. + ContextKeyRequestRemoteAddr + + // ContextKeyRequestXForwardedFor is populated in the context by + // PopulateRequestContext. Its value is r.Header.Get("X-Forwarded-For"). + ContextKeyRequestXForwardedFor + + // ContextKeyRequestXForwardedProto is populated in the context by + // PopulateRequestContext. Its value is r.Header.Get("X-Forwarded-Proto"). + ContextKeyRequestXForwardedProto + + // ContextKeyRequestAuthorization is populated in the context by + // PopulateRequestContext. Its value is r.Header.Get("Authorization"). + ContextKeyRequestAuthorization + + // ContextKeyRequestReferer is populated in the context by + // PopulateRequestContext. Its value is r.Header.Get("Referer"). + ContextKeyRequestReferer + + // ContextKeyRequestUserAgent is populated in the context by + // PopulateRequestContext. Its value is r.Header.Get("User-Agent"). + ContextKeyRequestUserAgent + + // ContextKeyRequestXRequestID is populated in the context by + // PopulateRequestContext. Its value is r.Header.Get("X-Request-Id"). + ContextKeyRequestXRequestID + + // ContextKeyRequestAccept is populated in the context by + // PopulateRequestContext. Its value is r.Header.Get("Accept"). + ContextKeyRequestAccept + + // ContextKeyResponseHeaders is populated in the context whenever a + // ServerFinalizerFunc is specified. Its value is of type http.Header, and + // is captured only once the entire response has been written. + ContextKeyResponseHeaders + + // ContextKeyResponseSize is populated in the context whenever a + // ServerFinalizerFunc is specified. Its value is of type int64. + ContextKeyResponseSize +) diff --git a/transport/http2/request_response_funcs_test.go b/transport/http2/request_response_funcs_test.go new file mode 100644 index 000000000..a6cfbfde7 --- /dev/null +++ b/transport/http2/request_response_funcs_test.go @@ -0,0 +1,30 @@ +package http_test + +import ( + "context" + "net/http/httptest" + "testing" + + httptransport "github.com/go-kit/kit/transport/http2" +) + +func TestSetHeader(t *testing.T) { + const ( + key = "X-Foo" + val = "12345" + ) + r := httptest.NewRecorder() + httptransport.SetResponseHeader(key, val)(context.Background(), r) + if want, have := val, r.Header().Get(key); want != have { + t.Errorf("want %q, have %q", want, have) + } +} + +func TestSetContentType(t *testing.T) { + const contentType = "application/json" + r := httptest.NewRecorder() + httptransport.SetContentType(contentType)(context.Background(), r) + if want, have := contentType, r.Header().Get("Content-Type"); want != have { + t.Errorf("want %q, have %q", want, have) + } +} diff --git a/transport/http2/server.go b/transport/http2/server.go new file mode 100644 index 000000000..9594a5ef6 --- /dev/null +++ b/transport/http2/server.go @@ -0,0 +1,244 @@ +package http + +import ( + "context" + "encoding/json" + "net/http" + + "github.com/go-kit/kit/endpoint" + "github.com/go-kit/kit/transport" + "github.com/go-kit/log" +) + +// Server wraps an endpoint and implements http.Handler. +type Server[Req any, Resp any] struct { + e endpoint.Endpoint[Req, Resp] + dec DecodeRequestFunc[Req] + enc EncodeResponseFunc[Resp] + before []RequestFunc + after []ServerResponseFunc + errorEncoder ErrorEncoder + finalizer []ServerFinalizerFunc + errorHandler transport.ErrorHandler +} + +// NewServer constructs a new server, which implements http.Handler and wraps +// the provided endpoint. +func NewServer[Req any, Resp any]( + e endpoint.Endpoint[Req, Resp], + dec DecodeRequestFunc[Req], + enc EncodeResponseFunc[Resp], + options ...ServerOption[Req, Resp], +) *Server[Req, Resp] { + s := &Server[Req, Resp]{ + e: e, + dec: dec, + enc: enc, + errorEncoder: DefaultErrorEncoder, + errorHandler: transport.NewLogErrorHandler(log.NewNopLogger()), + } + for _, option := range options { + option(s) + } + return s +} + +// ServerOption sets an optional parameter for servers. +type ServerOption[Req any, Resp any] func(*Server[Req, Resp]) + +// ServerBefore functions are executed on the HTTP request object before the +// request is decoded. +func ServerBefore[Req any, Resp any](before ...RequestFunc) ServerOption[Req, Resp] { + return func(s *Server[Req, Resp]) { s.before = append(s.before, before...) } +} + +// ServerAfter functions are executed on the HTTP response writer after the +// endpoint is invoked, but before anything is written to the client. +func ServerAfter[Req any, Resp any](after ...ServerResponseFunc) ServerOption[Req, Resp] { + return func(s *Server[Req, Resp]) { s.after = append(s.after, after...) } +} + +// ServerErrorEncoder is used to encode errors to the http.ResponseWriter +// whenever they're encountered in the processing of a request. Clients can +// use this to provide custom error formatting and response codes. By default, +// errors will be written with the DefaultErrorEncoder. +func ServerErrorEncoder[Req any, Resp any](ee ErrorEncoder) ServerOption[Req, Resp] { + return func(s *Server[Req, Resp]) { s.errorEncoder = ee } +} + +// ServerErrorLogger is used to log non-terminal errors. By default, no errors +// are logged. This is intended as a diagnostic measure. Finer-grained control +// of error handling, including logging in more detail, should be performed in a +// custom ServerErrorEncoder or ServerFinalizer, both of which have access to +// the context. +// Deprecated: Use ServerErrorHandler instead. +func ServerErrorLogger[Req any, Resp any](logger log.Logger) ServerOption[Req, Resp] { + return func(s *Server[Req, Resp]) { s.errorHandler = transport.NewLogErrorHandler(logger) } +} + +// ServerErrorHandler is used to handle non-terminal errors. By default, non-terminal errors +// are ignored. This is intended as a diagnostic measure. Finer-grained control +// of error handling, including logging in more detail, should be performed in a +// custom ServerErrorEncoder or ServerFinalizer, both of which have access to +// the context. +func ServerErrorHandler[Req any, Resp any](errorHandler transport.ErrorHandler) ServerOption[Req, Resp] { + return func(s *Server[Req, Resp]) { s.errorHandler = errorHandler } +} + +// ServerFinalizer is executed at the end of every HTTP request. +// By default, no finalizer is registered. +func ServerFinalizer[Req any, Resp any](f ...ServerFinalizerFunc) ServerOption[Req, Resp] { + return func(s *Server[Req, Resp]) { s.finalizer = append(s.finalizer, f...) } +} + +// ServeHTTP implements http.Handler. +func (s Server[Req, Resp]) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + if len(s.finalizer) > 0 { + iw := &interceptingWriter{w, http.StatusOK, 0} + defer func() { + ctx = context.WithValue(ctx, ContextKeyResponseHeaders, iw.Header()) + ctx = context.WithValue(ctx, ContextKeyResponseSize, iw.written) + for _, f := range s.finalizer { + f(ctx, iw.code, r) + } + }() + w = iw + } + + for _, f := range s.before { + ctx = f(ctx, r) + } + + request, err := s.dec(ctx, r) + if err != nil { + s.errorHandler.Handle(ctx, err) + s.errorEncoder(ctx, err, w) + return + } + + response, err := s.e(ctx, request) + if err != nil { + s.errorHandler.Handle(ctx, err) + s.errorEncoder(ctx, err, w) + return + } + + for _, f := range s.after { + ctx = f(ctx, w) + } + + if err := s.enc(ctx, w, response); err != nil { + s.errorHandler.Handle(ctx, err) + s.errorEncoder(ctx, err, w) + return + } +} + +// ErrorEncoder is responsible for encoding an error to the ResponseWriter. +// Users are encouraged to use custom ErrorEncoders to encode HTTP errors to +// their clients, and will likely want to pass and check for their own error +// types. See the example shipping/handling service. +type ErrorEncoder func(ctx context.Context, err error, w http.ResponseWriter) + +// ServerFinalizerFunc can be used to perform work at the end of an HTTP +// request, after the response has been written to the client. The principal +// intended use is for request logging. In addition to the response code +// provided in the function signature, additional response parameters are +// provided in the context under keys with the ContextKeyResponse prefix. +type ServerFinalizerFunc func(ctx context.Context, code int, r *http.Request) + +// NopRequestDecoder is a DecodeRequestFunc that can be used for requests that do not +// need to be decoded, and simply returns nil, nil. +func NopRequestDecoder(ctx context.Context, r *http.Request) (interface{}, error) { + return nil, nil +} + +// EncodeJSONResponse is a EncodeResponseFunc that serializes the response as a +// JSON object to the ResponseWriter. Many JSON-over-HTTP services can use it as +// a sensible default. If the response implements Headerer, the provided headers +// will be applied to the response. If the response implements StatusCoder, the +// provided StatusCode will be used instead of 200. +func EncodeJSONResponse(_ context.Context, w http.ResponseWriter, response interface{}) error { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + if headerer, ok := response.(Headerer); ok { + for k, values := range headerer.Headers() { + for _, v := range values { + w.Header().Add(k, v) + } + } + } + code := http.StatusOK + if sc, ok := response.(StatusCoder); ok { + code = sc.StatusCode() + } + w.WriteHeader(code) + if code == http.StatusNoContent { + return nil + } + return json.NewEncoder(w).Encode(response) +} + +// DefaultErrorEncoder writes the error to the ResponseWriter, by default a +// content type of text/plain, a body of the plain text of the error, and a +// status code of 500. If the error implements Headerer, the provided headers +// will be applied to the response. If the error implements json.Marshaler, and +// the marshaling succeeds, a content type of application/json and the JSON +// encoded form of the error will be used. If the error implements StatusCoder, +// the provided StatusCode will be used instead of 500. +func DefaultErrorEncoder(_ context.Context, err error, w http.ResponseWriter) { + contentType, body := "text/plain; charset=utf-8", []byte(err.Error()) + if marshaler, ok := err.(json.Marshaler); ok { + if jsonBody, marshalErr := marshaler.MarshalJSON(); marshalErr == nil { + contentType, body = "application/json; charset=utf-8", jsonBody + } + } + w.Header().Set("Content-Type", contentType) + if headerer, ok := err.(Headerer); ok { + for k, values := range headerer.Headers() { + for _, v := range values { + w.Header().Add(k, v) + } + } + } + code := http.StatusInternalServerError + if sc, ok := err.(StatusCoder); ok { + code = sc.StatusCode() + } + w.WriteHeader(code) + w.Write(body) +} + +// StatusCoder is checked by DefaultErrorEncoder. If an error value implements +// StatusCoder, the StatusCode will be used when encoding the error. By default, +// StatusInternalServerError (500) is used. +type StatusCoder interface { + StatusCode() int +} + +// Headerer is checked by DefaultErrorEncoder. If an error value implements +// Headerer, the provided headers will be applied to the response writer, after +// the Content-Type is set. +type Headerer interface { + Headers() http.Header +} + +type interceptingWriter struct { + http.ResponseWriter + code int + written int64 +} + +// WriteHeader may not be explicitly called, so care must be taken to +// initialize w.code to its default value of http.StatusOK. +func (w *interceptingWriter) WriteHeader(code int) { + w.code = code + w.ResponseWriter.WriteHeader(code) +} + +func (w *interceptingWriter) Write(p []byte) (int, error) { + n, err := w.ResponseWriter.Write(p) + w.written += int64(n) + return n, err +} diff --git a/transport/http2/server_test.go b/transport/http2/server_test.go new file mode 100644 index 000000000..8e43ac878 --- /dev/null +++ b/transport/http2/server_test.go @@ -0,0 +1,442 @@ +package http_test + +import ( + "context" + "errors" + "io/ioutil" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/go-kit/kit/endpoint" + httptransport "github.com/go-kit/kit/transport/http2" +) + +func TestServerBadDecode(t *testing.T) { + handler := httptransport.NewServer[any, any]( + func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }, + func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, errors.New("dang") }, + func(context.Context, http.ResponseWriter, interface{}) error { return nil }, + ) + server := httptest.NewServer(handler) + defer server.Close() + resp, _ := http.Get(server.URL) + if want, have := http.StatusInternalServerError, resp.StatusCode; want != have { + t.Errorf("want %d, have %d", want, have) + } +} + +func TestServerBadEndpoint(t *testing.T) { + handler := httptransport.NewServer[any, any]( + func(context.Context, interface{}) (interface{}, error) { return struct{}{}, errors.New("dang") }, + func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil }, + func(context.Context, http.ResponseWriter, interface{}) error { return nil }, + ) + server := httptest.NewServer(handler) + defer server.Close() + resp, _ := http.Get(server.URL) + if want, have := http.StatusInternalServerError, resp.StatusCode; want != have { + t.Errorf("want %d, have %d", want, have) + } +} + +func TestServerBadEncode(t *testing.T) { + handler := httptransport.NewServer[any, any]( + func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }, + func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil }, + func(context.Context, http.ResponseWriter, interface{}) error { return errors.New("dang") }, + ) + server := httptest.NewServer(handler) + defer server.Close() + resp, _ := http.Get(server.URL) + if want, have := http.StatusInternalServerError, resp.StatusCode; want != have { + t.Errorf("want %d, have %d", want, have) + } +} + +func TestServerErrorEncoder(t *testing.T) { + errTeapot := errors.New("teapot") + code := func(err error) int { + if errors.Is(err, errTeapot) { + return http.StatusTeapot + } + return http.StatusInternalServerError + } + handler := httptransport.NewServer[any, any]( + func(context.Context, interface{}) (interface{}, error) { return struct{}{}, errTeapot }, + func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil }, + func(context.Context, http.ResponseWriter, interface{}) error { return nil }, + httptransport.ServerErrorEncoder[any, any](func(_ context.Context, err error, w http.ResponseWriter) { w.WriteHeader(code(err)) }), + ) + server := httptest.NewServer(handler) + defer server.Close() + resp, _ := http.Get(server.URL) + if want, have := http.StatusTeapot, resp.StatusCode; want != have { + t.Errorf("want %d, have %d", want, have) + } +} + +func TestServerHappyPath(t *testing.T) { + step, response := testServer(t) + step() + resp := <-response + defer resp.Body.Close() + buf, _ := ioutil.ReadAll(resp.Body) + if want, have := http.StatusOK, resp.StatusCode; want != have { + t.Errorf("want %d, have %d (%s)", want, have, buf) + } +} + +func TestMultipleServerBefore(t *testing.T) { + var ( + headerKey = "X-Henlo-Lizer" + headerVal = "Helllo you stinky lizard" + statusCode = http.StatusTeapot + responseBody = "go eat a fly ugly\n" + done = make(chan struct{}) + ) + handler := httptransport.NewServer[any, any]( + endpoint.Nop, + func(context.Context, *http.Request) (interface{}, error) { + return struct{}{}, nil + }, + func(_ context.Context, w http.ResponseWriter, _ interface{}) error { + w.Header().Set(headerKey, headerVal) + w.WriteHeader(statusCode) + w.Write([]byte(responseBody)) + return nil + }, + httptransport.ServerBefore[any, any](func(ctx context.Context, r *http.Request) context.Context { + ctx = context.WithValue(ctx, "one", 1) + + return ctx + }), + httptransport.ServerBefore[any, any](func(ctx context.Context, r *http.Request) context.Context { + if _, ok := ctx.Value("one").(int); !ok { + t.Error("Value was not set properly when multiple ServerBefores are used") + } + + close(done) + return ctx + }), + ) + + server := httptest.NewServer(handler) + defer server.Close() + go http.Get(server.URL) + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("timeout waiting for finalizer") + } +} + +func TestMultipleServerAfter(t *testing.T) { + var ( + headerKey = "X-Henlo-Lizer" + headerVal = "Helllo you stinky lizard" + statusCode = http.StatusTeapot + responseBody = "go eat a fly ugly\n" + done = make(chan struct{}) + ) + handler := httptransport.NewServer[any, any]( + endpoint.Nop, + func(context.Context, *http.Request) (interface{}, error) { + return struct{}{}, nil + }, + func(_ context.Context, w http.ResponseWriter, _ interface{}) error { + w.Header().Set(headerKey, headerVal) + w.WriteHeader(statusCode) + w.Write([]byte(responseBody)) + return nil + }, + httptransport.ServerAfter[any, any](func(ctx context.Context, w http.ResponseWriter) context.Context { + ctx = context.WithValue(ctx, "one", 1) + + return ctx + }), + httptransport.ServerAfter[any, any](func(ctx context.Context, w http.ResponseWriter) context.Context { + if _, ok := ctx.Value("one").(int); !ok { + t.Error("Value was not set properly when multiple ServerAfters are used") + } + + close(done) + return ctx + }), + ) + + server := httptest.NewServer(handler) + defer server.Close() + go http.Get(server.URL) + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("timeout waiting for finalizer") + } +} + +func TestServerFinalizer(t *testing.T) { + var ( + headerKey = "X-Henlo-Lizer" + headerVal = "Helllo you stinky lizard" + statusCode = http.StatusTeapot + responseBody = "go eat a fly ugly\n" + done = make(chan struct{}) + ) + handler := httptransport.NewServer[any, any]( + endpoint.Nop, + func(context.Context, *http.Request) (interface{}, error) { + return struct{}{}, nil + }, + func(_ context.Context, w http.ResponseWriter, _ interface{}) error { + w.Header().Set(headerKey, headerVal) + w.WriteHeader(statusCode) + w.Write([]byte(responseBody)) + return nil + }, + httptransport.ServerFinalizer[any, any](func(ctx context.Context, code int, _ *http.Request) { + if want, have := statusCode, code; want != have { + t.Errorf("StatusCode: want %d, have %d", want, have) + } + + responseHeader := ctx.Value(httptransport.ContextKeyResponseHeaders).(http.Header) + if want, have := headerVal, responseHeader.Get(headerKey); want != have { + t.Errorf("%s: want %q, have %q", headerKey, want, have) + } + + responseSize := ctx.Value(httptransport.ContextKeyResponseSize).(int64) + if want, have := int64(len(responseBody)), responseSize; want != have { + t.Errorf("response size: want %d, have %d", want, have) + } + + close(done) + }), + ) + + server := httptest.NewServer(handler) + defer server.Close() + go http.Get(server.URL) + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("timeout waiting for finalizer") + } +} + +type enhancedResponse struct { + Foo string `json:"foo"` +} + +func (e enhancedResponse) StatusCode() int { return http.StatusPaymentRequired } +func (e enhancedResponse) Headers() http.Header { return http.Header{"X-Edward": []string{"Snowden"}} } + +func TestEncodeJSONResponse(t *testing.T) { + handler := httptransport.NewServer( + func(context.Context, interface{}) (interface{}, error) { return enhancedResponse{Foo: "bar"}, nil }, + func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil }, + httptransport.EncodeJSONResponse, + ) + + server := httptest.NewServer(handler) + defer server.Close() + + resp, err := http.Get(server.URL) + if err != nil { + t.Fatal(err) + } + if want, have := http.StatusPaymentRequired, resp.StatusCode; want != have { + t.Errorf("StatusCode: want %d, have %d", want, have) + } + if want, have := "Snowden", resp.Header.Get("X-Edward"); want != have { + t.Errorf("X-Edward: want %q, have %q", want, have) + } + buf, _ := ioutil.ReadAll(resp.Body) + if want, have := `{"foo":"bar"}`, strings.TrimSpace(string(buf)); want != have { + t.Errorf("Body: want %s, have %s", want, have) + } +} + +type multiHeaderResponse struct{} + +func (_ multiHeaderResponse) Headers() http.Header { + return http.Header{"Vary": []string{"Origin", "User-Agent"}} +} + +func TestAddMultipleHeaders(t *testing.T) { + handler := httptransport.NewServer( + func(context.Context, interface{}) (interface{}, error) { return multiHeaderResponse{}, nil }, + func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil }, + httptransport.EncodeJSONResponse, + ) + + server := httptest.NewServer(handler) + defer server.Close() + + resp, err := http.Get(server.URL) + if err != nil { + t.Fatal(err) + } + expect := map[string]map[string]struct{}{"Vary": map[string]struct{}{"Origin": struct{}{}, "User-Agent": struct{}{}}} + for k, vls := range resp.Header { + for _, v := range vls { + delete((expect[k]), v) + } + if len(expect[k]) != 0 { + t.Errorf("Header: unexpected header %s: %v", k, expect[k]) + } + } +} + +type multiHeaderResponseError struct { + multiHeaderResponse + msg string +} + +func (m multiHeaderResponseError) Error() string { + return m.msg +} + +func TestAddMultipleHeadersErrorEncoder(t *testing.T) { + errStr := "oh no" + handler := httptransport.NewServer( + func(context.Context, interface{}) (interface{}, error) { + return nil, multiHeaderResponseError{msg: errStr} + }, + func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil }, + httptransport.EncodeJSONResponse, + ) + + server := httptest.NewServer(handler) + defer server.Close() + + resp, err := http.Get(server.URL) + if err != nil { + t.Fatal(err) + } + expect := map[string]map[string]struct{}{"Vary": map[string]struct{}{"Origin": struct{}{}, "User-Agent": struct{}{}}} + for k, vls := range resp.Header { + for _, v := range vls { + delete((expect[k]), v) + } + if len(expect[k]) != 0 { + t.Errorf("Header: unexpected header %s: %v", k, expect[k]) + } + } + if b, _ := ioutil.ReadAll(resp.Body); errStr != string(b) { + t.Errorf("ErrorEncoder: got: %q, expected: %q", b, errStr) + } +} + +type noContentResponse struct{} + +func (e noContentResponse) StatusCode() int { return http.StatusNoContent } + +func TestEncodeNoContent(t *testing.T) { + handler := httptransport.NewServer( + func(context.Context, interface{}) (interface{}, error) { return noContentResponse{}, nil }, + func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil }, + httptransport.EncodeJSONResponse, + ) + + server := httptest.NewServer(handler) + defer server.Close() + + resp, err := http.Get(server.URL) + if err != nil { + t.Fatal(err) + } + if want, have := http.StatusNoContent, resp.StatusCode; want != have { + t.Errorf("StatusCode: want %d, have %d", want, have) + } + buf, _ := ioutil.ReadAll(resp.Body) + if want, have := 0, len(buf); want != have { + t.Errorf("Body: want no content, have %d bytes", have) + } +} + +type enhancedError struct{} + +func (e enhancedError) Error() string { return "enhanced error" } +func (e enhancedError) StatusCode() int { return http.StatusTeapot } +func (e enhancedError) MarshalJSON() ([]byte, error) { return []byte(`{"err":"enhanced"}`), nil } +func (e enhancedError) Headers() http.Header { return http.Header{"X-Enhanced": []string{"1"}} } + +func TestEnhancedError(t *testing.T) { + handler := httptransport.NewServer( + func(context.Context, interface{}) (interface{}, error) { return nil, enhancedError{} }, + func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil }, + func(_ context.Context, w http.ResponseWriter, _ interface{}) error { return nil }, + ) + + server := httptest.NewServer(handler) + defer server.Close() + + resp, err := http.Get(server.URL) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if want, have := http.StatusTeapot, resp.StatusCode; want != have { + t.Errorf("StatusCode: want %d, have %d", want, have) + } + if want, have := "1", resp.Header.Get("X-Enhanced"); want != have { + t.Errorf("X-Enhanced: want %q, have %q", want, have) + } + buf, _ := ioutil.ReadAll(resp.Body) + if want, have := `{"err":"enhanced"}`, strings.TrimSpace(string(buf)); want != have { + t.Errorf("Body: want %s, have %s", want, have) + } +} + +func TestNoOpRequestDecoder(t *testing.T) { + resw := httptest.NewRecorder() + req, err := http.NewRequest(http.MethodGet, "/", nil) + if err != nil { + t.Error("Failed to create request") + } + handler := httptransport.NewServer( + func(ctx context.Context, request interface{}) (interface{}, error) { + if request != nil { + t.Error("Expected nil request in endpoint when using NopRequestDecoder") + } + return nil, nil + }, + httptransport.NopRequestDecoder, + httptransport.EncodeJSONResponse, + ) + handler.ServeHTTP(resw, req) + if resw.Code != http.StatusOK { + t.Errorf("Expected status code %d but got %d", http.StatusOK, resw.Code) + } +} + +func testServer(t *testing.T) (step func(), resp <-chan *http.Response) { + var ( + stepch = make(chan bool) + endpoint = func(context.Context, interface{}) (interface{}, error) { <-stepch; return struct{}{}, nil } + response = make(chan *http.Response) + handler = httptransport.NewServer[any, any]( + endpoint, + func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil }, + func(context.Context, http.ResponseWriter, interface{}) error { return nil }, + httptransport.ServerBefore[any, any](func(ctx context.Context, r *http.Request) context.Context { return ctx }), + httptransport.ServerAfter[any, any](func(ctx context.Context, w http.ResponseWriter) context.Context { return ctx }), + ) + ) + go func() { + server := httptest.NewServer(handler) + defer server.Close() + resp, err := http.Get(server.URL) + if err != nil { + t.Error(err) + return + } + response <- resp + }() + return func() { stepch <- true }, response +} From ae0890296dfc0b910469b013ffd53a504e850929 Mon Sep 17 00:00:00 2001 From: Mark Sagi-Kazar Date: Wed, 29 Sep 2021 00:34:48 +0200 Subject: [PATCH 4/5] ci: add ci config for gotip Signed-off-by: Mark Sagi-Kazar --- .github/workflows/tip.yaml | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) create mode 100644 .github/workflows/tip.yaml diff --git a/.github/workflows/tip.yaml b/.github/workflows/tip.yaml new file mode 100644 index 000000000..3affedb59 --- /dev/null +++ b/.github/workflows/tip.yaml @@ -0,0 +1,24 @@ +name: Tip + +on: + push: + branches: + - generics + pull_request: + +jobs: + build: + name: Build + runs-on: ubuntu-latest + + steps: + - name: Set up gotip + run: | + go install golang.org/dl/gotip@latest + gotip download + + - name: Checkout code + uses: actions/checkout@v2 + + - name: Run tests + run: gotip test -v ./endpoint/ ./transport/http2/ From 20efbde45640596e34d137ecd0f6a768b60c3f48 Mon Sep 17 00:00:00 2001 From: Mark Sagi-Kazar Date: Sun, 9 Jan 2022 02:39:46 +0100 Subject: [PATCH 5/5] eliminate endpoint layer Signed-off-by: Mark Sagi-Kazar --- .github/workflows/tip.yaml | 3 + transport/http3/doc.go | 2 + transport/http3/encode_decode.go | 36 ++ transport/http3/endpoint.go | 9 + transport/http3/example_test.go | 36 ++ transport/http3/request_response_funcs.go | 133 ++++++ .../http3/request_response_funcs_test.go | 30 ++ transport/http3/server.go | 243 ++++++++++ transport/http3/server_test.go | 445 ++++++++++++++++++ 9 files changed, 937 insertions(+) create mode 100644 transport/http3/doc.go create mode 100644 transport/http3/encode_decode.go create mode 100644 transport/http3/endpoint.go create mode 100644 transport/http3/example_test.go create mode 100644 transport/http3/request_response_funcs.go create mode 100644 transport/http3/request_response_funcs_test.go create mode 100644 transport/http3/server.go create mode 100644 transport/http3/server_test.go diff --git a/.github/workflows/tip.yaml b/.github/workflows/tip.yaml index 3affedb59..0a10465c0 100644 --- a/.github/workflows/tip.yaml +++ b/.github/workflows/tip.yaml @@ -22,3 +22,6 @@ jobs: - name: Run tests run: gotip test -v ./endpoint/ ./transport/http2/ + + - name: Run tests + run: gotip test -v ./endpoint/ ./transport/http3/ diff --git a/transport/http3/doc.go b/transport/http3/doc.go new file mode 100644 index 000000000..e64010358 --- /dev/null +++ b/transport/http3/doc.go @@ -0,0 +1,2 @@ +// Package http provides a general purpose HTTP binding for endpoints. +package http diff --git a/transport/http3/encode_decode.go b/transport/http3/encode_decode.go new file mode 100644 index 000000000..dc1423fe7 --- /dev/null +++ b/transport/http3/encode_decode.go @@ -0,0 +1,36 @@ +package http + +import ( + "context" + "net/http" +) + +// DecodeRequestFunc extracts a user-domain request object from an HTTP +// request object. It's designed to be used in HTTP servers, for server-side +// endpoints. One straightforward DecodeRequestFunc could be something that +// JSON decodes from the request body to the concrete request type. +type DecodeRequestFunc[Req any] func(context.Context, *http.Request) (request Req, err error) + +// EncodeRequestFunc encodes the passed request object into the HTTP request +// object. It's designed to be used in HTTP clients, for client-side +// endpoints. One straightforward EncodeRequestFunc could be something that JSON +// encodes the object directly to the request body. +type EncodeRequestFunc func(context.Context, *http.Request, interface{}) error + +// CreateRequestFunc creates an outgoing HTTP request based on the passed +// request object. It's designed to be used in HTTP clients, for client-side +// endpoints. It's a more powerful version of EncodeRequestFunc, and can be used +// if more fine-grained control of the HTTP request is required. +type CreateRequestFunc func(context.Context, interface{}) (*http.Request, error) + +// EncodeResponseFunc encodes the passed response object to the HTTP response +// writer. It's designed to be used in HTTP servers, for server-side +// endpoints. One straightforward EncodeResponseFunc could be something that +// JSON encodes the object directly to the response body. +type EncodeResponseFunc[Resp any] func(context.Context, http.ResponseWriter, Resp) error + +// DecodeResponseFunc extracts a user-domain response object from an HTTP +// response object. It's designed to be used in HTTP clients, for client-side +// endpoints. One straightforward DecodeResponseFunc could be something that +// JSON decodes from the response body to the concrete response type. +type DecodeResponseFunc func(context.Context, *http.Response) (response interface{}, err error) diff --git a/transport/http3/endpoint.go b/transport/http3/endpoint.go new file mode 100644 index 000000000..80b0552f3 --- /dev/null +++ b/transport/http3/endpoint.go @@ -0,0 +1,9 @@ +package http + +import ( + "context" +) + +// Endpoint is the fundamental building block of servers and clients. +// It represents a single RPC method. +type Endpoint[Req any, Resp any] func(ctx context.Context, request Req) (response Resp, err error) diff --git a/transport/http3/example_test.go b/transport/http3/example_test.go new file mode 100644 index 000000000..65397c3be --- /dev/null +++ b/transport/http3/example_test.go @@ -0,0 +1,36 @@ +package http + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" +) + +func ExamplePopulateRequestContext() { + handler := NewServer[any, any]( + func(ctx context.Context, request interface{}) (response interface{}, err error) { + fmt.Println("Method", ctx.Value(ContextKeyRequestMethod).(string)) + fmt.Println("RequestPath", ctx.Value(ContextKeyRequestPath).(string)) + fmt.Println("RequestURI", ctx.Value(ContextKeyRequestURI).(string)) + fmt.Println("X-Request-ID", ctx.Value(ContextKeyRequestXRequestID).(string)) + return struct{}{}, nil + }, + func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil }, + func(context.Context, http.ResponseWriter, interface{}) error { return nil }, + ServerBefore[any, any](PopulateRequestContext), + ) + + server := httptest.NewServer(handler) + defer server.Close() + + req, _ := http.NewRequest("PATCH", fmt.Sprintf("%s/search?q=sympatico", server.URL), nil) + req.Header.Set("X-Request-Id", "a1b2c3d4e5") + http.DefaultClient.Do(req) + + // Output: + // Method PATCH + // RequestPath /search + // RequestURI /search?q=sympatico + // X-Request-ID a1b2c3d4e5 +} diff --git a/transport/http3/request_response_funcs.go b/transport/http3/request_response_funcs.go new file mode 100644 index 000000000..8f92b3bc7 --- /dev/null +++ b/transport/http3/request_response_funcs.go @@ -0,0 +1,133 @@ +package http + +import ( + "context" + "net/http" +) + +// RequestFunc may take information from an HTTP request and put it into a +// request context. In Servers, RequestFuncs are executed prior to invoking the +// endpoint. In Clients, RequestFuncs are executed after creating the request +// but prior to invoking the HTTP client. +type RequestFunc func(context.Context, *http.Request) context.Context + +// ServerResponseFunc may take information from a request context and use it to +// manipulate a ResponseWriter. ServerResponseFuncs are only executed in +// servers, after invoking the endpoint but prior to writing a response. +type ServerResponseFunc func(context.Context, http.ResponseWriter) context.Context + +// ClientResponseFunc may take information from an HTTP request and make the +// response available for consumption. ClientResponseFuncs are only executed in +// clients, after a request has been made, but prior to it being decoded. +type ClientResponseFunc func(context.Context, *http.Response) context.Context + +// SetContentType returns a ServerResponseFunc that sets the Content-Type header +// to the provided value. +func SetContentType(contentType string) ServerResponseFunc { + return SetResponseHeader("Content-Type", contentType) +} + +// SetResponseHeader returns a ServerResponseFunc that sets the given header. +func SetResponseHeader(key, val string) ServerResponseFunc { + return func(ctx context.Context, w http.ResponseWriter) context.Context { + w.Header().Set(key, val) + return ctx + } +} + +// SetRequestHeader returns a RequestFunc that sets the given header. +func SetRequestHeader(key, val string) RequestFunc { + return func(ctx context.Context, r *http.Request) context.Context { + r.Header.Set(key, val) + return ctx + } +} + +// PopulateRequestContext is a RequestFunc that populates several values into +// the context from the HTTP request. Those values may be extracted using the +// corresponding ContextKey type in this package. +func PopulateRequestContext(ctx context.Context, r *http.Request) context.Context { + for k, v := range map[contextKey]string{ + ContextKeyRequestMethod: r.Method, + ContextKeyRequestURI: r.RequestURI, + ContextKeyRequestPath: r.URL.Path, + ContextKeyRequestProto: r.Proto, + ContextKeyRequestHost: r.Host, + ContextKeyRequestRemoteAddr: r.RemoteAddr, + ContextKeyRequestXForwardedFor: r.Header.Get("X-Forwarded-For"), + ContextKeyRequestXForwardedProto: r.Header.Get("X-Forwarded-Proto"), + ContextKeyRequestAuthorization: r.Header.Get("Authorization"), + ContextKeyRequestReferer: r.Header.Get("Referer"), + ContextKeyRequestUserAgent: r.Header.Get("User-Agent"), + ContextKeyRequestXRequestID: r.Header.Get("X-Request-Id"), + ContextKeyRequestAccept: r.Header.Get("Accept"), + } { + ctx = context.WithValue(ctx, k, v) + } + return ctx +} + +type contextKey int + +const ( + // ContextKeyRequestMethod is populated in the context by + // PopulateRequestContext. Its value is r.Method. + ContextKeyRequestMethod contextKey = iota + + // ContextKeyRequestURI is populated in the context by + // PopulateRequestContext. Its value is r.RequestURI. + ContextKeyRequestURI + + // ContextKeyRequestPath is populated in the context by + // PopulateRequestContext. Its value is r.URL.Path. + ContextKeyRequestPath + + // ContextKeyRequestProto is populated in the context by + // PopulateRequestContext. Its value is r.Proto. + ContextKeyRequestProto + + // ContextKeyRequestHost is populated in the context by + // PopulateRequestContext. Its value is r.Host. + ContextKeyRequestHost + + // ContextKeyRequestRemoteAddr is populated in the context by + // PopulateRequestContext. Its value is r.RemoteAddr. + ContextKeyRequestRemoteAddr + + // ContextKeyRequestXForwardedFor is populated in the context by + // PopulateRequestContext. Its value is r.Header.Get("X-Forwarded-For"). + ContextKeyRequestXForwardedFor + + // ContextKeyRequestXForwardedProto is populated in the context by + // PopulateRequestContext. Its value is r.Header.Get("X-Forwarded-Proto"). + ContextKeyRequestXForwardedProto + + // ContextKeyRequestAuthorization is populated in the context by + // PopulateRequestContext. Its value is r.Header.Get("Authorization"). + ContextKeyRequestAuthorization + + // ContextKeyRequestReferer is populated in the context by + // PopulateRequestContext. Its value is r.Header.Get("Referer"). + ContextKeyRequestReferer + + // ContextKeyRequestUserAgent is populated in the context by + // PopulateRequestContext. Its value is r.Header.Get("User-Agent"). + ContextKeyRequestUserAgent + + // ContextKeyRequestXRequestID is populated in the context by + // PopulateRequestContext. Its value is r.Header.Get("X-Request-Id"). + ContextKeyRequestXRequestID + + // ContextKeyRequestAccept is populated in the context by + // PopulateRequestContext. Its value is r.Header.Get("Accept"). + ContextKeyRequestAccept + + // ContextKeyResponseHeaders is populated in the context whenever a + // ServerFinalizerFunc is specified. Its value is of type http.Header, and + // is captured only once the entire response has been written. + ContextKeyResponseHeaders + + // ContextKeyResponseSize is populated in the context whenever a + // ServerFinalizerFunc is specified. Its value is of type int64. + ContextKeyResponseSize +) diff --git a/transport/http3/request_response_funcs_test.go b/transport/http3/request_response_funcs_test.go new file mode 100644 index 000000000..a6cfbfde7 --- /dev/null +++ b/transport/http3/request_response_funcs_test.go @@ -0,0 +1,30 @@ +package http_test + +import ( + "context" + "net/http/httptest" + "testing" + + httptransport "github.com/go-kit/kit/transport/http2" +) + +func TestSetHeader(t *testing.T) { + const ( + key = "X-Foo" + val = "12345" + ) + r := httptest.NewRecorder() + httptransport.SetResponseHeader(key, val)(context.Background(), r) + if want, have := val, r.Header().Get(key); want != have { + t.Errorf("want %q, have %q", want, have) + } +} + +func TestSetContentType(t *testing.T) { + const contentType = "application/json" + r := httptest.NewRecorder() + httptransport.SetContentType(contentType)(context.Background(), r) + if want, have := contentType, r.Header().Get("Content-Type"); want != have { + t.Errorf("want %q, have %q", want, have) + } +} diff --git a/transport/http3/server.go b/transport/http3/server.go new file mode 100644 index 000000000..a0cccf526 --- /dev/null +++ b/transport/http3/server.go @@ -0,0 +1,243 @@ +package http + +import ( + "context" + "encoding/json" + "net/http" + + "github.com/go-kit/kit/transport" + "github.com/go-kit/log" +) + +// Server wraps an endpoint and implements http.Handler. +type Server[Req any, Resp any] struct { + e Endpoint[Req, Resp] + dec DecodeRequestFunc[Req] + enc EncodeResponseFunc[Resp] + before []RequestFunc + after []ServerResponseFunc + errorEncoder ErrorEncoder + finalizer []ServerFinalizerFunc + errorHandler transport.ErrorHandler +} + +// NewServer constructs a new server, which implements http.Handler and wraps +// the provided endpoint. +func NewServer[Req any, Resp any]( + e Endpoint[Req, Resp], + dec DecodeRequestFunc[Req], + enc EncodeResponseFunc[Resp], + options ...ServerOption[Req, Resp], +) *Server[Req, Resp] { + s := &Server[Req, Resp]{ + e: e, + dec: dec, + enc: enc, + errorEncoder: DefaultErrorEncoder, + errorHandler: transport.NewLogErrorHandler(log.NewNopLogger()), + } + for _, option := range options { + option(s) + } + return s +} + +// ServerOption sets an optional parameter for servers. +type ServerOption[Req any, Resp any] func(*Server[Req, Resp]) + +// ServerBefore functions are executed on the HTTP request object before the +// request is decoded. +func ServerBefore[Req any, Resp any](before ...RequestFunc) ServerOption[Req, Resp] { + return func(s *Server[Req, Resp]) { s.before = append(s.before, before...) } +} + +// ServerAfter functions are executed on the HTTP response writer after the +// endpoint is invoked, but before anything is written to the client. +func ServerAfter[Req any, Resp any](after ...ServerResponseFunc) ServerOption[Req, Resp] { + return func(s *Server[Req, Resp]) { s.after = append(s.after, after...) } +} + +// ServerErrorEncoder is used to encode errors to the http.ResponseWriter +// whenever they're encountered in the processing of a request. Clients can +// use this to provide custom error formatting and response codes. By default, +// errors will be written with the DefaultErrorEncoder. +func ServerErrorEncoder[Req any, Resp any](ee ErrorEncoder) ServerOption[Req, Resp] { + return func(s *Server[Req, Resp]) { s.errorEncoder = ee } +} + +// ServerErrorLogger is used to log non-terminal errors. By default, no errors +// are logged. This is intended as a diagnostic measure. Finer-grained control +// of error handling, including logging in more detail, should be performed in a +// custom ServerErrorEncoder or ServerFinalizer, both of which have access to +// the context. +// Deprecated: Use ServerErrorHandler instead. +func ServerErrorLogger[Req any, Resp any](logger log.Logger) ServerOption[Req, Resp] { + return func(s *Server[Req, Resp]) { s.errorHandler = transport.NewLogErrorHandler(logger) } +} + +// ServerErrorHandler is used to handle non-terminal errors. By default, non-terminal errors +// are ignored. This is intended as a diagnostic measure. Finer-grained control +// of error handling, including logging in more detail, should be performed in a +// custom ServerErrorEncoder or ServerFinalizer, both of which have access to +// the context. +func ServerErrorHandler[Req any, Resp any](errorHandler transport.ErrorHandler) ServerOption[Req, Resp] { + return func(s *Server[Req, Resp]) { s.errorHandler = errorHandler } +} + +// ServerFinalizer is executed at the end of every HTTP request. +// By default, no finalizer is registered. +func ServerFinalizer[Req any, Resp any](f ...ServerFinalizerFunc) ServerOption[Req, Resp] { + return func(s *Server[Req, Resp]) { s.finalizer = append(s.finalizer, f...) } +} + +// ServeHTTP implements http.Handler. +func (s Server[Req, Resp]) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + if len(s.finalizer) > 0 { + iw := &interceptingWriter{w, http.StatusOK, 0} + defer func() { + ctx = context.WithValue(ctx, ContextKeyResponseHeaders, iw.Header()) + ctx = context.WithValue(ctx, ContextKeyResponseSize, iw.written) + for _, f := range s.finalizer { + f(ctx, iw.code, r) + } + }() + w = iw + } + + for _, f := range s.before { + ctx = f(ctx, r) + } + + request, err := s.dec(ctx, r) + if err != nil { + s.errorHandler.Handle(ctx, err) + s.errorEncoder(ctx, err, w) + return + } + + response, err := s.e(ctx, request) + if err != nil { + s.errorHandler.Handle(ctx, err) + s.errorEncoder(ctx, err, w) + return + } + + for _, f := range s.after { + ctx = f(ctx, w) + } + + if err := s.enc(ctx, w, response); err != nil { + s.errorHandler.Handle(ctx, err) + s.errorEncoder(ctx, err, w) + return + } +} + +// ErrorEncoder is responsible for encoding an error to the ResponseWriter. +// Users are encouraged to use custom ErrorEncoders to encode HTTP errors to +// their clients, and will likely want to pass and check for their own error +// types. See the example shipping/handling service. +type ErrorEncoder func(ctx context.Context, err error, w http.ResponseWriter) + +// ServerFinalizerFunc can be used to perform work at the end of an HTTP +// request, after the response has been written to the client. The principal +// intended use is for request logging. In addition to the response code +// provided in the function signature, additional response parameters are +// provided in the context under keys with the ContextKeyResponse prefix. +type ServerFinalizerFunc func(ctx context.Context, code int, r *http.Request) + +// NopRequestDecoder is a DecodeRequestFunc that can be used for requests that do not +// need to be decoded, and simply returns nil, nil. +func NopRequestDecoder(ctx context.Context, r *http.Request) (interface{}, error) { + return nil, nil +} + +// EncodeJSONResponse is a EncodeResponseFunc that serializes the response as a +// JSON object to the ResponseWriter. Many JSON-over-HTTP services can use it as +// a sensible default. If the response implements Headerer, the provided headers +// will be applied to the response. If the response implements StatusCoder, the +// provided StatusCode will be used instead of 200. +func EncodeJSONResponse(_ context.Context, w http.ResponseWriter, response interface{}) error { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + if headerer, ok := response.(Headerer); ok { + for k, values := range headerer.Headers() { + for _, v := range values { + w.Header().Add(k, v) + } + } + } + code := http.StatusOK + if sc, ok := response.(StatusCoder); ok { + code = sc.StatusCode() + } + w.WriteHeader(code) + if code == http.StatusNoContent { + return nil + } + return json.NewEncoder(w).Encode(response) +} + +// DefaultErrorEncoder writes the error to the ResponseWriter, by default a +// content type of text/plain, a body of the plain text of the error, and a +// status code of 500. If the error implements Headerer, the provided headers +// will be applied to the response. If the error implements json.Marshaler, and +// the marshaling succeeds, a content type of application/json and the JSON +// encoded form of the error will be used. If the error implements StatusCoder, +// the provided StatusCode will be used instead of 500. +func DefaultErrorEncoder(_ context.Context, err error, w http.ResponseWriter) { + contentType, body := "text/plain; charset=utf-8", []byte(err.Error()) + if marshaler, ok := err.(json.Marshaler); ok { + if jsonBody, marshalErr := marshaler.MarshalJSON(); marshalErr == nil { + contentType, body = "application/json; charset=utf-8", jsonBody + } + } + w.Header().Set("Content-Type", contentType) + if headerer, ok := err.(Headerer); ok { + for k, values := range headerer.Headers() { + for _, v := range values { + w.Header().Add(k, v) + } + } + } + code := http.StatusInternalServerError + if sc, ok := err.(StatusCoder); ok { + code = sc.StatusCode() + } + w.WriteHeader(code) + w.Write(body) +} + +// StatusCoder is checked by DefaultErrorEncoder. If an error value implements +// StatusCoder, the StatusCode will be used when encoding the error. By default, +// StatusInternalServerError (500) is used. +type StatusCoder interface { + StatusCode() int +} + +// Headerer is checked by DefaultErrorEncoder. If an error value implements +// Headerer, the provided headers will be applied to the response writer, after +// the Content-Type is set. +type Headerer interface { + Headers() http.Header +} + +type interceptingWriter struct { + http.ResponseWriter + code int + written int64 +} + +// WriteHeader may not be explicitly called, so care must be taken to +// initialize w.code to its default value of http.StatusOK. +func (w *interceptingWriter) WriteHeader(code int) { + w.code = code + w.ResponseWriter.WriteHeader(code) +} + +func (w *interceptingWriter) Write(p []byte) (int, error) { + n, err := w.ResponseWriter.Write(p) + w.written += int64(n) + return n, err +} diff --git a/transport/http3/server_test.go b/transport/http3/server_test.go new file mode 100644 index 000000000..097391ebe --- /dev/null +++ b/transport/http3/server_test.go @@ -0,0 +1,445 @@ +package http_test + +import ( + "context" + "errors" + "io/ioutil" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + httptransport "github.com/go-kit/kit/transport/http3" +) + +// nopEndpoint is an endpoint that does nothing and returns a nil error. +// Useful for tests. +func nopEndpoint(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil } + +func TestServerBadDecode(t *testing.T) { + handler := httptransport.NewServer[any, any]( + func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }, + func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, errors.New("dang") }, + func(context.Context, http.ResponseWriter, interface{}) error { return nil }, + ) + server := httptest.NewServer(handler) + defer server.Close() + resp, _ := http.Get(server.URL) + if want, have := http.StatusInternalServerError, resp.StatusCode; want != have { + t.Errorf("want %d, have %d", want, have) + } +} + +func TestServerBadEndpoint(t *testing.T) { + handler := httptransport.NewServer[any, any]( + func(context.Context, interface{}) (interface{}, error) { return struct{}{}, errors.New("dang") }, + func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil }, + func(context.Context, http.ResponseWriter, interface{}) error { return nil }, + ) + server := httptest.NewServer(handler) + defer server.Close() + resp, _ := http.Get(server.URL) + if want, have := http.StatusInternalServerError, resp.StatusCode; want != have { + t.Errorf("want %d, have %d", want, have) + } +} + +func TestServerBadEncode(t *testing.T) { + handler := httptransport.NewServer[any, any]( + func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }, + func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil }, + func(context.Context, http.ResponseWriter, interface{}) error { return errors.New("dang") }, + ) + server := httptest.NewServer(handler) + defer server.Close() + resp, _ := http.Get(server.URL) + if want, have := http.StatusInternalServerError, resp.StatusCode; want != have { + t.Errorf("want %d, have %d", want, have) + } +} + +func TestServerErrorEncoder(t *testing.T) { + errTeapot := errors.New("teapot") + code := func(err error) int { + if errors.Is(err, errTeapot) { + return http.StatusTeapot + } + return http.StatusInternalServerError + } + handler := httptransport.NewServer[any, any]( + func(context.Context, interface{}) (interface{}, error) { return struct{}{}, errTeapot }, + func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil }, + func(context.Context, http.ResponseWriter, interface{}) error { return nil }, + httptransport.ServerErrorEncoder[any, any](func(_ context.Context, err error, w http.ResponseWriter) { w.WriteHeader(code(err)) }), + ) + server := httptest.NewServer(handler) + defer server.Close() + resp, _ := http.Get(server.URL) + if want, have := http.StatusTeapot, resp.StatusCode; want != have { + t.Errorf("want %d, have %d", want, have) + } +} + +func TestServerHappyPath(t *testing.T) { + step, response := testServer(t) + step() + resp := <-response + defer resp.Body.Close() + buf, _ := ioutil.ReadAll(resp.Body) + if want, have := http.StatusOK, resp.StatusCode; want != have { + t.Errorf("want %d, have %d (%s)", want, have, buf) + } +} + +func TestMultipleServerBefore(t *testing.T) { + var ( + headerKey = "X-Henlo-Lizer" + headerVal = "Helllo you stinky lizard" + statusCode = http.StatusTeapot + responseBody = "go eat a fly ugly\n" + done = make(chan struct{}) + ) + handler := httptransport.NewServer[any, any]( + nopEndpoint, + func(context.Context, *http.Request) (interface{}, error) { + return struct{}{}, nil + }, + func(_ context.Context, w http.ResponseWriter, _ interface{}) error { + w.Header().Set(headerKey, headerVal) + w.WriteHeader(statusCode) + w.Write([]byte(responseBody)) + return nil + }, + httptransport.ServerBefore[any, any](func(ctx context.Context, r *http.Request) context.Context { + ctx = context.WithValue(ctx, "one", 1) + + return ctx + }), + httptransport.ServerBefore[any, any](func(ctx context.Context, r *http.Request) context.Context { + if _, ok := ctx.Value("one").(int); !ok { + t.Error("Value was not set properly when multiple ServerBefores are used") + } + + close(done) + return ctx + }), + ) + + server := httptest.NewServer(handler) + defer server.Close() + go http.Get(server.URL) + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("timeout waiting for finalizer") + } +} + +func TestMultipleServerAfter(t *testing.T) { + var ( + headerKey = "X-Henlo-Lizer" + headerVal = "Helllo you stinky lizard" + statusCode = http.StatusTeapot + responseBody = "go eat a fly ugly\n" + done = make(chan struct{}) + ) + handler := httptransport.NewServer[any, any]( + nopEndpoint, + func(context.Context, *http.Request) (interface{}, error) { + return struct{}{}, nil + }, + func(_ context.Context, w http.ResponseWriter, _ interface{}) error { + w.Header().Set(headerKey, headerVal) + w.WriteHeader(statusCode) + w.Write([]byte(responseBody)) + return nil + }, + httptransport.ServerAfter[any, any](func(ctx context.Context, w http.ResponseWriter) context.Context { + ctx = context.WithValue(ctx, "one", 1) + + return ctx + }), + httptransport.ServerAfter[any, any](func(ctx context.Context, w http.ResponseWriter) context.Context { + if _, ok := ctx.Value("one").(int); !ok { + t.Error("Value was not set properly when multiple ServerAfters are used") + } + + close(done) + return ctx + }), + ) + + server := httptest.NewServer(handler) + defer server.Close() + go http.Get(server.URL) + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("timeout waiting for finalizer") + } +} + +func TestServerFinalizer(t *testing.T) { + var ( + headerKey = "X-Henlo-Lizer" + headerVal = "Helllo you stinky lizard" + statusCode = http.StatusTeapot + responseBody = "go eat a fly ugly\n" + done = make(chan struct{}) + ) + handler := httptransport.NewServer[any, any]( + nopEndpoint, + func(context.Context, *http.Request) (interface{}, error) { + return struct{}{}, nil + }, + func(_ context.Context, w http.ResponseWriter, _ interface{}) error { + w.Header().Set(headerKey, headerVal) + w.WriteHeader(statusCode) + w.Write([]byte(responseBody)) + return nil + }, + httptransport.ServerFinalizer[any, any](func(ctx context.Context, code int, _ *http.Request) { + if want, have := statusCode, code; want != have { + t.Errorf("StatusCode: want %d, have %d", want, have) + } + + responseHeader := ctx.Value(httptransport.ContextKeyResponseHeaders).(http.Header) + if want, have := headerVal, responseHeader.Get(headerKey); want != have { + t.Errorf("%s: want %q, have %q", headerKey, want, have) + } + + responseSize := ctx.Value(httptransport.ContextKeyResponseSize).(int64) + if want, have := int64(len(responseBody)), responseSize; want != have { + t.Errorf("response size: want %d, have %d", want, have) + } + + close(done) + }), + ) + + server := httptest.NewServer(handler) + defer server.Close() + go http.Get(server.URL) + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("timeout waiting for finalizer") + } +} + +type enhancedResponse struct { + Foo string `json:"foo"` +} + +func (e enhancedResponse) StatusCode() int { return http.StatusPaymentRequired } +func (e enhancedResponse) Headers() http.Header { return http.Header{"X-Edward": []string{"Snowden"}} } + +func TestEncodeJSONResponse(t *testing.T) { + handler := httptransport.NewServer( + func(context.Context, interface{}) (interface{}, error) { return enhancedResponse{Foo: "bar"}, nil }, + func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil }, + httptransport.EncodeJSONResponse, + ) + + server := httptest.NewServer(handler) + defer server.Close() + + resp, err := http.Get(server.URL) + if err != nil { + t.Fatal(err) + } + if want, have := http.StatusPaymentRequired, resp.StatusCode; want != have { + t.Errorf("StatusCode: want %d, have %d", want, have) + } + if want, have := "Snowden", resp.Header.Get("X-Edward"); want != have { + t.Errorf("X-Edward: want %q, have %q", want, have) + } + buf, _ := ioutil.ReadAll(resp.Body) + if want, have := `{"foo":"bar"}`, strings.TrimSpace(string(buf)); want != have { + t.Errorf("Body: want %s, have %s", want, have) + } +} + +type multiHeaderResponse struct{} + +func (_ multiHeaderResponse) Headers() http.Header { + return http.Header{"Vary": []string{"Origin", "User-Agent"}} +} + +func TestAddMultipleHeaders(t *testing.T) { + handler := httptransport.NewServer( + func(context.Context, interface{}) (interface{}, error) { return multiHeaderResponse{}, nil }, + func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil }, + httptransport.EncodeJSONResponse, + ) + + server := httptest.NewServer(handler) + defer server.Close() + + resp, err := http.Get(server.URL) + if err != nil { + t.Fatal(err) + } + expect := map[string]map[string]struct{}{"Vary": map[string]struct{}{"Origin": struct{}{}, "User-Agent": struct{}{}}} + for k, vls := range resp.Header { + for _, v := range vls { + delete((expect[k]), v) + } + if len(expect[k]) != 0 { + t.Errorf("Header: unexpected header %s: %v", k, expect[k]) + } + } +} + +type multiHeaderResponseError struct { + multiHeaderResponse + msg string +} + +func (m multiHeaderResponseError) Error() string { + return m.msg +} + +func TestAddMultipleHeadersErrorEncoder(t *testing.T) { + errStr := "oh no" + handler := httptransport.NewServer( + func(context.Context, interface{}) (interface{}, error) { + return nil, multiHeaderResponseError{msg: errStr} + }, + func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil }, + httptransport.EncodeJSONResponse, + ) + + server := httptest.NewServer(handler) + defer server.Close() + + resp, err := http.Get(server.URL) + if err != nil { + t.Fatal(err) + } + expect := map[string]map[string]struct{}{"Vary": map[string]struct{}{"Origin": struct{}{}, "User-Agent": struct{}{}}} + for k, vls := range resp.Header { + for _, v := range vls { + delete((expect[k]), v) + } + if len(expect[k]) != 0 { + t.Errorf("Header: unexpected header %s: %v", k, expect[k]) + } + } + if b, _ := ioutil.ReadAll(resp.Body); errStr != string(b) { + t.Errorf("ErrorEncoder: got: %q, expected: %q", b, errStr) + } +} + +type noContentResponse struct{} + +func (e noContentResponse) StatusCode() int { return http.StatusNoContent } + +func TestEncodeNoContent(t *testing.T) { + handler := httptransport.NewServer( + func(context.Context, interface{}) (interface{}, error) { return noContentResponse{}, nil }, + func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil }, + httptransport.EncodeJSONResponse, + ) + + server := httptest.NewServer(handler) + defer server.Close() + + resp, err := http.Get(server.URL) + if err != nil { + t.Fatal(err) + } + if want, have := http.StatusNoContent, resp.StatusCode; want != have { + t.Errorf("StatusCode: want %d, have %d", want, have) + } + buf, _ := ioutil.ReadAll(resp.Body) + if want, have := 0, len(buf); want != have { + t.Errorf("Body: want no content, have %d bytes", have) + } +} + +type enhancedError struct{} + +func (e enhancedError) Error() string { return "enhanced error" } +func (e enhancedError) StatusCode() int { return http.StatusTeapot } +func (e enhancedError) MarshalJSON() ([]byte, error) { return []byte(`{"err":"enhanced"}`), nil } +func (e enhancedError) Headers() http.Header { return http.Header{"X-Enhanced": []string{"1"}} } + +func TestEnhancedError(t *testing.T) { + handler := httptransport.NewServer( + func(context.Context, interface{}) (interface{}, error) { return nil, enhancedError{} }, + func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil }, + func(_ context.Context, w http.ResponseWriter, _ interface{}) error { return nil }, + ) + + server := httptest.NewServer(handler) + defer server.Close() + + resp, err := http.Get(server.URL) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if want, have := http.StatusTeapot, resp.StatusCode; want != have { + t.Errorf("StatusCode: want %d, have %d", want, have) + } + if want, have := "1", resp.Header.Get("X-Enhanced"); want != have { + t.Errorf("X-Enhanced: want %q, have %q", want, have) + } + buf, _ := ioutil.ReadAll(resp.Body) + if want, have := `{"err":"enhanced"}`, strings.TrimSpace(string(buf)); want != have { + t.Errorf("Body: want %s, have %s", want, have) + } +} + +func TestNoOpRequestDecoder(t *testing.T) { + resw := httptest.NewRecorder() + req, err := http.NewRequest(http.MethodGet, "/", nil) + if err != nil { + t.Error("Failed to create request") + } + handler := httptransport.NewServer( + func(ctx context.Context, request interface{}) (interface{}, error) { + if request != nil { + t.Error("Expected nil request in endpoint when using NopRequestDecoder") + } + return nil, nil + }, + httptransport.NopRequestDecoder, + httptransport.EncodeJSONResponse, + ) + handler.ServeHTTP(resw, req) + if resw.Code != http.StatusOK { + t.Errorf("Expected status code %d but got %d", http.StatusOK, resw.Code) + } +} + +func testServer(t *testing.T) (step func(), resp <-chan *http.Response) { + var ( + stepch = make(chan bool) + endpoint = func(context.Context, interface{}) (interface{}, error) { <-stepch; return struct{}{}, nil } + response = make(chan *http.Response) + handler = httptransport.NewServer[any, any]( + endpoint, + func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil }, + func(context.Context, http.ResponseWriter, interface{}) error { return nil }, + httptransport.ServerBefore[any, any](func(ctx context.Context, r *http.Request) context.Context { return ctx }), + httptransport.ServerAfter[any, any](func(ctx context.Context, w http.ResponseWriter) context.Context { return ctx }), + ) + ) + go func() { + server := httptest.NewServer(handler) + defer server.Close() + resp, err := http.Get(server.URL) + if err != nil { + t.Error(err) + return + } + response <- resp + }() + return func() { stepch <- true }, response +}