Skip to content

Commit

Permalink
fix: add a hardcoded cookie whitelist for internal cookie names (#1605)
Browse files Browse the repository at this point in the history
  • Loading branch information
endigma authored Feb 19, 2025
1 parent 9d8b2a1 commit 3ddb078
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 5 deletions.
2 changes: 1 addition & 1 deletion router/core/graph_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ func newGraphServer(ctx context.Context, r *Router, routerConfig *nodev1.RouterC
})

if s.headerRules != nil {
cr.Use(rmiddleware.CookieWhitelist(s.headerRules.CookieWhitelist))
cr.Use(rmiddleware.CookieWhitelist(s.headerRules.CookieWhitelist, []string{featureFlagCookie}))
}

// Mount the feature flag handler. It calls the base mux if no feature flag is set.
Expand Down
5 changes: 3 additions & 2 deletions router/internal/middleware/cookie_filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import (
"slices"
)

func CookieWhitelist(cookieWhitelist []string) func(next http.Handler) http.Handler {
func CookieWhitelist(cookieWhitelist []string, cookieSafelist []string) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, rr *http.Request) {
if len(cookieWhitelist) == 0 {
Expand All @@ -17,7 +17,8 @@ func CookieWhitelist(cookieWhitelist []string) func(next http.Handler) http.Hand
rr.Header.Del("Cookie")

for _, cookie := range cookies {
if slices.Contains(cookieWhitelist, cookie.Name) {
if slices.Contains(cookieWhitelist, cookie.Name) ||
slices.Contains(cookieSafelist, cookie.Name) {
rr.AddCookie(cookie)
}
}
Expand Down
55 changes: 53 additions & 2 deletions router/internal/middleware/cookie_filter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func TestCookieWhitelist(t *testing.T) {
req.AddCookie(cookie)
}

CookieWhitelist(cookieWhitelist)(next).ServeHTTP(recorder, req)
CookieWhitelist(cookieWhitelist, []string{})(next).ServeHTTP(recorder, req)

require.Equal(t, http.StatusOK, recorder.Code)
require.Equal(t, cookies, filteredCookies)
Expand Down Expand Up @@ -81,7 +81,58 @@ func TestCookieWhitelist(t *testing.T) {
req.AddCookie(cookie)
}

CookieWhitelist(cookieWhitelist)(next).ServeHTTP(recorder, req)
CookieWhitelist(cookieWhitelist, []string{})(next).ServeHTTP(recorder, req)

require.Equal(t, http.StatusOK, recorder.Code)
require.Equal(t, expectedFilteredCookies, filteredCookies)
})

t.Run("never filter safe listed cookie", func(t *testing.T) {
t.Parallel()

cookieWhitelist := []string{"allowed"}
cookies := []*http.Cookie{
{
Name: "allowed",
Value: "allowed",
},
{
Name: "disallowed",
Value: "disallowed",
},
{
Name: "safelisted",
Value: "safelisted",
},
}

expectedFilteredCookies := []*http.Cookie{
{
Name: "allowed",
Value: "allowed",
},
{
Name: "safelisted",
Value: "safelisted",
},
}

recorder := httptest.NewRecorder()

filteredCookies := []*http.Cookie{}

next := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
filteredCookies = r.Cookies()
})

req, err := http.NewRequest(http.MethodGet, "/", strings.NewReader("test"))
require.NoError(t, err)

for _, cookie := range cookies {
req.AddCookie(cookie)
}

CookieWhitelist(cookieWhitelist, []string{"safelisted"})(next).ServeHTTP(recorder, req)

require.Equal(t, http.StatusOK, recorder.Code)
require.Equal(t, expectedFilteredCookies, filteredCookies)
Expand Down

0 comments on commit 3ddb078

Please sign in to comment.