Skip to content

Commit

Permalink
Add locks to streams (#231)
Browse files Browse the repository at this point in the history
* add locks to streams

* refactor from review
  • Loading branch information
ukclivecox authored May 24, 2022
1 parent 660d11c commit ede182a
Show file tree
Hide file tree
Showing 7 changed files with 129 additions and 89 deletions.
2 changes: 2 additions & 0 deletions scheduler/pkg/agent/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,8 @@ func (s *Server) AgentEvent(ctx context.Context, message *pb.ModelEventMessage)
desiredState = store.ModelReplicaStateUnknown
}
logger.Infof("Updating state for model %s to %s", message.ModelName, desiredState.String())
s.store.LockModel(message.ModelName)
defer s.store.UnlockModel(message.ModelName)
err := s.store.UpdateModelState(message.ModelName, message.GetModelVersion(), message.ServerName, int(message.ReplicaIdx), &message.AvailableMemoryBytes, expectedState, desiredState, message.GetMessage())
if err != nil {
logger.WithError(err).Infof("Failed Updating state for model %s", message.ModelName)
Expand Down
68 changes: 36 additions & 32 deletions scheduler/pkg/envoy/processor/incremental.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,52 +88,56 @@ func NewIncrementalProcessor(

func (p *IncrementalProcessor) handlePipelinesEvents(event coordinator.PipelineEventMsg) {
logger := p.logger.WithField("func", "handleExperimentEvents")
pip, err := p.pipelineHandler.GetPipeline(event.PipelineName)
if err != nil {
logger.WithError(err).Errorf("Failed to get pipeline %s", event.PipelineName)
} else {
if pip.Deleted {
err := p.removePipeline(pip)
if err != nil {
logger.WithError(err).Errorf("Failed to remove pipeline %s", pip.Name)
}
go func() {
pip, err := p.pipelineHandler.GetPipeline(event.PipelineName)
if err != nil {
logger.WithError(err).Errorf("Failed to get pipeline %s", event.PipelineName)
} else {
err := p.addPipeline(pip)
if err != nil {
logger.WithError(err).Errorf("Dailed to add pipeline %s", pip.Name)
if pip.Deleted {
err := p.removePipeline(pip)
if err != nil {
logger.WithError(err).Errorf("Failed to remove pipeline %s", pip.Name)
}
} else {
err := p.addPipeline(pip)
if err != nil {
logger.WithError(err).Errorf("Dailed to add pipeline %s", pip.Name)
}
}
}
}
}()
}

func (p *IncrementalProcessor) handleExperimentEvents(event coordinator.ExperimentEventMsg) {
logger := p.logger.WithField("func", "handleExperimentEvents")
logger.Debugf("Received sync for experiment %s", event.String())
exp, err := p.experimentServer.GetExperiment(event.ExperimentName)
if err != nil {
logger.WithError(err).Errorf("Failed to get experiment %s", event.ExperimentName)
} else {
if exp.Deleted {
err := p.removeExperiment(exp)
if err != nil {
logger.WithError(err).Errorf("Failed to get experiment %s", event.ExperimentName)
}
go func() {
exp, err := p.experimentServer.GetExperiment(event.ExperimentName)
if err != nil {
logger.WithError(err).Errorf("Failed to get experiment %s", event.ExperimentName)
} else {
if event.UpdatedExperiment {
err := p.experimentSync(exp)
var err2 error
if exp.Deleted {
err := p.removeExperiment(exp)
if err != nil {
logger.WithError(err).Errorf("Failed to process sync for experiment %s", event.String())
err2 = p.experimentServer.SetStatus(event.ExperimentName, false, err.Error())
} else {
err2 = p.experimentServer.SetStatus(event.ExperimentName, true, "experiment active")
logger.WithError(err).Errorf("Failed to get experiment %s", event.ExperimentName)
}
if err2 != nil {
logger.WithError(err2).Errorf("Failed to set experiment activation")
} else {
if event.UpdatedExperiment {
err := p.experimentSync(exp)
var err2 error
if err != nil {
logger.WithError(err).Errorf("Failed to process sync for experiment %s", event.String())
err2 = p.experimentServer.SetStatus(event.ExperimentName, false, err.Error())
} else {
err2 = p.experimentServer.SetStatus(event.ExperimentName, true, "experiment active")
}
if err2 != nil {
logger.WithError(err2).Errorf("Failed to set experiment activation")
}
}
}
}
}
}()
}

func (p *IncrementalProcessor) handleModelEvents(event coordinator.ModelEventMsg) {
Expand Down
56 changes: 31 additions & 25 deletions scheduler/pkg/kafka/dataflow/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@ func (c *ChainerServer) SubscribePipelineUpdates(req *chainer.PipelineSubscripti
}

func (c *ChainerServer) StopSendPipelineEvents() {
c.mu.Lock()
defer c.mu.Unlock()
for _, subscription := range c.streams {
close(subscription.fin)
}
Expand Down Expand Up @@ -231,34 +233,38 @@ func (c *ChainerServer) createPipelineMessage(pv *pipeline.PipelineVersion) *cha

func (c *ChainerServer) handlePipelineEvent(event coordinator.PipelineEventMsg) {
logger := c.logger.WithField("func", "handlePipelineEvent")
pv, err := c.pipelineHandler.GetPipelineVersion(event.PipelineName, event.PipelineVersion, event.UID)
if err != nil {
logger.WithError(err).Errorf("Failed to get pipeline from event %s", event.String())
return
}
logger.Debugf("Received event %s with state %s", event.String(), pv.State.Status.String())
switch pv.State.Status {
case pipeline.PipelineCreate:
msg := c.createPipelineMessage(pv)
for _, subscription := range c.streams {
if err := subscription.stream.Send(msg); err != nil {
logger.WithError(err).Errorf("Failed to send msg for pipeline %s", pv.String())
} else {
if err := c.pipelineHandler.SetPipelineState(pv.Name, pv.Version, pv.UID, pipeline.PipelineCreating, ""); err != nil {
logger.WithError(err).Errorf("Failed to set pipeline %s to creating state", pv.String())
go func() {
pv, err := c.pipelineHandler.GetPipelineVersion(event.PipelineName, event.PipelineVersion, event.UID)
if err != nil {
logger.WithError(err).Errorf("Failed to get pipeline from event %s", event.String())
return
}
logger.Debugf("Received event %s with state %s", event.String(), pv.State.Status.String())
c.mu.Lock()
defer c.mu.Unlock()
switch pv.State.Status {
case pipeline.PipelineCreate:
msg := c.createPipelineMessage(pv)
for _, subscription := range c.streams {
if err := subscription.stream.Send(msg); err != nil {
logger.WithError(err).Errorf("Failed to send msg for pipeline %s", pv.String())
} else {
if err := c.pipelineHandler.SetPipelineState(pv.Name, pv.Version, pv.UID, pipeline.PipelineCreating, ""); err != nil {
logger.WithError(err).Errorf("Failed to set pipeline %s to creating state", pv.String())
}
}
}
}
case pipeline.PipelineTerminate:
msg := c.createPipelineMessage(pv)
for _, subscription := range c.streams {
if err := subscription.stream.Send(msg); err != nil {
logger.WithError(err).Errorf("Failed to send msg for pipeline %s", pv.String())
} else {
if err := c.pipelineHandler.SetPipelineState(pv.Name, pv.Version, pv.UID, pipeline.PipelineTerminating, ""); err != nil {
logger.WithError(err).Errorf("Failed to set pipeline %s to terminate state", pv.String())
case pipeline.PipelineTerminate:
msg := c.createPipelineMessage(pv)
for _, subscription := range c.streams {
if err := subscription.stream.Send(msg); err != nil {
logger.WithError(err).Errorf("Failed to send msg for pipeline %s", pv.String())
} else {
if err := c.pipelineHandler.SetPipelineState(pv.Name, pv.Version, pv.UID, pipeline.PipelineTerminating, ""); err != nil {
logger.WithError(err).Errorf("Failed to set pipeline %s to terminate state", pv.String())
}
}
}
}
}
}()
}
41 changes: 25 additions & 16 deletions scheduler/pkg/server/experiment_status.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@ func (s *SchedulerServer) SubscribeExperimentStatus(req *pb.ExperimentSubscripti

fin := make(chan bool)

s.mu.Lock()
s.experimentEventStream.mu.Lock()
s.experimentEventStream.streams[stream] = &ExperimentSubscription{
name: req.GetSubscriberName(),
stream: stream,
fin: fin,
}
s.mu.Unlock()
s.experimentEventStream.mu.Unlock()

ctx := stream.Context()
// Keep this scope alive because once this scope exits - the stream is closed
Expand All @@ -28,9 +28,9 @@ func (s *SchedulerServer) SubscribeExperimentStatus(req *pb.ExperimentSubscripti
return nil
case <-ctx.Done():
logger.Infof("Stream disconnected %s", req.GetSubscriberName())
s.mu.Lock()
s.experimentEventStream.mu.Lock()
delete(s.experimentEventStream.streams, stream)
s.mu.Unlock()
s.experimentEventStream.mu.Unlock()
return nil
}
}
Expand All @@ -50,18 +50,27 @@ func (s *SchedulerServer) handleExperimentEvents(event coordinator.ExperimentEve
logger := s.logger.WithField("func", "handleExperimentEvents")
logger.Debugf("Received experiment event %s", event.String())
if event.Status != nil {
for stream, subscription := range s.experimentEventStream.streams {
err := stream.Send(&pb.ExperimentStatusResponse{
ExperimentName: event.ExperimentName,
Active: event.Status.Active,
CandidatesReady: event.Status.CandidatesReady,
MirrorReady: event.Status.MirrorReady,
StatusDescription: event.Status.StatusDescription,
KubernetesMeta: asKubernetesMeta(event),
})
if err != nil {
logger.WithError(err).Errorf("Failed to send experiment status event to %s for %s", subscription.name, event.String())
}
go func() {
s.sendExperimentStatus(event)
}()
}
}

func (s *SchedulerServer) sendExperimentStatus(event coordinator.ExperimentEventMsg) {
logger := s.logger.WithField("func", "sendExperimentStatus")
s.experimentEventStream.mu.Lock()
defer s.experimentEventStream.mu.Unlock()
for stream, subscription := range s.experimentEventStream.streams {
err := stream.Send(&pb.ExperimentStatusResponse{
ExperimentName: event.ExperimentName,
Active: event.Status.Active,
CandidatesReady: event.Status.CandidatesReady,
MirrorReady: event.Status.MirrorReady,
StatusDescription: event.Status.StatusDescription,
KubernetesMeta: asKubernetesMeta(event),
})
if err != nil {
logger.WithError(err).Errorf("Failed to send experiment status event to %s for %s", subscription.name, event.String())
}
}
}
Expand Down
19 changes: 15 additions & 4 deletions scheduler/pkg/server/pipeline_status.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@ func (s *SchedulerServer) SubscribePipelineStatus(req *pb.PipelineSubscriptionRe

fin := make(chan bool)

s.mu.Lock()
s.pipelineEventStream.mu.Lock()
s.pipelineEventStream.streams[stream] = &PipelineSubscription{
name: req.GetSubscriberName(),
stream: stream,
fin: fin,
}
s.mu.Unlock()
s.pipelineEventStream.mu.Unlock()

ctx := stream.Context()
// Keep this scope alive because once this scope exits - the stream is closed
Expand All @@ -28,9 +28,9 @@ func (s *SchedulerServer) SubscribePipelineStatus(req *pb.PipelineSubscriptionRe
return nil
case <-ctx.Done():
logger.Infof("Stream disconnected %s", req.GetSubscriberName())
s.mu.Lock()
s.pipelineEventStream.mu.Lock()
delete(s.pipelineEventStream.streams, stream)
s.mu.Unlock()
s.pipelineEventStream.mu.Unlock()
return nil
}
}
Expand All @@ -39,6 +39,13 @@ func (s *SchedulerServer) SubscribePipelineStatus(req *pb.PipelineSubscriptionRe
func (s *SchedulerServer) handlePipelineEvents(event coordinator.PipelineEventMsg) {
logger := s.logger.WithField("func", "handlePipelineEvents")
logger.Debugf("Received pipeline event %s", event.String())
go func() {
s.sendPipelineEvents(event)
}()
}

func (s *SchedulerServer) sendPipelineEvents(event coordinator.PipelineEventMsg) {
logger := s.logger.WithField("func", "sendPipelineEvents")
pv, err := s.pipelineHandler.GetPipelineVersion(event.PipelineName, event.PipelineVersion, event.UID)
if err != nil {
logger.WithError(err).Errorf("Failed to get pipeline from event %s", event.String())
Expand All @@ -51,6 +58,8 @@ func (s *SchedulerServer) handlePipelineEvents(event coordinator.PipelineEventMs
PipelineName: pv.Name,
Versions: pipelineVersions,
}
s.pipelineEventStream.mu.Lock()
defer s.pipelineEventStream.mu.Unlock()
for stream, subscription := range s.pipelineEventStream.streams {
if err := stream.Send(status); err != nil {
logger.WithError(err).Errorf("Failed to send pipeline status event to %s for %s", subscription.name, event.String())
Expand All @@ -59,6 +68,8 @@ func (s *SchedulerServer) handlePipelineEvents(event coordinator.PipelineEventMs
}

func (s *SchedulerServer) StopSendPipelineEvents() {
s.pipelineEventStream.mu.Lock()
defer s.pipelineEventStream.mu.Unlock()
for _, subscription := range s.pipelineEventStream.streams {
close(subscription.fin)
}
Expand Down
8 changes: 4 additions & 4 deletions scheduler/pkg/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,26 +43,29 @@ type SchedulerServer struct {
experimentServer experiment.ExperimentServer
pipelineHandler pipeline.PipelineHandler
scheduler scheduler2.Scheduler
mu sync.RWMutex
modelEventStream ModelEventStream
serverEventStream ServerEventStream
experimentEventStream ExperimentEventStream
pipelineEventStream PipelineEventStream
}

type ModelEventStream struct {
mu sync.Mutex
streams map[pb.Scheduler_SubscribeModelStatusServer]*ModelSubscription
}

type ServerEventStream struct {
mu sync.Mutex
streams map[pb.Scheduler_SubscribeServerStatusServer]*ServerSubscription
}

type ExperimentEventStream struct {
mu sync.Mutex
streams map[pb.Scheduler_SubscribeExperimentStatusServer]*ExperimentSubscription
}

type PipelineEventStream struct {
mu sync.Mutex
streams map[pb.Scheduler_SubscribePipelineStatusServer]*PipelineSubscription
}

Expand Down Expand Up @@ -520,9 +523,6 @@ func (s *SchedulerServer) PipelineStatus(
logger.Infof("received status request from %s", req.SubscriberName)

if req.Name == nil {
// All pipelines requested
s.mu.RLock()
defer s.mu.RUnlock()

pipelines, err := s.pipelineHandler.GetPipelines()
if err != nil {
Expand Down
Loading

0 comments on commit ede182a

Please sign in to comment.