Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 33 additions & 4 deletions internal/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/sebest/xff"
"github.com/sirupsen/logrus"
"github.com/supabase/auth/internal/api/apierrors"
"github.com/supabase/auth/internal/api/oauthserver"
"github.com/supabase/auth/internal/conf"
"github.com/supabase/auth/internal/hooks/hookshttp"
"github.com/supabase/auth/internal/hooks/hookspgfunc"
Expand All @@ -35,8 +36,9 @@ type API struct {
config *conf.GlobalConfiguration
version string

hooksMgr *v0hooks.Manager
hibpClient *hibp.PwnedClient
hooksMgr *v0hooks.Manager
hibpClient *hibp.PwnedClient
oauthServer *oauthserver.Server

// overrideTime can be used to override the clock used by handlers. Should only be used in tests!
overrideTime func() time.Time
Expand Down Expand Up @@ -80,7 +82,12 @@ func (a *API) deprecationNotices() {

// NewAPIWithVersion creates a new REST API using the specified version
func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Connection, version string, opt ...Option) *API {
api := &API{config: globalConfig, db: db, version: version}
api := &API{
config: globalConfig,
db: db,
version: version,
oauthServer: oauthserver.NewServer(globalConfig, db),
}

for _, o := range opt {
o.apply(api)
Expand Down Expand Up @@ -197,7 +204,7 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne
With(api.verifyCaptcha).Post("/otp", api.Otp)

// rate limiting applied in handler
r.With(api.verifyCaptcha).Post("/token", api.Token)
r.With(api.verifyCaptcha).With(api.oauthClientAuth).Post("/token", api.Token)

r.With(api.limitHandler(api.limiterOpts.Verify)).Route("/verify", func(r *router) {
r.Get("/", api.Verify)
Expand Down Expand Up @@ -293,6 +300,28 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne
})
})
})

// Admin only oauth client management endpoints
r.Route("/oauth", func(r *router) {
r.Route("/clients", func(r *router) {
// Manual client registration
r.Post("/", api.oauthServer.AdminOAuthServerClientRegister)

r.Get("/", api.oauthServer.OAuthServerClientList)

r.Route("/{client_id}", func(r *router) {
r.Use(api.oauthServer.LoadOAuthServerClient)
r.Get("/", api.oauthServer.OAuthServerClientGet)
r.Delete("/", api.oauthServer.OAuthServerClientDelete)
})
})
})
})

// OAuth Dynamic Client Registration endpoint (public, rate limited)
r.Route("/oauth", func(r *router) {
r.With(api.limitHandler(api.limiterOpts.OAuthClientRegister)).
Post("/clients/register", api.oauthServer.OAuthServerClientDynamicRegister)
})
})

Expand Down
2 changes: 2 additions & 0 deletions internal/api/apierrors/errorcode.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,4 +95,6 @@ const (
ErrorCodeEmailAddressInvalid ErrorCode = "email_address_invalid"
ErrorCodeWeb3ProviderDisabled ErrorCode = "web3_provider_disabled"
ErrorCodeWeb3UnsupportedChain ErrorCode = "web3_unsupported_chain"

ErrorCodeOAuthDynamicClientRegistrationDisabled ErrorCode = "oauth_dynamic_client_registration_disabled"
)
12 changes: 2 additions & 10 deletions internal/api/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@ package api
import (
"context"
"encoding/json"
"fmt"
"net/http"

"github.com/pkg/errors"
"github.com/supabase/auth/internal/api/apierrors"
"github.com/supabase/auth/internal/api/shared"
"github.com/supabase/auth/internal/conf"
"github.com/supabase/auth/internal/models"
"github.com/supabase/auth/internal/security"
Expand All @@ -16,14 +15,7 @@ import (
)

func sendJSON(w http.ResponseWriter, status int, obj interface{}) error {
w.Header().Set("Content-Type", "application/json")
b, err := json.Marshal(obj)
if err != nil {
return errors.Wrap(err, fmt.Sprintf("Error encoding json response: %v", obj))
}
w.WriteHeader(status)
_, err = w.Write(b)
return err
return shared.SendJSON(w, status, obj)
}

func isAdmin(u *models.User, config *conf.GlobalConfiguration) bool {
Expand Down
36 changes: 36 additions & 0 deletions internal/api/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
chimiddleware "github.com/go-chi/chi/v5/middleware"
"github.com/sirupsen/logrus"
"github.com/supabase/auth/internal/api/apierrors"
"github.com/supabase/auth/internal/api/oauthserver"
"github.com/supabase/auth/internal/models"
"github.com/supabase/auth/internal/observability"
"github.com/supabase/auth/internal/security"
Expand Down Expand Up @@ -81,6 +82,41 @@ func (a *API) limitHandler(lmt *limiter.Limiter) middlewareHandler {
}
}

// oauthClientAuth optionally authenticates an OAuth client as middleware
// This doesn't fail if no client credentials are provided, but validates them if present
func (a *API) oauthClientAuth(w http.ResponseWriter, r *http.Request) (context.Context, error) {
ctx := r.Context()

clientID, clientSecret, err := oauthserver.ExtractClientCredentials(r)
if err != nil {
return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeInvalidCredentials, "Invalid client credentials: "+err.Error())
}

// If no client credentials provided, continue without client authentication
if clientID == "" {
return ctx, nil
}

// Validate client credentials
db := a.db.WithContext(ctx)
client, err := models.FindOAuthServerClientByClientID(db, clientID)
if err != nil {
if models.IsNotFoundError(err) {
return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeInvalidCredentials, "Invalid client credentials")
}
return nil, apierrors.NewInternalServerError("Error validating client credentials").WithInternalError(err)
}

