Skip to content

Commit

Permalink
Merge branch 'main' into micn/fix-load-shell-gui
Browse files Browse the repository at this point in the history
* main:
  chore: remove gpt-3.5-turbo UI suggestion, as it is deprecated (#959)
  chore: remove o1-mini suggestion from UI add model view (#957)
  fix: missing field in request (#956)
  docs: update provider docs, fix rate limit link (#943)
  fix: clarify linux cli install only (#927)
  feat: update ui for ollama host (#912)
  feat: add CONFIGURE=false option in install script (#920)
  fix: truncation agent token calculations (#915)
  fix: request payload for o1 models (#921)
  • Loading branch information
michaelneale committed Jan 30, 2025
2 parents 9325fcf + 3b75872 commit a7ea005
Show file tree
Hide file tree
Showing 29 changed files with 491 additions and 418 deletions.
4 changes: 0 additions & 4 deletions .github/workflows/deploy-docs-and-extensions.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,6 @@ on:
push:
branches:
- main

pull_request:
paths:
- 'documentation/**'

jobs:
deploy:
Expand Down
8 changes: 6 additions & 2 deletions crates/goose-cli/src/commands/configure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,12 @@ pub async fn configure_provider_dialog() -> Result<bool, Box<dyn Error>> {
.mask('▪')
.interact()?
} else {
cliclack::input(format!("Enter new value for {}", key.name))
.interact()?
let mut input =
cliclack::input(format!("Enter new value for {}", key.name));
if key.default.is_some() {
input = input.default_input(&key.default.clone().unwrap());
}
input.interact()?
};

if key.secret {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,23 @@ use serde_json::Value;
use std::collections::HashMap;

#[derive(Serialize)]
struct SecretResponse {
struct ConfigResponse {
error: bool,
}

#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
struct SecretRequest {
struct ConfigRequest {
key: String,
value: String,
is_secret: bool,
}

async fn store_secret(
async fn store_config(
State(state): State<AppState>,
headers: HeaderMap,
Json(request): Json<SecretRequest>,
) -> Result<Json<SecretResponse>, StatusCode> {
Json(request): Json<ConfigRequest>,
) -> Result<Json<ConfigResponse>, StatusCode> {
// Verify secret key
let secret_key = headers
.get("X-Secret-Key")
Expand All @@ -42,18 +42,18 @@ async fn store_secret(
config.set(&request.key, Value::String(request.value))
};
match result {
Ok(_) => Ok(Json(SecretResponse { error: false })),
Err(_) => Ok(Json(SecretResponse { error: true })),
Ok(_) => Ok(Json(ConfigResponse { error: false })),
Err(_) => Ok(Json(ConfigResponse { error: true })),
}
}

#[derive(Debug, Serialize, Deserialize)]
pub struct ProviderSecretRequest {
pub struct ProviderConfigRequest {
pub providers: Vec<String>,
}

#[derive(Debug, Serialize, Deserialize)]
pub struct SecretStatus {
pub struct ConfigStatus {
pub is_set: bool,
pub location: Option<String>,
}
Expand All @@ -64,7 +64,7 @@ pub struct ProviderResponse {
pub name: Option<String>,
pub description: Option<String>,
pub models: Option<Vec<String>>,
pub secret_status: HashMap<String, SecretStatus>,
pub config_status: HashMap<String, ConfigStatus>,
}

#[derive(Debug, Serialize, Deserialize)]
Expand All @@ -80,30 +80,33 @@ static PROVIDER_ENV_REQUIREMENTS: Lazy<HashMap<String, ProviderConfig>> = Lazy::
serde_json::from_str(contents).expect("Failed to parse providers_and_keys.json")
});

fn check_key_status(key: &str) -> (bool, Option<String>) {
fn check_key_status(config: &Config, key: &str) -> (bool, Option<String>) {
if let Ok(_value) = std::env::var(key) {
(true, Some("env".to_string()))
} else if Config::global().get_secret::<String>(key).is_ok() {
} else if config.get::<String>(key).is_ok() {
(true, Some("yaml".to_string()))
} else if config.get_secret::<String>(key).is_ok() {
(true, Some("keyring".to_string()))
} else {
(false, None)
}
}

async fn check_provider_secrets(
Json(request): Json<ProviderSecretRequest>,
async fn check_provider_configs(
Json(request): Json<ProviderConfigRequest>,
) -> Result<Json<HashMap<String, ProviderResponse>>, StatusCode> {
let mut response = HashMap::new();
let config = Config::global();

for provider_name in request.providers {
if let Some(provider_config) = PROVIDER_ENV_REQUIREMENTS.get(&provider_name) {
let mut secret_status = HashMap::new();
let mut config_status = HashMap::new();

for key in &provider_config.required_keys {
let (key_set, key_location) = check_key_status(key);
secret_status.insert(
let (key_set, key_location) = check_key_status(config, key);
config_status.insert(
key.to_string(),
SecretStatus {
ConfigStatus {
is_set: key_set,
location: key_location,
},
Expand All @@ -117,7 +120,7 @@ async fn check_provider_secrets(
name: Some(provider_config.name.clone()),
description: Some(provider_config.description.clone()),
models: Some(provider_config.models.clone()),
secret_status,
config_status,
},
);
} else {
Expand All @@ -128,7 +131,7 @@ async fn check_provider_secrets(
name: None,
description: None,
models: None,
secret_status: HashMap::new(),
config_status: HashMap::new(),
},
);
}
Expand All @@ -138,14 +141,16 @@ async fn check_provider_secrets(
}

#[derive(Deserialize)]
struct DeleteSecretRequest {
#[serde(rename_all = "camelCase")]
struct DeleteConfigRequest {
key: String,
is_secret: bool,
}

async fn delete_secret(
async fn delete_config(
State(state): State<AppState>,
headers: HeaderMap,
Json(request): Json<DeleteSecretRequest>,
Json(request): Json<DeleteConfigRequest>,
) -> Result<StatusCode, StatusCode> {
// Verify secret key
let secret_key = headers
Expand All @@ -158,17 +163,23 @@ async fn delete_secret(
}

// Attempt to delete the key
match Config::global().delete_secret(&request.key) {
let config = Config::global();
let result = if request.is_secret {
config.delete_secret(&request.key)
} else {
config.delete(&request.key)
};
match result {
Ok(_) => Ok(StatusCode::NO_CONTENT),
Err(_) => Err(StatusCode::NOT_FOUND),
}
}

pub fn routes(state: AppState) -> Router {
Router::new()
.route("/secrets/providers", post(check_provider_secrets))
.route("/secrets/store", post(store_secret))
.route("/secrets/delete", delete(delete_secret))
.route("/configs/providers", post(check_provider_configs))
.route("/configs/store", post(store_config))
.route("/configs/delete", delete(delete_config))
.with_state(state)
}

Expand All @@ -179,12 +190,12 @@ mod tests {
#[tokio::test]
async fn test_unsupported_provider() {
// Setup
let request = ProviderSecretRequest {
let request = ProviderConfigRequest {
providers: vec!["unsupported_provider".to_string()],
};

// Execute
let result = check_provider_secrets(Json(request)).await;
let result = check_provider_configs(Json(request)).await;

// Assert
assert!(result.is_ok());
Expand All @@ -194,6 +205,6 @@ mod tests {
.get("unsupported_provider")
.expect("Provider should exist");
assert!(!provider_status.supported);
assert!(provider_status.secret_status.is_empty());
assert!(provider_status.config_status.is_empty());
}
}
4 changes: 2 additions & 2 deletions crates/goose-server/src/routes/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
// Export route modules
pub mod agent;
pub mod configs;
pub mod extension;
pub mod health;
pub mod reply;
pub mod secrets;

use axum::Router;

Expand All @@ -14,5 +14,5 @@ pub fn configure(state: crate::state::AppState) -> Router {
.merge(reply::routes(state.clone()))
.merge(agent::routes(state.clone()))
.merge(extension::routes(state.clone()))
.merge(secrets::routes(state))
.merge(configs::routes(state))
}
2 changes: 1 addition & 1 deletion crates/goose-server/src/routes/providers_and_keys.json
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
"name": "Ollama",
"description": "Lorem ipsum",
"models": ["qwen2.5"],
"required_keys": []
"required_keys": ["OLLAMA_HOST"]
},
"openrouter": {
"name": "OpenRouter",
Expand Down
39 changes: 31 additions & 8 deletions crates/goose/src/agents/truncate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ impl TruncateAgent {
&self,
messages: &mut Vec<Message>,
estimate_factor: f32,
system_prompt: &str,
tools: &mut Vec<Tool>,
) -> anyhow::Result<()> {
// Model's actual context limit
let context_limit = self
Expand All @@ -57,20 +59,37 @@ impl TruncateAgent {
// Our token count is an estimate since model providers often don't provide the tokenizer (eg. Claude)
let context_limit = (context_limit as f32 * estimate_factor) as usize;

// Calculate current token count
// Take into account the system prompt, and our tools input and subtract that from the
// remaining context limit
let system_prompt_token_count = self.token_counter.count_tokens(system_prompt);
let tools_token_count = self.token_counter.count_tokens_for_tools(tools.as_slice());

// Check if system prompt + tools exceed our context limit
let remaining_tokens = context_limit
.checked_sub(system_prompt_token_count)
.and_then(|remaining| remaining.checked_sub(tools_token_count))
.ok_or_else(|| {
anyhow::anyhow!("System prompt and tools exceed estimated context limit")
})?;

let context_limit = remaining_tokens;

// Calculate current token count of each message, use count_chat_tokens to ensure we
// capture the full content of the message, include ToolRequests and ToolResponses
let mut token_counts: Vec<usize> = messages
.iter()
.map(|msg| self.token_counter.count_tokens(&msg.as_concat_text()))
.map(|msg| {
self.token_counter
.count_chat_tokens("", std::slice::from_ref(msg), &[])
})
.collect();

let _ = truncate_messages(
truncate_messages(
messages,
&mut token_counts,
context_limit,
&OldestFirstTruncation,
);

Ok(())
)
}
}

Expand Down Expand Up @@ -229,7 +248,7 @@ impl Agent for TruncateAgent {
// Create an error message & terminate the stream
// the previous message would have been a user message (e.g. before any tool calls, this is just after the input message.
// at the start of a loop after a tool call, it would be after a tool_use assistant followed by a tool_result user)
yield Message::assistant().with_text("Error: Context length exceeds limits even after multiple attempts to truncate.");
yield Message::assistant().with_text("Error: Context length exceeds limits even after multiple attempts to truncate. Please start a new session with fresh context and try again.");
break;
}

Expand All @@ -243,7 +262,11 @@ impl Agent for TruncateAgent {
// release the lock before truncation to prevent deadlock
drop(capabilities);

self.truncate_messages(&mut messages, estimate_factor).await?;
if let Err(err) = self.truncate_messages(&mut messages, estimate_factor, &system_prompt, &mut tools).await {
yield Message::assistant().with_text(format!("Error: Unable to truncate messages to stay within context limit. \n\nRan into this error: {}.\n\nPlease start a new session with fresh context and try again.", err));
break;
}


// Re-acquire the lock
capabilities = self.capabilities.lock().await;
Expand Down
Loading

0 comments on commit a7ea005

Please sign in to comment.