Skip to content

Commit

Permalink
refact pkg/apiclient (#2846)
Browse files Browse the repository at this point in the history
* extract resperr.go
* extract method prepareRequest()
* reset token inside mutex
  • Loading branch information
mmetc authored Feb 22, 2024
1 parent 3e3df5e commit 8da490f
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 48 deletions.
37 changes: 26 additions & 11 deletions pkg/apiclient/auth_jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,27 +130,41 @@ func (t *JWTTransport) refreshJwtToken() error {
return nil
}

// RoundTrip implements the RoundTripper interface.
func (t *JWTTransport) RoundTrip(req *http.Request) (*http.Response, error) {
// In a few occasions several goroutines will execute refreshJwtToken concurrently which is useless and will cause overload on CAPI
// we use a mutex to avoid this
// We also bypass the refresh if we are requesting the login endpoint, as it does not require a token, and it leads to do 2 requests instead of one (refresh + actual login request)
func (t *JWTTransport) needsTokenRefresh() bool {
return t.Token == "" || t.Expiration.Add(-time.Minute).Before(time.Now().UTC())
}

// prepareRequest returns a copy of the request with the necessary authentication headers.
func (t *JWTTransport) prepareRequest(req *http.Request) (*http.Request, error) {
// In a few occasions several goroutines will execute refreshJwtToken concurrently which is useless
// and will cause overload on CAPI. We use a mutex to avoid this.
t.refreshTokenMutex.Lock()
if req.URL.Path != "/"+t.VersionPrefix+"/watchers/login" && (t.Token == "" || t.Expiration.Add(-time.Minute).Before(time.Now().UTC())) {
if err := t.refreshJwtToken(); err != nil {
t.refreshTokenMutex.Unlock()
defer t.refreshTokenMutex.Unlock()

// We bypass the refresh if we are requesting the login endpoint, as it does not require a token,
// and it leads to do 2 requests instead of one (refresh + actual login request).
if req.URL.Path != "/"+t.VersionPrefix+"/watchers/login" && t.needsTokenRefresh() {
if err := t.refreshJwtToken(); err != nil {
return nil, err
}
}
t.refreshTokenMutex.Unlock()

if t.UserAgent != "" {
req.Header.Add("User-Agent", t.UserAgent)
}

req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", t.Token))

return req, nil
}

// RoundTrip implements the RoundTripper interface.
func (t *JWTTransport) RoundTrip(req *http.Request) (*http.Response, error) {
req, err := t.prepareRequest(req)
if err != nil {
return nil, err
}

if log.GetLevel() >= log.TraceLevel {
//requestToDump := cloneRequest(req)
dump, _ := httputil.DumpRequest(req, true)
Expand All @@ -166,7 +180,7 @@ func (t *JWTTransport) RoundTrip(req *http.Request) (*http.Response, error) {

if err != nil {
// we had an error (network error for example, or 401 because token is refused), reset the token?
t.Token = ""
t.ResetToken()

return resp, fmt.Errorf("performing jwt auth: %w", err)
}
Expand All @@ -189,7 +203,8 @@ func (t *JWTTransport) ResetToken() {
t.refreshTokenMutex.Unlock()
}

// transport() returns a round tripper that retries once when the status is unauthorized, and 5 times when the infrastructure is overloaded.
// transport() returns a round tripper that retries once when the status is unauthorized,
// and 5 times when the infrastructure is overloaded.
func (t *JWTTransport) transport() http.RoundTripper {
transport := t.Transport
if transport == nil {
Expand Down
36 changes: 0 additions & 36 deletions pkg/apiclient/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@ import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"

Expand Down Expand Up @@ -167,44 +165,10 @@ type Response struct {
//...
}

type ErrorResponse struct {
models.ErrorResponse
}

func (e *ErrorResponse) Error() string {
err := fmt.Sprintf("API error: %s", *e.Message)
if len(e.Errors) > 0 {
err += fmt.Sprintf(" (%s)", e.Errors)
}

return err
}

func newResponse(r *http.Response) *Response {
return &Response{Response: r}
}

func CheckResponse(r *http.Response) error {
if c := r.StatusCode; 200 <= c && c <= 299 || c == 304 {
return nil
}

errorResponse := &ErrorResponse{}

data, err := io.ReadAll(r.Body)
if err == nil && len(data)>0 {
err := json.Unmarshal(data, errorResponse)
if err != nil {
return fmt.Errorf("http code %d, invalid body: %w", r.StatusCode, err)
}
} else {
errorResponse.Message = new(string)
*errorResponse.Message = fmt.Sprintf("http code %d, no error message", r.StatusCode)
}

return errorResponse
}

type ListOpts struct {
//Page int
//PerPage int
Expand Down
46 changes: 46 additions & 0 deletions pkg/apiclient/resperr.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package apiclient

import (
"encoding/json"
"fmt"
"io"
"net/http"

"github.com/crowdsecurity/go-cs-lib/ptr"

"github.com/crowdsecurity/crowdsec/pkg/models"
)

type ErrorResponse struct {
models.ErrorResponse
}

func (e *ErrorResponse) Error() string {
err := fmt.Sprintf("API error: %s", *e.Message)
if len(e.Errors) > 0 {
err += fmt.Sprintf(" (%s)", e.Errors)
}

return err
}

// CheckResponse verifies the API response and builds an appropriate Go error if necessary.
func CheckResponse(r *http.Response) error {
if c := r.StatusCode; 200 <= c && c <= 299 || c == 304 {
return nil
}

ret := &ErrorResponse{}

data, err := io.ReadAll(r.Body)
if err != nil || len(data) == 0 {
ret.Message = ptr.Of(fmt.Sprintf("http code %d, no error message", r.StatusCode))
return ret
}

if err := json.Unmarshal(data, ret); err != nil {
return fmt.Errorf("http code %d, invalid body: %w", r.StatusCode, err)
}

return ret
}
1 change: 0 additions & 1 deletion pkg/apiserver/apic.go
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,6 @@ func createAlertForDecision(decision *models.Decision) *models.Alert {
scenario = *decision.Scenario
scope = types.ListOrigin
default:
// XXX: this or nil?
scenario = ""
scope = ""

Expand Down

0 comments on commit 8da490f

Please sign in to comment.