From 2bb3e24db93a3a7a5204ac89dcc75ea280529f12 Mon Sep 17 00:00:00 2001 From: Alex Tugarev Date: Thu, 9 Feb 2023 16:42:38 +0000 Subject: [PATCH] [oidc] assertIssuerIsReachable --- .../public-api-server/pkg/apiv1/oidc.go | 33 +++++++++++ .../public-api-server/pkg/apiv1/oidc_test.go | 58 +++++++++++-------- 2 files changed, 66 insertions(+), 25 deletions(-) diff --git a/components/public-api-server/pkg/apiv1/oidc.go b/components/public-api-server/pkg/apiv1/oidc.go index 9ef052d766d9af..9435e02e37251f 100644 --- a/components/public-api-server/pkg/apiv1/oidc.go +++ b/components/public-api-server/pkg/apiv1/oidc.go @@ -6,8 +6,11 @@ package apiv1 import ( "context" + "crypto/tls" "errors" "fmt" + "net/http" + "time" connect "github.com/bufbuild/connect-go" goidc "github.com/coreos/go-oidc/v3/oidc" @@ -50,6 +53,11 @@ func (s *OIDCService) CreateClientConfig(ctx context.Context, req *connect.Reque return nil, err } + err = assertIssuerIsReachable(req.Msg.GetConfig().GetOidcConfig().GetIssuer()) + if err != nil { + return nil, connect.NewError(connect.CodeInvalidArgument, err) + } + logger := log.WithField("organization_id", organizationID.String()) conn, err := s.getConnection(ctx) @@ -318,3 +326,28 @@ func toDbOIDCSpec(oauth2Config *v1.OAuth2Config, oidcConfig *v1.OIDCConfig) db.O Scopes: append([]string{goidc.ScopeOpenID, "profile", "email"}, oauth2Config.GetScopes()...), } } + +func assertIssuerIsReachable(host string) error { + tr := &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + Proxy: http.ProxyFromEnvironment, + } + client := &http.Client{ + Transport: tr, + Timeout: 2 * time.Second, + // never follow redirects + CheckRedirect: func(*http.Request, []*http.Request) error { + return http.ErrUseLastResponse + }, + } + + resp, err := client.Get(host) + if err != nil { + return err + } + resp.Body.Close() + if resp.StatusCode > 499 { + return fmt.Errorf("returned status %d", resp.StatusCode) + } + return nil +} diff --git a/components/public-api-server/pkg/apiv1/oidc_test.go b/components/public-api-server/pkg/apiv1/oidc_test.go index 0e46241144a321..65c4c39c5b6983 100644 --- a/components/public-api-server/pkg/apiv1/oidc_test.go +++ b/components/public-api-server/pkg/apiv1/oidc_test.go @@ -48,7 +48,7 @@ func TestOIDCService_CreateClientConfig_FeatureFlagDisabled(t *testing.T) { organizationID := uuid.New() t.Run("returns unauthorized", func(t *testing.T) { - serverMock, client, _ := setupOIDCService(t, withOIDCFeatureDisabled) + serverMock, client, _, issuer := setupOIDCService(t, withOIDCFeatureDisabled) serverMock.EXPECT().GetLoggedInUser(gomock.Any()).Return(user, nil) serverMock.EXPECT().GetTeams(gomock.Any()).Return(teams, nil) @@ -56,6 +56,9 @@ func TestOIDCService_CreateClientConfig_FeatureFlagDisabled(t *testing.T) { _, err := client.CreateClientConfig(context.Background(), connect.NewRequest(&v1.CreateClientConfigRequest{ Config: &v1.OIDCClientConfig{ OrganizationId: organizationID.String(), + OidcConfig: &v1.OIDCConfig{ + Issuer: issuer, + }, }, })) require.Error(t, err) @@ -67,10 +70,10 @@ func TestOIDCService_CreateClientConfig_FeatureFlagEnabled(t *testing.T) { organizationID := uuid.New() t.Run("returns invalid argument when no organisation specified", func(t *testing.T) { - _, client, _ := setupOIDCService(t, withOIDCFeatureEnabled) + _, client, _, issuer := setupOIDCService(t, withOIDCFeatureEnabled) config := &v1.OIDCClientConfig{ - OidcConfig: &v1.OIDCConfig{Issuer: "test-issuer"}, + OidcConfig: &v1.OIDCConfig{Issuer: issuer}, Oauth2Config: &v1.OAuth2Config{ClientId: "test-id", ClientSecret: "test-secret"}, } _, err := client.CreateClientConfig(context.Background(), connect.NewRequest(&v1.CreateClientConfigRequest{ @@ -81,11 +84,11 @@ func TestOIDCService_CreateClientConfig_FeatureFlagEnabled(t *testing.T) { }) t.Run("returns invalid argument when organisation id is not a uuid", func(t *testing.T) { - _, client, _ := setupOIDCService(t, withOIDCFeatureEnabled) + _, client, _, issuer := setupOIDCService(t, withOIDCFeatureEnabled) config := &v1.OIDCClientConfig{ OrganizationId: "some-random-id", - OidcConfig: &v1.OIDCConfig{Issuer: "test-issuer"}, + OidcConfig: &v1.OIDCConfig{Issuer: issuer}, Oauth2Config: &v1.OAuth2Config{ClientId: "test-id", ClientSecret: "test-secret"}, } _, err := client.CreateClientConfig(context.Background(), connect.NewRequest(&v1.CreateClientConfigRequest{ @@ -96,13 +99,13 @@ func TestOIDCService_CreateClientConfig_FeatureFlagEnabled(t *testing.T) { }) t.Run("creates oidc client config", func(t *testing.T) { - serverMock, client, dbConn := setupOIDCService(t, withOIDCFeatureEnabled) + serverMock, client, dbConn, issuer := setupOIDCService(t, withOIDCFeatureEnabled) serverMock.EXPECT().GetLoggedInUser(gomock.Any()).Return(user, nil) config := &v1.OIDCClientConfig{ OrganizationId: organizationID.String(), - OidcConfig: &v1.OIDCConfig{Issuer: "test-issuer"}, + OidcConfig: &v1.OIDCConfig{Issuer: issuer}, Oauth2Config: &v1.OAuth2Config{ ClientId: "test-id", ClientSecret: "test-secret", @@ -140,7 +143,7 @@ func TestOIDCService_CreateClientConfig_FeatureFlagEnabled(t *testing.T) { } func TestOIDCService_GetClientConfig_WithFeatureFlagDisabled(t *testing.T) { - serverMock, client, _ := setupOIDCService(t, withOIDCFeatureDisabled) + serverMock, client, _, _ := setupOIDCService(t, withOIDCFeatureDisabled) serverMock.EXPECT().GetLoggedInUser(gomock.Any()).Return(user, nil) serverMock.EXPECT().GetTeams(gomock.Any()).Return(teams, nil) @@ -156,7 +159,7 @@ func TestOIDCService_GetClientConfig_WithFeatureFlagDisabled(t *testing.T) { func TestOIDCService_GetClientConfig_WithFeatureFlagEnabled(t *testing.T) { t.Run("invalid argument when config id missing", func(t *testing.T) { - _, client, _ := setupOIDCService(t, withOIDCFeatureEnabled) + _, client, _, _ := setupOIDCService(t, withOIDCFeatureEnabled) _, err := client.GetClientConfig(context.Background(), connect.NewRequest(&v1.GetClientConfigRequest{})) require.Error(t, err) @@ -164,7 +167,7 @@ func TestOIDCService_GetClientConfig_WithFeatureFlagEnabled(t *testing.T) { }) t.Run("invalid argument when organization id missing", func(t *testing.T) { - _, client, _ := setupOIDCService(t, withOIDCFeatureEnabled) + _, client, _, _ := setupOIDCService(t, withOIDCFeatureEnabled) _, err := client.GetClientConfig(context.Background(), connect.NewRequest(&v1.GetClientConfigRequest{ Id: uuid.NewString(), @@ -174,7 +177,7 @@ func TestOIDCService_GetClientConfig_WithFeatureFlagEnabled(t *testing.T) { }) t.Run("not found when record does not exist", func(t *testing.T) { - serverMock, client, _ := setupOIDCService(t, withOIDCFeatureEnabled) + serverMock, client, _, _ := setupOIDCService(t, withOIDCFeatureEnabled) serverMock.EXPECT().GetLoggedInUser(gomock.Any()).Return(user, nil) @@ -187,7 +190,7 @@ func TestOIDCService_GetClientConfig_WithFeatureFlagEnabled(t *testing.T) { }) t.Run("retrieves record when it exists", func(t *testing.T) { - serverMock, client, dbConn := setupOIDCService(t, withOIDCFeatureEnabled) + serverMock, client, dbConn, issuer := setupOIDCService(t, withOIDCFeatureEnabled) serverMock.EXPECT().GetLoggedInUser(gomock.Any()).Return(user, nil) @@ -195,6 +198,7 @@ func TestOIDCService_GetClientConfig_WithFeatureFlagEnabled(t *testing.T) { created := dbtest.CreateOIDCClientConfigs(t, dbConn, db.OIDCClientConfig{ OrganizationID: &orgID, + Issuer: issuer, })[0] resp, err := client.GetClientConfig(context.Background(), connect.NewRequest(&v1.GetClientConfigRequest{ @@ -213,7 +217,7 @@ func TestOIDCService_GetClientConfig_WithFeatureFlagEnabled(t *testing.T) { } func TestOIDCService_ListClientConfigs_WithFeatureFlagDisabled(t *testing.T) { - serverMock, client, _ := setupOIDCService(t, withOIDCFeatureDisabled) + serverMock, client, _, _ := setupOIDCService(t, withOIDCFeatureDisabled) serverMock.EXPECT().GetLoggedInUser(gomock.Any()).Return(user, nil) serverMock.EXPECT().GetTeams(gomock.Any()).Return(teams, nil) @@ -228,7 +232,7 @@ func TestOIDCService_ListClientConfigs_WithFeatureFlagDisabled(t *testing.T) { func TestOIDCService_ListClientConfigs_WithFeatureFlagEnabled(t *testing.T) { t.Run("invalid argument when organization id missing", func(t *testing.T) { - _, client, _ := setupOIDCService(t, withOIDCFeatureEnabled) + _, client, _, _ := setupOIDCService(t, withOIDCFeatureEnabled) _, err := client.ListClientConfigs(context.Background(), connect.NewRequest(&v1.ListClientConfigsRequest{})) require.Error(t, err) @@ -236,7 +240,7 @@ func TestOIDCService_ListClientConfigs_WithFeatureFlagEnabled(t *testing.T) { }) t.Run("invalid argument when organization id is invalid", func(t *testing.T) { - _, client, _ := setupOIDCService(t, withOIDCFeatureEnabled) + _, client, _, _ := setupOIDCService(t, withOIDCFeatureEnabled) _, err := client.ListClientConfigs(context.Background(), connect.NewRequest(&v1.ListClientConfigsRequest{ OrganizationId: "some-invalid-id", @@ -246,7 +250,7 @@ func TestOIDCService_ListClientConfigs_WithFeatureFlagEnabled(t *testing.T) { }) t.Run("retrieves configs by organization id", func(t *testing.T) { - serverMock, client, dbConn := setupOIDCService(t, withOIDCFeatureEnabled) + serverMock, client, dbConn, issuer := setupOIDCService(t, withOIDCFeatureEnabled) orgA, orgB := uuid.New(), uuid.New() @@ -255,12 +259,15 @@ func TestOIDCService_ListClientConfigs_WithFeatureFlagEnabled(t *testing.T) { configs := dbtest.CreateOIDCClientConfigs(t, dbConn, dbtest.NewOIDCClientConfig(t, db.OIDCClientConfig{ OrganizationID: &orgA, + Issuer: issuer, }), dbtest.NewOIDCClientConfig(t, db.OIDCClientConfig{ OrganizationID: &orgA, + Issuer: issuer, }), dbtest.NewOIDCClientConfig(t, db.OIDCClientConfig{ OrganizationID: &orgB, + Issuer: issuer, }), ) @@ -290,7 +297,7 @@ func TestOIDCService_ListClientConfigs_WithFeatureFlagEnabled(t *testing.T) { func TestOIDCService_UpdateClientConfig(t *testing.T) { t.Run("feature flag disabled returns unauthorized", func(t *testing.T) { - serverMock, client, _ := setupOIDCService(t, withOIDCFeatureDisabled) + serverMock, client, _, _ := setupOIDCService(t, withOIDCFeatureDisabled) serverMock.EXPECT().GetLoggedInUser(gomock.Any()).Return(user, nil) serverMock.EXPECT().GetTeams(gomock.Any()).Return(teams, nil) @@ -301,7 +308,7 @@ func TestOIDCService_UpdateClientConfig(t *testing.T) { }) t.Run("feature flag enabled returns unimplemented", func(t *testing.T) { - serverMock, client, _ := setupOIDCService(t, withOIDCFeatureEnabled) + serverMock, client, _, _ := setupOIDCService(t, withOIDCFeatureEnabled) serverMock.EXPECT().GetLoggedInUser(gomock.Any()).Return(user, nil) @@ -313,7 +320,7 @@ func TestOIDCService_UpdateClientConfig(t *testing.T) { func TestOIDCService_DeleteClientConfig_WithFeatureFlagDisabled(t *testing.T) { t.Run("feature flag disabled returns unauthorized", func(t *testing.T) { - serverMock, client, _ := setupOIDCService(t, withOIDCFeatureDisabled) + serverMock, client, _, _ := setupOIDCService(t, withOIDCFeatureDisabled) serverMock.EXPECT().GetLoggedInUser(gomock.Any()).Return(user, nil) serverMock.EXPECT().GetTeams(gomock.Any()).Return(teams, nil) @@ -330,7 +337,7 @@ func TestOIDCService_DeleteClientConfig_WithFeatureFlagDisabled(t *testing.T) { func TestOIDCService_DeleteClientConfig_WithFeatureFlagEnabled(t *testing.T) { t.Run("invalid argument when ID not specified", func(t *testing.T) { - _, client, _ := setupOIDCService(t, withOIDCFeatureEnabled) + _, client, _, _ := setupOIDCService(t, withOIDCFeatureEnabled) _, err := client.DeleteClientConfig(context.Background(), connect.NewRequest(&v1.DeleteClientConfigRequest{})) require.Error(t, err) @@ -338,7 +345,7 @@ func TestOIDCService_DeleteClientConfig_WithFeatureFlagEnabled(t *testing.T) { }) t.Run("invalid argument when Organization ID not specified", func(t *testing.T) { - _, client, _ := setupOIDCService(t, withOIDCFeatureEnabled) + _, client, _, _ := setupOIDCService(t, withOIDCFeatureEnabled) _, err := client.DeleteClientConfig(context.Background(), connect.NewRequest(&v1.DeleteClientConfigRequest{ Id: uuid.NewString(), @@ -348,7 +355,7 @@ func TestOIDCService_DeleteClientConfig_WithFeatureFlagEnabled(t *testing.T) { }) t.Run("not found when record does not exist", func(t *testing.T) { - serverMock, client, _ := setupOIDCService(t, withOIDCFeatureEnabled) + serverMock, client, _, _ := setupOIDCService(t, withOIDCFeatureEnabled) serverMock.EXPECT().GetLoggedInUser(gomock.Any()).Return(user, nil) @@ -361,12 +368,13 @@ func TestOIDCService_DeleteClientConfig_WithFeatureFlagEnabled(t *testing.T) { }) t.Run("deletes record", func(t *testing.T) { - serverMock, client, dbConn := setupOIDCService(t, withOIDCFeatureEnabled) + serverMock, client, dbConn, issuer := setupOIDCService(t, withOIDCFeatureEnabled) orgID := uuid.New() created := dbtest.CreateOIDCClientConfigs(t, dbConn, db.OIDCClientConfig{ OrganizationID: &orgID, + Issuer: issuer, })[0] serverMock.EXPECT().GetLoggedInUser(gomock.Any()).Return(user, nil) @@ -380,7 +388,7 @@ func TestOIDCService_DeleteClientConfig_WithFeatureFlagEnabled(t *testing.T) { }) } -func setupOIDCService(t *testing.T, expClient experiments.Client) (*protocol.MockAPIInterface, v1connect.OIDCServiceClient, *gorm.DB) { +func setupOIDCService(t *testing.T, expClient experiments.Client) (*protocol.MockAPIInterface, v1connect.OIDCServiceClient, *gorm.DB, string) { t.Helper() dbConn := dbtest.ConnectForTests(t) @@ -401,5 +409,5 @@ func setupOIDCService(t *testing.T, expClient experiments.Client) (*protocol.Moc auth.NewClientInterceptor("auth-token"), )) - return serverMock, client, dbConn + return serverMock, client, dbConn, srv.URL }