diff --git a/crates/goose/src/providers/claude_code.rs b/crates/goose/src/providers/claude_code.rs index c98862c11c0f..03909c7a326d 100644 --- a/crates/goose/src/providers/claude_code.rs +++ b/crates/goose/src/providers/claude_code.rs @@ -332,12 +332,14 @@ impl ClaudeCodeProvider { cmd.arg("-p") .arg(messages_json.to_string()) .arg("--system-prompt") - .arg(&filtered_system) - .arg("--model") - .arg(&self.model.model_name) - .arg("--verbose") - .arg("--output-format") - .arg("json"); + .arg(&filtered_system); + + // Only pass model parameter if it's in the known models list + if CLAUDE_CODE_KNOWN_MODELS.contains(&self.model.model_name.as_str()) { + cmd.arg("--model").arg(&self.model.model_name); + } + + cmd.arg("--verbose").arg("--output-format").arg("json"); // Add permission mode based on GOOSE_MODE setting let config = Config::global(); @@ -541,4 +543,24 @@ mod tests { std::env::remove_var("GOOSE_MODE"); } + + #[test] + fn test_claude_code_invalid_model_no_fallback() { + // Test that an invalid model is kept as-is (no fallback) + let invalid_model = ModelConfig::new("invalid-model".to_string()); + let provider = ClaudeCodeProvider::from_env(invalid_model).unwrap(); + let config = provider.get_model_config(); + + assert_eq!(config.model_name, "invalid-model"); + } + + #[test] + fn test_claude_code_valid_model() { + // Test that a valid model is preserved + let valid_model = ModelConfig::new("sonnet".to_string()); + let provider = ClaudeCodeProvider::from_env(valid_model).unwrap(); + let config = provider.get_model_config(); + + assert_eq!(config.model_name, "sonnet"); + } } diff --git a/crates/goose/src/providers/gemini_cli.rs b/crates/goose/src/providers/gemini_cli.rs index d1c60517ae2f..387af4987d08 100644 --- a/crates/goose/src/providers/gemini_cli.rs +++ b/crates/goose/src/providers/gemini_cli.rs @@ -170,11 +170,13 @@ impl GeminiCliProvider { } let mut cmd = Command::new(&self.command); - cmd.arg("-m") - .arg(&self.model.model_name) - .arg("-p") - .arg(&full_prompt) - .arg("--yolo"); + + // Only pass model parameter if it's in the known models list + if GEMINI_CLI_KNOWN_MODELS.contains(&self.model.model_name.as_str()) { + cmd.arg("-m").arg(&self.model.model_name); + } + + cmd.arg("-p").arg(&full_prompt).arg("--yolo"); cmd.stdout(Stdio::piped()).stderr(Stdio::piped()); @@ -370,4 +372,24 @@ mod tests { // Context limit should be set by the ModelConfig assert!(config.context_limit() > 0); } + + #[test] + fn test_gemini_cli_invalid_model_no_fallback() { + // Test that an invalid model is kept as-is (no fallback) + let invalid_model = ModelConfig::new("invalid-model".to_string()); + let provider = GeminiCliProvider::from_env(invalid_model).unwrap(); + let config = provider.get_model_config(); + + assert_eq!(config.model_name, "invalid-model"); + } + + #[test] + fn test_gemini_cli_valid_model() { + // Test that a valid model is preserved + let valid_model = ModelConfig::new(GEMINI_CLI_DEFAULT_MODEL.to_string()); + let provider = GeminiCliProvider::from_env(valid_model).unwrap(); + let config = provider.get_model_config(); + + assert_eq!(config.model_name, GEMINI_CLI_DEFAULT_MODEL); + } }