Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Refactor terminal handler to use auth-middleware #12052

Merged
merged 11 commits into from
Jan 24, 2023
21 changes: 21 additions & 0 deletions server/application/terminal.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"github.com/argoproj/argo-cd/v2/util/rbac"
"github.com/argoproj/argo-cd/v2/util/security"
sessionmgr "github.com/argoproj/argo-cd/v2/util/session"
"github.com/argoproj/argo-cd/v2/util/settings"
)

type terminalHandler struct {
Expand Down Expand Up @@ -95,6 +96,26 @@ func isValidContainerName(name string) bool {
return len(validationErrors) == 0
}

type GetSettingsFunc func() (*settings.ArgoCDSettings, error)

// WithFeatureFlagMiddleware is an HTTP middleware to verify if the terminal
// feature is enabled before invoking the main handler
func (s *terminalHandler) WithFeatureFlagMiddleware(getSettings GetSettingsFunc) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
argocdSettings, err := getSettings()
if err != nil {
log.Errorf("error executing WithFeatureFlagMiddleware: error getting settings: %s", err)
http.Error(w, "Failed to get settings", http.StatusBadRequest)
return
}
if !argocdSettings.ExecEnabled {
w.WriteHeader(http.StatusNotFound)
return
}
s.ServeHTTP(w, r)
})
}

func (s *terminalHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
q := r.URL.Query()

Expand Down
38 changes: 4 additions & 34 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -905,40 +905,10 @@ func (a *ArgoCDServer) newHTTPServer(ctx context.Context, port int, grpcWebHandl
}
mux.Handle("/api/", handler)

terminalHandler := application.NewHandler(a.appLister, a.Namespace, a.ApplicationNamespaces, a.db, a.enf, a.Cache, appResourceTreeFn, a.settings.ExecShells)
mux.HandleFunc("/terminal", func(writer http.ResponseWriter, request *http.Request) {
argocdSettings, err := a.settingsMgr.GetSettings()
if err != nil {
http.Error(writer, fmt.Sprintf("Failed to get settings: %v", err), http.StatusBadRequest)
return
}
if !argocdSettings.ExecEnabled {
writer.WriteHeader(http.StatusNotFound)
return
}

if !a.DisableAuth {
ctx := request.Context()
cookies := request.Cookies()
tokenString, err := httputil.JoinCookies(common.AuthCookieName, cookies)
if err == nil && jwtutil.IsValid(tokenString) {
claims, _, err := a.sessionMgr.VerifyToken(tokenString)
if err != nil {
// nolint:staticcheck
ctx = context.WithValue(ctx, util_session.AuthErrorCtxKey, err)
} else if claims != nil {
// Add claims to the context to inspect for RBAC
// nolint:staticcheck
ctx = context.WithValue(ctx, "claims", claims)
}
request = request.WithContext(ctx)
} else {
writer.WriteHeader(http.StatusUnauthorized)
return
}
}
terminalHandler.ServeHTTP(writer, request)
})
terminal := application.NewHandler(a.appLister, a.Namespace, a.ApplicationNamespaces, a.db, a.enf, a.Cache, appResourceTreeFn, a.settings.ExecShells).
WithFeatureFlagMiddleware(a.settingsMgr.GetSettings)
th := util_session.WithAuthMiddleware(a.DisableAuth, a.sessionMgr, terminal)
mux.Handle("/terminal", th)

// Dead code for now
// Proxy extension is currently an experimental feature and is disabled
Expand Down
41 changes: 41 additions & 0 deletions util/session/sessionmanager.go
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,47 @@ func (mgr *SessionManager) VerifyUsernamePassword(username string, password stri
return nil
}

// AuthMiddlewareFunc returns a function that can be used as an
// authentication middleware for HTTP requests.
func (mgr *SessionManager) AuthMiddlewareFunc(disabled bool) func(http.Handler) http.Handler {
return func(h http.Handler) http.Handler {
return WithAuthMiddleware(disabled, mgr, h)
}
}

// TokenVerifier defines the contract to invoke token
// verification logic
type TokenVerifier interface {
VerifyToken(token string) (jwt.Claims, string, error)
}

// WithAuthMiddleware is an HTTP middleware used to ensure incoming
// requests are authenticated before invoking the target handler. If
// disabled is true, it will just invoke the next handler in the chain.
func WithAuthMiddleware(disabled bool, authn TokenVerifier, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !disabled {
cookies := r.Cookies()
tokenString, err := httputil.JoinCookies(common.AuthCookieName, cookies)
if err != nil {
http.Error(w, "Auth cookie not found", http.StatusBadRequest)
return
}
claims, _, err := authn.VerifyToken(tokenString)
if err != nil {
http.Error(w, "Invalid token", http.StatusUnauthorized)
return
}
ctx := r.Context()
// Add claims to the context to inspect for RBAC
// nolint:staticcheck
ctx = context.WithValue(ctx, "claims", claims)
r = r.WithContext(ctx)
}
next.ServeHTTP(w, r)
})
}

