Skip to content

Commit

Permalink
cockroachdb: add high-availability support
Browse files Browse the repository at this point in the history
This commit adds high-availability support to the CockroachDB backend. The
locking strategy implemented is heavily influenced from the very similar
Postgres backend.
  • Loading branch information
DuskEagle committed Mar 28, 2022
1 parent f5ded12 commit 66e3575
Show file tree
Hide file tree
Showing 4 changed files with 322 additions and 36 deletions.
3 changes: 3 additions & 0 deletions changelog/12965.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:improvement
cockroachdb: add high-availability support
```
83 changes: 66 additions & 17 deletions physical/cockroachdb/cockroachdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,23 @@ var (
)

const (
defaultTableName = "vault_kv_store"
defaultTableName = "vault_kv_store"
defaultHATableName = "vault_ha_locks"
)

// CockroachDBBackend Backend is a physical backend that stores data
// within a CockroachDB database.
type CockroachDBBackend struct {
table string
client *sql.DB
rawStatements map[string]string
statements map[string]*sql.Stmt
logger log.Logger
permitPool *physical.PermitPool
table string
haTable string
client *sql.DB
rawStatements map[string]string
statements map[string]*sql.Stmt
rawHAStatements map[string]string
haStatements map[string]*sql.Stmt
logger log.Logger
permitPool *physical.PermitPool
haEnabled bool
}

// NewCockroachDBBackend constructs a CockroachDB backend using the given
Expand All @@ -51,6 +56,8 @@ func NewCockroachDBBackend(conf map[string]string, logger log.Logger) (physical.
return nil, fmt.Errorf("missing connection_url")
}

haEnabled := conf["ha_enabled"] == "true"

dbTable := conf["table"]
if dbTable == "" {
dbTable = defaultTableName
Expand All @@ -61,6 +68,16 @@ func NewCockroachDBBackend(conf map[string]string, logger log.Logger) (physical.
return nil, fmt.Errorf("invalid table: %w", err)
}

dbHATable, ok := conf["ha_table"]
if !ok {
dbHATable = defaultHATableName
}

err = validateDBTable(dbHATable)
if err != nil {
return nil, fmt.Errorf("invalid HA table: %w", err)
}

maxParStr, ok := conf["max_parallel"]
var maxParInt int
if ok {
Expand All @@ -79,17 +96,30 @@ func NewCockroachDBBackend(conf map[string]string, logger log.Logger) (physical.
return nil, fmt.Errorf("failed to connect to cockroachdb: %w", err)
}

// Create the required table if it doesn't exists.
// Create the required tables if they don't exist.
createQuery := "CREATE TABLE IF NOT EXISTS " + dbTable +
" (path STRING, value BYTES, PRIMARY KEY (path))"
if _, err := db.Exec(createQuery); err != nil {
return nil, fmt.Errorf("failed to create mysql table: %w", err)
return nil, fmt.Errorf("failed to create CockroachDB table: %w", err)
}
if haEnabled {
createHATableQuery := "CREATE TABLE IF NOT EXISTS " + dbHATable +
"(ha_key TEXT NOT NULL, " +
" ha_identity TEXT NOT NULL, " +
" ha_value TEXT, " +
" valid_until TIMESTAMP WITH TIME ZONE NOT NULL, " +
" CONSTRAINT ha_key PRIMARY KEY (ha_key) " +
");"
if _, err := db.Exec(createHATableQuery); err != nil {
return nil, fmt.Errorf("failed to create CockroachDB HA table: %w", err)
}
}

// Setup the backend
c := &CockroachDBBackend{
table: dbTable,
client: db,
table: dbTable,
haTable: dbHATable,
client: db,
rawStatements: map[string]string{
"put": "INSERT INTO " + dbTable + " VALUES($1, $2)" +
" ON CONFLICT (path) DO " +
Expand All @@ -99,26 +129,45 @@ func NewCockroachDBBackend(conf map[string]string, logger log.Logger) (physical.
"list": "SELECT path FROM " + dbTable + " WHERE path LIKE $1",
},
statements: make(map[string]*sql.Stmt),
logger: logger,
permitPool: physical.NewPermitPool(maxParInt),
rawHAStatements: map[string]string{
"get": "SELECT ha_value FROM " + dbHATable + " WHERE NOW() <= valid_until AND ha_key = $1",
"upsert": "INSERT INTO " + dbHATable + " as t (ha_identity, ha_key, ha_value, valid_until)" +
" VALUES ($1, $2, $3, NOW() + $4) " +
" ON CONFLICT (ha_key) DO " +
" UPDATE SET (ha_identity, ha_key, ha_value, valid_until) = ($1, $2, $3, NOW() + $4) " +
" WHERE (t.valid_until < NOW() AND t.ha_key = $2) OR " +
" (t.ha_identity = $1 AND t.ha_key = $2) ",
"delete": "DELETE FROM " + dbHATable + " WHERE ha_key = $1",
},
haStatements: make(map[string]*sql.Stmt),
logger: logger,
permitPool: physical.NewPermitPool(maxParInt),
haEnabled: haEnabled,
}

// Prepare all the statements required
for name, query := range c.rawStatements {
if err := c.prepare(name, query); err != nil {
if err := c.prepare(c.statements, name, query); err != nil {
return nil, err
}
}
if haEnabled {
for name, query := range c.rawHAStatements {
if err := c.prepare(c.haStatements, name, query); err != nil {
return nil, err
}
}
}
return c, nil
}

// prepare is a helper to prepare a query for future execution
func (c *CockroachDBBackend) prepare(name, query string) error {
// prepare is a helper to prepare a query for future execution.
func (c *CockroachDBBackend) prepare(statementMap map[string]*sql.Stmt, name, query string) error {
stmt, err := c.client.Prepare(query)
if err != nil {
return fmt.Errorf("failed to prepare %q: %w", name, err)
}
c.statements[name] = stmt
statementMap[name] = stmt
return nil
}

Expand Down
201 changes: 201 additions & 0 deletions physical/cockroachdb/cockroachdb_ha.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
package cockroachdb

import (
"database/sql"
"fmt"
"sync"
"time"

"github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/sdk/physical"
)

const (
// The lock TTL matches the default that Consul API uses, 15 seconds.
// Used as part of SQL commands to set/extend lock expiry time relative to
// database clock.
CockroachDBLockTTLSeconds = 15

// The amount of time to wait between the lock renewals
CockroachDBLockRenewInterval = 5 * time.Second

// CockroachDBLockRetryInterval is the amount of time to wait
// if a lock fails before trying again.
CockroachDBLockRetryInterval = time.Second
)

// Verify backend satisfies the correct interfaces.
var (
_ physical.HABackend = (*CockroachDBBackend)(nil)
_ physical.Lock = (*CockroachDBLock)(nil)
)

type CockroachDBLock struct {
backend *CockroachDBBackend
key string
value string
identity string
lock sync.Mutex

renewTicker *time.Ticker

// ttlSeconds is how long a lock is valid for.
ttlSeconds int

// renewInterval is how much time to wait between lock renewals. must be << ttl.
renewInterval time.Duration

// retryInterval is how much time to wait between attempts to grab the lock.
retryInterval time.Duration
}

func (c *CockroachDBBackend) HAEnabled() bool {
return c.haEnabled
}

func (c *CockroachDBBackend) LockWith(key, value string) (physical.Lock, error) {
identity, err := uuid.GenerateUUID()
if err != nil {
return nil, err
}
return &CockroachDBLock{
backend: c,
key: key,
value: value,
identity: identity,
ttlSeconds: CockroachDBLockTTLSeconds,
renewInterval: CockroachDBLockRenewInterval,
retryInterval: CockroachDBLockRetryInterval,
}, nil
}

// Lock tries to acquire the lock by repeatedly trying to create a record in the
// CockroachDB table. It will block until either the stop channel is closed or
// the lock could be acquired successfully. The returned channel will be closed
// once the lock in the CockroachDB table cannot be renewed, either due to an
// error speaking to CockroachDB or because someone else has taken it.
func (l *CockroachDBLock) Lock(stopCh <-chan struct{}) (<-chan struct{}, error) {
l.lock.Lock()
defer l.lock.Unlock()

var (
success = make(chan struct{})
errors = make(chan error, 1)
leader = make(chan struct{})
)
go l.tryToLock(stopCh, success, errors)

select {
case <-success:
// After acquiring it successfully, we must renew the lock periodically.
l.renewTicker = time.NewTicker(l.renewInterval)
go l.periodicallyRenewLock(leader)
case err := <-errors:
return nil, err
case <-stopCh:
return nil, nil
}

return leader, nil
}

// Unlock releases the lock by deleting the lock record from the
// CockroachDB table.
func (l *CockroachDBLock) Unlock() error {
c := l.backend
c.permitPool.Acquire()
defer c.permitPool.Release()

if l.renewTicker != nil {
l.renewTicker.Stop()
}

_, err := c.haStatements["delete"].Exec(l.key)
return err
}

// Value checks whether or not the lock is held by any instance of CockroachDBLock,
// including this one, and returns the current value.
func (l *CockroachDBLock) Value() (bool, string, error) {
c := l.backend
c.permitPool.Acquire()
defer c.permitPool.Release()
var result string
err := c.haStatements["get"].QueryRow(l.key).Scan(&result)

switch err {
case nil:
return true, result, nil
case sql.ErrNoRows:
return false, "", nil
default:
return false, "", err

}
}

// tryToLock tries to create a new item in CockroachDB every `retryInterval`.
// As long as the item cannot be created (because it already exists), it will
// be retried. If the operation fails due to an error, it is sent to the errors
// channel. When the lock could be acquired successfully, the success channel
// is closed.
func (l *CockroachDBLock) tryToLock(stop <-chan struct{}, success chan struct{}, errors chan error) {
ticker := time.NewTicker(l.retryInterval)
defer ticker.Stop()

for {
select {
case <-stop:
return
case <-ticker.C:
gotlock, err := l.writeItem()
switch {
case err != nil:
// Send to the error channel and don't block if full.
select {
case errors <- err:
default:
}
return
case gotlock:
close(success)
return
}
}
}
}

func (l *CockroachDBLock) periodicallyRenewLock(done chan struct{}) {
for range l.renewTicker.C {
gotlock, err := l.writeItem()
if err != nil || !gotlock {
close(done)
l.renewTicker.Stop()
return
}
}
}

// Attempts to put/update the CockroachDB item using condition expressions to
// evaluate the TTL. Returns true if the lock was obtained, false if not.
// If false error may be nil or non-nil: nil indicates simply that someone
// else has the lock, whereas non-nil means that something unexpected happened.
func (l *CockroachDBLock) writeItem() (bool, error) {
c := l.backend
c.permitPool.Acquire()
defer c.permitPool.Release()

sqlResult, err := c.haStatements["upsert"].Exec(l.identity, l.key, l.value, l.ttlSeconds)
if err != nil {
return false, err
}
if sqlResult == nil {
return false, fmt.Errorf("empty SQL response received")
}

ar, err := sqlResult.RowsAffected()
if err != nil {
return false, err
}
return ar == 1, nil
}
Loading

0 comments on commit 66e3575

Please sign in to comment.