Skip to content

Commit

Permalink
pkg/server/api: Return 405s for non-GET requests
Browse files Browse the repository at this point in the history
We don't want folks POSTing, etc. to this server.  If they do, this
commit will return the appropriate response status code.

Also add support for HEAD requests and start setting Content-Length
and Content-Type ourselves instead of leaning on Go's autodetection
[1].  The application/json media type is from [2].

[1]: https://golang.org/pkg/net/http/#ResponseWriter
[2]: https://tools.ietf.org/html/rfc8259#section-1.2
  • Loading branch information
wking committed Dec 18, 2018
1 parent a5d0713 commit 01d22e1
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 24 deletions.
28 changes: 25 additions & 3 deletions pkg/server/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,14 @@ func NewServerAPIHandler(s Server) *APIHandler {
// ServeHTTP handles the requests for the machine config server
// API handler.
func (sh *APIHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet && r.Method != http.MethodHead {
w.Header().Set("Content-Length", "0")
w.WriteHeader(http.StatusMethodNotAllowed)
return
}

if r.URL.Path == "" {
w.Header().Set("Content-Length", "0")
w.WriteHeader(http.StatusBadRequest)
return
}
Expand All @@ -92,18 +98,34 @@ func (sh *APIHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {

conf, err := sh.server.GetConfig(cr)
if err != nil {
w.Header().Set("Content-Length", "0")
w.WriteHeader(http.StatusInternalServerError)
glog.Errorf("couldn't get config for req: %v, error: %v", cr, err)
return
}
if conf == nil && err == nil {
w.Header().Set("Content-Length", "0")
w.WriteHeader(http.StatusNotFound)
return
}

encoder := json.NewEncoder(w)
if err := encoder.Encode(conf); err != nil {
data, err := json.Marshal(conf)
if err != nil {
w.Header().Set("Content-Length", "0")
w.WriteHeader(http.StatusInternalServerError)
glog.Errorf("couldn't encode the config for req: %v, error: %v", cr, err)
glog.Errorf("failed to marshal %v config: %v", cr, err)
return
}

w.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
w.Header().Set("Content-Type", "application/json")
if r.Method == http.MethodHead {
w.WriteHeader(http.StatusOK)
return
}

_, err = w.Write(data)
if err != nil {
glog.Errorf("failed to write %v response: %v", cr, err)
}
}
127 changes: 106 additions & 21 deletions pkg/server/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package server

import (
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
Expand All @@ -17,50 +18,134 @@ func (ms *mockServer) GetConfig(pr poolRequest) (*ignv2_2types.Config, error) {
return ms.GetConfigFn(pr)
}

type checkResponse func(t *testing.T, response *http.Response)

type scenario struct {
name string
expectedStatus int
serverFunc func(poolRequest) (*ignv2_2types.Config, error)
name string
request *http.Request
serverFunc func(poolRequest) (*ignv2_2types.Config, error)
checkResponse checkResponse
}

func TestAPIHandler(t *testing.T) {
scenarios := []scenario{
{
name: "not-found",
expectedStatus: http.StatusNotFound,
name: "get non-config path that does not exist",
request: httptest.NewRequest(http.MethodGet, "http://testrequest/does-not-exist", nil),
serverFunc: func(poolRequest) (*ignv2_2types.Config, error) {
return nil, nil
},
checkResponse: func(t *testing.T, response *http.Response) {
checkStatus(t, response, http.StatusNotFound)
checkContentLength(t, response, 0)
checkBodyLength(t, response, 0)
},
},
{
name: "internal-server",
expectedStatus: http.StatusInternalServerError,
name: "get config path that does not exist",
request: httptest.NewRequest(http.MethodGet, "http://testrequest/config/does-not-exist", nil),
serverFunc: func(poolRequest) (*ignv2_2types.Config, error) {
return new(ignv2_2types.Config), fmt.Errorf("not acceptable")
},
checkResponse: func(t *testing.T, response *http.Response) {
checkStatus(t, response, http.StatusInternalServerError)
checkContentLength(t, response, 0)
checkBodyLength(t, response, 0)
},
},
{
name: "get config path that exists",
request: httptest.NewRequest(http.MethodGet, "http://testrequest/config/master", nil),
serverFunc: func(poolRequest) (*ignv2_2types.Config, error) {
return new(ignv2_2types.Config), nil
},
checkResponse: func(t *testing.T, response *http.Response) {
checkStatus(t, response, http.StatusOK)
checkContentType(t, response, "application/json")
checkContentLength(t, response, 114)
checkBodyLength(t, response, 114)
},
},
{
name: "head config path that exists",
request: httptest.NewRequest(http.MethodHead, "http://testrequest/config/master", nil),
serverFunc: func(poolRequest) (*ignv2_2types.Config, error) {
return new(ignv2_2types.Config), nil
},
checkResponse: func(t *testing.T, response *http.Response) {
checkStatus(t, response, http.StatusOK)
checkContentType(t, response, "application/json")
checkContentLength(t, response, 114)
checkBodyLength(t, response, 0)
},
},
{
name: "post non-config path that does not exist",
request: httptest.NewRequest(http.MethodPost, "http://testrequest/post", nil),
serverFunc: func(poolRequest) (*ignv2_2types.Config, error) {
return nil, nil
},
checkResponse: func(t *testing.T, response *http.Response) {
checkStatus(t, response, http.StatusMethodNotAllowed)
checkContentLength(t, response, 0)
checkBodyLength(t, response, 0)
},
},
{
name: "success",
expectedStatus: http.StatusOK,
name: "post config path that exists",
request: httptest.NewRequest(http.MethodPost, "http://testrequest/config/master", nil),
serverFunc: func(poolRequest) (*ignv2_2types.Config, error) {
return new(ignv2_2types.Config), nil
},
checkResponse: func(t *testing.T, response *http.Response) {
checkStatus(t, response, http.StatusMethodNotAllowed)
checkContentLength(t, response, 0)
checkBodyLength(t, response, 0)
},
},
}

for i := range scenarios {
req := httptest.NewRequest("GET", "http://testrequest/", nil)
w := httptest.NewRecorder()
ms := &mockServer{
GetConfigFn: scenarios[i].serverFunc,
}
handler := NewServerAPIHandler(ms)
handler.ServeHTTP(w, req)
for _, scenario := range scenarios {
t.Run(scenario.name, func(t *testing.T) {
w := httptest.NewRecorder()
ms := &mockServer{
GetConfigFn: scenario.serverFunc,
}
handler := NewServerAPIHandler(ms)
handler.ServeHTTP(w, scenario.request)

resp := w.Result()
resp := w.Result()
defer resp.Body.Close()
scenario.checkResponse(t, resp)
})
}
}

if resp.StatusCode != scenarios[i].expectedStatus {
t.Errorf("API Handler test failed for: %s, expected: %d, received: %d", scenarios[i].name, scenarios[i].expectedStatus, resp.StatusCode)
}
func checkStatus(t *testing.T, response *http.Response, expected int) {
if response.StatusCode != expected {
t.Errorf("expected response status %d, received %d", expected, response.StatusCode)
}
}

func checkContentType(t *testing.T, response *http.Response, expected string) {
actual := response.Header.Get("Content-Type")
if actual != expected {
t.Errorf("expected response Content-Type %q, received %q", expected, actual)
}
}

func checkContentLength(t *testing.T, response *http.Response, l int) {
if int(response.ContentLength) != l {
t.Errorf("expected response Content-Length %d, received %d", l, int(response.ContentLength))
}
}

func checkBodyLength(t *testing.T, response *http.Response, l int) {
body, err := ioutil.ReadAll(response.Body)
if err != nil {
t.Fatal(err)
}
if len(body) != l {
t.Errorf("expected response body length %d, received %d", l, len(body))
}
}

0 comments on commit 01d22e1

Please sign in to comment.