Skip to content

Commit

Permalink
Update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mochi-neko committed Apr 13, 2024
1 parent 35da1f2 commit fe0a62e
Show file tree
Hide file tree
Showing 9 changed files with 246 additions and 284 deletions.
34 changes: 25 additions & 9 deletions clust_macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,36 @@ mod parameter_type;
/// ## Supported arguments
/// - None
/// - e.g. `fn function() -> T`
/// - A type or types that implement(s) `std::str::FromStr`.
/// - e.g.
/// - `fn function(arg1: u32) -> T`
/// - `fn function(arg1: DefinedStruct) -> T` where `DefinedStruct` implements `std::str::FromStr`.
/// - Types that can be represented as JSON object.
/// - Boolean
/// - `bool`
/// - Integer
/// - `i8`, `i16`, `i32`, `i64`, `i128`
/// - `u8`, `u16`, `u32`, `u64`, `u128`
/// - Number
/// - `f32`
/// - `f64`
/// - String
/// - `String`
/// - `&str`
/// - Array
/// - `Vec<T>` where `T` is supported type.
/// - `&[T]` where `T` is supported type.
/// - `&[T; N]` where `T` is supported type and `N` is a constant.
/// - Option
/// - `Option<T>` where `T` is supported type.
///
/// ## Supported return values
/// - A type that implements `std::fmt::Display`.
/// - None
/// - e.g. `fn function()`
/// - A type that can be formatted, i.e. implements `std::fmt::Display`.
/// - e.g.
/// - `fn function() -> u32`
/// - `fn function() -> DefinedStruct` where `DefinedStruct` implements `std::fmt::Display`.
/// - Result<T, E> where T and E implement `std::fmt::Display`.
/// - `fn function() -> DefinedStruct` (where `DefinedStruct` implements `std::fmt::Display`).
/// - Result<T, E> where T and E can be formatted, i.e. implement `std::fmt::Display`.
/// - e.g.
/// - `fn function() -> Result<u32, Error>`
/// - `fn function() -> Result<DefinedStruct, Error>` where `DefinedStruct` and `Error` implement `std::fmt::Display`.
/// - `fn function() -> Result<u32, SomeError>` (where `SomeError` implements `std::fmt::Display`).
/// - `fn function() -> Result<DefinedStruct, SomeError>` (where `DefinedStruct` and `SomeError` implement `std::fmt::Display`).
///
/// ## Supported executions
/// - Synchronous -> implement `clust::messages::Tool`
Expand Down
14 changes: 8 additions & 6 deletions clust_macros/src/tool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use std::collections::BTreeMap;
use proc_macro2::{Ident, Span};
use quote::{quote, ToTokens};
use syn::{AttrStyle, Expr, ItemFn, Meta};
use valico::json_schema::PrimitiveType;

