Skip to content

Commit

Permalink
Merge pull request #1952 from ericandrewmeadows/ericandrewmeadows/fix…
Browse files Browse the repository at this point in the history
…-cors

Added CORS headers to enable Front-End inputs on Go Engine
  • Loading branch information
axsaucedo authored Jun 19, 2020
2 parents 0c82a86 + 5d2a533 commit 79ca088
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 12 deletions.
28 changes: 27 additions & 1 deletion executor/api/rest/middlewares.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package rest

import (
"net/http"

guuid "github.com/google/uuid"
"github.com/seldonio/seldon-core/executor/api/payload"
"net/http"
"github.com/seldonio/seldon-core/executor/api/util"
)

const (
Expand All @@ -16,6 +18,13 @@ const (

contentTypeOptsHeader = "X-Content-Type-Options"
contentTypeOptsValue = "nosniff"

corsAllowOriginEnvVar = "CORS_ALLOWED_ORIGINS"
corsAllowOriginHeader = "Access-Control-Allow-Origin"
corsAllowOriginValueAll = "*"
corsAllowOriginHeadersVar = "CORS_ALLOWED_HEADERS"
corsAllowHeadersHeader = "Access-Control-Allow-Headers"
corsAllowHeadersValueDefault = "Accept, Accept-Encoding, Authorization, Content-Length, Content-Type, X-CSRF-Token"
)

type CloudeventHeaderMiddleware struct {
Expand All @@ -39,6 +48,23 @@ func (h *CloudeventHeaderMiddleware) Middleware(next http.Handler) http.Handler
})
}

// handleCORSRequests adds CORS-required headers, and during CORS Preflight
// requests, it will exit the request and the request status will be
// http.StatusOK
func handleCORSRequests(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
corsAllowOriginValue := util.GetEnv(corsAllowOriginEnvVar, corsAllowOriginValueAll)
corsAllowHeadersValue := util.GetEnv(corsAllowOriginHeadersVar, corsAllowHeadersValueDefault)
w.Header().Set(corsAllowOriginHeader, corsAllowOriginValue)
w.Header().Set(corsAllowHeadersHeader, corsAllowHeadersValue)
// Don't pass along OPTIONS (CORS Prefetch) Requests
if r.Method == "OPTIONS" {
return
}
next.ServeHTTP(w, r)
})
}

