Skip to content

Commit

Permalink
httptransport: exit goroutine in error helper
Browse files Browse the repository at this point in the history
This change makes the `apiError` helper slightly tricksy in return for
making it impossible to misuse. Callers don't have to return after
calling `apiError;` instead the call just never returns.

Signed-off-by: Hank Donnay <hdonnay@redhat.com>
  • Loading branch information
hdonnay committed Oct 20, 2023
1 parent 41cda1f commit bddbc57
Show file tree
Hide file tree
Showing 7 changed files with 15 additions and 58 deletions.
1 change: 0 additions & 1 deletion httptransport/concurrentlimit.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ func (l *limitHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
Msg("rate limited HTTP request")

apiError(ctx, w, http.StatusTooManyRequests, "server handling too many requests")
return
}
defer sem.Release(1)
}
Expand Down
3 changes: 0 additions & 3 deletions httptransport/discoveryhandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,13 @@ func DiscoveryHandler() http.Handler {
ctx := r.Context()
if r.Method != http.MethodGet {
apiError(ctx, w, http.StatusMethodNotAllowed, "endpoint only allows GET")
return
}
switch err := pickContentType(w, r, allow); {
case errors.Is(err, nil):
case errors.Is(err, ErrMediaType):
apiError(ctx, w, http.StatusUnsupportedMediaType, "unable to negotiate common media type for %v", allow)
return
default:
apiError(ctx, w, http.StatusInternalServerError, "unexpected error: %v", err)
return
}
w.Header().Set("etag", openapiJSONEtag)
var err error
Expand Down
15 changes: 11 additions & 4 deletions httptransport/discoveryhandler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,20 @@ func TestDiscoveryFailure(t *testing.T) {
h := DiscoveryHandler()

r := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/openapi/v1", nil).WithContext(ctx)
req.Header.Set("Accept", "application/yaml")
h.ServeHTTP(r, req)
// Needed because handlers exit the goroutine.
done := make(chan struct{})
go func() {
defer close(done)
req := httptest.NewRequest("GET", "/openapi/v1", nil).WithContext(ctx)
req.Header.Set("Accept", "application/yaml")
h.ServeHTTP(r, req)
}()
<-done

resp := r.Result()
t.Log(resp.Status)
if got, want := resp.StatusCode, http.StatusUnsupportedMediaType; got != want {
t.Fatalf("got status code: %v want status code: %v", got, want)
t.Errorf("got status code: %v want status code: %v", got, want)
}
}

Expand Down
4 changes: 4 additions & 0 deletions httptransport/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,15 @@ import (
"encoding/json"
"fmt"
"net/http"
"runtime"

"github.com/quay/zlog"
)

// ApiError writes an untyped (that is, "application/json") error with the
// provided HTTP status code and message.
//
// ApiError does not return, but instead causes the goroutine to exit.
func apiError(ctx context.Context, w http.ResponseWriter, code int, f string, v ...interface{}) {
const errheader = `Clair-Error`
h := w.Header()
Expand Down Expand Up @@ -55,4 +58,5 @@ func apiError(ctx context.Context, w http.ResponseWriter, code int, f string, v
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
runtime.Goexit()
}
25 changes: 0 additions & 25 deletions httptransport/indexer_v1.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ func (h *IndexerV1) indexReport(w http.ResponseWriter, r *http.Request) {
case http.MethodDelete:
default:
apiError(ctx, w, http.StatusMethodNotAllowed, "method disallowed: %s", r.Method)
return
}
defer r.Body.Close()
dec := codec.GetDecoder(r.Body)
Expand All @@ -89,16 +88,13 @@ func (h *IndexerV1) indexReport(w http.ResponseWriter, r *http.Request) {
state, err := h.srv.State(ctx)
if err != nil {
apiError(ctx, w, http.StatusInternalServerError, "could not retrieve indexer state: %v", err)
return
}
var m claircore.Manifest
if err := dec.Decode(&m); err != nil {
apiError(ctx, w, http.StatusBadRequest, "failed to deserialize manifest: %v", err)
return
}
if m.Hash.String() == "" || len(m.Layers) == 0 {
apiError(ctx, w, http.StatusBadRequest, "bogus manifest")
return
}
next := path.Join(r.URL.Path, m.Hash.String())

Expand All @@ -117,10 +113,8 @@ func (h *IndexerV1) indexReport(w http.ResponseWriter, r *http.Request) {
case errors.Is(err, nil):
case errors.Is(err, tarfs.ErrFormat):
apiError(ctx, w, http.StatusBadRequest, "failed to start scan: %v", err)
return
default:
apiError(ctx, w, http.StatusInternalServerError, "failed to start scan: %v", err)
return
}

w.Header().Set("etag", validator)
Expand All @@ -134,12 +128,10 @@ func (h *IndexerV1) indexReport(w http.ResponseWriter, r *http.Request) {
var ds []claircore.Digest
if err := dec.Decode(&ds); err != nil {
apiError(ctx, w, http.StatusBadRequest, "failed to deserialize bulk delete: %v", err)
return
}
ds, err := h.srv.DeleteManifests(ctx, ds...)
if err != nil {
apiError(ctx, w, http.StatusInternalServerError, "could not delete manifests: %v", err)
return
}
zlog.Debug(ctx).
Int("count", len(ds)).
Expand All @@ -164,12 +156,10 @@ func (h *IndexerV1) indexReportOne(w http.ResponseWriter, r *http.Request) {
case http.MethodDelete:
default:
apiError(ctx, w, http.StatusMethodNotAllowed, "method disallowed: %s", r.Method)
return
}
d, err := getDigest(w, r)
if err != nil {
apiError(ctx, w, http.StatusBadRequest, "malformed path: %v", err)
return
}
switch r.Method {
case http.MethodGet:
Expand All @@ -178,16 +168,13 @@ func (h *IndexerV1) indexReportOne(w http.ResponseWriter, r *http.Request) {
case errors.Is(err, nil): // OK
case errors.Is(err, ErrMediaType):
apiError(ctx, w, http.StatusUnsupportedMediaType, "unable to negotiate common media type for %v", allow)
return
default:
apiError(ctx, w, http.StatusBadRequest, "malformed request: %v", err)
return
}

state, err := h.srv.State(ctx)
if err != nil {
apiError(ctx, w, http.StatusInternalServerError, "could not retrieve indexer state: %v", err)
return
}
validator := `"` + state + `"`
if unmodified(r, validator) {
Expand All @@ -198,11 +185,9 @@ func (h *IndexerV1) indexReportOne(w http.ResponseWriter, r *http.Request) {
report, ok, err := h.srv.IndexReport(ctx, d)
if !ok {
apiError(ctx, w, http.StatusNotFound, "index report not found")
return
}
if err != nil {
apiError(ctx, w, http.StatusInternalServerError, "could not retrieve index report: %v", err)
return
}

w.Header().Add("etag", validator)
Expand All @@ -213,7 +198,6 @@ func (h *IndexerV1) indexReportOne(w http.ResponseWriter, r *http.Request) {
case http.MethodDelete:
if _, err := h.srv.DeleteManifests(ctx, d); err != nil {
apiError(ctx, w, http.StatusInternalServerError, "unable to delete manifest: %v", err)
return
}
w.WriteHeader(http.StatusNoContent)
}
Expand All @@ -223,22 +207,18 @@ func (h *IndexerV1) indexState(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if r.Method != http.MethodGet {
apiError(ctx, w, http.StatusMethodNotAllowed, "method disallowed: %s", r.Method)
return
}
allow := []string{"application/vnd.clair.indexstate.v1+json", "application/json"}
switch err := pickContentType(w, r, allow); {
case errors.Is(err, nil): // OK
case errors.Is(err, ErrMediaType):
apiError(ctx, w, http.StatusUnsupportedMediaType, "unable to negotiate common media type for %v", allow)
return
default:
apiError(ctx, w, http.StatusBadRequest, "malformed request: %v", err)
return
}
s, err := h.srv.State(ctx)
if err != nil {
apiError(ctx, w, http.StatusInternalServerError, "could not retrieve indexer state: %v", err)
return
}

tag := `"` + s + `"`
Expand All @@ -264,17 +244,14 @@ func (h *IndexerV1) affectedManifests(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if r.Method != http.MethodPost {
apiError(ctx, w, http.StatusMethodNotAllowed, "method disallowed: %s", r.Method)
return
}
allow := []string{"application/vnd.clair.affectedmanifests.v1+json", "application/json"}
switch err := pickContentType(w, r, allow); {
case errors.Is(err, nil): // OK
case errors.Is(err, ErrMediaType):
apiError(ctx, w, http.StatusUnsupportedMediaType, "unable to negotiate common media type for %v", allow)
return
default:
apiError(ctx, w, http.StatusBadRequest, "malformed request: %v", err)
return
}

var vulnerabilities struct {
Expand All @@ -284,13 +261,11 @@ func (h *IndexerV1) affectedManifests(w http.ResponseWriter, r *http.Request) {
defer codec.PutDecoder(dec)
if err := dec.Decode(&vulnerabilities); err != nil {
apiError(ctx, w, http.StatusBadRequest, "failed to deserialize vulnerabilities: %v", err)
return
}

affected, err := h.srv.AffectedManifests(ctx, vulnerabilities.V)
if err != nil {
apiError(ctx, w, http.StatusInternalServerError, "could not retrieve affected manifests: %v", err)
return
}

defer writerError(w, &err)
Expand Down
18 changes: 0 additions & 18 deletions httptransport/matcher_v1.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ func (h *MatcherV1) vulnerabilityReport(w http.ResponseWriter, r *http.Request)

if r.Method != http.MethodGet {
apiError(ctx, w, http.StatusMethodNotAllowed, "endpoint only allows GET")
return
}
ctx, done := context.WithCancel(ctx)
defer done()
Expand All @@ -95,18 +94,15 @@ func (h *MatcherV1) vulnerabilityReport(w http.ResponseWriter, r *http.Request)
manifestStr := path.Base(r.URL.Path)
if manifestStr == "" {
apiError(ctx, w, http.StatusBadRequest, "malformed path. provide a single manifest hash")
return
}
manifest, err := claircore.ParseDigest(manifestStr)
if err != nil {
apiError(ctx, w, http.StatusBadRequest, "malformed path: %v", err)
return
}

initd, err := h.srv.Initialized(ctx)
if err != nil {
apiError(ctx, w, http.StatusInternalServerError, err.Error())
return
}
if !initd {
w.WriteHeader(http.StatusAccepted)
Expand All @@ -117,18 +113,15 @@ func (h *MatcherV1) vulnerabilityReport(w http.ResponseWriter, r *http.Request)
// check err first
if err != nil {
apiError(ctx, w, http.StatusInternalServerError, "experienced a server side error: %v", err)
return
}
// now check bool only after confirming no err
if !ok {
apiError(ctx, w, http.StatusNotFound, "index report for manifest %q not found", manifest.String())
return
}

vulnReport, err := h.srv.Scan(ctx, indexReport)
if err != nil {
apiError(ctx, w, http.StatusInternalServerError, "failed to start scan: %v", err)
return
}

w.Header().Set("content-type", "application/json")
Expand All @@ -146,7 +139,6 @@ func (h *MatcherV1) updateDiffHandler(w http.ResponseWriter, r *http.Request) {

if r.Method != http.MethodGet {
apiError(ctx, w, http.StatusMethodNotAllowed, "endpoint only allows GET")
return
}
// prev param is optional.
var prev uuid.UUID
Expand All @@ -155,7 +147,6 @@ func (h *MatcherV1) updateDiffHandler(w http.ResponseWriter, r *http.Request) {
prev, err = uuid.Parse(param)
if err != nil {
apiError(ctx, w, http.StatusBadRequest, "could not parse \"prev\" query param into uuid")
return
}
}

Expand All @@ -164,17 +155,14 @@ func (h *MatcherV1) updateDiffHandler(w http.ResponseWriter, r *http.Request) {
var param string
if param = r.URL.Query().Get("cur"); param == "" {
apiError(ctx, w, http.StatusBadRequest, "\"cur\" query param is required")
return
}
if cur, err = uuid.Parse(param); err != nil {
apiError(ctx, w, http.StatusBadRequest, "could not parse \"cur\" query param into uuid")
return
}

diff, err := h.srv.UpdateDiff(ctx, prev, cur)
if err != nil {
apiError(ctx, w, http.StatusInternalServerError, "could not get update operations: %v", err)
return
}

defer writerError(w, &err)()
Expand All @@ -191,7 +179,6 @@ func (h *MatcherV1) updateOperationHandlerGet(w http.ResponseWriter, r *http.Req
case http.MethodGet:
default:
apiError(ctx, w, http.StatusMethodNotAllowed, "method disallowed: %s", r.Method)
return
}

kind := driver.VulnerabilityKind
Expand All @@ -202,7 +189,6 @@ func (h *MatcherV1) updateOperationHandlerGet(w http.ResponseWriter, r *http.Req
// Leave as default
default:
apiError(ctx, w, http.StatusBadRequest, "unknown kind: %q", k)
return
}

// handle conditional request. this is an optimization
Expand All @@ -226,7 +212,6 @@ func (h *MatcherV1) updateOperationHandlerGet(w http.ResponseWriter, r *http.Req
}
if err != nil {
apiError(ctx, w, http.StatusInternalServerError, "could not get update operations: %v", err)
return
}

defer writerError(w, &err)()
Expand All @@ -242,7 +227,6 @@ func (h *MatcherV1) updateOperationHandlerDelete(w http.ResponseWriter, r *http.
case http.MethodDelete:
default:
apiError(ctx, w, http.StatusMethodNotAllowed, "method disallowed: %s", r.Method)
return
}

path := r.URL.Path
Expand All @@ -251,13 +235,11 @@ func (h *MatcherV1) updateOperationHandlerDelete(w http.ResponseWriter, r *http.
if err != nil {
zlog.Warn(ctx).Err(err).Msg("could not deserialize manifest")
apiError(ctx, w, http.StatusBadRequest, "could not deserialize manifest: %v", err)
return
}

_, err = h.srv.DeleteUpdateOperations(ctx, uuid)
if err != nil {
apiError(ctx, w, http.StatusInternalServerError, "could not get update operations: %v", err)
return
}
}

Expand Down
Loading

0 comments on commit bddbc57

Please sign in to comment.