Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
352 changes: 352 additions & 0 deletions crates/goose/src/agents/chat_recall_extension.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,352 @@
use crate::agents::extension::PlatformExtensionContext;
use crate::agents::mcp_client::{Error, McpClientTrait};
use crate::session::SessionManager;
use anyhow::Result;
use async_trait::async_trait;
use indoc::indoc;
use rmcp::model::{
CallToolResult, Content, GetPromptResult, Implementation, InitializeResult, JsonObject,
ListPromptsResult, ListResourcesResult, ListToolsResult, ProtocolVersion, ReadResourceResult,
ServerCapabilities, ServerNotification, Tool, ToolAnnotations, ToolsCapability,
};
use schemars::{schema_for, JsonSchema};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;

pub static EXTENSION_NAME: &str = "chat_recall";

/// Parameters for the chat_recall tool
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
struct ChatRecallParams {
/// Search keywords. Use multiple related terms/synonyms (e.g., 'database postgres sql'). Mutually exclusive with session_id.
#[serde(skip_serializing_if = "Option::is_none")]
query: Option<String>,
/// Session ID to load. Returns first/last 3 messages. Mutually exclusive with query.
#[serde(skip_serializing_if = "Option::is_none")]
session_id: Option<String>,
/// Max results (default: 10, max: 50). Search mode only.
#[serde(skip_serializing_if = "Option::is_none")]
limit: Option<i64>,
/// ISO 8601 date (e.g., '2025-10-01T00:00:00Z'). Search mode only.
#[serde(skip_serializing_if = "Option::is_none")]
after_date: Option<String>,
/// ISO 8601 date (e.g., '2025-10-15T23:59:59Z'). Search mode only.
#[serde(skip_serializing_if = "Option::is_none")]
before_date: Option<String>,
}

pub struct ChatRecallClient {
info: InitializeResult,
context: PlatformExtensionContext,
}

