Skip to content

Commit

Permalink
fix #272 (#273)
Browse files Browse the repository at this point in the history
* add test

* fix test case

* fix #272
  • Loading branch information
HarrisChu authored Apr 12, 2023
1 parent 3ed8926 commit fae48de
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 16 deletions.
7 changes: 7 additions & 0 deletions connection_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,10 +181,17 @@ func (pool *ConnectionPool) getIdleConn() (*connection, error) {

// Release connection to pool
func (pool *ConnectionPool) release(conn *connection) {
pool.releaseAndBack(conn, true)
}

func (pool *ConnectionPool) releaseAndBack(conn *connection, pushBack bool) {
pool.rwLock.Lock()
defer pool.rwLock.Unlock()
// Remove connection from active queue and add into idle queue
removeFromList(&pool.activeConnectionQueue, conn)
if !pushBack {
return
}
conn.release()
pool.idleConnectionQueue.PushBack(conn)
}
Expand Down
12 changes: 1 addition & 11 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import (
"sync"
"time"

"github.com/facebook/fbthrift/thrift/lib/go/thrift"
"github.com/vesoft-inc/nebula-go/v3/nebula"
graph "github.com/vesoft-inc/nebula-go/v3/nebula/graph"
)
Expand All @@ -35,14 +34,6 @@ type Session struct {
}

func (session *Session) reconnectWithExecuteErr(err error) error {
// Reconnect only if the transport is closed
err2, ok := err.(thrift.TransportException)
if !ok {
return err
}
if err2.TypeID() != thrift.END_OF_FILE {
return err
}
if _err := session.reConnect(); _err != nil {
return fmt.Errorf("failed to reconnect, %s", _err.Error())
}
Expand Down Expand Up @@ -204,8 +195,7 @@ func (session *Session) reConnect() error {
return err
}

// Release connection to pool
session.connPool.release(session.connection)
session.connPool.releaseAndBack(session.connection, false)
session.connection = newConnection
return nil
}
Expand Down
62 changes: 57 additions & 5 deletions session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@
package nebula_go

import (
"context"
"testing"
"time"

"github.com/stretchr/testify/assert"
)

func TestSession_Execute(t *testing.T) {
Expand Down Expand Up @@ -42,15 +45,64 @@ func TestSession_Execute(t *testing.T) {
t.Fatal(err)
}
}
go func() {
ctx, cancel := context.WithCancel(context.TODO())
defer cancel()
go func(ctx context.Context) {
for {
f(sess)
select {
case <-ctx.Done():
break
default:
f(sess)
}
}
}()
}(ctx)
go func(ctx context.Context) {
for {
select {
case <-ctx.Done():
break
default:
f(sess)
}
}
}(ctx)
time.Sleep(300 * time.Millisecond)

}

func TestSession_Recover(t *testing.T) {
query := "show hosts"
config := GetDefaultConf()
host := HostAddress{address, port}
pool, err := NewConnectionPool([]HostAddress{host}, config, DefaultLogger{})
if err != nil {
t.Fatal(err)
}

sess, err := pool.GetSession("root", "nebula")
if err != nil {
t.Fatal(err)
}
assert.Equal(t, 1, pool.getActiveConnCount()+pool.getIdleConnCount())
go func() {
for {
f(sess)
_, _ = sess.Execute(query)
}
}()
time.Sleep(300 * time.Millisecond)
stopContainer(t, "nebula-docker-compose_graphd_1")
stopContainer(t, "nebula-docker-compose_graphd1_1")
stopContainer(t, "nebula-docker-compose_graphd2_1")
defer func() {
startContainer(t, "nebula-docker-compose_graphd1_1")
startContainer(t, "nebula-docker-compose_graphd2_1")
}()
<-time.After(3 * time.Second)
startContainer(t, "nebula-docker-compose_graphd_1")
<-time.After(3 * time.Second)
_, err = sess.Execute(query)
if err != nil {
t.Fatal(err)
}
assert.Equal(t, 1, pool.getActiveConnCount()+pool.getIdleConnCount())
}

0 comments on commit fae48de

Please sign in to comment.