Skip to content

Commit

Permalink
chore: Refactor terminal handler to use auth-middleware (#12052)
Browse files Browse the repository at this point in the history
* chore: Refactor terminal handler to use auth-middleware

Signed-off-by: Leonardo Luz Almeida <leonardo_almeida@intuit.com>

* remove context key for now

Signed-off-by: Leonardo Luz Almeida <leonardo_almeida@intuit.com>

* implement unit-tests

Signed-off-by: Leonardo Luz Almeida <leonardo_almeida@intuit.com>

* remove claim valid check for now

Signed-off-by: Leonardo Luz Almeida <leonardo_almeida@intuit.com>

* remove unnecessary test

Signed-off-by: Leonardo Luz Almeida <leonardo_almeida@intuit.com>

* fix lint

Signed-off-by: Leonardo Luz Almeida <leonardo_almeida@intuit.com>

* don't too much details in http response

Signed-off-by: Leonardo Luz Almeida <leonardo_almeida@intuit.com>

* Fix error

Signed-off-by: Leonardo Luz Almeida <leonardo_almeida@intuit.com>

* Fix lint

Signed-off-by: Leonardo Luz Almeida <leonardo_almeida@intuit.com>

* trigger build

Signed-off-by: Leonardo Luz Almeida <leonardo_almeida@intuit.com>

* builder pattern in terminal feature-flag middleware

Signed-off-by: Leonardo Luz Almeida <leonardo_almeida@intuit.com>

Signed-off-by: Leonardo Luz Almeida <leonardo_almeida@intuit.com>
  • Loading branch information
leoluz authored and crenshaw-dev committed Jan 27, 2023
1 parent 68f58c9 commit 8f4678e
Show file tree
Hide file tree
Showing 4 changed files with 203 additions and 37 deletions.
21 changes: 21 additions & 0 deletions server/application/terminal.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"github.com/argoproj/argo-cd/v2/util/db"
"github.com/argoproj/argo-cd/v2/util/rbac"
sessionmgr "github.com/argoproj/argo-cd/v2/util/session"
"github.com/argoproj/argo-cd/v2/util/settings"
)

type terminalHandler struct {
Expand Down Expand Up @@ -91,6 +92,26 @@ func isValidContainerName(name string) bool {
return len(validationErrors) == 0
}

type GetSettingsFunc func() (*settings.ArgoCDSettings, error)

// WithFeatureFlagMiddleware is an HTTP middleware to verify if the terminal
// feature is enabled before invoking the main handler
func (s *terminalHandler) WithFeatureFlagMiddleware(getSettings GetSettingsFunc) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
argocdSettings, err := getSettings()
if err != nil {
log.Errorf("error executing WithFeatureFlagMiddleware: error getting settings: %s", err)
http.Error(w, "Failed to get settings", http.StatusBadRequest)
return
}
if !argocdSettings.ExecEnabled {
w.WriteHeader(http.StatusNotFound)
return
}
s.ServeHTTP(w, r)
})
}