// Validate client secret
if !oauthserver.ValidateClientSecret(clientSecret, client.ClientSecretHash) {
return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeInvalidCredentials, "Invalid client credentials")
}

// Add authenticated client to context
ctx = oauthserver.WithOAuthServerClient(ctx, client)
return ctx, nil
}

func (a *API) requireAdminCredentials(w http.ResponseWriter, req *http.Request) (context.Context, error) {
t, err := a.extractBearerToken(req)
if err != nil || t == "" {
Expand Down
50 changes: 50 additions & 0 deletions internal/api/oauthserver/auth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package oauthserver

import (
"encoding/base64"
"errors"
"net/http"
"strings"
)

// ExtractClientCredentials extracts OAuth client credentials from the request
// Supports both Basic auth header and form body parameters
func ExtractClientCredentials(r *http.Request) (clientID, clientSecret string, err error) {
// First, try Basic auth header: Authorization: Basic base64(client_id:client_secret)
authHeader := r.Header.Get("Authorization")
if authHeader != "" && strings.HasPrefix(authHeader, "Basic ") {
encoded := strings.TrimPrefix(authHeader, "Basic ")
decoded, err := base64.StdEncoding.DecodeString(encoded)
if err != nil {
return "", "", errors.New("invalid basic auth encoding")
}

credentials := string(decoded)
parts := strings.SplitN(credentials, ":", 2)
if len(parts) != 2 {
return "", "", errors.New("invalid basic auth format")
}

return parts[0], parts[1], nil
}

// Fall back to form parameters
if err := r.ParseForm(); err != nil {
return "", "", errors.New("failed to parse form")
}

clientID = r.FormValue("client_id")
clientSecret = r.FormValue("client_secret")

// Return empty credentials if both are empty (no client auth attempted)
if clientID == "" && clientSecret == "" {
return "", "", nil
}

// If only one is provided, it's an error
if clientID == "" || clientSecret == "" {
return "", "", errors.New("both client_id and client_secret must be provided")
}

return clientID, clientSecret, nil
}
184 changes: 184 additions & 0 deletions internal/api/oauthserver/handlers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
package oauthserver

import (
"context"
"encoding/json"
"net/http"
"time"

"github.com/go-chi/chi/v5"
"github.com/supabase/auth/internal/api/apierrors"
"github.com/supabase/auth/internal/api/shared"
"github.com/supabase/auth/internal/models"
"github.com/supabase/auth/internal/observability"
)

// OAuthServerClientResponse represents the response format for OAuth client operations
type OAuthServerClientResponse struct {
ClientID string `json:"client_id"`
ClientSecret string `json:"client_secret,omitempty"` // only returned on registration

RedirectURIs []string `json:"redirect_uris"`
TokenEndpointAuthMethod []string `json:"token_endpoint_auth_method"`
GrantTypes []string `json:"grant_types"`
ResponseTypes []string `json:"response_types"`
ClientName string `json:"client_name,omitempty"`
ClientURI string `json:"client_uri,omitempty"`
LogoURI string `json:"logo_uri,omitempty"`

// Metadata fields
RegistrationType string `json:"registration_type"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}

// OAuthServerClientListResponse represents the response for listing OAuth clients
type OAuthServerClientListResponse struct {
Clients []OAuthServerClientResponse `json:"clients"`
}

// oauthServerClientToResponse converts a model to response format
func oauthServerClientToResponse(client *models.OAuthServerClient, includeSecret bool) *OAuthServerClientResponse {
response := &OAuthServerClientResponse{
ClientID: client.ClientID,

// OAuth 2.1 DCR fields
RedirectURIs: client.GetRedirectURIs(),
TokenEndpointAuthMethod: []string{"client_secret_basic", "client_secret_post"}, // Both methods are supported
GrantTypes: client.GetGrantTypes(),
ResponseTypes: []string{"code"}, // Always "code" in OAuth 2.1
ClientName: client.ClientName.String(),
ClientURI: client.ClientURI.String(),
LogoURI: client.LogoURI.String(),

// Metadata fields
RegistrationType: client.RegistrationType,
CreatedAt: client.CreatedAt,
UpdatedAt: client.UpdatedAt,
}

// Only include client_secret during registration
if includeSecret {
// Note: This will be filled in by the handler with the plaintext secret
response.ClientSecret = ""
}

return response
}

// LoadOAuthServerClient is middleware that loads an OAuth server client from the URL parameter
func (s *Server) LoadOAuthServerClient(w http.ResponseWriter, r *http.Request) (context.Context, error) {
ctx := r.Context()
clientID := chi.URLParam(r, "client_id")

if clientID == "" {
return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "client_id is required")
}

observability.LogEntrySetField(r, "oauth_client_id", clientID)

client, err := s.getOAuthServerClient(ctx, clientID)
if err != nil {
if models.IsNotFoundError(err) {
return nil, apierrors.NewNotFoundError(apierrors.ErrorCodeUserNotFound, "OAuth client not found")
}
return nil, apierrors.NewInternalServerError("Error loading OAuth client").WithInternalError(err)
}

ctx = WithOAuthServerClient(ctx, client)
return ctx, nil
}

// AdminOAuthServerClientRegister handles POST /admin/oauth/clients (manual registration by admins)
func (s *Server) AdminOAuthServerClientRegister(w http.ResponseWriter, r *http.Request) error {
ctx := r.Context()

var params OAuthServerClientRegisterParams
if err := json.NewDecoder(r.Body).Decode(&params); err != nil {
return apierrors.NewBadRequestError(apierrors.ErrorCodeBadJSON, "Invalid JSON body")
}

// Force registration type to manual for admin endpoint
params.RegistrationType = "manual"

client, plaintextSecret, err := s.registerOAuthServerClient(ctx, &params)
if err != nil {
return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, err.Error())
}

response := oauthServerClientToResponse(client, true)
response.ClientSecret = plaintextSecret

return shared.SendJSON(w, http.StatusCreated, response)
}

// OAuthServerClientDynamicRegister handles POST /oauth/register (OAuth 2.1 Dynamic Client Registration)
func (s *Server) OAuthServerClientDynamicRegister(w http.ResponseWriter, r *http.Request) error {
ctx := r.Context()

// Check if dynamic registration is enabled
if !s.config.OAuthServer.AllowDynamicRegistration {
return apierrors.NewForbiddenError(apierrors.ErrorCodeOAuthDynamicClientRegistrationDisabled, "Dynamic client registration is not enabled")
}

var params OAuthServerClientRegisterParams
if err := json.NewDecoder(r.Body).Decode(&params); err != nil {
return apierrors.NewBadRequestError(apierrors.ErrorCodeBadJSON, "Invalid JSON body")
}

params.RegistrationType = "dynamic"

client, plaintextSecret, err := s.registerOAuthServerClient(ctx, &params)
if err != nil {
return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, err.Error())
}

response := oauthServerClientToResponse(client, true)
response.ClientSecret = plaintextSecret

return shared.SendJSON(w, http.StatusCreated, response)
}

// OAuthServerClientGet handles GET /admin/oauth/clients/{client_id}
func (s *Server) OAuthServerClientGet(w http.ResponseWriter, r *http.Request) error {
ctx := r.Context()
client := GetOAuthServerClient(ctx)

response := oauthServerClientToResponse(client, false)
return shared.SendJSON(w, http.StatusOK, response)
}

// OAuthServerClientDelete handles DELETE /admin/oauth/clients/{client_id}
func (s *Server) OAuthServerClientDelete(w http.ResponseWriter, r *http.Request) error {
ctx := r.Context()
client := GetOAuthServerClient(ctx)

if err := s.deleteOAuthServerClient(ctx, client.ClientID); err != nil {
return apierrors.NewInternalServerError("Error deleting OAuth client").WithInternalError(err)
}

w.WriteHeader(http.StatusNoContent)
return nil
}

// OAuthServerClientList handles GET /admin/oauth/clients
func (s *Server) OAuthServerClientList(w http.ResponseWriter, r *http.Request) error {
ctx := r.Context()
db := s.db.WithContext(ctx)

var clients []models.OAuthServerClient
if err := db.Q().Where("deleted_at is null").Order("created_at desc").All(&clients); err != nil {
return apierrors.NewInternalServerError("Error listing OAuth clients").WithInternalError(err)
}

responses := make([]OAuthServerClientResponse, len(clients))
for i, client := range clients {
responses[i] = *oauthServerClientToResponse(&client, false)
}

response := OAuthServerClientListResponse{
Clients: responses,
}

return shared.SendJSON(w, http.StatusOK, response)
}
Loading