diff --git a/pkg/server/api.go b/pkg/server/api.go index 444a9a540d..542ffaa321 100644 --- a/pkg/server/api.go +++ b/pkg/server/api.go @@ -80,8 +80,14 @@ func NewServerAPIHandler(s Server) *APIHandler { // ServeHTTP handles the requests for the machine config server // API handler. func (sh *APIHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet && r.Method != http.MethodHead { + w.Header().Set("Content-Length", "0") + w.WriteHeader(http.StatusMethodNotAllowed) + return + } if r.URL.Path == "" { + w.Header().Set("Content-Length", "0") w.WriteHeader(http.StatusBadRequest) return } @@ -92,18 +98,34 @@ func (sh *APIHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { conf, err := sh.server.GetConfig(cr) if err != nil { + w.Header().Set("Content-Length", "0") w.WriteHeader(http.StatusInternalServerError) glog.Errorf("couldn't get config for req: %v, error: %v", cr, err) return } if conf == nil && err == nil { + w.Header().Set("Content-Length", "0") w.WriteHeader(http.StatusNotFound) return } - encoder := json.NewEncoder(w) - if err := encoder.Encode(conf); err != nil { + data, err := json.Marshal(conf) + if err != nil { + w.Header().Set("Content-Length", "0") w.WriteHeader(http.StatusInternalServerError) - glog.Errorf("couldn't encode the config for req: %v, error: %v", cr, err) + glog.Errorf("failed to marshal %v config: %v", cr, err) + return + } + + w.Header().Set("Content-Length", fmt.Sprintf("%d", len(data))) + w.Header().Set("Content-Type", "application/json") + if r.Method == http.MethodHead { + w.WriteHeader(http.StatusOK) + return + } + + _, err = w.Write(data) + if err != nil { + glog.Errorf("failed to write %v response: %v", cr, err) } } diff --git a/pkg/server/api_test.go b/pkg/server/api_test.go index d603b7de9f..8b632bf297 100644 --- a/pkg/server/api_test.go +++ b/pkg/server/api_test.go @@ -2,6 +2,7 @@ package server import ( "fmt" + "io/ioutil" "net/http" "net/http/httptest" "testing" @@ -17,50 +18,134 @@ func (ms *mockServer) GetConfig(pr poolRequest) (*ignv2_2types.Config, error) { return ms.GetConfigFn(pr) } +type checkResponse func(t *testing.T, response *http.Response) + type scenario struct { - name string - expectedStatus int - serverFunc func(poolRequest) (*ignv2_2types.Config, error) + name string + request *http.Request + serverFunc func(poolRequest) (*ignv2_2types.Config, error) + checkResponse checkResponse } func TestAPIHandler(t *testing.T) { scenarios := []scenario{ { - name: "not-found", - expectedStatus: http.StatusNotFound, + name: "get non-config path that does not exist", + request: httptest.NewRequest(http.MethodGet, "http://testrequest/does-not-exist", nil), serverFunc: func(poolRequest) (*ignv2_2types.Config, error) { return nil, nil }, + checkResponse: func(t *testing.T, response *http.Response) { + checkStatus(t, response, http.StatusNotFound) + checkContentLength(t, response, 0) + checkBodyLength(t, response, 0) + }, }, { - name: "internal-server", - expectedStatus: http.StatusInternalServerError, + name: "get config path that does not exist", + request: httptest.NewRequest(http.MethodGet, "http://testrequest/config/does-not-exist", nil), serverFunc: func(poolRequest) (*ignv2_2types.Config, error) { return new(ignv2_2types.Config), fmt.Errorf("not acceptable") }, + checkResponse: func(t *testing.T, response *http.Response) { + checkStatus(t, response, http.StatusInternalServerError) + checkContentLength(t, response, 0) + checkBodyLength(t, response, 0) + }, + }, + { + name: "get config path that exists", + request: httptest.NewRequest(http.MethodGet, "http://testrequest/config/master", nil), + serverFunc: func(poolRequest) (*ignv2_2types.Config, error) { + return new(ignv2_2types.Config), nil + }, + checkResponse: func(t *testing.T, response *http.Response) { + checkStatus(t, response, http.StatusOK) + checkContentType(t, response, "application/json") + checkContentLength(t, response, 114) + checkBodyLength(t, response, 114) + }, + }, + { + name: "head config path that exists", + request: httptest.NewRequest(http.MethodHead, "http://testrequest/config/master", nil), + serverFunc: func(poolRequest) (*ignv2_2types.Config, error) { + return new(ignv2_2types.Config), nil + }, + checkResponse: func(t *testing.T, response *http.Response) { + checkStatus(t, response, http.StatusOK) + checkContentType(t, response, "application/json") + checkContentLength(t, response, 114) + checkBodyLength(t, response, 0) + }, + }, + { + name: "post non-config path that does not exist", + request: httptest.NewRequest(http.MethodPost, "http://testrequest/post", nil), + serverFunc: func(poolRequest) (*ignv2_2types.Config, error) { + return nil, nil + }, + checkResponse: func(t *testing.T, response *http.Response) { + checkStatus(t, response, http.StatusMethodNotAllowed) + checkContentLength(t, response, 0) + checkBodyLength(t, response, 0) + }, }, { - name: "success", - expectedStatus: http.StatusOK, + name: "post config path that exists", + request: httptest.NewRequest(http.MethodPost, "http://testrequest/config/master", nil), serverFunc: func(poolRequest) (*ignv2_2types.Config, error) { return new(ignv2_2types.Config), nil }, + checkResponse: func(t *testing.T, response *http.Response) { + checkStatus(t, response, http.StatusMethodNotAllowed) + checkContentLength(t, response, 0) + checkBodyLength(t, response, 0) + }, }, } - for i := range scenarios { - req := httptest.NewRequest("GET", "http://testrequest/", nil) - w := httptest.NewRecorder() - ms := &mockServer{ - GetConfigFn: scenarios[i].serverFunc, - } - handler := NewServerAPIHandler(ms) - handler.ServeHTTP(w, req) + for _, scenario := range scenarios { + t.Run(scenario.name, func(t *testing.T) { + w := httptest.NewRecorder() + ms := &mockServer{ + GetConfigFn: scenario.serverFunc, + } + handler := NewServerAPIHandler(ms) + handler.ServeHTTP(w, scenario.request) - resp := w.Result() + resp := w.Result() + defer resp.Body.Close() + scenario.checkResponse(t, resp) + }) + } +} - if resp.StatusCode != scenarios[i].expectedStatus { - t.Errorf("API Handler test failed for: %s, expected: %d, received: %d", scenarios[i].name, scenarios[i].expectedStatus, resp.StatusCode) - } +func checkStatus(t *testing.T, response *http.Response, expected int) { + if response.StatusCode != expected { + t.Errorf("expected response status %d, received %d", expected, response.StatusCode) + } +} + +func checkContentType(t *testing.T, response *http.Response, expected string) { + actual := response.Header.Get("Content-Type") + if actual != expected { + t.Errorf("expected response Content-Type %q, received %q", expected, actual) + } +} + +func checkContentLength(t *testing.T, response *http.Response, l int) { + if int(response.ContentLength) != l { + t.Errorf("expected response Content-Length %d, received %d", l, int(response.ContentLength)) + } +} + +func checkBodyLength(t *testing.T, response *http.Response, l int) { + body, err := ioutil.ReadAll(response.Body) + if err != nil { + t.Fatal(err) + } + if len(body) != l { + t.Errorf("expected response body length %d, received %d", l, len(body)) } }