From 9f70c7da4387ec0c16d9daa4748ea9f285cfb3a1 Mon Sep 17 00:00:00 2001 From: nobuyo Date: Mon, 30 Jan 2023 22:57:02 +0900 Subject: [PATCH] Add test for skipping approval Signed-off-by: nobuyo --- server/handlers_test.go | 180 ++++++++++++++++++++++++++++++++++++++++ server/server_test.go | 2 +- 2 files changed, 181 insertions(+), 1 deletion(-) diff --git a/server/handlers_test.go b/server/handlers_test.go index fb1a05064f..fdda922b83 100644 --- a/server/handlers_test.go +++ b/server/handlers_test.go @@ -5,10 +5,12 @@ import ( "context" "encoding/json" "errors" + "fmt" "net/http" "net/http/httptest" "net/url" "path" + "strings" "testing" "time" @@ -310,3 +312,181 @@ func TestPasswordConnectorDataNotEmpty(t *testing.T) { require.NoError(t, err) require.Equal(t, `{"test": "true"}`, string(newSess.ConnectorData)) } + +func TestHandlePasswordLoginWithSkipApproval(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + connID := "mockPw" + authReqID := "test" + expiry := time.Now().Add(100 * time.Second) + resTypes := []string{"code"} + + tests := []struct { + name string + skipApproval bool + authReq storage.AuthRequest + }{ + { + name: "Force approval", + skipApproval: false, + authReq: storage.AuthRequest{ + ID: authReqID, + ConnectorID: connID, + RedirectURI: "cb", + Expiry: expiry, + ResponseTypes: resTypes, + ForceApprovalPrompt: true, + }, + }, + { + name: "Skip approval by server config", + skipApproval: true, + authReq: storage.AuthRequest{ + ID: authReqID, + ConnectorID: connID, + RedirectURI: "cb", + Expiry: expiry, + ResponseTypes: resTypes, + ForceApprovalPrompt: true, + }, + }, + { + name: "Skip approval by auth request", + skipApproval: false, + authReq: storage.AuthRequest{ + ID: authReqID, + ConnectorID: connID, + RedirectURI: "cb", + Expiry: expiry, + ResponseTypes: resTypes, + ForceApprovalPrompt: false, + }, + }, + } + + for _, tc := range tests { + httpServer, s := newTestServer(ctx, t, func(c *Config) { + c.SkipApprovalScreen = tc.skipApproval + c.Now = time.Now + }) + defer httpServer.Close() + + sc := storage.Connector{ + ID: connID, + Type: "mockPassword", + Name: "MockPassword", + ResourceVersion: "1", + Config: []byte("{\"username\": \"foo\", \"password\": \"password\"}"), + } + if err := s.storage.CreateConnector(sc); err != nil { + t.Fatalf("create connector: %v", err) + } + if _, err := s.OpenConnector(sc); err != nil { + t.Fatalf("open connector: %v", err) + } + if err := s.storage.CreateAuthRequest(tc.authReq); err != nil { + t.Fatalf("failed to create AuthRequest: %v", err) + } + + rr := httptest.NewRecorder() + + path := fmt.Sprintf("/auth/%s/login?state=%s&back=&login=foo&password=password", connID, authReqID) + s.handlePasswordLogin(rr, httptest.NewRequest("POST", path, nil)) + + require.Equal(t, 303, rr.Code) + + resp := rr.Result() + cbPath := strings.Split(resp.Header.Get("Location"), "?")[0] + + if tc.skipApproval || !tc.authReq.ForceApprovalPrompt { + require.Equal(t, "/auth/mockPw/cb", cbPath) + } else { + require.Equal(t, "/approval", cbPath) + } + + resp.Body.Close() + } +} + +func TestHandleConnectorCallbackWithSkipApproval(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + connID := "mock" + authReqID := "test" + expiry := time.Now().Add(100 * time.Second) + resTypes := []string{"code"} + + tests := []struct { + name string + skipApproval bool + authReq storage.AuthRequest + }{ + { + name: "Force approval", + skipApproval: false, + authReq: storage.AuthRequest{ + ID: authReqID, + ConnectorID: connID, + RedirectURI: "cb", + Expiry: expiry, + ResponseTypes: resTypes, + ForceApprovalPrompt: true, + }, + }, + { + name: "Skip approval by server config", + skipApproval: true, + authReq: storage.AuthRequest{ + ID: authReqID, + ConnectorID: connID, + RedirectURI: "cb", + Expiry: expiry, + ResponseTypes: resTypes, + ForceApprovalPrompt: true, + }, + }, + { + name: "Skip approval by auth request", + skipApproval: false, + authReq: storage.AuthRequest{ + ID: authReqID, + ConnectorID: connID, + RedirectURI: "cb", + Expiry: expiry, + ResponseTypes: resTypes, + ForceApprovalPrompt: false, + }, + }, + } + + for _, tc := range tests { + httpServer, s := newTestServer(ctx, t, func(c *Config) { + c.SkipApprovalScreen = tc.skipApproval + c.Now = time.Now + }) + defer httpServer.Close() + + if err := s.storage.CreateAuthRequest(tc.authReq); err != nil { + t.Fatalf("failed to create AuthRequest: %v", err) + } + rr := httptest.NewRecorder() + + path := fmt.Sprintf("/callback/%s?state=%s", connID, authReqID) + s.handleConnectorCallback(rr, httptest.NewRequest("GET", path, nil)) + + require.Equal(t, 303, rr.Code) + + resp := rr.Result() + cbPath := strings.Split(resp.Header.Get("Location"), "?")[0] + + if tc.skipApproval || !tc.authReq.ForceApprovalPrompt { + require.Equal(t, "/callback/cb", cbPath) + } else { + require.Equal(t, "/approval", cbPath) + } + + resp.Body.Close() + } +} diff --git a/server/server_test.go b/server/server_test.go index e54e80af56..aa34be8c27 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -98,6 +98,7 @@ func newTestServer(ctx context.Context, t *testing.T, updateConfig func(c *Confi Logger: logger, PrometheusRegistry: prometheus.NewRegistry(), HealthChecker: gosundheit.New(), + SkipApprovalScreen: true, // Don't prompt for approval, just immediately redirect with code. } if updateConfig != nil { updateConfig(&config) @@ -118,7 +119,6 @@ func newTestServer(ctx context.Context, t *testing.T, updateConfig func(c *Confi if server, err = newServer(ctx, config, staticRotationStrategy(testKey)); err != nil { t.Fatal(err) } - server.skipApproval = true // Don't prompt for approval, just immediately redirect with code. // Default rotation policy if server.refreshTokenPolicy == nil {