diff --git a/internal/api/api.go b/internal/api/api.go index 17ce68d7b..b6e71473e 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -9,6 +9,7 @@ import ( "github.com/sebest/xff" "github.com/sirupsen/logrus" "github.com/supabase/auth/internal/api/apierrors" + "github.com/supabase/auth/internal/api/oauthserver" "github.com/supabase/auth/internal/conf" "github.com/supabase/auth/internal/hooks/hookshttp" "github.com/supabase/auth/internal/hooks/hookspgfunc" @@ -35,8 +36,9 @@ type API struct { config *conf.GlobalConfiguration version string - hooksMgr *v0hooks.Manager - hibpClient *hibp.PwnedClient + hooksMgr *v0hooks.Manager + hibpClient *hibp.PwnedClient + oauthServer *oauthserver.Server // overrideTime can be used to override the clock used by handlers. Should only be used in tests! overrideTime func() time.Time @@ -80,7 +82,12 @@ func (a *API) deprecationNotices() { // NewAPIWithVersion creates a new REST API using the specified version func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Connection, version string, opt ...Option) *API { - api := &API{config: globalConfig, db: db, version: version} + api := &API{ + config: globalConfig, + db: db, + version: version, + oauthServer: oauthserver.NewServer(globalConfig, db), + } for _, o := range opt { o.apply(api) @@ -197,7 +204,7 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne With(api.verifyCaptcha).Post("/otp", api.Otp) // rate limiting applied in handler - r.With(api.verifyCaptcha).Post("/token", api.Token) + r.With(api.verifyCaptcha).With(api.oauthClientAuth).Post("/token", api.Token) r.With(api.limitHandler(api.limiterOpts.Verify)).Route("/verify", func(r *router) { r.Get("/", api.Verify) @@ -293,6 +300,28 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne }) }) }) + + // Admin only oauth client management endpoints + r.Route("/oauth", func(r *router) { + r.Route("/clients", func(r *router) { + // Manual client registration + r.Post("/", api.oauthServer.AdminOAuthServerClientRegister) + + r.Get("/", api.oauthServer.OAuthServerClientList) + + r.Route("/{client_id}", func(r *router) { + r.Use(api.oauthServer.LoadOAuthServerClient) + r.Get("/", api.oauthServer.OAuthServerClientGet) + r.Delete("/", api.oauthServer.OAuthServerClientDelete) + }) + }) + }) + }) + + // OAuth Dynamic Client Registration endpoint (public, rate limited) + r.Route("/oauth", func(r *router) { + r.With(api.limitHandler(api.limiterOpts.OAuthClientRegister)). + Post("/clients/register", api.oauthServer.OAuthServerClientDynamicRegister) }) }) diff --git a/internal/api/apierrors/errorcode.go b/internal/api/apierrors/errorcode.go index b764de80d..6406028bc 100644 --- a/internal/api/apierrors/errorcode.go +++ b/internal/api/apierrors/errorcode.go @@ -95,4 +95,6 @@ const ( ErrorCodeEmailAddressInvalid ErrorCode = "email_address_invalid" ErrorCodeWeb3ProviderDisabled ErrorCode = "web3_provider_disabled" ErrorCodeWeb3UnsupportedChain ErrorCode = "web3_unsupported_chain" + + ErrorCodeOAuthDynamicClientRegistrationDisabled ErrorCode = "oauth_dynamic_client_registration_disabled" ) diff --git a/internal/api/helpers.go b/internal/api/helpers.go index dd4e0450f..c63d2a8ac 100644 --- a/internal/api/helpers.go +++ b/internal/api/helpers.go @@ -3,11 +3,10 @@ package api import ( "context" "encoding/json" - "fmt" "net/http" - "github.com/pkg/errors" "github.com/supabase/auth/internal/api/apierrors" + "github.com/supabase/auth/internal/api/shared" "github.com/supabase/auth/internal/conf" "github.com/supabase/auth/internal/models" "github.com/supabase/auth/internal/security" @@ -16,14 +15,7 @@ import ( ) func sendJSON(w http.ResponseWriter, status int, obj interface{}) error { - w.Header().Set("Content-Type", "application/json") - b, err := json.Marshal(obj) - if err != nil { - return errors.Wrap(err, fmt.Sprintf("Error encoding json response: %v", obj)) - } - w.WriteHeader(status) - _, err = w.Write(b) - return err + return shared.SendJSON(w, status, obj) } func isAdmin(u *models.User, config *conf.GlobalConfiguration) bool { diff --git a/internal/api/middleware.go b/internal/api/middleware.go index c974c7f91..c55373587 100644 --- a/internal/api/middleware.go +++ b/internal/api/middleware.go @@ -14,6 +14,7 @@ import ( chimiddleware "github.com/go-chi/chi/v5/middleware" "github.com/sirupsen/logrus" "github.com/supabase/auth/internal/api/apierrors" + "github.com/supabase/auth/internal/api/oauthserver" "github.com/supabase/auth/internal/models" "github.com/supabase/auth/internal/observability" "github.com/supabase/auth/internal/security" @@ -81,6 +82,41 @@ func (a *API) limitHandler(lmt *limiter.Limiter) middlewareHandler { } } +// oauthClientAuth optionally authenticates an OAuth client as middleware +// This doesn't fail if no client credentials are provided, but validates them if present +func (a *API) oauthClientAuth(w http.ResponseWriter, r *http.Request) (context.Context, error) { + ctx := r.Context() + + clientID, clientSecret, err := oauthserver.ExtractClientCredentials(r) + if err != nil { + return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeInvalidCredentials, "Invalid client credentials: "+err.Error()) + } + + // If no client credentials provided, continue without client authentication + if clientID == "" { + return ctx, nil + } + + // Validate client credentials + db := a.db.WithContext(ctx) + client, err := models.FindOAuthServerClientByClientID(db, clientID) + if err != nil { + if models.IsNotFoundError(err) { + return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeInvalidCredentials, "Invalid client credentials") + } + return nil, apierrors.NewInternalServerError("Error validating client credentials").WithInternalError(err) + } + + // Validate client secret + if !oauthserver.ValidateClientSecret(clientSecret, client.ClientSecretHash) { + return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeInvalidCredentials, "Invalid client credentials") + } + + // Add authenticated client to context + ctx = oauthserver.WithOAuthServerClient(ctx, client) + return ctx, nil +} + func (a *API) requireAdminCredentials(w http.ResponseWriter, req *http.Request) (context.Context, error) { t, err := a.extractBearerToken(req) if err != nil || t == "" { diff --git a/internal/api/oauthserver/auth.go b/internal/api/oauthserver/auth.go new file mode 100644 index 000000000..46f05d5bb --- /dev/null +++ b/internal/api/oauthserver/auth.go @@ -0,0 +1,50 @@ +package oauthserver + +import ( + "encoding/base64" + "errors" + "net/http" + "strings" +) + +// ExtractClientCredentials extracts OAuth client credentials from the request +// Supports both Basic auth header and form body parameters +func ExtractClientCredentials(r *http.Request) (clientID, clientSecret string, err error) { + // First, try Basic auth header: Authorization: Basic base64(client_id:client_secret) + authHeader := r.Header.Get("Authorization") + if authHeader != "" && strings.HasPrefix(authHeader, "Basic ") { + encoded := strings.TrimPrefix(authHeader, "Basic ") + decoded, err := base64.StdEncoding.DecodeString(encoded) + if err != nil { + return "", "", errors.New("invalid basic auth encoding") + } + + credentials := string(decoded) + parts := strings.SplitN(credentials, ":", 2) + if len(parts) != 2 { + return "", "", errors.New("invalid basic auth format") + } + + return parts[0], parts[1], nil + } + + // Fall back to form parameters + if err := r.ParseForm(); err != nil { + return "", "", errors.New("failed to parse form") + } + + clientID = r.FormValue("client_id") + clientSecret = r.FormValue("client_secret") + + // Return empty credentials if both are empty (no client auth attempted) + if clientID == "" && clientSecret == "" { + return "", "", nil + } + + // If only one is provided, it's an error + if clientID == "" || clientSecret == "" { + return "", "", errors.New("both client_id and client_secret must be provided") + } + + return clientID, clientSecret, nil +} diff --git a/internal/api/oauthserver/handlers.go b/internal/api/oauthserver/handlers.go new file mode 100644 index 000000000..27093e80e --- /dev/null +++ b/internal/api/oauthserver/handlers.go @@ -0,0 +1,184 @@ +package oauthserver + +import ( + "context" + "encoding/json" + "net/http" + "time" + + "github.com/go-chi/chi/v5" + "github.com/supabase/auth/internal/api/apierrors" + "github.com/supabase/auth/internal/api/shared" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/observability" +) + +// OAuthServerClientResponse represents the response format for OAuth client operations +type OAuthServerClientResponse struct { + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret,omitempty"` // only returned on registration + + RedirectURIs []string `json:"redirect_uris"` + TokenEndpointAuthMethod []string `json:"token_endpoint_auth_method"` + GrantTypes []string `json:"grant_types"` + ResponseTypes []string `json:"response_types"` + ClientName string `json:"client_name,omitempty"` + ClientURI string `json:"client_uri,omitempty"` + LogoURI string `json:"logo_uri,omitempty"` + + // Metadata fields + RegistrationType string `json:"registration_type"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +// OAuthServerClientListResponse represents the response for listing OAuth clients +type OAuthServerClientListResponse struct { + Clients []OAuthServerClientResponse `json:"clients"` +} + +// oauthServerClientToResponse converts a model to response format +func oauthServerClientToResponse(client *models.OAuthServerClient, includeSecret bool) *OAuthServerClientResponse { + response := &OAuthServerClientResponse{ + ClientID: client.ClientID, + + // OAuth 2.1 DCR fields + RedirectURIs: client.GetRedirectURIs(), + TokenEndpointAuthMethod: []string{"client_secret_basic", "client_secret_post"}, // Both methods are supported + GrantTypes: client.GetGrantTypes(), + ResponseTypes: []string{"code"}, // Always "code" in OAuth 2.1 + ClientName: client.ClientName.String(), + ClientURI: client.ClientURI.String(), + LogoURI: client.LogoURI.String(), + + // Metadata fields + RegistrationType: client.RegistrationType, + CreatedAt: client.CreatedAt, + UpdatedAt: client.UpdatedAt, + } + + // Only include client_secret during registration + if includeSecret { + // Note: This will be filled in by the handler with the plaintext secret + response.ClientSecret = "" + } + + return response +} + +// LoadOAuthServerClient is middleware that loads an OAuth server client from the URL parameter +func (s *Server) LoadOAuthServerClient(w http.ResponseWriter, r *http.Request) (context.Context, error) { + ctx := r.Context() + clientID := chi.URLParam(r, "client_id") + + if clientID == "" { + return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "client_id is required") + } + + observability.LogEntrySetField(r, "oauth_client_id", clientID) + + client, err := s.getOAuthServerClient(ctx, clientID) + if err != nil { + if models.IsNotFoundError(err) { + return nil, apierrors.NewNotFoundError(apierrors.ErrorCodeUserNotFound, "OAuth client not found") + } + return nil, apierrors.NewInternalServerError("Error loading OAuth client").WithInternalError(err) + } + + ctx = WithOAuthServerClient(ctx, client) + return ctx, nil +} + +// AdminOAuthServerClientRegister handles POST /admin/oauth/clients (manual registration by admins) +func (s *Server) AdminOAuthServerClientRegister(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + + var params OAuthServerClientRegisterParams + if err := json.NewDecoder(r.Body).Decode(¶ms); err != nil { + return apierrors.NewBadRequestError(apierrors.ErrorCodeBadJSON, "Invalid JSON body") + } + + // Force registration type to manual for admin endpoint + params.RegistrationType = "manual" + + client, plaintextSecret, err := s.registerOAuthServerClient(ctx, ¶ms) + if err != nil { + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, err.Error()) + } + + response := oauthServerClientToResponse(client, true) + response.ClientSecret = plaintextSecret + + return shared.SendJSON(w, http.StatusCreated, response) +} + +// OAuthServerClientDynamicRegister handles POST /oauth/register (OAuth 2.1 Dynamic Client Registration) +func (s *Server) OAuthServerClientDynamicRegister(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + + // Check if dynamic registration is enabled + if !s.config.OAuthServer.AllowDynamicRegistration { + return apierrors.NewForbiddenError(apierrors.ErrorCodeOAuthDynamicClientRegistrationDisabled, "Dynamic client registration is not enabled") + } + + var params OAuthServerClientRegisterParams + if err := json.NewDecoder(r.Body).Decode(¶ms); err != nil { + return apierrors.NewBadRequestError(apierrors.ErrorCodeBadJSON, "Invalid JSON body") + } + + params.RegistrationType = "dynamic" + + client, plaintextSecret, err := s.registerOAuthServerClient(ctx, ¶ms) + if err != nil { + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, err.Error()) + } + + response := oauthServerClientToResponse(client, true) + response.ClientSecret = plaintextSecret + + return shared.SendJSON(w, http.StatusCreated, response) +} + +// OAuthServerClientGet handles GET /admin/oauth/clients/{client_id} +func (s *Server) OAuthServerClientGet(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + client := GetOAuthServerClient(ctx) + + response := oauthServerClientToResponse(client, false) + return shared.SendJSON(w, http.StatusOK, response) +} + +// OAuthServerClientDelete handles DELETE /admin/oauth/clients/{client_id} +func (s *Server) OAuthServerClientDelete(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + client := GetOAuthServerClient(ctx) + + if err := s.deleteOAuthServerClient(ctx, client.ClientID); err != nil { + return apierrors.NewInternalServerError("Error deleting OAuth client").WithInternalError(err) + } + + w.WriteHeader(http.StatusNoContent) + return nil +} + +// OAuthServerClientList handles GET /admin/oauth/clients +func (s *Server) OAuthServerClientList(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + db := s.db.WithContext(ctx) + + var clients []models.OAuthServerClient + if err := db.Q().Where("deleted_at is null").Order("created_at desc").All(&clients); err != nil { + return apierrors.NewInternalServerError("Error listing OAuth clients").WithInternalError(err) + } + + responses := make([]OAuthServerClientResponse, len(clients)) + for i, client := range clients { + responses[i] = *oauthServerClientToResponse(&client, false) + } + + response := OAuthServerClientListResponse{ + Clients: responses, + } + + return shared.SendJSON(w, http.StatusOK, response) +} diff --git a/internal/api/oauthserver/handlers_test.go b/internal/api/oauthserver/handlers_test.go new file mode 100644 index 000000000..c72cdb534 --- /dev/null +++ b/internal/api/oauthserver/handlers_test.go @@ -0,0 +1,272 @@ +package oauthserver + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/storage" + "github.com/supabase/auth/internal/storage/test" +) + +const oauthServerTestConfig = "../../../hack/test.env" + +type OAuthClientTestSuite struct { + suite.Suite + Server *Server + Config *conf.GlobalConfiguration + DB *storage.Connection +} + +func TestOAuthClientHandler(t *testing.T) { + globalConfig, err := conf.LoadGlobal(oauthServerTestConfig) + require.NoError(t, err) + + conn, err := test.SetupDBConnection(globalConfig) + require.NoError(t, err) + + // Enable OAuth dynamic client registration for tests + globalConfig.OAuthServer.AllowDynamicRegistration = true + + server := NewServer(globalConfig, conn) + + ts := &OAuthClientTestSuite{ + Server: server, + Config: globalConfig, + DB: conn, + } + defer ts.DB.Close() + + suite.Run(t, ts) +} + +func (ts *OAuthClientTestSuite) SetupTest() { + if ts.DB != nil { + models.TruncateAll(ts.DB) + } + // Enable OAuth dynamic client registration for tests + ts.Config.OAuthServer.AllowDynamicRegistration = true +} + +// Helper function to create test OAuth client +func (ts *OAuthClientTestSuite) createTestOAuthClient() (*models.OAuthServerClient, string) { + params := &OAuthServerClientRegisterParams{ + ClientName: "Test Client", + RedirectURIs: []string{"https://example.com/callback", "http://localhost:3000/callback"}, + RegistrationType: "dynamic", + } + + ctx := context.Background() + client, secret, err := ts.Server.registerOAuthServerClient(ctx, params) + require.NoError(ts.T(), err) + require.NotNil(ts.T(), client) + require.NotEmpty(ts.T(), secret) + + return client, secret +} + +// HTTP Handler Tests +func (ts *OAuthClientTestSuite) TestAdminOAuthServerClientRegisterHandler() { + // Create request payload + payload := OAuthServerClientRegisterParams{ + ClientName: "Test Admin Client", + RedirectURIs: []string{"https://example.com/callback"}, + } + + body, err := json.Marshal(payload) + require.NoError(ts.T(), err) + + // Create HTTP request + req := httptest.NewRequest(http.MethodPost, "/admin/oauth/clients", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + + // Create response recorder + w := httptest.NewRecorder() + + // Call handler + err = ts.Server.AdminOAuthServerClientRegister(w, req) + require.NoError(ts.T(), err) + + // Check response + assert.Equal(ts.T(), http.StatusCreated, w.Code) + + var response OAuthServerClientResponse + err = json.Unmarshal(w.Body.Bytes(), &response) + require.NoError(ts.T(), err) + + assert.NotEmpty(ts.T(), response.ClientID) + assert.NotEmpty(ts.T(), response.ClientSecret) // Should be included in registration response + assert.Equal(ts.T(), "Test Admin Client", response.ClientName) + assert.Equal(ts.T(), []string{"https://example.com/callback"}, response.RedirectURIs) + assert.Equal(ts.T(), "manual", response.RegistrationType) // Admin registration is manual +} + +func (ts *OAuthClientTestSuite) TestOAuthServerClientDynamicRegisterHandler() { + payload := OAuthServerClientRegisterParams{ + ClientName: "Test Dynamic Client", + RedirectURIs: []string{"https://app.example.com/callback"}, + ClientURI: "https://app.example.com", + } + + body, err := json.Marshal(payload) + require.NoError(ts.T(), err) + + req := httptest.NewRequest(http.MethodPost, "/oauth/clients/register", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + + err = ts.Server.OAuthServerClientDynamicRegister(w, req) + require.NoError(ts.T(), err) + + assert.Equal(ts.T(), http.StatusCreated, w.Code) + + var response OAuthServerClientResponse + err = json.Unmarshal(w.Body.Bytes(), &response) + require.NoError(ts.T(), err) + + assert.NotEmpty(ts.T(), response.ClientID) + assert.NotEmpty(ts.T(), response.ClientSecret) // Should be included in registration response + assert.Equal(ts.T(), "Test Dynamic Client", response.ClientName) + assert.Equal(ts.T(), "https://app.example.com", response.ClientURI) + assert.Equal(ts.T(), "dynamic", response.RegistrationType) // Dynamic registration +} + +func (ts *OAuthClientTestSuite) TestOAuthServerClientDynamicRegisterDisabled() { + // Disable dynamic registration + ts.Config.OAuthServer.AllowDynamicRegistration = false + + payload := OAuthServerClientRegisterParams{ + ClientName: "Test Client", + RedirectURIs: []string{"https://example.com/callback"}, + } + + body, err := json.Marshal(payload) + require.NoError(ts.T(), err) + + req := httptest.NewRequest(http.MethodPost, "/oauth/clients/register", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + + // Call handler - should return error + err = ts.Server.OAuthServerClientDynamicRegister(w, req) + require.Error(ts.T(), err) + assert.Contains(ts.T(), err.Error(), "Dynamic client registration is not enabled") +} + +func (ts *OAuthClientTestSuite) TestOAuthServerClientGetHandler() { + client, _ := ts.createTestOAuthClient() + + req := httptest.NewRequest(http.MethodGet, "/admin/oauth/clients/"+client.ClientID, nil) + + ctx := WithOAuthServerClient(req.Context(), client) + req = req.WithContext(ctx) + + w := httptest.NewRecorder() + + err := ts.Server.OAuthServerClientGet(w, req) + require.NoError(ts.T(), err) + + assert.Equal(ts.T(), http.StatusOK, w.Code) + + var response OAuthServerClientResponse + err = json.Unmarshal(w.Body.Bytes(), &response) + require.NoError(ts.T(), err) + + assert.Equal(ts.T(), client.ClientID, response.ClientID) + assert.Empty(ts.T(), response.ClientSecret) // Should NOT be included in get response + assert.Equal(ts.T(), "Test Client", response.ClientName) +} + +func (ts *OAuthClientTestSuite) TestOAuthServerClientDeleteHandler() { + // Create a test client first + client, _ := ts.createTestOAuthClient() + + // Create HTTP request with client in context + req := httptest.NewRequest(http.MethodDelete, "/admin/oauth/clients/"+client.ClientID, nil) + + // Add client to context (normally done by LoadOAuthServerClient middleware) + ctx := WithOAuthServerClient(req.Context(), client) + req = req.WithContext(ctx) + + w := httptest.NewRecorder() + + err := ts.Server.OAuthServerClientDelete(w, req) + require.NoError(ts.T(), err) + + assert.Equal(ts.T(), http.StatusNoContent, w.Code) + assert.Empty(ts.T(), w.Body.String()) + + // Verify client was soft-deleted + deletedClient, err := ts.Server.getOAuthServerClient(context.Background(), client.ClientID) + assert.Error(ts.T(), err) // it was soft-deleted + assert.Nil(ts.T(), deletedClient) +} + +func (ts *OAuthClientTestSuite) TestOAuthServerClientListHandler() { + // Create a couple test clients first + client1, _ := ts.createTestOAuthClient() + client2, _ := ts.createTestOAuthClient() + + req := httptest.NewRequest(http.MethodGet, "/admin/oauth/clients", nil) + + w := httptest.NewRecorder() + + err := ts.Server.OAuthServerClientList(w, req) + require.NoError(ts.T(), err) + + assert.Equal(ts.T(), http.StatusOK, w.Code) + + var response OAuthServerClientListResponse + err = json.Unmarshal(w.Body.Bytes(), &response) + require.NoError(ts.T(), err) + + assert.Len(ts.T(), response.Clients, 2) + + // Check that both clients are in the response (order might vary) + clientIDs := []string{response.Clients[0].ClientID, response.Clients[1].ClientID} + assert.Contains(ts.T(), clientIDs, client1.ClientID) + assert.Contains(ts.T(), clientIDs, client2.ClientID) + + // Verify client secrets are not included in list response + for _, client := range response.Clients { + assert.Empty(ts.T(), client.ClientSecret) + } +} + +func (ts *OAuthClientTestSuite) TestHandlerValidation() { + // Test invalid JSON body + req := httptest.NewRequest(http.MethodPost, "/admin/oauth/clients", bytes.NewReader([]byte("invalid json"))) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + err := ts.Server.AdminOAuthServerClientRegister(w, req) + require.Error(ts.T(), err) + assert.Contains(ts.T(), err.Error(), "Invalid JSON body") + + // Test validation failure + payload := OAuthServerClientRegisterParams{ + ClientName: "Test Client", + RedirectURIs: []string{"invalid-uri"}, // Invalid URI + } + + body, err := json.Marshal(payload) + require.NoError(ts.T(), err) + + req = httptest.NewRequest(http.MethodPost, "/admin/oauth/clients", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + + w = httptest.NewRecorder() + err = ts.Server.AdminOAuthServerClientRegister(w, req) + require.Error(ts.T(), err) + assert.Contains(ts.T(), err.Error(), "invalid redirect_uri") +} diff --git a/internal/api/oauthserver/server.go b/internal/api/oauthserver/server.go new file mode 100644 index 000000000..fb89510db --- /dev/null +++ b/internal/api/oauthserver/server.go @@ -0,0 +1,20 @@ +package oauthserver + +import ( + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/storage" +) + +// Server represents the OAuth 2.1 server functionality +type Server struct { + config *conf.GlobalConfiguration + db *storage.Connection +} + +// NewServer creates a new OAuth server instance +func NewServer(config *conf.GlobalConfiguration, db *storage.Connection) *Server { + return &Server{ + config: config, + db: db, + } +} diff --git a/internal/api/oauthserver/service.go b/internal/api/oauthserver/service.go new file mode 100644 index 000000000..2537f5e1d --- /dev/null +++ b/internal/api/oauthserver/service.go @@ -0,0 +1,216 @@ +package oauthserver + +import ( + "context" + "fmt" + "net/url" + "time" + + "github.com/pkg/errors" + "github.com/supabase/auth/internal/api/apierrors" + "github.com/supabase/auth/internal/crypto" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/storage" + "golang.org/x/crypto/bcrypt" +) + +// OAuthServerClientRegisterParams contains parameters for registering a new OAuth client +type OAuthServerClientRegisterParams struct { + // Required fields + RedirectURIs []string `json:"redirect_uris"` + + GrantTypes []string `json:"grant_types,omitempty"` + ClientName string `json:"client_name,omitempty"` + ClientURI string `json:"client_uri,omitempty"` + LogoURI string `json:"logo_uri,omitempty"` + + // Internal field + RegistrationType string `json:"-"` +} + +// validate validates the OAuth client registration parameters +func (p *OAuthServerClientRegisterParams) validate() error { + // Validate redirect URIs (required) + if len(p.RedirectURIs) == 0 { + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "redirect_uris is required") + } + + if len(p.RedirectURIs) > 10 { + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "redirect_uris cannot exceed 10 items") + } + + for _, uri := range p.RedirectURIs { + if err := validateRedirectURI(uri); err != nil { + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "invalid redirect_uri '%s': %v", uri, err) + } + } + + for _, grantType := range p.GrantTypes { + if grantType != "authorization_code" && grantType != "refresh_token" { + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "grant_types must only contain 'authorization_code' and/or 'refresh_token'") + } + } + + if len(p.ClientName) > 1024 { + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "client_name cannot exceed 1024 characters") + } + + if p.ClientURI != "" { + if len(p.ClientURI) > 2048 { + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "client_uri cannot exceed 2048 characters") + } + if _, err := url.ParseRequestURI(p.ClientURI); err != nil { + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "client_uri must be a valid URL") + } + } + + if p.LogoURI != "" { + if len(p.LogoURI) > 2048 { + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "logo_uri cannot exceed 2048 characters") + } + if _, err := url.ParseRequestURI(p.LogoURI); err != nil { + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "logo_uri must be a valid URL") + } + } + + if p.RegistrationType != "dynamic" && p.RegistrationType != "manual" { + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "registration_type must be 'dynamic' or 'manual'") + } + + return nil +} + +// validateRedirectURI validates OAuth 2.1 redirect URIs +func validateRedirectURI(uri string) error { + if uri == "" { + return fmt.Errorf("redirect URI cannot be empty") + } + + parsedURL, err := url.Parse(uri) + if err != nil { + return fmt.Errorf("invalid URL format") + } + + // Must have scheme and host + if parsedURL.Scheme == "" || parsedURL.Host == "" { + return fmt.Errorf("must have scheme and host") + } + + // Check scheme requirements + if parsedURL.Scheme == "http" { + // HTTP only allowed for localhost + host := parsedURL.Hostname() + if host != "localhost" && host != "127.0.0.1" { + return fmt.Errorf("HTTP scheme only allowed for localhost") + } + } else if parsedURL.Scheme != "https" { + return fmt.Errorf("scheme must be HTTPS or HTTP (localhost only)") + } + + // Must not have fragment + if parsedURL.Fragment != "" { + return fmt.Errorf("fragment not allowed in redirect URI") + } + + return nil +} + +// generateClientID generates a URL-safe random client ID +func generateClientID() string { + // Generate a 32-character alphanumeric client ID + return crypto.SecureAlphanumeric(32) +} + +// generateClientSecret generates a secure random client secret +func generateClientSecret() string { + // Generate a 64-character secure random secret + return crypto.SecureAlphanumeric(64) +} + +// hashClientSecret hashes a client secret using bcrypt +func hashClientSecret(secret string) (string, error) { + hash, err := bcrypt.GenerateFromPassword([]byte(secret), bcrypt.DefaultCost) + if err != nil { + return "", errors.Wrap(err, "failed to hash client secret") + } + return string(hash), nil +} + +// ValidateClientSecret validates a client secret against its hash +func ValidateClientSecret(secret, hash string) bool { + err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(secret)) + return err == nil +} + +// registerOAuthServerClient creates a new OAuth server client with generated credentials +func (s *Server) registerOAuthServerClient(ctx context.Context, params *OAuthServerClientRegisterParams) (*models.OAuthServerClient, string, error) { + // Validate all parameters + if err := params.validate(); err != nil { + return nil, "", err + } + + // Set defaults + grantTypes := params.GrantTypes + if len(grantTypes) == 0 { + grantTypes = []string{"authorization_code", "refresh_token"} + } + + db := s.db.WithContext(ctx) + + client := &models.OAuthServerClient{ + ClientID: generateClientID(), + RegistrationType: params.RegistrationType, + ClientName: storage.NullString(params.ClientName), + ClientURI: storage.NullString(params.ClientURI), + LogoURI: storage.NullString(params.LogoURI), + } + + client.SetRedirectURIs(params.RedirectURIs) + client.SetGrantTypes(grantTypes) + + // Generate client secret for all clients + plaintextSecret := generateClientSecret() + hash, err := hashClientSecret(plaintextSecret) + if err != nil { + return nil, "", errors.Wrap(err, "failed to hash client secret") + } + client.ClientSecretHash = hash + + if err := models.CreateOAuthServerClient(db, client); err != nil { + return nil, "", errors.Wrap(err, "failed to create OAuth client") + } + + return client, plaintextSecret, nil +} + +// getOAuthServerClient retrieves an OAuth client by client_id +func (s *Server) getOAuthServerClient(ctx context.Context, clientID string) (*models.OAuthServerClient, error) { + db := s.db.WithContext(ctx) + + client, err := models.FindOAuthServerClientByClientID(db, clientID) + if err != nil { + return nil, err + } + + return client, nil +} + +// deleteOAuthServerClient soft-deletes an OAuth client +func (s *Server) deleteOAuthServerClient(ctx context.Context, clientID string) error { + db := s.db.WithContext(ctx) + + client, err := models.FindOAuthServerClientByClientID(db, clientID) + if err != nil { + return err + } + + // Soft delete by setting deleted_at + now := time.Now() + client.DeletedAt = &now + + if err := models.UpdateOAuthServerClient(db, client); err != nil { + return errors.Wrap(err, "failed to delete OAuth client") + } + + return nil +} diff --git a/internal/api/oauthserver/service_test.go b/internal/api/oauthserver/service_test.go new file mode 100644 index 000000000..3c7f1f365 --- /dev/null +++ b/internal/api/oauthserver/service_test.go @@ -0,0 +1,330 @@ +package oauthserver + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/storage" + "github.com/supabase/auth/internal/storage/test" +) + +const serviceTestConfig = "../../../hack/test.env" + +// OAuthServiceTestSuite runs tests for OAuth service layer functionality +type OAuthServiceTestSuite struct { + suite.Suite + Server *Server + Config *conf.GlobalConfiguration + DB *storage.Connection +} + +func TestOAuthService(t *testing.T) { + globalConfig, err := conf.LoadGlobal(serviceTestConfig) + require.NoError(t, err) + + conn, err := test.SetupDBConnection(globalConfig) + require.NoError(t, err) + + // Enable OAuth dynamic client registration for tests + globalConfig.OAuthServer.AllowDynamicRegistration = true + + server := NewServer(globalConfig, conn) + + ts := &OAuthServiceTestSuite{ + Server: server, + Config: globalConfig, + DB: conn, + } + defer ts.DB.Close() + + suite.Run(t, ts) +} + +func (ts *OAuthServiceTestSuite) SetupTest() { + if ts.DB != nil { + models.TruncateAll(ts.DB) + } + // Enable OAuth dynamic client registration for tests + ts.Config.OAuthServer.AllowDynamicRegistration = true +} + +// Helper function to create test OAuth client +func (ts *OAuthServiceTestSuite) createTestOAuthClient() (*models.OAuthServerClient, string) { + params := &OAuthServerClientRegisterParams{ + ClientName: "Test Client", + RedirectURIs: []string{"https://example.com/callback", "http://localhost:3000/callback"}, + RegistrationType: "dynamic", + } + + ctx := context.Background() + client, secret, err := ts.Server.registerOAuthServerClient(ctx, params) + require.NoError(ts.T(), err) + require.NotNil(ts.T(), client) + require.NotEmpty(ts.T(), secret) + + return client, secret +} + +// Test the OAuth client service functions directly +func (ts *OAuthServiceTestSuite) TestOAuthServerClientServiceMethods() { + + // Test registerOAuthServerClient + params := &OAuthServerClientRegisterParams{ + ClientName: "Test Client", + RedirectURIs: []string{"https://example.com/callback"}, + RegistrationType: "dynamic", + } + + ctx := context.Background() + client, secret, err := ts.Server.registerOAuthServerClient(ctx, params) + + require.NoError(ts.T(), err) + require.NotNil(ts.T(), client) + require.NotEmpty(ts.T(), secret) + assert.Equal(ts.T(), "Test Client", client.ClientName.String()) + assert.Equal(ts.T(), "dynamic", client.RegistrationType) + + // Test getOAuthServerClient + retrievedClient, err := ts.Server.getOAuthServerClient(ctx, client.ClientID) + require.NoError(ts.T(), err) + assert.Equal(ts.T(), client.ClientID, retrievedClient.ClientID) + +} + +func (ts *OAuthServiceTestSuite) TestHashClientSecret() { + secret := "test-secret-123" + + hash, err := hashClientSecret(secret) + require.NoError(ts.T(), err) + assert.NotEmpty(ts.T(), hash) + assert.NotEqual(ts.T(), secret, hash) // Should be hashed, not plaintext + + // Test validation + isValid := ValidateClientSecret(secret, hash) + assert.True(ts.T(), isValid) + + isInvalid := ValidateClientSecret("wrong-secret", hash) + assert.False(ts.T(), isInvalid) +} + +func (ts *OAuthServiceTestSuite) TestClientAuthentication() { + + client, secret := ts.createTestOAuthClient() + + // Test valid client credentials + valid := ValidateClientSecret(secret, client.ClientSecretHash) + assert.True(ts.T(), valid) + + // Test invalid client credentials + invalid := ValidateClientSecret("wrong-secret", client.ClientSecretHash) + assert.False(ts.T(), invalid) +} + +func (ts *OAuthServiceTestSuite) TestDeleteOAuthServerClient() { + // Create a test client first + client, _ := ts.createTestOAuthClient() + + // Delete the client + ctx := context.Background() + err := ts.Server.deleteOAuthServerClient(ctx, client.ClientID) + require.NoError(ts.T(), err) + + // Verify client was soft-deleted + deletedClient, err := ts.Server.getOAuthServerClient(ctx, client.ClientID) + assert.Error(ts.T(), err) // it was soft-deleted + assert.Nil(ts.T(), deletedClient) +} + +func (ts *OAuthServiceTestSuite) TestValidationEdgeCases() { + // Test empty redirect URIs + params := &OAuthServerClientRegisterParams{ + ClientName: "Test Client", + RedirectURIs: []string{}, // Empty array + RegistrationType: "dynamic", + } + + ctx := context.Background() + _, _, err := ts.Server.registerOAuthServerClient(ctx, params) + assert.Error(ts.T(), err) + assert.Contains(ts.T(), err.Error(), "redirect_uris is required") + + // Test invalid redirect URI + params = &OAuthServerClientRegisterParams{ + ClientName: "Test Client", + RedirectURIs: []string{"invalid-uri"}, // Invalid URI + RegistrationType: "dynamic", + } + + _, _, err = ts.Server.registerOAuthServerClient(ctx, params) + assert.Error(ts.T(), err) + assert.Contains(ts.T(), err.Error(), "invalid redirect_uri") + + // Test too many redirect URIs + params = &OAuthServerClientRegisterParams{ + ClientName: "Test Client", + RedirectURIs: make([]string, 11), // Too many URIs + RegistrationType: "dynamic", + } + + // Fill with valid URIs + for i := 0; i < 11; i++ { + params.RedirectURIs[i] = "https://example.com/callback" + string(rune('0'+i)) + } + + _, _, err = ts.Server.registerOAuthServerClient(ctx, params) + assert.Error(ts.T(), err) + assert.Contains(ts.T(), err.Error(), "redirect_uris cannot exceed 10 items") + + // Test invalid grant type + params = &OAuthServerClientRegisterParams{ + ClientName: "Test Client", + RedirectURIs: []string{"https://example.com/callback"}, + GrantTypes: []string{"invalid_grant_type"}, + RegistrationType: "dynamic", + } + + _, _, err = ts.Server.registerOAuthServerClient(ctx, params) + assert.Error(ts.T(), err) + assert.Contains(ts.T(), err.Error(), "grant_types must only contain 'authorization_code' and/or 'refresh_token'") + + // Test client name too long + params = &OAuthServerClientRegisterParams{ + ClientName: string(make([]byte, 1025)), // Too long + RedirectURIs: []string{"https://example.com/callback"}, + RegistrationType: "dynamic", + } + + _, _, err = ts.Server.registerOAuthServerClient(ctx, params) + assert.Error(ts.T(), err) + assert.Contains(ts.T(), err.Error(), "client_name cannot exceed 1024 characters") + + // Test invalid client URI + params = &OAuthServerClientRegisterParams{ + ClientName: "Test Client", + RedirectURIs: []string{"https://example.com/callback"}, + ClientURI: "not-a-valid-url", + RegistrationType: "dynamic", + } + + _, _, err = ts.Server.registerOAuthServerClient(ctx, params) + assert.Error(ts.T(), err) + assert.Contains(ts.T(), err.Error(), "client_uri must be a valid URL") + + // Test invalid logo URI + params = &OAuthServerClientRegisterParams{ + ClientName: "Test Client", + RedirectURIs: []string{"https://example.com/callback"}, + LogoURI: "not-a-valid-url", + RegistrationType: "dynamic", + } + + _, _, err = ts.Server.registerOAuthServerClient(ctx, params) + assert.Error(ts.T(), err) + assert.Contains(ts.T(), err.Error(), "logo_uri must be a valid URL") + + // Test invalid registration type + params = &OAuthServerClientRegisterParams{ + ClientName: "Test Client", + RedirectURIs: []string{"https://example.com/callback"}, + RegistrationType: "invalid", + } + + _, _, err = ts.Server.registerOAuthServerClient(ctx, params) + assert.Error(ts.T(), err) + assert.Contains(ts.T(), err.Error(), "registration_type must be 'dynamic' or 'manual'") +} + +func (ts *OAuthServiceTestSuite) TestRedirectURIValidation() { + testCases := []struct { + name string + uri string + shouldError bool + errorMsg string + }{ + { + name: "Valid HTTPS URI", + uri: "https://example.com/callback", + shouldError: false, + }, + { + name: "Valid localhost HTTP URI", + uri: "http://localhost:3000/callback", + shouldError: false, + }, + { + name: "Valid 127.0.0.1 HTTP URI", + uri: "http://127.0.0.1:8080/callback", + shouldError: false, + }, + { + name: "Invalid empty URI", + uri: "", + shouldError: true, + errorMsg: "redirect URI cannot be empty", + }, + { + name: "Invalid scheme", + uri: "ftp://example.com/callback", + shouldError: true, + errorMsg: "scheme must be HTTPS or HTTP (localhost only)", + }, + { + name: "Invalid HTTP non-localhost", + uri: "http://example.com/callback", + shouldError: true, + errorMsg: "HTTP scheme only allowed for localhost", + }, + { + name: "Invalid URI with fragment", + uri: "https://example.com/callback#fragment", + shouldError: true, + errorMsg: "fragment not allowed in redirect URI", + }, + { + name: "Invalid URI format", + uri: "not-a-uri", + shouldError: true, + errorMsg: "must have scheme and host", + }, + } + + for _, tc := range testCases { + ts.T().Run(tc.name, func(t *testing.T) { + err := validateRedirectURI(tc.uri) + if tc.shouldError { + assert.Error(t, err) + if tc.errorMsg != "" { + assert.Contains(t, err.Error(), tc.errorMsg) + } + } else { + assert.NoError(t, err) + } + }) + } +} + +func (ts *OAuthServiceTestSuite) TestGrantTypeDefaults() { + + // Test that default grant types are set when none provided + params := &OAuthServerClientRegisterParams{ + ClientName: "Test Client", + RedirectURIs: []string{"https://example.com/callback"}, + RegistrationType: "dynamic", + // GrantTypes not specified + } + + ctx := context.Background() + client, _, err := ts.Server.registerOAuthServerClient(ctx, params) + require.NoError(ts.T(), err) + + // Should have default grant types + grantTypes := client.GetGrantTypes() + assert.Contains(ts.T(), grantTypes, "authorization_code") + assert.Contains(ts.T(), grantTypes, "refresh_token") + assert.Len(ts.T(), grantTypes, 2) +} diff --git a/internal/api/oauthserver/utils.go b/internal/api/oauthserver/utils.go new file mode 100644 index 000000000..9c1833d66 --- /dev/null +++ b/internal/api/oauthserver/utils.go @@ -0,0 +1,28 @@ +package oauthserver + +import ( + "context" + + "github.com/supabase/auth/internal/models" +) + +// Context keys for OAuth server functionality +type contextKey string + +const ( + oauthServerClientKey contextKey = "oauth_server_client" +) + +// WithOAuthServerClient adds an OAuth server client to the context +func WithOAuthServerClient(ctx context.Context, client *models.OAuthServerClient) context.Context { + return context.WithValue(ctx, oauthServerClientKey, client) +} + +// GetOAuthServerClient retrieves an OAuth server client from the context +func GetOAuthServerClient(ctx context.Context) *models.OAuthServerClient { + obj := ctx.Value(oauthServerClientKey) + if obj == nil { + return nil + } + return obj.(*models.OAuthServerClient) +} diff --git a/internal/api/options.go b/internal/api/options.go index 13663152f..d2e281934 100644 --- a/internal/api/options.go +++ b/internal/api/options.go @@ -17,20 +17,21 @@ type LimiterOptions struct { Email ratelimit.Limiter Phone ratelimit.Limiter - Signups *limiter.Limiter - AnonymousSignIns *limiter.Limiter - Recover *limiter.Limiter - Resend *limiter.Limiter - MagicLink *limiter.Limiter - Otp *limiter.Limiter - Token *limiter.Limiter - Verify *limiter.Limiter - User *limiter.Limiter - FactorVerify *limiter.Limiter - FactorChallenge *limiter.Limiter - SSO *limiter.Limiter - SAMLAssertion *limiter.Limiter - Web3 *limiter.Limiter + Signups *limiter.Limiter + AnonymousSignIns *limiter.Limiter + Recover *limiter.Limiter + Resend *limiter.Limiter + MagicLink *limiter.Limiter + Otp *limiter.Limiter + Token *limiter.Limiter + Verify *limiter.Limiter + User *limiter.Limiter + FactorVerify *limiter.Limiter + FactorChallenge *limiter.Limiter + SSO *limiter.Limiter + SAMLAssertion *limiter.Limiter + Web3 *limiter.Limiter + OAuthClientRegister *limiter.Limiter } func (lo *LimiterOptions) apply(a *API) { a.limiterOpts = lo } @@ -96,6 +97,9 @@ func NewLimiterOptions(gc *conf.GlobalConfiguration) *LimiterOptions { o.Resend = newLimiterPer5mOver1h(gc.RateLimitOtp) o.MagicLink = newLimiterPer5mOver1h(gc.RateLimitOtp) o.Otp = newLimiterPer5mOver1h(gc.RateLimitOtp) + + o.OAuthClientRegister = newLimiterPer5mOver1h(gc.RateLimitOAuthDynamicClientRegister) + return o } diff --git a/internal/api/shared/http.go b/internal/api/shared/http.go new file mode 100644 index 000000000..4eec7cc82 --- /dev/null +++ b/internal/api/shared/http.go @@ -0,0 +1,21 @@ +package shared + +import ( + "encoding/json" + "fmt" + "net/http" + + "github.com/pkg/errors" +) + +// SendJSON sends a JSON response with proper error handling +func SendJSON(w http.ResponseWriter, status int, obj interface{}) error { + w.Header().Set("Content-Type", "application/json") + b, err := json.Marshal(obj) + if err != nil { + return errors.Wrap(err, fmt.Sprintf("Error encoding json response: %v", obj)) + } + w.WriteHeader(status) + _, err = w.Write(b) + return err +} diff --git a/internal/api/token.go b/internal/api/token.go index eb33b6ad8..aa54142d6 100644 --- a/internal/api/token.go +++ b/internal/api/token.go @@ -32,6 +32,8 @@ type AccessTokenClaims struct { AuthenticationMethodReference []models.AMREntry `json:"amr,omitempty"` SessionId string `json:"session_id,omitempty"` IsAnonymous bool `json:"is_anonymous"` + // TODO(cemalkilic) : client_id claim will be added later + // ClientId string `json:"client_id,omitempty"` } // AccessTokenResponse represents an OAuth2 success response diff --git a/internal/conf/configuration.go b/internal/conf/configuration.go index 14d2d5d69..8aff15f91 100644 --- a/internal/conf/configuration.go +++ b/internal/conf/configuration.go @@ -70,6 +70,11 @@ type OAuthProviderConfiguration struct { SkipNonceCheck bool `json:"skip_nonce_check" split_words:"true"` } +// OAuthServerConfiguration holds OAuth server configuration +type OAuthServerConfiguration struct { + AllowDynamicRegistration bool `json:"allow_dynamic_registration" split_words:"true"` +} + type AnonymousProviderConfiguration struct { Enabled bool `json:"enabled" default:"false"` } @@ -253,23 +258,25 @@ type GlobalConfiguration struct { API APIConfiguration DB DBConfiguration External ProviderConfiguration - Logging LoggingConfig `envconfig:"LOG"` - Profiler ProfilerConfig `envconfig:"PROFILER"` - OperatorToken string `split_words:"true" required:"false"` + OAuthServer OAuthServerConfiguration `envconfig:"OAUTH_SERVER"` + Logging LoggingConfig `envconfig:"LOG"` + Profiler ProfilerConfig `envconfig:"PROFILER"` + OperatorToken string `split_words:"true" required:"false"` Tracing TracingConfig Metrics MetricsConfig SMTP SMTPConfiguration AuditLog AuditLogConfiguration `split_words:"true"` - RateLimitHeader string `split_words:"true"` - RateLimitEmailSent Rate `split_words:"true" default:"30"` - RateLimitSmsSent Rate `split_words:"true" default:"30"` - RateLimitVerify float64 `split_words:"true" default:"30"` - RateLimitTokenRefresh float64 `split_words:"true" default:"150"` - RateLimitSso float64 `split_words:"true" default:"30"` - RateLimitAnonymousUsers float64 `split_words:"true" default:"30"` - RateLimitOtp float64 `split_words:"true" default:"30"` - RateLimitWeb3 float64 `split_words:"true" default:"30"` + RateLimitHeader string `split_words:"true"` + RateLimitEmailSent Rate `split_words:"true" default:"30"` + RateLimitSmsSent Rate `split_words:"true" default:"30"` + RateLimitVerify float64 `split_words:"true" default:"30"` + RateLimitTokenRefresh float64 `split_words:"true" default:"150"` + RateLimitSso float64 `split_words:"true" default:"30"` + RateLimitAnonymousUsers float64 `split_words:"true" default:"30"` + RateLimitOtp float64 `split_words:"true" default:"30"` + RateLimitWeb3 float64 `split_words:"true" default:"30"` + RateLimitOAuthDynamicClientRegister float64 `split_words:"true" default:"10"` SiteURL string `json:"site_url" split_words:"true" required:"true"` URIAllowList []string `json:"uri_allow_list" split_words:"true"` diff --git a/internal/models/connection.go b/internal/models/connection.go index 80acccc57..82a5e8775 100644 --- a/internal/models/connection.go +++ b/internal/models/connection.go @@ -49,6 +49,7 @@ func TruncateAll(conn *storage.Connection) error { (&pop.Model{Value: SAMLRelayState{}}).TableName(), (&pop.Model{Value: FlowState{}}).TableName(), (&pop.Model{Value: OneTimeToken{}}).TableName(), + (&pop.Model{Value: OAuthServerClient{}}).TableName(), } for _, tableName := range tables { diff --git a/internal/models/errors.go b/internal/models/errors.go index 96f831969..025779c4e 100644 --- a/internal/models/errors.go +++ b/internal/models/errors.go @@ -27,6 +27,8 @@ func IsNotFoundError(err error) bool { return true case OneTimeTokenNotFoundError, *OneTimeTokenNotFoundError: return true + case OAuthServerClientNotFoundError, *OAuthServerClientNotFoundError: + return true } return false } diff --git a/internal/models/oauth_client.go b/internal/models/oauth_client.go new file mode 100644 index 000000000..7c8776377 --- /dev/null +++ b/internal/models/oauth_client.go @@ -0,0 +1,184 @@ +package models + +import ( + "database/sql" + "fmt" + "net/url" + "strings" + "time" + + "github.com/gobuffalo/pop/v6" + "github.com/gofrs/uuid" + "github.com/pkg/errors" + "github.com/supabase/auth/internal/storage" +) + +// OAuthServerClient represents an OAuth client application registered with this OAuth server +type OAuthServerClient struct { + ID uuid.UUID `json:"-" db:"id"` + ClientID string `json:"client_id" db:"client_id"` + ClientSecretHash string `json:"-" db:"client_secret_hash"` + RegistrationType string `json:"registration_type" db:"registration_type"` + + RedirectURIs string `json:"-" db:"redirect_uris"` + GrantTypes string `json:"grant_types" db:"grant_types"` + ClientName storage.NullString `json:"client_name" db:"client_name"` + ClientURI storage.NullString `json:"client_uri" db:"client_uri"` + LogoURI storage.NullString `json:"logo_uri" db:"logo_uri"` + CreatedAt time.Time `json:"created_at" db:"created_at"` + UpdatedAt time.Time `json:"updated_at" db:"updated_at"` + DeletedAt *time.Time `json:"deleted_at,omitempty" db:"deleted_at"` +} + +// TableName returns the table name for the OAuthServerClient model +func (OAuthServerClient) TableName() string { + return "oauth_clients" +} + +// BeforeSave is invoked before the OAuth client is saved to the database +func (c *OAuthServerClient) BeforeSave(tx *pop.Connection) error { + c.UpdatedAt = time.Now() + return nil +} + +// Validate performs basic validation on the OAuth client +func (c *OAuthServerClient) Validate() error { + if c.ClientID == "" { + return fmt.Errorf("client_id is required") + } + + if c.RegistrationType != "dynamic" && c.RegistrationType != "manual" { + return fmt.Errorf("registration_type must be 'dynamic' or 'manual'") + } + + if c.RedirectURIs == "" { + return fmt.Errorf("at least one redirect_uri is required") + } + + return nil +} + +// GetRedirectURIs returns the redirect URIs as a slice +func (c *OAuthServerClient) GetRedirectURIs() []string { + if c.RedirectURIs == "" { + return []string{} + } + return strings.Split(c.RedirectURIs, ",") +} + +// SetRedirectURIs sets the redirect URIs from a slice +func (c *OAuthServerClient) SetRedirectURIs(uris []string) { + c.RedirectURIs = strings.Join(uris, ",") +} + +// GetGrantTypes returns the grant types as a slice +func (c *OAuthServerClient) GetGrantTypes() []string { + if c.GrantTypes == "" { + return []string{} + } + return strings.Split(c.GrantTypes, ",") +} + +// SetGrantTypes sets the grant types from a slice +func (c *OAuthServerClient) SetGrantTypes(types []string) { + c.GrantTypes = strings.Join(types, ",") +} + +// validateRedirectURI validates a single redirect URI according to OAuth 2.1 spec +func validateRedirectURI(uri string) error { + if uri == "" { + return fmt.Errorf("redirect URI cannot be empty") + } + + parsedURL, err := url.Parse(uri) + if err != nil { + return fmt.Errorf("invalid URL format: %v", err) + } + + if parsedURL.Scheme == "" { + return fmt.Errorf("redirect URI must be absolute (include scheme)") + } + + if parsedURL.Fragment != "" { + return fmt.Errorf("redirect URI must not contain fragment") + } + + // Allow localhost for development, otherwise require HTTPS + if parsedURL.Scheme == "http" { + if parsedURL.Hostname() != "localhost" && parsedURL.Hostname() != "127.0.0.1" { + return fmt.Errorf("redirect URI must use HTTPS except for localhost") + } + } else if parsedURL.Scheme != "https" { + return fmt.Errorf("redirect URI must use HTTPS or HTTP for localhost") + } + + return nil +} + +// Error types for OAuth client operations +type OAuthServerClientNotFoundError struct{} + +func (e OAuthServerClientNotFoundError) Error() string { + return "OAuth client not found" +} + +type InvalidRedirectURIError struct { + URI string +} + +func (e InvalidRedirectURIError) Error() string { + return fmt.Sprintf("invalid redirect URI: %s", e.URI) +} + +// Query functions for OAuth clients + +// FindOAuthServerClientByID finds an OAuth client by ID +func FindOAuthServerClientByID(tx *storage.Connection, id uuid.UUID) (*OAuthServerClient, error) { + client := &OAuthServerClient{} + if err := tx.Q().Where("id = ? AND deleted_at IS NULL", id).First(client); err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return nil, OAuthServerClientNotFoundError{} + } + return nil, errors.Wrap(err, "error finding OAuth client") + } + return client, nil +} + +// FindOAuthServerClientByClientID finds an OAuth client by client_id +func FindOAuthServerClientByClientID(tx *storage.Connection, clientID string) (*OAuthServerClient, error) { + client := &OAuthServerClient{} + if err := tx.Q().Where("client_id = ? AND deleted_at IS NULL", clientID).First(client); err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return nil, OAuthServerClientNotFoundError{} + } + return nil, errors.Wrap(err, "error finding OAuth client") + } + return client, nil +} + +// CreateOAuthServerClient creates a new OAuth client in the database +func CreateOAuthServerClient(tx *storage.Connection, client *OAuthServerClient) error { + if err := client.Validate(); err != nil { + return err + } + + if client.ID == uuid.Nil { + client.ID = uuid.Must(uuid.NewV4()) + } + + now := time.Now() + client.CreatedAt = now + client.UpdatedAt = now + + return tx.Create(client) +} + +// UpdateOAuthServerClient updates an existing OAuth client in the database +func UpdateOAuthServerClient(tx *storage.Connection, client *OAuthServerClient) error { + if err := client.Validate(); err != nil { + return err + } + + client.UpdatedAt = time.Now() + return tx.Update(client) +} diff --git a/internal/models/oauth_client_test.go b/internal/models/oauth_client_test.go new file mode 100644 index 000000000..1a2607f6c --- /dev/null +++ b/internal/models/oauth_client_test.go @@ -0,0 +1,292 @@ +package models + +import ( + "testing" + "time" + + "github.com/gofrs/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/storage" + "github.com/supabase/auth/internal/storage/test" + "golang.org/x/crypto/bcrypt" +) + +type OAuthServerClientTestSuite struct { + suite.Suite + db *storage.Connection +} + +func (ts *OAuthServerClientTestSuite) SetupTest() { + TruncateAll(ts.db) +} + +func TestOAuthServerClient(t *testing.T) { + globalConfig, err := conf.LoadGlobal(modelsTestConfig) + require.NoError(t, err) + + conn, err := test.SetupDBConnection(globalConfig) + require.NoError(t, err) + + ts := &OAuthServerClientTestSuite{ + db: conn, + } + defer ts.db.Close() + + suite.Run(t, ts) +} + +func (ts *OAuthServerClientTestSuite) TestOAuthServerClientValidation() { + validClient := &OAuthServerClient{ + ID: uuid.Must(uuid.NewV4()), + ClientID: "test_client_id", + ClientName: storage.NullString("Test Client"), + RegistrationType: "dynamic", + RedirectURIs: "https://example.com/callback", + GrantTypes: "authorization_code,refresh_token", + } + + // Test valid client + err := validClient.Validate() + assert.NoError(ts.T(), err) + + // Test missing client_id + invalidClient := *validClient + invalidClient.ClientID = "" + err = invalidClient.Validate() + assert.Error(ts.T(), err) + assert.Contains(ts.T(), err.Error(), "client_id is required") + + // Test missing client_id + invalidClient = *validClient + invalidClient.ClientID = "" + err = invalidClient.Validate() + assert.Error(ts.T(), err) + assert.Contains(ts.T(), err.Error(), "client_id is required") + + // Test invalid registration type + invalidClient = *validClient + invalidClient.RegistrationType = "invalid" + err = invalidClient.Validate() + assert.Error(ts.T(), err) + assert.Contains(ts.T(), err.Error(), "registration_type must be 'dynamic' or 'manual'") + + // Test missing redirect URIs + invalidClient = *validClient + invalidClient.RedirectURIs = "" + err = invalidClient.Validate() + assert.Error(ts.T(), err) + assert.Contains(ts.T(), err.Error(), "at least one redirect_uri is required") +} + +func (ts *OAuthServerClientTestSuite) TestRedirectURIValidation() { + validURIs := []string{ + "https://example.com/callback", + "https://app.example.com/auth/callback", + "http://localhost:3000/callback", + "http://127.0.0.1:8080/auth", + } + + invalidURIs := []string{ + "", // empty + "not-a-url", // not a URL + "example.com/callback", // missing scheme + "ftp://example.com/callback", // invalid scheme + "https://example.com/callback#hash", // has fragment + "http://example.com/callback", // HTTP for non-localhost + } + + // Test valid URIs + for _, uri := range validURIs { + err := validateRedirectURI(uri) + assert.NoError(ts.T(), err, "URI should be valid: %s", uri) + } + + // Test invalid URIs + for _, uri := range invalidURIs { + err := validateRedirectURI(uri) + assert.Error(ts.T(), err, "URI should be invalid: %s", uri) + } +} + +func (ts *OAuthServerClientTestSuite) TestRedirectURIHelpers() { + client := &OAuthServerClient{} + + // Test setting and getting redirect URIs + uris := []string{ + "https://example.com/callback", + "https://app.example.com/auth", + "http://localhost:3000/callback", + } + + client.SetRedirectURIs(uris) + assert.Equal(ts.T(), "https://example.com/callback,https://app.example.com/auth,http://localhost:3000/callback", client.RedirectURIs) + + retrievedURIs := client.GetRedirectURIs() + assert.Equal(ts.T(), uris, retrievedURIs) + + // Test empty URIs + client.SetRedirectURIs([]string{}) + assert.Equal(ts.T(), "", client.RedirectURIs) + + retrievedURIs = client.GetRedirectURIs() + assert.Equal(ts.T(), []string{}, retrievedURIs) + + // Test getting URIs from empty string + client.RedirectURIs = "" + retrievedURIs = client.GetRedirectURIs() + assert.Equal(ts.T(), []string{}, retrievedURIs) +} + +func (ts *OAuthServerClientTestSuite) TestCreateOAuthServerClient() { + client := &OAuthServerClient{ + ClientID: "test_client_create_" + uuid.Must(uuid.NewV4()).String()[:8], + ClientName: storage.NullString("Test Application"), + GrantTypes: "authorization_code,refresh_token", + RegistrationType: "dynamic", + RedirectURIs: "https://example.com/callback", + } + + err := CreateOAuthServerClient(ts.db, client) + require.NoError(ts.T(), err) + + // Verify client was created with generated ID and timestamps + assert.NotEqual(ts.T(), uuid.Nil, client.ID) + assert.NotZero(ts.T(), client.CreatedAt) + assert.NotZero(ts.T(), client.UpdatedAt) +} + +func (ts *OAuthServerClientTestSuite) TestCreateOAuthServerClientValidation() { + invalidClient := &OAuthServerClient{ + ClientID: "", // Missing required field + } + + err := CreateOAuthServerClient(ts.db, invalidClient) + assert.Error(ts.T(), err) + assert.Contains(ts.T(), err.Error(), "client_id is required") +} + +func (ts *OAuthServerClientTestSuite) TestFindOAuthServerClientByID() { + // Create a test client + client := &OAuthServerClient{ + ClientID: "test_client_find_by_id_" + uuid.Must(uuid.NewV4()).String()[:8], + ClientName: storage.NullString("Find By ID Test"), + GrantTypes: "authorization_code,refresh_token", + RegistrationType: "dynamic", + RedirectURIs: "https://example.com/callback", + } + + err := CreateOAuthServerClient(ts.db, client) + require.NoError(ts.T(), err) + + // Find by ID + foundClient, err := FindOAuthServerClientByID(ts.db, client.ID) + require.NoError(ts.T(), err) + assert.Equal(ts.T(), client.ClientID, foundClient.ClientID) + assert.Equal(ts.T(), client.ClientName.String(), foundClient.ClientName.String()) + + // Test not found + _, err = FindOAuthServerClientByID(ts.db, uuid.Must(uuid.NewV4())) + assert.Error(ts.T(), err) + assert.True(ts.T(), IsNotFoundError(err)) +} + +func (ts *OAuthServerClientTestSuite) TestFindOAuthServerClientByClientID() { + // Create a test client + client := &OAuthServerClient{ + ClientID: "test_client_find_by_client_id_" + uuid.Must(uuid.NewV4()).String()[:8], + ClientName: storage.NullString("Find By Client ID Test"), + GrantTypes: "authorization_code,refresh_token", + RegistrationType: "manual", + RedirectURIs: "https://example.com/callback", + } + + err := CreateOAuthServerClient(ts.db, client) + require.NoError(ts.T(), err) + + // Find by client_id + foundClient, err := FindOAuthServerClientByClientID(ts.db, client.ClientID) + require.NoError(ts.T(), err) + assert.Equal(ts.T(), client.ID, foundClient.ID) + assert.Equal(ts.T(), client.ClientName.String(), foundClient.ClientName.String()) + + // Test not found + _, err = FindOAuthServerClientByClientID(ts.db, "nonexistent_client_id") + assert.Error(ts.T(), err) + assert.True(ts.T(), IsNotFoundError(err)) +} + +func (ts *OAuthServerClientTestSuite) TestUpdateOAuthServerClient() { + // Create a test client + client := &OAuthServerClient{ + ClientID: "test_client_update_" + uuid.Must(uuid.NewV4()).String()[:8], + ClientName: storage.NullString("Original Name"), + GrantTypes: "authorization_code,refresh_token", + RegistrationType: "dynamic", + RedirectURIs: "https://example.com/callback", + } + + err := CreateOAuthServerClient(ts.db, client) + require.NoError(ts.T(), err) + originalUpdatedAt := client.UpdatedAt + + // Update the client + client.ClientName = storage.NullString("Updated Name") + // client.Description removed - no longer exists + client.SetRedirectURIs([]string{"https://updated.example.com/callback"}) + + err = UpdateOAuthServerClient(ts.db, client) + require.NoError(ts.T(), err) + + // Verify updates + updatedClient, err := FindOAuthServerClientByID(ts.db, client.ID) + require.NoError(ts.T(), err) + assert.Equal(ts.T(), "Updated Name", updatedClient.ClientName.String()) + // assert.Equal(ts.T(), "Updated description", updatedClient.Description.String()) // Description field removed + assert.Equal(ts.T(), []string{"https://updated.example.com/callback"}, updatedClient.GetRedirectURIs()) + assert.True(ts.T(), updatedClient.UpdatedAt.After(originalUpdatedAt)) +} + +func (ts *OAuthServerClientTestSuite) TestClientSecretHashing() { + // Test that secrets can be properly hashed and validated + secret := "super_secret_client_secret" + + hash, err := bcrypt.GenerateFromPassword([]byte(secret), bcrypt.DefaultCost) + require.NoError(ts.T(), err) + + // Test correct secret validates + err = bcrypt.CompareHashAndPassword(hash, []byte(secret)) + assert.NoError(ts.T(), err) + + // Test incorrect secret fails + err = bcrypt.CompareHashAndPassword(hash, []byte("wrong_secret")) + assert.Error(ts.T(), err) +} + +func (ts *OAuthServerClientTestSuite) TestSoftDelete() { + // Create a test client + client := &OAuthServerClient{ + ClientID: "test_client_soft_delete_" + uuid.Must(uuid.NewV4()).String()[:8], + ClientName: storage.NullString("Soft Delete Test"), + GrantTypes: "authorization_code,refresh_token", + RegistrationType: "dynamic", + RedirectURIs: "https://example.com/callback", + } + + err := CreateOAuthServerClient(ts.db, client) + require.NoError(ts.T(), err) + + // Soft delete by setting deleted_at + now := time.Now() + client.DeletedAt = &now + + err = UpdateOAuthServerClient(ts.db, client) + require.NoError(ts.T(), err) + + // Verify client is not found in normal queries (which filter out deleted) + _, err = FindOAuthServerClientByClientID(ts.db, client.ClientID) + assert.Error(ts.T(), err) + assert.True(ts.T(), IsNotFoundError(err)) +} diff --git a/migrations/20250731150234_add_oauth_clients_table.up.sql b/migrations/20250731150234_add_oauth_clients_table.up.sql new file mode 100644 index 000000000..9abf69a49 --- /dev/null +++ b/migrations/20250731150234_add_oauth_clients_table.up.sql @@ -0,0 +1,34 @@ +-- Create enums for OAuth client fields +do $$ begin + create type {{ index .Options "Namespace" }}.oauth_registration_type as enum('dynamic', 'manual'); +exception + when duplicate_object then null; +end $$; + +-- Create oauth_clients table for OAuth client management +create table if not exists {{ index .Options "Namespace" }}.oauth_clients ( + id uuid not null, + client_id text not null, + client_secret_hash text not null, + registration_type {{ index .Options "Namespace" }}.oauth_registration_type not null, + redirect_uris text not null, + grant_types text not null, + client_name text null, + client_uri text null, + logo_uri text null, + created_at timestamptz not null default now(), + updated_at timestamptz not null default now(), + deleted_at timestamptz null, + constraint oauth_clients_pkey primary key (id), + constraint oauth_clients_client_id_key unique (client_id), + constraint oauth_clients_client_name_length check (char_length(client_name) <= 1024), + constraint oauth_clients_client_uri_length check (char_length(client_uri) <= 2048), + constraint oauth_clients_logo_uri_length check (char_length(logo_uri) <= 2048) +); + +-- Create indexes +create index if not exists oauth_clients_client_id_idx + on {{ index .Options "Namespace" }}.oauth_clients (client_id); + +create index if not exists oauth_clients_deleted_at_idx + on {{ index .Options "Namespace" }}.oauth_clients (deleted_at);