diff --git a/server/application/terminal.go b/server/application/terminal.go index 85468a3e92f5f..38368f486b665 100644 --- a/server/application/terminal.go +++ b/server/application/terminal.go @@ -25,6 +25,7 @@ import ( "github.com/argoproj/argo-cd/v2/util/rbac" "github.com/argoproj/argo-cd/v2/util/security" sessionmgr "github.com/argoproj/argo-cd/v2/util/session" + "github.com/argoproj/argo-cd/v2/util/settings" ) type terminalHandler struct { @@ -95,6 +96,26 @@ func isValidContainerName(name string) bool { return len(validationErrors) == 0 } +type GetSettingsFunc func() (*settings.ArgoCDSettings, error) + +// WithFeatureFlagMiddleware is an HTTP middleware to verify if the terminal +// feature is enabled before invoking the main handler +func (s *terminalHandler) WithFeatureFlagMiddleware(getSettings GetSettingsFunc) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + argocdSettings, err := getSettings() + if err != nil { + log.Errorf("error executing WithFeatureFlagMiddleware: error getting settings: %s", err) + http.Error(w, "Failed to get settings", http.StatusBadRequest) + return + } + if !argocdSettings.ExecEnabled { + w.WriteHeader(http.StatusNotFound) + return + } + s.ServeHTTP(w, r) + }) +} + func (s *terminalHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { q := r.URL.Query() diff --git a/server/server.go b/server/server.go index 750e8c04369eb..f1302a0f12f45 100644 --- a/server/server.go +++ b/server/server.go @@ -905,40 +905,10 @@ func (a *ArgoCDServer) newHTTPServer(ctx context.Context, port int, grpcWebHandl } mux.Handle("/api/", handler) - terminalHandler := application.NewHandler(a.appLister, a.Namespace, a.ApplicationNamespaces, a.db, a.enf, a.Cache, appResourceTreeFn, a.settings.ExecShells) - mux.HandleFunc("/terminal", func(writer http.ResponseWriter, request *http.Request) { - argocdSettings, err := a.settingsMgr.GetSettings() - if err != nil { - http.Error(writer, fmt.Sprintf("Failed to get settings: %v", err), http.StatusBadRequest) - return - } - if !argocdSettings.ExecEnabled { - writer.WriteHeader(http.StatusNotFound) - return - } - - if !a.DisableAuth { - ctx := request.Context() - cookies := request.Cookies() - tokenString, err := httputil.JoinCookies(common.AuthCookieName, cookies) - if err == nil && jwtutil.IsValid(tokenString) { - claims, _, err := a.sessionMgr.VerifyToken(tokenString) - if err != nil { - // nolint:staticcheck - ctx = context.WithValue(ctx, util_session.AuthErrorCtxKey, err) - } else if claims != nil { - // Add claims to the context to inspect for RBAC - // nolint:staticcheck - ctx = context.WithValue(ctx, "claims", claims) - } - request = request.WithContext(ctx) - } else { - writer.WriteHeader(http.StatusUnauthorized) - return - } - } - terminalHandler.ServeHTTP(writer, request) - }) + terminal := application.NewHandler(a.appLister, a.Namespace, a.ApplicationNamespaces, a.db, a.enf, a.Cache, appResourceTreeFn, a.settings.ExecShells). + WithFeatureFlagMiddleware(a.settingsMgr.GetSettings) + th := util_session.WithAuthMiddleware(a.DisableAuth, a.sessionMgr, terminal) + mux.Handle("/terminal", th) // Dead code for now // Proxy extension is currently an experimental feature and is disabled diff --git a/util/session/sessionmanager.go b/util/session/sessionmanager.go index 3571e2b39e114..35cbfd66fdd70 100644 --- a/util/session/sessionmanager.go +++ b/util/session/sessionmanager.go @@ -463,6 +463,47 @@ func (mgr *SessionManager) VerifyUsernamePassword(username string, password stri return nil } +// AuthMiddlewareFunc returns a function that can be used as an +// authentication middleware for HTTP requests. +func (mgr *SessionManager) AuthMiddlewareFunc(disabled bool) func(http.Handler) http.Handler { + return func(h http.Handler) http.Handler { + return WithAuthMiddleware(disabled, mgr, h) + } +} + +// TokenVerifier defines the contract to invoke token +// verification logic +type TokenVerifier interface { + VerifyToken(token string) (jwt.Claims, string, error) +} + +// WithAuthMiddleware is an HTTP middleware used to ensure incoming +// requests are authenticated before invoking the target handler. If +// disabled is true, it will just invoke the next handler in the chain. +func WithAuthMiddleware(disabled bool, authn TokenVerifier, next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !disabled { + cookies := r.Cookies() + tokenString, err := httputil.JoinCookies(common.AuthCookieName, cookies) + if err != nil { + http.Error(w, "Auth cookie not found", http.StatusBadRequest) + return + } + claims, _, err := authn.VerifyToken(tokenString) + if err != nil { + http.Error(w, "Invalid token", http.StatusUnauthorized) + return + } + ctx := r.Context() + // Add claims to the context to inspect for RBAC + // nolint:staticcheck + ctx = context.WithValue(ctx, "claims", claims) + r = r.WithContext(ctx) + } + next.ServeHTTP(w, r) + }) +} + // VerifyToken verifies if a token is correct. Tokens can be issued either from us or by an IDP. // We choose how to verify based on the issuer. func (mgr *SessionManager) VerifyToken(tokenString string) (jwt.Claims, string, error) { diff --git a/util/session/sessionmanager_test.go b/util/session/sessionmanager_test.go index e646e5016fb0f..57869a8e60381 100644 --- a/util/session/sessionmanager_test.go +++ b/util/session/sessionmanager_test.go @@ -3,8 +3,12 @@ package session import ( "context" "encoding/pem" + stderrors "errors" "fmt" + "io" "math" + "net/http" + "net/http/httptest" "os" "strconv" "strings" @@ -221,6 +225,136 @@ func TestSessionManager_ProjectToken(t *testing.T) { }) } +type claimsMock struct { + err error +} + +func (cm *claimsMock) Valid() error { + return cm.err +} + +type tokenVerifierMock struct { + claims *claimsMock + err error +} + +func (tm *tokenVerifierMock) VerifyToken(token string) (jwt.Claims, string, error) { + if tm.claims == nil { + return nil, "", tm.err + } + return tm.claims, "", tm.err +} + +func strPointer(str string) *string { + return &str +} + +func TestSessionManager_WithAuthMiddleware(t *testing.T) { + handlerFunc := func() func(http.ResponseWriter, *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + t.Helper() + w.WriteHeader(http.StatusOK) + w.Header().Set("Content-Type", "application/text") + _, err := w.Write([]byte("Ok")) + if err != nil { + t.Fatalf("error writing response: %s", err) + } + } + } + type testCase struct { + name string + authDisabled bool + cookieHeader bool + verifiedClaims *claimsMock + verifyTokenErr error + expectedStatusCode int + expectedResponseBody *string + } + + cases := []testCase{ + { + name: "will authenticate successfully", + authDisabled: false, + cookieHeader: true, + verifiedClaims: &claimsMock{}, + verifyTokenErr: nil, + expectedStatusCode: 200, + expectedResponseBody: strPointer("Ok"), + }, + { + name: "will be noop if auth is disabled", + authDisabled: true, + cookieHeader: false, + verifiedClaims: nil, + verifyTokenErr: nil, + expectedStatusCode: 200, + expectedResponseBody: strPointer("Ok"), + }, + { + name: "will return 400 if no cookie header", + authDisabled: false, + cookieHeader: false, + verifiedClaims: &claimsMock{}, + verifyTokenErr: nil, + expectedStatusCode: 400, + expectedResponseBody: nil, + }, + { + name: "will return 401 verify token fails", + authDisabled: false, + cookieHeader: true, + verifiedClaims: &claimsMock{}, + verifyTokenErr: stderrors.New("token error"), + expectedStatusCode: 401, + expectedResponseBody: nil, + }, + { + name: "will return 200 if claims are nil", + authDisabled: false, + cookieHeader: true, + verifiedClaims: nil, + verifyTokenErr: nil, + expectedStatusCode: 200, + expectedResponseBody: strPointer("Ok"), + }, + } + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + // given + mux := http.NewServeMux() + mux.HandleFunc("/", handlerFunc()) + tm := &tokenVerifierMock{ + claims: tc.verifiedClaims, + err: tc.verifyTokenErr, + } + ts := httptest.NewServer(WithAuthMiddleware(tc.authDisabled, tm, mux)) + defer ts.Close() + req, err := http.NewRequest(http.MethodGet, ts.URL, nil) + if err != nil { + t.Fatalf("error creating request: %s", err) + } + if tc.cookieHeader { + req.Header.Add("Cookie", "argocd.token=123456") + } + + // when + resp, err := http.DefaultClient.Do(req) + + // then + assert.NoError(t, err) + assert.NotNil(t, resp) + assert.Equal(t, tc.expectedStatusCode, resp.StatusCode) + if tc.expectedResponseBody != nil { + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + actual := strings.TrimSuffix(string(body), "\n") + assert.Contains(t, actual, *tc.expectedResponseBody) + } + }) + } +} + var loggedOutContext = context.Background() // nolint:staticcheck