Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some refactoring in azcore #6982

Merged
merged 1 commit into from
Jan 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 16 additions & 28 deletions sdk/azcore/policy_retry.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,33 +59,15 @@ var (
}
)

func (o RetryOptions) defaults() RetryOptions {
// We assume the following:
// 1. o.MaxTries >= 0
// 2. o.TryTimeout, o.RetryDelay, and o.MaxRetryDelay >=0
// 3. o.RetryDelay <= o.MaxRetryDelay
// 4. Both o.RetryDelay and o.MaxRetryDelay must be 0 or neither can be 0

if len(o.StatusCodes) == 0 {
o.StatusCodes = StatusCodesForRetry[:]
// DefaultRetryOptions returns an instance of RetryOptions initialized with default values.
func DefaultRetryOptions() RetryOptions {
return RetryOptions{
StatusCodes: StatusCodesForRetry[:],
MaxTries: defaultMaxTries,
TryTimeout: 1 * time.Minute,
RetryDelay: 4 * time.Second,
MaxRetryDelay: 120 * time.Second,
}

IfDefault := func(current *time.Duration, desired time.Duration) {
if *current == time.Duration(0) {
*current = desired
}
}

// Set defaults if unspecified
if o.MaxTries == 0 {
o.MaxTries = defaultMaxTries
}

IfDefault(&o.TryTimeout, 1*time.Minute)
IfDefault(&o.RetryDelay, 4*time.Second)
IfDefault(&o.MaxRetryDelay, 120*time.Second)

return o
}

func (o RetryOptions) calcDelay(try int32) time.Duration { // try is >=1; never 0
Expand All @@ -108,8 +90,14 @@ func (o RetryOptions) calcDelay(try int32) time.Duration { // try is >=1; never
}

// NewRetryPolicy creates a policy object configured using the specified options.
func NewRetryPolicy(o RetryOptions) Policy {
return &retryPolicy{options: o.defaults()} // Force defaults to be calculated
// Pass nil to accept the default values; this is the same as passing the result
// from a call to DefaultRetryOptions().
func NewRetryPolicy(o *RetryOptions) Policy {
if o == nil {
def := DefaultRetryOptions()
o = &def
}
return &retryPolicy{options: *o}
}

type retryPolicy struct {
Expand Down
22 changes: 12 additions & 10 deletions sdk/azcore/policy_retry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,17 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/internal/mock"
)

const retryDelay = 20 * time.Millisecond
func testRetryOptions() *RetryOptions {
def := DefaultRetryOptions()
def.RetryDelay = 20 * time.Millisecond
return &def
}

func TestRetryPolicySuccess(t *testing.T) {
srv, close := mock.NewServer()
defer close()
srv.SetResponse(mock.WithStatusCode(http.StatusOK))
pl := NewPipeline(srv, NewRetryPolicy(RetryOptions{}))
pl := NewPipeline(srv, NewRetryPolicy(nil))
req := NewRequest(http.MethodGet, srv.URL())
body := newRewindTrackingBody("stuff")
req.SetBody(body)
Expand All @@ -46,7 +50,7 @@ func TestRetryPolicyFailOnStatusCode(t *testing.T) {
srv, close := mock.NewServer()
defer close()
srv.SetResponse(mock.WithStatusCode(http.StatusInternalServerError))
pl := NewPipeline(srv, NewRetryPolicy(RetryOptions{RetryDelay: retryDelay}))
pl := NewPipeline(srv, NewRetryPolicy(testRetryOptions()))
req := NewRequest(http.MethodGet, srv.URL())
body := newRewindTrackingBody("stuff")
req.SetBody(body)
Expand Down Expand Up @@ -74,7 +78,7 @@ func TestRetryPolicySuccessWithRetry(t *testing.T) {
srv.AppendResponse(mock.WithStatusCode(http.StatusRequestTimeout))
srv.AppendResponse(mock.WithStatusCode(http.StatusInternalServerError))
srv.AppendResponse()
pl := NewPipeline(srv, NewRetryPolicy(RetryOptions{RetryDelay: retryDelay}))
pl := NewPipeline(srv, NewRetryPolicy(testRetryOptions()))
req := NewRequest(http.MethodGet, srv.URL())
body := newRewindTrackingBody("stuff")
req.SetBody(body)
Expand All @@ -101,7 +105,7 @@ func TestRetryPolicyFailOnError(t *testing.T) {
defer close()
fakeErr := errors.New("bogus error")
srv.SetError(fakeErr)
pl := NewPipeline(srv, NewRetryPolicy(RetryOptions{RetryDelay: retryDelay}))
pl := NewPipeline(srv, NewRetryPolicy(testRetryOptions()))
req := NewRequest(http.MethodPost, srv.URL())
body := newRewindTrackingBody("stuff")
req.SetBody(body)
Expand Down Expand Up @@ -130,7 +134,7 @@ func TestRetryPolicySuccessWithRetryComplex(t *testing.T) {
srv.AppendError(errors.New("bogus error"))
srv.AppendResponse(mock.WithStatusCode(http.StatusInternalServerError))
srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted))
pl := NewPipeline(srv, NewRetryPolicy(RetryOptions{RetryDelay: retryDelay}))
pl := NewPipeline(srv, NewRetryPolicy(testRetryOptions()))
req := NewRequest(http.MethodGet, srv.URL())
body := newRewindTrackingBody("stuff")
req.SetBody(body)
Expand All @@ -156,7 +160,7 @@ func TestRetryPolicyRequestTimedOut(t *testing.T) {
srv, close := mock.NewServer()
defer close()
srv.SetError(errors.New("bogus error"))
pl := NewPipeline(srv, NewRetryPolicy(RetryOptions{}))
pl := NewPipeline(srv, NewRetryPolicy(nil))
req := NewRequest(http.MethodPost, srv.URL())
body := newRewindTrackingBody("stuff")
req.SetBody(body)
Expand Down Expand Up @@ -195,9 +199,7 @@ func TestRetryPolicyIsNotRetriable(t *testing.T) {
defer close()
srv.AppendResponse(mock.WithStatusCode(http.StatusRequestTimeout))
srv.AppendError(theErr)
pl := NewPipeline(srv, NewRetryPolicy(RetryOptions{
RetryDelay: retryDelay,
}))
pl := NewPipeline(srv, NewRetryPolicy(testRetryOptions()))
_, err := pl.Do(context.Background(), NewRequest(http.MethodGet, srv.URL()))
if err == nil {
t.Fatal("unexpected nil error")
Expand Down
12 changes: 0 additions & 12 deletions sdk/azcore/response.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,6 @@ func (r *Response) payload() []byte {
return nil
}

// CheckStatusCode returns a RequestError if the Response's status code isn't one of the specified values.
func (r *Response) CheckStatusCode(statusCodes ...int) error {
if !r.HasStatusCode(statusCodes...) {
msg := r.Status
if len(r.payload()) > 0 {
msg = string(r.payload())
}
return newRequestError(msg, r)
}
return nil
}

// HasStatusCode returns true if the Response's status code is one of the specified values.
func (r *Response) HasStatusCode(statusCodes ...int) bool {
if r == nil {
Expand Down
27 changes: 10 additions & 17 deletions sdk/azcore/response_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ func TestResponseUnmarshalXML(t *testing.T) {
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if err := resp.CheckStatusCode(http.StatusOK); err != nil {
t.Fatalf("unexpected status code error: %v", err)
if !resp.HasStatusCode(http.StatusOK) {
t.Fatalf("unexpected status code: %d", resp.StatusCode)
}
var tx testXML
if err := resp.UnmarshalAsXML(&tx); err != nil {
Expand All @@ -44,15 +44,8 @@ func TestResponseFailureStatusCode(t *testing.T) {
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if err = resp.CheckStatusCode(http.StatusOK); err == nil {
t.Fatal("unexpected nil status code error")
}
re, ok := err.(RequestError)
if !ok {
t.Fatal("expected RequestError type")
}
if re.Response().StatusCode != http.StatusForbidden {
t.Fatal("unexpected response")
if resp.HasStatusCode(http.StatusOK) {
t.Fatalf("unexpected status code: %d", resp.StatusCode)
}
}

Expand All @@ -65,8 +58,8 @@ func TestResponseUnmarshalJSON(t *testing.T) {
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if err := resp.CheckStatusCode(http.StatusOK); err != nil {
t.Fatalf("unexpected status code error: %v", err)
if !resp.HasStatusCode(http.StatusOK) {
t.Fatalf("unexpected status code: %d", resp.StatusCode)
}
var tx testJSON
if err := resp.UnmarshalAsJSON(&tx); err != nil {
Expand All @@ -86,8 +79,8 @@ func TestResponseUnmarshalJSONNoBody(t *testing.T) {
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if err := resp.CheckStatusCode(http.StatusOK); err != nil {
t.Fatalf("unexpected status code error: %v", err)
if !resp.HasStatusCode(http.StatusOK) {
t.Fatalf("unexpected status code: %d", resp.StatusCode)
}
if err := resp.UnmarshalAsJSON(nil); err != nil {
t.Fatalf("unexpected error unmarshalling: %v", err)
Expand All @@ -103,8 +96,8 @@ func TestResponseUnmarshalXMLNoBody(t *testing.T) {
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if err := resp.CheckStatusCode(http.StatusOK); err != nil {
t.Fatalf("unexpected status code error: %v", err)
if !resp.HasStatusCode(http.StatusOK) {
t.Fatalf("unexpected status code: %d", resp.StatusCode)
}
if err := resp.UnmarshalAsXML(nil); err != nil {
t.Fatalf("unexpected error unmarshalling: %v", err)
Expand Down