Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
e501e3b
my hints for goose
katzdave Aug 20, 2025
060bbdd
base impl
katzdave Aug 20, 2025
8b34e02
update new providers with api
katzdave Aug 20, 2025
966e391
add provider defaults
katzdave Aug 20, 2025
f95b13d
cleanup
katzdave Aug 20, 2025
32256b0
clean comments
katzdave Aug 20, 2025
e8cd8fc
Fmt
katzdave Aug 20, 2025
3c3c376
parse context limit
katzdave Aug 20, 2025
bdb5583
more model config
katzdave Aug 20, 2025
352e0c5
rm groq model
katzdave Aug 20, 2025
ee5f40b
Reset python files
katzdave Aug 20, 2025
0c85d4d
should build now
katzdave Aug 20, 2025
3637a1a
with fast abstraction + env set
katzdave Aug 21, 2025
5528a29
no fast model on custom config
katzdave Aug 21, 2025
0447587
fn comments
katzdave Aug 21, 2025
51bfa48
Swap model to modelconfig
katzdave Aug 21, 2025
d3a5307
rm extra scripts
katzdave Aug 21, 2025
55aeedf
openai output model
katzdave Aug 21, 2025
b666551
fix warnings
katzdave Aug 21, 2025
f6c2550
fmt
katzdave Aug 21, 2025
8cc66f1
support databricks
katzdave Aug 21, 2025
4ec0d9b
bring back output
katzdave Aug 21, 2025
bdbfcf4
fix databricks
katzdave Aug 21, 2025
7a5f390
databricks to 3.7sonnet
katzdave Aug 21, 2025
1e6f164
fix clippy
katzdave Aug 21, 2025
395f434
summary model -> 1.5flash
katzdave Aug 21, 2025
a60b0fd
Merge branch 'main' of github.com:block/goose into dkatz/fast-summarize2
katzdave Aug 21, 2025
f996262
fix titrate
katzdave Aug 21, 2025
9e43de9
one more test fix
katzdave Aug 21, 2025
2f232f8
fix agent tests
katzdave Aug 21, 2025
0cd4189
add fast model exists check
katzdave Aug 21, 2025
74a8359
bump sonnet model
katzdave Aug 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion crates/goose-server/src/routes/reply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -537,8 +537,9 @@ mod tests {
goose::providers::base::ProviderMetadata::empty()
}

