diff --git a/Makefile b/Makefile index 76301707..9dbf2376 100644 --- a/Makefile +++ b/Makefile @@ -25,6 +25,7 @@ down: ssl-test: ssl_test=true go test -v --tags=integration -run TestSslConnection; + ssl_test=true go test -v --tags=integration -run TestSslSessionPool; ssl-test-self-signed: self_signed=true go test -v --tags=integration -run TestSslConnection; diff --git a/connection.go b/connection.go index 63c0aee8..f2c1e29c 100644 --- a/connection.go +++ b/connection.go @@ -157,19 +157,10 @@ func (cn *connection) executeWithParameter(sessionID int64, stmt string, params map[string]*nebula.Value) (*graph.ExecutionResponse, error) { resp, err := cn.graph.ExecuteWithParameter(sessionID, []byte(stmt), params) if err != nil { - // reopen the connection if timeout - if _, ok := err.(thrift.TransportException); ok { - if err.(thrift.TransportException).TypeID() == thrift.TIMED_OUT { - reopenErr := cn.reopen() - if reopenErr != nil { - return nil, reopenErr - } - return cn.graph.ExecuteWithParameter(sessionID, []byte(stmt), params) - } - } + return nil, err } - return resp, err + return resp, nil } func (cn *connection) executeJson(sessionID int64, stmt string) ([]byte, error) { diff --git a/connection_pool.go b/connection_pool.go index f18c8924..d8f891ed 100644 --- a/connection_pool.go +++ b/connection_pool.go @@ -129,7 +129,6 @@ func (pool *ConnectionPool) GetSession(username, password string) (*Session, err sessionID: sessID, connection: conn, connPool: pool, - sessPool: nil, log: pool.log, timezoneInfo: timezoneInfo{timezoneOffset, timezoneName}, } diff --git a/nebula-docker-compose/docker-compose-ssl.yaml b/nebula-docker-compose/docker-compose-ssl.yaml index 0fa8b882..86e9d3a5 100644 --- a/nebula-docker-compose/docker-compose-ssl.yaml +++ b/nebula-docker-compose/docker-compose-ssl.yaml @@ -418,6 +418,14 @@ services: -e 'ADD HOSTS "storaged0":44500,"storaged1":44500,"storaged2":44500'`; if [[ $$? == 0 ]];then echo "Hosts added successfully" + sleep 2; + nebula-console -addr graphd0 -port 3699 -u root -p nebula \ + -enable_ssl=true \ + -ssl_root_ca_path /secrets/root.crt \ + -ssl_cert_path /secrets/client.crt \ + -ssl_private_key_path /secrets/client.key \ + --ssl_insecure_skip_verify=true \ + -e 'CREATE SPACE session_pool(partition_num=4, replica_factor=1, vid_type = FIXED_STRING(30))' break; fi; sleep 1; diff --git a/session.go b/session.go index 2683378d..6a7da2af 100644 --- a/session.go +++ b/session.go @@ -11,7 +11,6 @@ package nebula_go import ( "fmt" "sync" - "time" "github.com/vesoft-inc/nebula-go/v3/nebula" graph "github.com/vesoft-inc/nebula-go/v3/nebula/graph" @@ -26,9 +25,7 @@ type Session struct { sessionID int64 connection *connection connPool *ConnectionPool // the connection pool which the session belongs to. could be nil if the Session is store in the SessionPool - sessPool *SessionPool // the session pool which the session belongs to. could be nil if the Session is store in the ConnectionPool log Logger - returnedAt time.Time // the timestamp that the session was created or returned. mu sync.Mutex timezoneInfo } @@ -228,6 +225,10 @@ func (session *Session) GetSessionID() int64 { return session.sessionID } +func IsError(resp *graph.ExecutionResponse) bool { + return resp.GetErrorCode() != nebula.ErrorCode_SUCCEEDED +} + // Ping checks if the session is valid func (session *Session) Ping() error { if session.connection == nil { @@ -246,10 +247,6 @@ func (session *Session) Ping() error { return nil } -func IsError(resp *graph.ExecutionResponse) bool { - return resp.GetErrorCode() != nebula.ErrorCode_SUCCEEDED -} - // construct Slice to nebula.NList func slice2Nlist(list []interface{}) (*nebula.NList, error) { sv := []*nebula.Value{} diff --git a/session_pool.go b/session_pool.go index 75d7035c..5acd6f82 100644 --- a/session_pool.go +++ b/session_pool.go @@ -16,7 +16,6 @@ import ( "time" "github.com/vesoft-inc/nebula-go/v3/nebula" - "github.com/vesoft-inc/nebula-go/v3/nebula/graph" ) // SessionPool is a pool that manages sessions internally. @@ -46,6 +45,17 @@ type SessionPool struct { rwLock sync.RWMutex } +// one pureSession binds to one connection and shares the same lifespan. +// If the underlying connection is broken, the session will be removed from the session pool. +type pureSession struct { + sessionID int64 + connection *connection + sessPool *SessionPool + returnedAt time.Time // the timestamp that the session was created or returned. + timezoneInfo + spaceName string +} + // NewSessionPool creates a new session pool with the given configs. // There must be an existing SPACE in the DB. func NewSessionPool(conf SessionPoolConf, log Logger) (*SessionPool, error) { @@ -67,8 +77,6 @@ func NewSessionPool(conf SessionPoolConf, log Logger) (*SessionPool, error) { // init initializes the session pool. func (pool *SessionPool) init() error { - pool.rwLock.Lock() - defer pool.rwLock.Unlock() // check the hosts status if err := checkAddresses(pool.conf.timeOut, pool.conf.serviceAddrs, pool.conf.sslConfig, pool.conf.useHTTP2); err != nil { return fmt.Errorf("failed to initialize the session pool, %s", err.Error()) @@ -82,7 +90,7 @@ func (pool *SessionPool) init() error { } session.returnedAt = time.Now() - pool.addSessionToList(&pool.idleSessions, session) + pool.addSessionToIdle(session) } return nil @@ -107,42 +115,46 @@ func (pool *SessionPool) ExecuteWithParameter(stmt string, params map[string]int } // Get a session from the pool - session, err := pool.getIdleSession() + session, err := pool.getSessionFromIdle() if err != nil { return nil, err } - - // Parse params - paramsMap, err := parseParams(params) - if err != nil { - return nil, err + // if there's no idle session, create a new one + if session == nil { + session, err = pool.newSession() + if err != nil { + return nil, err + } + pool.addSessionToActive(session) + } else { + pool.removeSessionFromIdle(session) + pool.addSessionToActive(session) } // Execute the query - execFunc := func(s *Session) (*graph.ExecutionResponse, error) { - resp, err := s.connection.executeWithParameter(s.sessionID, stmt, paramsMap) + execFunc := func(s *pureSession) (*ResultSet, error) { + rs, err := s.executeWithParameter(stmt, params) if err != nil { return nil, err } - session = s - return resp, nil - } - - resp, err := pool.executeWithRetry(session, execFunc, pool.conf.retryGetSessionTimes) - if err != nil { - return nil, err + return rs, nil } - resSet, err := genResultSet(resp, session.timezoneInfo) + rs, err := pool.executeWithRetry(session, execFunc, pool.conf.retryGetSessionTimes) if err != nil { + session.close() + pool.removeSessionFromActive(session) return nil, err } // if the space was changed after the execution of the given query, // change it back to the default space specified in the pool config - if resSet.GetSpaceName() != "" && resSet.GetSpaceName() != pool.conf.spaceName { - err := pool.setSessionSpaceToDefault(session) + if rs.GetSpaceName() != "" && rs.GetSpaceName() != pool.conf.spaceName { + err := session.setSessionSpaceToDefault() if err != nil { + pool.log.Warn(err.Error()) + session.close() + pool.removeSessionFromActive(session) return nil, err } } @@ -150,7 +162,7 @@ func (pool *SessionPool) ExecuteWithParameter(stmt string, params map[string]int // Return the session to the idle list pool.returnSession(session) - return resSet, err + return rs, nil } // ExecuteJson returns the result of the given query as a json string @@ -223,31 +235,6 @@ func (pool *SessionPool) ExecuteJson(stmt string) ([]byte, error) { // TODO(Aiee) check the space name func (pool *SessionPool) ExecuteJsonWithParameter(stmt string, params map[string]interface{}) ([]byte, error) { return nil, fmt.Errorf("not implemented") - - // Get a session from the pool - session, err := pool.getIdleSession() - if err != nil { - return nil, err - } - // check the session is valid - if session.connection == nil { - return nil, fmt.Errorf("failed to execute: Session has been released") - } - // parse params - paramsMap, err := parseParams(params) - if err != nil { - return nil, err - } - - pool.rwLock.Lock() - defer pool.rwLock.Unlock() - resp, err := session.connection.ExecuteJsonWithParameter(session.sessionID, stmt, paramsMap) - if err != nil { - return nil, err - } - - //TODO(Aiee) check the space name - return resp, nil } // Close logs out all sessions and closes bonded connection. @@ -261,22 +248,29 @@ func (pool *SessionPool) Close() { // iterate all sessions for i := 0; i < idleLen; i++ { - session := pool.idleSessions.Front().Value.(*Session) + session := pool.idleSessions.Front().Value.(*pureSession) if session.connection == nil { - session.log.Warn("Session has been released") - } else if err := session.connection.signOut(session.sessionID); err != nil { - session.log.Warn(fmt.Sprintf("Sign out failed, %s", err.Error())) + pool.log.Warn("Session has been released") + pool.idleSessions.Remove(pool.idleSessions.Front()) + continue + } + + if err := session.connection.signOut(session.sessionID); err != nil { + pool.log.Warn(fmt.Sprintf("Sign out failed, %s", err.Error())) } // close connection session.connection.close() pool.idleSessions.Remove(pool.idleSessions.Front()) } for i := 0; i < activeLen; i++ { - session := pool.activeSessions.Front().Value.(*Session) + session := pool.activeSessions.Front().Value.(*pureSession) if session.connection == nil { - session.log.Warn("Session has been released") - } else if err := session.connection.signOut(session.sessionID); err != nil { - session.log.Warn(fmt.Sprintf("Sign out failed, %s", err.Error())) + pool.log.Warn("Session has been released") + pool.activeSessions.Remove(pool.activeSessions.Front()) + continue + } + if err := session.connection.signOut(session.sessionID); err != nil { + pool.log.Warn(fmt.Sprintf("Sign out failed, %s", err.Error())) } // close connection session.connection.close() @@ -298,7 +292,7 @@ func (pool *SessionPool) GetTotalSessionCount() int { // newSession creates a new session and returns it. // `use ` will be executed so that the new session will be in the default space. -func (pool *SessionPool) newSession() (*Session, error) { +func (pool *SessionPool) newSession() (*pureSession, error) { graphAddr := pool.getNextAddr() cn := connection{ severAddress: graphAddr, @@ -336,33 +330,34 @@ func (pool *SessionPool) newSession() (*Session, error) { timezoneOffset := authResp.GetTimeZoneOffsetSeconds() timezoneName := authResp.GetTimeZoneName() // Create new session - newSession := Session{ + newSession := pureSession{ sessionID: sessID, connection: &cn, - connPool: nil, sessPool: pool, - log: pool.log, timezoneInfo: timezoneInfo{timezoneOffset, timezoneName}, + spaceName: pool.conf.spaceName, } // Switch to the default space stmt := fmt.Sprintf("USE %s", pool.conf.spaceName) - useSpaceResp, err := newSession.connection.execute(newSession.sessionID, stmt) + useSpaceRs, err := newSession.execute(stmt) if err != nil { return nil, err } - if useSpaceResp.GetErrorCode() != nebula.ErrorCode_SUCCEEDED { - newSession.connection.close() + if useSpaceRs.GetErrorCode() != ErrorCode_SUCCEEDED { + newSession.close() return nil, fmt.Errorf("failed to use space %s: %s", - pool.conf.spaceName, useSpaceResp.GetErrorMsg()) + pool.conf.spaceName, useSpaceRs.GetErrorMsg()) } return &newSession, nil } // getNextAddr returns the next address in the address list using simple round robin approach. func (pool *SessionPool) getNextAddr() HostAddress { - if pool.conf.hostIndex == len(pool.conf.serviceAddrs) { + pool.rwLock.Lock() + defer pool.rwLock.Unlock() + if pool.conf.hostIndex >= len(pool.conf.serviceAddrs) { pool.conf.hostIndex = 0 } host := pool.conf.serviceAddrs[pool.conf.hostIndex] @@ -372,52 +367,42 @@ func (pool *SessionPool) getNextAddr() HostAddress { // getSession returns an available session. // This method should move an available session to the active list and should be MT-safe. -func (pool *SessionPool) getIdleSession() (*Session, error) { +func (pool *SessionPool) getSessionFromIdle() (*pureSession, error) { pool.rwLock.Lock() defer pool.rwLock.Unlock() // Get a session from the idle queue if possible if pool.idleSessions.Len() > 0 { - session := pool.idleSessions.Front().Value.(*Session) - pool.removeSessionFromList(&pool.idleSessions, session) - pool.addSessionToList(&pool.activeSessions, session) + session := pool.idleSessions.Front().Value.(*pureSession) + pool.idleSessions.Remove(pool.idleSessions.Front()) return session, nil } else if pool.activeSessions.Len() < pool.conf.maxSize { - // Create a new session if the total number of sessions is less than the max size - session, err := pool.newSession() - if err != nil { - return nil, err - } - pool.addSessionToList(&pool.activeSessions, session) - return session, nil + return nil, nil } // There is no available session in the pool and the total session count has reached the limit return nil, fmt.Errorf("failed to get session: no session available in the" + " session pool and the total session count has reached the limit") } -// retryGetSession tries to create a new session when the current session is invalid. +// retryGetSession tries to create a new session when: +// 1. the current session is invalid. +// 2. connection is invalid. +// and then change the original session to the new one. func (pool *SessionPool) executeWithRetry( - session *Session, - f func(*Session) (*graph.ExecutionResponse, error), - retry int) (*graph.ExecutionResponse, error) { - pool.rwLock.Lock() - defer pool.rwLock.Unlock() - - resp, err := f(session) - if err != nil { - pool.removeSessionFromList(&pool.activeSessions, session) - return nil, err - } - - if resp.ErrorCode == nebula.ErrorCode_SUCCEEDED { - return resp, nil - } else if ErrorCode(resp.ErrorCode) != ErrorCode_E_SESSION_INVALID { // only retry when the session is invalid - return resp, err + session *pureSession, + f func(*pureSession) (*ResultSet, error), + retry int) (*ResultSet, error) { + rs, err := f(session) + if err == nil { + if rs.GetErrorCode() == ErrorCode_SUCCEEDED { + return rs, nil + } else if rs.GetErrorCode() != ErrorCode_E_SESSION_INVALID { // only retry when the session is invalid + return rs, err + } } - // remove invalid session regardless of the retry is successful or not - defer pool.removeSessionFromList(&pool.activeSessions, session) - // If the session is invalid, close it and get a new session + // If the session is invalid, close it first + session.close() + // get a new session for i := 0; i < retry; i++ { pool.log.Info("retry to get sessions") newSession, err := pool.newSession() @@ -425,15 +410,15 @@ func (pool *SessionPool) executeWithRetry( return nil, err } - pingErr := newSession.Ping() + pingErr := newSession.ping() if pingErr != nil { pool.log.Error("failed to ping the session, error: " + pingErr.Error()) continue } pool.log.Info("retry to get sessions successfully") - pool.addSessionToList(&pool.activeSessions, newSession) + *session = *newSession - return f(newSession) + return f(session) } pool.log.Error(fmt.Sprintf("failed to get session after " + strconv.Itoa(retry) + " retries")) return nil, fmt.Errorf("failed to get session after %d retries", retry) @@ -476,23 +461,23 @@ func (pool *SessionPool) sessionCleaner() { //release expired session from the pool for _, session := range closing { if session.connection == nil { - session.log.Warn("Session has been released") + pool.log.Warn("Session has been released") + pool.rwLock.Unlock() return } if err := session.connection.signOut(session.sessionID); err != nil { - session.log.Warn(fmt.Sprintf("Sign out failed, %s", err.Error())) + pool.log.Warn(fmt.Sprintf("Sign out failed, %s", err.Error())) } // close connection session.connection.close() } pool.rwLock.Unlock() - t.Reset(d) } } // timeoutSessionList returns a list of sessions that have been idle for longer than the idle time. -func (pool *SessionPool) timeoutSessionList() (closing []*Session) { +func (pool *SessionPool) timeoutSessionList() (closing []*pureSession) { if pool.conf.idleTime > 0 { expiredSince := time.Now().Add(-pool.conf.idleTime) var newEle *list.Element = nil @@ -506,10 +491,10 @@ func (pool *SessionPool) timeoutSessionList() (closing []*Session) { newEle = ele.Next() // Check Session is expired - if !ele.Value.(*Session).returnedAt.Before(expiredSince) { + if !ele.Value.(*pureSession).returnedAt.Before(expiredSince) { return } - closing = append(closing, ele.Value.(*Session)) + closing = append(closing, ele.Value.(*pureSession)) pool.idleSessions.Remove(ele) ele = newEle maxCleanSize-- @@ -532,39 +517,137 @@ func parseParams(params map[string]interface{}) (map[string]*nebula.Value, error } // removeSessionFromIdleList Removes a session from list -func (pool *SessionPool) removeSessionFromList(l *list.List, session *Session) { +func (pool *SessionPool) removeSessionFromActive(session *pureSession) { + pool.rwLock.Lock() + defer pool.rwLock.Unlock() + l := &pool.activeSessions for ele := l.Front(); ele != nil; ele = ele.Next() { - if ele.Value.(*Session) == session { + if ele.Value.(*pureSession) == session { l.Remove(ele) } } } -func (pool *SessionPool) addSessionToList(l *list.List, session *Session) { +func (pool *SessionPool) addSessionToActive(session *pureSession) { + pool.rwLock.Lock() + defer pool.rwLock.Unlock() + l := &pool.activeSessions + l.PushBack(session) +} + +func (pool *SessionPool) removeSessionFromIdle(session *pureSession) { + pool.rwLock.Lock() + defer pool.rwLock.Unlock() + l := &pool.idleSessions + for ele := l.Front(); ele != nil; ele = ele.Next() { + if ele.Value.(*pureSession) == session { + l.Remove(ele) + } + } +} + +func (pool *SessionPool) addSessionToIdle(session *pureSession) { + pool.rwLock.Lock() + defer pool.rwLock.Unlock() + l := &pool.idleSessions l.PushBack(session) } // returnSession returns a session from active list to the idle list. -func (pool *SessionPool) returnSession(session *Session) { +func (pool *SessionPool) returnSession(session *pureSession) { pool.rwLock.Lock() defer pool.rwLock.Unlock() - pool.removeSessionFromList(&pool.activeSessions, session) - pool.addSessionToList(&pool.idleSessions, session) + l := &pool.activeSessions + for ele := l.Front(); ele != nil; ele = ele.Next() { + if ele.Value.(*pureSession) == session { + l.Remove(ele) + } + } + l = &pool.idleSessions + l.PushBack(session) session.returnedAt = time.Now() } -func (pool *SessionPool) setSessionSpaceToDefault(session *Session) error { +func (pool *SessionPool) setSessionSpaceToDefault(session *pureSession) error { stmt := fmt.Sprintf("USE %s", pool.conf.spaceName) - resp, err := session.connection.execute(session.sessionID, stmt) + rs, err := session.execute(stmt) if err != nil { return err } + + if rs.GetErrorCode() == ErrorCode_SUCCEEDED { + return nil + } // if failed to change back to the default space, send a warning log // and remove the session from the pool because it is malformed. - if resp.ErrorCode != nebula.ErrorCode_SUCCEEDED { - pool.log.Warn(fmt.Sprintf("failed to reset the space of the session: errorCode: %s, errorMsg: %s, session removed", - resp.ErrorCode, resp.ErrorMsg)) - pool.removeSessionFromList(&pool.activeSessions, session) + pool.log.Warn(fmt.Sprintf("failed to reset the space of the session: errorCode: %d, errorMsg: %s, session removed", + rs.GetErrorCode(), rs.GetErrorMsg())) + session.close() + pool.removeSessionFromActive(session) + return fmt.Errorf("failed to reset the space of the session: errorCode: %d, errorMsg: %s", + rs.GetErrorCode(), rs.GetErrorMsg()) +} + +func (session *pureSession) execute(stmt string) (*ResultSet, error) { + return session.executeWithParameter(stmt, nil) +} + +func (session *pureSession) executeWithParameter(stmt string, params map[string]interface{}) (*ResultSet, error) { + paramsMap, err := parseParams(params) + if err != nil { + return nil, err + } + if session.connection == nil { + return nil, fmt.Errorf("failed to execute: Session has been released") + } + resp, err := session.connection.executeWithParameter(session.sessionID, stmt, paramsMap) + if err != nil { + return nil, err + } + rs, err := genResultSet(resp, session.timezoneInfo) + if err != nil { + return nil, err + } + return rs, nil +} + +func (session *pureSession) close() { + if session.connection != nil { + // ignore signout error + _ = session.connection.signOut(session.sessionID) + session.connection.close() + } + session.connection = nil +} + +// Ping checks if the session is valid +func (session *pureSession) ping() error { + if session.connection == nil { + return fmt.Errorf("failed to ping: Session has been released") + } + // send ping request + rs, err := session.execute(`RETURN "NEBULA GO PING"`) + // check connection level error + if err != nil { + return fmt.Errorf("session ping failed, %s" + err.Error()) + } + // check session level error + if !rs.IsSucceed() { + return fmt.Errorf("session ping failed, %s" + rs.GetErrorMsg()) } return nil } + +func (session *pureSession) setSessionSpaceToDefault() error { + stmt := fmt.Sprintf("USE %s", session.spaceName) + rs, err := session.execute(stmt) + if err != nil { + return err + } + + if rs.GetErrorCode() == ErrorCode_SUCCEEDED { + return nil + } + return fmt.Errorf("failed to reset the space of the session: errorCode: %d, errorMsg: %s", + rs.GetErrorCode(), rs.GetErrorMsg()) +} diff --git a/session_pool_test.go b/session_pool_test.go index 5b36e1ce..d0e10cb6 100644 --- a/session_pool_test.go +++ b/session_pool_test.go @@ -17,6 +17,8 @@ import ( "time" "github.com/stretchr/testify/assert" + "github.com/vesoft-inc/nebula-go/v3/nebula" + "github.com/vesoft-inc/nebula-go/v3/nebula/graph" ) func TestSessionPoolInvalidConfig(t *testing.T) { @@ -110,37 +112,38 @@ func TestSessionPoolMultiThreadGetSession(t *testing.T) { defer sessionPool.Close() var wg sync.WaitGroup - sessCh := make(chan *Session) + rsCh := make(chan *ResultSet, sessionPool.conf.maxSize) done := make(chan bool) wg.Add(sessionPool.conf.maxSize) // producer create sessions for i := 0; i < sessionPool.conf.maxSize; i++ { - go func(sessCh chan<- *Session, wg *sync.WaitGroup) { + go func(sessCh chan<- *ResultSet, wg *sync.WaitGroup) { defer wg.Done() - session, err := sessionPool.getIdleSession() + rs, err := sessionPool.Execute("yield 1") if err != nil { - t.Errorf("fail to create a new session from connection pool, %s", err.Error()) + t.Errorf("fail to execute query, %s", err.Error()) } - sessCh <- session - }(sessCh, &wg) + + rsCh <- rs + }(rsCh, &wg) } // consumer consumes the session created - var sessionList []*Session - go func(sessCh <-chan *Session) { - for session := range sessCh { - sessionList = append(sessionList, session) + var rsList []*ResultSet + go func(rsCh <-chan *ResultSet) { + for session := range rsCh { + rsList = append(rsList, session) } done <- true - }(sessCh) + }(rsCh) wg.Wait() - close(sessCh) + close(rsCh) <-done - assert.Equalf(t, config.maxSize, sessionPool.activeSessions.Len(), + assert.Equalf(t, 0, sessionPool.activeSessions.Len(), "Total number of active connections should be %d", config.maxSize) - assert.Equalf(t, config.maxSize, len(sessionList), + assert.Equalf(t, config.maxSize, len(rsList), "Total number of result returned should be %d", config.maxSize) } @@ -411,3 +414,125 @@ func BenchmarkConcurrency(b *testing.B) { b.Logf("Concurrency: %d, Total time cost: %v", clients, end.Sub(start)) } } + +// retry when return the error code *ErrorCode_E_SESSION_INVALID* +func TestSessionPoolRetry(t *testing.T) { + err := prepareSpace("client_test") + if err != nil { + t.Fatal(err) + } + defer dropSpace("client_test") + + hostAddress := HostAddress{Host: address, Port: port} + config, err := NewSessionPoolConf( + "root", + "nebula", + []HostAddress{hostAddress}, + "client_test") + if err != nil { + t.Errorf("failed to create session pool config, %s", err.Error()) + } + config.minSize = 2 + config.maxSize = 2 + config.retryGetSessionTimes = 1 + + // create session pool + sessionPool, err := NewSessionPool(*config, DefaultLogger{}) + if err != nil { + t.Fatal(err) + } + defer sessionPool.Close() + testcaes := []struct { + name string + retryFn func(*pureSession) (*ResultSet, error) + retry bool + }{ + { + name: "success", + retryFn: func(s *pureSession) (*ResultSet, error) { + return &ResultSet{ + resp: &graph.ExecutionResponse{ + ErrorCode: nebula.ErrorCode_SUCCEEDED, + }, + }, nil + }, + retry: false, + }, + { + name: "error", + retryFn: func(s *pureSession) (*ResultSet, error) { + return nil, fmt.Errorf("error") + }, + retry: true, + }, + { + name: "invalid session error code", + retryFn: func(s *pureSession) (*ResultSet, error) { + return &ResultSet{ + resp: &graph.ExecutionResponse{ + ErrorCode: nebula.ErrorCode_E_SESSION_INVALID, + }, + }, nil + }, + retry: true, + }, + { + name: "execution error code", + retryFn: func(s *pureSession) (*ResultSet, error) { + return &ResultSet{ + resp: &graph.ExecutionResponse{ + ErrorCode: nebula.ErrorCode_E_EXECUTION_ERROR, + }, + }, nil + }, + retry: false, + }, + } + for _, tc := range testcaes { + session, err := sessionPool.newSession() + if err != nil { + t.Fatal(err) + } + original := *session + _, _ = sessionPool.executeWithRetry(session, tc.retryFn, 2) + if tc.retry { + assert.NotEqual(t, original, *session, fmt.Sprintf("test case: %s", tc.name)) + assert.NotEqual(t, original.connection, nil, fmt.Sprintf("test case: %s", tc.name)) + } else { + assert.Equal(t, original, *session, fmt.Sprintf("test case: %s", tc.name)) + } + } +} + +func TestSessionPoolClose(t *testing.T) { + err := prepareSpace("client_test") + if err != nil { + t.Fatal(err) + } + defer dropSpace("client_test") + + hostAddress := HostAddress{Host: address, Port: port} + config, err := NewSessionPoolConf( + "root", + "nebula", + []HostAddress{hostAddress}, + "client_test") + if err != nil { + t.Errorf("failed to create session pool config, %s", err.Error()) + } + config.minSize = 2 + config.maxSize = 2 + config.retryGetSessionTimes = 1 + + // create session pool + sessionPool, err := NewSessionPool(*config, DefaultLogger{}) + if err != nil { + t.Fatal(err) + } + sessionPool.Close() + + assert.Equal(t, 0, sessionPool.activeSessions.Len(), "Total number of active connections should be 0") + assert.Equal(t, 0, sessionPool.idleSessions.Len(), "Total number of active connections should be 0") + _, err = sessionPool.Execute("SHOW HOSTS;") + assert.Equal(t, err.Error(), "failed to execute: Session pool has been closed", "session pool should be closed") +} diff --git a/ssl_sessionpool_test.go b/ssl_sessionpool_test.go new file mode 100644 index 00000000..87197f6f --- /dev/null +++ b/ssl_sessionpool_test.go @@ -0,0 +1,74 @@ +//go:build integration +// +build integration + +/* + * + * Copyright (c) 2023 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License. + * + */ + +package nebula_go + +import ( + "testing" + "time" +) + +func TestSslSessionPool(t *testing.T) { + skipSsl(t) + + hostAddress := HostAddress{Host: address, Port: port} + hostList := []HostAddress{} + hostList = append(hostList, hostAddress) + sslConfig, err := GetDefaultSSLConfig( + "./nebula-docker-compose/secrets/root.crt", + "./nebula-docker-compose/secrets/client.crt", + "./nebula-docker-compose/secrets/client.key", + ) + if err != nil { + t.Fatal(err) + } + sslConfig.InsecureSkipVerify = true // This is only used for testing + conf, err := NewSessionPoolConf( + username, + password, + hostList, + "session_pool", + WithMaxSize(10), + WithMinSize(1), + WithTimeOut(0*time.Millisecond), + WithIdleTime(0*time.Millisecond), + WithSSLConfig(sslConfig), + ) + + if err != nil { + t.Fatal(err) + } + pool, err := NewSessionPool(*conf, nebulaLog) + if err != nil { + t.Fatal(err) + } + defer pool.Close() + resp, err := pool.Execute("SHOW HOSTS;") + if err != nil { + t.Fatalf(err.Error()) + return + } + checkResultSet(t, "show hosts", resp) + // Create a new space + resp, err = pool.Execute("CREATE SPACE client_test(partition_num=1024, replica_factor=1, vid_type = FIXED_STRING(30));") + if err != nil { + t.Fatalf(err.Error()) + return + } + checkResultSet(t, "create space", resp) + + resp, err = pool.Execute("DROP SPACE client_test;") + if err != nil { + t.Fatalf(err.Error()) + return + } + checkResultSet(t, "drop space", resp) +}