From 579cecde04df7a552d754f99a0c2d687983e1022 Mon Sep 17 00:00:00 2001 From: he2ss Date: Wed, 14 Dec 2022 16:42:46 +0100 Subject: [PATCH] apiclient: fix http roundtrip (clone body also) (#1758) * apiclient: fix http roundtrip (clone body also) --- pkg/apiclient/auth.go | 50 +++++++++++++++++++++++------- pkg/apiclient/auth_service_test.go | 2 +- pkg/apiclient/client.go | 1 + 3 files changed, 40 insertions(+), 13 deletions(-) diff --git a/pkg/apiclient/auth.go b/pkg/apiclient/auth.go index 6898ffe2adf..747fa76d3da 100644 --- a/pkg/apiclient/auth.go +++ b/pkg/apiclient/auth.go @@ -78,7 +78,7 @@ func (t *APIKeyTransport) transport() http.RoundTripper { type JWTTransport struct { MachineID *string Password *strfmt.Password - token string + Token string Expiration time.Time Scenarios []string URL *url.URL @@ -88,6 +88,7 @@ type JWTTransport struct { // It will default to http.DefaultTransport if nil. Transport http.RoundTripper UpdateScenario func() ([]string, error) + NbRetry int } func (t *JWTTransport) refreshJwtToken() error { @@ -161,45 +162,63 @@ func (t *JWTTransport) refreshJwtToken() error { if err := t.Expiration.UnmarshalText([]byte(response.Expire)); err != nil { return errors.Wrap(err, "unable to parse jwt expiration") } - t.token = response.Token + t.Token = response.Token - log.Debugf("token %s will expire on %s", t.token, t.Expiration.String()) + log.Debugf("token %s will expire on %s", t.Token, t.Expiration.String()) return nil } // RoundTrip implements the RoundTripper interface. func (t *JWTTransport) RoundTrip(req *http.Request) (*http.Response, error) { - if t.token == "" || t.Expiration.Add(-time.Minute).Before(time.Now().UTC()) { + if t.NbRetry > 1 { + t.NbRetry = 0 + return nil, fmt.Errorf("unable to refresh JWT token multiple times") + } + if t.Token == "" || t.Expiration.Add(-time.Minute).Before(time.Now().UTC()) { if err := t.refreshJwtToken(); err != nil { return nil, err } } + if t.UserAgent != "" { + req.Header.Add("User-Agent", t.UserAgent) + } + // We must make a copy of the Request so // that we don't modify the Request we were given. This is required by the // specification of http.RoundTripper. - req = cloneRequest(req) - req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", t.token)) - log.Debugf("req-jwt: %s %s", req.Method, req.URL.String()) + clonedReq := cloneRequest(req) + + req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", t.Token)) + if log.GetLevel() >= log.TraceLevel { + //requestToDump := cloneRequest(req) dump, _ := httputil.DumpRequest(req, true) log.Tracef("req-jwt: %s", string(dump)) } - if t.UserAgent != "" { - req.Header.Add("User-Agent", t.UserAgent) - } + // Make the HTTP request. resp, err := t.transport().RoundTrip(req) if log.GetLevel() >= log.TraceLevel { dump, _ := httputil.DumpResponse(resp, true) log.Tracef("resp-jwt: %s (err:%v)", string(dump), err) } - if err != nil || resp.StatusCode == http.StatusForbidden || resp.StatusCode == http.StatusUnauthorized { + if err != nil { /*we had an error (network error for example, or 401 because token is refused), reset the token ?*/ - t.token = "" + t.Token = "" return resp, errors.Wrapf(err, "performing jwt auth") } + + if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden { + t.Token = "" + t.NbRetry++ + return t.RoundTrip(clonedReq) + } + + t.NbRetry = 0 + log.Debugf("resp-jwt: %d", resp.StatusCode) + return resp, nil } @@ -225,5 +244,12 @@ func cloneRequest(r *http.Request) *http.Request { for k, s := range r.Header { r2.Header[k] = append([]string(nil), s...) } + + if r.Body != nil { + var b bytes.Buffer + b.ReadFrom(r.Body) + r.Body = io.NopCloser(&b) + r2.Body = io.NopCloser(bytes.NewReader(b.Bytes())) + } return r2 } diff --git a/pkg/apiclient/auth_service_test.go b/pkg/apiclient/auth_service_test.go index 1eb5138e3de..f671868a82e 100644 --- a/pkg/apiclient/auth_service_test.go +++ b/pkg/apiclient/auth_service_test.go @@ -234,5 +234,5 @@ func TestWatcherEnroll(t *testing.T) { } _, err = client.Auth.EnrollWatcher(context.Background(), "badkey", "", []string{}, false) - assert.Contains(t, err.Error(), "the attachment key provided is not valid") + assert.Contains(t, err.Error(), "unable to refresh JWT token multiple times", "got %s", err.Error()) } diff --git a/pkg/apiclient/client.go b/pkg/apiclient/client.go index f6cc738947b..3229d29b5a5 100644 --- a/pkg/apiclient/client.go +++ b/pkg/apiclient/client.go @@ -51,6 +51,7 @@ func NewClient(config *Config) (*ApiClient, error) { UserAgent: config.UserAgent, VersionPrefix: config.VersionPrefix, UpdateScenario: config.UpdateScenario, + NbRetry: 0, } tlsconfig := tls.Config{InsecureSkipVerify: InsecureSkipVerify} if Cert != nil {