Skip to content

[OIDC] Check issuer URL for reachability #16331

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions components/public-api-server/pkg/apiv1/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@ package apiv1

import (
"context"
"crypto/tls"
"errors"
"fmt"
"net/http"
"time"

connect "github.com/bufbuild/connect-go"
goidc "github.com/coreos/go-oidc/v3/oidc"
Expand Down Expand Up @@ -50,6 +53,11 @@ func (s *OIDCService) CreateClientConfig(ctx context.Context, req *connect.Reque
return nil, err
}

err = assertIssuerIsReachable(req.Msg.GetConfig().GetOidcConfig().GetIssuer())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it make sense to check the URL actually parses as a URL, before we make this request?

How do we make sure this doesn't get used as a way to DDOS third party systems? If I supply a https://my-victim.org, we'd hit those endpoints on behalf of the attacker.

One way to guard would be to rate-limit the create itself.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 for rate-limiting.

if err != nil {
return nil, connect.NewError(connect.CodeInvalidArgument, err)
}

logger := log.WithField("organization_id", organizationID.String())

conn, err := s.getConnection(ctx)
Expand Down Expand Up @@ -318,3 +326,28 @@ func toDbOIDCSpec(oauth2Config *v1.OAuth2Config, oidcConfig *v1.OIDCConfig) db.O
Scopes: append([]string{goidc.ScopeOpenID, "profile", "email"}, oauth2Config.GetScopes()...),
}
}

func assertIssuerIsReachable(host string) error {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should accept the context and construct the request with context (http.NewRequest().WithContext()). This ensures that if the caller cancels, we also cancel.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, anything crossing a network boundary should accept context, and propagate it.

tr := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we disabling TLS? If necessary, a comment would help.

Proxy: http.ProxyFromEnvironment,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this? The environment this runs is in controlled by us, and we don't have a system level proxy configured.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Took this from a k8s reachability probe.

we don't have a system level proxy configured.

Since #12726 we have an option to set it on the installation level.

The environment this runs is in controlled by us

I'm not sure this is the question. If the Org only provides a proxy to reach internal Git, which was several times the case with self-hosted, it might still be relevant.

}
client := &http.Client{
Transport: tr,
Timeout: 2 * time.Second,
// never follow redirects
CheckRedirect: func(*http.Request, []*http.Request) error {
return http.ErrUseLastResponse
},
}

resp, err := client.Get(host)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could use the HEAD request here. It better communicates that you only want to know if it's reachable, rather than actually needing the payload.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's change to Head, I've no preferences.