func (s *terminalHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
q := r.URL.Query()

Expand Down
44 changes: 7 additions & 37 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package server

import (
"context"
netCtx "context"
"crypto/tls"
"fmt"
goio "io"
Expand All @@ -24,8 +25,6 @@ import (
// nolint:staticcheck
golang_proto "github.com/golang/protobuf/proto"

netCtx "context"

"github.com/argoproj/pkg/sync"
"github.com/go-redis/redis/v8"
"github.com/golang-jwt/jwt/v4"
Expand Down Expand Up @@ -55,6 +54,8 @@ import (
"k8s.io/client-go/kubernetes"
"k8s.io/client-go/tools/cache"

"github.com/pkg/errors"

"github.com/argoproj/argo-cd/v2/common"
"github.com/argoproj/argo-cd/v2/pkg/apiclient"
accountpkg "github.com/argoproj/argo-cd/v2/pkg/apiclient/account"
Expand Down Expand Up @@ -112,7 +113,6 @@ import (
"github.com/argoproj/argo-cd/v2/util/swagger"
tlsutil "github.com/argoproj/argo-cd/v2/util/tls"
"github.com/argoproj/argo-cd/v2/util/webhook"
"github.com/pkg/errors"
)

const maxConcurrentLoginRequestsCountEnv = "ARGOCD_MAX_CONCURRENT_LOGIN_REQUESTS_COUNT"
Expand Down Expand Up @@ -812,40 +812,10 @@ func (a *ArgoCDServer) newHTTPServer(ctx context.Context, port int, grpcWebHandl
}
mux.Handle("/api/", handler)

terminalHandler := application.NewHandler(a.appLister, a.db, a.enf, a.Cache, appResourceTreeFn, a.settings.ExecShells)
mux.HandleFunc("/terminal", func(writer http.ResponseWriter, request *http.Request) {
argocdSettings, err := a.settingsMgr.GetSettings()
if err != nil {
http.Error(writer, fmt.Sprintf("Failed to get settings: %v", err), http.StatusBadRequest)
return
}
if !argocdSettings.ExecEnabled {
writer.WriteHeader(http.StatusNotFound)
return
}

if !a.DisableAuth {
ctx := request.Context()
cookies := request.Cookies()
tokenString, err := httputil.JoinCookies(common.AuthCookieName, cookies)
if err == nil && jwtutil.IsValid(tokenString) {
claims, _, err := a.sessionMgr.VerifyToken(tokenString)
if err != nil {
// nolint:staticcheck
ctx = context.WithValue(ctx, util_session.AuthErrorCtxKey, err)
} else if claims != nil {
// Add claims to the context to inspect for RBAC
// nolint:staticcheck
ctx = context.WithValue(ctx, "claims", claims)
}
request = request.WithContext(ctx)
} else {
writer.WriteHeader(http.StatusUnauthorized)
return
}
}
terminalHandler.ServeHTTP(writer, request)
})
terminal := application.NewHandler(a.appLister, a.db, a.enf, a.Cache, appResourceTreeFn, a.settings.ExecShells).
WithFeatureFlagMiddleware(a.settingsMgr.GetSettings)
th := util_session.WithAuthMiddleware(a.DisableAuth, a.sessionMgr, terminal)
mux.Handle("/terminal", th)

mustRegisterGWHandler(versionpkg.RegisterVersionServiceHandler, ctx, gwmux, conn)
mustRegisterGWHandler(clusterpkg.RegisterClusterServiceHandler, ctx, gwmux, conn)
Expand Down
41 changes: 41 additions & 0 deletions util/session/sessionmanager.go
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,47 @@ func (mgr *SessionManager) VerifyUsernamePassword(username string, password stri
return nil
}

// AuthMiddlewareFunc returns a function that can be used as an
// authentication middleware for HTTP requests.
func (mgr *SessionManager) AuthMiddlewareFunc(disabled bool) func(http.Handler) http.Handler {
return func(h http.Handler) http.Handler {
return WithAuthMiddleware(disabled, mgr, h)
}
}

// TokenVerifier defines the contract to invoke token
// verification logic
type TokenVerifier interface {
VerifyToken(token string) (jwt.Claims, string, error)
}

// WithAuthMiddleware is an HTTP middleware used to ensure incoming
// requests are authenticated before invoking the target handler. If
// disabled is true, it will just invoke the next handler in the chain.
func WithAuthMiddleware(disabled bool, authn TokenVerifier, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !disabled {
cookies := r.Cookies()
tokenString, err := httputil.JoinCookies(common.AuthCookieName, cookies)
if err != nil {
http.Error(w, "Auth cookie not found", http.StatusBadRequest)
return
}
claims, _, err := authn.VerifyToken(tokenString)
if err != nil {
http.Error(w, "Invalid token", http.StatusUnauthorized)
return
}
ctx := r.Context()
// Add claims to the context to inspect for RBAC
// nolint:staticcheck
ctx = context.WithValue(ctx, "claims", claims)
r = r.WithContext(ctx)
}
next.ServeHTTP(w, r)
})
}

// VerifyToken verifies if a token is correct. Tokens can be issued either from us or by an IDP.
// We choose how to verify based on the issuer.
func (mgr *SessionManager) VerifyToken(tokenString string) (jwt.Claims, string, error) {
Expand Down
134 changes: 134 additions & 0 deletions util/session/sessionmanager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@ package session
import (
"context"
"encoding/pem"
stderrors "errors"
"fmt"
"io"
"math"
"net/http"
"net/http/httptest"
"os"
"strconv"
"strings"
Expand Down Expand Up @@ -221,6 +225,136 @@ func TestSessionManager_ProjectToken(t *testing.T) {
})
}

type claimsMock struct {
err error
}

func (cm *claimsMock) Valid() error {
return cm.err
}

type tokenVerifierMock struct {
claims *claimsMock
err error
}

func (tm *tokenVerifierMock) VerifyToken(token string) (jwt.Claims, string, error) {
if tm.claims == nil {
return nil, "", tm.err
}
return tm.claims, "", tm.err
}

func strPointer(str string) *string {
return &str
}

func TestSessionManager_WithAuthMiddleware(t *testing.T) {
handlerFunc := func() func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
t.Helper()
w.WriteHeader(http.StatusOK)
w.Header().Set("Content-Type", "application/text")
_, err := w.Write([]byte("Ok"))
if err != nil {
t.Fatalf("error writing response: %s", err)
}
}
}
type testCase struct {
name string
authDisabled bool
cookieHeader bool
verifiedClaims *claimsMock
verifyTokenErr error
expectedStatusCode int
expectedResponseBody *string
}

cases := []testCase{
{
name: "will authenticate successfully",
authDisabled: false,
cookieHeader: true,
verifiedClaims: &claimsMock{},
verifyTokenErr: nil,
expectedStatusCode: 200,
expectedResponseBody: strPointer("Ok"),
},
{
name: "will be noop if auth is disabled",
authDisabled: true,
cookieHeader: false,
verifiedClaims: nil,
verifyTokenErr: nil,
expectedStatusCode: 200,
expectedResponseBody: strPointer("Ok"),
},
{
name: "will return 400 if no cookie header",
authDisabled: false,
cookieHeader: false,
verifiedClaims: &claimsMock{},
verifyTokenErr: nil,
expectedStatusCode: 400,
expectedResponseBody: nil,
},
{
name: "will return 401 verify token fails",
authDisabled: false,
cookieHeader: true,
verifiedClaims: &claimsMock{},
verifyTokenErr: stderrors.New("token error"),
expectedStatusCode: 401,
expectedResponseBody: nil,
},
{
name: "will return 200 if claims are nil",
authDisabled: false,
cookieHeader: true,
verifiedClaims: nil,
verifyTokenErr: nil,
expectedStatusCode: 200,
expectedResponseBody: strPointer("Ok"),
},
}
for _, tc := range cases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
// given
mux := http.NewServeMux()
mux.HandleFunc("/", handlerFunc())
tm := &tokenVerifierMock{
claims: tc.verifiedClaims,
err: tc.verifyTokenErr,
}
ts := httptest.NewServer(WithAuthMiddleware(tc.authDisabled, tm, mux))
defer ts.Close()
req, err := http.NewRequest(http.MethodGet, ts.URL, nil)
if err != nil {
t.Fatalf("error creating request: %s", err)
}
if tc.cookieHeader {
req.Header.Add("Cookie", "argocd.token=123456")
}

// when
resp, err := http.DefaultClient.Do(req)

// then
assert.NoError(t, err)
assert.NotNil(t, resp)
assert.Equal(t, tc.expectedStatusCode, resp.StatusCode)
if tc.expectedResponseBody != nil {
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
actual := strings.TrimSuffix(string(body), "\n")
assert.Contains(t, actual, *tc.expectedResponseBody)
}
})
}
}

var loggedOutContext = context.Background()

// nolint:staticcheck
Expand Down

0 comments on commit 8f4678e

Please sign in to comment.