diff --git a/connection_pool.go b/connection_pool.go index 6bbe3324..e7c8260d 100644 --- a/connection_pool.go +++ b/connection_pool.go @@ -181,10 +181,17 @@ func (pool *ConnectionPool) getIdleConn() (*connection, error) { // Release connection to pool func (pool *ConnectionPool) release(conn *connection) { + pool.releaseAndBack(conn, true) +} + +func (pool *ConnectionPool) releaseAndBack(conn *connection, pushBack bool) { pool.rwLock.Lock() defer pool.rwLock.Unlock() // Remove connection from active queue and add into idle queue removeFromList(&pool.activeConnectionQueue, conn) + if !pushBack { + return + } conn.release() pool.idleConnectionQueue.PushBack(conn) } diff --git a/session.go b/session.go index 8cf23e43..2683378d 100644 --- a/session.go +++ b/session.go @@ -13,7 +13,6 @@ import ( "sync" "time" - "github.com/facebook/fbthrift/thrift/lib/go/thrift" "github.com/vesoft-inc/nebula-go/v3/nebula" graph "github.com/vesoft-inc/nebula-go/v3/nebula/graph" ) @@ -35,14 +34,6 @@ type Session struct { } func (session *Session) reconnectWithExecuteErr(err error) error { - // Reconnect only if the transport is closed - err2, ok := err.(thrift.TransportException) - if !ok { - return err - } - if err2.TypeID() != thrift.END_OF_FILE { - return err - } if _err := session.reConnect(); _err != nil { return fmt.Errorf("failed to reconnect, %s", _err.Error()) } @@ -204,8 +195,7 @@ func (session *Session) reConnect() error { return err } - // Release connection to pool - session.connPool.release(session.connection) + session.connPool.releaseAndBack(session.connection, false) session.connection = newConnection return nil } diff --git a/session_test.go b/session_test.go index bf7e321f..0e94e515 100644 --- a/session_test.go +++ b/session_test.go @@ -9,8 +9,11 @@ package nebula_go import ( + "context" "testing" "time" + + "github.com/stretchr/testify/assert" ) func TestSession_Execute(t *testing.T) { @@ -42,15 +45,64 @@ func TestSession_Execute(t *testing.T) { t.Fatal(err) } } - go func() { + ctx, cancel := context.WithCancel(context.TODO()) + defer cancel() + go func(ctx context.Context) { for { - f(sess) + select { + case <-ctx.Done(): + break + default: + f(sess) + } } - }() + }(ctx) + go func(ctx context.Context) { + for { + select { + case <-ctx.Done(): + break + default: + f(sess) + } + } + }(ctx) + time.Sleep(300 * time.Millisecond) + +} + +func TestSession_Recover(t *testing.T) { + query := "show hosts" + config := GetDefaultConf() + host := HostAddress{address, port} + pool, err := NewConnectionPool([]HostAddress{host}, config, DefaultLogger{}) + if err != nil { + t.Fatal(err) + } + + sess, err := pool.GetSession("root", "nebula") + if err != nil { + t.Fatal(err) + } + assert.Equal(t, 1, pool.getActiveConnCount()+pool.getIdleConnCount()) go func() { for { - f(sess) + _, _ = sess.Execute(query) } }() - time.Sleep(300 * time.Millisecond) + stopContainer(t, "nebula-docker-compose_graphd_1") + stopContainer(t, "nebula-docker-compose_graphd1_1") + stopContainer(t, "nebula-docker-compose_graphd2_1") + defer func() { + startContainer(t, "nebula-docker-compose_graphd1_1") + startContainer(t, "nebula-docker-compose_graphd2_1") + }() + <-time.After(3 * time.Second) + startContainer(t, "nebula-docker-compose_graphd_1") + <-time.After(3 * time.Second) + _, err = sess.Execute(query) + if err != nil { + t.Fatal(err) + } + assert.Equal(t, 1, pool.getActiveConnCount()+pool.getIdleConnCount()) }