diff --git a/README.md b/README.md index 0666c5e..7da1198 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,4 @@ + [![Go Reference](https://pkg.go.dev/badge/github.com/infobloxopen/hotload.svg)](https://pkg.go.dev/github.com/infobloxopen/hotload) # hotload Hotload is a golang database/sql that supports dynamic reloading @@ -75,3 +76,17 @@ Pth represents a unique string that makes sense to the strategy. For example, pt point to a path in etcd or a kind/id in k8s. The hotload project ships with one hotload strategy: fsnotify. + +# Force Kill + +By default, the hotload driver allows for connections to be closed gracefully by the underlying driver. If your +application holds connections open with long running operations, this will prevent graceful switchover to new datasources. + +Adding `forceKill=true` to your DSN will cause the hotload driver to close the underlying connection manually when a +change to the connection information is detected. + + +For example: +``` +db, err := sql.Open("hotload", "fsnotify://postgres/tmp/myconfig.txt?forceKill=true") +``` diff --git a/chanGroup_test.go b/chanGroup_test.go index 246d394..faf3a63 100644 --- a/chanGroup_test.go +++ b/chanGroup_test.go @@ -2,11 +2,32 @@ package hotload import ( "context" + "database/sql/driver" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" "sync" ) +type testConn struct { + closed bool +} + +func (tc *testConn) Open(name string) (driver.Conn, error) { + return tc, nil +} + +func (tc *testConn) Prepare(query string) (driver.Stmt, error) { + return nil, nil +} +func (tc *testConn) Begin() (driver.Tx, error) { + return nil, nil +} + +func (tc *testConn) Close() error { + tc.closed = true + return nil +} + var _ = Describe("Driver", func() { var pctx context.Context var ctx context.Context @@ -58,11 +79,30 @@ var _ = Describe("Driver", func() { }) It("Should mark all connections for reset", func() { - cg.markForReset() + cg.resetConnections() for _, c := range conns { Expect(c.reset).To(BeTrue()) } }) + + It("Should kill all connections when specified", func() { + cg.forceKill = true + testConns := make([]*testConn, 0) + for _, c := range cg.conns { + tc := &testConn{} + c.conn = tc + testConns = append(testConns, tc) + } + cg.resetConnections() + + for _, c := range conns { + Expect(c.killed).To(BeTrue(), "connection should be marked killed") + } + + for _, tc := range testConns { + Expect(tc.closed).To(BeTrue(), "Closed() should have been called on the underlying connection") + } + }) }) }) diff --git a/conn.go b/conn.go index 10c79b6..f34372d 100644 --- a/conn.go +++ b/conn.go @@ -9,10 +9,11 @@ import ( // managedConn wraps a sql/driver.Conn so that it can be closed by // a supervising context. type managedConn struct { - ctx context.Context - conn driver.Conn - reset bool - mu sync.RWMutex + ctx context.Context + conn driver.Conn + reset bool + killed bool + mu sync.RWMutex } func newManagedConn(ctx context.Context, conn driver.Conn) *managedConn { @@ -112,7 +113,15 @@ func (c *managedConn) ResetSession(ctx context.Context) error { } func (c *managedConn) Close() error { - return c.conn.Close() + c.mu.Lock() + defer c.mu.Unlock() + err := c.conn.Close() + + if err == nil { + c.killed = true + } + + return err } func (c *managedConn) GetReset() bool { @@ -126,3 +135,9 @@ func (c *managedConn) Reset(v bool) { defer c.mu.Unlock() c.reset = v } + +func (c *managedConn) GetKill() bool { + c.mu.RLock() + defer c.mu.RUnlock() + return c.killed +} diff --git a/driver.go b/driver.go index a6406ec..e03c7d3 100644 --- a/driver.go +++ b/driver.go @@ -63,6 +63,8 @@ type Strategy interface { Watch(ctx context.Context, pth string, options url.Values) (value string, values <-chan string, err error) } +const forceKill = "forceKill" + var ( ErrUnsupportedStrategy = fmt.Errorf("unsupported hotload strategy") ErrMalformedConnectionString = fmt.Errorf("malformed hotload connection string") @@ -155,6 +157,7 @@ type chanGroup struct { cancel context.CancelFunc sqlDriver driver.Driver mu sync.RWMutex + forceKill bool conns []*managedConn } @@ -180,15 +183,20 @@ func (cg *chanGroup) valueChanged(v string) { defer cg.mu.Unlock() cg.cancel() cg.ctx, cg.cancel = context.WithCancel(cg.parentCtx) - cg.markForReset() + cg.resetConnections() cg.value = v } -func (cg *chanGroup) markForReset() { +func (cg *chanGroup) resetConnections() { for _, c := range cg.conns { c.Reset(true) + + if cg.forceKill { + // ignore errors from close + c.Close() + } } cg.conns = make([]*managedConn, 0) @@ -208,6 +216,16 @@ func (cg *chanGroup) Open() (driver.Conn, error) { return manConn, nil } +func (cg *chanGroup) parseValues(vs url.Values) { + cg.mu.Lock() + defer cg.mu.Unlock() + + if v, ok := vs[forceKill]; ok { + firstValue := v[0] + cg.forceKill = firstValue == "true" + } +} + func (h *hdriver) Open(name string) (driver.Conn, error) { uri, err := url.Parse(name) if err != nil { @@ -228,8 +246,8 @@ func (h *hdriver) Open(name string) (driver.Conn, error) { if !ok { return nil, ErrUnknownDriver } - - value, values, err := strategy.Watch(h.ctx, uri.Path, uri.Query()) + queryParams := uri.Query() + value, values, err := strategy.Watch(h.ctx, uri.Path, queryParams) if err != nil { return nil, err } @@ -243,6 +261,7 @@ func (h *hdriver) Open(name string) (driver.Conn, error) { sqlDriver: sqlDriver, conns: make([]*managedConn, 0), } + cgroup.parseValues(queryParams) h.cgroup[name] = cgroup go cgroup.run() }