diff --git a/oauthex/auth_meta.go b/oauthex/auth_meta.go index 3723318f..9aa0c8d7 100644 --- a/oauthex/auth_meta.go +++ b/oauthex/auth_meta.go @@ -22,6 +22,10 @@ import ( // Not supported: // - signed metadata // +// Note: URL fields in this struct are validated by validateAuthServerMetaURLs to +// prevent XSS attacks. If you add a new URL field, you must also add it to that +// function. +// // [RFC 8414]: https://tools.ietf.org/html/rfc8414) type AuthServerMeta struct { // GENERATED BY GEMINI 2.5. @@ -144,9 +148,40 @@ func GetAuthServerMeta(ctx context.Context, issuerURL string, c *http.Client) (* return nil, fmt.Errorf("authorization server at %s does not implement PKCE", issuerURL) } + // Validate endpoint URLs to prevent XSS attacks (see #526). + if err := validateAuthServerMetaURLs(asm); err != nil { + return nil, err + } + return asm, nil } errs = append(errs, err) } return nil, fmt.Errorf("failed to get auth server metadata from %q: %w", issuerURL, errors.Join(errs...)) } + +// validateAuthServerMetaURLs validates all URL fields in AuthServerMeta +// to ensure they don't use dangerous schemes that could enable XSS attacks. +func validateAuthServerMetaURLs(asm *AuthServerMeta) error { + urls := []struct { + name string + value string + }{ + {"authorization_endpoint", asm.AuthorizationEndpoint}, + {"token_endpoint", asm.TokenEndpoint}, + {"jwks_uri", asm.JWKSURI}, + {"registration_endpoint", asm.RegistrationEndpoint}, + {"service_documentation", asm.ServiceDocumentation}, + {"op_policy_uri", asm.OpPolicyURI}, + {"op_tos_uri", asm.OpTOSURI}, + {"revocation_endpoint", asm.RevocationEndpoint}, + {"introspection_endpoint", asm.IntrospectionEndpoint}, + } + + for _, u := range urls { + if err := checkURLScheme(u.value); err != nil { + return fmt.Errorf("%s: %w", u.name, err) + } + } + return nil +} diff --git a/oauthex/dcr.go b/oauthex/dcr.go index 75ce2961..c64cb8cd 100644 --- a/oauthex/dcr.go +++ b/oauthex/dcr.go @@ -20,6 +20,10 @@ import ( ) // ClientRegistrationMetadata represents the client metadata fields for the DCR POST request (RFC 7591). +// +// Note: URL fields in this struct are validated by validateClientRegistrationURLs +// to prevent XSS attacks. If you add a new URL field, you must also add it to +// that function. type ClientRegistrationMetadata struct { // RedirectURIs is a REQUIRED JSON array of redirection URI strings for use in // redirect-based flows (such as the authorization code grant). @@ -208,6 +212,10 @@ func RegisterClient(ctx context.Context, registrationEndpoint string, clientMeta if regResponse.ClientID == "" { return nil, fmt.Errorf("registration response is missing required 'client_id' field") } + // Validate URL fields to prevent XSS attacks (see #526). + if err := validateClientRegistrationURLs(®Response.ClientRegistrationMetadata); err != nil { + return nil, err + } return ®Response, nil } @@ -221,3 +229,33 @@ func RegisterClient(ctx context.Context, registrationEndpoint string, clientMeta return nil, fmt.Errorf("registration failed with status %s: %s", resp.Status, string(body)) } + +// validateClientRegistrationURLs validates all URL fields in ClientRegistrationMetadata +// to ensure they don't use dangerous schemes that could enable XSS attacks. +func validateClientRegistrationURLs(meta *ClientRegistrationMetadata) error { + // Validate redirect URIs + for i, uri := range meta.RedirectURIs { + if err := checkURLScheme(uri); err != nil { + return fmt.Errorf("redirect_uris[%d]: %w", i, err) + } + } + + // Validate other URL fields + urls := []struct { + name string + value string + }{ + {"client_uri", meta.ClientURI}, + {"logo_uri", meta.LogoURI}, + {"tos_uri", meta.TOSURI}, + {"policy_uri", meta.PolicyURI}, + {"jwks_uri", meta.JWKSURI}, + } + + for _, u := range urls { + if err := checkURLScheme(u.value); err != nil { + return fmt.Errorf("%s: %w", u.name, err) + } + } + return nil +} diff --git a/oauthex/url_scheme_test.go b/oauthex/url_scheme_test.go new file mode 100644 index 00000000..531a1f9c --- /dev/null +++ b/oauthex/url_scheme_test.go @@ -0,0 +1,314 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +//go:build mcp_go_client_oauth + +package oauthex + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +// TestCheckURLScheme tests the checkURLScheme function directly. +func TestCheckURLScheme(t *testing.T) { + tests := []struct { + name string + url string + wantErr bool + }{ + // Valid schemes + {"empty string", "", false}, + {"https url", "https://example.com/path", false}, + {"http url", "http://example.com/path", false}, + {"custom scheme", "myapp://callback", false}, + + // Dangerous schemes that should be blocked + {"javascript scheme", "javascript:alert('xss')", true}, + {"javascript uppercase", "JAVASCRIPT:alert('xss')", true}, + {"javascript mixed case", "JaVaScRiPt:alert('xss')", true}, + {"data scheme", "data:text/html,", true}, + {"data uppercase", "DATA:text/html,test", true}, + {"vbscript scheme", "vbscript:msgbox('xss')", true}, + {"vbscript uppercase", "VBSCRIPT:test", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := checkURLScheme(tt.url) + if (err != nil) != tt.wantErr { + t.Errorf("checkURLScheme(%q): got err %v, want err %v", tt.url, err != nil, tt.wantErr) + } + }) + } +} + +// TestValidateAuthServerMetaURLs tests validation of AuthServerMeta URL fields. +func TestValidateAuthServerMetaURLs(t *testing.T) { + validMeta := &AuthServerMeta{ + Issuer: "https://auth.example.com", + AuthorizationEndpoint: "https://auth.example.com/authorize", + TokenEndpoint: "https://auth.example.com/token", + JWKSURI: "https://auth.example.com/.well-known/jwks.json", + RegistrationEndpoint: "https://auth.example.com/register", + ServiceDocumentation: "https://docs.example.com", + OpPolicyURI: "https://example.com/policy", + OpTOSURI: "https://example.com/tos", + RevocationEndpoint: "https://auth.example.com/revoke", + IntrospectionEndpoint: "https://auth.example.com/introspect", + } + + t.Run("valid metadata", func(t *testing.T) { + if err := validateAuthServerMetaURLs(validMeta); err != nil { + t.Errorf("validateAuthServerMetaURLs(): got err %v, want nil", err) + } + }) + + // Test each URL field with a dangerous scheme + dangerousFields := []struct { + name string + setField func(*AuthServerMeta) + }{ + {"authorization_endpoint", func(m *AuthServerMeta) { m.AuthorizationEndpoint = "javascript:alert(1)" }}, + {"token_endpoint", func(m *AuthServerMeta) { m.TokenEndpoint = "javascript:alert(1)" }}, + {"jwks_uri", func(m *AuthServerMeta) { m.JWKSURI = "data:text/html,test" }}, + {"registration_endpoint", func(m *AuthServerMeta) { m.RegistrationEndpoint = "vbscript:test" }}, + {"service_documentation", func(m *AuthServerMeta) { m.ServiceDocumentation = "javascript:void(0)" }}, + {"op_policy_uri", func(m *AuthServerMeta) { m.OpPolicyURI = "javascript:x" }}, + {"op_tos_uri", func(m *AuthServerMeta) { m.OpTOSURI = "data:,test" }}, + {"revocation_endpoint", func(m *AuthServerMeta) { m.RevocationEndpoint = "javascript:1" }}, + {"introspection_endpoint", func(m *AuthServerMeta) { m.IntrospectionEndpoint = "javascript:2" }}, + } + + for _, tt := range dangerousFields { + t.Run("dangerous "+tt.name, func(t *testing.T) { + // Copy valid metadata + meta := *validMeta + // Set one field to a dangerous value + tt.setField(&meta) + + err := validateAuthServerMetaURLs(&meta) + if err == nil { + t.Errorf("validateAuthServerMetaURLs(): got nil error, want error for dangerous %s", tt.name) + } else if !strings.Contains(err.Error(), tt.name) { + t.Errorf("validateAuthServerMetaURLs(): got error %v, want error containing %q", err, tt.name) + } + }) + } + + t.Run("empty optional fields are valid", func(t *testing.T) { + meta := &AuthServerMeta{ + Issuer: "https://auth.example.com", + AuthorizationEndpoint: "https://auth.example.com/authorize", + TokenEndpoint: "https://auth.example.com/token", + JWKSURI: "https://auth.example.com/.well-known/jwks.json", + // All optional fields left empty + } + if err := validateAuthServerMetaURLs(meta); err != nil { + t.Errorf("validateAuthServerMetaURLs(): got err %v, want nil", err) + } + }) +} + +// TestValidateClientRegistrationURLs tests validation of ClientRegistrationMetadata URL fields. +func TestValidateClientRegistrationURLs(t *testing.T) { + validMeta := &ClientRegistrationMetadata{ + RedirectURIs: []string{"https://app.example.com/callback", "myapp://callback"}, + ClientURI: "https://example.com", + LogoURI: "https://example.com/logo.png", + TOSURI: "https://example.com/tos", + PolicyURI: "https://example.com/policy", + JWKSURI: "https://example.com/.well-known/jwks.json", + } + + t.Run("valid metadata", func(t *testing.T) { + if err := validateClientRegistrationURLs(validMeta); err != nil { + t.Errorf("validateClientRegistrationURLs(): got err %v, want nil", err) + } + }) + + t.Run("dangerous redirect_uri", func(t *testing.T) { + meta := *validMeta + meta.RedirectURIs = []string{"https://safe.com/cb", "javascript:alert(1)"} + + err := validateClientRegistrationURLs(&meta) + if err == nil { + t.Error("validateClientRegistrationURLs(): got nil error, want error for dangerous redirect_uri") + } else if !strings.Contains(err.Error(), "redirect_uris[1]") { + t.Errorf("validateClientRegistrationURLs(): got error %v, want error containing \"redirect_uris[1]\"", err) + } + }) + + // Test each URL field with a dangerous scheme + dangerousFields := []struct { + name string + setField func(*ClientRegistrationMetadata) + }{ + {"client_uri", func(m *ClientRegistrationMetadata) { m.ClientURI = "javascript:alert(1)" }}, + {"logo_uri", func(m *ClientRegistrationMetadata) { m.LogoURI = "data:image/svg," }}, + {"tos_uri", func(m *ClientRegistrationMetadata) { m.TOSURI = "vbscript:test" }}, + {"policy_uri", func(m *ClientRegistrationMetadata) { m.PolicyURI = "javascript:void(0)" }}, + {"jwks_uri", func(m *ClientRegistrationMetadata) { m.JWKSURI = "data:application/json,{}" }}, + } + + for _, tt := range dangerousFields { + t.Run("dangerous "+tt.name, func(t *testing.T) { + meta := *validMeta + tt.setField(&meta) + + err := validateClientRegistrationURLs(&meta) + if err == nil { + t.Errorf("validateClientRegistrationURLs(): got nil error, want error for dangerous %s", tt.name) + } else if !strings.Contains(err.Error(), tt.name) { + t.Errorf("validateClientRegistrationURLs(): got error %v, want error containing %q", err, tt.name) + } + }) + } + + t.Run("empty optional fields are valid", func(t *testing.T) { + meta := &ClientRegistrationMetadata{ + RedirectURIs: []string{"https://app.example.com/callback"}, + // All optional URL fields left empty + } + if err := validateClientRegistrationURLs(meta); err != nil { + t.Errorf("validateClientRegistrationURLs(): got err %v, want nil", err) + } + }) +} + +// TestGetAuthServerMetaRejectsDangerousURLs tests that GetAuthServerMeta rejects +// metadata containing dangerous URL schemes. +func TestGetAuthServerMetaRejectsDangerousURLs(t *testing.T) { + tests := []struct { + name string + metadata AuthServerMeta + wantErrText string + }{ + { + name: "javascript authorization_endpoint", + metadata: AuthServerMeta{ + Issuer: "", // Will be set dynamically + AuthorizationEndpoint: "javascript:alert('xss')", + TokenEndpoint: "https://auth.example.com/token", + JWKSURI: "https://auth.example.com/.well-known/jwks.json", + ResponseTypesSupported: []string{"code"}, + CodeChallengeMethodsSupported: []string{"S256"}, + }, + wantErrText: "authorization_endpoint", + }, + { + name: "data token_endpoint", + metadata: AuthServerMeta{ + Issuer: "", + AuthorizationEndpoint: "https://auth.example.com/authorize", + TokenEndpoint: "data:text/html,", + JWKSURI: "https://auth.example.com/.well-known/jwks.json", + ResponseTypesSupported: []string{"code"}, + CodeChallengeMethodsSupported: []string{"S256"}, + }, + wantErrText: "token_endpoint", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + meta := tt.metadata + meta.Issuer = "https://" + r.Host + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(meta) + })) + defer server.Close() + + ctx := context.Background() + _, err := GetAuthServerMeta(ctx, server.URL, server.Client()) + if err == nil { + t.Fatal("GetAuthServerMeta(): got nil error, want error") + } + if !strings.Contains(err.Error(), tt.wantErrText) { + t.Errorf("GetAuthServerMeta(): got error %v, want error containing %q", err, tt.wantErrText) + } + }) + } +} + +// TestGetProtectedResourceMetadataRejectsDangerousURLs tests that +// GetProtectedResourceMetadataFromID rejects metadata with dangerous authorization server URLs. +func TestGetProtectedResourceMetadataRejectsDangerousURLs(t *testing.T) { + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + serverURL := "https://" + r.Host + meta := ProtectedResourceMetadata{ + Resource: serverURL, + AuthorizationServers: []string{"javascript:alert('xss')"}, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(meta) + })) + defer server.Close() + + ctx := context.Background() + _, err := GetProtectedResourceMetadataFromID(ctx, server.URL, server.Client()) + if err == nil { + t.Fatal("GetProtectedResourceMetadataFromID(): got nil error, want error") + } + if !strings.Contains(err.Error(), "disallowed scheme") { + t.Errorf("GetProtectedResourceMetadataFromID(): got error %v, want error containing \"disallowed scheme\"", err) + } +} + +// TestRegisterClientRejectsDangerousURLs tests that RegisterClient rejects +// responses containing dangerous URL schemes. +func TestRegisterClientRejectsDangerousURLs(t *testing.T) { + tests := []struct { + name string + responseJSON string + wantErrText string + }{ + { + name: "javascript redirect_uri in response", + responseJSON: `{ + "client_id": "test-client", + "redirect_uris": ["javascript:alert(1)"] + }`, + wantErrText: "redirect_uris[0]", + }, + { + name: "data client_uri", + responseJSON: `{ + "client_id": "test-client", + "redirect_uris": ["https://app.example.com/callback"], + "client_uri": "data:text/html," + }`, + wantErrText: "client_uri", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusCreated) + w.Write([]byte(tt.responseJSON)) + })) + defer server.Close() + + ctx := context.Background() + clientMeta := &ClientRegistrationMetadata{ + RedirectURIs: []string{"https://app.example.com/callback"}, + } + + _, err := RegisterClient(ctx, server.URL+"/register", clientMeta, server.Client()) + if err == nil { + t.Fatal("RegisterClient(): got nil error, want error") + } + if !strings.Contains(err.Error(), tt.wantErrText) { + t.Errorf("RegisterClient(): got error %v, want error containing %q", err, tt.wantErrText) + } + }) + } +}