Skip to content

Commit

Permalink
Roundtrip should only return err from base().RoundTrip
Browse files Browse the repository at this point in the history
Update some log output
Remove unused test case from transport_test.go
  • Loading branch information
TaoZou1 committed Feb 7, 2022
1 parent 545bdb6 commit 7997527
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 133 deletions.
2 changes: 1 addition & 1 deletion pkg/nsx/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ func (ep *Endpoint) keepAlive() error {
return err
}
resp.Body = ioutil.NopCloser(bytes.NewReader(body))
err = util.InitErrorFromResponse(ep.Host(), resp)
err = util.InitErrorFromResponse(ep.Host(), resp.StatusCode, body)

if util.ShouldRegenerate(err) {
log.Error(err, "failed to validate API cluster due to an exception that calls for regeneration", "endpoint", ep.Host())
Expand Down
28 changes: 20 additions & 8 deletions pkg/nsx/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
package nsx

import (
"bytes"
"errors"
"io/ioutil"
"net/http"
"net/url"
"strings"
Expand Down Expand Up @@ -33,7 +35,7 @@ func (t *Transport) RoundTrip(r *http.Request) (*http.Response, error) {
var resp *http.Response
var resul error

err1 := retry.Do(
retry.Do(
func() error {
ep, err := t.selectEndpoint()
if err != nil {
Expand All @@ -54,14 +56,24 @@ func (t *Transport) RoundTrip(r *http.Request) (*http.Response, error) {
}
transTime := time.Since(start) - waitTime
ep.adjustRate(waitTime, resp.StatusCode)
log.V(1).Info("RoundTrip got response", "request", r.URL, "method", r.Method, "transTime", transTime)
if err = util.InitErrorFromResponse(ep.Host(), resp); err == nil {
ep.setAliveTime(start.Add(transTime))
log.V(1).Info("RoundTrip request", "request", r.URL, "method", r.Method, "transTime", transTime)
if resp == nil {
return nil
}
log.V(1).Info("request failed", "error", err.Error())

// refresh token here
body, err := ioutil.ReadAll(resp.Body)
resp.Body.Close()
resp.Body = ioutil.NopCloser(bytes.NewReader(body))

if err != nil {
log.Error(err, "failed to extract HTTP body")
return util.CreateGeneralManagerError(ep.Host(), "extract http", err.Error())
}

if err = util.InitErrorFromResponse(ep.Host(), resp.StatusCode, body); err == nil {
ep.setAliveTime(start.Add(transTime))
return nil
}
if util.ShouldRegenerate(err) {
ep.createAuthSession(t.config.ClientCertProvider, t.config.TokenProvider, t.config.Username, t.config.Password, jarCache)
}
Expand All @@ -78,7 +90,7 @@ func (t *Transport) RoundTrip(r *http.Request) (*http.Response, error) {
}), retry.LastErrorOnly(true),
)

return resp, err1
return resp, resul
}

func handleRoundTripError(err error, ep *Endpoint) error {
Expand Down Expand Up @@ -117,7 +129,7 @@ func (t *Transport) updateAuthInfo(r *http.Request, ep *Endpoint) {
ep.Unlock()
for _, cookie := range cookies {
if cookie == nil {
log.Error(errors.New("Cookie is nil."), "Update authentication info failed")
log.Error(errors.New("cookie is nil."), "Update authentication info failed")
}
r.Header.Set("Cookie", cookie.String())
}
Expand Down
83 changes: 1 addition & 82 deletions pkg/nsx/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,92 +13,13 @@ import (

"github.com/stretchr/testify/assert"
"github.com/vmware-tanzu/nsx-operator/pkg/nsx/ratelimiter"
"github.com/vmware-tanzu/nsx-operator/pkg/nsx/util"
)

var (
timeout = time.Duration(20)
idleConnTimeout = time.Duration(20)
)

func TestRoundTripConnectionRefused(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "hello")
}))
defer ts.Close()
a := "127.0.0.1"
config := NewConfig(a, "admin", "passw0rd", "", 10, 3, 20, 20, true, true, true, ratelimiter.AIMD, nil, nil, []string{})
cluster := &Cluster{}
tr := cluster.createTransport(config.TokenProvider, idleConnTimeout)
client := cluster.createHTTPClient(tr, idleConnTimeout)
noBClient := cluster.createNoBalancerClient(timeout, idleConnTimeout)
r := ratelimiter.NewRateLimiter(config.APIRateMode)
eps, _ := cluster.createEndpoints(config.APIManagers, &client, &noBClient, r, nil)
eps[0].status = UP
tr.endpoints = eps
req, err := http.NewRequest("GET", ts.URL, nil)
resp, err := tr.RoundTrip(req)
assert.NotNil(t, err, "Should report error")
_, ok := err.(*util.ServiceClusterUnavailable)
assert.True(t, ok, fmt.Sprintf("Return wrong error type %v", err))
assert.Nil(t, resp, "Resp should be nil")
}

func TestRoundTripDecodeBodyFailed(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "hello")
}))
defer ts.Close()
index := strings.Index(ts.URL, "//")
a := ts.URL[index+2:]
config := NewConfig(a, "admin", "passw0rd", "", 10, 3, 20, 20, true, true, true, ratelimiter.AIMD, nil, nil, []string{})
cluster := &Cluster{}
tr := cluster.createTransport(config.TokenProvider, timeout)
client := cluster.createHTTPClient(tr, 30)
noBClient := cluster.createNoBalancerClient(30, 20)
r := ratelimiter.NewRateLimiter(config.APIRateMode)
eps, _ := cluster.createEndpoints(config.APIManagers, &client, &noBClient, r, nil)
eps[0].status = UP
tr.endpoints = eps
req, err := http.NewRequest("GET", ts.URL, nil)
_, err = tr.RoundTrip(req)
_, ok := err.(util.ManagerError)
assert.True(t, ok, fmt.Sprintf("Return wrong error type %v", err))
}

