-
Notifications
You must be signed in to change notification settings - Fork 42
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Fix] Support custom retry logic per method #1081
Changes from all commits
be93274
da7ab8c
f62a5f0
1e51e8b
6b66911
5bfa780
5be83f5
b947958
ab7ed04
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,9 +6,11 @@ import ( | |
"fmt" | ||
"net/http" | ||
"net/url" | ||
"regexp" | ||
"time" | ||
|
||
"github.com/databricks/databricks-sdk-go/apierr" | ||
"github.com/databricks/databricks-sdk-go/common" | ||
"github.com/databricks/databricks-sdk-go/credentials" | ||
"github.com/databricks/databricks-sdk-go/httpclient" | ||
"github.com/databricks/databricks-sdk-go/useragent" | ||
|
@@ -73,17 +75,22 @@ func (c *Config) NewApiClient() (*httpclient.ApiClient, error) { | |
return nil | ||
}, | ||
}, | ||
TransientErrors: []string{ | ||
"REQUEST_LIMIT_EXCEEDED", // This is temporary workaround for SCIM API returning 500. Remove when it's fixed | ||
}, | ||
ErrorMapper: apierr.GetAPIError, | ||
ErrorRetriable: func(ctx context.Context, err error) bool { | ||
var apiErr *apierr.APIError | ||
if errors.As(err, &apiErr) { | ||
return apiErr.IsRetriable(ctx) | ||
} | ||
return false | ||
}, | ||
ErrorRetriable: httpclient.CombineRetriers( | ||
func(ctx context.Context, _ *http.Request, _ *common.ResponseWrapper, err error) bool { | ||
var apiErr *apierr.APIError | ||
if errors.As(err, &apiErr) { | ||
return apiErr.IsRetriable(ctx) | ||
} | ||
return false | ||
}, | ||
httpclient.RetryUrlErrors, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I moved this out of the ApiClient to have a single codesite where the retry logic is defined for the client. The downside is that you need to add this explicitly in your ErrorRetriable if you don't specify DefaultErrorRetriable. Happy to make this a default behavior, let me know what you think. |
||
httpclient.RetryTransientErrors([]string{"REQUEST_LIMIT_EXCEEDED"}), | ||
httpclient.RetryMatchedRequests([]httpclient.RestApiMatcher{ | ||
// Get Permissions API can be retried on 504 | ||
{Method: http.MethodGet, Path: *regexp.MustCompile(`/api/2.0/permissions/[^/]+/[^/]+`)}, | ||
}, httpclient.RetryOnGatewayTimeout), | ||
), | ||
}), nil | ||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
package config | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
"io" | ||
"net/http" | ||
"strings" | ||
"testing" | ||
|
||
"github.com/databricks/databricks-sdk-go/httpclient" | ||
"github.com/stretchr/testify/assert" | ||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
type hc func(r *http.Request) (*http.Response, error) | ||
|
||
func (cb hc) RoundTrip(r *http.Request) (*http.Response, error) { | ||
return cb(r) | ||
} | ||
|
||
func (cb hc) SkipRetryOnIO() bool { | ||
return true | ||
} | ||
|
||
func TestApiClient_RetriesGetPermissionsOnGatewayTimeout(t *testing.T) { | ||
requestCount := 0 | ||
c := &Config{ | ||
HTTPTransport: hc(func(r *http.Request) (*http.Response, error) { | ||
initialRequestCount := requestCount | ||
requestCount++ | ||
if initialRequestCount == 0 { | ||
return &http.Response{ | ||
Request: r, | ||
StatusCode: http.StatusGatewayTimeout, | ||
Body: io.NopCloser(strings.NewReader( | ||
fmt.Sprintf(`{"error_code":"TEMPORARILY_UNAVAILABLE", "message":"The service at %s is taking too long to process your request. Please try again later or try a faster operation."}`, r.URL))), | ||
}, nil | ||
} | ||
return &http.Response{ | ||
Request: r, | ||
StatusCode: http.StatusOK, | ||
Body: io.NopCloser(strings.NewReader(`{"permissions": ["can_run_queries"]}`)), | ||
}, nil | ||
}), | ||
} | ||
client, err := c.NewApiClient() | ||
require.NoError(t, err) | ||
ctx := context.Background() | ||
var res map[string][]string | ||
err = client.Do(ctx, "GET", "/api/2.0/permissions/object/id", httpclient.WithResponseUnmarshal(&res)) | ||
assert.NoError(t, err) | ||
assert.Equal(t, map[string][]string{"permissions": {"can_run_queries"}}, res) | ||
} | ||
Comment on lines
+1
to
+54
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd recommend structuring the unit tests differently by having this test focused on how This test could look like the following (I did not verify that the code works): type mock struct {
MaxFails int // number of times the failed Response is returned
FailResponse *http.Response // response to return in case of fail
FailError error // error to return in case of fail
NumCalls int // total number of calls
}
func (m *mock) RoundTrip(r *http.Request) (*http.Response, error) {
m.NumCalls++
if m.NumCalls <= m.MaxFails {
return m.FailResponse, n.FailError
}
return &http.Response{
Request: r,
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(`{}`)),
}, nil
}
func (m *mock) SkipRetryOnIO() bool {
return true
}
func TestApiClient_Do_retries(t *testing.T) {
testCases := []struct{
desc string
config *Config
errorRetrier ErrorRetrier
wantNumCalls int
} {
{
desc: "nil retrier",
mock: &mock{
MaxFails: 1,
FailResponse: &http.Response{StatusCode: http.StatusGatewayTimeout}
}
wantNumCalls: 1,
},
{
desc: "no retry",
mock: &mock{
MaxFails: 1,
FailResponse: &http.Response{StatusCode: http.StatusGatewayTimeout}
}
errorRetrier: func(context.Context, *http.Request, *common.ResponseWrapper, error) bool {
return false
},
wantNumCalls: 1,
},
{
desc: "retry 1 time",
mock: &mock{
MaxFails: 1,
FailResponse: &http.Response{StatusCode: http.StatusGatewayTimeout}
}
errorRetrier: func(context.Context, *http.Request, *common.ResponseWrapper, error) bool {
return true
},
wantNumCalls: 2,
},
{
desc: "retry 2 times",
mock: &mock{
MaxFails: 2,
FailResponse: &http.Response{StatusCode: http.StatusGatewayTimeout}
}
errorRetrier: func(_ context.Context, _ *http.Request, _ *common.ResponseWrapper, _ error) bool {
return true
},
wantNumCalls: 3,
},
{
desc: "retry 3 times",
mock: &mock{
MaxFails: 3,
FailResponse: &http.Response{StatusCode: http.StatusGatewayTimeout}
}
errorRetrier: func(_ context.Context, _ *http.Request, _ *common.ResponseWrapper, _ error) bool {
return true
},
wantNumCalls: 4,
},
}
func _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
cfg := &Config{HTTPTransport: tc.mock}
client, err := cfg.NewApiClient()
client.ErrorRetrier = tc.errorRetrier
err = client.Do(context.Background(), "GET", "test-path")
gotNumCalls = tc.mock.NumCalls
if gotNumCalls != tc.wantNumCalls {
t.Errorf("got %d calls, want %d", gotNumCalls, tc.wantNumCalls)
}
})
}
} Please feel free to ignore this comment if this is too much work or if the ApiClient cannot be instrumented that easily. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It definitely can be instrumented this way, and this is a nice test case to use (I'll adapt it and include it in this PR). However, I did want to specifically test the get permissions pathway. Essentially, this tests that "the client returned by Config.GetApiClient() correctly implements retry on 504." I will add more test cases here though. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Sounds good to me as long as this complements the overall testing of the retry logic. |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,7 +9,6 @@ import ( | |
"net/http" | ||
"net/url" | ||
"runtime" | ||
"strings" | ||
"time" | ||
|
||
"github.com/databricks/databricks-sdk-go/common" | ||
|
@@ -28,16 +27,25 @@ type ClientConfig struct { | |
AuthVisitor RequestVisitor | ||
Visitors []RequestVisitor | ||
|
||
RetryTimeout time.Duration | ||
// The maximum amount of time to retry requests that return retriable errors. | ||
// If unset, the default is 5 minutes. | ||
RetryTimeout time.Duration | ||
|
||
// Returns the amount of time to wait after the given attempt. | ||
RetryBackoff retries.BackoffFunc | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is added here and exposed initially for testing, though it should be safe for users to configure this if they need to modify the backoff behavior. |
||
|
||
HTTPTimeout time.Duration | ||
InsecureSkipVerify bool | ||
DebugHeaders bool | ||
DebugTruncateBytes int | ||
RateLimitPerSecond int | ||
|
||
ErrorMapper func(ctx context.Context, resp common.ResponseWrapper) error | ||
ErrorRetriable func(ctx context.Context, err error) bool | ||
TransientErrors []string | ||
// ErrorMapper converts the API response into a Go error if the response is an error. | ||
ErrorMapper func(ctx context.Context, resp common.ResponseWrapper) error | ||
|
||
// ErrorRetriable determines if the API request should be retried. It is not | ||
// called if the context is cancelled or if the request succeeded. | ||
ErrorRetriable ErrorRetrier | ||
|
||
Transport http.RoundTripper | ||
} | ||
|
@@ -130,7 +138,6 @@ func (c *ApiClient) Do(ctx context.Context, method, path string, opts ...DoOptio | |
// merge client-wide and request-specific visitors | ||
visitors = append(visitors, o.in) | ||
} | ||
|
||
} | ||
// Use default AuthVisitor if none is provided | ||
if authVisitor == nil { | ||
|
@@ -170,45 +177,6 @@ func (c *ApiClient) Do(ctx context.Context, method, path string, opts ...DoOptio | |
return nil | ||
} | ||
|
||
func (c *ApiClient) isRetriable(ctx context.Context, err error) bool { | ||
if c.config.ErrorRetriable(ctx, err) { | ||
return true | ||
} | ||
if isRetriableUrlError(err) { | ||
// all IO errors are retriable | ||
logger.Debugf(ctx, "Attempting retry because of IO error: %s", err) | ||
return true | ||
} | ||
message := err.Error() | ||
// Handle transient errors for retries | ||
for _, substring := range c.config.TransientErrors { | ||
if strings.Contains(message, substring) { | ||
logger.Debugf(ctx, "Attempting retry because of %#v", substring) | ||
return true | ||
} | ||
} | ||
// some API's recommend retries on HTTP 500, but we'll add that later | ||
return false | ||
} | ||
|
||
// Common error-handling logic for all responses that may need to be retried. | ||
// | ||
// If the error is retriable, return a retries.Err to retry the request. However, as the request body will have been consumed | ||
// by the first attempt, the body must be reset before retrying. If the body cannot be reset, return a retries.Err to halt. | ||
// | ||
// Always returns nil for the first parameter as there is no meaningful response body to return in the error case. | ||
// | ||
// If it is certain that an error should not be retried, use failRequest() instead. | ||
func (c *ApiClient) handleError(ctx context.Context, err error, body common.RequestBody) (*common.ResponseWrapper, *retries.Err) { | ||
if !c.isRetriable(ctx, err) { | ||
return nil, retries.Halt(err) | ||
} | ||
if resetErr := body.Reset(); resetErr != nil { | ||
return nil, retries.Halt(resetErr) | ||
} | ||
return nil, retries.Continue(err) | ||
} | ||
|
||
// Fails the request with a retries.Err to halt future retries. | ||
func (c *ApiClient) failRequest(msg string, err error) (*common.ResponseWrapper, *retries.Err) { | ||
err = fmt.Errorf("%s: %w", msg, err) | ||
|
@@ -299,7 +267,16 @@ func (c *ApiClient) attempt( | |
|
||
// proactively release the connections in HTTP connection pool | ||
c.httpClient.CloseIdleConnections() | ||
return c.handleError(ctx, err, requestBody) | ||
|
||
// Non-retriable errors can be returned immediately. | ||
if !c.config.ErrorRetriable(ctx, request, &responseWrapper, err) { | ||
return nil, retries.Halt(err) | ||
} | ||
// Retriable errors may require the request body to be reset. | ||
if resetErr := requestBody.Reset(); resetErr != nil { | ||
return nil, retries.Halt(resetErr) | ||
} | ||
return nil, retries.Continue(err) | ||
} | ||
} | ||
|
||
|
@@ -331,16 +308,24 @@ func (c *ApiClient) recordRequestLog( | |
func (c *ApiClient) RoundTrip(request *http.Request) (*http.Response, error) { | ||
ctx := request.Context() | ||
requestURL := request.URL.String() | ||
resp, err := retries.Poll(ctx, c.config.RetryTimeout, | ||
c.attempt(ctx, request.Method, requestURL, common.RequestBody{ | ||
Reader: request.Body, | ||
// DO NOT DECODE BODY, because it may contain sensitive payload, | ||
// like Azure Service Principal in a multipart/form-data body. | ||
DebugBytes: []byte("<http.RoundTripper>"), | ||
}, func(r *http.Request) error { | ||
r.Header = request.Header | ||
return nil | ||
})) | ||
retrier := makeRetrier[common.ResponseWrapper](c.config) | ||
resp, err := retrier.Run( | ||
ctx, | ||
func(ctx context.Context) (*common.ResponseWrapper, error) { | ||
resp, err := c.attempt(ctx, request.Method, requestURL, common.RequestBody{ | ||
Reader: request.Body, | ||
// DO NOT DECODE BODY, because it may contain sensitive payload, | ||
// like Azure Service Principal in a multipart/form-data body. | ||
DebugBytes: []byte("<http.RoundTripper>"), | ||
}, func(r *http.Request) error { | ||
r.Header = request.Header | ||
return nil | ||
})() | ||
if err != nil { | ||
return nil, err | ||
} | ||
return resp, nil | ||
}) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
@@ -365,8 +350,16 @@ func (c *ApiClient) perform( | |
requestBody common.RequestBody, | ||
visitors ...RequestVisitor, | ||
) (*common.ResponseWrapper, error) { | ||
resp, err := retries.Poll(ctx, c.config.RetryTimeout, | ||
c.attempt(ctx, method, requestURL, requestBody, visitors...)) | ||
retrier := makeRetrier[common.ResponseWrapper](c.config) | ||
resp, err := retrier.Run( | ||
ctx, | ||
func(ctx context.Context) (*common.ResponseWrapper, error) { | ||
resp, err := c.attempt(ctx, method, requestURL, requestBody, visitors...)() | ||
if err != nil { | ||
return resp, err | ||
} | ||
return resp, nil | ||
}) | ||
var timedOut *retries.ErrTimedOut | ||
if errors.As(err, &timedOut) { | ||
// TODO: check if we want to unwrap this error here | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can add this back in if desired, just a small formatting change.