use crate::parameter_type::ParameterType;
use crate::return_type::ReturnType;
Expand Down Expand Up @@ -39,7 +40,8 @@ struct ToolInformation {
impl ToolInformation {
fn build_json_schema(&self) -> serde_json::Value {
let mut builder = valico::json_schema::Builder::new();
builder.object();

builder.type_(PrimitiveType::Object);

if let Some(description) = &self.description {
builder.desc(&description.clone());
Expand All @@ -50,15 +52,16 @@ impl ToolInformation {
for parameter in &self.parameters {
builder.properties(|properties| {
properties.insert(&parameter.name, |property| {
if let Some(description) = &parameter.description {
property.desc(&description.clone());
}
property.type_(
parameter
._type
.to_primitive_type(),
);

if let Some(description) = &parameter.description {
property.desc(&description.clone());
}

// "items" for array
if let ParameterType::Array(item_type) =
parameter._type.clone()
Expand Down Expand Up @@ -399,9 +402,8 @@ fn quote_call_with_value_async(

fn quote_return_no_value() -> proc_macro2::TokenStream {
quote! {
Ok(clust::messages::ToolResult::success(
Ok(clust::messages::ToolResult::success_without_content(
tool_use.id,
None,
))
}
}
Expand Down
2 changes: 1 addition & 1 deletion clust_macros/tests/tool.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use clust::messages::{TextContentBlock, Tool, ToolResult, ToolUse};
use clust::messages::{Tool, ToolUse};

use clust_macros::clust_tool;

Expand Down
69 changes: 31 additions & 38 deletions clust_macros/tests/tool_async.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
use std::collections::BTreeMap;

use clust::messages::{AsyncTool, FunctionCalls, FunctionResults, Invoke};
use clust::messages::{AsyncTool, ToolUse};

use clust_macros::clust_tool;

/// An asynchronous function for testing.
/// A function for testing.
///
/// ## Arguments
/// - `arg1` - First argument.
Expand All @@ -18,45 +16,40 @@ fn test_description() {
let tool = ClustTool_test_function {};

assert_eq!(
tool.description().to_string(),
r#"
<tool_description>
<tool_name>test_function</tool_name>
<description>An asynchronous function for testing.</description>
<parameters>
<parameter>
<name>arg1</name>
<type>i32</type>
<description>First argument.</description>
</parameter>
</parameters>
</tool_description>"#
tool.definition().to_string(),
r#"{
"name": "test_function",
"description": "A function for testing.",
"input_schema": {
"description": "A function for testing.",
"properties": {
"arg1": {
"description": "First argument.",
"type": "integer"
}
},
"required": [
"arg1"
],
"type": "object"
}
}"#
);
}

#[tokio::test]
async fn test_call() {
let tool = ClustTool_test_function {};

let function_calls = FunctionCalls {
invoke: Invoke {
tool_name: String::from("test_function"),
parameters: BTreeMap::from_iter(vec![(
"arg1".to_string(),
"42".to_string(),
)]),
},
};

let result = tool
.call(function_calls)
.await
.unwrap();

if let FunctionResults::Result(result) = result {
assert_eq!(result.tool_name, "test_function");
assert_eq!(result.stdout, "43");
} else {
panic!("Expected FunctionResults::Result");
}
let tool_use = ToolUse::new(
"toolu_XXXX",
"test_function",
serde_json::json!({"arg1": 42}),
);

let result = tool.call(tool_use).await.unwrap();

assert_eq!(result.tool_use_id, "toolu_XXXX");
assert_eq!(result.is_error, None);
assert_eq!(result.content.unwrap().text, "43");
}
111 changes: 45 additions & 66 deletions clust_macros/tests/tool_async_with_result.rs
Original file line number Diff line number Diff line change
@@ -1,29 +1,25 @@
use std::collections::BTreeMap;
use std::fmt::Display;

use clust::messages::{AsyncTool, FunctionCalls, FunctionResults, Invoke};
use clust::messages::{AsyncTool, ToolUse};

use clust_macros::clust_tool;

/// An asynchronous function with returning result for testing.
use std::fmt::Display;

/// A function for testing.
///
/// ## Arguments
/// - `arg1` - First argument.
///
/// ## Examples
/// ```rust
/// ```
#[clust_tool]
async fn test_function_with_result(arg1: i32) -> Result<u32, TestError> {
async fn test_function(arg1: i32) -> Result<i32, TestError> {
if arg1 >= 0 {
Ok(arg1 as u32)
Ok(arg1 + 1)
} else {
Err(TestError {
message: "arg1 is negative".to_string(),
})
}
}

#[derive(Debug)]
struct TestError {
message: String,
}
Expand All @@ -39,72 +35,55 @@ impl Display for TestError {

#[test]
fn test_description() {
let tool = ClustTool_test_function_with_result {};
let tool = ClustTool_test_function {};

assert_eq!(
tool.description().to_string(),
r#"
<tool_description>
<tool_name>test_function_with_result</tool_name>
<description>An asynchronous function with returning result for testing.</description>
<parameters>
<parameter>
<name>arg1</name>
<type>i32</type>
<description>First argument.</description>
</parameter>
</parameters>
</tool_description>"#
tool.definition().to_string(),
r#"{
"name": "test_function",
"description": "A function for testing.",
"input_schema": {
"description": "A function for testing.",
"properties": {
"arg1": {
"description": "First argument.",
"type": "integer"
}
},
"required": [
"arg1"
],
"type": "object"
}
}"#
);
}

#[tokio::test]
async fn test_call() {
let tool = ClustTool_test_function_with_result {};
let tool = ClustTool_test_function {};

let function_calls = FunctionCalls {
invoke: Invoke {
tool_name: String::from("test_function_with_result"),
parameters: BTreeMap::from_iter(vec![(
"arg1".to_string(),
"1".to_string(),
)]),
},
};
let tool_use = ToolUse::new(
"toolu_XXXX",
"test_function",
serde_json::json!({"arg1": 42}),
);

let result = tool
.call(function_calls)
.await
.unwrap();
let result = tool.call(tool_use).await.unwrap();

if let FunctionResults::Result(result) = result {
assert_eq!(
result.tool_name,
"test_function_with_result"
);
assert_eq!(result.stdout, "1");
} else {
panic!("Expected FunctionResults::Result");
}
assert_eq!(result.tool_use_id, "toolu_XXXX");
assert_eq!(result.is_error, None);
assert_eq!(result.content.unwrap().text, "43");

let function_calls = FunctionCalls {
invoke: Invoke {
tool_name: String::from("test_function_with_result"),
parameters: BTreeMap::from_iter(vec![(
"arg1".to_string(),
"-1".to_string(),
)]),
},
};
let tool_use = ToolUse::new(
"toolu_XXXX",
"test_function",
serde_json::json!({"arg1": -3}),
);

let result = tool
.call(function_calls)
.await
.unwrap();
let result = tool.call(tool_use).await.unwrap();

if let FunctionResults::Error(error) = result {
assert_eq!(error, "arg1 is negative");
} else {
panic!("Expected FunctionResults::Error");
}
assert_eq!(result.tool_use_id, "toolu_XXXX");
assert_eq!(result.is_error, Some(true));
assert_eq!(result.content.unwrap().text, "arg1 is negative");
}
55 changes: 55 additions & 0 deletions clust_macros/tests/tool_with_no_return_value.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
use clust::messages::{Tool, ToolUse};

use clust_macros::clust_tool;

/// A function for testing.
///
/// ## Arguments
/// - `arg1` - First argument.
#[clust_tool]
fn test_function(arg1: i32) {

}

#[test]
fn test_description() {
let tool = ClustTool_test_function {};

assert_eq!(
tool.definition().to_string(),
r#"{
"name": "test_function",
"description": "A function for testing.",
"input_schema": {
"description": "A function for testing.",
"properties": {
"arg1": {
"description": "First argument.",
"type": "integer"
}
},
"required": [
"arg1"
],
"type": "object"
}
}"#
);
}

#[test]
fn test_call() {
let tool = ClustTool_test_function {};

let tool_use = ToolUse::new(
"toolu_XXXX",
"test_function",
serde_json::json!({"arg1": 42}),
);

let result = tool.call(tool_use).unwrap();

assert_eq!(result.tool_use_id, "toolu_XXXX");
assert_eq!(result.is_error, None);
assert_eq!(result.content, None);
}
Loading

0 comments on commit fe0a62e

Please sign in to comment.