From 5d7f8da1bdc9e702ab3989963e95c8e0d1278cd5 Mon Sep 17 00:00:00 2001 From: Anshul Pundir Date: Fri, 30 Mar 2018 14:51:13 -0700 Subject: [PATCH] [manager/dispatcher] Use read-write lock for dispatcher context. Signed-off-by: Anshul Pundir --- manager/dispatcher/dispatcher.go | 32 ++++++++++++++++++-------------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/manager/dispatcher/dispatcher.go b/manager/dispatcher/dispatcher.go index c6530fad76..74db5c767f 100644 --- a/manager/dispatcher/dispatcher.go +++ b/manager/dispatcher/dispatcher.go @@ -133,7 +133,8 @@ type Dispatcher struct { // has finished initializing the dispatcher. wg sync.WaitGroup // This RWMutex synchronizes RPC handlers and the dispatcher stop(). - // The RPC handlers use the read lock while stop() uses the write lock + // Used to serialize read-write access to the dispatcher context. + // Also, the RPC handlers use the read lock while stop() uses the write lock // and acts as a barrier to shutdown. rpcRW sync.RWMutex nodes *nodeStore @@ -265,12 +266,15 @@ func (d *Dispatcher) Run(ctx context.Context) error { d.lastSeenManagers = getWeightedPeers(d.cluster) defer cancel() - d.ctx, d.cancel = context.WithCancel(ctx) - ctx = d.ctx d.wg.Add(1) defer d.wg.Done() d.mu.Unlock() + d.rpcRW.Lock() + d.ctx, d.cancel = context.WithCancel(ctx) + ctx = d.ctx + d.rpcRW.Unlock() + publishManagers := func(peers []*api.Peer) { var mgrs []*api.WeightedPeer for _, p := range peers { @@ -333,7 +337,6 @@ func (d *Dispatcher) Stop() error { log := log.G(d.ctx).WithField("method", "(*Dispatcher).Stop") log.Info("dispatcher stopping") - d.cancel() d.mu.Unlock() // The active nodes list can be cleaned out only when all @@ -341,6 +344,7 @@ func (d *Dispatcher) Stop() error { // RPCs that start after rpcRW.Unlock() should find the context // cancelled and should fail organically. d.rpcRW.Lock() + d.cancel() d.nodes.Clean() d.downNodes.Clean() d.rpcRW.Unlock() @@ -364,14 +368,14 @@ func (d *Dispatcher) Stop() error { return nil } -func (d *Dispatcher) isRunningLocked() (context.Context, error) { - d.mu.Lock() +// context returns the dispatcher context. +func (d *Dispatcher) context() (context.Context, error) { + d.rpcRW.RLock() + defer d.rpcRW.RUnlock() if !d.isRunning() { - d.mu.Unlock() return nil, status.Errorf(codes.Aborted, "dispatcher is stopped") } ctx := d.ctx - d.mu.Unlock() return ctx, nil } @@ -510,7 +514,7 @@ func nodeIPFromContext(ctx context.Context) (string, error) { func (d *Dispatcher) register(ctx context.Context, nodeID string, description *api.NodeDescription) (string, error) { logLocal := log.G(ctx).WithField("method", "(*Dispatcher).register") // prevent register until we're ready to accept it - dctx, err := d.isRunningLocked() + dctx, err := d.context() if err != nil { return "", err } @@ -565,7 +569,7 @@ func (d *Dispatcher) UpdateTaskStatus(ctx context.Context, r *api.UpdateTaskStat d.rpcRW.RLock() defer d.rpcRW.RUnlock() - dctx, err := d.isRunningLocked() + dctx, err := d.context() if err != nil { return nil, err } @@ -759,7 +763,7 @@ func (d *Dispatcher) Tasks(r *api.TasksRequest, stream api.Dispatcher_TasksServe d.rpcRW.RLock() defer d.rpcRW.RUnlock() - dctx, err := d.isRunningLocked() + dctx, err := d.context() if err != nil { return err } @@ -885,7 +889,7 @@ func (d *Dispatcher) Assignments(r *api.AssignmentsRequest, stream api.Dispatche d.rpcRW.RLock() defer d.rpcRW.RUnlock() - dctx, err := d.isRunningLocked() + dctx, err := d.context() if err != nil { return err } @@ -1080,7 +1084,7 @@ func (d *Dispatcher) moveTasksToOrphaned(nodeID string) error { func (d *Dispatcher) markNodeNotReady(id string, state api.NodeStatus_State, message string) error { logLocal := log.G(d.ctx).WithField("method", "(*Dispatcher).markNodeNotReady") - dctx, err := d.isRunningLocked() + dctx, err := d.context() if err != nil { return err } @@ -1190,7 +1194,7 @@ func (d *Dispatcher) Session(r *api.SessionRequest, stream api.Dispatcher_Sessio d.rpcRW.RLock() defer d.rpcRW.RUnlock() - dctx, err := d.isRunningLocked() + dctx, err := d.context() if err != nil { return err }