if err != nil {
return err
}
resp.Body.Close()
if resp.StatusCode > 499 {
return fmt.Errorf("returned status %d", resp.StatusCode)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return fmt.Errorf("returned status %d", resp.StatusCode)
return fmt.Errorf("OIDC reachability check returned status %d", resp.StatusCode)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This ends up being the public facing error message, so we should communicate as much detail as possible.

}
return nil
}
58 changes: 33 additions & 25 deletions components/public-api-server/pkg/apiv1/oidc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,17 @@ func TestOIDCService_CreateClientConfig_FeatureFlagDisabled(t *testing.T) {
organizationID := uuid.New()

t.Run("returns unauthorized", func(t *testing.T) {
serverMock, client, _ := setupOIDCService(t, withOIDCFeatureDisabled)
serverMock, client, _, issuer := setupOIDCService(t, withOIDCFeatureDisabled)

serverMock.EXPECT().GetLoggedInUser(gomock.Any()).Return(user, nil)
serverMock.EXPECT().GetTeams(gomock.Any()).Return(teams, nil)

_, err := client.CreateClientConfig(context.Background(), connect.NewRequest(&v1.CreateClientConfigRequest{
Config: &v1.OIDCClientConfig{
OrganizationId: organizationID.String(),
OidcConfig: &v1.OIDCConfig{
Issuer: issuer,
},
},
}))
require.Error(t, err)
Expand All @@ -67,10 +70,10 @@ func TestOIDCService_CreateClientConfig_FeatureFlagEnabled(t *testing.T) {
organizationID := uuid.New()

t.Run("returns invalid argument when no organisation specified", func(t *testing.T) {
_, client, _ := setupOIDCService(t, withOIDCFeatureEnabled)
_, client, _, issuer := setupOIDCService(t, withOIDCFeatureEnabled)

config := &v1.OIDCClientConfig{
OidcConfig: &v1.OIDCConfig{Issuer: "test-issuer"},
OidcConfig: &v1.OIDCConfig{Issuer: issuer},
Oauth2Config: &v1.OAuth2Config{ClientId: "test-id", ClientSecret: "test-secret"},
}
_, err := client.CreateClientConfig(context.Background(), connect.NewRequest(&v1.CreateClientConfigRequest{
Expand All @@ -81,11 +84,11 @@ func TestOIDCService_CreateClientConfig_FeatureFlagEnabled(t *testing.T) {
})

t.Run("returns invalid argument when organisation id is not a uuid", func(t *testing.T) {
_, client, _ := setupOIDCService(t, withOIDCFeatureEnabled)
_, client, _, issuer := setupOIDCService(t, withOIDCFeatureEnabled)

config := &v1.OIDCClientConfig{
OrganizationId: "some-random-id",
OidcConfig: &v1.OIDCConfig{Issuer: "test-issuer"},
OidcConfig: &v1.OIDCConfig{Issuer: issuer},
Oauth2Config: &v1.OAuth2Config{ClientId: "test-id", ClientSecret: "test-secret"},
}
_, err := client.CreateClientConfig(context.Background(), connect.NewRequest(&v1.CreateClientConfigRequest{
Expand All @@ -96,13 +99,13 @@ func TestOIDCService_CreateClientConfig_FeatureFlagEnabled(t *testing.T) {
})

t.Run("creates oidc client config", func(t *testing.T) {
serverMock, client, dbConn := setupOIDCService(t, withOIDCFeatureEnabled)
serverMock, client, dbConn, issuer := setupOIDCService(t, withOIDCFeatureEnabled)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It took me quite a while to wrap my head around the fact that we're using our server, which has our API handlers, as the issuer, for the validation check during create.

Perhaps it would make sense to start a second server, and return it as part of setupOIDCService which is the issuer server. That would also give us the ability to test for the case when the check fails, which we're currently not covering AFAIK.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🙈 indeed, this was a shortcut to have a known to be reachable host.

💯 will touch on this in follow-up to extract it. As mentioned a while ago, the mocked OIDC services should become proper mocks to get full coverage of the handlers of ours.


serverMock.EXPECT().GetLoggedInUser(gomock.Any()).Return(user, nil)

config := &v1.OIDCClientConfig{
OrganizationId: organizationID.String(),
OidcConfig: &v1.OIDCConfig{Issuer: "test-issuer"},
OidcConfig: &v1.OIDCConfig{Issuer: issuer},
Oauth2Config: &v1.OAuth2Config{
ClientId: "test-id",
ClientSecret: "test-secret",
Expand Down Expand Up @@ -140,7 +143,7 @@ func TestOIDCService_CreateClientConfig_FeatureFlagEnabled(t *testing.T) {
}

func TestOIDCService_GetClientConfig_WithFeatureFlagDisabled(t *testing.T) {
serverMock, client, _ := setupOIDCService(t, withOIDCFeatureDisabled)
serverMock, client, _, _ := setupOIDCService(t, withOIDCFeatureDisabled)

serverMock.EXPECT().GetLoggedInUser(gomock.Any()).Return(user, nil)
serverMock.EXPECT().GetTeams(gomock.Any()).Return(teams, nil)
Expand All @@ -156,15 +159,15 @@ func TestOIDCService_GetClientConfig_WithFeatureFlagDisabled(t *testing.T) {
func TestOIDCService_GetClientConfig_WithFeatureFlagEnabled(t *testing.T) {

t.Run("invalid argument when config id missing", func(t *testing.T) {
_, client, _ := setupOIDCService(t, withOIDCFeatureEnabled)
_, client, _, _ := setupOIDCService(t, withOIDCFeatureEnabled)

_, err := client.GetClientConfig(context.Background(), connect.NewRequest(&v1.GetClientConfigRequest{}))
require.Error(t, err)
require.Equal(t, connect.CodeInvalidArgument, connect.CodeOf(err))
})

t.Run("invalid argument when organization id missing", func(t *testing.T) {
_, client, _ := setupOIDCService(t, withOIDCFeatureEnabled)
_, client, _, _ := setupOIDCService(t, withOIDCFeatureEnabled)

_, err := client.GetClientConfig(context.Background(), connect.NewRequest(&v1.GetClientConfigRequest{
Id: uuid.NewString(),
Expand All @@ -174,7 +177,7 @@ func TestOIDCService_GetClientConfig_WithFeatureFlagEnabled(t *testing.T) {
})

t.Run("not found when record does not exist", func(t *testing.T) {
serverMock, client, _ := setupOIDCService(t, withOIDCFeatureEnabled)
serverMock, client, _, _ := setupOIDCService(t, withOIDCFeatureEnabled)

serverMock.EXPECT().GetLoggedInUser(gomock.Any()).Return(user, nil)

Expand All @@ -187,14 +190,15 @@ func TestOIDCService_GetClientConfig_WithFeatureFlagEnabled(t *testing.T) {
})

t.Run("retrieves record when it exists", func(t *testing.T) {
serverMock, client, dbConn := setupOIDCService(t, withOIDCFeatureEnabled)
serverMock, client, dbConn, issuer := setupOIDCService(t, withOIDCFeatureEnabled)

serverMock.EXPECT().GetLoggedInUser(gomock.Any()).Return(user, nil)

orgID := uuid.New()

created := dbtest.CreateOIDCClientConfigs(t, dbConn, db.OIDCClientConfig{
OrganizationID: &orgID,
Issuer: issuer,
})[0]

resp, err := client.GetClientConfig(context.Background(), connect.NewRequest(&v1.GetClientConfigRequest{
Expand All @@ -213,7 +217,7 @@ func TestOIDCService_GetClientConfig_WithFeatureFlagEnabled(t *testing.T) {
}

func TestOIDCService_ListClientConfigs_WithFeatureFlagDisabled(t *testing.T) {
serverMock, client, _ := setupOIDCService(t, withOIDCFeatureDisabled)
serverMock, client, _, _ := setupOIDCService(t, withOIDCFeatureDisabled)

serverMock.EXPECT().GetLoggedInUser(gomock.Any()).Return(user, nil)
serverMock.EXPECT().GetTeams(gomock.Any()).Return(teams, nil)
Expand All @@ -228,15 +232,15 @@ func TestOIDCService_ListClientConfigs_WithFeatureFlagDisabled(t *testing.T) {
func TestOIDCService_ListClientConfigs_WithFeatureFlagEnabled(t *testing.T) {

t.Run("invalid argument when organization id missing", func(t *testing.T) {
_, client, _ := setupOIDCService(t, withOIDCFeatureEnabled)
_, client, _, _ := setupOIDCService(t, withOIDCFeatureEnabled)

_, err := client.ListClientConfigs(context.Background(), connect.NewRequest(&v1.ListClientConfigsRequest{}))
require.Error(t, err)
require.Equal(t, connect.CodeInvalidArgument, connect.CodeOf(err))
})

t.Run("invalid argument when organization id is invalid", func(t *testing.T) {
_, client, _ := setupOIDCService(t, withOIDCFeatureEnabled)
_, client, _, _ := setupOIDCService(t, withOIDCFeatureEnabled)

_, err := client.ListClientConfigs(context.Background(), connect.NewRequest(&v1.ListClientConfigsRequest{
OrganizationId: "some-invalid-id",
Expand All @@ -246,7 +250,7 @@ func TestOIDCService_ListClientConfigs_WithFeatureFlagEnabled(t *testing.T) {
})

t.Run("retrieves configs by organization id", func(t *testing.T) {
serverMock, client, dbConn := setupOIDCService(t, withOIDCFeatureEnabled)
serverMock, client, dbConn, issuer := setupOIDCService(t, withOIDCFeatureEnabled)

orgA, orgB := uuid.New(), uuid.New()

Expand All @@ -255,12 +259,15 @@ func TestOIDCService_ListClientConfigs_WithFeatureFlagEnabled(t *testing.T) {
configs := dbtest.CreateOIDCClientConfigs(t, dbConn,
dbtest.NewOIDCClientConfig(t, db.OIDCClientConfig{
OrganizationID: &orgA,
Issuer: issuer,
}),
dbtest.NewOIDCClientConfig(t, db.OIDCClientConfig{
OrganizationID: &orgA,
Issuer: issuer,
}),
dbtest.NewOIDCClientConfig(t, db.OIDCClientConfig{
OrganizationID: &orgB,
Issuer: issuer,
}),
)

Expand Down Expand Up @@ -290,7 +297,7 @@ func TestOIDCService_ListClientConfigs_WithFeatureFlagEnabled(t *testing.T) {

func TestOIDCService_UpdateClientConfig(t *testing.T) {
t.Run("feature flag disabled returns unauthorized", func(t *testing.T) {
serverMock, client, _ := setupOIDCService(t, withOIDCFeatureDisabled)
serverMock, client, _, _ := setupOIDCService(t, withOIDCFeatureDisabled)

serverMock.EXPECT().GetLoggedInUser(gomock.Any()).Return(user, nil)
serverMock.EXPECT().GetTeams(gomock.Any()).Return(teams, nil)
Expand All @@ -301,7 +308,7 @@ func TestOIDCService_UpdateClientConfig(t *testing.T) {
})

t.Run("feature flag enabled returns unimplemented", func(t *testing.T) {
serverMock, client, _ := setupOIDCService(t, withOIDCFeatureEnabled)
serverMock, client, _, _ := setupOIDCService(t, withOIDCFeatureEnabled)

serverMock.EXPECT().GetLoggedInUser(gomock.Any()).Return(user, nil)

Expand All @@ -313,7 +320,7 @@ func TestOIDCService_UpdateClientConfig(t *testing.T) {

func TestOIDCService_DeleteClientConfig_WithFeatureFlagDisabled(t *testing.T) {
t.Run("feature flag disabled returns unauthorized", func(t *testing.T) {
serverMock, client, _ := setupOIDCService(t, withOIDCFeatureDisabled)
serverMock, client, _, _ := setupOIDCService(t, withOIDCFeatureDisabled)

serverMock.EXPECT().GetLoggedInUser(gomock.Any()).Return(user, nil)
serverMock.EXPECT().GetTeams(gomock.Any()).Return(teams, nil)
Expand All @@ -330,15 +337,15 @@ func TestOIDCService_DeleteClientConfig_WithFeatureFlagDisabled(t *testing.T) {

func TestOIDCService_DeleteClientConfig_WithFeatureFlagEnabled(t *testing.T) {
t.Run("invalid argument when ID not specified", func(t *testing.T) {
_, client, _ := setupOIDCService(t, withOIDCFeatureEnabled)
_, client, _, _ := setupOIDCService(t, withOIDCFeatureEnabled)

_, err := client.DeleteClientConfig(context.Background(), connect.NewRequest(&v1.DeleteClientConfigRequest{}))
require.Error(t, err)
require.Equal(t, connect.CodeInvalidArgument, connect.CodeOf(err))
})

t.Run("invalid argument when Organization ID not specified", func(t *testing.T) {
_, client, _ := setupOIDCService(t, withOIDCFeatureEnabled)
_, client, _, _ := setupOIDCService(t, withOIDCFeatureEnabled)

_, err := client.DeleteClientConfig(context.Background(), connect.NewRequest(&v1.DeleteClientConfigRequest{
Id: uuid.NewString(),
Expand All @@ -348,7 +355,7 @@ func TestOIDCService_DeleteClientConfig_WithFeatureFlagEnabled(t *testing.T) {
})

t.Run("not found when record does not exist", func(t *testing.T) {
serverMock, client, _ := setupOIDCService(t, withOIDCFeatureEnabled)
serverMock, client, _, _ := setupOIDCService(t, withOIDCFeatureEnabled)

serverMock.EXPECT().GetLoggedInUser(gomock.Any()).Return(user, nil)

Expand All @@ -361,12 +368,13 @@ func TestOIDCService_DeleteClientConfig_WithFeatureFlagEnabled(t *testing.T) {
})

t.Run("deletes record", func(t *testing.T) {
serverMock, client, dbConn := setupOIDCService(t, withOIDCFeatureEnabled)
serverMock, client, dbConn, issuer := setupOIDCService(t, withOIDCFeatureEnabled)

orgID := uuid.New()

created := dbtest.CreateOIDCClientConfigs(t, dbConn, db.OIDCClientConfig{
OrganizationID: &orgID,
Issuer: issuer,
})[0]

serverMock.EXPECT().GetLoggedInUser(gomock.Any()).Return(user, nil)
Expand All @@ -380,7 +388,7 @@ func TestOIDCService_DeleteClientConfig_WithFeatureFlagEnabled(t *testing.T) {
})
}

func setupOIDCService(t *testing.T, expClient experiments.Client) (*protocol.MockAPIInterface, v1connect.OIDCServiceClient, *gorm.DB) {
func setupOIDCService(t *testing.T, expClient experiments.Client) (*protocol.MockAPIInterface, v1connect.OIDCServiceClient, *gorm.DB, string) {
t.Helper()

dbConn := dbtest.ConnectForTests(t)
Expand All @@ -401,5 +409,5 @@ func setupOIDCService(t *testing.T, expClient experiments.Client) (*protocol.Moc
auth.NewClientInterceptor("auth-token"),
))

return serverMock, client, dbConn
return serverMock, client, dbConn, srv.URL
}