func TestRoundTripAuthFailed(t *testing.T) {
assert := assert.New(t)
result := `{"module_name":"common-services","error_message":"The credentials were incorrect or the account specified has been locked","error_code":403}`
healthresult := `{
"healthy" : true,
"components_health" : "POLICY:UP, SEARCH:UP, MANAGER:UP, NODE_MGMT:UP, UI:UP"
}`
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.Index(r.URL.Path, "reverse-proxy/node/health") > 1 {
w.WriteHeader(http.StatusOK)
w.Write([]byte(healthresult))
} else {
w.WriteHeader(http.StatusForbidden)
w.Write([]byte(result))
}
}))
defer ts.Close()
index := strings.Index(ts.URL, "//")
a := ts.URL[index+2:]
config := NewConfig(a, "admin", "passw0rd", "", 10, 3, 20, 20, true, true, true, ratelimiter.AIMD, nil, nil, []string{})
cluster, err := NewCluster(config)
assert.Nil(err, fmt.Sprintf("Create cluster error %v", err))
cluster.endpoints[0], _ = NewEndpoint(ts.URL, &cluster.client, &cluster.noBalancerClient, cluster.endpoints[0].ratelimiter, nil)
cluster.endpoints[0].keepAlive()
tr := cluster.transport
req, err := http.NewRequest("GET", ts.URL, nil)
req.Header.Add("Accept", "application/json")
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
_, err = tr.RoundTrip(req)
_, ok := err.(*util.InvalidCredentials)
assert.True(ok, fmt.Sprintf("Wrong error type %v", err))
}

