Skip to content

Commit

Permalink
Add async tool support
Browse files Browse the repository at this point in the history
  • Loading branch information
mochi-neko committed Apr 2, 2024
1 parent 599655a commit 5cbd66f
Show file tree
Hide file tree
Showing 9 changed files with 363 additions and 28 deletions.
3 changes: 3 additions & 0 deletions clust_macros/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,6 @@ quote = "1.0.*"
syn = "2.0.*"
proc-macro2 = "1.0.*"
clust = { path = ".." }

[dev-dependencies]
tokio = { version = "1.37.0", features = ["macros"] }
69 changes: 69 additions & 0 deletions clust_macros/src/check_result.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
use syn::Type;

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub(crate) enum ReturnType {
Value,
Result,
}

pub(crate) fn get_return_type(ty: &Type) -> ReturnType {
if is_result_type(ty) {
ReturnType::Result
} else {
ReturnType::Value
}
}

fn is_result_type(ty: &Type) -> bool {
match ty {
| Type::Path(type_path) => {
let path_segments = &type_path.path.segments;
path_segments.last().map_or(false, |last_segment| {
if last_segment.ident == "Result" {
match &last_segment.arguments {
syn::PathArguments::AngleBracketed(args) => args.args.len() == 2,
_ => false,
}
} else if path_segments.len() >= 2 {
path_segments
.iter()
.rev()
.nth(1)
.map_or(false, |second_last_segment| {
second_last_segment.ident == "result"
&& match &last_segment.arguments {
syn::PathArguments::AngleBracketed(args) => args.args.len() == 2,
_ => false,
}
})
} else {
false
}
})
},
| _ => false,
}
}

#[cfg(test)]
mod tests {
use super::*;

type TestResult = Result<i32, String>;

#[test]
fn test_get_return_type() {
let ty = syn::parse_str::<Type>("Result<i32, String>").unwrap();
assert_eq!(get_return_type(&ty), ReturnType::Result);

let ty =
syn::parse_str::<Type>("std::result::Result<i32, String>").unwrap();
assert_eq!(get_return_type(&ty), ReturnType::Result);

let ty = syn::parse_str::<Type>("i32").unwrap();
assert_eq!(get_return_type(&ty), ReturnType::Value);

let ty: Type = syn::parse_str::<Type>("TestResult").unwrap();
assert_eq!(get_return_type(&ty), ReturnType::Value);
}
}
14 changes: 3 additions & 11 deletions clust_macros/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,14 @@
use crate::tool::{impl_tool, impl_tool_with_result};
use crate::tool::impl_tool;
use proc_macro::TokenStream;

mod check_result;
mod tool;

#[proc_macro_attribute]
pub fn clust_tool(
attr: TokenStream,
_attr: TokenStream,
item: TokenStream,
) -> TokenStream {
let item_func = syn::parse::<syn::ItemFn>(item).unwrap();
impl_tool(&item_func)
}

#[proc_macro_attribute]
pub fn clust_tool_result(
attr: TokenStream,
item: TokenStream,
) -> TokenStream {
let item_func = syn::parse::<syn::ItemFn>(item).unwrap();
impl_tool_with_result(&item_func)
}
115 changes: 100 additions & 15 deletions clust_macros/src/tool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use proc_macro::TokenStream;
use proc_macro2::{Ident, Span};
use std::collections::BTreeMap;

use crate::check_result::{get_return_type, ReturnType};
use quote::{quote, ToTokens};
use syn::{AttrStyle, Expr, ItemFn, Meta};

Expand Down Expand Up @@ -364,12 +365,80 @@ fn quote_call_with_result(
}
}

fn quote_call_async(
func: &ItemFn,
info: &ToolInformation,
) -> proc_macro2::TokenStream {
let name = info.name.clone();
let ident = func.sig.ident.clone();
let parameters = quote_invoke_parameters(info);
let quote_result = quote_result(name.clone());

quote! {
async fn call(&self, function_calls: clust::messages::FunctionCalls)
-> std::result::Result<clust::messages::FunctionResults, clust::messages::ToolCallError> {
if function_calls.invoke.tool_name != #name {
return Err(clust::messages::ToolCallError::ToolNameMismatch);
}

let result = #ident(
#(
#parameters
),*
).await;

#quote_result
}
}
}

