Skip to content

Commit

Permalink
Introduce forceKill logic (#11)
Browse files Browse the repository at this point in the history
* force kill connections when configured to do so

* add readme

* add example
  • Loading branch information
khous authored Apr 5, 2021
1 parent 8e8ec53 commit bb8f1f0
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 10 deletions.
15 changes: 15 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

[![Go Reference](https://pkg.go.dev/badge/github.com/infobloxopen/hotload.svg)](https://pkg.go.dev/github.com/infobloxopen/hotload)
# hotload
Hotload is a golang database/sql that supports dynamic reloading
Expand Down Expand Up @@ -75,3 +76,17 @@ Pth represents a unique string that makes sense to the strategy. For example, pt
point to a path in etcd or a kind/id in k8s.

The hotload project ships with one hotload strategy: fsnotify.

# Force Kill

By default, the hotload driver allows for connections to be closed gracefully by the underlying driver. If your
application holds connections open with long running operations, this will prevent graceful switchover to new datasources.

Adding `forceKill=true` to your DSN will cause the hotload driver to close the underlying connection manually when a
change to the connection information is detected.


For example:
```
db, err := sql.Open("hotload", "fsnotify://postgres/tmp/myconfig.txt?forceKill=true")
```
42 changes: 41 additions & 1 deletion chanGroup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,32 @@ package hotload

import (
"context"
"database/sql/driver"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
"sync"
)

type testConn struct {
closed bool
}

func (tc *testConn) Open(name string) (driver.Conn, error) {
return tc, nil
}

func (tc *testConn) Prepare(query string) (driver.Stmt, error) {
return nil, nil
}
func (tc *testConn) Begin() (driver.Tx, error) {
return nil, nil
}

func (tc *testConn) Close() error {
tc.closed = true
return nil
}

var _ = Describe("Driver", func() {
var pctx context.Context
var ctx context.Context
Expand Down Expand Up @@ -58,11 +79,30 @@ var _ = Describe("Driver", func() {
})

It("Should mark all connections for reset", func() {
cg.markForReset()
cg.resetConnections()

for _, c := range conns {
Expect(c.reset).To(BeTrue())
}
})

It("Should kill all connections when specified", func() {
cg.forceKill = true
testConns := make([]*testConn, 0)
for _, c := range cg.conns {
tc := &testConn{}
c.conn = tc
testConns = append(testConns, tc)
}
cg.resetConnections()

for _, c := range conns {
Expect(c.killed).To(BeTrue(), "connection should be marked killed")
}

for _, tc := range testConns {
Expect(tc.closed).To(BeTrue(), "Closed() should have been called on the underlying connection")
}
})
})
})
25 changes: 20 additions & 5 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@ import (
// managedConn wraps a sql/driver.Conn so that it can be closed by
// a supervising context.
type managedConn struct {
ctx context.Context
conn driver.Conn
reset bool
mu sync.RWMutex
ctx context.Context
conn driver.Conn
reset bool
killed bool
mu sync.RWMutex
}

func newManagedConn(ctx context.Context, conn driver.Conn) *managedConn {
Expand Down Expand Up @@ -112,7 +113,15 @@ func (c *managedConn) ResetSession(ctx context.Context) error {
}

func (c *managedConn) Close() error {
return c.conn.Close()
c.mu.Lock()
defer c.mu.Unlock()
err := c.conn.Close()

if err == nil {
c.killed = true
}

return err
}

func (c *managedConn) GetReset() bool {
Expand All @@ -126,3 +135,9 @@ func (c *managedConn) Reset(v bool) {
defer c.mu.Unlock()
c.reset = v
}

func (c *managedConn) GetKill() bool {
c.mu.RLock()
defer c.mu.RUnlock()
return c.killed
}
27 changes: 23 additions & 4 deletions driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ type Strategy interface {
Watch(ctx context.Context, pth string, options url.Values) (value string, values <-chan string, err error)
}

const forceKill = "forceKill"

var (
ErrUnsupportedStrategy = fmt.Errorf("unsupported hotload strategy")
ErrMalformedConnectionString = fmt.Errorf("malformed hotload connection string")
Expand Down Expand Up @@ -155,6 +157,7 @@ type chanGroup struct {
cancel context.CancelFunc
sqlDriver driver.Driver
mu sync.RWMutex
forceKill bool
conns []*managedConn
}

Expand All @@ -180,15 +183,20 @@ func (cg *chanGroup) valueChanged(v string) {
defer cg.mu.Unlock()
cg.cancel()
cg.ctx, cg.cancel = context.WithCancel(cg.parentCtx)
cg.markForReset()
cg.resetConnections()

cg.value = v

}

func (cg *chanGroup) markForReset() {
func (cg *chanGroup) resetConnections() {
for _, c := range cg.conns {
c.Reset(true)

if cg.forceKill {
// ignore errors from close
c.Close()
}
}

cg.conns = make([]*managedConn, 0)
Expand All @@ -208,6 +216,16 @@ func (cg *chanGroup) Open() (driver.Conn, error) {
return manConn, nil
}

func (cg *chanGroup) parseValues(vs url.Values) {
cg.mu.Lock()
defer cg.mu.Unlock()

if v, ok := vs[forceKill]; ok {
firstValue := v[0]
cg.forceKill = firstValue == "true"
}
}

func (h *hdriver) Open(name string) (driver.Conn, error) {
uri, err := url.Parse(name)
if err != nil {
Expand All @@ -228,8 +246,8 @@ func (h *hdriver) Open(name string) (driver.Conn, error) {
if !ok {
return nil, ErrUnknownDriver
}

value, values, err := strategy.Watch(h.ctx, uri.Path, uri.Query())
queryParams := uri.Query()
value, values, err := strategy.Watch(h.ctx, uri.Path, queryParams)
if err != nil {
return nil, err
}
Expand All @@ -243,6 +261,7 @@ func (h *hdriver) Open(name string) (driver.Conn, error) {
sqlDriver: sqlDriver,
conns: make([]*managedConn, 0),
}
cgroup.parseValues(queryParams)
h.cgroup[name] = cgroup
go cgroup.run()
}
Expand Down

0 comments on commit bb8f1f0

Please sign in to comment.