// VerifyToken verifies if a token is correct. Tokens can be issued either from us or by an IDP.
// We choose how to verify based on the issuer.
func (mgr *SessionManager) VerifyToken(tokenString string) (jwt.Claims, string, error) {
Expand Down
134 changes: 134 additions & 0 deletions util/session/sessionmanager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@ package session
import (
"context"
"encoding/pem"
stderrors "errors"
"fmt"
"io"
"math"
"net/http"
"net/http/httptest"
"os"
"strconv"
"strings"
Expand Down Expand Up @@ -221,6 +225,136 @@ func TestSessionManager_ProjectToken(t *testing.T) {
})
}

type claimsMock struct {
err error
}

func (cm *claimsMock) Valid() error {
return cm.err
}
leoluz marked this conversation as resolved.
Show resolved Hide resolved

type tokenVerifierMock struct {
claims *claimsMock
err error
}

func (tm *tokenVerifierMock) VerifyToken(token string) (jwt.Claims, string, error) {
if tm.claims == nil {
return nil, "", tm.err
}
leoluz marked this conversation as resolved.
Show resolved Hide resolved
return tm.claims, "", tm.err
}

func strPointer(str string) *string {
return &str
}

func TestSessionManager_WithAuthMiddleware(t *testing.T) {
handlerFunc := func() func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
t.Helper()
w.WriteHeader(http.StatusOK)
w.Header().Set("Content-Type", "application/text")
_, err := w.Write([]byte("Ok"))
if err != nil {
t.Fatalf("error writing response: %s", err)
}
}
}
type testCase struct {
name string
authDisabled bool
cookieHeader bool
verifiedClaims *claimsMock
verifyTokenErr error
expectedStatusCode int
expectedResponseBody *string
}

cases := []testCase{
{
name: "will authenticate successfully",
authDisabled: false,
cookieHeader: true,
verifiedClaims: &claimsMock{},
verifyTokenErr: nil,
expectedStatusCode: 200,
expectedResponseBody: strPointer("Ok"),
},
{
name: "will be noop if auth is disabled",
authDisabled: true,
cookieHeader: false,
verifiedClaims: nil,
verifyTokenErr: nil,
expectedStatusCode: 200,
expectedResponseBody: strPointer("Ok"),
},
{
name: "will return 400 if no cookie header",
authDisabled: false,
cookieHeader: false,
verifiedClaims: &claimsMock{},
verifyTokenErr: nil,
expectedStatusCode: 400,
expectedResponseBody: nil,
},
{
name: "will return 401 verify token fails",
authDisabled: false,
cookieHeader: true,
verifiedClaims: &claimsMock{},
verifyTokenErr: stderrors.New("token error"),
expectedStatusCode: 401,
expectedResponseBody: nil,
},
{
name: "will return 200 if claims are nil",
authDisabled: false,
cookieHeader: true,
verifiedClaims: nil,
verifyTokenErr: nil,
expectedStatusCode: 200,
expectedResponseBody: strPointer("Ok"),
},
}
for _, tc := range cases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
// given
mux := http.NewServeMux()
mux.HandleFunc("/", handlerFunc())
tm := &tokenVerifierMock{
claims: tc.verifiedClaims,
err: tc.verifyTokenErr,
}
ts := httptest.NewServer(WithAuthMiddleware(tc.authDisabled, tm, mux))
defer ts.Close()
req, err := http.NewRequest(http.MethodGet, ts.URL, nil)
if err != nil {
t.Fatalf("error creating request: %s", err)
}
if tc.cookieHeader {
req.Header.Add("Cookie", "argocd.token=123456")
}

// when
resp, err := http.DefaultClient.Do(req)

// then
assert.NoError(t, err)
assert.NotNil(t, resp)
assert.Equal(t, tc.expectedStatusCode, resp.StatusCode)
if tc.expectedResponseBody != nil {
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
actual := strings.TrimSuffix(string(body), "\n")
assert.Contains(t, actual, *tc.expectedResponseBody)
}
})
}
}

var loggedOutContext = context.Background()

// nolint:staticcheck
Expand Down