Skip to content

Commit

Permalink
Refactor before and after hooks to take a mut ChatCompletion*
Browse files Browse the repository at this point in the history
  • Loading branch information
timonv committed Dec 5, 2024
1 parent 21f341e commit 651fde0
Show file tree
Hide file tree
Showing 5 changed files with 192 additions and 98 deletions.
188 changes: 105 additions & 83 deletions swiftide-agents/src/agent.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
#![allow(dead_code)]
use crate::{
default_context::DefaultContext,
hooks::{AfterToolFn, BeforeToolFn, Hook, HookFn, HookTypes, MessageHookFn},
hooks::{
AfterCompletionFn, AfterToolFn, BeforeCompletionFn, BeforeToolFn, Hook, HookFn, HookTypes,
MessageHookFn,
},
state,
system_prompt::SystemPrompt,
tools::control::Stop,
Expand All @@ -11,7 +14,9 @@ use std::{collections::HashSet, sync::Arc};
use anyhow::Result;
use derive_builder::Builder;
use swiftide_core::{
chat_completion::{ChatCompletion, ChatCompletionRequest, ChatMessage, Tool, ToolOutput},
chat_completion::{
ChatCompletion, ChatCompletionRequest, ChatMessage, Tool, ToolCall, ToolOutput,
},
prompt::Prompt,
AgentContext,
};
Expand Down Expand Up @@ -128,8 +133,8 @@ impl AgentBuilder {
}

/// Add a hook that runs before each completion.
pub fn before_each(&mut self, hook: impl HookFn + 'static) -> &mut Self {
self.add_hook(Hook::BeforeEach(Box::new(hook)))
pub fn before_completion(&mut self, hook: impl BeforeCompletionFn + 'static) -> &mut Self {
self.add_hook(Hook::BeforeCompletion(Box::new(hook)))
}

/// Add a hook that runs after each tool. The `Result<ToolOutput, ToolError>` is provided
Expand All @@ -147,8 +152,8 @@ impl AgentBuilder {
}

/// Add a hook that runs after each completion, when all tool calls are finished.
pub fn after_each(&mut self, hook: impl HookFn + 'static) -> &mut Self {
self.add_hook(Hook::AfterEach(Box::new(hook)))
pub fn after_completion(&mut self, hook: impl AfterCompletionFn + 'static) -> &mut Self {
self.add_hook(Hook::AfterCompletion(Box::new(hook)))
}

/// Add a hook that runs when a new message is added to the context. Note that each tool adds a
Expand Down Expand Up @@ -239,7 +244,16 @@ impl Agent {
.add_messages(&[ChatMessage::System(system_prompt.render().await?)])
.await;
}
self.invoke_hooks_matching(HookTypes::BeforeAll).await?;
for hook in self.hooks_by_type(HookTypes::BeforeAll) {
if let Hook::BeforeAll(hook) = hook {
let span = tracing::info_span!(
"hook",
"otel.name" = format!("hook.{}", HookTypes::AfterTool)
);
tracing::info!("Calling {} hook", HookTypes::AfterTool);
hook(&*self.context).instrument(span).await?;
}
}
}

if let Some(query) = maybe_query {
Expand Down Expand Up @@ -270,14 +284,12 @@ impl Agent {

#[tracing::instrument(skip_all)]
async fn run_completions(&mut self, messages: &[ChatMessage]) -> Result<()> {
self.invoke_hooks_matching(HookTypes::BeforeEach).await?;

debug!(
"Running completion for agent with {} messages",
messages.len()
);

let chat_completion_request = ChatCompletionRequest::builder()
let mut chat_completion_request = ChatCompletionRequest::builder()
.messages(messages)
.tools_spec(
self.tools
Expand All @@ -287,6 +299,19 @@ impl Agent {
)
.build()?;

for hook in self.hooks_by_type(HookTypes::BeforeCompletion) {
if let Hook::BeforeCompletion(hook) = hook {
let span = tracing::info_span!(
"hook",
"otel.name" = format!("hook.{}", HookTypes::BeforeCompletion)
);
tracing::info!("Calling {} hook", HookTypes::BeforeCompletion);
hook(&*self.context, &mut chat_completion_request)
.instrument(span)
.await?;
}
}

debug!(
"Calling LLM with the following new messages:\n {}",
self.context
Expand All @@ -297,80 +322,96 @@ impl Agent {
.collect::<Vec<_>>()
.join(",\n")
);
let response = self.llm.complete(&chat_completion_request).await?;

let mut response = self.llm.complete(&chat_completion_request).await?;

for hook in self.hooks_by_type(HookTypes::AfterCompletion) {
if let Hook::AfterCompletion(hook) = hook {
let span = tracing::info_span!(
"hook",
"otel.name" = format!("hook.{}", HookTypes::AfterCompletion)
);
tracing::info!("Calling {} hook", HookTypes::AfterCompletion);
hook(&*self.context, &mut response).instrument(span).await?;
}
}
self.add_message(ChatMessage::Assistant(
response.message,
response.tool_calls.clone(),
))
.await?;

if let Some(tool_calls) = response.tool_calls {
debug!("LLM returned tool calls: {:?}", tool_calls);

let mut handles = vec![];
for tool_call in tool_calls {
let Some(tool) = self.find_tool_by_name(tool_call.name()) else {
tracing::warn!("Tool {} not found", tool_call.name());
continue;
};
tracing::info!("Calling tool `{}`", tool_call.name());

let tool_args = tool_call.args().map(String::from);
let context: Arc<dyn AgentContext> = Arc::clone(&self.context);

let tool_name = tool.name().to_string();
let tool_span =
tracing::info_span!("tool", "otel.name" = format!("tool.{}", &tool_name));

for hook in self.hooks_by_type(HookTypes::BeforeTool) {
if let Hook::BeforeTool(hook) = hook {
let span = tracing::info_span!(
"hook",
"otel.name" = format!("hook.{}", HookTypes::BeforeTool)
);
tracing::info!("Calling {} hook", HookTypes::BeforeTool);
hook(&*self.context, &tool_call).instrument(span).await?;
}
self.invoke_tools(tool_calls).await?;
};

Ok(())
}

async fn invoke_tools(&mut self, tool_calls: Vec<ToolCall>) -> Result<()> {
debug!("LLM returned tool calls: {:?}", tool_calls);

let mut handles = vec![];
for tool_call in tool_calls {
let Some(tool) = self.find_tool_by_name(tool_call.name()) else {
tracing::warn!("Tool {} not found", tool_call.name());
continue;
};
tracing::info!("Calling tool `{}`", tool_call.name());

let tool_args = tool_call.args().map(String::from);
let context: Arc<dyn AgentContext> = Arc::clone(&self.context);

let tool_name = tool.name().to_string();
let tool_span =
tracing::info_span!("tool", "otel.name" = format!("tool.{}", &tool_name));

for hook in self.hooks_by_type(HookTypes::BeforeTool) {
if let Hook::BeforeTool(hook) = hook {
let span = tracing::info_span!(
"hook",
"otel.name" = format!("hook.{}", HookTypes::BeforeTool)
);
tracing::info!("Calling {} hook", HookTypes::BeforeTool);
hook(&*self.context, &tool_call).instrument(span).await?;
}
}

let handle = tokio::spawn(async move {
let handle = tokio::spawn(async move {
let output = tool.invoke(&*context, tool_args.as_deref()).await.map_err(|e| { tracing::error!(error = %e, "Failed tool call"); e })?;

tracing::debug!(output = output.to_string(), args = ?tool_args, tool_name, "Completed tool call");

Ok(output)
}.instrument(tool_span));

handles.push((handle, tool_call));
}
handles.push((handle, tool_call));
}

for (handle, tool_call) in handles {
let mut output = handle.await?;

for hook in self.hooks_by_type(HookTypes::AfterTool) {
if let Hook::AfterTool(hook) = hook {
let span = tracing::info_span!(
"hook",
"otel.name" = format!("hook.{}", HookTypes::AfterTool)
);
tracing::info!("Calling {} hook", HookTypes::AfterTool);
hook(&*self.context, &tool_call, &mut output)
.instrument(span)
.await?;
}
for (handle, tool_call) in handles {
let mut output = handle.await?;

// Invoking hooks feels too verbose and repetitive
for hook in self.hooks_by_type(HookTypes::AfterTool) {
if let Hook::AfterTool(hook) = hook {
let span = tracing::info_span!(
"hook",
"otel.name" = format!("hook.{}", HookTypes::AfterTool)
);
tracing::info!("Calling {} hook", HookTypes::AfterTool);
hook(&*self.context, &tool_call, &mut output)
.instrument(span)
.await?;
}
}

let output = output?;
let output = output?;

self.handle_control_tools(&output);
self.handle_control_tools(&output);

self.add_message(ChatMessage::ToolOutput(tool_call, output))
.await?;
}
};

self.invoke_hooks_matching(HookTypes::AfterEach).await?;
self.add_message(ChatMessage::ToolOutput(tool_call, output))
.await?;
}

Ok(())
}
Expand All @@ -382,25 +423,6 @@ impl Agent {
.collect()
}

async fn invoke_hooks_matching(&self, hook_type: HookTypes) -> Result<()> {
tracing::info!("Invoking {hook_type} hooks");

for hook in self.hooks_by_type(hook_type) {
let span = tracing::info_span!("hook", "otel.name" = format!("hook.{}", hook_type));

match hook {
Hook::BeforeAll(hook) => hook(&*self.context).instrument(span).await?,
Hook::BeforeEach(hook) => hook(&*self.context).instrument(span).await?,
Hook::AfterEach(hook) => hook(&*self.context).instrument(span).await?,
Hook::AfterTool(..) | Hook::OnNewMessage(..) | Hook::BeforeTool(..) => {
debug_assert!(false, "Should not be called here");
}
}
}

Ok(())
}

fn find_tool_by_name(&self, tool_name: &str) -> Option<Box<dyn Tool>> {
self.tools
.iter()
Expand Down Expand Up @@ -699,9 +721,9 @@ mod tests {
.llm(&mock_llm)
.no_system_prompt()
.before_all(mock_before_all.hook_fn())
.before_each(mock_before_each.hook_fn())
.before_completion(mock_before_each.before_completion_fn())
.before_tool(mock_before_tool.before_tool_fn())
.after_each(mock_after_each.hook_fn())
.after_completion(mock_after_each.after_completion_fn())
.after_tool(mock_after_tool.after_tool_fn())
.on_new_message(mock_on_message.message_hook_fn())
.build()
Expand Down
64 changes: 57 additions & 7 deletions swiftide-agents/src/hooks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,17 @@
//! }
//!}
use anyhow::Result;
use std::{future::Future, pin::Pin};
use std::{future::Future, ops::Deref, pin::Pin};

Check failure on line 45 in swiftide-agents/src/hooks.rs

View workflow job for this annotation

GitHub Actions / Lint

unused import: `ops::Deref`

use dyn_clone::DynClone;
use swiftide_core::{
chat_completion::{errors::ToolError, ChatMessage, ToolCall, ToolOutput},
chat_completion::{
errors::ToolError, ChatCompletionRequest, ChatCompletionResponse, ChatMessage, ToolCall,
ToolOutput,
},
AgentContext,
};

/// Hooks that are call on before each, after each and before all
pub trait HookFn:
for<'a> Fn(&'a dyn AgentContext) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
+ Send
Expand All @@ -61,6 +63,32 @@ pub trait HookFn:

dyn_clone::clone_trait_object!(HookFn);

pub trait BeforeCompletionFn:
for<'a> Fn(
&'a dyn AgentContext,
&mut ChatCompletionRequest,
) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
+ Send
+ Sync
+ DynClone
{
}

dyn_clone::clone_trait_object!(BeforeCompletionFn);

pub trait AfterCompletionFn:
for<'a> Fn(
&'a dyn AgentContext,
&mut ChatCompletionResponse,
) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
+ Send
+ Sync
+ DynClone
{
}

dyn_clone::clone_trait_object!(AfterCompletionFn);

/// Hooks that are called after each tool
///
pub trait AfterToolFn:
Expand Down Expand Up @@ -110,11 +138,11 @@ dyn_clone::clone_trait_object!(MessageHookFn);
#[strum_discriminants(name(HookTypes), derive(strum_macros::Display))]
pub enum Hook {
BeforeAll(Box<dyn HookFn>),
BeforeEach(Box<dyn HookFn>),
BeforeCompletion(Box<dyn BeforeCompletionFn>),
BeforeTool(Box<dyn BeforeToolFn>),
AfterTool(Box<dyn AfterToolFn>),
OnNewMessage(Box<dyn MessageHookFn>),
AfterEach(Box<dyn HookFn>),
AfterCompletion(Box<dyn AfterCompletionFn>),
}

impl<F> HookFn for F where
Expand All @@ -125,6 +153,28 @@ impl<F> HookFn for F where
{
}

impl<F> BeforeCompletionFn for F where
F: for<'a> Fn(
&'a dyn AgentContext,
&mut ChatCompletionRequest,
) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
+ Send
+ Sync
+ DynClone
{
}

impl<F> AfterCompletionFn for F where
F: for<'a> Fn(
&'a dyn AgentContext,
&mut ChatCompletionResponse,
) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
+ Send
+ Sync
+ DynClone
{
}

impl<F> BeforeToolFn for F where
F: for<'a> Fn(
&'a dyn AgentContext,
Expand Down Expand Up @@ -166,9 +216,9 @@ mod tests {
fn test_hooks_compile_sync_and_async() {
Agent::builder()
.before_all(|_| Box::pin(async { Ok(()) }))
.before_each(|_| Box::pin(async { Ok(()) }))
.before_completion(|_, _| Box::pin(async { Ok(()) }))
.before_tool(|_, _| Box::pin(async { Ok(()) }))
.after_tool(|_, _, _| Box::pin(async { Ok(()) }))
.after_each(|_| Box::pin(async { Ok(()) }));
.after_completion(|_, _| Box::pin(async { Ok(()) }));
}
}
Loading

0 comments on commit 651fde0

Please sign in to comment.