Skip to content

Commit

Permalink
rocksdb: implement DBConnection.Revert
Browse files Browse the repository at this point in the history
  • Loading branch information
roysc committed Oct 5, 2021
1 parent 25b166d commit 9cbf319
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 18 deletions.
89 changes: 71 additions & 18 deletions db/rocksdb/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
)

var (
currentDBFileName string = "current.db"
checkpointFileFormat string = "%020d.db"
)

Expand Down Expand Up @@ -47,8 +48,8 @@ type checkpointCache struct {
}

type cpCacheEntry struct {
cxn *dbConnection
txnCount uint
cxn *dbConnection
openCount uint
}

type dbTxn struct {
Expand Down Expand Up @@ -104,7 +105,16 @@ func NewDB(dir string) (*dbManager, error) {
if mgr.vmgr, err = readVersions(mgr.checkpointsDir()); err != nil {
return nil, err
}
dbPath := filepath.Join(dir, "current.db")
dbPath := filepath.Join(dir, currentDBFileName)
// if the current db file is missing but there are checkpoints, restore it
if mgr.vmgr.Count() > 0 {
if _, err = os.Stat(dbPath); os.IsNotExist(err) {
err = mgr.restoreFromCheckpoint(mgr.vmgr.Last(), dbPath)
if err != nil {
return nil, err
}
}
}
mgr.current, err = gorocksdb.OpenOptimisticTransactionDb(dbo, dbPath)
if err != nil {
return nil, err
Expand Down Expand Up @@ -133,8 +143,8 @@ func readVersions(dir string) (*dbm.VersionManager, error) {
return dbm.NewVersionManager(versions), nil
}

func (mgr *dbManager) checkpointPath(ver uint64) (string, error) {
dbPath := filepath.Join(mgr.checkpointsDir(), fmt.Sprintf(checkpointFileFormat, ver))
func (mgr *dbManager) checkpointPath(version uint64) (string, error) {
dbPath := filepath.Join(mgr.checkpointsDir(), fmt.Sprintf(checkpointFileFormat, version))
if stat, err := os.Stat(dbPath); err != nil {
if errors.Is(err, os.ErrNotExist) {
err = dbm.ErrVersionDoesNotExist
Expand All @@ -146,23 +156,23 @@ func (mgr *dbManager) checkpointPath(ver uint64) (string, error) {
return dbPath, nil
}

func (mgr *dbManager) openCheckpoint(ver uint64) (*dbConnection, error) {
func (mgr *dbManager) openCheckpoint(version uint64) (*dbConnection, error) {
mgr.cpCache.mtx.Lock()
defer mgr.cpCache.mtx.Unlock()
cp, has := mgr.cpCache.cache[ver]
cp, has := mgr.cpCache.cache[version]
if has {
cp.txnCount += 1
cp.openCount += 1
return cp.cxn, nil
}
dbPath, err := mgr.checkpointPath(ver)
dbPath, err := mgr.checkpointPath(version)
if err != nil {
return nil, err
}
db, err := gorocksdb.OpenOptimisticTransactionDb(mgr.opts.dbo, dbPath)
if err != nil {
return nil, err
}
mgr.cpCache.cache[ver] = &cpCacheEntry{cxn: db, txnCount: 1}
mgr.cpCache.cache[version] = &cpCacheEntry{cxn: db, openCount: 1}
return db, nil
}

Expand All @@ -177,18 +187,18 @@ func (mgr *dbManager) Reader() dbm.DBReader {
}
}

func (mgr *dbManager) ReaderAt(ver uint64) (dbm.DBReader, error) {
func (mgr *dbManager) ReaderAt(version uint64) (dbm.DBReader, error) {
mgr.mtx.RLock()
defer mgr.mtx.RUnlock()
db, err := mgr.openCheckpoint(ver)
db, err := mgr.openCheckpoint(version)
if err != nil {
return nil, err
}

return &dbTxn{
txn: db.TransactionBegin(mgr.opts.wo, mgr.opts.txo, nil),
mgr: mgr,
version: ver,
version: version,
}, nil
}

Expand Down Expand Up @@ -236,21 +246,21 @@ func (mgr *dbManager) save(target uint64) (uint64, error) {
return 0, dbm.ErrOpenTransactions
}
newVmgr := mgr.vmgr.Copy()
ver, err := newVmgr.Save(target)
target, err := newVmgr.Save(target)
if err != nil {
return 0, err
}
cp, err := mgr.current.NewCheckpoint()
if err != nil {
return 0, err
}
dir := filepath.Join(mgr.checkpointsDir(), fmt.Sprintf(checkpointFileFormat, ver))
dir := filepath.Join(mgr.checkpointsDir(), fmt.Sprintf(checkpointFileFormat, target))
if err := cp.CreateCheckpoint(dir, 0); err != nil {
return 0, err
}
cp.Destroy()
mgr.vmgr = newVmgr
return ver, nil
return target, nil
}

func (mgr *dbManager) DeleteVersion(ver uint64) error {
Expand All @@ -268,6 +278,49 @@ func (mgr *dbManager) DeleteVersion(ver uint64) error {
return os.RemoveAll(dbPath)
}

func (mgr *dbManager) Revert() (err error) {
mgr.mtx.RLock()
defer mgr.mtx.RUnlock()
if mgr.openWriters > 0 {
return dbm.ErrOpenTransactions
}
last := mgr.vmgr.Last()
if last == 0 {
return dbm.ErrInvalidVersion
}
// Close current connection and replace it with a checkpoint (created from the last checkpoint)
mgr.current.Close()
dbPath := filepath.Join(mgr.dir, currentDBFileName)
err = os.RemoveAll(dbPath)
if err != nil {
return
}
err = mgr.restoreFromCheckpoint(last, dbPath)
if err != nil {
return
}
mgr.current, err = gorocksdb.OpenOptimisticTransactionDb(mgr.opts.dbo, dbPath)
return
}

func (mgr *dbManager) restoreFromCheckpoint(version uint64, path string) error {
cxn, err := mgr.openCheckpoint(version)
if err != nil {
return err
}
defer mgr.cpCache.decrement(version)
cp, err := cxn.NewCheckpoint()
if err != nil {
return err
}
err = cp.CreateCheckpoint(path, 0)
if err != nil {
return err
}
cp.Destroy()
return nil
}

// Close implements DBConnection.
func (mgr *dbManager) Close() error {
mgr.current.Close()
Expand Down Expand Up @@ -421,8 +474,8 @@ func (cpc *checkpointCache) decrement(ver uint64) bool {
if !has {
return false
}
cp.txnCount -= 1
if cp.txnCount == 0 {
cp.openCount -= 1
if cp.openCount == 0 {
cp.cxn.Close()
delete(cpc.cache, ver)
}
Expand Down
27 changes: 27 additions & 0 deletions db/rocksdb/db_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package rocksdb

import (
"os"
"testing"

"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -31,6 +32,32 @@ func TestVersioning(t *testing.T) {
dbtest.DoTestVersioning(t, load)
}

func TestRevert(t *testing.T) {
dbtest.DoTestRevert(t, load, false)
dbtest.DoTestRevert(t, load, true)
}

func TestReloadDB(t *testing.T) {
dbtest.DoTestReloadDB(t, load)
}

// Test that the DB can be reloaded after a failed Revert
func TestRevertRecovery(t *testing.T) {
dir := t.TempDir()
db, err := NewDB(dir)
require.NoError(t, err)
_, err = db.SaveNextVersion()
require.NoError(t, err)
txn := db.Writer()
require.NoError(t, txn.Set([]byte{1}, []byte{1}))
require.NoError(t, txn.Set([]byte{2}, []byte{2}))
require.NoError(t, txn.Commit())

// make checkpoints dir temporarily unreadable to trigger an error
require.NoError(t, os.Chmod(db.checkpointsDir(), 0000))
require.Error(t, db.Revert())

require.NoError(t, os.Chmod(db.checkpointsDir(), 0755))
db, err = NewDB(dir)
require.NoError(t, err)
}

0 comments on commit 9cbf319

Please sign in to comment.