Skip to content

Commit

Permalink
Merge pull request #1684 from cliveseldon/1611_tf_protocol_path
Browse files Browse the repository at this point in the history
Allow non-model specific predict for Tensorflow protocol
  • Loading branch information
ukclivecox authored Apr 14, 2020
2 parents c949408 + 7de281f commit be4c2c7
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 22 deletions.
4 changes: 4 additions & 0 deletions doc/source/graph/protocols.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,15 @@ For Seldon graphs the protocol will work as expected for single model graphs for
* Sending the response from the first as a request to the second. This will be done automatically when you defined a chain of models as a Seldon graph. It is up to the user to ensure the response of each changed model can be fed a request to the next in the chain.
* Only Predict calls can be handled in multiple model chaining.


General considerations:

* Seldon components marked as MODELS, INPUT_TRANSFORMER and OUTPUT_TRANSFORMERS will allow a PredictionService Predict method to be called.
* GetModelStatus for any model in the graph is available.
* GetModelMetadata for any model in the graph is available.
* Combining and Routing with the Tensorflow protocol is not presently supported.
* `status` and `metadata` calls can be asked for any model in the graph
* a non-standard Seldon extension is available to call predict on the graph as a whole: `/v1/models/:predict`.
* The name of the model in the `graph` section of the SeldonDeployment spec must match the name of the model loaded onto the Tensorflow Server.


29 changes: 12 additions & 17 deletions executor/api/rest/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ func (r *SeldonRestApi) respondWithSuccess(w http.ResponseWriter, code int, payl
}

