diff --git a/executor/api/rest/middlewares.go b/executor/api/rest/middlewares.go index 9017a8e276..9e123600dc 100644 --- a/executor/api/rest/middlewares.go +++ b/executor/api/rest/middlewares.go @@ -5,6 +5,7 @@ import ( guuid "github.com/google/uuid" "github.com/seldonio/seldon-core/executor/api/payload" + "github.com/seldonio/seldon-core/executor/api/util" ) const ( @@ -18,12 +19,12 @@ const ( contentTypeOptsHeader = "X-Content-Type-Options" contentTypeOptsValue = "nosniff" - corsAllowOriginHeader = "Access-Control-Allow-Origin" - corsAllowOriginValue = "*" - corsAllowMethodsHeader = "Access-Control-Allow-Methods" - corsAllowMethodsValue = "GET, OPTIONS, POST" - corsAllowHeadersHeader = "Access-Control-Allow-Headers" - corsAllowHeadersValue = "Accept, Accept-Encoding, Authorization, Content-Length, Content-Type, X-CSRF-Token" + corsAllowOriginEnvVar = "CORS_ALLOWED_ORIGINS" + corsAllowOriginHeader = "Access-Control-Allow-Origin" + corsAllowOriginValueAll = "*" + corsAllowOriginHeadersVar = "CORS_ALLOWED_HEADERS" + corsAllowHeadersHeader = "Access-Control-Allow-Headers" + corsAllowHeadersValueDefault = "Accept, Accept-Encoding, Authorization, Content-Length, Content-Type, X-CSRF-Token" ) type CloudeventHeaderMiddleware struct { @@ -52,12 +53,12 @@ func (h *CloudeventHeaderMiddleware) Middleware(next http.Handler) http.Handler // http.StatusOK func handleCORSRequests(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + corsAllowOriginValue := util.GetEnv(corsAllowOriginEnvVar, corsAllowOriginValueAll) + corsAllowHeadersValue := util.GetEnv(corsAllowOriginHeadersVar, corsAllowHeadersValueDefault) w.Header().Set(corsAllowOriginHeader, corsAllowOriginValue) - w.Header().Set(corsAllowMethodsHeader, corsAllowMethodsValue) w.Header().Set(corsAllowHeadersHeader, corsAllowHeadersValue) // Don't pass along OPTIONS (CORS Prefetch) Requests if r.Method == "OPTIONS" { - w.WriteHeader(http.StatusOK) return } next.ServeHTTP(w, r) diff --git a/executor/api/rest/middlewares_test.go b/executor/api/rest/middlewares_test.go index adea623787..ed412d1485 100644 --- a/executor/api/rest/middlewares_test.go +++ b/executor/api/rest/middlewares_test.go @@ -3,16 +3,22 @@ package rest import ( "net/http" "net/http/httptest" + "os" "testing" . "github.com/onsi/gomega" ) -func TestCORSHeadersGetRequest(t *testing.T) { +func TestEnvVars(t *testing.T) { g := NewGomegaWithT(t) + os.Setenv(corsAllowOriginEnvVar, "http://www.google.com") + os.Setenv(corsAllowOriginHeadersVar, "Accept") + defer os.Unsetenv(corsAllowOriginEnvVar) + defer os.Unsetenv(corsAllowOriginHeadersVar) + m := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) - wrapped := corsHeaders(m) + wrapped := handleCORSRequests(m) req := httptest.NewRequest("GET", "http://example.com/foo", nil) w := httptest.NewRecorder() @@ -22,20 +28,37 @@ func TestCORSHeadersGetRequest(t *testing.T) { defer res.Body.Close() headerValAllowOrigin := res.Header.Get(corsAllowOriginHeader) - g.Expect(headerValAllowOrigin).To(Equal(corsAllowOriginValue)) + g.Expect(headerValAllowOrigin).To(Equal("http://www.google.com")) headerValAllowHeaders := res.Header.Get(corsAllowHeadersHeader) - g.Expect(headerValAllowHeaders).To(Equal(corsAllowHeadersValue)) + g.Expect(headerValAllowHeaders).To(Equal("Accept")) +} + +func TestCORSHeadersGetRequest(t *testing.T) { + g := NewGomegaWithT(t) + + m := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + wrapped := handleCORSRequests(m) - headerValAllowMethods := res.Header.Get(corsAllowMethodsHeader) - g.Expect(headerValAllowMethods).To(Equal(corsAllowMethodsValue)) + req := httptest.NewRequest("GET", "http://example.com/foo", nil) + w := httptest.NewRecorder() + wrapped.ServeHTTP(w, req) + + res := w.Result() + defer res.Body.Close() + + headerValAllowOrigin := res.Header.Get(corsAllowOriginHeader) + g.Expect(headerValAllowOrigin).To(Equal(corsAllowOriginValueAll)) + + headerValAllowHeaders := res.Header.Get(corsAllowHeadersHeader) + g.Expect(headerValAllowHeaders).To(Equal(corsAllowHeadersValueDefault)) } func TestCORSHeadersOptionsRequest(t *testing.T) { g := NewGomegaWithT(t) m := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) - wrapped := corsHeaders(m) + wrapped := handleCORSRequests(m) req := httptest.NewRequest("GET", "http://example.com/foo", nil) w := httptest.NewRecorder() @@ -45,13 +68,10 @@ func TestCORSHeadersOptionsRequest(t *testing.T) { defer res.Body.Close() headerValAllowOrigin := res.Header.Get(corsAllowOriginHeader) - g.Expect(headerValAllowOrigin).To(Equal(corsAllowOriginValue)) + g.Expect(headerValAllowOrigin).To(Equal(corsAllowOriginValueAll)) headerValAllowHeaders := res.Header.Get(corsAllowHeadersHeader) - g.Expect(headerValAllowHeaders).To(Equal(corsAllowHeadersValue)) - - headerValAllowMethods := res.Header.Get(corsAllowMethodsHeader) - g.Expect(headerValAllowMethods).To(Equal(corsAllowMethodsValue)) + g.Expect(headerValAllowHeaders).To(Equal(corsAllowHeadersValueDefault)) statusCode := res.StatusCode g.Expect(statusCode).To(Equal(http.StatusOK)) diff --git a/executor/api/rest/server.go b/executor/api/rest/server.go index 2de35257fe..9e0845f09d 100644 --- a/executor/api/rest/server.go +++ b/executor/api/rest/server.go @@ -127,6 +127,7 @@ func (r *SeldonRestApi) Initialise() { r.Router.Use(puidHeader) r.Router.Use(cloudeventHeaderMiddleware.Middleware) r.Router.Use(xssMiddleware) + r.Router.Use(mux.CORSMethodMiddleware(r.Router)) r.Router.Use(handleCORSRequests) switch r.Protocol { diff --git a/executor/api/util/utils.go b/executor/api/util/utils.go index 3a72a06f5e..03f38d7e32 100644 --- a/executor/api/util/utils.go +++ b/executor/api/util/utils.go @@ -1,6 +1,8 @@ package util import ( + "os" + "github.com/seldonio/seldon-core/executor/api/grpc/seldon/proto" ) @@ -34,3 +36,11 @@ func ExtractRouteFromSeldonMessage(msg *proto.SeldonMessage) []int { } return []int{-1} } + +// Get an environment variable given by key or return the fallback. +func GetEnv(key, fallback string) string { + if value, ok := os.LookupEnv(key); ok { + return value + } + return fallback +} diff --git a/executor/go.mod b/executor/go.mod index 5a1ca0fbff..8f90b223f4 100644 --- a/executor/go.mod +++ b/executor/go.mod @@ -24,6 +24,7 @@ require ( github.com/uber/jaeger-client-go v2.21.1+incompatible github.com/uber/jaeger-lib v2.2.0+incompatible // indirect golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7 + google.golang.org/appengine v1.6.5 google.golang.org/grpc v1.28.0 gotest.tools v2.2.0+incompatible k8s.io/api v0.17.2