diff --git a/doc/source/graph/protocols.md b/doc/source/graph/protocols.md index 6f82a3b8bd..f7687aaa72 100644 --- a/doc/source/graph/protocols.md +++ b/doc/source/graph/protocols.md @@ -27,11 +27,15 @@ For Seldon graphs the protocol will work as expected for single model graphs for * Sending the response from the first as a request to the second. This will be done automatically when you defined a chain of models as a Seldon graph. It is up to the user to ensure the response of each changed model can be fed a request to the next in the chain. * Only Predict calls can be handled in multiple model chaining. + General considerations: * Seldon components marked as MODELS, INPUT_TRANSFORMER and OUTPUT_TRANSFORMERS will allow a PredictionService Predict method to be called. * GetModelStatus for any model in the graph is available. * GetModelMetadata for any model in the graph is available. * Combining and Routing with the Tensorflow protocol is not presently supported. + * `status` and `metadata` calls can be asked for any model in the graph + * a non-standard Seldon extension is available to call predict on the graph as a whole: `/v1/models/:predict`. + * The name of the model in the `graph` section of the SeldonDeployment spec must match the name of the model loaded onto the Tensorflow Server. diff --git a/executor/api/rest/server.go b/executor/api/rest/server.go index 53e681ac54..4f0315673f 100644 --- a/executor/api/rest/server.go +++ b/executor/api/rest/server.go @@ -91,7 +91,6 @@ func (r *SeldonRestApi) respondWithSuccess(w http.ResponseWriter, code int, payl } func (r *SeldonRestApi) respondWithError(w http.ResponseWriter, payload payload.SeldonPayload, err error) { - w.Header().Set("Content-Type", payload.GetContentType()) if serr, ok := err.(*httpStatusError); ok { w.WriteHeader(serr.StatusCode) @@ -100,12 +99,14 @@ func (r *SeldonRestApi) respondWithError(w http.ResponseWriter, payload payload. } if payload != nil && payload.GetPayload() != nil { + w.Header().Set("Content-Type", payload.GetContentType()) err := r.Client.Marshall(w, payload) if err != nil { r.Log.Error(err, "Failed to write response") } } else { errPayload := r.Client.CreateErrorPayload(err) + w.Header().Set("Content-Type", errPayload.GetContentType()) err = r.Client.Marshall(w, errPayload) if err != nil { r.Log.Error(err, "Failed to write error payload") @@ -149,6 +150,7 @@ func (r *SeldonRestApi) Initialise() { case api.ProtocolTensorflow: r.Router.NewRoute().Path("/v1/models/{" + ModelHttpPathVariable + "}/:predict").Methods("POST").HandlerFunc(r.wrapMetrics(metric.PredictionHttpServiceName, r.predictions)) + r.Router.NewRoute().Path("/v1/models/:predict").Methods("POST").HandlerFunc(r.wrapMetrics(metric.PredictionHttpServiceName, r.predictions)) // Nonstandard path - Seldon extension r.Router.NewRoute().Path("/v1/models/{" + ModelHttpPathVariable + "}").Methods("GET").HandlerFunc(r.wrapMetrics(metric.StatusHttpServiceName, r.status)) r.Router.NewRoute().Path("/v1/models/{" + ModelHttpPathVariable + "}/metadata").Methods("GET").HandlerFunc(r.wrapMetrics(metric.MetadataHttpServiceName, r.metadata)) } @@ -163,8 +165,6 @@ type CloudeventHeaderMiddleware struct { func (h *CloudeventHeaderMiddleware) Middleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Checking if request is cloudevent based on specname being present - fmt.Println(r.Header) - fmt.Println(w.Header()) if _, ok := r.Header[CLOUDEVENTS_HEADER_SPECVERSION_NAME]; ok { puid := r.Header.Get(payload.SeldonPUIDHeader) w.Header().Set(CLOUDEVENTS_HEADER_ID_NAME, puid) @@ -207,16 +207,6 @@ func (r *SeldonRestApi) alive(w http.ResponseWriter, req *http.Request) { w.WriteHeader(http.StatusOK) } -func getGraphNodeForModelName(req *http.Request, graph *v1.PredictiveUnit) (*v1.PredictiveUnit, error) { - vars := mux.Vars(req) - modelName := vars[ModelHttpPathVariable] - if graphNode := v1.GetPredictiveUnit(graph, modelName); graphNode == nil { - return nil, fmt.Errorf("Failed to find model %s", modelName) - } else { - return graphNode, nil - } -} - func setupTracing(ctx context.Context, req *http.Request, spanName string) (context.Context, opentracing.Span) { tracer := opentracing.GlobalTracer() spanCtx, _ := tracer.Extract(opentracing.HTTPHeaders, opentracing.HTTPHeadersCarrier(req.Header)) @@ -299,10 +289,15 @@ func (r *SeldonRestApi) predictions(w http.ResponseWriter, req *http.Request) { var graphNode *v1.PredictiveUnit if r.Protocol == api.ProtocolTensorflow { - graphNode, err = getGraphNodeForModelName(req, r.predictor.Graph) - if err != nil { - r.respondWithError(w, nil, err) - return + vars := mux.Vars(req) + modelName := vars[ModelHttpPathVariable] + if modelName != "" { + if graphNode = v1.GetPredictiveUnit(r.predictor.Graph, modelName); graphNode == nil { + r.respondWithError(w, nil, fmt.Errorf("Failed to find model %s", modelName)) + return + } + } else { + graphNode = r.predictor.Graph } } else { graphNode = r.predictor.Graph diff --git a/executor/api/rest/server_test.go b/executor/api/rest/server_test.go index 0484db7c5c..9e364814b4 100644 --- a/executor/api/rest/server_test.go +++ b/executor/api/rest/server_test.go @@ -476,3 +476,63 @@ func TestPredictErrorWithServer(t *testing.T) { g.Expect(err).Should(BeNil()) g.Expect(string(b)).To(Equal(errorPredictResponse)) } + +func TestTensorflowModel(t *testing.T) { + t.Logf("Started") + g := NewGomegaWithT(t) + + model := v1.MODEL + p := v1.PredictorSpec{ + Name: "p", + Graph: &v1.PredictiveUnit{ + Name: "mymodel", + Type: &model, + Endpoint: &v1.Endpoint{ + ServiceHost: "foo", + ServicePort: 9000, + Type: v1.REST, + }, + }, + } + + url, _ := url.Parse("http://localhost") + r := NewServerRestApi(&p, &test.SeldonMessageTestClient{}, false, url, "default", api.ProtocolTensorflow, "test", "/metrics") + r.Initialise() + + var data = ` {"instances":[[1,2,3]]}` + req, _ := http.NewRequest("POST", "/v1/models/:predict", strings.NewReader(data)) + res := httptest.NewRecorder() + r.Router.ServeHTTP(res, req) + g.Expect(res.Code).To(Equal(200)) + g.Expect(res.Body.String()).To(Equal(data)) +} + +func TestTensorflowModelBadModelName(t *testing.T) { + t.Logf("Started") + g := NewGomegaWithT(t) + + model := v1.MODEL + p := v1.PredictorSpec{ + Name: "p", + Graph: &v1.PredictiveUnit{ + Name: "mymodel", + Type: &model, + Endpoint: &v1.Endpoint{ + ServiceHost: "foo", + ServicePort: 9000, + Type: v1.REST, + }, + }, + } + + url, _ := url.Parse("http://localhost") + r := NewServerRestApi(&p, &test.SeldonMessageTestClient{}, false, url, "default", api.ProtocolTensorflow, "test", "/metrics") + r.Initialise() + + var data = ` {"instances":[[1,2,3]]}` + req, _ := http.NewRequest("POST", "/v1/models/xyz/:predict", strings.NewReader(data)) + res := httptest.NewRecorder() + r.Router.ServeHTTP(res, req) + g.Expect(res.Code).To(Equal(500)) + g.Expect(res.Header().Get("Content-Type")).To(Equal(test.TestContentType)) +} diff --git a/executor/api/test/seldonmessage_test_client.go b/executor/api/test/seldonmessage_test_client.go index 204810efdb..6abd55fbf8 100644 --- a/executor/api/test/seldonmessage_test_client.go +++ b/executor/api/test/seldonmessage_test_client.go @@ -6,7 +6,6 @@ import ( "github.com/seldonio/seldon-core/executor/api/payload" "github.com/seldonio/seldon-core/operator/apis/machinelearning.seldon.io/v1" "io" - "net/http" ) type SeldonMessageTestClient struct { @@ -19,6 +18,7 @@ type SeldonMessageTestClient struct { const ( TestClientStatusResponse = `{"status":"ok"}` TestClientMetadataResponse = `{"metadata":{"name":"mymodel"}}` + TestContentType = "application/json" ) func (s SeldonMessageTestClient) Status(ctx context.Context, modelName string, host string, port int32, msg payload.SeldonPayload, meta map[string][]string) (payload.SeldonPayload, error) { @@ -34,7 +34,7 @@ func (s SeldonMessageTestClient) Chain(ctx context.Context, modelName string, ms } func (s SeldonMessageTestClient) Unmarshall(msg []byte) (payload.SeldonPayload, error) { - reqPayload := payload.BytesPayload{Msg: msg, ContentType: "application/json"} + reqPayload := payload.BytesPayload{Msg: msg, ContentType: TestContentType} return &reqPayload, nil } @@ -44,9 +44,8 @@ func (s SeldonMessageTestClient) Marshall(out io.Writer, msg payload.SeldonPaylo } func (s SeldonMessageTestClient) CreateErrorPayload(err error) payload.SeldonPayload { - respFailed := proto.SeldonMessage{Status: &proto.Status{Code: http.StatusInternalServerError, Info: err.Error()}} - res := payload.ProtoPayload{Msg: &respFailed} - return &res + respFailed := payload.BytesPayload{Msg: []byte(err.Error()), ContentType: TestContentType} + return &respFailed } func (s SeldonMessageTestClient) Predict(ctx context.Context, modelName string, host string, port int32, msg payload.SeldonPayload, meta map[string][]string) (payload.SeldonPayload, error) {