Skip to content

Commit

Permalink
Environment variable setting enabled
Browse files Browse the repository at this point in the history
  • Loading branch information
Eric Meadows committed Jun 16, 2020
1 parent 4140a11 commit 65dfd7e
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 20 deletions.
17 changes: 9 additions & 8 deletions executor/api/rest/middlewares.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (

guuid "github.com/google/uuid"
"github.com/seldonio/seldon-core/executor/api/payload"
"github.com/seldonio/seldon-core/executor/api/util"
)

const (
Expand All @@ -18,12 +19,12 @@ const (
contentTypeOptsHeader = "X-Content-Type-Options"
contentTypeOptsValue = "nosniff"

corsAllowOriginHeader = "Access-Control-Allow-Origin"
corsAllowOriginValue = "*"
corsAllowMethodsHeader = "Access-Control-Allow-Methods"
corsAllowMethodsValue = "GET, OPTIONS, POST"
corsAllowHeadersHeader = "Access-Control-Allow-Headers"
corsAllowHeadersValue = "Accept, Accept-Encoding, Authorization, Content-Length, Content-Type, X-CSRF-Token"
corsAllowOriginEnvVar = "CORS_ALLOWED_ORIGINS"
corsAllowOriginHeader = "Access-Control-Allow-Origin"
corsAllowOriginValueAll = "*"
corsAllowOriginHeadersVar = "CORS_ALLOWED_HEADERS"
corsAllowHeadersHeader = "Access-Control-Allow-Headers"
corsAllowHeadersValueDefault = "Accept, Accept-Encoding, Authorization, Content-Length, Content-Type, X-CSRF-Token"
)

type CloudeventHeaderMiddleware struct {
Expand Down Expand Up @@ -52,12 +53,12 @@ func (h *CloudeventHeaderMiddleware) Middleware(next http.Handler) http.Handler
// http.StatusOK
func handleCORSRequests(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
corsAllowOriginValue := util.GetEnv(corsAllowOriginEnvVar, corsAllowOriginValueAll)
corsAllowHeadersValue := util.GetEnv(corsAllowOriginHeadersVar, corsAllowHeadersValueDefault)
w.Header().Set(corsAllowOriginHeader, corsAllowOriginValue)
w.Header().Set(corsAllowMethodsHeader, corsAllowMethodsValue)
w.Header().Set(corsAllowHeadersHeader, corsAllowHeadersValue)
// Don't pass along OPTIONS (CORS Prefetch) Requests
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusOK)
return
}
next.ServeHTTP(w, r)
Expand Down
44 changes: 32 additions & 12 deletions executor/api/rest/middlewares_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,22 @@ package rest
import (
"net/http"
"net/http/httptest"
"os"
"testing"

. "github.com/onsi/gomega"
)

func TestCORSHeadersGetRequest(t *testing.T) {
func TestEnvVars(t *testing.T) {
g := NewGomegaWithT(t)

os.Setenv(corsAllowOriginEnvVar, "http://www.google.com")
os.Setenv(corsAllowOriginHeadersVar, "Accept")
defer os.Unsetenv(corsAllowOriginEnvVar)
defer os.Unsetenv(corsAllowOriginHeadersVar)

m := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
wrapped := corsHeaders(m)
wrapped := handleCORSRequests(m)

req := httptest.NewRequest("GET", "http://example.com/foo", nil)
w := httptest.NewRecorder()
Expand All @@ -22,20 +28,37 @@ func TestCORSHeadersGetRequest(t *testing.T) {
defer res.Body.Close()

headerValAllowOrigin := res.Header.Get(corsAllowOriginHeader)
g.Expect(headerValAllowOrigin).To(Equal(corsAllowOriginValue))
g.Expect(headerValAllowOrigin).To(Equal("http://www.google.com"))

headerValAllowHeaders := res.Header.Get(corsAllowHeadersHeader)
g.Expect(headerValAllowHeaders).To(Equal(corsAllowHeadersValue))
g.Expect(headerValAllowHeaders).To(Equal("Accept"))
}

func TestCORSHeadersGetRequest(t *testing.T) {
g := NewGomegaWithT(t)

m := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
wrapped := handleCORSRequests(m)

headerValAllowMethods := res.Header.Get(corsAllowMethodsHeader)
g.Expect(headerValAllowMethods).To(Equal(corsAllowMethodsValue))
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
w := httptest.NewRecorder()
wrapped.ServeHTTP(w, req)

res := w.Result()
defer res.Body.Close()

headerValAllowOrigin := res.Header.Get(corsAllowOriginHeader)
g.Expect(headerValAllowOrigin).To(Equal(corsAllowOriginValueAll))

headerValAllowHeaders := res.Header.Get(corsAllowHeadersHeader)
g.Expect(headerValAllowHeaders).To(Equal(corsAllowHeadersValueDefault))
}

func TestCORSHeadersOptionsRequest(t *testing.T) {
g := NewGomegaWithT(t)

m := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
wrapped := corsHeaders(m)
wrapped := handleCORSRequests(m)

req := httptest.NewRequest("GET", "http://example.com/foo", nil)
w := httptest.NewRecorder()
Expand All @@ -45,13 +68,10 @@ func TestCORSHeadersOptionsRequest(t *testing.T) {
defer res.Body.Close()

headerValAllowOrigin := res.Header.Get(corsAllowOriginHeader)
g.Expect(headerValAllowOrigin).To(Equal(corsAllowOriginValue))
g.Expect(headerValAllowOrigin).To(Equal(corsAllowOriginValueAll))

headerValAllowHeaders := res.Header.Get(corsAllowHeadersHeader)
g.Expect(headerValAllowHeaders).To(Equal(corsAllowHeadersValue))

headerValAllowMethods := res.Header.Get(corsAllowMethodsHeader)
g.Expect(headerValAllowMethods).To(Equal(corsAllowMethodsValue))
g.Expect(headerValAllowHeaders).To(Equal(corsAllowHeadersValueDefault))

statusCode := res.StatusCode
g.Expect(statusCode).To(Equal(http.StatusOK))
Expand Down
1 change: 1 addition & 0 deletions executor/api/rest/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ func (r *SeldonRestApi) Initialise() {
r.Router.Use(puidHeader)
r.Router.Use(cloudeventHeaderMiddleware.Middleware)
r.Router.Use(xssMiddleware)
r.Router.Use(mux.CORSMethodMiddleware(r.Router))
r.Router.Use(handleCORSRequests)

switch r.Protocol {
Expand Down
10 changes: 10 additions & 0 deletions executor/api/util/utils.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package util

import (
"os"

"github.com/seldonio/seldon-core/executor/api/grpc/seldon/proto"
)

Expand Down Expand Up @@ -34,3 +36,11 @@ func ExtractRouteFromSeldonMessage(msg *proto.SeldonMessage) []int {
}
return []int{-1}
}

// Get an environment variable given by key or return the fallback.
func GetEnv(key, fallback string) string {
if value, ok := os.LookupEnv(key); ok {
return value
}
return fallback
}
1 change: 1 addition & 0 deletions executor/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ require (
github.com/uber/jaeger-client-go v2.21.1+incompatible
github.com/uber/jaeger-lib v2.2.0+incompatible // indirect
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7
google.golang.org/appengine v1.6.5
google.golang.org/grpc v1.28.0
gotest.tools v2.2.0+incompatible
k8s.io/api v0.17.2
Expand Down

0 comments on commit 65dfd7e

Please sign in to comment.