diff --git a/conn/pool.go b/conn/pool.go index e4730018402..4270aa9ceee 100644 --- a/conn/pool.go +++ b/conn/pool.go @@ -22,6 +22,7 @@ import ( "sync" "time" + "github.com/dgraph-io/badger/y" "github.com/dgraph-io/dgo/protos/api" "github.com/dgraph-io/dgraph/protos/pb" "github.com/dgraph-io/dgraph/x" @@ -47,7 +48,7 @@ type Pool struct { lastEcho time.Time Addr string - ticker *time.Ticker + closer *y.Closer } type Pools struct { @@ -121,7 +122,7 @@ func (p *Pools) Connect(addr string) *Pool { } p.RUnlock() - pool, err := NewPool(addr) + pool, err := newPool(addr) if err != nil { glog.Errorf("Unable to connect to host: %s", addr) return nil @@ -129,18 +130,18 @@ func (p *Pools) Connect(addr string) *Pool { p.Lock() existingPool, has = p.all[addr] + defer p.Unlock() if has { - p.Unlock() + go pool.shutdown() // Not being used, so release the resources. return existingPool } glog.Infof("CONNECTED to %v\n", addr) p.all[addr] = pool - p.Unlock() return pool } -// NewPool creates a new "pool" with one gRPC connection, refcount 0. -func NewPool(addr string) (*Pool, error) { +// newPool creates a new "pool" with one gRPC connection, refcount 0. +func newPool(addr string) (*Pool, error) { conn, err := grpc.Dial(addr, grpc.WithStatsHandler(&ocgrpc.ClientHandler{}), grpc.WithDefaultCallOptions( @@ -151,10 +152,7 @@ func NewPool(addr string) (*Pool, error) { if err != nil { return nil, err } - pl := &Pool{conn: conn, Addr: addr, lastEcho: time.Now()} - - // Initialize ticker before running monitor health. - pl.ticker = time.NewTicker(echoDuration) + pl := &Pool{conn: conn, Addr: addr, lastEcho: time.Now(), closer: y.NewCloser(1)} go pl.MonitorHealth() return pl, nil } @@ -167,7 +165,8 @@ func (p *Pool) Get() *grpc.ClientConn { } func (p *Pool) shutdown() { - p.ticker.Stop() + glog.Warningf("Shutting down extra connection to %s", p.Addr) + p.closer.SignalAndWait() p.conn.Close() } @@ -189,6 +188,15 @@ func (p *Pool) listenToHeartbeat() error { return err } + go func() { + select { + case <-ctx.Done(): + case <-p.closer.HasBeenClosed(): + cancel() + } + }() + + // This loop can block indefinitely as long as it keeps on receiving pings back. for { _, err := stream.Recv() if err != nil { @@ -203,17 +211,24 @@ func (p *Pool) listenToHeartbeat() error { // MonitorHealth monitors the health of the connection via Echo. This function blocks forever. func (p *Pool) MonitorHealth() { + defer p.closer.Done() + var lastErr error - for range p.ticker.C { - err := p.listenToHeartbeat() - if lastErr != nil && err == nil { - glog.Infof("Connection established with %v\n", p.Addr) - } else if err != nil && lastErr == nil { - glog.Warningf("Connection lost with %v. Error: %v\n", p.Addr, err) + for { + select { + case <-p.closer.HasBeenClosed(): + return + default: + err := p.listenToHeartbeat() + if lastErr != nil && err == nil { + glog.Infof("Connection established with %v\n", p.Addr) + } else if err != nil && lastErr == nil { + glog.Warningf("Connection lost with %v. Error: %v\n", p.Addr, err) + } + lastErr = err + // Sleep for a bit before retrying. + time.Sleep(echoDuration) } - lastErr = err - // Sleep for a bit before retrying. - time.Sleep(echoDuration) } }