async fn complete(
async fn complete_with_model(
&self,
_model_config: &ModelConfig,
_system: &str,
_messages: &[Message],
_tools: &[rmcp::model::Tool],
Expand Down
3 changes: 2 additions & 1 deletion crates/goose/src/context_mgmt/auto_compact.rs
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,9 @@ mod tests {
self.model_config.clone()
}

async fn complete(
async fn complete_with_model(
&self,
_model_config: &ModelConfig,
_system: &str,
_messages: &[Message],
_tools: &[Tool],
Expand Down
5 changes: 3 additions & 2 deletions crates/goose/src/context_mgmt/summarize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ pub async fn summarize_messages(

// Send the request to the provider and fetch the response
let (mut response, mut provider_usage) = provider
.complete(&system_prompt, &summarization_request, &[])
.complete_fast(&system_prompt, &summarization_request, &[])
.await?;

// Set role to user as it will be used in following conversation as user content
Expand Down Expand Up @@ -87,8 +87,9 @@ mod tests {
self.model_config.clone()
}

async fn complete(
async fn complete_with_model(
&self,
_model_config: &ModelConfig,
_system: &str,
_messages: &[Message],
_tools: &[Tool],
Expand Down
59 changes: 56 additions & 3 deletions crates/goose/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ static MODEL_SPECIFIC_LIMITS: Lazy<Vec<(&'static str, usize)>> = Lazy::new(|| {
// anthropic - all 200k
("claude", 200_000),
// google
("gemini-1.5-flash", 1_000_000),
("gemini-1", 128_000),
("gemini-2", 1_000_000),
("gemma-3-27b", 128_000),
Expand Down Expand Up @@ -72,6 +73,7 @@ pub struct ModelConfig {
pub max_tokens: Option<i32>,
pub toolshim: bool,
pub toolshim_model: Option<String>,
pub fast_model: Option<String>,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
Expand All @@ -89,7 +91,7 @@ impl ModelConfig {
model_name: String,
context_env_var: Option<&str>,
) -> Result<Self, ConfigError> {
let context_limit = Self::parse_context_limit(&model_name, context_env_var)?;
let context_limit = Self::parse_context_limit(&model_name, None, context_env_var)?;
let temperature = Self::parse_temperature()?;
let toolshim = Self::parse_toolshim()?;
let toolshim_model = Self::parse_toolshim_model()?;
Expand All @@ -101,13 +103,16 @@ impl ModelConfig {
max_tokens: None,
toolshim,
toolshim_model,
fast_model: None,
})
}

fn parse_context_limit(
model_name: &str,
fast_model: Option<&str>,
custom_env_var: Option<&str>,
) -> Result<Option<usize>, ConfigError> {
// First check if there's an explicit environment variable override
if let Some(env_var) = custom_env_var {
if let Ok(val) = std::env::var(env_var) {
return Self::validate_context_limit(&val, env_var).map(Some);
Expand All @@ -116,7 +121,24 @@ impl ModelConfig {
if let Ok(val) = std::env::var("GOOSE_CONTEXT_LIMIT") {
return Self::validate_context_limit(&val, "GOOSE_CONTEXT_LIMIT").map(Some);
}
Ok(Self::get_model_specific_limit(model_name))

// Get the model's limit
let model_limit = Self::get_model_specific_limit(model_name);

// If there's a fast_model, get its limit and use the minimum
if let Some(fast_model_name) = fast_model {
let fast_model_limit = Self::get_model_specific_limit(fast_model_name);

// Return the minimum of both limits (if both exist)
match (model_limit, fast_model_limit) {
(Some(m), Some(f)) => Ok(Some(m.min(f))),
(Some(m), None) => Ok(Some(m)),
(None, Some(f)) => Ok(Some(f)),
(None, None) => Ok(None),
}
} else {
Ok(model_limit)
}
}

fn validate_context_limit(val: &str, env_var: &str) -> Result<usize, ConfigError> {
Expand Down Expand Up @@ -231,8 +253,39 @@ impl ModelConfig {
self
}

pub fn with_fast(mut self, fast_model: String) -> Self {
self.fast_model = Some(fast_model);
self
}

pub fn use_fast_model(&self) -> Self {
if let Some(fast_model) = &self.fast_model {
let mut config = self.clone();
config.model_name = fast_model.clone();
config
} else {
self.clone()
}
}

pub fn context_limit(&self) -> usize {
self.context_limit.unwrap_or(DEFAULT_CONTEXT_LIMIT)
// If we have an explicit context limit set, use it
if let Some(limit) = self.context_limit {
return limit;
}

// Otherwise, get the model's default limit
let main_limit =
Self::get_model_specific_limit(&self.model_name).unwrap_or(DEFAULT_CONTEXT_LIMIT);

// If we have a fast_model, also check its limit and use the minimum
if let Some(fast_model) = &self.fast_model {
let fast_limit =
Self::get_model_specific_limit(fast_model).unwrap_or(DEFAULT_CONTEXT_LIMIT);
main_limit.min(fast_limit)
} else {
main_limit
}
}

pub fn new_or_fail(model_name: &str) -> ModelConfig {
Expand Down
3 changes: 2 additions & 1 deletion crates/goose/src/permission/permission_judge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -292,8 +292,9 @@ mod tests {
self.model_config.clone()
}

async fn complete(
async fn complete_with_model(
&self,
_model_config: &ModelConfig,
_system: &str,
_messages: &[Message],
_tools: &[Tool],
Expand Down
18 changes: 11 additions & 7 deletions crates/goose/src/providers/anthropic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use crate::providers::retry::ProviderRetry;
use rmcp::model::Tool;

const ANTHROPIC_DEFAULT_MODEL: &str = "claude-sonnet-4-0";
const ANTHROPIC_DEFAULT_FAST_MODEL: &str = "claude-3-7-sonnet-latest";
const ANTHROPIC_KNOWN_MODELS: &[&str] = &[
"claude-sonnet-4-0",
"claude-sonnet-4-20250514",
Expand Down Expand Up @@ -50,6 +51,8 @@ impl_provider_default!(AnthropicProvider);

impl AnthropicProvider {
pub fn from_env(model: ModelConfig) -> Result<Self> {
let model = model.with_fast(ANTHROPIC_DEFAULT_FAST_MODEL.to_string());

let config = crate::config::Config::global();
let api_key: String = config.get_secret("ANTHROPIC_API_KEY")?;
let host: String = config
Expand Down Expand Up @@ -179,16 +182,17 @@ impl Provider for AnthropicProvider {
}

#[tracing::instrument(
skip(self, system, messages, tools),
skip(self, model_config, system, messages, tools),
fields(model_config, input, output, input_tokens, output_tokens, total_tokens)
)]
async fn complete(
async fn complete_with_model(
&self,
model_config: &ModelConfig,
system: &str,
messages: &[Message],
tools: &[Tool],
) -> Result<(Message, ProviderUsage), ProviderError> {
let payload = create_request(&self.model, system, messages, tools)?;
let payload = create_request(model_config, system, messages, tools)?;

let response = self
.with_retry(|| async { self.post(&payload).await })
Expand All @@ -201,9 +205,9 @@ impl Provider for AnthropicProvider {
tracing::debug!("🔍 Anthropic non-streaming parsed usage: input_tokens={:?}, output_tokens={:?}, total_tokens={:?}",
usage.input_tokens, usage.output_tokens, usage.total_tokens);

let model = get_model(&json_response);
let response_model = get_model(&json_response);
emit_debug_trace(&self.model, &payload, &json_response, &usage);
let provider_usage = ProviderUsage::new(model, usage);
let provider_usage = ProviderUsage::new(response_model, usage);
tracing::debug!(
"🔍 Anthropic non-streaming returning ProviderUsage: {:?}",
provider_usage
Expand Down Expand Up @@ -271,7 +275,7 @@ impl Provider for AnthropicProvider {

let stream = response.bytes_stream().map_err(io::Error::other);

let model_config = self.model.clone();
let model = self.model.clone();
Ok(Box::pin(try_stream! {
let stream_reader = StreamReader::new(stream);
let framed = tokio_util::codec::FramedRead::new(stream_reader, tokio_util::codec::LinesCodec::new()).map_err(anyhow::Error::from);
Expand All @@ -280,7 +284,7 @@ impl Provider for AnthropicProvider {
pin!(message_stream);
while let Some(message) = futures::StreamExt::next(&mut message_stream).await {
let (message, usage) = message.map_err(|e| ProviderError::RequestFailed(format!("Stream decode error: {}", e)))?;
emit_debug_trace(&model_config, &payload, &message, &usage.as_ref().map(|f| f.usage).unwrap_or_default());
emit_debug_trace(&model, &payload, &message, &usage.as_ref().map(|f| f.usage).unwrap_or_default());
yield (message, usage);
}
}))
Expand Down
13 changes: 7 additions & 6 deletions crates/goose/src/providers/azure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,16 +135,17 @@ impl Provider for AzureProvider {
}

#[tracing::instrument(
skip(self, system, messages, tools),
skip(self, model_config, system, messages, tools),
fields(model_config, input, output, input_tokens, output_tokens, total_tokens)
)]
async fn complete(
async fn complete_with_model(
&self,
model_config: &ModelConfig,
system: &str,
messages: &[Message],
tools: &[Tool],
) -> Result<(Message, ProviderUsage), ProviderError> {
let payload = create_request(&self.model, system, messages, tools, &ImageFormat::OpenAi)?;
let payload = create_request(model_config, system, messages, tools, &ImageFormat::OpenAi)?;
let response = self
.with_retry(|| async {
let payload_clone = payload.clone();
Expand All @@ -157,8 +158,8 @@ impl Provider for AzureProvider {
tracing::debug!("Failed to get usage data");
Usage::default()
});
let model = get_model(&response);
emit_debug_trace(&self.model, &payload, &response, &usage);
Ok((message, ProviderUsage::new(model, usage)))
let response_model = get_model(&response);
emit_debug_trace(model_config, &payload, &response, &usage);
Ok((message, ProviderUsage::new(response_model, usage)))
}
}
45 changes: 30 additions & 15 deletions crates/goose/src/providers/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -317,26 +317,41 @@ pub trait Provider: Send + Sync {
where
Self: Sized;

/// Generate the next message using the configured model and other parameters
///
/// # Arguments
/// * `system` - The system prompt that guides the model's behavior
/// * `messages` - The conversation history as a sequence of messages
/// * `tools` - Optional list of tools the model can use
///
/// # Returns
/// A tuple containing the model's response message and provider usage statistics
///
/// # Errors
/// ProviderError
/// - It's important to raise ContextLengthExceeded correctly since agent handles it
async fn complete(
// Internal implementation of complete, used by complete_fast and complete
// Providers should override this to implement their actual completion logic
async fn complete_with_model(
&self,
model_config: &ModelConfig,
system: &str,
messages: &[Message],
tools: &[Tool],
) -> Result<(Message, ProviderUsage), ProviderError>;

// Default implementation: use the provider's configured model
async fn complete(
&self,
system: &str,
messages: &[Message],
tools: &[Tool],
) -> Result<(Message, ProviderUsage), ProviderError> {
let model_config = self.get_model_config();
self.complete_with_model(&model_config, system, messages, tools)
.await
}

// Check if a fast model is configured, otherwise fall back to regular model
async fn complete_fast(
&self,
system: &str,
messages: &[Message],
tools: &[Tool],
) -> Result<(Message, ProviderUsage), ProviderError> {
let model_config = self.get_model_config();
let fast_config = model_config.use_fast_model();
self.complete_with_model(&fast_config, system, messages, tools)
.await
}

/// Get the model config from the provider
fn get_model_config(&self) -> ModelConfig;

Expand Down Expand Up @@ -418,7 +433,7 @@ pub trait Provider: Send + Sync {
let prompt = self.create_session_name_prompt(&context);
let message = Message::user().with_text(&prompt);
let result = self
.complete(
.complete_fast(
"Reply with only a description in four words or less",
&[message],
&[],
Expand Down
7 changes: 4 additions & 3 deletions crates/goose/src/providers/bedrock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,16 +152,17 @@ impl Provider for BedrockProvider {
}

#[tracing::instrument(
skip(self, system, messages, tools),
skip(self, model_config, system, messages, tools),
fields(model_config, input, output, input_tokens, output_tokens, total_tokens)
)]
async fn complete(
async fn complete_with_model(
&self,
model_config: &ModelConfig,
system: &str,
messages: &[Message],
tools: &[Tool],
) -> Result<(Message, ProviderUsage), ProviderError> {
let model_name = &self.model.model_name;
let model_name = model_config.model_name.clone();

let (bedrock_message, bedrock_usage) = self
.with_retry(|| self.converse(system, messages, tools))
Expand Down
11 changes: 6 additions & 5 deletions crates/goose/src/providers/claude_code.rs
Original file line number Diff line number Diff line change
Expand Up @@ -474,11 +474,12 @@ impl Provider for ClaudeCodeProvider {
}

#[tracing::instrument(
skip(self, system, messages, tools),
skip(self, model_config, system, messages, tools),
fields(model_config, input, output, input_tokens, output_tokens, total_tokens)
)]
async fn complete(
async fn complete_with_model(
&self,
model_config: &ModelConfig,
system: &str,
messages: &[Message],
tools: &[Tool],
Expand All @@ -495,7 +496,7 @@ impl Provider for ClaudeCodeProvider {
// Create a dummy payload for debug tracing
let payload = json!({
"command": self.command,
"model": self.model.model_name,
"model": model_config.model_name,
"system": system,
"messages": messages.len()
});
Expand All @@ -505,11 +506,11 @@ impl Provider for ClaudeCodeProvider {
"usage": usage
});

emit_debug_trace(&self.model, &payload, &response, &usage);
emit_debug_trace(model_config, &payload, &response, &usage);

Ok((
message,
ProviderUsage::new(self.model.model_name.clone(), usage),
ProviderUsage::new(model_config.model_name.clone(), usage),
))
}
}
Expand Down
Loading
Loading