diff --git a/Gopkg.lock b/Gopkg.lock index 55cc0c3..2ce37a4 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -15,7 +15,10 @@ [[projects]] name = "github.com/sirupsen/logrus" - packages = ["."] + packages = [ + ".", + "hooks/test" + ] revision = "c155da19408a8799da419ed3eeb0cb5db0ad5dbc" version = "v1.0.5" @@ -67,6 +70,6 @@ [solve-meta] analyzer-name = "dep" analyzer-version = 1 - inputs-digest = "4f6c6ebefba028f9350159e530ad9c3e7654005e0361f692eb04ec04af6bc7ec" + inputs-digest = "d21a7336cc2f9c23a2eea8182be8faaafbe1b1cf445e8807777eb352eb32ddf4" solver-name = "gps-cdcl" solver-version = 1 diff --git a/server/handler.go b/server/handler.go index 9484c12..4103948 100644 --- a/server/handler.go +++ b/server/handler.go @@ -8,8 +8,36 @@ import ( "strings" "sync" "time" + + "github.com/sirupsen/logrus" ) +func serviceHandler(l *logrus.Logger, sr ServiceRegistry, validators []Validator) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + name := r.URL.Path[6:] + if name == "" { + l.Info("no list name defined") + w.WriteHeader(400) + return + } + + if !sr.HasServiceForList(name) { + l.Infof("list '%s' not defined", name) + w.WriteHeader(404) + return + } + + svc := sr.GetServiceForList(name) + hf := createRequestHandler( + l, + svc, + validators, + ) + + hf(w, r) + } +} + func defaultHeaderHandler(h http.Handler) http.HandlerFunc { type kv struct { @@ -51,6 +79,8 @@ func createRequestIDHandler(h http.Handler) http.HandlerFunc { lock.Unlock() ctx := context.WithValue(r.Context(), CtxRequestID, requestID) + w.Header().Set(HeaderRequestID, requestID) + h.ServeHTTP(w, r.WithContext(ctx)) } } diff --git a/server/handler_test.go b/server/handler_test.go new file mode 100644 index 0000000..ed72098 --- /dev/null +++ b/server/handler_test.go @@ -0,0 +1,108 @@ +package server + +import ( + "net/http" + "net/http/httptest" + "testing" + + "bytes" + "encoding/json" + + "github.com/sirupsen/logrus/hooks/test" +) + +func TestServiceHandlerHappyFlow(t *testing.T) { + logger, _ := test.NewNullLogger() + recorder := httptest.NewRecorder() + + { // setup + payload, _ := json.Marshal(tySugRequest{Input: "baz"}) + req := httptest.NewRequest(http.MethodPost, "/list/foo", bytes.NewReader(payload)) + sr := NewServiceRegistry() + sr.Register("foo", stubSvc{FindResult: "bar"}) + + hf := http.HandlerFunc(serviceHandler(logger, sr, nil)) + hf.ServeHTTP(recorder, req) + } + + t.Run("test status code", func(t *testing.T) { + if status := recorder.Code; status != http.StatusOK { + t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusOK) + } + }) + + t.Run("test response result", func(t *testing.T) { + expect := tySugResponse{Result: "bar"} + var result tySugResponse + err := json.Unmarshal(recorder.Body.Bytes(), &result) + + if err != nil { + t.Errorf("unexpected error while unmarshalling the response type %s", err) + } + + if result.Result != expect.Result { + t.Errorf("expected the input to be %s, instead I got %s", expect.Result, result.Result) + } + }) +} + +func TestServiceHandlerInvalidListName(t *testing.T) { + logger, _ := test.NewNullLogger() + + t.Run("status code, nonexistent list name", func(t *testing.T) { + recorder := httptest.NewRecorder() + { // setup + payload, _ := json.Marshal(tySugRequest{Input: "baz"}) + req := httptest.NewRequest(http.MethodPost, "/list/not-existing", bytes.NewReader(payload)) + + sr := NewServiceRegistry() + hf := http.HandlerFunc(serviceHandler(logger, sr, nil)) + hf.ServeHTTP(recorder, req) + } + + if status := recorder.Code; status != http.StatusNotFound { + t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusNotFound) + } + }) + + t.Run("status code, unspecified list name", func(t *testing.T) { + recorder := httptest.NewRecorder() + { // setup + payload, _ := json.Marshal(tySugRequest{Input: "baz"}) + req := httptest.NewRequest(http.MethodPost, "/list/", bytes.NewReader(payload)) + + sr := NewServiceRegistry() + hf := http.HandlerFunc(serviceHandler(logger, sr, nil)) + hf.ServeHTTP(recorder, req) + } + + if status := recorder.Code; status != http.StatusBadRequest { + t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusBadRequest) + } + }) +} + +func TestRequestID(t *testing.T) { + recorder1 := httptest.NewRecorder() + recorder2 := httptest.NewRecorder() + + { // setup + req := httptest.NewRequest(http.MethodPost, "/", nil) + h := createRequestIDHandler(http.NewServeMux()) + hf := http.HandlerFunc(h) + + // Recording two requests + hf.ServeHTTP(recorder1, req) + hf.ServeHTTP(recorder2, req) + } + + rid1 := recorder1.Result().Header.Get(HeaderRequestID) + rid2 := recorder2.Result().Header.Get(HeaderRequestID) + if rid1 == rid2 { + t.Errorf("Did not expect the request ID's to be identical: %s vs %s", rid1, rid2) + } + + if rid1 == "" { + t.Errorf("Did not expect the request ID to be empty.") + } +} diff --git a/server/http.go b/server/http.go index 52d6585..6e2cc0b 100644 --- a/server/http.go +++ b/server/http.go @@ -34,6 +34,11 @@ const ( CtxRequestID contextKey = iota ) +// Header constants +const ( + HeaderRequestID = "X-Request-ID" +) + type tySugRequest struct { Input string `json:"input"` } @@ -85,29 +90,7 @@ func NewHTTP(sr ServiceRegistry, mux *http.ServeMux, options ...Option) TySugSer } mux.HandleFunc("/", http.NotFound) - mux.HandleFunc("/list/", func(w http.ResponseWriter, r *http.Request) { - name := r.URL.Path[6:] - if name == "" { - tySug.Logger.Info("no list name defined") - w.WriteHeader(400) - return - } - - if !sr.HasServiceForList(name) { - tySug.Logger.Infof("list '%s' not defined", name) - w.WriteHeader(404) - return - } - - svc := sr.GetServiceForList(name) - hf := createRequestHandler( - tySug.Logger, - svc, - tySug.validators, - ) - - hf(w, r) - }) + mux.HandleFunc("/list/", serviceHandler(tySug.Logger, sr, tySug.validators)) tySug.server = &http.Server{ ReadHeaderTimeout: 2 * time.Second, @@ -217,6 +200,11 @@ func getRequestFromHTTPRequest(r *http.Request) (tySugRequest, error) { var req tySugRequest var maxSizePlusOne int64 = maxBodySize + 1 + + if r.Body == nil { + return req, ErrMissingBody + } + b, err := ioutil.ReadAll(io.LimitReader(r.Body, maxSizePlusOne)) if err != nil { if err == io.EOF { diff --git a/server/option.go b/server/option.go index 503e09f..d2bce62 100644 --- a/server/option.go +++ b/server/option.go @@ -5,6 +5,8 @@ import ( "fmt" + "errors" + "github.com/NYTimes/gziphandler" "github.com/rs/cors" "github.com/sirupsen/logrus" @@ -48,6 +50,10 @@ func WithInputLimitValidator(inputMax int) Option { return fmt.Errorf("WithInputLimitValidator input exceeds server specified maximum of %d bytes", inputMax) } + if len(TSRequest.Input) == 0 { + return errors.New("WithInputLimitValidator input may not be empty") + } + return nil }) } diff --git a/server/service_registry_test.go b/server/service_registry_test.go index dce2a6f..e42f0db 100644 --- a/server/service_registry_test.go +++ b/server/service_registry_test.go @@ -6,10 +6,13 @@ import ( ) type stubSvc struct { + FindResult string + FindScore float64 + FindExact bool } -func (stubSvc) Find(ctx context.Context, input string) (string, float64, bool) { - return "", 0, true +func (svc stubSvc) Find(ctx context.Context, input string) (string, float64, bool) { + return svc.FindResult, svc.FindScore, svc.FindExact } func TestHasServiceForList(t *testing.T) {