diff --git a/manager/dispatcher/dispatcher.go b/manager/dispatcher/dispatcher.go index 12c2a81e34..0af1295e75 100644 --- a/manager/dispatcher/dispatcher.go +++ b/manager/dispatcher/dispatcher.go @@ -125,8 +125,18 @@ type clusterUpdate struct { // Dispatcher is responsible for dispatching tasks and tracking agent health. type Dispatcher struct { - mu sync.Mutex - wg sync.WaitGroup + // mu is a lock to provide mutually exclusive access to dispatcher fields + // e.g. lastSeenManagers, networkBootstrapKeys, lastSeenRootCert etc. + // Also used to make atomic the setting of the shutdown flag to 'true' and the + // Add() operation on the shutdownWait to make sure that stop() waits for + // all operations to finish and disallow new operations from starting. + mu sync.Mutex + // shutdown is a flag to indicate shutdown and prevent new operations on the dispatcher. + // Set by calling Stop(). + shutdown bool + // shutdownWait is used by stop() to wait for existing operations to finish. + shutdownWait sync.WaitGroup + nodes *nodeStore store *store.MemoryStore lastSeenManagers []*api.WeightedPeer @@ -195,6 +205,12 @@ func getWeightedPeers(cluster Cluster) []*api.WeightedPeer { // Run runs dispatcher tasks which should be run on leader dispatcher. // Dispatcher can be stopped with cancelling ctx or calling Stop(). func (d *Dispatcher) Run(ctx context.Context) error { + // The dispatcher object is not recreated when a node re-gains + // leadership. We need to reset to default state. + d.mu.Lock() + d.shutdown = false + d.mu.Unlock() + ctx = log.WithModule(ctx, "dispatcher") log.G(ctx).Info("dispatcher starting") @@ -249,8 +265,8 @@ func (d *Dispatcher) Run(ctx context.Context) error { defer cancel() d.ctx, d.cancel = context.WithCancel(ctx) ctx = d.ctx - d.wg.Add(1) - defer d.wg.Done() + d.shutdownWait.Add(1) + defer d.shutdownWait.Done() d.mu.Unlock() publishManagers := func(peers []*api.Peer) { @@ -313,11 +329,19 @@ func (d *Dispatcher) Stop() error { return errors.New("dispatcher is already stopped") } - log := log.G(d.ctx).WithField("method", "(*Dispatcher).Stop") - log.Info("dispatcher stopping") + // Set shutdown to true. + // This will prevent RPCs that start after stop() is called + // from making progress and essentially puts the dispatcher in drain. + d.shutdown = true + + // Cancel dispatcher context. + // This should also close the the streams in Tasks(), Assignments(). d.cancel() d.mu.Unlock() + // Wait for the RPCs that are in-progress to finish. + d.shutdownWait.Wait() + d.nodes.Clean() d.processUpdatesLock.Lock() @@ -328,9 +352,6 @@ func (d *Dispatcher) Stop() error { d.processUpdatesLock.Unlock() d.clusterUpdateQueue.Close() - - d.wg.Wait() - return nil } @@ -478,13 +499,19 @@ func nodeIPFromContext(ctx context.Context) (string, error) { // register is used for registration of node with particular dispatcher. func (d *Dispatcher) register(ctx context.Context, nodeID string, description *api.NodeDescription) (string, error) { - logLocal := log.G(ctx).WithField("method", "(*Dispatcher).register") - // prevent register until we're ready to accept it - dctx, err := d.isRunningLocked() - if err != nil { - return "", err + d.mu.Lock() + if d.shutdown { + d.mu.Unlock() + return "", status.Errorf(codes.Aborted, "dispatcher is stopped") } + dctx := d.ctx + d.shutdownWait.Add(1) + defer d.shutdownWait.Done() + d.mu.Unlock() + + logLocal := log.G(ctx).WithField("method", "(*Dispatcher).register") + if err := d.nodes.CheckRateLimit(nodeID); err != nil { return "", err } @@ -532,6 +559,16 @@ func (d *Dispatcher) register(ctx context.Context, nodeID string, description *a // UpdateTaskStatus updates status of task. Node should send such updates // on every status change of its tasks. func (d *Dispatcher) UpdateTaskStatus(ctx context.Context, r *api.UpdateTaskStatusRequest) (*api.UpdateTaskStatusResponse, error) { + d.mu.Lock() + if d.shutdown { + d.mu.Unlock() + return nil, status.Errorf(codes.Aborted, "dispatcher is stopped") + } + dctx := d.ctx + d.shutdownWait.Add(1) + defer d.shutdownWait.Done() + d.mu.Unlock() + nodeInfo, err := ca.RemoteNode(ctx) if err != nil { return nil, err @@ -547,11 +584,6 @@ func (d *Dispatcher) UpdateTaskStatus(ctx context.Context, r *api.UpdateTaskStat } log := log.G(ctx).WithFields(fields) - dctx, err := d.isRunningLocked() - if err != nil { - return nil, err - } - if _, err := d.nodes.GetWithSession(nodeID, r.SessionID); err != nil { return nil, err } @@ -774,6 +806,16 @@ func (d *Dispatcher) Tasks(r *api.TasksRequest, stream api.Dispatcher_TasksServe defer cancel() for { + d.mu.Lock() + if d.shutdown { + d.mu.Unlock() + return status.Errorf(codes.Aborted, "dispatcher is stopped") + } + + d.shutdownWait.Add(1) + defer d.shutdownWait.Done() + d.mu.Unlock() + if _, err := d.nodes.GetWithSession(nodeID, r.SessionID); err != nil { return err } @@ -919,6 +961,16 @@ func (d *Dispatcher) Assignments(r *api.AssignmentsRequest, stream api.Dispatche } for { + d.mu.Lock() + if d.shutdown { + d.mu.Unlock() + return nil + } + + d.shutdownWait.Add(1) + defer d.shutdownWait.Done() + d.mu.Unlock() + // Check for session expiration if _, err := d.nodes.GetWithSession(nodeID, r.SessionID); err != nil { return err @@ -1103,6 +1155,15 @@ func (d *Dispatcher) markNodeNotReady(id string, state api.NodeStatus_State, mes // Node should send new heartbeat earlier than now + TTL, otherwise it will // be deregistered from dispatcher and its status will be updated to NodeStatus_DOWN func (d *Dispatcher) Heartbeat(ctx context.Context, r *api.HeartbeatRequest) (*api.HeartbeatResponse, error) { + d.mu.Lock() + if d.shutdown { + d.mu.Unlock() + return nil, status.Errorf(codes.Aborted, "dispatcher is stopped") + } + d.shutdownWait.Add(1) + defer d.shutdownWait.Done() + d.mu.Unlock() + nodeInfo, err := ca.RemoteNode(ctx) if err != nil { return nil, err @@ -1232,6 +1293,15 @@ func (d *Dispatcher) Session(r *api.SessionRequest, stream api.Dispatcher_Sessio } for { + d.mu.Lock() + if d.shutdown { + d.mu.Unlock() + status.Errorf(codes.Aborted, "dispatcher is stopped") + } + d.shutdownWait.Add(1) + defer d.shutdownWait.Done() + d.mu.Unlock() + // After each message send, we need to check the nodes sessionID hasn't // changed. If it has, we will shut down the stream and make the node // re-register.