diff --git a/client_test.go b/client_test.go index 6fd93204..82041abe 100644 --- a/client_test.go +++ b/client_test.go @@ -26,7 +26,7 @@ import ( const ( address = "127.0.0.1" - port = 29562 + port = 3699 username = "root" password = "nebula" @@ -725,9 +725,6 @@ func TestMultiThreads(t *testing.T) { assert.Equal(t, 666, pool.getActiveConnCount(), "Total number of active connections should be 666") assert.Equal(t, 666, len(sessionList), "Total number of sessions should be 666") - // for i := 0; i < len(hostList); i++ { - // assert.Equal(t, 222, pool.GetServerWorkload(i)) - // } for i := 0; i < testPoolConfig.MaxConnPoolSize; i++ { sessionList[i].Release() } diff --git a/connection.go b/connection.go index c191498c..0606fbd7 100644 --- a/connection.go +++ b/connection.go @@ -116,7 +116,8 @@ func (cn *connection) execute(sessionID int64, stmt string) (*graph.ExecutionRes return cn.executeWithParameter(sessionID, stmt, map[string]*nebula.Value{}) } -func (cn *connection) executeWithParameter(sessionID int64, stmt string, params map[string]*nebula.Value) (*graph.ExecutionResponse, error) { +func (cn *connection) executeWithParameter(sessionID int64, stmt string, + params map[string]*nebula.Value) (*graph.ExecutionResponse, error) { resp, err := cn.graph.ExecuteWithParameter(sessionID, []byte(stmt), params) if err != nil { // reopen the connection if timeout diff --git a/connection_pool.go b/connection_pool.go index aeb7aa14..91e4d958 100644 --- a/connection_pool.go +++ b/connection_pool.go @@ -129,6 +129,7 @@ func (pool *ConnectionPool) GetSession(username, password string) (*Session, err sessionID: sessID, connection: conn, connPool: pool, + sessPool: nil, log: pool.log, timezoneInfo: timezoneInfo{timezoneOffset, timezoneName}, } diff --git a/session.go b/session.go index 08a1a3d2..13c1e70d 100644 --- a/session.go +++ b/session.go @@ -11,6 +11,7 @@ package nebula_go import ( "fmt" "sync" + "time" "github.com/facebook/fbthrift/thrift/lib/go/thrift" "github.com/vesoft-inc/nebula-go/v3/nebula" @@ -25,8 +26,10 @@ type timezoneInfo struct { type Session struct { sessionID int64 connection *connection - connPool *ConnectionPool + connPool *ConnectionPool // the connection pool which the session belongs to. could be nil if the Session is store in the SessionPool + sessPool *SessionPool // the session pool which the session belongs to. could be nil if the Session is store in the ConnectionPool log Logger + returnedAt time.Time // the timestamp that the session was created or returned. mu sync.Mutex timezoneInfo } @@ -221,8 +224,11 @@ func (session *Session) Release() { if err := session.connection.signOut(session.sessionID); err != nil { session.log.Warn(fmt.Sprintf("Sign out failed, %s", err.Error())) } - // Release connection to pool - session.connPool.release(session.connection) + + // if the session is created from the connection pool, return the connection to the pool + if session.connPool != nil { + session.connPool.release(session.connection) + } session.connection = nil } @@ -231,8 +237,6 @@ func (session *Session) GetSessionID() int64 { } func (session *Session) Ping() error { - // session.mu.Lock() - // defer session.mu.Unlock() if session.connection == nil { return fmt.Errorf("failed to ping: Session has been released") } diff --git a/session_pool.go b/session_pool.go index b09d648d..f358b6e6 100644 --- a/session_pool.go +++ b/session_pool.go @@ -42,7 +42,7 @@ type SessionPool struct { log Logger closed bool cleanerChan chan struct{} //notify when pool is close - rwLock sync.Mutex + rwLock sync.RWMutex sslConfig *tls.Config } @@ -54,6 +54,7 @@ func NewSessionPool(conf SessionPoolConf, log Logger) (*SessionPool, error) { newSessionPool := &SessionPool{ conf: conf, + log: log, } // init the pool @@ -66,20 +67,22 @@ func NewSessionPool(conf SessionPoolConf, log Logger) (*SessionPool, error) { // init initializes the session pool. func (pool *SessionPool) init() error { + pool.rwLock.Lock() + defer pool.rwLock.Unlock() // check the hosts status if err := checkAddresses(pool.conf.timeOut, pool.conf.serviceAddrs, pool.sslConfig); err != nil { return fmt.Errorf("failed to initialize the session pool, %s", err.Error()) } - // create sessions to fulfill the min connection size + // create sessions to fulfill the min pool size for i := 0; i < pool.conf.minSize; i++ { - session, err := pool.getIdleSession() + session, err := pool.newSession() if err != nil { return fmt.Errorf("failed to initialize the session pool, %s", err.Error()) } - pool.rwLock.Lock() - defer pool.rwLock.Unlock() - pool.idleSessions.PushBack(session) + + session.returnedAt = time.Now() + pool.addSessionToList(&pool.idleSessions, session) } return nil @@ -89,6 +92,8 @@ func (pool *SessionPool) init() error { // Notice there are some limitations: // 1. The query should not be a plain space switch statement, e.g. "USE test_space", // but queries like "use space xxx; match (v) return v" are accepted. +// 2. If the query contains statements like "USE ", the space will be set to the +// one in the pool config after the execution of the query. func (pool *SessionPool) Execute(stmt string) (*ResultSet, error) { return pool.ExecuteWithParameter(stmt, map[string]interface{}{}) } @@ -112,39 +117,34 @@ func (pool *SessionPool) ExecuteWithParameter(stmt string, params map[string]int } // Execute the query + pool.rwLock.Lock() + defer pool.rwLock.Unlock() resp, err := session.connection.executeWithParameter(session.sessionID, stmt, paramsMap) if err != nil { return nil, err } + resSet, err := genResultSet(resp, session.timezoneInfo) if err != nil { return nil, err } + // pool.rwLock.Lock() // if the space was changed in after the execution of the given query, // change it back to the default space specified in the pool config if resSet.GetSpaceName() != "" && resSet.GetSpaceName() != pool.conf.spaceName { - stmt = fmt.Sprintf("USE %s", pool.conf.spaceName) - resp, err := session.connection.execute(session.sessionID, stmt) + err := pool.setSessionSpaceToDefault(session) if err != nil { return nil, err } - // if failed to change back to the default space, send a warning log - // and remove the session from the pool because it is malformed. - if resp.ErrorCode != nebula.ErrorCode_SUCCEEDED { - pool.log.Warn(fmt.Sprintf("failed to reset the space of the session: errorCode: %s, errorMsg: %s, session removed", - resp.ErrorCode, resp.ErrorMsg)) - pool.rwLock.Lock() - removeSessionFromList(&pool.activeSessions, session) - pool.rwLock.Unlock() - } } // Return the session to the idle list - pool.rwLock.Lock() - defer pool.rwLock.Unlock() - removeSessionFromList(&pool.activeSessions, session) - pool.idleSessions.PushBack(session) + // TODO(Aiee): Use go routine to avoid blocking + pool.removeSessionFromList(&pool.activeSessions, session) + pool.addSessionToList(&pool.idleSessions, session) + session.returnedAt = time.Now() + return resSet, err } @@ -213,7 +213,10 @@ func (pool *SessionPool) ExecuteJson(stmt string) ([]byte, error) { // ExecuteJson returns the result of the given query as a json string // Date and Datetime will be returned in UTC // The result is a JSON string in the same format as ExecuteJson() +//TODO(Aiee) check the space name func (pool *SessionPool) ExecuteJsonWithParameter(stmt string, params map[string]interface{}) ([]byte, error) { + return nil, fmt.Errorf("not implemented") + // Get a session from the pool session, err := pool.getIdleSession() if err != nil { @@ -229,10 +232,14 @@ func (pool *SessionPool) ExecuteJsonWithParameter(stmt string, params map[string return nil, err } + pool.rwLock.Lock() + defer pool.rwLock.Unlock() resp, err := session.connection.ExecuteJsonWithParameter(session.sessionID, stmt, paramsMap) if err != nil { return nil, err } + + //TODO(Aiee) check the space name return resp, nil } @@ -275,7 +282,15 @@ func (pool *SessionPool) Close() { } } +// GetTotalSessionCount returns the total number of sessions in the pool +func (pool *SessionPool) GetTotalSessionCount() int { + pool.rwLock.RLock() + defer pool.rwLock.RUnlock() + return pool.activeSessions.Len() + pool.idleSessions.Len() +} + // newSession creates a new session and returns it. +// `use ` will be executed so that the new session will be in the default space. func (pool *SessionPool) newSession() (*Session, error) { graphAddr := pool.getNextAddr() cn := connection{ @@ -304,6 +319,7 @@ func (pool *SessionPool) newSession() (*Session, error) { sessionID: sessID, connection: &cn, connPool: nil, + sessPool: pool, log: pool.log, timezoneInfo: timezoneInfo{timezoneOffset, timezoneName}, } @@ -334,23 +350,24 @@ func (pool *SessionPool) getNextAddr() HostAddress { return host } -// getSession returns a available session. +// getSession returns an available session. +// This method should move an available session to the active list and should be MT-safe. func (pool *SessionPool) getIdleSession() (*Session, error) { pool.rwLock.Lock() defer pool.rwLock.Unlock() - // Get a session from the idle queue if possible if pool.idleSessions.Len() > 0 { - session := pool.idleSessions.Remove(pool.idleSessions.Front()).(*Session) - pool.activeSessions.PushBack(session) + session := pool.idleSessions.Front().Value.(*Session) + pool.removeSessionFromList(&pool.idleSessions, session) + pool.addSessionToList(&pool.activeSessions, session) return session, nil - } else if pool.activeSessions.Len() < pool.conf.maxSize { + } else if pool.activeSessions.Len()+pool.idleSessions.Len() < pool.conf.maxSize { // Create a new session if the total number of sessions is less than the max size session, err := pool.newSession() if err != nil { return nil, err } - pool.activeSessions.PushBack(session) + pool.addSessionToList(&pool.activeSessions, session) return session, nil } // There is no available session in the pool and the total session count has reached the limit @@ -391,10 +408,18 @@ func (pool *SessionPool) sessionCleaner() { } closing := pool.timeoutSessionList() - pool.rwLock.Unlock() - for _, s := range closing { - s.Release() + + //release expired session from the pool + for _, session := range closing { + if session.connection == nil { + session.log.Warn("Session has been released") + return + } + if err := session.connection.signOut(session.sessionID); err != nil { + session.log.Warn(fmt.Sprintf("Sign out failed, %s", err.Error())) + } } + pool.rwLock.Unlock() t.Reset(d) } @@ -415,7 +440,7 @@ func (pool *SessionPool) timeoutSessionList() (closing []*Session) { newEle = ele.Next() // Check Session is expired - if !ele.Value.(*connection).returnedAt.Before(expiredSince) { + if !ele.Value.(*Session).returnedAt.Before(expiredSince) { return } closing = append(closing, ele.Value.(*Session)) @@ -441,10 +466,30 @@ func parseParams(params map[string]interface{}) (map[string]*nebula.Value, error } // removeSessionFromIdleList Removes a session from list -func removeSessionFromList(l *list.List, session *Session) { +func (pool *SessionPool) removeSessionFromList(l *list.List, session *Session) { for ele := l.Front(); ele != nil; ele = ele.Next() { if ele.Value.(*Session) == session { l.Remove(ele) } } } + +func (pool *SessionPool) addSessionToList(l *list.List, session *Session) { + l.PushBack(session) +} + +func (pool *SessionPool) setSessionSpaceToDefault(session *Session) error { + stmt := fmt.Sprintf("USE %s", pool.conf.spaceName) + resp, err := session.connection.execute(session.sessionID, stmt) + if err != nil { + return err + } + // if failed to change back to the default space, send a warning log + // and remove the session from the pool because it is malformed. + if resp.ErrorCode != nebula.ErrorCode_SUCCEEDED { + pool.log.Warn(fmt.Sprintf("failed to reset the space of the session: errorCode: %s, errorMsg: %s, session removed", + resp.ErrorCode, resp.ErrorMsg)) + pool.removeSessionFromList(&pool.activeSessions, session) + } + return nil +} diff --git a/session_pool_test.go b/session_pool_test.go index 946085b7..79945f19 100644 --- a/session_pool_test.go +++ b/session_pool_test.go @@ -14,6 +14,7 @@ import ( "fmt" "sync" "testing" + "time" "github.com/stretchr/testify/assert" ) @@ -61,8 +62,11 @@ func TestSessionPoolBasic(t *testing.T) { assert.Equal(t, 1, sessionPool.idleSessions.Len(), "Total number of active connections should be 1") } -func TestSessionPoolMultiThread(t *testing.T) { - prepareSpace(t, "client_test") +func TestSessionPoolMultiThreadGetSession(t *testing.T) { + err := prepareSpace(t, "client_test") + if err != nil { + t.Fatal(err) + } defer dropSpace(t, "client_test") hostList := poolAddress @@ -72,72 +76,106 @@ func TestSessionPoolMultiThread(t *testing.T) { } config.maxSize = 333 - // test get idle session - { - // create session pool - sessionPool, err := NewSessionPool(*config, DefaultLogger{}) - if err != nil { - t.Fatal(err) - } - defer sessionPool.Close() - - var wg sync.WaitGroup - sessCh := make(chan *Session) - done := make(chan bool) - wg.Add(sessionPool.conf.maxSize) - - // producer creates sessions - for i := 0; i < sessionPool.conf.maxSize; i++ { - go func(sessCh chan<- *Session, wg *sync.WaitGroup) { - defer wg.Done() - session, err := sessionPool.getIdleSession() - if err != nil { - t.Errorf("fail to create a new session from connection pool, %s", err.Error()) - } - sessCh <- session - }(sessCh, &wg) - } + // create session pool + sessionPool, err := NewSessionPool(*config, DefaultLogger{}) + if err != nil { + t.Fatal(err) + } + defer sessionPool.Close() + + var wg sync.WaitGroup + sessCh := make(chan *Session) + done := make(chan bool) + wg.Add(sessionPool.conf.maxSize) - // consumer consumes the session created - var sessionList []*Session - go func(sessCh <-chan *Session) { - for session := range sessCh { - sessionList = append(sessionList, session) + // producer create sessions + for i := 0; i < sessionPool.conf.maxSize; i++ { + go func(sessCh chan<- *Session, wg *sync.WaitGroup) { + defer wg.Done() + session, err := sessionPool.getIdleSession() + if err != nil { + t.Errorf("fail to create a new session from connection pool, %s", err.Error()) } - done <- true - }(sessCh) - wg.Wait() - close(sessCh) - <-done - - assert.Equal(t, 333, sessionPool.activeSessions.Len(), "Total number of active connections should be 333") - assert.Equal(t, 333, len(sessionList), "Total number of result returned should be 333") - } - - // test Execute() - { - // create session pool - sessionPool, err := NewSessionPool(*config, DefaultLogger{}) - if err != nil { - t.Fatal(err) - } - defer sessionPool.Close() - - var wg sync.WaitGroup - wg.Add(sessionPool.conf.maxSize) - - for i := 0; i < sessionPool.conf.maxSize; i++ { - go func(wg *sync.WaitGroup) { - defer wg.Done() - _, err := sessionPool.Execute("RETURN 1") - if err != nil { - t.Errorf(err.Error()) - } - }(&wg) + sessCh <- session + }(sessCh, &wg) + } + + // consumer consumes the session created + var sessionList []*Session + go func(sessCh <-chan *Session) { + for session := range sessCh { + sessionList = append(sessionList, session) } - wg.Wait() - assert.Equal(t, 0, sessionPool.activeSessions.Len(), "Total number of active connections should be 0") + done <- true + }(sessCh) + wg.Wait() + close(sessCh) + <-done + + assert.Equalf(t, config.maxSize, sessionPool.activeSessions.Len(), + "Total number of active connections should be %d", config.maxSize) + assert.Equalf(t, config.maxSize, len(sessionList), + "Total number of result returned should be %d", config.maxSize) +} + +func TestSessionPoolMultiThreadExecute(t *testing.T) { + err := prepareSpace(t, "client_test") + if err != nil { + t.Fatal(err) + } + defer dropSpace(t, "client_test") + + hostList := poolAddress + config, err := NewSessionPoolConf("root", "nebula", hostList, "client_test") + if err != nil { + t.Errorf("failed to create session pool config, %s", err.Error()) + } + config.maxSize = 300 + + // create session pool + sessionPool, err := NewSessionPool(*config, DefaultLogger{}) + if err != nil { + t.Fatal(err) } + defer sessionPool.Close() + + var wg sync.WaitGroup + wg.Add(sessionPool.conf.maxSize) + respCh := make(chan *ResultSet) + done := make(chan bool) + + for i := 0; i < sessionPool.conf.maxSize; i++ { + go func(respCh chan<- *ResultSet, wg *sync.WaitGroup) { + defer wg.Done() + resp, err := sessionPool.Execute("SHOW HOSTS") + if err != nil { + t.Errorf(err.Error()) + } + respCh <- resp + }(respCh, &wg) + } + + var respList []*ResultSet + go func(respCh <-chan *ResultSet) { + for resp := range respCh { + respList = append(respList, resp) + } + done <- true + }(respCh) + wg.Wait() + close(respCh) + <-done + + // should generate config.maxSize results + assert.Equalf(t, config.maxSize, len(respList), + "Total number of response should be %d", config.maxSize) + + // should be 0 active sessions because they are put back to idle session list after + // query execution + assert.Equal(t, 0, sessionPool.activeSessions.Len(), + "Total number of active sessions should be 0") + // Note that here the idle session number may not be equal to the max size because once the query execution + // finished, the session will be put back to the idle session list and could be reused by other goroutines. } // This test is used to test if the space bond to session is the same as the space in the session pool config after executing @@ -188,3 +226,56 @@ func TestSessionPoolSpaceChange(t *testing.T) { resultSet.GetErrorCode(), resultSet.GetErrorMsg())) assert.Equal(t, resultSet.GetSpaceName(), "test_space_1", "space name should be test_space_1") } + +func TestIdleSessionCleaner(t *testing.T) { + err := prepareSpace(t, "client_test") + if err != nil { + t.Fatal(err) + } + defer dropSpace(t, "client_test") + + hostAddress := HostAddress{Host: address, Port: port} + idleTimeoutConfig, err := NewSessionPoolConf("root", "nebula", []HostAddress{hostAddress}, "client_test") + if err != nil { + t.Errorf("failed to create session pool config, %s", err.Error()) + } + + idleTimeoutConfig.idleTime = 2 * time.Second + idleTimeoutConfig.minSize = 5 + idleTimeoutConfig.maxSize = 100 + + // create session pool + sessionPool, err := NewSessionPool(*idleTimeoutConfig, DefaultLogger{}) + if err != nil { + t.Fatal(err) + } + defer sessionPool.Close() + assert.Equal(t, 5, sessionPool.activeSessions.Len()+sessionPool.idleSessions.Len(), "Total number of sessions should be 5") + + // execute multiple queries so more sessions will be created + var wg sync.WaitGroup + wg.Add(sessionPool.conf.maxSize) + + for i := 0; i < sessionPool.conf.maxSize; i++ { + go func(wg *sync.WaitGroup) { + defer wg.Done() + _, err := sessionPool.Execute("RETURN 1") + if err != nil { + t.Errorf(err.Error()) + } + }(&wg) + } + wg.Wait() + + // wait for sessions to be idle + time.Sleep(idleTimeoutConfig.idleTime) + + // the minimum interval for cleanup is 1 minute, so in CI we need to trigger cleanup manually + sessionPool.cleanerChan <- struct{}{} + time.Sleep(idleTimeoutConfig.idleTime + 1) + + // after cleanup, the total session should be 5 which is the minSize + assert.Truef(t, sessionPool.GetTotalSessionCount() == sessionPool.conf.minSize, + "Total number of session should be %d, but got %d", + sessionPool.conf.minSize, sessionPool.GetTotalSessionCount()) +}