From 1bd643f4ebc46e7ac5c1d35326e9b5540c569b82 Mon Sep 17 00:00:00 2001 From: Cemal Kilic Date: Wed, 29 Oct 2025 14:27:10 +0300 Subject: [PATCH] feat(oauthserver): add OAuth client admin update endpoint --- internal/api/api.go | 1 + internal/api/oauthserver/handlers.go | 27 +++ internal/api/oauthserver/handlers_test.go | 159 +++++++++++++++ internal/api/oauthserver/service.go | 228 +++++++++++++++++++--- openapi.yaml | 65 ++++++ 5 files changed, 448 insertions(+), 32 deletions(-) diff --git a/internal/api/api.go b/internal/api/api.go index cc20808cd..85e551604 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -348,6 +348,7 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne r.Route("/{client_id}", func(r *router) { r.Use(api.oauthServer.LoadOAuthServerClient) r.Get("/", api.oauthServer.OAuthServerClientGet) + r.Put("/", api.oauthServer.OAuthServerClientUpdate) r.Delete("/", api.oauthServer.OAuthServerClientDelete) r.Post("/regenerate_secret", api.oauthServer.OAuthServerClientRegenerateSecret) }) diff --git a/internal/api/oauthserver/handlers.go b/internal/api/oauthserver/handlers.go index 476e01329..d3aed3048 100644 --- a/internal/api/oauthserver/handlers.go +++ b/internal/api/oauthserver/handlers.go @@ -177,6 +177,33 @@ func (s *Server) OAuthServerClientGet(w http.ResponseWriter, r *http.Request) er return shared.SendJSON(w, http.StatusOK, response) } +// OAuthServerClientUpdate handles PUT /admin/oauth/clients/{client_id} +func (s *Server) OAuthServerClientUpdate(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + client := shared.GetOAuthServerClient(ctx) + + var params OAuthServerClientUpdateParams + if err := json.NewDecoder(r.Body).Decode(¶ms); err != nil { + return apierrors.NewBadRequestError(apierrors.ErrorCodeBadJSON, "Invalid JSON body") + } + + // Return early if no fields are provided for update + if params.isEmpty() { + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "No fields provided for update") + } + + updatedClient, err := s.updateOAuthServerClient(ctx, client.ID, ¶ms) + if err != nil { + if httpErr, ok := err.(*apierrors.HTTPError); ok { + return httpErr + } + return apierrors.NewInternalServerError("Error updating OAuth client").WithInternalError(err) + } + + response := oauthServerClientToResponse(updatedClient) + return shared.SendJSON(w, http.StatusOK, response) +} + // OAuthServerClientDelete handles DELETE /admin/oauth/clients/{client_id} func (s *Server) OAuthServerClientDelete(w http.ResponseWriter, r *http.Request) error { ctx := r.Context() diff --git a/internal/api/oauthserver/handlers_test.go b/internal/api/oauthserver/handlers_test.go index 3ad402f49..69fb61e51 100644 --- a/internal/api/oauthserver/handlers_test.go +++ b/internal/api/oauthserver/handlers_test.go @@ -251,6 +251,165 @@ func (ts *OAuthClientTestSuite) TestOAuthServerClientListHandler() { } } +func (ts *OAuthClientTestSuite) TestOAuthServerClientUpdateHandler() { + // Create a test client first + client, _ := ts.createTestOAuthClient() + + // Test updating all fields + newRedirectURIs := []string{"https://newapp.example.com/callback"} + newGrantTypes := []string{"authorization_code", "refresh_token"} + newClientName := "Updated Client Name" + newClientURI := "https://newapp.example.com" + newLogoURI := "https://newapp.example.com/logo.png" + + payload := OAuthServerClientUpdateParams{ + RedirectURIs: &newRedirectURIs, + GrantTypes: &newGrantTypes, + ClientName: &newClientName, + ClientURI: &newClientURI, + LogoURI: &newLogoURI, + } + + body, err := json.Marshal(payload) + require.NoError(ts.T(), err) + + req := httptest.NewRequest(http.MethodPut, "/admin/oauth/clients/"+client.ID.String(), bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + + ctx := shared.WithOAuthServerClient(req.Context(), client) + req = req.WithContext(ctx) + + w := httptest.NewRecorder() + + err = ts.Server.OAuthServerClientUpdate(w, req) + require.NoError(ts.T(), err) + + assert.Equal(ts.T(), http.StatusOK, w.Code) + + var response OAuthServerClientResponse + err = json.Unmarshal(w.Body.Bytes(), &response) + require.NoError(ts.T(), err) + + assert.Equal(ts.T(), client.ID.String(), response.ClientID) + assert.Equal(ts.T(), newClientName, response.ClientName) + assert.Equal(ts.T(), newRedirectURIs, response.RedirectURIs) + assert.Equal(ts.T(), newGrantTypes, response.GrantTypes) + assert.Equal(ts.T(), newClientURI, response.ClientURI) + assert.Equal(ts.T(), newLogoURI, response.LogoURI) + assert.Empty(ts.T(), response.ClientSecret) // Should NOT be included in update response +} + +func (ts *OAuthClientTestSuite) TestOAuthServerClientUpdateHandlerPartial() { + // Create a test client first + client, _ := ts.createTestOAuthClient() + + // Test updating only client name + newClientName := "Partially Updated Name" + payload := OAuthServerClientUpdateParams{ + ClientName: &newClientName, + } + + body, err := json.Marshal(payload) + require.NoError(ts.T(), err) + + req := httptest.NewRequest(http.MethodPut, "/admin/oauth/clients/"+client.ID.String(), bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + + ctx := shared.WithOAuthServerClient(req.Context(), client) + req = req.WithContext(ctx) + + w := httptest.NewRecorder() + + err = ts.Server.OAuthServerClientUpdate(w, req) + require.NoError(ts.T(), err) + + assert.Equal(ts.T(), http.StatusOK, w.Code) + + var response OAuthServerClientResponse + err = json.Unmarshal(w.Body.Bytes(), &response) + require.NoError(ts.T(), err) + + // Verify only client name was updated + assert.Equal(ts.T(), newClientName, response.ClientName) + // Verify other fields remained unchanged + assert.Equal(ts.T(), client.GetRedirectURIs(), response.RedirectURIs) +} + +func (ts *OAuthClientTestSuite) TestOAuthServerClientUpdateHandlerEmptyBody() { + // Create a test client first + client, _ := ts.createTestOAuthClient() + + // Test with empty body + payload := OAuthServerClientUpdateParams{} + + body, err := json.Marshal(payload) + require.NoError(ts.T(), err) + + req := httptest.NewRequest(http.MethodPut, "/admin/oauth/clients/"+client.ID.String(), bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + + ctx := shared.WithOAuthServerClient(req.Context(), client) + req = req.WithContext(ctx) + + w := httptest.NewRecorder() + + err = ts.Server.OAuthServerClientUpdate(w, req) + require.Error(ts.T(), err) + assert.Contains(ts.T(), err.Error(), "No fields provided for update") +} + +func (ts *OAuthClientTestSuite) TestOAuthServerClientUpdateHandlerInvalidValidation() { + // Create a test client first + client, _ := ts.createTestOAuthClient() + + // Test with invalid redirect URI + invalidRedirectURIs := []string{"invalid-uri"} + payload := OAuthServerClientUpdateParams{ + RedirectURIs: &invalidRedirectURIs, + } + + body, err := json.Marshal(payload) + require.NoError(ts.T(), err) + + req := httptest.NewRequest(http.MethodPut, "/admin/oauth/clients/"+client.ID.String(), bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + + ctx := shared.WithOAuthServerClient(req.Context(), client) + req = req.WithContext(ctx) + + w := httptest.NewRecorder() + + err = ts.Server.OAuthServerClientUpdate(w, req) + require.Error(ts.T(), err) + assert.Contains(ts.T(), err.Error(), "invalid redirect_uri") +} + +func (ts *OAuthClientTestSuite) TestOAuthServerClientUpdateHandlerSameValues() { + // Create a test client first + client, _ := ts.createTestOAuthClient() + + // Update with same values (should succeed) + currentName := "Test Client" + payload := OAuthServerClientUpdateParams{ + ClientName: ¤tName, + } + + body, err := json.Marshal(payload) + require.NoError(ts.T(), err) + + req := httptest.NewRequest(http.MethodPut, "/admin/oauth/clients/"+client.ID.String(), bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + + ctx := shared.WithOAuthServerClient(req.Context(), client) + req = req.WithContext(ctx) + + w := httptest.NewRecorder() + + err = ts.Server.OAuthServerClientUpdate(w, req) + require.NoError(ts.T(), err) + assert.Equal(ts.T(), http.StatusOK, w.Code) +} + func (ts *OAuthClientTestSuite) TestHandlerValidation() { // Test invalid JSON body req := httptest.NewRequest(http.MethodPost, "/admin/oauth/clients", bytes.NewReader([]byte("invalid json"))) diff --git a/internal/api/oauthserver/service.go b/internal/api/oauthserver/service.go index 88f1ab8f0..3af95d064 100644 --- a/internal/api/oauthserver/service.go +++ b/internal/api/oauthserver/service.go @@ -18,6 +18,86 @@ import ( "github.com/supabase/auth/internal/utilities" ) +// validateRedirectURIList validates a list of redirect URIs +func validateRedirectURIList(redirectURIs []string, required bool) error { + if required && len(redirectURIs) == 0 { + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "redirect_uris is required") + } + + if len(redirectURIs) == 0 { + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "redirect_uris cannot be empty") + } + + if len(redirectURIs) > 10 { + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "redirect_uris cannot exceed 10 items") + } + + for _, uri := range redirectURIs { + if err := validateRedirectURI(uri); err != nil { + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "invalid redirect_uri '%s': %v", uri, err) + } + } + + return nil +} + +// validateGrantTypeList validates a list of grant types +func validateGrantTypeList(grantTypes []string) error { + if len(grantTypes) == 0 { + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "grant_types cannot be empty") + } + + for _, grantType := range grantTypes { + if grantType != "authorization_code" && grantType != "refresh_token" { + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "grant_types must only contain 'authorization_code' and/or 'refresh_token'") + } + } + + return nil +} + +// validateClientName validates a client name +func validateClientName(clientName string) error { + if len(clientName) > 1024 { + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "client_name cannot exceed 1024 characters") + } + return nil +} + +// validateClientURI validates a client URI +func validateClientURI(clientURI string) error { + if clientURI == "" { + return nil + } + + if len(clientURI) > 2048 { + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "client_uri cannot exceed 2048 characters") + } + + if _, err := url.ParseRequestURI(clientURI); err != nil { + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "client_uri must be a valid URL") + } + + return nil +} + +// validateLogoURI validates a logo URI +func validateLogoURI(logoURI string) error { + if logoURI == "" { + return nil + } + + if len(logoURI) > 2048 { + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "logo_uri cannot exceed 2048 characters") + } + + if _, err := url.ParseRequestURI(logoURI); err != nil { + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "logo_uri must be a valid URL") + } + + return nil +} + // OAuthServerClientRegisterParams contains parameters for registering a new OAuth client type OAuthServerClientRegisterParams struct { // Required fields @@ -38,47 +118,31 @@ type OAuthServerClientRegisterParams struct { // validate validates the OAuth client registration parameters func (p *OAuthServerClientRegisterParams) validate() error { - // Validate redirect URIs (required) - if len(p.RedirectURIs) == 0 { - return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "redirect_uris is required") - } - - if len(p.RedirectURIs) > 10 { - return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "redirect_uris cannot exceed 10 items") - } - - for _, uri := range p.RedirectURIs { - if err := validateRedirectURI(uri); err != nil { - return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "invalid redirect_uri '%s': %v", uri, err) - } + // Validate redirect URIs (required for registration) + if err := validateRedirectURIList(p.RedirectURIs, true); err != nil { + return err } - for _, grantType := range p.GrantTypes { - if grantType != "authorization_code" && grantType != "refresh_token" { - return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "grant_types must only contain 'authorization_code' and/or 'refresh_token'") + // Validate grant types if provided + if len(p.GrantTypes) > 0 { + if err := validateGrantTypeList(p.GrantTypes); err != nil { + return err } } - if len(p.ClientName) > 1024 { - return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "client_name cannot exceed 1024 characters") + // Validate client name + if err := validateClientName(p.ClientName); err != nil { + return err } - if p.ClientURI != "" { - if len(p.ClientURI) > 2048 { - return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "client_uri cannot exceed 2048 characters") - } - if _, err := url.ParseRequestURI(p.ClientURI); err != nil { - return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "client_uri must be a valid URL") - } + // Validate client URI + if err := validateClientURI(p.ClientURI); err != nil { + return err } - if p.LogoURI != "" { - if len(p.LogoURI) > 2048 { - return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "logo_uri cannot exceed 2048 characters") - } - if _, err := url.ParseRequestURI(p.LogoURI); err != nil { - return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "logo_uri must be a valid URL") - } + // Validate logo URI + if err := validateLogoURI(p.LogoURI); err != nil { + return err } if p.RegistrationType != "dynamic" && p.RegistrationType != "manual" { @@ -278,3 +342,103 @@ func (s *Server) regenerateOAuthServerClientSecret(ctx context.Context, clientID return client, plaintextSecret, nil } + +// OAuthServerClientUpdateParams contains parameters for updating an OAuth client +type OAuthServerClientUpdateParams struct { + RedirectURIs *[]string `json:"redirect_uris,omitempty"` + GrantTypes *[]string `json:"grant_types,omitempty"` + ClientName *string `json:"client_name,omitempty"` + ClientURI *string `json:"client_uri,omitempty"` + LogoURI *string `json:"logo_uri,omitempty"` +} + +// isEmpty returns true if no fields are set for update +func (p *OAuthServerClientUpdateParams) isEmpty() bool { + return p.RedirectURIs == nil && + p.GrantTypes == nil && + p.ClientName == nil && + p.ClientURI == nil && + p.LogoURI == nil +} + +// validate validates the OAuth client update parameters +func (p *OAuthServerClientUpdateParams) validate() error { + // Validate redirect URIs if provided + if p.RedirectURIs != nil { + if err := validateRedirectURIList(*p.RedirectURIs, false); err != nil { + return err + } + } + + // Validate grant types if provided + if p.GrantTypes != nil { + if err := validateGrantTypeList(*p.GrantTypes); err != nil { + return err + } + } + + // Validate client name if provided + if p.ClientName != nil { + if err := validateClientName(*p.ClientName); err != nil { + return err + } + } + + // Validate client URI if provided + if p.ClientURI != nil { + if err := validateClientURI(*p.ClientURI); err != nil { + return err + } + } + + // Validate logo URI if provided + if p.LogoURI != nil { + if err := validateLogoURI(*p.LogoURI); err != nil { + return err + } + } + + return nil +} + +// updateOAuthServerClient updates an existing OAuth client +func (s *Server) updateOAuthServerClient(ctx context.Context, clientID uuid.UUID, params *OAuthServerClientUpdateParams) (*models.OAuthServerClient, error) { + // Validate all parameters + if err := params.validate(); err != nil { + return nil, err + } + + db := s.db.WithContext(ctx) + + client, err := models.FindOAuthServerClientByID(db, clientID) + if err != nil { + return nil, err + } + + // Update only the provided fields + if params.RedirectURIs != nil { + client.SetRedirectURIs(*params.RedirectURIs) + } + + if params.GrantTypes != nil { + client.SetGrantTypes(*params.GrantTypes) + } + + if params.ClientName != nil { + client.ClientName = utilities.StringPtr(*params.ClientName) + } + + if params.ClientURI != nil { + client.ClientURI = utilities.StringPtr(*params.ClientURI) + } + + if params.LogoURI != nil { + client.LogoURI = utilities.StringPtr(*params.LogoURI) + } + + if err := models.UpdateOAuthServerClient(db, client); err != nil { + return nil, errors.Wrap(err, "failed to update OAuth client") + } + + return client, nil +} diff --git a/openapi.yaml b/openapi.yaml index e7a87115f..390079ebb 100644 --- a/openapi.yaml +++ b/openapi.yaml @@ -1995,6 +1995,71 @@ paths: $ref: "#/components/responses/UnauthorizedResponse" 403: $ref: "#/components/responses/ForbiddenResponse" + put: + summary: Update OAuth client (admin) + description: > + Updates an existing OAuth client registration. Only the provided fields will be updated. + Only available when OAuth server is enabled. + tags: + - admin + - oauth-server + security: + - APIKeyAuth: [] + AdminAuth: [] + requestBody: + content: + application/json: + schema: + type: object + properties: + client_name: + type: string + description: Human-readable name of the client application + client_uri: + type: string + format: uri + description: URL of the client application's homepage + logo_uri: + type: string + format: uri + description: URL of the client application's logo + redirect_uris: + type: array + items: + type: string + format: uri + description: Array of redirect URIs used by the client + grant_types: + type: array + items: + type: string + enum: + - authorization_code + - refresh_token + description: OAuth grant types the client is authorized to use + responses: + 200: + description: OAuth client updated successfully + content: + application/json: + schema: + $ref: "#/components/schemas/OAuthClientSchema" + 400: + description: Bad request - validation failed or no fields provided for update + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorSchema" + 404: + description: OAuth client not found + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorSchema" + 401: + $ref: "#/components/responses/UnauthorizedResponse" + 403: + $ref: "#/components/responses/ForbiddenResponse" delete: summary: Delete OAuth client (admin) description: >