fn quote_call_async_with_result(
func: &ItemFn,
info: &ToolInformation,
) -> proc_macro2::TokenStream {
let name = info.name.clone();
let ident = func.sig.ident.clone();
let parameters = quote_invoke_parameters(info);
let quote_result = quote_result_with_match(name.clone());

quote! {
async fn call(&self, function_calls: clust::messages::FunctionCalls)
-> std::result::Result<clust::messages::FunctionResults, clust::messages::ToolCallError> {
if function_calls.invoke.tool_name != #name {
return Err(clust::messages::ToolCallError::ToolNameMismatch);
}

let result = #ident(
#(
#parameters
),*
).await;

#quote_result
}
}
}

fn impl_tool_for_function(
func: &ItemFn,
info: ToolInformation,
) -> proc_macro2::TokenStream {
let description_quote = quote_description(&info);
let call_quote = quote_call(func, &info);
let impl_description = quote_description(&info);

let impl_call = match func.sig.output.clone() {
| syn::ReturnType::Default => {
panic!("Function must have a displayable return type")
},
| syn::ReturnType::Type(_, _type) => {
let return_type = get_return_type(&_type);

match return_type {
| ReturnType::Value => quote_call(func, &info),
| ReturnType::Result => quote_call_with_result(func, &info),
}
},
};

let struct_name = Ident::new(
&format!("ClustTool_{}", info.name),
Span::call_site(),
Expand All @@ -384,18 +453,34 @@ fn impl_tool_for_function(

// Implement Tool trait for generated tool struct
impl clust::messages::Tool for #struct_name {
#description_quote
#call_quote
#impl_description
#impl_call
}
}
}

fn impl_tool_for_function_with_result(
fn impl_tool_for_async_function(
func: &ItemFn,
info: ToolInformation,
) -> proc_macro2::TokenStream {
let description_quote = quote_description(&info);
let call_quote = quote_call_with_result(func, &info);
let impl_description = quote_description(&info);

let impl_call = match func.sig.output.clone() {
| syn::ReturnType::Default => {
panic!("Function must have a displayable return type")
},
| syn::ReturnType::Type(_, _type) => {
let return_type = get_return_type(&_type);

match return_type {
| ReturnType::Value => quote_call_async(func, &info),
| ReturnType::Result => {
quote_call_async_with_result(func, &info)
},
}
},
};

let struct_name = Ident::new(
&format!("ClustTool_{}", info.name),
Span::call_site(),
Expand All @@ -409,21 +494,21 @@ fn impl_tool_for_function_with_result(
pub struct #struct_name;

// Implement Tool trait for generated tool struct
impl clust::messages::Tool for #struct_name {
#description_quote
#call_quote
impl clust::messages::AsyncTool for #struct_name {
#impl_description
#impl_call
}
}
}

pub(crate) fn impl_tool(func: &ItemFn) -> TokenStream {
let tool_information = get_tool_information(func);
impl_tool_for_function(func, tool_information).into()
}

pub(crate) fn impl_tool_with_result(func: &ItemFn) -> TokenStream {
let tool_information = get_tool_information(func);
impl_tool_for_function_with_result(func, tool_information).into()
if func.sig.asyncness.is_some() {
impl_tool_for_async_function(func, tool_information).into()
} else {
impl_tool_for_function(func, tool_information).into()
}
}

#[cfg(test)]
Expand Down
62 changes: 62 additions & 0 deletions clust_macros/tests/tool_async.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
use std::collections::BTreeMap;

use clust::messages::{AsyncTool, FunctionCalls, FunctionResults, Invoke};

use clust_macros::clust_tool;

/// An asynchronous function for testing.
///
/// ## Arguments
/// - `arg1` - First argument.
#[clust_tool]
async fn test_function(arg1: i32) -> i32 {
arg1 + 1
}

#[test]
fn test_description() {
let tool = ClustTool_test_function {};

assert_eq!(
tool.description().to_string(),
r#"
<tool_description>
<tool_name>test_function</tool_name>
<description>An asynchronous function for testing.</description>
<parameters>
<parameter>
<name>arg1</name>
<type>i32</type>
<description>First argument.</description>
</parameter>
</parameters>
</tool_description>"#
);
}

#[tokio::test]
async fn test_call() {
let tool = ClustTool_test_function {};

let function_calls = FunctionCalls {
invoke: Invoke {
tool_name: String::from("test_function"),
parameters: BTreeMap::from_iter(vec![(
"arg1".to_string(),
"42".to_string(),
)]),
},
};

let result = tool
.call(function_calls)
.await
.unwrap();

if let FunctionResults::Result(result) = result {
assert_eq!(result.tool_name, "test_function");
assert_eq!(result.stdout, "43");
} else {
panic!("Expected FunctionResults::Result");
}
}
Loading

0 comments on commit 5cbd66f

Please sign in to comment.