func (r *SeldonRestApi) respondWithError(w http.ResponseWriter, payload payload.SeldonPayload, err error) {
w.Header().Set("Content-Type", payload.GetContentType())

if serr, ok := err.(*httpStatusError); ok {
w.WriteHeader(serr.StatusCode)
Expand All @@ -100,12 +99,14 @@ func (r *SeldonRestApi) respondWithError(w http.ResponseWriter, payload payload.
}

if payload != nil && payload.GetPayload() != nil {
w.Header().Set("Content-Type", payload.GetContentType())
err := r.Client.Marshall(w, payload)
if err != nil {
r.Log.Error(err, "Failed to write response")
}
} else {
errPayload := r.Client.CreateErrorPayload(err)
w.Header().Set("Content-Type", errPayload.GetContentType())
err = r.Client.Marshall(w, errPayload)
if err != nil {
r.Log.Error(err, "Failed to write error payload")
Expand Down Expand Up @@ -149,6 +150,7 @@ func (r *SeldonRestApi) Initialise() {

case api.ProtocolTensorflow:
r.Router.NewRoute().Path("/v1/models/{" + ModelHttpPathVariable + "}/:predict").Methods("POST").HandlerFunc(r.wrapMetrics(metric.PredictionHttpServiceName, r.predictions))
r.Router.NewRoute().Path("/v1/models/:predict").Methods("POST").HandlerFunc(r.wrapMetrics(metric.PredictionHttpServiceName, r.predictions)) // Nonstandard path - Seldon extension
r.Router.NewRoute().Path("/v1/models/{" + ModelHttpPathVariable + "}").Methods("GET").HandlerFunc(r.wrapMetrics(metric.StatusHttpServiceName, r.status))
r.Router.NewRoute().Path("/v1/models/{" + ModelHttpPathVariable + "}/metadata").Methods("GET").HandlerFunc(r.wrapMetrics(metric.MetadataHttpServiceName, r.metadata))
}
Expand All @@ -163,8 +165,6 @@ type CloudeventHeaderMiddleware struct {
func (h *CloudeventHeaderMiddleware) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Checking if request is cloudevent based on specname being present
fmt.Println(r.Header)
fmt.Println(w.Header())
if _, ok := r.Header[CLOUDEVENTS_HEADER_SPECVERSION_NAME]; ok {
puid := r.Header.Get(payload.SeldonPUIDHeader)
w.Header().Set(CLOUDEVENTS_HEADER_ID_NAME, puid)
Expand Down Expand Up @@ -207,16 +207,6 @@ func (r *SeldonRestApi) alive(w http.ResponseWriter, req *http.Request) {
w.WriteHeader(http.StatusOK)
}

func getGraphNodeForModelName(req *http.Request, graph *v1.PredictiveUnit) (*v1.PredictiveUnit, error) {
vars := mux.Vars(req)
modelName := vars[ModelHttpPathVariable]
if graphNode := v1.GetPredictiveUnit(graph, modelName); graphNode == nil {
return nil, fmt.Errorf("Failed to find model %s", modelName)
} else {
return graphNode, nil
}
}

func setupTracing(ctx context.Context, req *http.Request, spanName string) (context.Context, opentracing.Span) {
tracer := opentracing.GlobalTracer()
spanCtx, _ := tracer.Extract(opentracing.HTTPHeaders, opentracing.HTTPHeadersCarrier(req.Header))
Expand Down Expand Up @@ -299,10 +289,15 @@ func (r *SeldonRestApi) predictions(w http.ResponseWriter, req *http.Request) {

var graphNode *v1.PredictiveUnit
if r.Protocol == api.ProtocolTensorflow {
graphNode, err = getGraphNodeForModelName(req, r.predictor.Graph)
if err != nil {
r.respondWithError(w, nil, err)
return
vars := mux.Vars(req)
modelName := vars[ModelHttpPathVariable]
if modelName != "" {
if graphNode = v1.GetPredictiveUnit(r.predictor.Graph, modelName); graphNode == nil {
r.respondWithError(w, nil, fmt.Errorf("Failed to find model %s", modelName))
return
}
} else {
graphNode = r.predictor.Graph
}
} else {
graphNode = r.predictor.Graph
Expand Down
60 changes: 60 additions & 0 deletions executor/api/rest/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -476,3 +476,63 @@ func TestPredictErrorWithServer(t *testing.T) {
g.Expect(err).Should(BeNil())
g.Expect(string(b)).To(Equal(errorPredictResponse))
}

func TestTensorflowModel(t *testing.T) {
t.Logf("Started")
g := NewGomegaWithT(t)

model := v1.MODEL
p := v1.PredictorSpec{
Name: "p",
Graph: &v1.PredictiveUnit{
Name: "mymodel",
Type: &model,
Endpoint: &v1.Endpoint{
ServiceHost: "foo",
ServicePort: 9000,
Type: v1.REST,
},
},
}

url, _ := url.Parse("http://localhost")
r := NewServerRestApi(&p, &test.SeldonMessageTestClient{}, false, url, "default", api.ProtocolTensorflow, "test", "/metrics")
r.Initialise()

var data = ` {"instances":[[1,2,3]]}`
req, _ := http.NewRequest("POST", "/v1/models/:predict", strings.NewReader(data))
res := httptest.NewRecorder()
r.Router.ServeHTTP(res, req)
g.Expect(res.Code).To(Equal(200))
g.Expect(res.Body.String()).To(Equal(data))
}

func TestTensorflowModelBadModelName(t *testing.T) {
t.Logf("Started")
g := NewGomegaWithT(t)

model := v1.MODEL
p := v1.PredictorSpec{
Name: "p",
Graph: &v1.PredictiveUnit{
Name: "mymodel",
Type: &model,
Endpoint: &v1.Endpoint{
ServiceHost: "foo",
ServicePort: 9000,
Type: v1.REST,
},
},
}

url, _ := url.Parse("http://localhost")
r := NewServerRestApi(&p, &test.SeldonMessageTestClient{}, false, url, "default", api.ProtocolTensorflow, "test", "/metrics")
r.Initialise()

var data = ` {"instances":[[1,2,3]]}`
req, _ := http.NewRequest("POST", "/v1/models/xyz/:predict", strings.NewReader(data))
res := httptest.NewRecorder()
r.Router.ServeHTTP(res, req)
g.Expect(res.Code).To(Equal(500))
g.Expect(res.Header().Get("Content-Type")).To(Equal(test.TestContentType))
}
9 changes: 4 additions & 5 deletions executor/api/test/seldonmessage_test_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"github.com/seldonio/seldon-core/executor/api/payload"
"github.com/seldonio/seldon-core/operator/apis/machinelearning.seldon.io/v1"
"io"
"net/http"
)

type SeldonMessageTestClient struct {
Expand All @@ -19,6 +18,7 @@ type SeldonMessageTestClient struct {
const (
TestClientStatusResponse = `{"status":"ok"}`
TestClientMetadataResponse = `{"metadata":{"name":"mymodel"}}`
TestContentType = "application/json"
)

func (s SeldonMessageTestClient) Status(ctx context.Context, modelName string, host string, port int32, msg payload.SeldonPayload, meta map[string][]string) (payload.SeldonPayload, error) {
Expand All @@ -34,7 +34,7 @@ func (s SeldonMessageTestClient) Chain(ctx context.Context, modelName string, ms
}

func (s SeldonMessageTestClient) Unmarshall(msg []byte) (payload.SeldonPayload, error) {
reqPayload := payload.BytesPayload{Msg: msg, ContentType: "application/json"}
reqPayload := payload.BytesPayload{Msg: msg, ContentType: TestContentType}
return &reqPayload, nil
}

Expand All @@ -44,9 +44,8 @@ func (s SeldonMessageTestClient) Marshall(out io.Writer, msg payload.SeldonPaylo
}

func (s SeldonMessageTestClient) CreateErrorPayload(err error) payload.SeldonPayload {
respFailed := proto.SeldonMessage{Status: &proto.Status{Code: http.StatusInternalServerError, Info: err.Error()}}
res := payload.ProtoPayload{Msg: &respFailed}
return &res
respFailed := payload.BytesPayload{Msg: []byte(err.Error()), ContentType: TestContentType}
return &respFailed
}

func (s SeldonMessageTestClient) Predict(ctx context.Context, modelName string, host string, port int32, msg payload.SeldonPayload, meta map[string][]string) (payload.SeldonPayload, error) {
Expand Down

0 comments on commit be4c2c7

Please sign in to comment.