diff --git a/vault/rollback.go b/vault/rollback.go index 9148335289b1..9dad02acdf7d 100644 --- a/vault/rollback.go +++ b/vault/rollback.go @@ -59,6 +59,8 @@ type RollbackManager struct { type rollbackState struct { lastError error sync.WaitGroup + cancelLockGrabCtx context.Context + cancelLockGrabCtxCancel context.CancelFunc } // NewRollbackManager is used to create a new rollback manager @@ -132,24 +134,31 @@ func (m *RollbackManager) triggerRollbacks() { } fullPath := e.namespace.Path + path - m.inflightLock.RLock() - _, ok := m.inflight[fullPath] - m.inflightLock.RUnlock() - if !ok { - m.startRollback(ctx, fullPath, true) - } + // Start a rollback if necessary + m.startOrLookupRollback(ctx, fullPath, true) } } -// startRollback is used to start an async rollback attempt. +// startOrLookupRollback is used to start an async rollback attempt. // This must be called with the inflightLock held. -func (m *RollbackManager) startRollback(ctx context.Context, fullPath string, grabStatelock bool) *rollbackState { - rs := &rollbackState{} - rs.Add(1) - m.inflightAll.Add(1) +func (m *RollbackManager) startOrLookupRollback(ctx context.Context, fullPath string, grabStatelock bool) *rollbackState { m.inflightLock.Lock() + defer m.inflightLock.Unlock() + rsInflight, ok := m.inflight[fullPath] + if ok { + return rsInflight + } + + cancelCtx, cancelFunc := context.WithCancel(context.Background()) + rs := &rollbackState{ + cancelLockGrabCtx: cancelCtx, + cancelLockGrabCtxCancel: cancelFunc, + } + + // If no inflight rollback is already running, kick one off m.inflight[fullPath] = rs - m.inflightLock.Unlock() + rs.Add(1) + m.inflightAll.Add(1) go m.attemptRollback(ctx, fullPath, rs, grabStatelock) return rs } @@ -184,17 +193,39 @@ func (m *RollbackManager) attemptRollback(ctx context.Context, fullPath string, Path: ns.TrimmedPath(fullPath), } + releaseLock := true if grabStatelock { + doneCh := make(chan struct{}) + defer close(doneCh) + + stopCh := make(chan struct{}) + go func() { + defer close(stopCh) + + select { + case <-m.shutdownCh: + case <-rs.cancelLockGrabCtx.Done(): + case <-doneCh: + } + }() + // Grab the statelock or stop - if stopped := grabLockOrStop(m.core.stateLock.RLock, m.core.stateLock.RUnlock, m.shutdownCh); stopped { - return errors.New("rollback shutting down") + if stopped := grabLockOrStop(m.core.stateLock.RLock, m.core.stateLock.RUnlock, stopCh); stopped { + // If we stopped due to shutdown, return. Otherwise another thread + // is holding the lock for us, continue on. + select { + case <-m.shutdownCh: + return errors.New("rollback shutting down") + default: + releaseLock = false + } } } var cancelFunc context.CancelFunc ctx, cancelFunc = context.WithTimeout(ctx, DefaultMaxRequestDuration) _, err = m.router.Route(ctx, req) - if grabStatelock { + if grabStatelock && releaseLock { m.core.stateLock.RUnlock() } cancelFunc() @@ -216,7 +247,8 @@ func (m *RollbackManager) attemptRollback(ctx context.Context, fullPath string, // Rollback is used to trigger an immediate rollback of the path, // or to join an existing rollback operation if in flight. Caller should have -// core's statelock held +// core's statelock held (write OR read). If an already inflight rollback is +// happening this function will simply wait for it to complete func (m *RollbackManager) Rollback(ctx context.Context, path string) error { ns, err := namespace.FromContext(ctx) if err != nil { @@ -224,15 +256,18 @@ func (m *RollbackManager) Rollback(ctx context.Context, path string) error { } fullPath := ns.Path + path - // Check for an existing attempt and start one if none - m.inflightLock.RLock() - rs, ok := m.inflight[fullPath] - m.inflightLock.RUnlock() - if !ok { - rs = m.startRollback(ctx, fullPath, false) - } + // Check for an existing attempt or start one if none + rs := m.startOrLookupRollback(ctx, fullPath, false) + + // Since we have the statelock held, tell any inflight rollback to give up + // trying to aquire it. This will prevent deadlocks in the case where we + // have the write lock. In the case where it was waiting to grab + // a read lock it will then simply continue with the rollback + // operation under the protection of our write lock. + rs.cancelLockGrabCtxCancel() - // Wait for the attempt to finish + // It's safe to do this, since the other thread either already has the lock + // held, or we just canceled it above. rs.Wait() // Return the last error