diff --git a/session_pool.go b/session_pool.go index 7e61e587..2e6084fc 100644 --- a/session_pool.go +++ b/session_pool.go @@ -425,7 +425,8 @@ func (pool *SessionPool) executeWithRetryLimit(session *pureSession, if rs.GetErrorCode() != ErrorCode_E_SESSION_INVALID { return rs, nil } - if sessionRetryTimes >= sessionRetryLimit { + // exec fn first, so should +1 for validation + if sessionRetryTimes+1 >= sessionRetryLimit { return rs, nil } pool.log.Info(fmt.Sprintf("retry to execute the query %d times for session invalid", sessionRetryTimes+1)) @@ -435,7 +436,7 @@ func (pool *SessionPool) executeWithRetryLimit(session *pureSession, } return pool.executeWithRetryLimit(session, fn, sessionRetryTimes+1, sessionRetryLimit, errRetryTimes, errRetryLimit) } else { - if errRetryTimes >= errRetryLimit { + if errRetryTimes+1 >= errRetryLimit { return rs, err } if err := pool.retryStrategyErr(session); err != nil { diff --git a/session_pool_test.go b/session_pool_test.go index 04ab666d..67c45b50 100644 --- a/session_pool_test.go +++ b/session_pool_test.go @@ -536,7 +536,7 @@ func TestSessionPoolRetry(t *testing.T) { t.Fatal(err) } defer sessionPool.Close() - testcaes := []struct { + testcases := []struct { name string retryFn func(*pureSession) (*ResultSet, error) retry bool @@ -582,7 +582,7 @@ func TestSessionPoolRetry(t *testing.T) { retry: false, }, } - for _, tc := range testcaes { + for _, tc := range testcases { session, err := sessionPool.newSession() if err != nil { t.Fatal(err) @@ -599,6 +599,78 @@ func TestSessionPoolRetry(t *testing.T) { } } +// split retry flags +func TestSessionPoolSplitRetry(t *testing.T) { + err := prepareSpace("client_test") + if err != nil { + t.Fatal(err) + } + defer dropSpace("client_test") + testcases := []struct { + retryFn *retryFn + retryTimes int + sessionInvaildLimit int + errLimit int + hasErr bool + hasResult bool + err error + }{ + // retry when return the error code *ErrorCode_E_SESSION_INVALID* + { + retryFn: newRetryFn(&ResultSet{ + resp: &graph.ExecutionResponse{ + ErrorCode: nebula.ErrorCode_E_SESSION_INVALID, + }}, nil), + sessionInvaildLimit: 5, + errLimit: 0, + err: nil, + hasErr: false, + hasResult: true, + retryTimes: 5, + }, + // retry when occur error + { + retryFn: newRetryFn(nil, fmt.Errorf("error")), + sessionInvaildLimit: 0, + errLimit: 4, + err: fmt.Errorf("error"), + hasErr: true, + hasResult: false, + retryTimes: 4, + }, + } + hostAddress := HostAddress{Host: address, Port: port} + config, err := NewSessionPoolConf( + "root", + "nebula", + []HostAddress{hostAddress}, + "client_test") + if err != nil { + t.Errorf("failed to create session pool config, %s", err.Error()) + } + config.minSize = 2 + config.maxSize = 2 + + for _, tc := range testcases { + c := *config + c.retryErrorTimes = tc.errLimit + c.retryGetSessionTimes = tc.sessionInvaildLimit + sessionPool, err := NewSessionPool(c, DefaultLogger{}) + if err != nil { + t.Fatal(err) + } + defer sessionPool.Close() + resp, err := sessionPool.executeFn(tc.retryFn.retry) + assert.Equal(t, tc.retryTimes, tc.retryFn.retryTimes) + if tc.hasErr { + assert.EqualError(t, tc.err, err.Error()) + } + if tc.hasResult { + assert.Equal(t, resp.GetErrorCode(), ErrorCode_E_SESSION_INVALID) + } + } +} + type retryFn struct { fn func(*pureSession) (*ResultSet, error) retryTimes int