Skip to content

Commit

Permalink
Fix concurrent map access in executor
Browse files Browse the repository at this point in the history
  • Loading branch information
ivan-valkov committed Feb 17, 2021
1 parent 182e4c7 commit 01c44be
Showing 1 changed file with 28 additions and 14 deletions.
42 changes: 28 additions & 14 deletions executor/predictor/predictor_process.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,24 +30,26 @@ var (
)

type PredictorProcess struct {
Ctx context.Context
Client client.SeldonApiClient
Log logr.Logger
ServerUrl *url.URL
Namespace string
Meta *payload.MetaData
Routing map[string]int32
Ctx context.Context
Client client.SeldonApiClient
Log logr.Logger
ServerUrl *url.URL
Namespace string
Meta *payload.MetaData
Routing map[string]int32
RoutingMutex *sync.RWMutex
}

func NewPredictorProcess(context context.Context, client client.SeldonApiClient, log logr.Logger, serverUrl *url.URL, namespace string, meta map[string][]string) PredictorProcess {
return PredictorProcess{
Ctx: context,
Client: client,
Log: log,
ServerUrl: serverUrl,
Namespace: namespace,
Meta: payload.NewFromMap(meta),
Routing: make(map[string]int32),
Ctx: context,
Client: client,
Log: log,
ServerUrl: serverUrl,
Namespace: namespace,
Meta: payload.NewFromMap(meta),
Routing: make(map[string]int32),
RoutingMutex: &sync.RWMutex{},
}
}

Expand Down Expand Up @@ -89,14 +91,18 @@ func (p *PredictorProcess) transformInput(node *v1.PredictiveUnit, msg payload.S
if err != nil {
return nil, err
}
p.RoutingMutex.Lock()
p.Routing[node.Name] = -1
p.RoutingMutex.Unlock()
return p.Client.Predict(p.Ctx, node.Name, node.Endpoint.ServiceHost, p.getPort(node), msg, p.Meta.Meta)
} else if callTransformInput {
msg, err := p.Client.Chain(p.Ctx, node.Name, msg)
if err != nil {
return nil, err
}
p.RoutingMutex.Lock()
p.Routing[node.Name] = -1
p.RoutingMutex.Unlock()
return p.Client.TransformInput(p.Ctx, node.Name, node.Endpoint.ServiceHost, p.getPort(node), msg, p.Meta.Meta)
} else {
return msg, nil
Expand Down Expand Up @@ -189,7 +195,9 @@ func (p *PredictorProcess) aggregate(node *v1.PredictiveUnit, msg []payload.Seld
}

if callClient {
p.RoutingMutex.Lock()
p.Routing[node.Name] = -1
p.RoutingMutex.Unlock()
return p.Client.Combine(p.Ctx, node.Name, node.Endpoint.ServiceHost, p.getPort(node), msg, p.Meta.Meta)
} else {
return msg[0], nil
Expand All @@ -216,20 +224,26 @@ func (p *PredictorProcess) predictChildren(node *v1.PredictiveUnit, msg payload.
}(i, nodeChild, msg)
}
wg.Wait()
p.RoutingMutex.Lock()
p.Routing[node.Name] = -1
p.RoutingMutex.Unlock()
for i, err := range errs {
if err != nil {
return cmsgs[i], err
}
}
} else if route == -2 {
//Abort and return request
p.RoutingMutex.Lock()
p.Routing[node.Name] = -2
p.RoutingMutex.Unlock()
return msg, nil
} else {
cmsgs = make([]payload.SeldonPayload, 1)
cmsgs[0], err = p.Predict(&node.Children[route], msg)
p.RoutingMutex.Lock()
p.Routing[node.Name] = int32(route)
p.RoutingMutex.Unlock()
if err != nil {
return cmsgs[0], err
}
Expand Down

0 comments on commit 01c44be

Please sign in to comment.