Skip to content

Commit

Permalink
apiclient: fix http roundtrip (clone body also) (#1758)
Browse files Browse the repository at this point in the history
* apiclient: fix http roundtrip (clone body also)
  • Loading branch information
he2ss authored Dec 14, 2022
1 parent fe23da6 commit 579cecd
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 13 deletions.
50 changes: 38 additions & 12 deletions pkg/apiclient/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ func (t *APIKeyTransport) transport() http.RoundTripper {
type JWTTransport struct {
MachineID *string
Password *strfmt.Password
token string
Token string
Expiration time.Time
Scenarios []string
URL *url.URL
Expand All @@ -88,6 +88,7 @@ type JWTTransport struct {
// It will default to http.DefaultTransport if nil.
Transport http.RoundTripper
UpdateScenario func() ([]string, error)
NbRetry int
}

func (t *JWTTransport) refreshJwtToken() error {
Expand Down Expand Up @@ -161,45 +162,63 @@ func (t *JWTTransport) refreshJwtToken() error {
if err := t.Expiration.UnmarshalText([]byte(response.Expire)); err != nil {
return errors.Wrap(err, "unable to parse jwt expiration")
}
t.token = response.Token
t.Token = response.Token

log.Debugf("token %s will expire on %s", t.token, t.Expiration.String())
log.Debugf("token %s will expire on %s", t.Token, t.Expiration.String())
return nil
}

// RoundTrip implements the RoundTripper interface.
func (t *JWTTransport) RoundTrip(req *http.Request) (*http.Response, error) {
if t.token == "" || t.Expiration.Add(-time.Minute).Before(time.Now().UTC()) {
if t.NbRetry > 1 {
t.NbRetry = 0
return nil, fmt.Errorf("unable to refresh JWT token multiple times")
}
if t.Token == "" || t.Expiration.Add(-time.Minute).Before(time.Now().UTC()) {
if err := t.refreshJwtToken(); err != nil {
return nil, err
}
}

if t.UserAgent != "" {
req.Header.Add("User-Agent", t.UserAgent)
}

// We must make a copy of the Request so
// that we don't modify the Request we were given. This is required by the
// specification of http.RoundTripper.
req = cloneRequest(req)
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", t.token))
log.Debugf("req-jwt: %s %s", req.Method, req.URL.String())
clonedReq := cloneRequest(req)

req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", t.Token))

if log.GetLevel() >= log.TraceLevel {
//requestToDump := cloneRequest(req)
dump, _ := httputil.DumpRequest(req, true)
log.Tracef("req-jwt: %s", string(dump))
}
if t.UserAgent != "" {
req.Header.Add("User-Agent", t.UserAgent)
}

// Make the HTTP request.
resp, err := t.transport().RoundTrip(req)
if log.GetLevel() >= log.TraceLevel {
dump, _ := httputil.DumpResponse(resp, true)
log.Tracef("resp-jwt: %s (err:%v)", string(dump), err)
}
if err != nil || resp.StatusCode == http.StatusForbidden || resp.StatusCode == http.StatusUnauthorized {
if err != nil {
/*we had an error (network error for example, or 401 because token is refused), reset the token ?*/
t.token = ""
t.Token = ""
return resp, errors.Wrapf(err, "performing jwt auth")
}

if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
t.Token = ""
t.NbRetry++
return t.RoundTrip(clonedReq)
}

t.NbRetry = 0

log.Debugf("resp-jwt: %d", resp.StatusCode)

return resp, nil
}

Expand All @@ -225,5 +244,12 @@ func cloneRequest(r *http.Request) *http.Request {
for k, s := range r.Header {
r2.Header[k] = append([]string(nil), s...)
}

if r.Body != nil {
var b bytes.Buffer
b.ReadFrom(r.Body)
r.Body = io.NopCloser(&b)
r2.Body = io.NopCloser(bytes.NewReader(b.Bytes()))
}
return r2
}
2 changes: 1 addition & 1 deletion pkg/apiclient/auth_service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -234,5 +234,5 @@ func TestWatcherEnroll(t *testing.T) {
}

_, err = client.Auth.EnrollWatcher(context.Background(), "badkey", "", []string{}, false)
assert.Contains(t, err.Error(), "the attachment key provided is not valid")
assert.Contains(t, err.Error(), "unable to refresh JWT token multiple times", "got %s", err.Error())
}
1 change: 1 addition & 0 deletions pkg/apiclient/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ func NewClient(config *Config) (*ApiClient, error) {
UserAgent: config.UserAgent,
VersionPrefix: config.VersionPrefix,
UpdateScenario: config.UpdateScenario,
NbRetry: 0,
}
tlsconfig := tls.Config{InsecureSkipVerify: InsecureSkipVerify}
if Cert != nil {
Expand Down

0 comments on commit 579cecd

Please sign in to comment.