diff --git a/conn.go b/conn.go index 3685330..7393e55 100644 --- a/conn.go +++ b/conn.go @@ -10,11 +10,15 @@ type PoolConn struct { net.Conn hp *heapPool updatedtime time.Time + unusable bool } func (pc *PoolConn) Close() error { - pc.updatedtime = time.Now() + if pc.unusable { + return pc.close() + } + pc.updatedtime = time.Now() if err := pc.hp.put(pc); err != nil { log.Printf("put conn failed:%v\n", err) pc.hp = nil @@ -23,6 +27,10 @@ func (pc *PoolConn) Close() error { return nil } +func (pc *PoolConn) MarkUnusable() { + pc.unusable = true +} + func (pc *PoolConn) close() error { return pc.Conn.Close() } diff --git a/heap.go b/heap.go index c11659a..b0e4fbe 100644 --- a/heap.go +++ b/heap.go @@ -80,7 +80,7 @@ func NewHeapPool(config *PoolConfig) (Pool, error) { maxIdle: maxIdle, idletime: idletime, maxLifetime: maxLifetime, - cleanerCh: make(chan struct{}, 1), + cleanerCh: make(chan struct{}), factory: config.Factory, } @@ -126,6 +126,9 @@ func (hp *heapPool) Close() { hp.mu.Lock() defer hp.mu.Unlock() + if hp.freeConn == nil { + return + } hp.cleanerCh <- struct{}{} hp.factory = nil for hp.freeConn.Len() > 0 { @@ -140,6 +143,9 @@ func (hp *heapPool) put(conn *PoolConn) error { hp.mu.Lock() defer hp.mu.Unlock() + if hp.freeConn == nil { + return ErrClosed + } if hp.freeConn.Len() >= hp.maxCap { return errors.New("pool have been filled") } @@ -151,6 +157,9 @@ func (hp *heapPool) Len() int { hp.mu.Lock() defer hp.mu.Unlock() + if hp.freeConn == nil { + return 0 + } return hp.freeConn.Len() } diff --git a/heap_test.go b/heap_test.go index b44668d..62a3ed0 100644 --- a/heap_test.go +++ b/heap_test.go @@ -75,12 +75,15 @@ func TestPriorityQueue(t *testing.T) { if pc1.updatedtime.Sub(pc2.updatedtime) > 0 { t.Errorf("priority is invalid, older conn should first out") } + pc1.Close() + pc2.Close() + p.Close() } func TestPoolConcurrent(t *testing.T) { p, _ := newHeapPool() - for i := 0; i < 100; i++ { + for i := 0; i < MaxCap+10; i++ { conn, err := p.Get() if err != nil { t.Errorf("Get error: %s", err) @@ -93,7 +96,7 @@ func TestPoolConcurrent(t *testing.T) { time.Sleep(5 * time.Second) if p.Len() != MaxCap { - t.Errorf("Pool length should equals MaxCap, but get:%v", p.Len()) + t.Errorf("Pool length should equals:, but get:%v", MaxCap, p.Len()) } time.Sleep(time.Minute) @@ -103,6 +106,30 @@ func TestPoolConcurrent(t *testing.T) { p.Close() } +func TestPoolConcurrent2(t *testing.T) { + p, _ := newHeapPool() + for i := 0; i < MaxCap; i++ { + conn, err := p.Get() + if err != nil { + t.Errorf("Get error: %s", err) + } + go func(conn net.Conn, i int) { + time.Sleep(time.Second) + if i >= MaxCap-10 { + conn.(*PoolConn).MarkUnusable() + } + conn.Close() + }(conn, i) + } + + time.Sleep(5 * time.Second) + if p.Len() != MaxCap-10 { + t.Errorf("Pool length should equals:%v, but get:%v", MaxCap-10, p.Len()) + } + + p.Close() +} + func newHeapPool() (Pool, error) { return NewHeapPool(&PoolConfig{ InitialCap: InitialCap,