diff --git a/executor/predictor/predictor_process.go b/executor/predictor/predictor_process.go index e8c2e1e551..79146a59d9 100644 --- a/executor/predictor/predictor_process.go +++ b/executor/predictor/predictor_process.go @@ -30,24 +30,26 @@ var ( ) type PredictorProcess struct { - Ctx context.Context - Client client.SeldonApiClient - Log logr.Logger - ServerUrl *url.URL - Namespace string - Meta *payload.MetaData - Routing map[string]int32 + Ctx context.Context + Client client.SeldonApiClient + Log logr.Logger + ServerUrl *url.URL + Namespace string + Meta *payload.MetaData + Routing map[string]int32 + RoutingMutex *sync.RWMutex } func NewPredictorProcess(context context.Context, client client.SeldonApiClient, log logr.Logger, serverUrl *url.URL, namespace string, meta map[string][]string) PredictorProcess { return PredictorProcess{ - Ctx: context, - Client: client, - Log: log, - ServerUrl: serverUrl, - Namespace: namespace, - Meta: payload.NewFromMap(meta), - Routing: make(map[string]int32), + Ctx: context, + Client: client, + Log: log, + ServerUrl: serverUrl, + Namespace: namespace, + Meta: payload.NewFromMap(meta), + Routing: make(map[string]int32), + RoutingMutex: &sync.RWMutex{}, } } @@ -89,14 +91,18 @@ func (p *PredictorProcess) transformInput(node *v1.PredictiveUnit, msg payload.S if err != nil { return nil, err } + p.RoutingMutex.Lock() p.Routing[node.Name] = -1 + p.RoutingMutex.Unlock() return p.Client.Predict(p.Ctx, node.Name, node.Endpoint.ServiceHost, p.getPort(node), msg, p.Meta.Meta) } else if callTransformInput { msg, err := p.Client.Chain(p.Ctx, node.Name, msg) if err != nil { return nil, err } + p.RoutingMutex.Lock() p.Routing[node.Name] = -1 + p.RoutingMutex.Unlock() return p.Client.TransformInput(p.Ctx, node.Name, node.Endpoint.ServiceHost, p.getPort(node), msg, p.Meta.Meta) } else { return msg, nil @@ -189,7 +195,9 @@ func (p *PredictorProcess) aggregate(node *v1.PredictiveUnit, msg []payload.Seld } if callClient { + p.RoutingMutex.Lock() p.Routing[node.Name] = -1 + p.RoutingMutex.Unlock() return p.Client.Combine(p.Ctx, node.Name, node.Endpoint.ServiceHost, p.getPort(node), msg, p.Meta.Meta) } else { return msg[0], nil @@ -216,7 +224,9 @@ func (p *PredictorProcess) predictChildren(node *v1.PredictiveUnit, msg payload. }(i, nodeChild, msg) } wg.Wait() + p.RoutingMutex.Lock() p.Routing[node.Name] = -1 + p.RoutingMutex.Unlock() for i, err := range errs { if err != nil { return cmsgs[i], err @@ -224,12 +234,16 @@ func (p *PredictorProcess) predictChildren(node *v1.PredictiveUnit, msg payload. } } else if route == -2 { //Abort and return request + p.RoutingMutex.Lock() p.Routing[node.Name] = -2 + p.RoutingMutex.Unlock() return msg, nil } else { cmsgs = make([]payload.SeldonPayload, 1) cmsgs[0], err = p.Predict(&node.Children[route], msg) + p.RoutingMutex.Lock() p.Routing[node.Name] = int32(route) + p.RoutingMutex.Unlock() if err != nil { return cmsgs[0], err }