diff --git a/docs/Configuring.md b/docs/Configuring.md index 0242f4ee2..a700919a0 100644 --- a/docs/Configuring.md +++ b/docs/Configuring.md @@ -10,12 +10,27 @@ The platform leverages [viper](https://github.com/spf13/viper) to help load conf - [SDK Configuration](#sdk-configuration) - [Logger Configuration](#logger-configuration) - [Server Configuration](#server-configuration) + - [CORS Configuration](#cors-configuration) + - [Additive Configuration](#additive-configuration) + - [Programmatic Configuration](#programmatic-configuration) - [Crypto Provider](#crypto-provider) - [Database Configuration](#database-configuration) - [Tracing Configuration](#tracing-configuration) + - [Security Configuration](#security-configuration) - [Services Configuration](#services-configuration) - [Key Access Server (KAS)](#key-access-server-kas) - [Authorization](#authorization) + - [Shared Keys (v1 \& v2)](#shared-keys-v1--v2) + - [Authorization v1 Only](#authorization-v1-only) + - [Authorization v2 Only](#authorization-v2-only) + - [Example: Authorization v1](#example-authorization-v1) + - [Example: Authorization v2](#example-authorization-v2) + - [Entity Resolution](#entity-resolution) + - [Shared Keys (v1 \& v2)](#shared-keys-v1--v2-1) + - [Entity Resolution v1 Only](#entity-resolution-v1-only) + - [Entity Resolution v2 Only](#entity-resolution-v2-only) + - [Example: Entity Resolution v1](#example-entity-resolution-v1) + - [Example: Entity Resolution v2](#example-entity-resolution-v2) - [Policy](#policy) - [Casbin Endpoint Authorization](#casbin-endpoint-authorization) - [Key Aspects of Authorization Configuration](#key-aspects-of-authorization-configuration) @@ -52,8 +67,8 @@ mode: core,-policy mode: all,-kas,-entityresolution ``` -| Field | Description | Default | Environment Variable | -| ------ | ----------------------------------------------------------------------------- | ------- | -------------------- | +| Field | Description | Default | Environment Variable | +| ------ | ---------------------------------------------------------------------------------------------------------------------------------------------------- | ------- | -------------------- | | `mode` | Drives which services to run. Supported modes: `all`, `core`, `kas`. Use `-servicename` to exclude specific services (e.g., `all,-entityresolution`) | `all` | OPENTDF_MODE | ## SDK Configuration @@ -148,6 +163,74 @@ server: cert: kas-ec-cert.pem ``` +### CORS Configuration + +Root level key `server.cors` + +| Field | Description | Default | Environment Variable | +| -------------------------- | ------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -------------------------------------------- | +| `enabled` | Enable CORS for the server | `true` | OPENTDF_SERVER_CORS_ENABLED | +| `allowedorigins` | List of allowed origins (`*` for any) | `[]` | OPENTDF_SERVER_CORS_ALLOWEDORIGINS | +| `allowedmethods` | List of allowed HTTP methods | `["GET","POST","PATCH","DELETE","OPTIONS"]` | OPENTDF_SERVER_CORS_ALLOWEDMETHODS | +| `allowedheaders` | List of allowed request headers | `["Accept","Accept-Encoding","Authorization","Connect-Protocol-Version","Content-Length","Content-Type","Dpop","X-CSRF-Token","X-Requested-With","X-Rewrap-Additional-Context"]` | OPENTDF_SERVER_CORS_ALLOWEDHEADERS | +| `exposedheaders` | List of response headers browsers can access | `[]` | OPENTDF_SERVER_CORS_EXPOSEDHEADERS | +| `allowcredentials` | Whether credentials are included in CORS requests | `true` | OPENTDF_SERVER_CORS_ALLOWCREDENTIALS | +| `maxage` | Maximum age (seconds) of preflight cache | `3600` | OPENTDF_SERVER_CORS_MAXAGE | +| `additionalmethods` | Additional methods to append to defaults | `[]` | OPENTDF_SERVER_CORS_ADDITIONALMETHODS | +| `additionalheaders` | Additional headers to append to defaults | `[]` | OPENTDF_SERVER_CORS_ADDITIONALHEADERS | +| `additionalexposedheaders` | Additional exposed headers to append | `[]` | OPENTDF_SERVER_CORS_ADDITIONALEXPOSEDHEADERS | + +#### Additive Configuration + +The `additional*` fields allow operators to extend the default lists without replacing them entirely: + +```yaml +server: + cors: + enabled: true + # Add custom headers without copying all defaults + additionalheaders: + - X-Custom-Header + - X-Another-Header +``` + +To completely replace defaults, use the base fields directly: + +```yaml +server: + cors: + allowedheaders: + - Authorization + - Content-Type + # Only these headers will be allowed +``` + +#### Programmatic Configuration + +For applications embedding the OpenTDF platform, CORS can also be configured programmatically using functional options. These are applied after YAML/environment configuration and follow the same additive semantics: + +```go +import "github.com/opentdf/platform/service/pkg/server" + +err := server.Start( + server.WithConfigFile("opentdf.yaml"), + // Add custom headers for your application + server.WithAdditionalCORSHeaders("X-Custom-Header", "X-App-Version"), + // Add custom methods if needed + server.WithAdditionalCORSMethods("CUSTOM"), + // Expose additional response headers to browsers + server.WithAdditionalCORSExposedHeaders("X-Request-Id", "X-Trace-Id"), +) +``` + +**Configuration Precedence:** + +1. **Defaults** - Built-in default values +2. **YAML/Environment** - Operator configuration via `server.cors.*` fields +3. **Programmatic Options** - Developer overlays via `WithAdditionalCORS*` functions + +All layers are additive. Deduplication is handled automatically (case-insensitive for headers per RFC 7230, case-sensitive for methods per RFC 7231). + ### Crypto Provider To configure the Key Access Server, @@ -244,9 +327,9 @@ For OTLP provider: Root level key `security` -| Field | Description | Default | -|-----------------------------|-------------------------------------------------------------------------------------------------|---------| -| `unsafe.clock_skew` | Platform-wide maximum tolerated clock skew for token verification (Go duration, use cautiously) | `1m` | +| Field | Description | Default | +| ------------------- | ----------------------------------------------------------------------------------------------- | ------- | +| `unsafe.clock_skew` | Platform-wide maximum tolerated clock skew for token verification (Go duration, use cautiously) | `1m` | > **Warning:** Increasing `unsafe.clock_skew` weakens token freshness guarantees. Only raise this value temporarily while you correct clock drift. @@ -300,23 +383,23 @@ Root level key `authorization` #### Shared Keys (v1 & v2) -| Field | Description | Default | Environment Variables | -|-------|-------------|---------|----------------------| -| *(none currently; all keys are version-specific)* | | | | +| Field | Description | Default | Environment Variables | +| ------------------------------------------------- | ----------- | ------- | --------------------- | +| *(none currently; all keys are version-specific)* | | | | #### Authorization v1 Only -| Field | Description | Default | Environment Variables | -|--------------|------------------------------|--------------------------------|---------------------------------------------| -| `rego.path` | Path to rego policy file | Leverages embedded rego policy | OPENTDF_SERVICES_AUTHORIZATION_REGO_PATH | -| `rego.query` | Rego query to execute | `data.opentdf.entitlements.attributes` | OPENTDF_SERVICES_AUTHORIZATION_REGO_QUERY | +| Field | Description | Default | Environment Variables | +| ------------ | ------------------------ | -------------------------------------- | ----------------------------------------- | +| `rego.path` | Path to rego policy file | Leverages embedded rego policy | OPENTDF_SERVICES_AUTHORIZATION_REGO_PATH | +| `rego.query` | Rego query to execute | `data.opentdf.entitlements.attributes` | OPENTDF_SERVICES_AUTHORIZATION_REGO_QUERY | #### Authorization v2 Only -| Field | Description | Default | Environment Variables | -|-----------------------------------------|--------------------------------------------------|---------|----------------------| -| `entitlement_policy_cache.enabled` | Enable the entitlement policy cache | `false` | | -| `entitlement_policy_cache.refresh_interval` | How often to refresh the entitlement policy cache (e.g. `30s`) | | | +| Field | Description | Default | Environment Variables | +| ------------------------------------------- | -------------------------------------------------------------- | ------- | --------------------- | +| `entitlement_policy_cache.enabled` | Enable the entitlement policy cache | `false` | | +| `entitlement_policy_cache.refresh_interval` | How often to refresh the entitlement policy cache (e.g. `30s`) | | | #### Example: Authorization v1 @@ -346,29 +429,29 @@ Root level key `entityresolution` #### Shared Keys (v1 & v2) -| Field | Description | Default | Environment Variable | -|-------------------------|----------------------------------------------------------------------------------------------|-----------|------------------------------------------------------| -| `mode` | The mode in which to run ERS (`keycloak` or `claims`) | `keycloak` | OPENTDF_SERVICES_ENTITYRESOLUTION_MODE | -| `url` | Endpoint URL for the entity resolution service (specific to `keycloak` mode) | `""` | OPENTDF_SERVICES_ENTITYRESOLUTION_URL | -| `clientid` | Keycloak client ID for authentication (specific to `keycloak` mode) | `""` | OPENTDF_SERVICES_ENTITYRESOLUTION_CLIENTID | -| `clientsecret` | Keycloak client secret for authentication(specific to `keycloak` mode) | `""` | OPENTDF_SERVICES_ENTITYRESOLUTION_CLIENTSECRET | -| `realm` | Keycloak realm for authentication (specific to `keycloak` mode) | | OPENTDF_SERVICES_ENTITYRESOLUTION_REALM | -| `legacykeycloak` | Enables legacy Keycloak compatibility (`/auth` as base endpoint) (specific to `keycloak` mode) | `false` | OPENTDF_SERVICES_ENTITYRESOLUTION_LEGACYKEYCLOAK | -| `inferid.from.email` | Infer entity IDs from email addresses (specific to `keycloak` mode) | `false` | OPENTDF_SERVICES_ENTITYRESOLUTION_INFERID_FROM_EMAIL | -| `inferid.from.username` | Infer entity IDs from usernames (specific to `keycloak` mode) | `false` | OPENTDF_SERVICES_ENTITYRESOLUTION_INFERID_FROM_USERNAME | -| `inferid.from.clientid` | Infer entity IDs from client IDs (specific to `keycloak` mode) | `false` | OPENTDF_SERVICES_ENTITYRESOLUTION_INFERID_FROM_CLIENTID | +| Field | Description | Default | Environment Variable | +| ----------------------- | ---------------------------------------------------------------------------------------------- | ---------- | ------------------------------------------------------- | +| `mode` | The mode in which to run ERS (`keycloak` or `claims`) | `keycloak` | OPENTDF_SERVICES_ENTITYRESOLUTION_MODE | +| `url` | Endpoint URL for the entity resolution service (specific to `keycloak` mode) | `""` | OPENTDF_SERVICES_ENTITYRESOLUTION_URL | +| `clientid` | Keycloak client ID for authentication (specific to `keycloak` mode) | `""` | OPENTDF_SERVICES_ENTITYRESOLUTION_CLIENTID | +| `clientsecret` | Keycloak client secret for authentication(specific to `keycloak` mode) | `""` | OPENTDF_SERVICES_ENTITYRESOLUTION_CLIENTSECRET | +| `realm` | Keycloak realm for authentication (specific to `keycloak` mode) | | OPENTDF_SERVICES_ENTITYRESOLUTION_REALM | +| `legacykeycloak` | Enables legacy Keycloak compatibility (`/auth` as base endpoint) (specific to `keycloak` mode) | `false` | OPENTDF_SERVICES_ENTITYRESOLUTION_LEGACYKEYCLOAK | +| `inferid.from.email` | Infer entity IDs from email addresses (specific to `keycloak` mode) | `false` | OPENTDF_SERVICES_ENTITYRESOLUTION_INFERID_FROM_EMAIL | +| `inferid.from.username` | Infer entity IDs from usernames (specific to `keycloak` mode) | `false` | OPENTDF_SERVICES_ENTITYRESOLUTION_INFERID_FROM_USERNAME | +| `inferid.from.clientid` | Infer entity IDs from client IDs (specific to `keycloak` mode) | `false` | OPENTDF_SERVICES_ENTITYRESOLUTION_INFERID_FROM_CLIENTID | #### Entity Resolution v1 Only -| Field | Description | Default | Environment Variables | -|-------|-------------|---------|----------------------| -| *(none currently)* | | | | +| Field | Description | Default | Environment Variables | +| ------------------ | ----------- | ------- | --------------------- | +| *(none currently)* | | | | #### Entity Resolution v2 Only -| Field | Description | Default | Environment Variable | -|--------------------|--------------------------------------------------------------------|----------|---------------------| -| `cache_expiration` | Cache duration for entity resolution results (e.g., `30s`). Disabled if not set or zero. (specific to `keycloak` mode) | disabled | | +| Field | Description | Default | Environment Variable | +| ------------------ | ---------------------------------------------------------------------------------------------------------------------- | -------- | -------------------- | +| `cache_expiration` | Cache duration for entity resolution results (e.g., `30s`). Disabled if not set or zero. (specific to `keycloak` mode) | disabled | | #### Example: Entity Resolution v1 @@ -524,9 +607,9 @@ The platform supports a cache manager to improve performance for frequently acce Root level key `cache` -| Field | Description | Default | -|--------------------------|------------------------------------------------------------------|--------------| -| `ristretto.max_cost` | Maximum cost for the cache (e.g. 100mb, 1gb) | `1gb` | +| Field | Description | Default | +| -------------------- | -------------------------------------------- | ------- | +| `ristretto.max_cost` | Maximum cost for the cache (e.g. 100mb, 1gb) | `1gb` | Example: diff --git a/opentdf-dev.yaml b/opentdf-dev.yaml index 4712457bb..6f2288821 100644 --- a/opentdf-dev.yaml +++ b/opentdf-dev.yaml @@ -151,6 +151,12 @@ server: allowcredentials: true # Sets the maximum age (in seconds) of a specific CORS preflight request maxage: 3600 + # Additive fields - append to base lists without replacing defaults + # Use these to add custom values without having to copy all defaults + # additionalmethods: [] + # additionalheaders: + # - X-Custom-Header + # additionalexposedheaders: [] grpc: reflectionEnabled: true # Default is false # http: diff --git a/service/internal/server/server.go b/service/internal/server/server.go index 5a82391da..907f50f4b 100644 --- a/service/internal/server/server.go +++ b/service/internal/server/server.go @@ -129,6 +129,89 @@ type CORSConfig struct { AllowCredentials bool `mapstructure:"allowcredentials" json:"allowcredentials" default:"true"` MaxAge int `mapstructure:"maxage" json:"maxage" default:"3600"` Debug bool `mapstructure:"debug" json:"debug"` + + // Additive fields - appended to base lists at runtime without replacing defaults + AdditionalMethods []string `mapstructure:"additionalmethods" json:"additionalmethods"` + AdditionalHeaders []string `mapstructure:"additionalheaders" json:"additionalheaders"` + AdditionalExposedHeaders []string `mapstructure:"additionalexposedheaders" json:"additionalexposedheaders"` +} + +// mergeStringSlices combines base and additional slices, removing duplicates. +// The order is: base items first, then additional items (preserving order within each). +// Comparison is case-sensitive. +func mergeStringSlices(base, additional []string) []string { + if len(additional) == 0 { + return base + } + if len(base) == 0 { + return additional + } + + seen := make(map[string]struct{}, len(base)+len(additional)) + result := make([]string, 0, len(base)+len(additional)) + + for _, v := range base { + if _, exists := seen[v]; !exists { + seen[v] = struct{}{} + result = append(result, v) + } + } + for _, v := range additional { + if _, exists := seen[v]; !exists { + seen[v] = struct{}{} + result = append(result, v) + } + } + return result +} + +// mergeHeaderSlices combines base and additional HTTP header slices with case-insensitive +// deduplication. HTTP headers are case-insensitive per RFC 7230, so "Authorization" and +// "authorization" are treated as duplicates. The first occurrence's original casing is preserved. +func mergeHeaderSlices(base, additional []string) []string { + if len(additional) == 0 { + return base + } + if len(base) == 0 { + return additional + } + + // Use canonical header keys for case-insensitive comparison + seen := make(map[string]struct{}, len(base)+len(additional)) + result := make([]string, 0, len(base)+len(additional)) + + for _, v := range base { + canonical := textproto.CanonicalMIMEHeaderKey(v) + if _, exists := seen[canonical]; !exists { + seen[canonical] = struct{}{} + result = append(result, v) // Preserve original casing + } + } + for _, v := range additional { + canonical := textproto.CanonicalMIMEHeaderKey(v) + if _, exists := seen[canonical]; !exists { + seen[canonical] = struct{}{} + result = append(result, v) // Preserve original casing + } + } + return result +} + +// EffectiveMethods returns AllowedMethods merged with AdditionalMethods. +func (c CORSConfig) EffectiveMethods() []string { + return mergeStringSlices(c.AllowedMethods, c.AdditionalMethods) +} + +// EffectiveHeaders returns AllowedHeaders merged with AdditionalHeaders. +// Uses case-insensitive deduplication since HTTP headers are case-insensitive per RFC 7230. +func (c CORSConfig) EffectiveHeaders() []string { + return mergeHeaderSlices(c.AllowedHeaders, c.AdditionalHeaders) +} + +// EffectiveExposedHeaders returns ExposedHeaders merged with AdditionalExposedHeaders. +// Uses case-insensitive deduplication since HTTP headers are case-insensitive per RFC 7230. +func (c CORSConfig) EffectiveExposedHeaders() []string { + return mergeHeaderSlices(c.ExposedHeaders, c.AdditionalExposedHeaders) } type ConnectRPC struct { @@ -314,6 +397,19 @@ func newHTTPServer(c Config, connectRPC http.Handler, originalGrpcGateway http.H // Note: The grpc-gateway handlers are getting chained together in reverse. So the last handler is the first to be called. // CORS if c.CORS.Enabled { + // Compute effective values by merging base and additional lists + effectiveMethods := c.CORS.EffectiveMethods() + effectiveHeaders := c.CORS.EffectiveHeaders() + effectiveExposed := c.CORS.EffectiveExposedHeaders() + + // Log effective CORS config for operator visibility + l.Info("CORS middleware enabled", + slog.Any("allowed_origins", c.CORS.AllowedOrigins), + slog.Any("effective_methods", effectiveMethods), + slog.Any("effective_headers", effectiveHeaders), + slog.Any("effective_exposed_headers", effectiveExposed), + ) + corsHandler := cors.New(cors.Options{ AllowOriginFunc: func(_ *http.Request, origin string) bool { for _, allowedOrigin := range c.CORS.AllowedOrigins { @@ -326,9 +422,9 @@ func newHTTPServer(c Config, connectRPC http.Handler, originalGrpcGateway http.H } return false }, - AllowedMethods: c.CORS.AllowedMethods, - AllowedHeaders: c.CORS.AllowedHeaders, - ExposedHeaders: c.CORS.ExposedHeaders, + AllowedMethods: effectiveMethods, + AllowedHeaders: effectiveHeaders, + ExposedHeaders: effectiveExposed, AllowCredentials: c.CORS.AllowCredentials, MaxAge: c.CORS.MaxAge, Debug: c.CORS.Debug, diff --git a/service/internal/server/server_test.go b/service/internal/server/server_test.go new file mode 100644 index 000000000..d18bd6066 --- /dev/null +++ b/service/internal/server/server_test.go @@ -0,0 +1,524 @@ +package server + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/go-chi/cors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestMergeStringSlices(t *testing.T) { + tests := []struct { + name string + base []string + additional []string + want []string + }{ + { + name: "empty additional returns base", + base: []string{"A", "B"}, + additional: nil, + want: []string{"A", "B"}, + }, + { + name: "empty base returns additional", + base: nil, + additional: []string{"A", "B"}, + want: []string{"A", "B"}, + }, + { + name: "merge without duplicates", + base: []string{"A", "B"}, + additional: []string{"C", "D"}, + want: []string{"A", "B", "C", "D"}, + }, + { + name: "merge with duplicates removed", + base: []string{"A", "B", "C"}, + additional: []string{"B", "D"}, + want: []string{"A", "B", "C", "D"}, + }, + { + name: "both empty returns nil", + base: nil, + additional: nil, + want: nil, + }, + { + name: "preserves order - base first then additional", + base: []string{"Z", "A"}, + additional: []string{"M", "B"}, + want: []string{"Z", "A", "M", "B"}, + }, + { + name: "deduplicates within base", + base: []string{"A", "B", "A"}, + additional: []string{"C"}, + want: []string{"A", "B", "C"}, + }, + { + name: "case sensitive comparison", + base: []string{"Accept"}, + additional: []string{"accept", "ACCEPT"}, + want: []string{"Accept", "accept", "ACCEPT"}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := mergeStringSlices(tt.base, tt.additional) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestCORSConfig_EffectiveMethods(t *testing.T) { + tests := []struct { + name string + cfg CORSConfig + want []string + }{ + { + name: "no additional methods", + cfg: CORSConfig{ + AllowedMethods: []string{"GET", "POST"}, + AdditionalMethods: nil, + }, + want: []string{"GET", "POST"}, + }, + { + name: "with additional methods", + cfg: CORSConfig{ + AllowedMethods: []string{"GET", "POST"}, + AdditionalMethods: []string{"PUT", "DELETE"}, + }, + want: []string{"GET", "POST", "PUT", "DELETE"}, + }, + { + name: "additional methods with duplicate", + cfg: CORSConfig{ + AllowedMethods: []string{"GET", "POST", "PUT"}, + AdditionalMethods: []string{"PUT", "DELETE"}, + }, + want: []string{"GET", "POST", "PUT", "DELETE"}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.cfg.EffectiveMethods() + assert.Equal(t, tt.want, got) + }) + } +} + +func TestMergeHeaderSlices(t *testing.T) { + tests := []struct { + name string + base []string + additional []string + want []string + }{ + { + name: "empty additional returns base", + base: []string{"Authorization", "Content-Type"}, + additional: nil, + want: []string{"Authorization", "Content-Type"}, + }, + { + name: "empty base returns additional", + base: nil, + additional: []string{"Authorization", "Content-Type"}, + want: []string{"Authorization", "Content-Type"}, + }, + { + name: "merge without duplicates", + base: []string{"Authorization"}, + additional: []string{"Content-Type"}, + want: []string{"Authorization", "Content-Type"}, + }, + { + name: "case-insensitive deduplication preserves first occurrence", + base: []string{"Authorization"}, + additional: []string{"authorization", "AUTHORIZATION"}, + want: []string{"Authorization"}, + }, + { + name: "mixed case duplicates - base wins", + base: []string{"Content-Type", "Accept"}, + additional: []string{"content-type", "X-Custom"}, + want: []string{"Content-Type", "Accept", "X-Custom"}, + }, + { + name: "both empty returns nil", + base: nil, + additional: nil, + want: nil, + }, + { + name: "preserves order - base first then additional", + base: []string{"X-First", "X-Second"}, + additional: []string{"X-Third", "X-Fourth"}, + want: []string{"X-First", "X-Second", "X-Third", "X-Fourth"}, + }, + { + name: "deduplicates within base (case-insensitive)", + base: []string{"Accept", "accept", "ACCEPT"}, + additional: []string{"Content-Type"}, + want: []string{"Accept", "Content-Type"}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := mergeHeaderSlices(tt.base, tt.additional) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestCORSConfig_EffectiveHeaders(t *testing.T) { + tests := []struct { + name string + cfg CORSConfig + want []string + }{ + { + name: "no additional headers", + cfg: CORSConfig{ + AllowedHeaders: []string{"Authorization", "Content-Type"}, + AdditionalHeaders: nil, + }, + want: []string{"Authorization", "Content-Type"}, + }, + { + name: "with additional headers", + cfg: CORSConfig{ + AllowedHeaders: []string{"Authorization", "Content-Type"}, + AdditionalHeaders: []string{"X-Custom-Header"}, + }, + want: []string{"Authorization", "Content-Type", "X-Custom-Header"}, + }, + { + name: "additional header already in base is deduplicated", + cfg: CORSConfig{ + AllowedHeaders: []string{"Authorization", "Content-Type"}, + AdditionalHeaders: []string{"X-Custom", "Content-Type"}, + }, + want: []string{"Authorization", "Content-Type", "X-Custom"}, + }, + { + name: "case-insensitive deduplication - base casing preserved", + cfg: CORSConfig{ + AllowedHeaders: []string{"Authorization", "Content-Type"}, + AdditionalHeaders: []string{"authorization", "content-type", "X-New"}, + }, + want: []string{"Authorization", "Content-Type", "X-New"}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.cfg.EffectiveHeaders() + assert.Equal(t, tt.want, got) + }) + } +} + +func TestCORSConfig_EffectiveExposedHeaders(t *testing.T) { + tests := []struct { + name string + cfg CORSConfig + want []string + }{ + { + name: "no additional exposed headers", + cfg: CORSConfig{ + ExposedHeaders: []string{"Link"}, + AdditionalExposedHeaders: nil, + }, + want: []string{"Link"}, + }, + { + name: "with additional exposed headers", + cfg: CORSConfig{ + ExposedHeaders: []string{"Link"}, + AdditionalExposedHeaders: []string{"X-Custom-Exposed"}, + }, + want: []string{"Link", "X-Custom-Exposed"}, + }, + { + name: "empty base with additional", + cfg: CORSConfig{ + ExposedHeaders: nil, + AdditionalExposedHeaders: []string{"X-Custom-Exposed"}, + }, + want: []string{"X-Custom-Exposed"}, + }, + { + name: "case-insensitive deduplication for exposed headers", + cfg: CORSConfig{ + ExposedHeaders: []string{"Link", "X-Request-Id"}, + AdditionalExposedHeaders: []string{"link", "x-request-id", "X-New-Header"}, + }, + want: []string{"Link", "X-Request-Id", "X-New-Header"}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.cfg.EffectiveExposedHeaders() + assert.Equal(t, tt.want, got) + }) + } +} + +// TestCORSMiddleware_WildcardOrigin tests that the CORS middleware correctly handles +// wildcard origin configuration with credentials enabled. Per CORS spec, when +// credentials are allowed, the response must reflect the actual origin, not "*". +func TestCORSMiddleware_WildcardOrigin(t *testing.T) { + // Configure CORS the same way as newHTTPServer does + cfg := CORSConfig{ + Enabled: true, + AllowedOrigins: []string{"*"}, + AllowedMethods: []string{"GET", "POST", "OPTIONS"}, + AllowedHeaders: []string{"Authorization", "Content-Type"}, + ExposedHeaders: []string{"Link"}, + AllowCredentials: true, + MaxAge: 3600, + } + + // Create the CORS handler using the same pattern as server.go + corsHandler := cors.New(cors.Options{ + AllowOriginFunc: func(_ *http.Request, origin string) bool { + for _, allowedOrigin := range cfg.AllowedOrigins { + if allowedOrigin == "*" { + return true + } + if strings.EqualFold(origin, allowedOrigin) { + return true + } + } + return false + }, + AllowedMethods: cfg.EffectiveMethods(), + AllowedHeaders: cfg.EffectiveHeaders(), + ExposedHeaders: cfg.EffectiveExposedHeaders(), + AllowCredentials: cfg.AllowCredentials, + MaxAge: cfg.MaxAge, + }) + + // Create a simple handler wrapped with CORS + handler := corsHandler.Handler(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + tests := []struct { + name string + origin string + method string + requestHeaders string + wantOrigin string + wantCreds string + }{ + { + name: "preflight from localhost:3000", + origin: "http://localhost:3000", + method: http.MethodOptions, + requestHeaders: "authorization,content-type", + wantOrigin: "http://localhost:3000", + wantCreds: "true", + }, + { + name: "preflight from example.com", + origin: "https://example.com", + method: http.MethodOptions, + requestHeaders: "authorization", + wantOrigin: "https://example.com", + wantCreds: "true", + }, + { + name: "preflight from arbitrary origin", + origin: "https://any-site.io", + method: http.MethodOptions, + requestHeaders: "content-type", + wantOrigin: "https://any-site.io", + wantCreds: "true", + }, + { + name: "actual request from localhost", + origin: "http://localhost:3000", + method: http.MethodGet, + wantOrigin: "http://localhost:3000", + wantCreds: "true", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(tt.method, "/test", nil) + req.Header.Set("Origin", tt.origin) + if tt.method == http.MethodOptions { + req.Header.Set("Access-Control-Request-Method", "POST") + if tt.requestHeaders != "" { + req.Header.Set("Access-Control-Request-Headers", tt.requestHeaders) + } + } + + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + // Verify origin is reflected back (not "*" since credentials are enabled) + gotOrigin := rr.Header().Get("Access-Control-Allow-Origin") + require.Equal(t, tt.wantOrigin, gotOrigin, + "Origin should be reflected back, not '*', when credentials are enabled") + + // Verify credentials header + gotCreds := rr.Header().Get("Access-Control-Allow-Credentials") + require.Equal(t, tt.wantCreds, gotCreds) + + // For preflight, verify allowed headers + if tt.method == http.MethodOptions { + gotHeaders := rr.Header().Get("Access-Control-Allow-Headers") + require.NotEmpty(t, gotHeaders, "Preflight should include allowed headers") + } + }) + } +} + +// TestCORSMiddleware_WildcardWithSpecificOrigins tests that wildcard takes precedence +// when mixed with specific origins - all origins are allowed if "*" is in the list. +func TestCORSMiddleware_WildcardWithSpecificOrigins(t *testing.T) { + cfg := CORSConfig{ + AllowedOrigins: []string{"https://specific.com", "*", "https://another.com"}, + AllowedMethods: []string{"GET", "POST"}, + AllowedHeaders: []string{"Authorization"}, + AllowCredentials: true, + } + + corsHandler := cors.New(cors.Options{ + AllowOriginFunc: func(_ *http.Request, origin string) bool { + for _, allowedOrigin := range cfg.AllowedOrigins { + if allowedOrigin == "*" { + return true + } + if strings.EqualFold(origin, allowedOrigin) { + return true + } + } + return false + }, + AllowedMethods: cfg.AllowedMethods, + AllowedHeaders: cfg.AllowedHeaders, + AllowCredentials: cfg.AllowCredentials, + }) + + handler := corsHandler.Handler(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // When "*" is in the list, ANY origin should be allowed + tests := []struct { + name string + origin string + wantOrigin string + }{ + { + name: "specific origin still works", + origin: "https://specific.com", + wantOrigin: "https://specific.com", + }, + { + name: "random origin allowed due to wildcard", + origin: "https://random-site.io", + wantOrigin: "https://random-site.io", + }, + { + name: "evil origin also allowed due to wildcard", + origin: "https://evil.com", + wantOrigin: "https://evil.com", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodOptions, "/test", nil) + req.Header.Set("Origin", tt.origin) + req.Header.Set("Access-Control-Request-Method", "GET") + + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + gotOrigin := rr.Header().Get("Access-Control-Allow-Origin") + assert.Equal(t, tt.wantOrigin, gotOrigin, + "Wildcard in list should allow ALL origins") + }) + } +} + +// TestCORSMiddleware_SpecificOrigins tests CORS with specific origin list (not wildcard) +func TestCORSMiddleware_SpecificOrigins(t *testing.T) { + cfg := CORSConfig{ + Enabled: true, + AllowedOrigins: []string{"https://allowed.com", "https://also-allowed.com"}, + AllowedMethods: []string{"GET", "POST"}, + AllowedHeaders: []string{"Authorization"}, + AllowCredentials: true, + } + + corsHandler := cors.New(cors.Options{ + AllowOriginFunc: func(_ *http.Request, origin string) bool { + for _, allowedOrigin := range cfg.AllowedOrigins { + if allowedOrigin == "*" { + return true + } + if strings.EqualFold(origin, allowedOrigin) { + return true + } + } + return false + }, + AllowedMethods: cfg.AllowedMethods, + AllowedHeaders: cfg.AllowedHeaders, + AllowCredentials: cfg.AllowCredentials, + }) + + handler := corsHandler.Handler(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + tests := []struct { + name string + origin string + wantOrigin string + }{ + { + name: "allowed origin", + origin: "https://allowed.com", + wantOrigin: "https://allowed.com", + }, + { + name: "also allowed origin", + origin: "https://also-allowed.com", + wantOrigin: "https://also-allowed.com", + }, + { + name: "disallowed origin - no CORS headers", + origin: "https://evil.com", + wantOrigin: "", // No Access-Control-Allow-Origin header + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodOptions, "/test", nil) + req.Header.Set("Origin", tt.origin) + req.Header.Set("Access-Control-Request-Method", "GET") + + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + gotOrigin := rr.Header().Get("Access-Control-Allow-Origin") + assert.Equal(t, tt.wantOrigin, gotOrigin) + }) + } +} diff --git a/service/pkg/server/options.go b/service/pkg/server/options.go index d4eadbaaa..65ee1f8e0 100644 --- a/service/pkg/server/options.go +++ b/service/pkg/server/options.go @@ -25,6 +25,11 @@ type StartConfig struct { configLoaderOrder []string trustKeyManagerCtxs []trust.NamedKeyManagerCtxFactory + + // CORS additive configuration - appended to YAML/env config values + additionalCORSHeaders []string + additionalCORSMethods []string + additionalCORSExposedHeaders []string } // Deprecated: Use WithConfigKey @@ -162,3 +167,53 @@ func WithTrustKeyManagerCtxFactories(factories ...trust.NamedKeyManagerCtxFactor return c } } + +// WithAdditionalCORSHeaders appends additional request headers to allow via CORS. +// These are merged with headers from YAML config (server.cors.allowedheaders and +// server.cors.additionalheaders). Deduplication is handled automatically with +// case-insensitive comparison per RFC 7230. +// +// Example: +// +// server.Start( +// server.WithAdditionalCORSHeaders("X-Custom-Header", "X-Another-Header"), +// ) +func WithAdditionalCORSHeaders(headers ...string) StartOptions { + return func(c StartConfig) StartConfig { + c.additionalCORSHeaders = append(c.additionalCORSHeaders, headers...) + return c + } +} + +// WithAdditionalCORSMethods appends additional HTTP methods to allow via CORS. +// These are merged with methods from YAML config (server.cors.allowedmethods and +// server.cors.additionalmethods). Deduplication is handled automatically. +// +// Example: +// +// server.Start( +// server.WithAdditionalCORSMethods("CUSTOM", "SPECIAL"), +// ) +func WithAdditionalCORSMethods(methods ...string) StartOptions { + return func(c StartConfig) StartConfig { + c.additionalCORSMethods = append(c.additionalCORSMethods, methods...) + return c + } +} + +// WithAdditionalCORSExposedHeaders appends additional response headers to expose via CORS. +// These are merged with headers from YAML config (server.cors.exposedheaders and +// server.cors.additionalexposedheaders). Deduplication is handled automatically with +// case-insensitive comparison per RFC 7230. +// +// Example: +// +// server.Start( +// server.WithAdditionalCORSExposedHeaders("X-Request-Id", "X-Trace-Id"), +// ) +func WithAdditionalCORSExposedHeaders(headers ...string) StartOptions { + return func(c StartConfig) StartConfig { + c.additionalCORSExposedHeaders = append(c.additionalCORSExposedHeaders, headers...) + return c + } +} diff --git a/service/pkg/server/start.go b/service/pkg/server/start.go index 3c74e129b..1cf129c3d 100644 --- a/service/pkg/server/start.go +++ b/service/pkg/server/start.go @@ -161,6 +161,21 @@ func Start(f ...StartOptions) error { cfg.Server.Auth.Policy.Adapter = startConfig.casbinAdapter } + // Apply additional CORS configuration from programmatic options + // These are appended to the YAML config values; deduplication happens in Effective*() methods + if len(startConfig.additionalCORSHeaders) > 0 { + logger.Info("additional CORS headers added via options", slog.Any("headers", startConfig.additionalCORSHeaders)) + cfg.Server.CORS.AdditionalHeaders = append(cfg.Server.CORS.AdditionalHeaders, startConfig.additionalCORSHeaders...) + } + if len(startConfig.additionalCORSMethods) > 0 { + logger.Info("additional CORS methods added via options", slog.Any("methods", startConfig.additionalCORSMethods)) + cfg.Server.CORS.AdditionalMethods = append(cfg.Server.CORS.AdditionalMethods, startConfig.additionalCORSMethods...) + } + if len(startConfig.additionalCORSExposedHeaders) > 0 { + logger.Info("additional CORS exposed headers added via options", slog.Any("headers", startConfig.additionalCORSExposedHeaders)) + cfg.Server.CORS.AdditionalExposedHeaders = append(cfg.Server.CORS.AdditionalExposedHeaders, startConfig.additionalCORSExposedHeaders...) + } + // Create new server for grpc & http. Also will support in process grpc potentially too logger.Debug("initializing opentdf server") cfg.Server.WellKnownConfigRegister = wellknown.RegisterConfiguration diff --git a/service/pkg/server/testdata/all-no-config.yaml b/service/pkg/server/testdata/all-no-config.yaml index cdb3592cb..7692d2dda 100644 --- a/service/pkg/server/testdata/all-no-config.yaml +++ b/service/pkg/server/testdata/all-no-config.yaml @@ -98,6 +98,10 @@ server: allowcredentials: true # Sets the maximum age (in seconds) of a specific CORS preflight request maxage: 3600 + # Additive fields - append to base lists without replacing defaults + additionalmethods: [] + additionalheaders: [] + additionalexposedheaders: [] grpc: reflectionEnabled: true # Default is false port: 8080 diff --git a/test/service-cors.bats b/test/service-cors.bats index bab764c3b..8c781ca0d 100755 --- a/test/service-cors.bats +++ b/test/service-cors.bats @@ -110,3 +110,25 @@ fi # Verify Connect-Protocol-Version is in allowed headers [[ "$output" =~ [Aa]ccess-[Cc]ontrol-[Aa]llow-[Hh]eaders:.*[Cc]onnect-[Pp]rotocol-[Vv]ersion ]] } + +@test "CORS: verify all default headers are allowed" { + # Tests that all default headers from CORSConfig are allowed + # Default headers: Accept, Accept-Encoding, Authorization, Connect-Protocol-Version, + # Content-Length, Content-Type, Dpop, X-CSRF-Token, X-Requested-With, X-Rewrap-Additional-Context + # Note: Additional headers can be added via 'additionalheaders' config without replacing defaults + run curl -i -X OPTIONS $CURL_OPTIONS \ + -H "Origin: http://localhost:3000" \ + -H "Access-Control-Request-Method: POST" \ + -H "Access-Control-Request-Headers: accept,authorization,content-type,dpop,x-csrf-token" \ + ${BASE_URL}/policy.namespaces.NamespaceService/GetNamespace + + echo "$output" + + # Verify 200 OK response + [[ "$output" =~ "HTTP/2 200" ]] || [[ "$output" =~ "HTTP/1.1 200 OK" ]] + + # Verify key default headers are in allowed headers + [[ "$output" =~ [Aa]ccess-[Cc]ontrol-[Aa]llow-[Hh]eaders:.*[Aa]uthorization ]] + [[ "$output" =~ [Aa]ccess-[Cc]ontrol-[Aa]llow-[Hh]eaders:.*[Cc]ontent-[Tt]ype ]] + [[ "$output" =~ [Aa]ccess-[Cc]ontrol-[Aa]llow-[Hh]eaders:.*[Dd]pop ]] +}