diff --git a/cmd/thv/app/run_flags.go b/cmd/thv/app/run_flags.go index ea71583ca..f6458256e 100644 --- a/cmd/thv/app/run_flags.go +++ b/cmd/thv/app/run_flags.go @@ -4,28 +4,18 @@ import ( "context" "fmt" "strings" + "time" "github.com/spf13/cobra" "github.com/stacklok/toolhive/pkg/auth" - "github.com/stacklok/toolhive/pkg/authz" cfg "github.com/stacklok/toolhive/pkg/config" - "github.com/stacklok/toolhive/pkg/container" - "github.com/stacklok/toolhive/pkg/container/runtime" - "github.com/stacklok/toolhive/pkg/environment" - "github.com/stacklok/toolhive/pkg/ignore" + "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/networking" - "github.com/stacklok/toolhive/pkg/process" - "github.com/stacklok/toolhive/pkg/registry" "github.com/stacklok/toolhive/pkg/runner" "github.com/stacklok/toolhive/pkg/runner/retriever" "github.com/stacklok/toolhive/pkg/telemetry" "github.com/stacklok/toolhive/pkg/transport" - "github.com/stacklok/toolhive/pkg/transport/types" -) - -const ( - defaultTransportType = "streamable-http" ) // RunFlags holds the configuration for running MCP servers @@ -104,6 +94,161 @@ type RunFlags struct { OAuthParams map[string]string } +// GetEnableRemoteAuth returns whether remote authentication is enabled +func (r *RemoteAuthFlags) GetEnableRemoteAuth() bool { return r.EnableRemoteAuth } + +// GetRemoteAuthClientID returns the remote authentication client ID +func (r *RemoteAuthFlags) GetRemoteAuthClientID() string { return r.RemoteAuthClientID } + +// GetRemoteAuthClientSecret returns the remote authentication client secret +func (r *RemoteAuthFlags) GetRemoteAuthClientSecret() string { return r.RemoteAuthClientSecret } + +// GetRemoteAuthClientSecretFile returns the remote authentication client secret file path +func (r *RemoteAuthFlags) GetRemoteAuthClientSecretFile() string { return r.RemoteAuthClientSecretFile } + +// GetRemoteAuthScopes returns the remote authentication scopes +func (r *RemoteAuthFlags) GetRemoteAuthScopes() []string { return r.RemoteAuthScopes } + +// GetRemoteAuthSkipBrowser returns whether to skip browser for remote authentication +func (r *RemoteAuthFlags) GetRemoteAuthSkipBrowser() bool { return r.RemoteAuthSkipBrowser } + +// GetRemoteAuthTimeout returns the remote authentication timeout +func (r *RemoteAuthFlags) GetRemoteAuthTimeout() time.Duration { return r.RemoteAuthTimeout } + +// GetRemoteAuthCallbackPort returns the remote authentication callback port +func (r *RemoteAuthFlags) GetRemoteAuthCallbackPort() int { return r.RemoteAuthCallbackPort } + +// GetRemoteAuthIssuer returns the remote authentication issuer +func (r *RemoteAuthFlags) GetRemoteAuthIssuer() string { return r.RemoteAuthIssuer } + +// GetRemoteAuthAuthorizeURL returns the remote authentication authorize URL +func (r *RemoteAuthFlags) GetRemoteAuthAuthorizeURL() string { return r.RemoteAuthAuthorizeURL } + +// GetRemoteAuthTokenURL returns the remote authentication token URL +func (r *RemoteAuthFlags) GetRemoteAuthTokenURL() string { return r.RemoteAuthTokenURL } + +// GetName returns the server name +func (r *RunFlags) GetName() string { return r.Name } + +// GetGroup returns the server group +func (r *RunFlags) GetGroup() string { return r.Group } + +// GetTransport returns the transport type +func (r *RunFlags) GetTransport() string { return r.Transport } + +// GetProxyMode returns the proxy mode +func (r *RunFlags) GetProxyMode() string { return r.ProxyMode } + +// GetHost returns the host +func (r *RunFlags) GetHost() string { return r.Host } + +// GetProxyPort returns the proxy port +func (r *RunFlags) GetProxyPort() int { return r.ProxyPort } + +// GetTargetPort returns the target port +func (r *RunFlags) GetTargetPort() int { return r.TargetPort } + +// GetTargetHost returns the target host +func (r *RunFlags) GetTargetHost() string { return r.TargetHost } + +// GetEnv returns the environment variables +func (r *RunFlags) GetEnv() []string { return r.Env } + +// GetVolumes returns the volumes +func (r *RunFlags) GetVolumes() []string { return r.Volumes } + +// GetSecrets returns the secrets +func (r *RunFlags) GetSecrets() []string { return r.Secrets } + +// GetEnvFile returns the environment file path +func (r *RunFlags) GetEnvFile() string { return r.EnvFile } + +// GetEnvFileDir returns the environment file directory +func (r *RunFlags) GetEnvFileDir() string { return r.EnvFileDir } + +// GetPermissionProfile returns the permission profile +func (r *RunFlags) GetPermissionProfile() string { return r.PermissionProfile } + +// GetAuthzConfig returns the authorization config +func (r *RunFlags) GetAuthzConfig() string { return r.AuthzConfig } + +// GetAuditConfig returns the audit config +func (r *RunFlags) GetAuditConfig() string { return r.AuditConfig } + +// GetEnableAudit returns whether audit is enabled +func (r *RunFlags) GetEnableAudit() bool { return r.EnableAudit } + +// GetCACertPath returns the CA certificate path +func (r *RunFlags) GetCACertPath() string { return r.CACertPath } + +// GetVerifyImage returns the image verification setting +func (r *RunFlags) GetVerifyImage() string { return r.VerifyImage } + +// GetRemoteURL returns the remote URL +func (r *RunFlags) GetRemoteURL() string { return r.RemoteURL } + +// GetRemoteAuthFlags returns the remote authentication flags +func (r *RunFlags) GetRemoteAuthFlags() runner.RemoteAuthFlagsInterface { return &r.RemoteAuthFlags } + +// GetOAuthParams returns the OAuth parameters +func (r *RunFlags) GetOAuthParams() map[string]string { return r.OAuthParams } + +// GetIsolateNetwork returns whether network isolation is enabled +func (r *RunFlags) GetIsolateNetwork() bool { return r.IsolateNetwork } + +// GetLabels returns the labels +func (r *RunFlags) GetLabels() []string { return r.Labels } + +// GetToolsFilter returns the tools filter +func (r *RunFlags) GetToolsFilter() []string { return r.ToolsFilter } + +// GetForeground returns whether to run in foreground +func (r *RunFlags) GetForeground() bool { return r.Foreground } + +// GetThvCABundle returns the THV CA bundle +func (r *RunFlags) GetThvCABundle() string { return r.ThvCABundle } + +// GetJWKSAuthTokenFile returns the JWKS auth token file +func (r *RunFlags) GetJWKSAuthTokenFile() string { return r.JWKSAuthTokenFile } + +// GetJWKSAllowPrivateIP returns whether to allow private IP for JWKS +func (r *RunFlags) GetJWKSAllowPrivateIP() bool { return r.JWKSAllowPrivateIP } + +// GetResourceURL returns the resource URL +func (r *RunFlags) GetResourceURL() string { return r.ResourceURL } + +// GetOtelEndpoint returns the OpenTelemetry endpoint +func (r *RunFlags) GetOtelEndpoint() string { return r.OtelEndpoint } + +// GetOtelServiceName returns the OpenTelemetry service name +func (r *RunFlags) GetOtelServiceName() string { return r.OtelServiceName } + +// GetOtelSamplingRate returns the OpenTelemetry sampling rate +func (r *RunFlags) GetOtelSamplingRate() float64 { return r.OtelSamplingRate } + +// GetOtelHeaders returns the OpenTelemetry headers +func (r *RunFlags) GetOtelHeaders() []string { return r.OtelHeaders } + +// GetOtelInsecure returns whether OpenTelemetry is insecure +func (r *RunFlags) GetOtelInsecure() bool { return r.OtelInsecure } + +// GetOtelEnablePrometheusMetricsPath returns whether Prometheus metrics path is enabled +func (r *RunFlags) GetOtelEnablePrometheusMetricsPath() bool { + return r.OtelEnablePrometheusMetricsPath +} + +// GetOtelEnvironmentVariables returns the OpenTelemetry environment variables +func (r *RunFlags) GetOtelEnvironmentVariables() []string { return r.OtelEnvironmentVariables } + +// GetK8sPodPatch returns the Kubernetes pod patch +func (r *RunFlags) GetK8sPodPatch() string { return r.K8sPodPatch } + +// GetIgnoreGlobally returns whether to ignore globally +func (r *RunFlags) GetIgnoreGlobally() bool { return r.IgnoreGlobally } + +// GetPrintOverlays returns whether to print overlays +func (r *RunFlags) GetPrintOverlays() bool { return r.PrintOverlays } + // AddRunFlags adds all the run flags to a command func AddRunFlags(cmd *cobra.Command, config *RunFlags) { cmd.Flags().StringVar(&config.Transport, "transport", "", "Transport mode (sse, streamable-http or stdio)") @@ -207,7 +352,7 @@ func AddRunFlags(cmd *cobra.Command, config *RunFlags) { "Debug: show resolved container paths for tmpfs overlays") } -// BuildRunnerConfig creates a runner.RunConfig from the configuration +// BuildRunnerConfig creates a runner.RunConfig from the configuration using the new PreRunConfig approach func BuildRunnerConfig( ctx context.Context, runFlags *RunFlags, @@ -231,38 +376,35 @@ func BuildRunnerConfig( // Setup telemetry configuration telemetryConfig := setupTelemetryConfiguration(cmd, runFlags) - // Setup runtime and validation - rt, envVarValidator, err := setupRuntimeAndValidation(ctx) + // Parse and classify the input using PreRunConfig + preConfig, err := runner.ParsePreRunConfig(serverOrImage, runFlags.FromConfig) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to parse run configuration: %w", err) } - // If --remote flag is provided, use it as the serverOrImage - if runFlags.RemoteURL != "" { - return buildRunnerConfig(ctx, runFlags, cmdArgs, debugMode, validatedHost, rt, runFlags.RemoteURL, nil, - nil, envVarValidator, oidcConfig, telemetryConfig) - } + logger.Debugf("Parsed PreRunConfig: %s", preConfig.String()) - // Handle image retrieval - imageURL, serverMetadata, err := handleImageRetrieval(ctx, serverOrImage, runFlags) + // Create the transformer + transformer, err := runner.NewPreRunTransformer(ctx) if err != nil { - return nil, err - } - - // Validate and setup proxy mode - if err := validateAndSetupProxyMode(runFlags); err != nil { - return nil, err + return nil, fmt.Errorf("failed to create transformer: %w", err) } - // Parse environment variables - envVars, err := environment.ParseEnvironmentVariables(runFlags.Env) + // Transform PreRunConfig to RunConfig + runConfig, err := transformer.TransformToRunConfig( + preConfig, + runFlags, // runFlags implements RunFlagsInterface + cmdArgs, + debugMode, + validatedHost, + oidcConfig, + telemetryConfig, + ) if err != nil { - return nil, fmt.Errorf("failed to parse environment variables: %v", err) + return nil, fmt.Errorf("failed to transform to run config: %w", err) } - // Build the runner config - return buildRunnerConfig(ctx, runFlags, cmdArgs, debugMode, validatedHost, rt, imageURL, serverMetadata, - envVars, envVarValidator, oidcConfig, telemetryConfig) + return runConfig, nil } // setupOIDCConfiguration sets up OIDC configuration and validates URLs @@ -295,242 +437,6 @@ func setupTelemetryConfiguration(cmd *cobra.Command, runFlags *RunFlags) *teleme finalOtelEnvironmentVariables) } -// setupRuntimeAndValidation creates container runtime and selects environment variable validator -func setupRuntimeAndValidation(ctx context.Context) (runtime.Deployer, runner.EnvVarValidator, error) { - rt, err := container.NewFactory().Create(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to create container runtime: %v", err) - } - - var envVarValidator runner.EnvVarValidator - if process.IsDetached() || runtime.IsKubernetesRuntime() { - envVarValidator = &runner.DetachedEnvVarValidator{} - } else { - envVarValidator = &runner.CLIEnvVarValidator{} - } - - return rt, envVarValidator, nil -} - -// handleImageRetrieval handles image retrieval and metadata fetching -func handleImageRetrieval( - ctx context.Context, - serverOrImage string, - runFlags *RunFlags, -) ( - string, - registry.ServerMetadata, - error, -) { - - // Try to get server from registry (container or remote) or direct URL - imageURL, serverMetadata, err := retriever.GetMCPServer( - ctx, serverOrImage, runFlags.CACertPath, runFlags.VerifyImage) - if err != nil { - return "", nil, fmt.Errorf("failed to find or create the MCP server %s: %v", serverOrImage, err) - } - - // Check if we have a remote server - if serverMetadata != nil && serverMetadata.IsRemote() { - return imageURL, serverMetadata, nil - } - - // Only pull image if we are not running in Kubernetes mode. - // This split will go away if we implement a separate command or binary - // for running MCP servers in Kubernetes. - if !runtime.IsKubernetesRuntime() { - // Take the MCP server we were supplied and either fetch the image, or - // build it from a protocol scheme. If the server URI refers to an image - // in our trusted registry, we will also fetch the image metadata. - if serverMetadata != nil { - return imageURL, serverMetadata, nil - } - } - return serverOrImage, nil, nil -} - -// validateAndSetupProxyMode validates and sets default proxy mode if needed -func validateAndSetupProxyMode(runFlags *RunFlags) error { - if !types.IsValidProxyMode(runFlags.ProxyMode) { - if runFlags.ProxyMode == "" { - runFlags.ProxyMode = types.ProxyModeSSE.String() // default to SSE for backward compatibility - } else { - return fmt.Errorf("invalid value for --proxy-mode: %s", runFlags.ProxyMode) - } - } - return nil -} - -// buildRunnerConfig creates the final RunnerConfig using the builder pattern -func buildRunnerConfig( - ctx context.Context, - runFlags *RunFlags, - cmdArgs []string, - debugMode bool, - validatedHost string, - rt runtime.Deployer, - imageURL string, - serverMetadata registry.ServerMetadata, - envVars map[string]string, - envVarValidator runner.EnvVarValidator, - oidcConfig *auth.TokenValidatorConfig, - telemetryConfig *telemetry.Config, -) (*runner.RunConfig, error) { - // Determine transport type - transportType := defaultTransportType - if runFlags.Transport != "" { - transportType = runFlags.Transport - } else if serverMetadata != nil { - transportType = serverMetadata.GetTransport() - } - // Create a builder for the RunConfig - builder := runner.NewRunConfigBuilder(). - WithRuntime(rt). - WithCmdArgs(cmdArgs). - WithName(runFlags.Name). - WithImage(imageURL). - WithRemoteURL(runFlags.RemoteURL). - WithHost(validatedHost). - WithTargetHost(runFlags.TargetHost). - WithDebug(debugMode). - WithVolumes(runFlags.Volumes). - WithSecrets(runFlags.Secrets). - WithAuthzConfigPath(runFlags.AuthzConfig). - WithAuditConfigPath(runFlags.AuditConfig). - WithPermissionProfileNameOrPath(runFlags.PermissionProfile). - WithNetworkIsolation(runFlags.IsolateNetwork). - WithK8sPodPatch(runFlags.K8sPodPatch). - WithProxyMode(types.ProxyMode(runFlags.ProxyMode)). - WithTransportAndPorts(transportType, runFlags.ProxyPort, runFlags.TargetPort). - WithAuditEnabled(runFlags.EnableAudit, runFlags.AuditConfig). - WithLabels(runFlags.Labels). - WithGroup(runFlags.Group). - WithIgnoreConfig(&ignore.Config{ - LoadGlobal: runFlags.IgnoreGlobally, - PrintOverlays: runFlags.PrintOverlays, - }) - - // Configure middleware from flags - builder = builder.WithMiddlewareFromFlags( - oidcConfig, - runFlags.ToolsFilter, - telemetryConfig, - runFlags.AuthzConfig, - runFlags.EnableAudit, - runFlags.AuditConfig, - runFlags.Name, - runFlags.Transport, - ) - - if remoteServerMetadata, ok := serverMetadata.(*registry.RemoteServerMetadata); ok { - if remoteAuthConfig := getRemoteAuthFromRemoteServerMetadata(remoteServerMetadata); remoteAuthConfig != nil { - builder = builder.WithRemoteAuth(remoteAuthConfig) - } - } - if runFlags.RemoteURL != "" { - if remoteAuthConfig := getRemoteAuthFromRunFlags(runFlags); remoteAuthConfig != nil { - builder = builder.WithRemoteAuth(remoteAuthConfig) - } - } - - // Load authz config if path is provided - if runFlags.AuthzConfig != "" { - if authzConfigData, err := authz.LoadConfig(runFlags.AuthzConfig); err == nil { - builder = builder.WithAuthzConfig(authzConfigData) - } - // Note: Path is already set via WithAuthzConfigPath above - } - - // Get OIDC and telemetry values for legacy configuration - oidcIssuer, oidcAudience, oidcJwksURL, oidcIntrospectionURL, oidcClientID, oidcClientSecret := extractOIDCValues(oidcConfig) - finalOtelEndpoint, finalOtelSamplingRate, finalOtelEnvironmentVariables := extractTelemetryValues(telemetryConfig) - - // Set additional configurations that are still needed in old format for other parts of the system - builder = builder.WithOIDCConfig(oidcIssuer, oidcAudience, oidcJwksURL, oidcIntrospectionURL, oidcClientID, oidcClientSecret, - runFlags.ThvCABundle, runFlags.JWKSAuthTokenFile, runFlags.ResourceURL, runFlags.JWKSAllowPrivateIP). - WithTelemetryConfig(finalOtelEndpoint, runFlags.OtelEnablePrometheusMetricsPath, runFlags.OtelServiceName, - finalOtelSamplingRate, runFlags.OtelHeaders, runFlags.OtelInsecure, finalOtelEnvironmentVariables). - WithToolsFilter(runFlags.ToolsFilter) - - imageMetadata, _ := serverMetadata.(*registry.ImageMetadata) - // Process environment files - var err error - if runFlags.EnvFile != "" { - builder, err = builder.WithEnvFile(runFlags.EnvFile) - if err != nil { - return nil, fmt.Errorf("failed to process env file %s: %v", runFlags.EnvFile, err) - } - } - if runFlags.EnvFileDir != "" { - builder, err = builder.WithEnvFilesFromDirectory(runFlags.EnvFileDir) - if err != nil { - return nil, fmt.Errorf("failed to process env files from directory %s: %v", runFlags.EnvFileDir, err) - } - } - - return builder.Build(ctx, imageMetadata, envVars, envVarValidator) -} - -// extractOIDCValues extracts OIDC values from the OIDC config for legacy configuration -func extractOIDCValues(config *auth.TokenValidatorConfig) (string, string, string, string, string, string) { - if config == nil { - return "", "", "", "", "", "" - } - return config.Issuer, config.Audience, config.JWKSURL, config.IntrospectionURL, config.ClientID, config.ClientSecret -} - -// extractTelemetryValues extracts telemetry values from the telemetry config for legacy configuration -func extractTelemetryValues(config *telemetry.Config) (string, float64, []string) { - if config == nil { - return "", 0.0, nil - } - return config.Endpoint, config.SamplingRate, config.EnvironmentVariables -} - -// getRemoteAuthFromRemoteServerMetadata creates RemoteAuthConfig from RemoteServerMetadata -func getRemoteAuthFromRemoteServerMetadata(remoteServerMetadata *registry.RemoteServerMetadata) *runner.RemoteAuthConfig { - if remoteServerMetadata == nil { - return nil - } - - if remoteServerMetadata.OAuthConfig != nil { - return &runner.RemoteAuthConfig{ - ClientID: runFlags.RemoteAuthFlags.RemoteAuthClientID, - ClientSecret: runFlags.RemoteAuthFlags.RemoteAuthClientSecret, - Scopes: remoteServerMetadata.OAuthConfig.Scopes, - SkipBrowser: runFlags.RemoteAuthFlags.RemoteAuthSkipBrowser, - Timeout: runFlags.RemoteAuthFlags.RemoteAuthTimeout, - CallbackPort: remoteServerMetadata.OAuthConfig.CallbackPort, - Issuer: remoteServerMetadata.OAuthConfig.Issuer, - AuthorizeURL: remoteServerMetadata.OAuthConfig.AuthorizeURL, - TokenURL: remoteServerMetadata.OAuthConfig.TokenURL, - OAuthParams: remoteServerMetadata.OAuthConfig.OAuthParams, - Headers: remoteServerMetadata.Headers, - EnvVars: remoteServerMetadata.EnvVars, - } - } - return nil -} - -// getRemoteAuthFromRunFlags creates RemoteAuthConfig from RunFlags -func getRemoteAuthFromRunFlags(runFlags *RunFlags) *runner.RemoteAuthConfig { - if runFlags.RemoteAuthFlags.EnableRemoteAuth || runFlags.RemoteAuthFlags.RemoteAuthClientID != "" { - return &runner.RemoteAuthConfig{ - ClientID: runFlags.RemoteAuthFlags.RemoteAuthClientID, - ClientSecret: runFlags.RemoteAuthFlags.RemoteAuthClientSecret, - Scopes: runFlags.RemoteAuthFlags.RemoteAuthScopes, - SkipBrowser: runFlags.RemoteAuthFlags.RemoteAuthSkipBrowser, - Timeout: runFlags.RemoteAuthFlags.RemoteAuthTimeout, - CallbackPort: runFlags.RemoteAuthFlags.RemoteAuthCallbackPort, - Issuer: runFlags.RemoteAuthFlags.RemoteAuthIssuer, - AuthorizeURL: runFlags.RemoteAuthFlags.RemoteAuthAuthorizeURL, - TokenURL: runFlags.RemoteAuthFlags.RemoteAuthTokenURL, - OAuthParams: runFlags.OAuthParams, - } - } - return nil -} - // getOidcFromFlags extracts OIDC configuration from command flags func getOidcFromFlags(cmd *cobra.Command) (string, string, string, string, string, string) { oidcIssuer := GetStringFlagOrEmpty(cmd, "oidc-issuer") diff --git a/pkg/runner/prerun_config.go b/pkg/runner/prerun_config.go new file mode 100644 index 000000000..9adebf69f --- /dev/null +++ b/pkg/runner/prerun_config.go @@ -0,0 +1,233 @@ +package runner + +import ( + "fmt" + "net/url" + + nameref "github.com/google/go-containerregistry/pkg/name" + + "github.com/stacklok/toolhive/pkg/container/templates" + "github.com/stacklok/toolhive/pkg/networking" + "github.com/stacklok/toolhive/pkg/registry" +) + +// PreRunConfigType represents the different types of MCP server sources +type PreRunConfigType string + +const ( + // PreRunConfigTypeRegistry indicates the source is a server name from the registry + PreRunConfigTypeRegistry PreRunConfigType = "registry" + // PreRunConfigTypeContainerImage indicates the source is a direct container image reference + PreRunConfigTypeContainerImage PreRunConfigType = "container_image" + // PreRunConfigTypeProtocolScheme indicates the source is a protocol scheme (uvx://, npx://, go://) + PreRunConfigTypeProtocolScheme PreRunConfigType = "protocol_scheme" + // PreRunConfigTypeRemoteURL indicates the source is a remote HTTP/HTTPS URL + PreRunConfigTypeRemoteURL PreRunConfigType = "remote_url" + // PreRunConfigTypeConfigFile indicates the source is a configuration file + PreRunConfigTypeConfigFile PreRunConfigType = "config_file" +) + +// PreRunConfig represents the parsed and classified input before building a RunConfig +type PreRunConfig struct { + // Type indicates what kind of source this is + Type PreRunConfigType + + // Source is the original input string + Source string + + // ParsedSource contains type-specific parsed information + ParsedSource interface{} + + // Metadata contains any additional metadata discovered during parsing + Metadata map[string]interface{} +} + +// RegistrySource represents a server from the registry +type RegistrySource struct { + ServerName string + IsRemote bool +} + +// ContainerImageSource represents a direct container image reference +type ContainerImageSource struct { + ImageRef string + Registry string + Name string + Tag string +} + +// ProtocolSchemeSource represents a protocol scheme build +type ProtocolSchemeSource struct { + ProtocolTransportType templates.TransportType // TransportTypeUVX, TransportTypeNPX, TransportTypeGO + Package string // everything after the :// + IsLocalPath bool // true for go://./local-path +} + +// RemoteURLSource represents a remote MCP server URL +type RemoteURLSource struct { + URL string + ParsedURL *url.URL +} + +// ConfigFileSource represents a configuration file +type ConfigFileSource struct { + FilePath string +} + +// ParsePreRunConfig analyzes the input and creates a PreRunConfig +func ParsePreRunConfig(source string, fromConfigPath string) (*PreRunConfig, error) { + preConfig := &PreRunConfig{ + Source: source, + Metadata: make(map[string]interface{}), + } + + // 1. Check if loading from config file + if fromConfigPath != "" { + preConfig.Type = PreRunConfigTypeConfigFile + preConfig.ParsedSource = &ConfigFileSource{ + FilePath: fromConfigPath, + } + return preConfig, nil + } + + // 2. Check if it's a remote URL + if networking.IsURL(source) { + parsedURL, err := url.Parse(source) + if err != nil { + return nil, fmt.Errorf("invalid URL: %w", err) + } + + preConfig.Type = PreRunConfigTypeRemoteURL + preConfig.ParsedSource = &RemoteURLSource{ + URL: source, + ParsedURL: parsedURL, + } + return preConfig, nil + } + + // 3. Check if it's a protocol scheme + if IsImageProtocolScheme(source) { + transportType, packageName, err := parseProtocolScheme(source) + if err != nil { + return nil, fmt.Errorf("failed to parse protocol scheme: %w", err) + } + + isLocal := transportType == templates.TransportTypeGO && isLocalGoPath(packageName) + + preConfig.Type = PreRunConfigTypeProtocolScheme + preConfig.ParsedSource = &ProtocolSchemeSource{ + ProtocolTransportType: transportType, + Package: packageName, + IsLocalPath: isLocal, + } + return preConfig, nil + } + + // 4. Try to find in registry + provider, err := registry.GetDefaultProvider() + if err == nil { + server, err := provider.GetServer(source) + if err == nil { + preConfig.Type = PreRunConfigTypeRegistry + preConfig.ParsedSource = &RegistrySource{ + ServerName: source, + IsRemote: server.IsRemote(), + } + return preConfig, nil + } + } + + // 5. Default to container image reference + registryName, name, tag, err := parseContainerImageRef(source) + if err != nil { + return nil, fmt.Errorf("failed to parse container image reference: %w", err) + } + + preConfig.Type = PreRunConfigTypeContainerImage + preConfig.ParsedSource = &ContainerImageSource{ + ImageRef: source, + Registry: registryName, + Name: name, + Tag: tag, + } + + return preConfig, nil +} + +// parseContainerImageRef parses a container image reference using go-containerregistry +func parseContainerImageRef(imageRef string) (registryName, name, tag string, err error) { + ref, err := nameref.ParseReference(imageRef) + if err != nil { + return "", "", "", fmt.Errorf("invalid image reference: %w", err) + } + + registryName = ref.Context().RegistryStr() + name = ref.Context().RepositoryStr() + + // Check if it's a tagged reference or digest + if taggedRef, ok := ref.(nameref.Tag); ok { + tag = taggedRef.TagStr() + } else if digestRef, ok := ref.(nameref.Digest); ok { + tag = digestRef.DigestStr() + } else { + tag = "latest" + } + + return registryName, name, tag, nil +} + +// String returns a human-readable description of the PreRunConfig +func (p *PreRunConfig) String() string { + switch p.Type { + case PreRunConfigTypeRegistry: + src := p.ParsedSource.(*RegistrySource) + if src.IsRemote { + return fmt.Sprintf("Registry remote server: %s", src.ServerName) + } + return fmt.Sprintf("Registry container server: %s", src.ServerName) + case PreRunConfigTypeContainerImage: + src := p.ParsedSource.(*ContainerImageSource) + return fmt.Sprintf("Container image: %s:%s", src.Name, src.Tag) + case PreRunConfigTypeProtocolScheme: + src := p.ParsedSource.(*ProtocolSchemeSource) + if src.IsLocalPath { + return fmt.Sprintf("Protocol scheme %s (local): %s", src.ProtocolTransportType, src.Package) + } + return fmt.Sprintf("Protocol scheme %s: %s", src.ProtocolTransportType, src.Package) + case PreRunConfigTypeRemoteURL: + src := p.ParsedSource.(*RemoteURLSource) + return fmt.Sprintf("Remote URL: %s", src.URL) + case PreRunConfigTypeConfigFile: + src := p.ParsedSource.(*ConfigFileSource) + return fmt.Sprintf("Config file: %s", src.FilePath) + default: + return fmt.Sprintf("Unknown type: %s", p.Source) + } +} + +// IsRemote returns true if this PreRunConfig represents a remote MCP server +func (p *PreRunConfig) IsRemote() bool { + switch p.Type { + case PreRunConfigTypeRemoteURL: + return true + case PreRunConfigTypeRegistry: + src := p.ParsedSource.(*RegistrySource) + return src.IsRemote + case PreRunConfigTypeContainerImage, PreRunConfigTypeProtocolScheme, PreRunConfigTypeConfigFile: + return false + default: + return false + } +} + +// RequiresBuild returns true if this PreRunConfig requires building a container image +func (p *PreRunConfig) RequiresBuild() bool { + switch p.Type { + case PreRunConfigTypeProtocolScheme: + return true + case PreRunConfigTypeRegistry, PreRunConfigTypeContainerImage, PreRunConfigTypeRemoteURL, PreRunConfigTypeConfigFile: + return false + default: + return false + } +} diff --git a/pkg/runner/prerun_config_test.go b/pkg/runner/prerun_config_test.go new file mode 100644 index 000000000..28be0f442 --- /dev/null +++ b/pkg/runner/prerun_config_test.go @@ -0,0 +1,327 @@ +package runner + +import ( + "testing" + + "github.com/stacklok/toolhive/pkg/container/templates" +) + +func TestParsePreRunConfig(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + source string + fromConfigPath string + expectedType PreRunConfigType + expectError bool + }{ + { + name: "Config file", + source: "", + fromConfigPath: "/path/to/config.json", + expectedType: PreRunConfigTypeConfigFile, + expectError: false, + }, + { + name: "Remote URL", + source: "https://example.com/mcp-server", + expectedType: PreRunConfigTypeRemoteURL, + expectError: false, + }, + { + name: "UVX protocol scheme", + source: "uvx://some-package", + expectedType: PreRunConfigTypeProtocolScheme, + expectError: false, + }, + { + name: "NPX protocol scheme", + source: "npx://@scope/package", + expectedType: PreRunConfigTypeProtocolScheme, + expectError: false, + }, + { + name: "Go protocol scheme", + source: "go://github.com/example/package", + expectedType: PreRunConfigTypeProtocolScheme, + expectError: false, + }, + { + name: "Go local path", + source: "go://./local-package", + expectedType: PreRunConfigTypeProtocolScheme, + expectError: false, + }, + { + name: "Container image", + source: "ghcr.io/example/mcp-server:latest", + expectedType: PreRunConfigTypeContainerImage, + expectError: false, + }, + { + name: "Simple image name", + source: "alpine:latest", + expectedType: PreRunConfigTypeContainerImage, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + preConfig, err := ParsePreRunConfig(tt.source, tt.fromConfigPath) + + if tt.expectError { + if err == nil { + t.Errorf("Expected error but got none") + } + return + } + + if err != nil { + t.Errorf("Unexpected error: %v", err) + return + } + + if preConfig.Type != tt.expectedType { + t.Errorf("Expected type %s, got %s", tt.expectedType, preConfig.Type) + } + + if preConfig.Source != tt.source { + t.Errorf("Expected source %s, got %s", tt.source, preConfig.Source) + } + }) + } +} + +func TestProtocolSchemeSource(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + source string + expectedTransportType templates.TransportType + expectedPackage string + expectedIsLocal bool + }{ + { + name: "UVX package", + source: "uvx://some-package", + expectedTransportType: templates.TransportTypeUVX, + expectedPackage: "some-package", + expectedIsLocal: false, + }, + { + name: "NPX scoped package", + source: "npx://@scope/package", + expectedTransportType: templates.TransportTypeNPX, + expectedPackage: "@scope/package", + expectedIsLocal: false, + }, + { + name: "Go remote package", + source: "go://github.com/example/package", + expectedTransportType: templates.TransportTypeGO, + expectedPackage: "github.com/example/package", + expectedIsLocal: false, + }, + { + name: "Go local path", + source: "go://./local-package", + expectedTransportType: templates.TransportTypeGO, + expectedPackage: "./local-package", + expectedIsLocal: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + preConfig, err := ParsePreRunConfig(tt.source, "") + if err != nil { + t.Errorf("Unexpected error: %v", err) + return + } + + if preConfig.Type != PreRunConfigTypeProtocolScheme { + t.Errorf("Expected type %s, got %s", PreRunConfigTypeProtocolScheme, preConfig.Type) + return + } + + src := preConfig.ParsedSource.(*ProtocolSchemeSource) + if src.ProtocolTransportType != tt.expectedTransportType { + t.Errorf("Expected transport type %s, got %s", tt.expectedTransportType, src.ProtocolTransportType) + } + + if src.Package != tt.expectedPackage { + t.Errorf("Expected package %s, got %s", tt.expectedPackage, src.Package) + } + + if src.IsLocalPath != tt.expectedIsLocal { + t.Errorf("Expected IsLocalPath %t, got %t", tt.expectedIsLocal, src.IsLocalPath) + } + }) + } +} + +func TestContainerImageSource(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + source string + expectedRegistry string + expectedName string + expectedTag string + }{ + { + name: "Full image reference", + source: "ghcr.io/example/mcp-server:v1.0.0", + expectedRegistry: "ghcr.io", + expectedName: "example/mcp-server", + expectedTag: "v1.0.0", + }, + { + name: "Docker Hub image", + source: "alpine:latest", + expectedRegistry: "index.docker.io", + expectedName: "library/alpine", + expectedTag: "latest", + }, + { + name: "Image without tag", + source: "ubuntu", + expectedRegistry: "index.docker.io", + expectedName: "library/ubuntu", + expectedTag: "latest", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + preConfig, err := ParsePreRunConfig(tt.source, "") + if err != nil { + t.Errorf("Unexpected error: %v", err) + return + } + + if preConfig.Type != PreRunConfigTypeContainerImage { + t.Errorf("Expected type %s, got %s", PreRunConfigTypeContainerImage, preConfig.Type) + return + } + + src := preConfig.ParsedSource.(*ContainerImageSource) + if src.Registry != tt.expectedRegistry { + t.Errorf("Expected registry %s, got %s", tt.expectedRegistry, src.Registry) + } + + if src.Name != tt.expectedName { + t.Errorf("Expected name %s, got %s", tt.expectedName, src.Name) + } + + if src.Tag != tt.expectedTag { + t.Errorf("Expected tag %s, got %s", tt.expectedTag, src.Tag) + } + }) + } +} + +func TestPreRunConfigString(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + source string + fromConfigPath string + expectedString string + }{ + { + name: "Config file", + source: "", + fromConfigPath: "/path/to/config.json", + expectedString: "Config file: /path/to/config.json", + }, + { + name: "Remote URL", + source: "https://example.com/mcp-server", + expectedString: "Remote URL: https://example.com/mcp-server", + }, + { + name: "Protocol scheme", + source: "uvx://some-package", + expectedString: "Protocol scheme uvx: some-package", + }, + { + name: "Go local path", + source: "go://./local-package", + expectedString: "Protocol scheme go (local): ./local-package", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + preConfig, err := ParsePreRunConfig(tt.source, tt.fromConfigPath) + if err != nil { + t.Errorf("Unexpected error: %v", err) + return + } + + result := preConfig.String() + if result != tt.expectedString { + t.Errorf("Expected string %q, got %q", tt.expectedString, result) + } + }) + } +} + +func TestPreRunConfigHelperMethods(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + source string + expectedRemote bool + expectedBuild bool + }{ + { + name: "Remote URL", + source: "https://example.com/mcp-server", + expectedRemote: true, + expectedBuild: false, + }, + { + name: "Protocol scheme", + source: "uvx://some-package", + expectedRemote: false, + expectedBuild: true, + }, + { + name: "Container image", + source: "alpine:latest", + expectedRemote: false, + expectedBuild: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + preConfig, err := ParsePreRunConfig(tt.source, "") + if err != nil { + t.Errorf("Unexpected error: %v", err) + return + } + + if preConfig.IsRemote() != tt.expectedRemote { + t.Errorf("Expected IsRemote %t, got %t", tt.expectedRemote, preConfig.IsRemote()) + } + + if preConfig.RequiresBuild() != tt.expectedBuild { + t.Errorf("Expected RequiresBuild %t, got %t", tt.expectedBuild, preConfig.RequiresBuild()) + } + }) + } +} diff --git a/pkg/runner/prerun_transformer.go b/pkg/runner/prerun_transformer.go new file mode 100644 index 000000000..c6d55d901 --- /dev/null +++ b/pkg/runner/prerun_transformer.go @@ -0,0 +1,593 @@ +package runner + +import ( + "context" + "fmt" + "os" + + "github.com/stacklok/toolhive/pkg/auth" + "github.com/stacklok/toolhive/pkg/authz" + "github.com/stacklok/toolhive/pkg/container" + "github.com/stacklok/toolhive/pkg/container/images" + "github.com/stacklok/toolhive/pkg/container/runtime" + "github.com/stacklok/toolhive/pkg/environment" + "github.com/stacklok/toolhive/pkg/ignore" + "github.com/stacklok/toolhive/pkg/logger" + "github.com/stacklok/toolhive/pkg/process" + "github.com/stacklok/toolhive/pkg/registry" + "github.com/stacklok/toolhive/pkg/telemetry" + "github.com/stacklok/toolhive/pkg/transport/types" +) + +const ( + defaultTransportType = "streamable-http" +) + +// PreRunTransformer handles the transformation from PreRunConfig to RunConfig +type PreRunTransformer struct { + ctx context.Context + rt runtime.Deployer + envVarValidator EnvVarValidator + imageManager images.ImageManager +} + +// NewPreRunTransformer creates a new PreRunTransformer +func NewPreRunTransformer(ctx context.Context) (*PreRunTransformer, error) { + rt, err := container.NewFactory().Create(ctx) + if err != nil { + return nil, fmt.Errorf("failed to create container runtime: %v", err) + } + + var envVarValidator EnvVarValidator + if process.IsDetached() || runtime.IsKubernetesRuntime() { + envVarValidator = &DetachedEnvVarValidator{} + } else { + envVarValidator = &CLIEnvVarValidator{} + } + + imageManager := images.NewImageManager(ctx) + + return &PreRunTransformer{ + ctx: ctx, + rt: rt, + envVarValidator: envVarValidator, + imageManager: imageManager, + }, nil +} + +// TransformToRunConfig converts a PreRunConfig to a RunConfig +func (t *PreRunTransformer) TransformToRunConfig( + preConfig *PreRunConfig, + runFlags RunFlagsInterface, + cmdArgs []string, + debugMode bool, + validatedHost string, + oidcConfig *auth.TokenValidatorConfig, + telemetryConfig *telemetry.Config, +) (*RunConfig, error) { + switch preConfig.Type { + case PreRunConfigTypeConfigFile: + return t.transformFromConfigFile(preConfig) + case PreRunConfigTypeRemoteURL: + return t.transformFromRemoteURL(preConfig, runFlags, cmdArgs, debugMode, validatedHost, oidcConfig, telemetryConfig) + case PreRunConfigTypeProtocolScheme: + return t.transformFromProtocolScheme(preConfig, runFlags, cmdArgs, debugMode, validatedHost, oidcConfig, telemetryConfig) + case PreRunConfigTypeRegistry: + return t.transformFromRegistry(preConfig, runFlags, cmdArgs, debugMode, validatedHost, oidcConfig, telemetryConfig) + case PreRunConfigTypeContainerImage: + return t.transformFromContainerImage(preConfig, runFlags, cmdArgs, debugMode, validatedHost, oidcConfig, telemetryConfig) + default: + return nil, fmt.Errorf("unsupported PreRunConfig type: %s", preConfig.Type) + } +} + +// transformFromConfigFile handles loading from a configuration file +func (t *PreRunTransformer) transformFromConfigFile(preConfig *PreRunConfig) (*RunConfig, error) { + src := preConfig.ParsedSource.(*ConfigFileSource) + + configFile, err := os.Open(src.FilePath) + if err != nil { + return nil, fmt.Errorf("failed to open configuration file '%s': %w", src.FilePath, err) + } + defer configFile.Close() + + runConfig, err := ReadJSON(configFile) + if err != nil { + return nil, fmt.Errorf("failed to parse configuration file: %w", err) + } + + // Set the runtime in the config + runConfig.Deployer = t.rt + + return runConfig, nil +} + +// transformFromRemoteURL handles remote MCP servers +func (t *PreRunTransformer) transformFromRemoteURL( + preConfig *PreRunConfig, + runFlags RunFlagsInterface, + cmdArgs []string, + debugMode bool, + validatedHost string, + oidcConfig *auth.TokenValidatorConfig, + telemetryConfig *telemetry.Config, +) (*RunConfig, error) { + src := preConfig.ParsedSource.(*RemoteURLSource) + + logger.Debugf("Creating RunConfig for remote URL: %s", src.URL) + + return t.buildRunConfigFromSource( + src.URL, // imageURL (will be empty for remote) + nil, // serverMetadata + runFlags, + cmdArgs, + debugMode, + validatedHost, + oidcConfig, + telemetryConfig, + src.URL, // remoteURL + ) +} + +// transformFromProtocolScheme handles protocol scheme builds +func (t *PreRunTransformer) transformFromProtocolScheme( + preConfig *PreRunConfig, + runFlags RunFlagsInterface, + cmdArgs []string, + debugMode bool, + validatedHost string, + oidcConfig *auth.TokenValidatorConfig, + telemetryConfig *telemetry.Config, +) (*RunConfig, error) { + src := preConfig.ParsedSource.(*ProtocolSchemeSource) + + logger.Debugf("Building container from protocol scheme: %s://%s", src.ProtocolTransportType, src.Package) + + // Build the container image from the protocol scheme + generatedImage, err := HandleProtocolScheme(t.ctx, t.imageManager, preConfig.Source, runFlags.GetCACertPath()) + if err != nil { + return nil, fmt.Errorf("failed to build image from protocol scheme: %w", err) + } + + logger.Debugf("Built image: %s", generatedImage) + + return t.buildRunConfigFromSource( + generatedImage, // imageURL + nil, // serverMetadata + runFlags, + cmdArgs, + debugMode, + validatedHost, + oidcConfig, + telemetryConfig, + "", // remoteURL + ) +} + +// transformFromRegistry handles registry server lookups +func (t *PreRunTransformer) transformFromRegistry( + preConfig *PreRunConfig, + runFlags RunFlagsInterface, + cmdArgs []string, + debugMode bool, + validatedHost string, + oidcConfig *auth.TokenValidatorConfig, + telemetryConfig *telemetry.Config, +) (*RunConfig, error) { + src := preConfig.ParsedSource.(*RegistrySource) + + provider, err := registry.GetDefaultProvider() + if err != nil { + return nil, fmt.Errorf("failed to get registry provider: %v", err) + } + + server, err := provider.GetServer(src.ServerName) + if err != nil { + return nil, fmt.Errorf("failed to get server from registry: %v", err) + } + + if server.IsRemote() { + // Remote server from registry + remoteServerMetadata, ok := server.(*registry.RemoteServerMetadata) + if !ok { + return nil, fmt.Errorf("server marked as remote but is not RemoteServerMetadata type") + } + logger.Infof("Found remote server in registry: %s -> %s", src.ServerName, remoteServerMetadata.URL) + return t.buildRunConfigFromSource( + "", // imageURL (empty for remote servers - no Docker image) + server, // serverMetadata + runFlags, + cmdArgs, + debugMode, + validatedHost, + oidcConfig, + telemetryConfig, + remoteServerMetadata.URL, // remoteURL from registry metadata + ) + } + + // Container server from registry + imageMetadata, err := provider.GetImageServer(src.ServerName) + if err != nil { + return nil, fmt.Errorf("failed to get image metadata from registry: %v", err) + } + + logger.Infof("Found container server in registry: %s -> %s", src.ServerName, imageMetadata.Image) + + // Pull the image if necessary + if err := t.pullImageIfNeeded(imageMetadata.Image); err != nil { + return nil, fmt.Errorf("failed to pull image: %v", err) + } + + return t.buildRunConfigFromSource( + imageMetadata.Image, // imageURL + imageMetadata, // serverMetadata + runFlags, + cmdArgs, + debugMode, + validatedHost, + oidcConfig, + telemetryConfig, + "", // remoteURL + ) +} + +// transformFromContainerImage handles direct container image references +func (t *PreRunTransformer) transformFromContainerImage( + preConfig *PreRunConfig, + runFlags RunFlagsInterface, + cmdArgs []string, + debugMode bool, + validatedHost string, + oidcConfig *auth.TokenValidatorConfig, + telemetryConfig *telemetry.Config, +) (*RunConfig, error) { + src := preConfig.ParsedSource.(*ContainerImageSource) + + logger.Debugf("Using direct container image: %s", src.ImageRef) + + // Pull the image if necessary + if err := t.pullImageIfNeeded(src.ImageRef); err != nil { + return nil, fmt.Errorf("failed to pull image: %v", err) + } + + return t.buildRunConfigFromSource( + src.ImageRef, // imageURL + nil, // serverMetadata + runFlags, + cmdArgs, + debugMode, + validatedHost, + oidcConfig, + telemetryConfig, + "", // remoteURL + ) +} + +// buildRunConfigFromSource builds the final RunConfig using the existing builder pattern +func (t *PreRunTransformer) buildRunConfigFromSource( + imageURL string, + serverMetadata registry.ServerMetadata, + runFlags RunFlagsInterface, + cmdArgs []string, + debugMode bool, + validatedHost string, + oidcConfig *auth.TokenValidatorConfig, + telemetryConfig *telemetry.Config, + remoteURL string, +) (*RunConfig, error) { + // Parse environment variables + envVars, err := environment.ParseEnvironmentVariables(runFlags.GetEnv()) + if err != nil { + return nil, fmt.Errorf("failed to parse environment variables: %v", err) + } + + // Determine transport type and proxy mode + transportType, proxyMode, err := t.determineTransportAndProxyMode(runFlags, serverMetadata) + if err != nil { + return nil, err + } + + // Create base builder + builder := t.createBaseBuilder(imageURL, runFlags, cmdArgs, debugMode, validatedHost, remoteURL, transportType, proxyMode) + + // Configure middleware and authentication + builder = t.configureMiddlewareAndAuth(builder, serverMetadata, runFlags, oidcConfig, telemetryConfig, remoteURL) + + // Configure legacy OIDC and telemetry + builder = t.configureLegacyOIDCAndTelemetry(builder, runFlags, oidcConfig, telemetryConfig) + + // Process environment files + builder, err = t.processEnvironmentFiles(builder, runFlags) + if err != nil { + return nil, err + } + + imageMetadata, _ := serverMetadata.(*registry.ImageMetadata) + return builder.Build(t.ctx, imageMetadata, envVars, t.envVarValidator) +} + +// determineTransportAndProxyMode determines the transport type and validates proxy mode +func (*PreRunTransformer) determineTransportAndProxyMode( + runFlags RunFlagsInterface, + serverMetadata registry.ServerMetadata, +) (string, string, error) { + // Determine transport type + transportType := defaultTransportType + if runFlags.GetTransport() != "" { + transportType = runFlags.GetTransport() + } else if serverMetadata != nil { + transportType = serverMetadata.GetTransport() + } + + // Validate and setup proxy mode + proxyMode := runFlags.GetProxyMode() + if !types.IsValidProxyMode(proxyMode) { + if proxyMode == "" { + proxyMode = types.ProxyModeSSE.String() // default to SSE for backward compatibility + } else { + return "", "", fmt.Errorf("invalid value for --proxy-mode: %s", proxyMode) + } + } + + return transportType, proxyMode, nil +} + +// createBaseBuilder creates the base RunConfig builder with core settings +func (t *PreRunTransformer) createBaseBuilder( + imageURL string, + runFlags RunFlagsInterface, + cmdArgs []string, + debugMode bool, + validatedHost, remoteURL, transportType, proxyMode string, +) *RunConfigBuilder { + return NewRunConfigBuilder(). + WithRuntime(t.rt). + WithCmdArgs(cmdArgs). + WithName(runFlags.GetName()). + WithImage(imageURL). + WithRemoteURL(remoteURL). + WithHost(validatedHost). + WithTargetHost(runFlags.GetTargetHost()). + WithDebug(debugMode). + WithVolumes(runFlags.GetVolumes()). + WithSecrets(runFlags.GetSecrets()). + WithAuthzConfigPath(runFlags.GetAuthzConfig()). + WithAuditConfigPath(runFlags.GetAuditConfig()). + WithPermissionProfileNameOrPath(runFlags.GetPermissionProfile()). + WithNetworkIsolation(runFlags.GetIsolateNetwork()). + WithK8sPodPatch(runFlags.GetK8sPodPatch()). + WithProxyMode(types.ProxyMode(proxyMode)). + WithTransportAndPorts(transportType, runFlags.GetProxyPort(), runFlags.GetTargetPort()). + WithAuditEnabled(runFlags.GetEnableAudit(), runFlags.GetAuditConfig()). + WithLabels(runFlags.GetLabels()). + WithGroup(runFlags.GetGroup()). + WithIgnoreConfig(&ignore.Config{ + LoadGlobal: runFlags.GetIgnoreGlobally(), + PrintOverlays: runFlags.GetPrintOverlays(), + }) +} + +// configureMiddlewareAndAuth configures middleware and authentication settings +func (*PreRunTransformer) configureMiddlewareAndAuth( + builder *RunConfigBuilder, + serverMetadata registry.ServerMetadata, + runFlags RunFlagsInterface, + oidcConfig *auth.TokenValidatorConfig, + telemetryConfig *telemetry.Config, + remoteURL string, +) *RunConfigBuilder { + // Configure middleware from flags + builder = builder.WithMiddlewareFromFlags( + oidcConfig, + runFlags.GetToolsFilter(), + telemetryConfig, + runFlags.GetAuthzConfig(), + runFlags.GetEnableAudit(), + runFlags.GetAuditConfig(), + runFlags.GetName(), + runFlags.GetTransport(), + ) + + // Handle remote authentication if applicable + if remoteServerMetadata, ok := serverMetadata.(*registry.RemoteServerMetadata); ok { + if remoteAuthConfig := getRemoteAuthFromRemoteServerMetadata(remoteServerMetadata, runFlags); remoteAuthConfig != nil { + builder = builder.WithRemoteAuth(remoteAuthConfig) + } + } + if remoteURL != "" { + if remoteAuthConfig := getRemoteAuthFromRunFlags(runFlags); remoteAuthConfig != nil { + builder = builder.WithRemoteAuth(remoteAuthConfig) + } + } + + // Load authz config if path is provided + if runFlags.GetAuthzConfig() != "" { + if authzConfigData, err := authz.LoadConfig(runFlags.GetAuthzConfig()); err == nil { + builder = builder.WithAuthzConfig(authzConfigData) + } + // Note: Path is already set via WithAuthzConfigPath above + } + + return builder +} + +// configureLegacyOIDCAndTelemetry configures legacy OIDC and telemetry settings +func (*PreRunTransformer) configureLegacyOIDCAndTelemetry( + builder *RunConfigBuilder, + runFlags RunFlagsInterface, + oidcConfig *auth.TokenValidatorConfig, + telemetryConfig *telemetry.Config, +) *RunConfigBuilder { + // Get OIDC and telemetry values for legacy configuration + oidcIssuer, oidcAudience, oidcJwksURL, oidcIntrospectionURL, oidcClientID, oidcClientSecret := extractOIDCValues(oidcConfig) + finalOtelEndpoint, finalOtelSamplingRate, finalOtelEnvironmentVariables := extractTelemetryValues(telemetryConfig) + + // Set additional configurations that are still needed in old format for other parts of the system + return builder.WithOIDCConfig(oidcIssuer, oidcAudience, oidcJwksURL, oidcIntrospectionURL, oidcClientID, oidcClientSecret, + runFlags.GetThvCABundle(), runFlags.GetJWKSAuthTokenFile(), runFlags.GetResourceURL(), runFlags.GetJWKSAllowPrivateIP()). + WithTelemetryConfig(finalOtelEndpoint, runFlags.GetOtelEnablePrometheusMetricsPath(), runFlags.GetOtelServiceName(), + finalOtelSamplingRate, runFlags.GetOtelHeaders(), runFlags.GetOtelInsecure(), finalOtelEnvironmentVariables). + WithToolsFilter(runFlags.GetToolsFilter()) +} + +// processEnvironmentFiles processes environment files and directories +func (*PreRunTransformer) processEnvironmentFiles( + builder *RunConfigBuilder, + runFlags RunFlagsInterface, +) (*RunConfigBuilder, error) { + var err error + + // Process environment files + if runFlags.GetEnvFile() != "" { + builder, err = builder.WithEnvFile(runFlags.GetEnvFile()) + if err != nil { + return nil, fmt.Errorf("failed to process env file %s: %v", runFlags.GetEnvFile(), err) + } + } + if runFlags.GetEnvFileDir() != "" { + builder, err = builder.WithEnvFilesFromDirectory(runFlags.GetEnvFileDir()) + if err != nil { + return nil, fmt.Errorf("failed to process env files from directory %s: %v", runFlags.GetEnvFileDir(), err) + } + } + + return builder, nil +} + +// pullImageIfNeeded pulls an image if it doesn't exist locally or has the latest tag +func (t *PreRunTransformer) pullImageIfNeeded(imageRef string) error { + if !runtime.IsKubernetesRuntime() { + return t.pullImage(imageRef) + } + return nil +} + +// pullImage is a copy of the pullImage function from retriever package +// since it's not exported. This pulls an image from a remote registry if it has the "latest" tag +// or if it doesn't exist locally. +func (t *PreRunTransformer) pullImage(image string) error { + // Check if the image has the "latest" tag + isLatestTag := t.hasLatestTag(image) + + if isLatestTag { + // For "latest" tag, try to pull first + logger.Infof("Image %s has 'latest' tag, pulling to ensure we have the most recent version...", image) + err := t.imageManager.PullImage(t.ctx, image) + if err != nil { + // Pull failed, check if it exists locally + logger.Infof("Pull failed, checking if image exists locally: %s", image) + imageExists, checkErr := t.imageManager.ImageExists(t.ctx, image) + if checkErr != nil { + return fmt.Errorf("failed to check if image exists: %v", checkErr) + } + + if imageExists { + logger.Debugf("Using existing local image: %s", image) + } else { + return fmt.Errorf("image not found: %s", image) + } + } else { + logger.Infof("Successfully pulled image: %s", image) + } + } else { + // For non-latest tags, check locally first + logger.Debugf("Checking if image exists locally: %s", image) + imageExists, err := t.imageManager.ImageExists(t.ctx, image) + logger.Debugf("ImageExists locally: %t", imageExists) + if err != nil { + return fmt.Errorf("failed to check if image exists locally: %v", err) + } + + if imageExists { + logger.Debugf("Using existing local image: %s", image) + } else { + // Image doesn't exist locally, try to pull + logger.Infof("Image %s not found locally, pulling...", image) + if err := t.imageManager.PullImage(t.ctx, image); err != nil { + return fmt.Errorf("image not found: %s", image) + } + logger.Infof("Successfully pulled image: %s", image) + } + } + + return nil +} + +// hasLatestTag checks if the given image reference has the "latest" tag or no tag (which defaults to "latest") +func (*PreRunTransformer) hasLatestTag(imageRef string) bool { + // We can reuse the existing logic from parseContainerImageRef + _, _, tag, err := parseContainerImageRef(imageRef) + if err != nil { + logger.Warnf("Warning: Failed to parse image reference: %v", err) + return false + } + return tag == "latest" +} + +// Helper functions that need to be implemented or imported from existing code + +// extractOIDCValues extracts OIDC values from the OIDC config for legacy configuration +func extractOIDCValues(config *auth.TokenValidatorConfig) (string, string, string, string, string, string) { + if config == nil { + return "", "", "", "", "", "" + } + return config.Issuer, config.Audience, config.JWKSURL, config.IntrospectionURL, config.ClientID, config.ClientSecret +} + +// extractTelemetryValues extracts telemetry values from the telemetry config for legacy configuration +func extractTelemetryValues(config *telemetry.Config) (string, float64, []string) { + if config == nil { + return "", 0.0, nil + } + return config.Endpoint, config.SamplingRate, config.EnvironmentVariables +} + +// getRemoteAuthFromRemoteServerMetadata creates RemoteAuthConfig from RemoteServerMetadata +func getRemoteAuthFromRemoteServerMetadata( + remoteServerMetadata *registry.RemoteServerMetadata, + runFlags RunFlagsInterface, +) *RemoteAuthConfig { + if remoteServerMetadata == nil { + return nil + } + + if remoteServerMetadata.OAuthConfig != nil { + remoteAuthFlags := runFlags.GetRemoteAuthFlags() + return &RemoteAuthConfig{ + ClientID: remoteAuthFlags.GetRemoteAuthClientID(), + ClientSecret: remoteAuthFlags.GetRemoteAuthClientSecret(), + Scopes: remoteServerMetadata.OAuthConfig.Scopes, + SkipBrowser: remoteAuthFlags.GetRemoteAuthSkipBrowser(), + Timeout: remoteAuthFlags.GetRemoteAuthTimeout(), + CallbackPort: remoteServerMetadata.OAuthConfig.CallbackPort, + Issuer: remoteServerMetadata.OAuthConfig.Issuer, + AuthorizeURL: remoteServerMetadata.OAuthConfig.AuthorizeURL, + TokenURL: remoteServerMetadata.OAuthConfig.TokenURL, + OAuthParams: remoteServerMetadata.OAuthConfig.OAuthParams, + Headers: remoteServerMetadata.Headers, + EnvVars: remoteServerMetadata.EnvVars, + } + } + return nil +} + +// getRemoteAuthFromRunFlags creates RemoteAuthConfig from RunFlags +func getRemoteAuthFromRunFlags(runFlags RunFlagsInterface) *RemoteAuthConfig { + remoteAuthFlags := runFlags.GetRemoteAuthFlags() + if remoteAuthFlags.GetEnableRemoteAuth() || remoteAuthFlags.GetRemoteAuthClientID() != "" { + return &RemoteAuthConfig{ + ClientID: remoteAuthFlags.GetRemoteAuthClientID(), + ClientSecret: remoteAuthFlags.GetRemoteAuthClientSecret(), + Scopes: remoteAuthFlags.GetRemoteAuthScopes(), + SkipBrowser: remoteAuthFlags.GetRemoteAuthSkipBrowser(), + Timeout: remoteAuthFlags.GetRemoteAuthTimeout(), + CallbackPort: remoteAuthFlags.GetRemoteAuthCallbackPort(), + Issuer: remoteAuthFlags.GetRemoteAuthIssuer(), + AuthorizeURL: remoteAuthFlags.GetRemoteAuthAuthorizeURL(), + TokenURL: remoteAuthFlags.GetRemoteAuthTokenURL(), + OAuthParams: runFlags.GetOAuthParams(), + } + } + return nil +} diff --git a/pkg/runner/prerun_transformer_test.go b/pkg/runner/prerun_transformer_test.go new file mode 100644 index 000000000..6f18e946c --- /dev/null +++ b/pkg/runner/prerun_transformer_test.go @@ -0,0 +1,159 @@ +package runner + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/stacklok/toolhive/pkg/registry" +) + +func TestRemoteServerMetadata_IsRemote(t *testing.T) { + t.Parallel() + + // Test that RemoteServerMetadata correctly identifies as remote + remoteServer := ®istry.RemoteServerMetadata{ + BaseServerMetadata: registry.BaseServerMetadata{ + Name: "test-remote-server", + Description: "Test remote server", + Transport: "sse", + Tools: []string{"test-tool"}, + Tier: "test", + Status: "active", + }, + URL: "https://example.com/mcp", + } + + assert.True(t, remoteServer.IsRemote(), "RemoteServerMetadata should identify as remote") + assert.Equal(t, "https://example.com/mcp", remoteServer.URL) + assert.Equal(t, "test-remote-server", remoteServer.GetName()) + assert.Equal(t, "sse", remoteServer.GetTransport()) +} + +func TestImageMetadata_IsRemote(t *testing.T) { + t.Parallel() + + // Test that ImageMetadata correctly identifies as not remote + containerServer := ®istry.ImageMetadata{ + BaseServerMetadata: registry.BaseServerMetadata{ + Name: "test-container-server", + Description: "Test container server", + Transport: "stdio", + Tools: []string{"test-tool"}, + Tier: "test", + Status: "active", + }, + Image: "docker.io/test/server:latest", + } + + assert.False(t, containerServer.IsRemote(), "ImageMetadata should not identify as remote") + assert.Equal(t, "docker.io/test/server:latest", containerServer.Image) + assert.Equal(t, "test-container-server", containerServer.GetName()) + assert.Equal(t, "stdio", containerServer.GetTransport()) +} + +func TestPreRunConfigType_Constants(t *testing.T) { + t.Parallel() + + // Test that the PreRunConfigType constants are defined correctly + assert.Equal(t, PreRunConfigType("registry"), PreRunConfigTypeRegistry) + assert.Equal(t, PreRunConfigType("container_image"), PreRunConfigTypeContainerImage) + assert.Equal(t, PreRunConfigType("protocol_scheme"), PreRunConfigTypeProtocolScheme) + assert.Equal(t, PreRunConfigType("remote_url"), PreRunConfigTypeRemoteURL) + assert.Equal(t, PreRunConfigType("config_file"), PreRunConfigTypeConfigFile) +} + +func TestRegistrySource_Structure(t *testing.T) { + t.Parallel() + + // Test the RegistrySource structure + registrySource := &RegistrySource{ + ServerName: "test-server", + IsRemote: true, + } + + assert.Equal(t, "test-server", registrySource.ServerName) + assert.True(t, registrySource.IsRemote) + + // Test with container server + containerSource := &RegistrySource{ + ServerName: "container-server", + IsRemote: false, + } + + assert.Equal(t, "container-server", containerSource.ServerName) + assert.False(t, containerSource.IsRemote) +} + +func TestPreRunConfig_Structure(t *testing.T) { + t.Parallel() + + // Test PreRunConfig structure for registry source + preConfig := &PreRunConfig{ + Type: PreRunConfigTypeRegistry, + Source: "test-server", + ParsedSource: &RegistrySource{ + ServerName: "test-server", + IsRemote: true, + }, + Metadata: map[string]interface{}{ + "discovered": true, + }, + } + + assert.Equal(t, PreRunConfigTypeRegistry, preConfig.Type) + assert.Equal(t, "test-server", preConfig.Source) + + registrySource, ok := preConfig.ParsedSource.(*RegistrySource) + require.True(t, ok, "ParsedSource should be a RegistrySource") + assert.Equal(t, "test-server", registrySource.ServerName) + assert.True(t, registrySource.IsRemote) + + assert.Equal(t, true, preConfig.Metadata["discovered"]) +} + +// Test the type assertion logic that was fixed in the transformer +func TestRemoteServerTypeAssertion(t *testing.T) { + t.Parallel() + + // Test successful type assertion + var serverMetadata registry.ServerMetadata = ®istry.RemoteServerMetadata{ + BaseServerMetadata: registry.BaseServerMetadata{ + Name: "remote-server", + }, + URL: "https://example.com/mcp", + } + + if serverMetadata.IsRemote() { + remoteServer, ok := serverMetadata.(*registry.RemoteServerMetadata) + require.True(t, ok, "Should be able to cast to RemoteServerMetadata") + assert.Equal(t, "https://example.com/mcp", remoteServer.URL) + } + + // Test failed type assertion (this is what our fix handles) + var invalidRemoteServer registry.ServerMetadata = &mockInvalidRemoteServer{} + + if invalidRemoteServer.IsRemote() { + _, ok := invalidRemoteServer.(*registry.RemoteServerMetadata) + assert.False(t, ok, "Should not be able to cast invalid server to RemoteServerMetadata") + // This is the error case our fix handles + } +} + +// Mock server that claims to be remote but isn't RemoteServerMetadata +// This tests the error case we fixed in the transformer +type mockInvalidRemoteServer struct{} + +func (*mockInvalidRemoteServer) GetName() string { return "invalid-remote" } +func (*mockInvalidRemoteServer) GetDescription() string { return "Invalid remote server" } +func (*mockInvalidRemoteServer) GetTier() string { return "test" } +func (*mockInvalidRemoteServer) GetStatus() string { return "active" } +func (*mockInvalidRemoteServer) GetTransport() string { return "sse" } +func (*mockInvalidRemoteServer) GetTools() []string { return []string{"test-tool"} } +func (*mockInvalidRemoteServer) GetMetadata() *registry.Metadata { return nil } +func (*mockInvalidRemoteServer) GetRepositoryURL() string { return "" } +func (*mockInvalidRemoteServer) GetTags() []string { return nil } +func (*mockInvalidRemoteServer) GetCustomMetadata() map[string]any { return nil } +func (*mockInvalidRemoteServer) IsRemote() bool { return true } // Claims to be remote +func (*mockInvalidRemoteServer) GetEnvVars() []*registry.EnvVar { return nil } diff --git a/pkg/runner/run_flags_interface.go b/pkg/runner/run_flags_interface.go new file mode 100644 index 000000000..96153119a --- /dev/null +++ b/pkg/runner/run_flags_interface.go @@ -0,0 +1,83 @@ +package runner + +import "time" + +// RemoteAuthFlagsInterface defines the interface for remote authentication flags +type RemoteAuthFlagsInterface interface { + GetEnableRemoteAuth() bool + GetRemoteAuthClientID() string + GetRemoteAuthClientSecret() string + GetRemoteAuthClientSecretFile() string + GetRemoteAuthScopes() []string + GetRemoteAuthSkipBrowser() bool + GetRemoteAuthTimeout() time.Duration + GetRemoteAuthCallbackPort() int + GetRemoteAuthIssuer() string + GetRemoteAuthAuthorizeURL() string + GetRemoteAuthTokenURL() string +} + +// RunFlagsInterface defines the interface for run flags that the PreRunTransformer needs +type RunFlagsInterface interface { + // Basic configuration + GetName() string + GetGroup() string + GetTransport() string + GetProxyMode() string + GetHost() string + GetProxyPort() int + GetTargetPort() int + GetTargetHost() string + + // Environment and volumes + GetEnv() []string + GetVolumes() []string + GetSecrets() []string + GetEnvFile() string + GetEnvFileDir() string + + // Security and permissions + GetPermissionProfile() string + GetAuthzConfig() string + GetAuditConfig() string + GetEnableAudit() bool + GetCACertPath() string + GetVerifyImage() string + + // Remote configuration + GetRemoteURL() string + GetRemoteAuthFlags() RemoteAuthFlagsInterface + GetOAuthParams() map[string]string + + // Network and isolation + GetIsolateNetwork() bool + GetLabels() []string + + // Tools and filtering + GetToolsFilter() []string + + // Execution mode + GetForeground() bool + + // OIDC configuration + GetThvCABundle() string + GetJWKSAuthTokenFile() string + GetJWKSAllowPrivateIP() bool + GetResourceURL() string + + // Telemetry configuration + GetOtelEndpoint() string + GetOtelServiceName() string + GetOtelSamplingRate() float64 + GetOtelHeaders() []string + GetOtelInsecure() bool + GetOtelEnablePrometheusMetricsPath() bool + GetOtelEnvironmentVariables() []string + + // Kubernetes specific + GetK8sPodPatch() string + + // Ignore functionality + GetIgnoreGlobally() bool + GetPrintOverlays() bool +}