func TestRoundTripRetry(t *testing.T) {
assert := assert.New(t)
result := `{"module_name":"common-services","error_message":"The credentials were incorrect or the account specified has been locked","error_code":98}`
Expand Down Expand Up @@ -128,9 +49,7 @@ func TestRoundTripRetry(t *testing.T) {
req.Header.Add("Accept", "application/json")
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
_, err = tr.RoundTrip(req)
log.V(1).Info("", "errorType", err)
_, ok := err.(*util.CannotConnectToServer)
assert.True(ok, fmt.Sprintf("Wrong error type %v", err))
assert.Equal(err, nil)
}

func TestSelectEndpoint(t *testing.T) {
Expand Down
24 changes: 4 additions & 20 deletions pkg/nsx/util/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,8 @@
package util

import (
"bytes"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"reflect"
"strconv"
Expand Down Expand Up @@ -98,11 +96,8 @@ func ShouldRegenerate(err error) bool {
}

// InitErrorFromResponse returns error based on http.Response
func InitErrorFromResponse(host string, resp *http.Response) error {
if resp == nil {
return nil
}
detail, err := extractHTTPDetail(host, resp)
func InitErrorFromResponse(host string, statusCode int, body []byte) error {
detail, err := extractHTTPDetailFromBody(host, statusCode, body)
if err != nil {
return err
}
Expand All @@ -112,22 +107,10 @@ func InitErrorFromResponse(host string, resp *http.Response) error {
return httpErrortoNSXError(&detail)
}

func extractHTTPDetail(host string, resp *http.Response) (ErrorDetail, error) {
ed := ErrorDetail{StatusCode: resp.StatusCode}
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
log.Error(err, "failed to extract HTTP detail")
return ed, CreateGeneralManagerError(host, "extract http", err.Error())
}
//TODO, log some fields of response
resp.Body = ioutil.NopCloser(bytes.NewReader(body))
return extractHTTPDetailFromBody(host, resp.StatusCode, body)
}

func extractHTTPDetailFromBody(host string, statusCode int, body []byte) (ErrorDetail, error) {
ec := ErrorDetail{StatusCode: statusCode}
if len(body) == 0 {
log.V(1).Info("aborting HTTP detail extraction since body len is 0")
log.V(1).Info("body length is 0")
return ec, nil
}
var res responseBody
Expand All @@ -137,6 +120,7 @@ func extractHTTPDetailFromBody(host string, statusCode int, body []byte) (ErrorD
}

ec.ErrorCode = res.ErrorCode
log.V(1).Info("http response", "status code", statusCode, "body", res)
var msg []string
for _, a := range res.RelatedErr {
ec.RelatedErrorCodes = append(ec.RelatedErrorCodes, a.ErrorCode)
Expand Down
24 changes: 2 additions & 22 deletions pkg/nsx/util/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ package util
import (
"fmt"
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
Expand Down Expand Up @@ -46,24 +45,6 @@ func TestHttpErrortoNSXError(t *testing.T) {

}

func TestExtractHTTPDetail(t *testing.T) {
handler := func(w http.ResponseWriter, r *http.Request) {
io.WriteString(w, "hello, world")
}
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
w := httptest.NewRecorder()
handler(w, req)
resp := w.Result()
_, err := extractHTTPDetail("10.0.0.1", resp)
if err != nil {
if _, ok := err.(ManagerError); !ok {
t.Errorf("Extract wrong error type : %v", err)
}
}
assert.NotNil(t, err, "Extract wrong error type")

}

func TestInitErrorFromResponse(t *testing.T) {
assert := assert.New(t)
result := `{
Expand Down Expand Up @@ -100,9 +81,8 @@ func TestInitErrorFromResponse(t *testing.T) {
w := httptest.NewRecorder()
handler(w, req)
resp := w.Result()
err := InitErrorFromResponse("10.0.0.1", resp)
body, err := ioutil.ReadAll(resp.Body)
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
err := InitErrorFromResponse("10.0.0.1", resp.StatusCode, body)

assert.Equal(err, nil, "Read resp body error")
assert.Equal(string(body), result, "Read resp body error")
Expand Down

0 comments on commit 7997527

Please sign in to comment.