From 0b1144488f4f7f9b75e749e78500bc03536bab36 Mon Sep 17 00:00:00 2001 From: Chunzhu Li Date: Thu, 27 Apr 2023 19:05:52 +0800 Subject: [PATCH] lightning: fix pd retry and add ut for it (#43432) close pingcap/tidb#43400 --- br/pkg/lightning/common/retry.go | 8 +++---- br/pkg/pdutil/pd.go | 40 ++++++++++++++++++++++++-------- br/pkg/pdutil/pd_serial_test.go | 26 ++++++++++++++++++++- 3 files changed, 58 insertions(+), 16 deletions(-) diff --git a/br/pkg/lightning/common/retry.go b/br/pkg/lightning/common/retry.go index 293985097bb77..f6db6cda86407 100644 --- a/br/pkg/lightning/common/retry.go +++ b/br/pkg/lightning/common/retry.go @@ -17,6 +17,7 @@ import ( "context" "database/sql" "database/sql/driver" + goerrors "errors" "io" "net" "os" @@ -101,11 +102,8 @@ func isSingleRetryableError(err error) bool { if nerr.Timeout() { return true } - if cause, ok := nerr.(*net.OpError); ok { - syscallErr, ok := cause.Unwrap().(*os.SyscallError) - if ok { - return syscallErr.Err == syscall.ECONNREFUSED || syscallErr.Err == syscall.ECONNRESET - } + if syscallErr, ok := goerrors.Unwrap(err).(*os.SyscallError); ok { + return syscallErr.Err == syscall.ECONNREFUSED || syscallErr.Err == syscall.ECONNRESET } return false case *mysql.MySQLError: diff --git a/br/pkg/pdutil/pd.go b/br/pkg/pdutil/pd.go index 26677fc89dd87..c0546a1801843 100644 --- a/br/pkg/pdutil/pd.go +++ b/br/pkg/pdutil/pd.go @@ -11,9 +11,12 @@ import ( "fmt" "io" "math" + "net" "net/http" "net/url" + "os" "strings" + "syscall" "time" "github.com/coreos/go-semver/semver" @@ -165,22 +168,39 @@ func pdRequestWithCode( if err != nil { return 0, nil, errors.Trace(err) } - resp, err := cli.Do(req) - if err != nil { - return 0, nil, errors.Trace(err) - } + var resp *http.Response count := 0 for { + resp, err = cli.Do(req) //nolint:bodyclose count++ - if count > pdRequestRetryTime || resp.StatusCode < 500 { + failpoint.Inject("InjectClosed", func(v failpoint.Value) { + if failType, ok := v.(int); ok && count <= pdRequestRetryTime-1 { + resp = nil + switch failType { + case 0: + err = &net.OpError{ + Op: "read", + Err: os.NewSyscallError("connect", syscall.ECONNREFUSED), + } + default: + err = &url.Error{ + Op: "read", + Err: os.NewSyscallError("connect", syscall.ECONNREFUSED), + } + } + } + }) + if count > pdRequestRetryTime || (resp != nil && resp.StatusCode < 500) || + (err != nil && !common.IsRetryableError(err)) { break } - _ = resp.Body.Close() - time.Sleep(pdRequestRetryInterval()) - resp, err = cli.Do(req) - if err != nil { - return 0, nil, errors.Trace(err) + if resp != nil { + _ = resp.Body.Close() } + time.Sleep(pdRequestRetryInterval()) + } + if err != nil { + return 0, nil, errors.Trace(err) } defer func() { _ = resp.Body.Close() diff --git a/br/pkg/pdutil/pd_serial_test.go b/br/pkg/pdutil/pd_serial_test.go index 09f70f1d78476..32a415ed8800d 100644 --- a/br/pkg/pdutil/pd_serial_test.go +++ b/br/pkg/pdutil/pd_serial_test.go @@ -199,10 +199,34 @@ func TestPDRequestRetry(t *testing.T) { } w.WriteHeader(http.StatusOK) })) - defer ts.Close() taddr = ts.URL _, reqErr = pdRequest(ctx, taddr, "", cli, http.MethodGet, nil) require.Error(t, reqErr) + ts.Close() + + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/br/pkg/pdutil/InjectClosed", + fmt.Sprintf("return(%d)", 0))) + defer func() { + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/br/pkg/pdutil/InjectClosed")) + }() + ts = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + taddr = ts.URL + _, reqErr = pdRequest(ctx, taddr, "", cli, http.MethodGet, nil) + require.NoError(t, reqErr) + ts.Close() + + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/br/pkg/pdutil/InjectClosed")) + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/br/pkg/pdutil/InjectClosed", + fmt.Sprintf("return(%d)", 1))) + ts = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer ts.Close() + taddr = ts.URL + _, reqErr = pdRequest(ctx, taddr, "", cli, http.MethodGet, nil) + require.NoError(t, reqErr) } func TestPDResetTSCompatibility(t *testing.T) {