impl ChatRecallClient {
pub fn new(context: PlatformExtensionContext) -> Result<Self> {
let info = InitializeResult {
protocol_version: ProtocolVersion::V_2025_03_26,
capabilities: ServerCapabilities {
tools: Some(ToolsCapability {
list_changed: Some(false),
}),
resources: None,
prompts: None,
completions: None,
experimental: None,
logging: None,
},
server_info: Implementation {
name: EXTENSION_NAME.to_string(),
title: Some("Chat Recall".to_string()),
version: "1.0.0".to_string(),
icons: None,
website_url: None,
},
instructions: Some(indoc! {r#"
Chat Recall

Search past conversations and load session summaries when the user expects some memory or context.

Two modes:
- Search mode: Use query with keywords/synonyms to find relevant messages
- Load mode: Use session_id to get first and last messages of a specific session
"#}.to_string()),
};

Ok(Self { info, context })
}

#[allow(clippy::too_many_lines)]
async fn handle_chat_recall(
&self,
arguments: Option<JsonObject>,
) -> Result<Vec<Content>, String> {
let arguments = arguments.ok_or("Missing arguments")?;

let session_id = arguments
.get("session_id")
.and_then(|v| v.as_str())
.map(|s| s.to_string());

if let Some(sid) = session_id {
// LOAD MODE: Get session summary (first and last few messages)
match SessionManager::get_session(&sid, true).await {
Ok(loaded_session) => {
let conversation = loaded_session.conversation.as_ref();

if conversation.is_none() {
return Ok(vec![Content::text(format!(
"Session {} has no conversation.",
sid
))]);
}

let msgs = conversation.unwrap().messages();
let total = msgs.len();

if total == 0 {
return Ok(vec![Content::text(format!(
"Session {} has no messages.",
sid
))]);
}

let mut output = format!(
"Session: {} (ID: {})\nWorking Dir: {}\nTotal Messages: {}\n\n",
loaded_session.description,
sid,
loaded_session.working_dir.display(),
total
);

// Show first 3 messages
let first_count = std::cmp::min(3, total);
output.push_str("--- First Few Messages ---\n\n");
for (idx, msg) in msgs.iter().take(first_count).enumerate() {
output.push_str(&format!("{}. [{:?}] ", idx + 1, msg.role));
for content in &msg.content {
if let Some(text) = content.as_text() {
output.push_str(text);
output.push('\n');
}
}
output.push('\n');
}

// Show last 3 messages (if different from first)
if total > first_count {
output.push_str("--- Last Few Messages ---\n\n");
let last_count = std::cmp::min(3, total);
let skip_count = total.saturating_sub(last_count);
for (idx, msg) in msgs.iter().skip(skip_count).enumerate() {
output.push_str(&format!(
"{}. [{:?}] ",
skip_count + idx + 1,
msg.role
));
for content in &msg.content {
if let Some(text) = content.as_text() {
output.push_str(text);
output.push('\n');
}
}
output.push('\n');
}
}

Ok(vec![Content::text(output)])
}
Err(e) => Err(format!("Failed to load session: {}", e)),
}
} else {
// SEARCH MODE: Search across all sessions
let query = arguments
.get("query")
.and_then(|v| v.as_str())
.ok_or("Missing required parameter: query or session_id")?
.to_string();

let limit = arguments
.get("limit")
.and_then(|v| v.as_i64())
.map(|l| l as usize)
.unwrap_or(10)
.min(50);

let after_date = arguments
.get("after_date")
.and_then(|v| v.as_str())
.and_then(|s| chrono::DateTime::parse_from_rfc3339(s).ok())
.map(|dt| dt.with_timezone(&chrono::Utc));

let before_date = arguments
.get("before_date")
.and_then(|v| v.as_str())
.and_then(|s| chrono::DateTime::parse_from_rfc3339(s).ok())
.map(|dt| dt.with_timezone(&chrono::Utc));

// Exclude current session from results to avoid self-referential loops
let exclude_session_id = self.context.session_id.clone();

match SessionManager::search_chat_history(
&query,
Some(limit),
after_date,
before_date,
exclude_session_id,
)
.await
{
Ok(results) => {
let formatted_results = if results.total_matches == 0 {
format!("No results found for query: '{}'", query)
} else {
let mut output = format!(
"Found {} matching message(s) across {} session(s) for query: '{}'\n\n",
results.total_matches,
results.results.len(),
query
);
for (idx, result) in results.results.iter().enumerate() {
output.push_str(&format!(
"{}. Session: {} (ID: {})\n Working Dir: {}\n Last Activity: {}\n Showing {} of {} total message(s) in session:\n\n",
idx + 1,
result.session_description,
result.session_id,
result.session_working_dir,
result.last_activity.format("%Y-%m-%d"),
result.messages.len(),
result.total_messages_in_session
));

for (msg_idx, message) in result.messages.iter().enumerate() {
output.push_str(&format!(
" {}.{} [{}]\n {}\n\n",
idx + 1,
msg_idx + 1,
message.role,
message
.content
.lines()
.map(|line| format!(" {}", line))
.collect::<Vec<_>>()
.join("\n")
));
}
}
output
};
Ok(vec![Content::text(formatted_results)])
}
Err(e) => Err(format!("Chat recall failed: {}", e)),
}
}
}

fn get_tools() -> Vec<Tool> {
// Generate JSON schema from the ChatRecallParams struct
let schema = schema_for!(ChatRecallParams);
let schema_value =
serde_json::to_value(schema).expect("Failed to serialize ChatRecallParams schema");

let input_schema = schema_value
.as_object()
.expect("Schema should be an object")
.clone();

vec![Tool::new(
"chat_recall".to_string(),
indoc! {r#"
Search past chat or load session summaries. Use when it is clear user expects some memory or context.

search mode (query): Use multiple keywords/synonyms. Returns messages grouped by session, ordered by recency. Supports date filters.
load mode (session_id): Returns first/last 3 messages of a session.
"#}
.to_string(),
input_schema,
)
.annotate(ToolAnnotations {
title: Some("Recall past conversations".to_string()),
read_only_hint: Some(true),
destructive_hint: Some(false),
idempotent_hint: Some(true),
open_world_hint: Some(false),
})]
}
}

#[async_trait]
impl McpClientTrait for ChatRecallClient {
async fn list_resources(
&self,
_next_cursor: Option<String>,
_cancellation_token: CancellationToken,
) -> Result<ListResourcesResult, Error> {
Err(Error::TransportClosed)
}

async fn read_resource(
&self,
_uri: &str,
_cancellation_token: CancellationToken,
) -> Result<ReadResourceResult, Error> {
Err(Error::TransportClosed)
}

async fn list_tools(
&self,
_next_cursor: Option<String>,
_cancellation_token: CancellationToken,
) -> Result<ListToolsResult, Error> {
Ok(ListToolsResult {
tools: Self::get_tools(),
next_cursor: None,
})
}

async fn call_tool(
&self,
name: &str,
arguments: Option<JsonObject>,
_cancellation_token: CancellationToken,
) -> Result<CallToolResult, Error> {
let content = match name {
"chat_recall" => self.handle_chat_recall(arguments).await,
_ => Err(format!("Unknown tool: {}", name)),
};

match content {
Ok(content) => Ok(CallToolResult::success(content)),
Err(error) => Ok(CallToolResult::error(vec![Content::text(format!(
"Error: {}",
error
))])),
}
}

async fn list_prompts(
&self,
_next_cursor: Option<String>,
_cancellation_token: CancellationToken,
) -> Result<ListPromptsResult, Error> {
Err(Error::TransportClosed)
}

async fn get_prompt(
&self,
_name: &str,
_arguments: Value,
_cancellation_token: CancellationToken,
) -> Result<GetPromptResult, Error> {
Err(Error::TransportClosed)
}

async fn subscribe(&self) -> mpsc::Receiver<ServerNotification> {
mpsc::channel(1).1
}

fn get_info(&self) -> Option<&InitializeResult> {
Some(&self.info)
}
}
14 changes: 14 additions & 0 deletions crates/goose/src/agents/extension.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::agents::chat_recall_extension;
use crate::agents::todo_extension;
use std::collections::HashMap;

Expand Down Expand Up @@ -49,6 +50,19 @@ pub static PLATFORM_EXTENSIONS: Lazy<HashMap<&'static str, PlatformExtensionDef>
},
);

map.insert(
chat_recall_extension::EXTENSION_NAME,
PlatformExtensionDef {
name: chat_recall_extension::EXTENSION_NAME,
description:
"Enable chat recall to search past conversations and load session summaries",
default_enabled: true,
client_factory: |ctx| {
Box::new(chat_recall_extension::ChatRecallClient::new(ctx).unwrap())
},
},
);

map
});

Expand Down
1 change: 1 addition & 0 deletions crates/goose/src/agents/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
mod agent;
pub(crate) mod chat_recall_extension;
pub mod extension;
pub mod extension_malware_check;
pub mod extension_manager;
Expand Down
4 changes: 2 additions & 2 deletions crates/goose/src/agents/platform_tools.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,11 @@ pub fn manage_schedule_tool() -> Tool {
PLATFORM_MANAGE_SCHEDULE_TOOL_NAME.to_string(),
indoc! {r#"
Manage scheduled recipe execution for this goose instance.

Actions:
- "list": List all scheduled jobs
- "create": Create a new scheduled job from a recipe file
- "run_now": Execute a scheduled job immediately
- "run_now": Execute a scheduled job immediately
- "pause": Pause a scheduled job
- "unpause": Resume a paused job
- "delete": Remove a scheduled job
Expand Down
Loading
Loading