func puidHeader(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
puid := r.Header.Get(payload.SeldonPUIDHeader)
Expand Down
69 changes: 69 additions & 0 deletions executor/api/rest/middlewares_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,80 @@ package rest
import (
"net/http"
"net/http/httptest"
"os"
"testing"

. "github.com/onsi/gomega"
)

func TestEnvVars(t *testing.T) {
g := NewGomegaWithT(t)

os.Setenv(corsAllowOriginEnvVar, "http://www.google.com")
os.Setenv(corsAllowOriginHeadersVar, "Accept")
defer os.Unsetenv(corsAllowOriginEnvVar)
defer os.Unsetenv(corsAllowOriginHeadersVar)

m := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
wrapped := handleCORSRequests(m)

req := httptest.NewRequest("GET", "http://example.com/foo", nil)
w := httptest.NewRecorder()
wrapped.ServeHTTP(w, req)

res := w.Result()
defer res.Body.Close()

headerValAllowOrigin := res.Header.Get(corsAllowOriginHeader)
g.Expect(headerValAllowOrigin).To(Equal("http://www.google.com"))

headerValAllowHeaders := res.Header.Get(corsAllowHeadersHeader)
g.Expect(headerValAllowHeaders).To(Equal("Accept"))
}

func TestCORSHeadersGetRequest(t *testing.T) {
g := NewGomegaWithT(t)

m := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
wrapped := handleCORSRequests(m)

req := httptest.NewRequest("GET", "http://example.com/foo", nil)
w := httptest.NewRecorder()
wrapped.ServeHTTP(w, req)

res := w.Result()
defer res.Body.Close()

headerValAllowOrigin := res.Header.Get(corsAllowOriginHeader)
g.Expect(headerValAllowOrigin).To(Equal(corsAllowOriginValueAll))

headerValAllowHeaders := res.Header.Get(corsAllowHeadersHeader)
g.Expect(headerValAllowHeaders).To(Equal(corsAllowHeadersValueDefault))
}

func TestCORSHeadersOptionsRequest(t *testing.T) {
g := NewGomegaWithT(t)

m := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
wrapped := handleCORSRequests(m)

req := httptest.NewRequest("GET", "http://example.com/foo", nil)
w := httptest.NewRecorder()
wrapped.ServeHTTP(w, req)

res := w.Result()
defer res.Body.Close()

headerValAllowOrigin := res.Header.Get(corsAllowOriginHeader)
g.Expect(headerValAllowOrigin).To(Equal(corsAllowOriginValueAll))

headerValAllowHeaders := res.Header.Get(corsAllowHeadersHeader)
g.Expect(headerValAllowHeaders).To(Equal(corsAllowHeadersValueDefault))

statusCode := res.StatusCode
g.Expect(statusCode).To(Equal(http.StatusOK))
}

func TestXSSMiddleware(t *testing.T) {
g := NewGomegaWithT(t)

Expand Down
24 changes: 13 additions & 11 deletions executor/api/rest/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,30 +128,32 @@ func (r *SeldonRestApi) Initialise() {
r.Router.Use(puidHeader)
r.Router.Use(cloudeventHeaderMiddleware.Middleware)
r.Router.Use(xssMiddleware)
r.Router.Use(mux.CORSMethodMiddleware(r.Router))
r.Router.Use(handleCORSRequests)

switch r.Protocol {
case api.ProtocolSeldon:
//v0.1 API
api01 := r.Router.PathPrefix("/api/v0.1").Methods("POST").Subrouter()
api01 := r.Router.PathPrefix("/api/v0.1").Methods("OPTIONS", "POST").Subrouter()
api01.Handle("/predictions", r.wrapMetrics(metric.PredictionHttpServiceName, r.predictions))
api01.Handle("/feedback", r.wrapMetrics(metric.FeedbackHttpServiceName, r.feedback))
r.Router.NewRoute().Path("/api/v0.1/status/{" + ModelHttpPathVariable + "}").Methods("GET").HandlerFunc(r.wrapMetrics(metric.StatusHttpServiceName, r.status))
r.Router.NewRoute().Path("/api/v0.1/metadata/{" + ModelHttpPathVariable + "}").Methods("GET").HandlerFunc(r.wrapMetrics(metric.MetadataHttpServiceName, r.metadata))
r.Router.NewRoute().Path("/api/v0.1/status/{"+ModelHttpPathVariable+"}").Methods("GET", "OPTIONS").HandlerFunc(r.wrapMetrics(metric.StatusHttpServiceName, r.status))
r.Router.NewRoute().Path("/api/v0.1/metadata/{"+ModelHttpPathVariable+"}").Methods("GET", "OPTIONS").HandlerFunc(r.wrapMetrics(metric.MetadataHttpServiceName, r.metadata))
r.Router.NewRoute().PathPrefix("/api/v0.1/doc/").Handler(http.StripPrefix("/api/v0.1/doc/", http.FileServer(http.Dir("./openapi/"))))
//v1.0 API
api10 := r.Router.PathPrefix("/api/v1.0").Methods("POST").Subrouter()
api10 := r.Router.PathPrefix("/api/v1.0").Methods("OPTIONS", "POST").Subrouter()
api10.Handle("/predictions", r.wrapMetrics(metric.PredictionHttpServiceName, r.predictions))
api10.Handle("/feedback", r.wrapMetrics(metric.FeedbackHttpServiceName, r.feedback))
r.Router.NewRoute().Path("/api/v1.0/status/{" + ModelHttpPathVariable + "}").Methods("GET").HandlerFunc(r.wrapMetrics(metric.StatusHttpServiceName, r.status))
r.Router.NewRoute().Path("/api/v1.0/metadata").Methods("GET").HandlerFunc(r.wrapMetrics(metric.MetadataHttpServiceName, r.graphMetadata))
r.Router.NewRoute().Path("/api/v1.0/metadata/{" + ModelHttpPathVariable + "}").Methods("GET").HandlerFunc(r.wrapMetrics(metric.MetadataHttpServiceName, r.metadata))
r.Router.NewRoute().Path("/api/v1.0/status/{"+ModelHttpPathVariable+"}").Methods("GET", "OPTIONS").HandlerFunc(r.wrapMetrics(metric.StatusHttpServiceName, r.status))
r.Router.NewRoute().Path("/api/v1.0/metadata").Methods("GET", "OPTIONS").HandlerFunc(r.wrapMetrics(metric.MetadataHttpServiceName, r.graphMetadata))
r.Router.NewRoute().Path("/api/v1.0/metadata/{"+ModelHttpPathVariable+"}").Methods("GET", "OPTIONS").HandlerFunc(r.wrapMetrics(metric.MetadataHttpServiceName, r.metadata))
r.Router.NewRoute().PathPrefix("/api/v1.0/doc/").Handler(http.StripPrefix("/api/v1.0/doc/", http.FileServer(http.Dir("./openapi/"))))

case api.ProtocolTensorflow:
r.Router.NewRoute().Path("/v1/models/{" + ModelHttpPathVariable + "}/:predict").Methods("POST").HandlerFunc(r.wrapMetrics(metric.PredictionHttpServiceName, r.predictions))
r.Router.NewRoute().Path("/v1/models/:predict").Methods("POST").HandlerFunc(r.wrapMetrics(metric.PredictionHttpServiceName, r.predictions)) // Nonstandard path - Seldon extension
r.Router.NewRoute().Path("/v1/models/{" + ModelHttpPathVariable + "}").Methods("GET").HandlerFunc(r.wrapMetrics(metric.StatusHttpServiceName, r.status))
r.Router.NewRoute().Path("/v1/models/{" + ModelHttpPathVariable + "}/metadata").Methods("GET").HandlerFunc(r.wrapMetrics(metric.MetadataHttpServiceName, r.metadata))
r.Router.NewRoute().Path("/v1/models/{"+ModelHttpPathVariable+"}/:predict").Methods("OPTIONS", "POST").HandlerFunc(r.wrapMetrics(metric.PredictionHttpServiceName, r.predictions))
r.Router.NewRoute().Path("/v1/models/:predict").Methods("OPTIONS", "POST").HandlerFunc(r.wrapMetrics(metric.PredictionHttpServiceName, r.predictions)) // Nonstandard path - Seldon extension
r.Router.NewRoute().Path("/v1/models/{"+ModelHttpPathVariable+"}").Methods("GET", "OPTIONS").HandlerFunc(r.wrapMetrics(metric.StatusHttpServiceName, r.status))
r.Router.NewRoute().Path("/v1/models/{"+ModelHttpPathVariable+"}/metadata").Methods("GET", "OPTIONS").HandlerFunc(r.wrapMetrics(metric.MetadataHttpServiceName, r.metadata))
}
}
}
Expand Down
10 changes: 10 additions & 0 deletions executor/api/util/utils.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package util

import (
"os"

"github.com/seldonio/seldon-core/executor/api/grpc/seldon/proto"
)

Expand Down Expand Up @@ -34,3 +36,11 @@ func ExtractRouteFromSeldonMessage(msg *proto.SeldonMessage) []int {
}
return []int{-1}
}

// Get an environment variable given by key or return the fallback.
func GetEnv(key, fallback string) string {
if value, ok := os.LookupEnv(key); ok {
return value
}
return fallback
}

0 comments on commit 79ca088

Please sign in to comment.