Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions crates/goose-cli/src/commands/configure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ pub async fn configure_provider_dialog() -> Result<bool, Box<dyn Error>> {
let spin = spinner();
spin.start("Attempting to fetch supported models...");
let models_res = {
let temp_model_config = goose::model::ModelConfig::new(provider_meta.default_model.clone());
let temp_model_config = goose::model::ModelConfig::new(&provider_meta.default_model)?;
let temp_provider = create(provider_name, temp_model_config)?;
temp_provider.fetch_supported_models_async().await
};
Expand Down Expand Up @@ -373,7 +373,7 @@ pub async fn configure_provider_dialog() -> Result<bool, Box<dyn Error>> {
.map(|val| val == "1" || val.to_lowercase() == "true")
.unwrap_or(false);

let model_config = goose::model::ModelConfig::new(model.clone())
let model_config = goose::model::ModelConfig::new(&model)?
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we probably just want to skip this model if we can't load it

.with_max_tokens(Some(50))
.with_toolshim(toolshim_enabled)
.with_toolshim_model(std::env::var("GOOSE_TOOLSHIM_OLLAMA_MODEL").ok());
Expand Down Expand Up @@ -1231,7 +1231,7 @@ pub async fn configure_tool_permissions_dialog() -> Result<(), Box<dyn Error>> {
let model: String = config
.get_param("GOOSE_MODEL")
.expect("No model configured. Please set model first");
let model_config = goose::model::ModelConfig::new(model.clone());
let model_config = goose::model::ModelConfig::new(&model)?;

// Create the agent
let agent = Agent::new();
Expand Down
2 changes: 1 addition & 1 deletion crates/goose-cli/src/commands/web.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ pub async fn handle_web(port: u16, host: String, open: bool) -> Result<()> {
}
};

let model_config = goose::model::ModelConfig::new(model.clone());
let model_config = goose::model::ModelConfig::new(&model)?;

// Create the agent
let agent = Agent::new();
Expand Down
5 changes: 1 addition & 4 deletions crates/goose-cli/src/scenario_tests/scenario_runner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,10 +176,7 @@ where

let original_env = setup_environment(config)?;

let inner_provider = create(
&factory_name,
ModelConfig::new(config.model_name.to_string()),
)?;
let inner_provider = create(&factory_name, ModelConfig::new(&config.model_name)?)?;

let test_provider = Arc::new(TestProvider::new_recording(inner_provider, &file_path));
(
Expand Down
8 changes: 6 additions & 2 deletions crates/goose-cli/src/session/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,12 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> Session {

let temperature = session_config.settings.as_ref().and_then(|s| s.temperature);

let model_config =
goose::model::ModelConfig::new(model_name.clone()).with_temperature(temperature);
let model_config = goose::model::ModelConfig::new(&model_name)
.unwrap_or_else(|e| {
output::render_error(&format!("Failed to create model configuration: {}", e));
process::exit(1);
})
.with_temperature(temperature);

// Create the agent
let agent: Agent = Agent::new();
Expand Down
2 changes: 1 addition & 1 deletion crates/goose-cli/src/session/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1572,7 +1572,7 @@ fn get_reasoner() -> Result<Arc<dyn Provider>, anyhow::Error> {
};

let model_config =
ModelConfig::new_with_context_env(model, Some("GOOSE_PLANNER_CONTEXT_LIMIT"));
ModelConfig::new_with_context_env(model, Some("GOOSE_PLANNER_CONTEXT_LIMIT"))?;
let reasoner = create(&provider, model_config)?;

Ok(reasoner)
Expand Down
2 changes: 1 addition & 1 deletion crates/goose-server/src/routes/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ async fn update_agent_provider(
.get_param("GOOSE_MODEL")
.expect("Did not find a model on payload or in env to update provider with")
});
let model_config = ModelConfig::new(model);
let model_config = ModelConfig::new(&model).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the line above where we have .expect() will kill goosed and it is all over. not sure how we would have gotten here, but still, we should probably do another 500; I also think we should change the description of the 500 from internal server error here to something descriptive; could not update provider, make sure that you have the right config or something

let new_provider = create(&payload.provider, model_config).unwrap();
agent
.update_provider(new_provider)
Expand Down
2 changes: 1 addition & 1 deletion crates/goose-server/src/routes/reply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ mod tests {

#[tokio::test]
async fn test_reply_endpoint() {
let mock_model_config = ModelConfig::new("test-model".to_string());
let mock_model_config = ModelConfig::new("test-model").unwrap();
let mock_provider = Arc::new(MockProvider {
model_config: mock_model_config,
});
Expand Down
3 changes: 2 additions & 1 deletion crates/goose/src/agents/router_tool_selector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ impl VectorToolSelector {
env::var("GOOSE_EMBEDDING_MODEL_PROVIDER").unwrap_or_else(|_| "openai".to_string());

// Create the provider using the factory
let model_config = ModelConfig::new(embedding_model);
let model_config = ModelConfig::new(embedding_model.as_str())
.context("Failed to create model config for embedding provider")?;
providers::create(&embedding_provider_name, model_config).context(format!(
"Failed to create {} provider for embeddings. If using OpenAI, make sure OPENAI_API_KEY env var is set or that you have configured the OpenAI provider via Goose before.",
embedding_provider_name
Expand Down
15 changes: 8 additions & 7 deletions crates/goose/src/context_mgmt/summarize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -264,12 +264,13 @@ mod tests {
}
}

fn create_mock_provider() -> Arc<dyn Provider> {
fn create_mock_provider() -> Result<Arc<dyn Provider>> {
let mock_model_config =
ModelConfig::new("test-model".to_string()).with_context_limit(200_000.into());
Arc::new(MockProvider {
ModelConfig::new_or_fail("test-model").with_context_limit(200_000.into());

Ok(Arc::new(MockProvider {
model_config: mock_model_config,
})
}))
}

fn create_test_messages() -> Vec<Message> {
Expand Down Expand Up @@ -305,7 +306,7 @@ mod tests {

#[tokio::test]
async fn test_summarize_messages_single_chunk() {
let provider = create_mock_provider();
let provider = create_mock_provider().expect("failed to create mock provider");
let token_counter = TokenCounter::new();
let context_limit = 100; // Set a high enough limit to avoid chunking.
let messages = create_test_messages();
Expand Down Expand Up @@ -341,7 +342,7 @@ mod tests {

#[tokio::test]
async fn test_summarize_messages_multiple_chunks() {
let provider = create_mock_provider();
let provider = create_mock_provider().expect("failed to create mock provider");
let token_counter = TokenCounter::new();
let context_limit = 30;
let messages = create_test_messages();
Expand Down Expand Up @@ -377,7 +378,7 @@ mod tests {

#[tokio::test]
async fn test_summarize_messages_empty_input() {
let provider = create_mock_provider();
let provider = create_mock_provider().expect("failed to create mock provider");
let token_counter = TokenCounter::new();
let context_limit = 100;
let messages: Vec<Message> = Vec::new();
Expand Down
2 changes: 2 additions & 0 deletions crates/goose/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,5 @@ pub mod utils;

#[cfg(test)]
mod cron_test;
#[macro_use]
mod macros;
19 changes: 19 additions & 0 deletions crates/goose/src/macros.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#[macro_export]
macro_rules! impl_provider_default {
($provider:ty) => {
impl Default for $provider {
fn default() -> Self {
let model = $crate::model::ModelConfig::new(
&<$provider as $crate::providers::base::Provider>::metadata().default_model,
)
.expect(concat!(
"Failed to create model config for ",
stringify!($provider)
));

<$provider>::from_env(model)
.expect(concat!("Failed to initialize ", stringify!($provider)))
}
}
};
}
Loading
Loading