From 01d22e13f20df7cf68be8d5fb8608d6abee059ca Mon Sep 17 00:00:00 2001 From: "W. Trevor King" Date: Sat, 15 Dec 2018 23:56:11 -0800 Subject: [PATCH] pkg/server/api: Return 405s for non-GET requests We don't want folks POSTing, etc. to this server. If they do, this commit will return the appropriate response status code. Also add support for HEAD requests and start setting Content-Length and Content-Type ourselves instead of leaning on Go's autodetection [1]. The application/json media type is from [2]. [1]: https://golang.org/pkg/net/http/#ResponseWriter [2]: https://tools.ietf.org/html/rfc8259#section-1.2 --- pkg/server/api.go | 28 ++++++++- pkg/server/api_test.go | 127 ++++++++++++++++++++++++++++++++++------- 2 files changed, 131 insertions(+), 24 deletions(-) diff --git a/pkg/server/api.go b/pkg/server/api.go index 444a9a540d..542ffaa321 100644 --- a/pkg/server/api.go +++ b/pkg/server/api.go @@ -80,8 +80,14 @@ func NewServerAPIHandler(s Server) *APIHandler { // ServeHTTP handles the requests for the machine config server // API handler. func (sh *APIHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet && r.Method != http.MethodHead { + w.Header().Set("Content-Length", "0") + w.WriteHeader(http.StatusMethodNotAllowed) + return + } if r.URL.Path == "" { + w.Header().Set("Content-Length", "0") w.WriteHeader(http.StatusBadRequest) return } @@ -92,18 +98,34 @@ func (sh *APIHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { conf, err := sh.server.GetConfig(cr) if err != nil { + w.Header().Set("Content-Length", "0") w.WriteHeader(http.StatusInternalServerError) glog.Errorf("couldn't get config for req: %v, error: %v", cr, err) return } if conf == nil && err == nil { + w.Header().Set("Content-Length", "0") w.WriteHeader(http.StatusNotFound) return } - encoder := json.NewEncoder(w) - if err := encoder.Encode(conf); err != nil { + data, err := json.Marshal(conf) + if err != nil { + w.Header().Set("Content-Length", "0") w.WriteHeader(http.StatusInternalServerError) - glog.Errorf("couldn't encode the config for req: %v, error: %v", cr, err) + glog.Errorf("failed to marshal %v config: %v", cr, err) + return + } + + w.Header().Set("Content-Length", fmt.Sprintf("%d", len(data))) + w.Header().Set("Content-Type", "application/json") + if r.Method == http.MethodHead { + w.WriteHeader(http.StatusOK) + return + } + + _, err = w.Write(data) + if err != nil { + glog.Errorf("failed to write %v response: %v", cr, err) } } diff --git a/pkg/server/api_test.go b/pkg/server/api_test.go index d603b7de9f..8b632bf297 100644 --- a/pkg/server/api_test.go +++ b/pkg/server/api_test.go @@ -2,6 +2,7 @@ package server import ( "fmt" + "io/ioutil" "net/http" "net/http/httptest" "testing" @@ -17,50 +18,134 @@ func (ms *mockServer) GetConfig(pr poolRequest) (*ignv2_2types.Config, error) { return ms.GetConfigFn(pr) } +type checkResponse func(t *testing.T, response *http.Response) + type scenario struct { - name string - expectedStatus int - serverFunc func(poolRequest) (*ignv2_2types.Config, error) + name string + request *http.Request + serverFunc func(poolRequest) (*ignv2_2types.Config, error) + checkResponse checkResponse } func TestAPIHandler(t *testing.T) { scenarios := []scenario{ { - name: "not-found", - expectedStatus: http.StatusNotFound, + name: "get non-config path that does not exist", + request: httptest.NewRequest(http.MethodGet, "http://testrequest/does-not-exist", nil), serverFunc: func(poolRequest) (*ignv2_2types.Config, error) { return nil, nil }, + checkResponse: func(t *testing.T, response *http.Response) { + checkStatus(t, response, http.StatusNotFound) + checkContentLength(t, response, 0) + checkBodyLength(t, response, 0) + }, }, { - name: "internal-server", - expectedStatus: http.StatusInternalServerError, + name: "get config path that does not exist", + request: httptest.NewRequest(http.MethodGet, "http://testrequest/config/does-not-exist", nil), serverFunc: func(poolRequest) (*ignv2_2types.Config, error) { return new(ignv2_2types.Config), fmt.Errorf("not acceptable") }, + checkResponse: func(t *testing.T, response *http.Response) { + checkStatus(t, response, http.StatusInternalServerError) + checkContentLength(t, response, 0) + checkBodyLength(t, response, 0) + }, + }, + { + name: "get config path that exists", + request: httptest.NewRequest(http.MethodGet, "http://testrequest/config/master", nil), + serverFunc: func(poolRequest) (*ignv2_2types.Config, error) { + return new(ignv2_2types.Config), nil + }, + checkResponse: func(t *testing.T, response *http.Response) { + checkStatus(t, response, http.StatusOK) + checkContentType(t, response, "application/json") + checkContentLength(t, response, 114) + checkBodyLength(t, response, 114) + }, + }, + { + name: "head config path that exists", + request: httptest.NewRequest(http.MethodHead, "http://testrequest/config/master", nil), + serverFunc: func(poolRequest) (*ignv2_2types.Config, error) { + return new(ignv2_2types.Config), nil + }, + checkResponse: func(t *testing.T, response *http.Response) { + checkStatus(t, response, http.StatusOK) + checkContentType(t, response, "application/json") + checkContentLength(t, response, 114) + checkBodyLength(t, response, 0) + }, + }, + { + name: "post non-config path that does not exist", + request: httptest.NewRequest(http.MethodPost, "http://testrequest/post", nil), + serverFunc: func(poolRequest) (*ignv2_2types.Config, error) { + return nil, nil + }, + checkResponse: func(t *testing.T, response *http.Response) { + checkStatus(t, response, http.StatusMethodNotAllowed) + checkContentLength(t, response, 0) + checkBodyLength(t, response, 0) + }, }, { - name: "success", - expectedStatus: http.StatusOK, + name: "post config path that exists", + request: httptest.NewRequest(http.MethodPost, "http://testrequest/config/master", nil), serverFunc: func(poolRequest) (*ignv2_2types.Config, error) { return new(ignv2_2types.Config), nil }, + checkResponse: func(t *testing.T, response *http.Response) { + checkStatus(t, response, http.StatusMethodNotAllowed) + checkContentLength(t, response, 0) + checkBodyLength(t, response, 0) + }, }, } - for i := range scenarios { - req := httptest.NewRequest("GET", "http://testrequest/", nil) - w := httptest.NewRecorder() - ms := &mockServer{ - GetConfigFn: scenarios[i].serverFunc, - } - handler := NewServerAPIHandler(ms) - handler.ServeHTTP(w, req) + for _, scenario := range scenarios { + t.Run(scenario.name, func(t *testing.T) { + w := httptest.NewRecorder() + ms := &mockServer{ + GetConfigFn: scenario.serverFunc, + } + handler := NewServerAPIHandler(ms) + handler.ServeHTTP(w, scenario.request) - resp := w.Result() + resp := w.Result() + defer resp.Body.Close() + scenario.checkResponse(t, resp) + }) + } +} - if resp.StatusCode != scenarios[i].expectedStatus { - t.Errorf("API Handler test failed for: %s, expected: %d, received: %d", scenarios[i].name, scenarios[i].expectedStatus, resp.StatusCode) - } +func checkStatus(t *testing.T, response *http.Response, expected int) { + if response.StatusCode != expected { + t.Errorf("expected response status %d, received %d", expected, response.StatusCode) + } +} + +func checkContentType(t *testing.T, response *http.Response, expected string) { + actual := response.Header.Get("Content-Type") + if actual != expected { + t.Errorf("expected response Content-Type %q, received %q", expected, actual) + } +} + +func checkContentLength(t *testing.T, response *http.Response, l int) { + if int(response.ContentLength) != l { + t.Errorf("expected response Content-Length %d, received %d", l, int(response.ContentLength)) + } +} + +func checkBodyLength(t *testing.T, response *http.Response, l int) { + body, err := ioutil.ReadAll(response.Body) + if err != nil { + t.Fatal(err) + } + if len(body) != l { + t.Errorf("expected response body length %d, received %d", l, len(body)) } }