Skip to content

Commit

Permalink
rollback tests
Browse files Browse the repository at this point in the history
  • Loading branch information
miagilepner committed Aug 25, 2023
1 parent 2869b37 commit c5e1774
Show file tree
Hide file tree
Showing 2 changed files with 228 additions and 0 deletions.
224 changes: 224 additions & 0 deletions vault/rollback_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package vault

import (
"context"
"fmt"
"strings"
"sync"
"testing"
Expand All @@ -16,6 +17,7 @@ import (
"github.com/hashicorp/vault/helper/metricsutil"
"github.com/hashicorp/vault/helper/namespace"
"github.com/hashicorp/vault/sdk/helper/logging"
"github.com/hashicorp/vault/sdk/logical"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -81,6 +83,228 @@ func TestRollbackManager(t *testing.T) {
}
}

// TestRollbackManager_UnboundedWorkers adds 10 backends that require a rollback
// operation, with an unbounded number of workers. The test verifies that the 10
// work items will run in parallel
func TestRollbackManager_UnboundedWorkers(t *testing.T) {
core := TestCoreWithConfig(t, &CoreConfig{NumRollbackWorkers: -1, RollbackPeriod: time.Millisecond * 10})
view := NewBarrierView(core.barrier, "logical/")

ran := make(chan string)
release := make(chan struct{})
core, _, _ = testCoreUnsealed(t, core)

// create 10 backends
// when a rollback happens, each backend will try to write to an unbuffered
// channel, then wait to be released
for i := 0; i < 10; i++ {
b := &NoopBackend{}
b.RequestHandler = func(ctx context.Context, request *logical.Request) (*logical.Response, error) {
if request.Operation == logical.RollbackOperation {
ran <- request.Path
<-release
}
return nil, nil
}
b.Root = []string{fmt.Sprintf("foo/%d", i)}
meUUID, err := uuid.GenerateUUID()
require.NoError(t, err)
mountEntry := &MountEntry{
Table: mountTableType,
UUID: meUUID,
Accessor: fmt.Sprintf("accessor-%d", i),
NamespaceID: namespace.RootNamespaceID,
namespace: namespace.RootNamespace,
Path: fmt.Sprintf("logical/foo/%d", i),
}
func() {
core.mountsLock.Lock()
defer core.mountsLock.Unlock()
newTable := core.mounts.shallowClone()
newTable.Entries = append(newTable.Entries, mountEntry)
core.mounts = newTable
err = core.router.Mount(b, "logical", mountEntry, view)
require.NoError(t, core.persistMounts(context.Background(), newTable, &mountEntry.Local))
}()
}

timeout, cancel := context.WithTimeout(context.Background(), 20*time.Second)
defer cancel()
got := make(map[string]bool)
hasMore := true
for hasMore {
// we're not bounding the number of workers, so we would expect to see
// all 10 writes to the channel from each of the backends. Once that
// happens, close the release channel so that the functions can exit
select {
case <-timeout.Done():
require.Fail(t, "test timed out")
case i := <-ran:
got[i] = true
if len(got) == 10 {
close(release)
hasMore = false
}
}
}
done := make(chan struct{})

// start a goroutine to consume the remaining items from the queued work
go func() {
for {
select {
case <-ran:
case <-done:
}
}
}()
// stop the rollback worker, which will wait for all inflight rollbacks to
// complete
core.rollback.Stop()
close(done)
}

// TestRollbackManager_WorkerPool adds 10 backends that require a rollback
// operation, with 5 workers. The test verifies that the 5 work items can occur
// concurrently, and that the remainder of the work is queued and run when
// workers are available
func TestRollbackManager_WorkerPool(t *testing.T) {
core := TestCoreWithConfig(t, &CoreConfig{NumRollbackWorkers: 5, RollbackPeriod: time.Millisecond * 10})
view := NewBarrierView(core.barrier, "logical/")

ran := make(chan string)
release := make(chan struct{})
core, _, _ = testCoreUnsealed(t, core)

// create 10 backends
// when a rollback happens, each backend will try to write to an unbuffered
// channel, then wait to be released
for i := 0; i < 10; i++ {
b := &NoopBackend{}
b.RequestHandler = func(ctx context.Context, request *logical.Request) (*logical.Response, error) {
if request.Operation == logical.RollbackOperation {
ran <- request.Path
<-release
}
return nil, nil
}
b.Root = []string{fmt.Sprintf("foo/%d", i)}
meUUID, err := uuid.GenerateUUID()
require.NoError(t, err)
mountEntry := &MountEntry{
Table: mountTableType,
UUID: meUUID,
Accessor: fmt.Sprintf("accessor-%d", i),
NamespaceID: namespace.RootNamespaceID,
namespace: namespace.RootNamespace,
Path: fmt.Sprintf("logical/foo/%d", i),
}
func() {
core.mountsLock.Lock()
defer core.mountsLock.Unlock()
newTable := core.mounts.shallowClone()
newTable.Entries = append(newTable.Entries, mountEntry)
core.mounts = newTable
err = core.router.Mount(b, "logical", mountEntry, view)
require.NoError(t, core.persistMounts(context.Background(), newTable, &mountEntry.Local))
}()
}

timeout, cancel := context.WithTimeout(context.Background(), 20*time.Second)
defer cancel()
got := make(map[string]bool)
hasMore := true
for hasMore {
// we're using 5 workers, so we would expect to see 5 writes to the
// channel. Once that happens, close the release channel so that the
// functions can exit and new rollback operations can run
select {
case <-timeout.Done():
require.Fail(t, "test timed out")
case i := <-ran:
got[i] = true
if len(got) == 5 {
close(release)
hasMore = false
}
}
}
done := make(chan struct{})

// all 10 of the backends *should* be queued, but we can add an extra sleep
// just to make sure
time.Sleep(20 * time.Millisecond)

// start a goroutine to consume the remaining items from the queued work
go func() {
for {
select {
case i := <-ran:
got[i] = true
case <-done:
}
}
}()
// stop the rollback worker, which will wait for all inflight rollbacks to
// complete
core.rollback.Stop()
close(done)

// we should have received at least 1 rollback for every backend
require.GreaterOrEqual(t, len(got), 10)
}

// TestRollbackManager_numRollbackWorkers verifies that the number of rollback
// workers is parsed from the configuration, but can be overridden by an
// environment variable. This test cannot be run in parallel because of the
// environment variable
func TestRollbackManager_numRollbackWorkers(t *testing.T) {
testCases := []struct {
name string
configWorkers int
setEnvVar bool
envVar string
wantWorkers int
}{
{
name: "default in config",
configWorkers: RollbackDefaultNumWorkers,
wantWorkers: RollbackDefaultNumWorkers,
},
{
name: "invalid envvar",
configWorkers: RollbackDefaultNumWorkers,
wantWorkers: RollbackDefaultNumWorkers,
setEnvVar: true,
envVar: "invalid",
},
{
name: "envvar overrides config",
configWorkers: RollbackDefaultNumWorkers,
wantWorkers: 0,
setEnvVar: true,
envVar: "0",
},
{
name: "envvar negative",
configWorkers: RollbackDefaultNumWorkers,
wantWorkers: -1,
setEnvVar: true,
envVar: "-1",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
if tc.setEnvVar {
t.Setenv(RollbackWorkersEnvVar, tc.envVar)
}
core := &Core{numRollbackWorkers: tc.configWorkers}
r := &RollbackManager{logger: logger.Named("test"), core: core}
require.Equal(t, tc.wantWorkers, r.numRollbackWorkers())
})
}
}

func TestRollbackManager_Join(t *testing.T) {
m, backend := mockRollback(t)
if len(backend.Paths) > 0 {
Expand Down
4 changes: 4 additions & 0 deletions vault/testing.go
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,9 @@ func TestCoreWithSealAndUINoCleanup(t testing.T, opts *CoreConfig) *Core {
if opts.RollbackPeriod != time.Duration(0) {
conf.RollbackPeriod = opts.RollbackPeriod
}
if opts.NumRollbackWorkers != 0 {
conf.NumRollbackWorkers = opts.NumRollbackWorkers
}

conf.ActivityLogConfig = opts.ActivityLogConfig
testApplyEntBaseConfig(conf, opts)
Expand Down Expand Up @@ -303,6 +306,7 @@ func testCoreConfig(t testing.T, physicalBackend physical.Backend, logger log.Lo
CredentialBackends: credentialBackends,
DisableMlock: true,
Logger: logger,
NumRollbackWorkers: 10,
BuiltinRegistry: corehelpers.NewMockBuiltinRegistry(),
}

Expand Down

0 comments on commit c5e1774

Please sign in to comment.