diff --git a/tests/template/data/sft_with_tools/sft_with_tools.json b/tests/template/data/sft_with_tools/sft_with_tools.json new file mode 100644 index 0000000000..49d7ab3da5 --- /dev/null +++ b/tests/template/data/sft_with_tools/sft_with_tools.json @@ -0,0 +1,371 @@ +[ + { + "messages": [ + { + "role": "system", + "content": "You are a helpful assistant with access to various tools. Use them when needed to help users." + }, + { + "role": "user", + "content": "What's the weather like in Beijing today?" + }, + { + "role": "assistant", + "content": "Let me get the weather for you.", + "tool_calls": [ + { + "id": "call_abc123", + "type": "function", + "function": { + "name": "get_weather", + "arguments": "{\"location\": \"Beijing\", \"unit\": \"celsius\"}" + } + } + ] + }, + { + "role": "tool", + "content": "{\"temperature\": 22, \"condition\": \"sunny\", \"humidity\": 45}", + "tool_call_id": "call_abc123" + }, + { + "role": "assistant", + "content": "The weather in Beijing today is sunny with a temperature of 22°C and humidity at 45%. It's a pleasant day!" + } + ], + "tools": [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The temperature unit" + } + }, + "required": ["location"] + } + } + } + ] + }, + { + "messages": [ + { + "role": "system", + "content": "You are a helpful math assistant. Use the calculator tool when needed." + }, + { + "role": "user", + "content": "Can you help me calculate 1584 * 327?" + }, + { + "role": "assistant", + "content": "let me calculate 1584 * 327 using the tools provided.", + "tool_calls": [ + { + "id": "call_def456", + "type": "function", + "function": { + "name": "calculator", + "arguments": "{\"operation\": \"multiply\", \"a\": 1584, \"b\": 327}" + } + } + ] + }, + { + "role": "tool", + "content": "{\"result\": 517968}", + "tool_call_id": "call_def456" + }, + { + "role": "assistant", + "content": "The result of 1584 × 327 is 517,968." + } + ], + "tools": [ + { + "type": "function", + "function": { + "name": "calculator", + "description": "Perform basic mathematical operations", + "parameters": { + "type": "object", + "properties": { + "operation": { + "type": "string", + "enum": ["add", "subtract", "multiply", "divide"], + "description": "The mathematical operation to perform" + }, + "a": { + "type": "number", + "description": "First operand" + }, + "b": { + "type": "number", + "description": "Second operand" + } + }, + "required": ["operation", "a", "b"] + } + } + } + ] + }, + { + "messages": [ + { + "role": "system", + "content": "You are a helpful assistant that can search for information and send emails." + }, + { + "role": "user", + "content": "Search for the latest news about AI and then send a summary to john@example.com" + }, + { + "role": "assistant", + "tool_calls": [ + { + "id": "call_ghi789", + "type": "function", + "function": { + "name": "web_search", + "arguments": "{\"query\": \"latest AI news 2024\", \"num_results\": 5}" + } + } + ] + }, + { + "role": "tool", + "content": "{\"results\": [{\"title\": \"GPT-5 Announced\", \"snippet\": \"OpenAI announces next generation model...\"}, {\"title\": \"AI Regulation Updates\", \"snippet\": \"EU passes new AI act...\"}]}", + "tool_call_id": "call_ghi789" + }, + { + "role": "assistant", + "tool_calls": [ + { + "id": "call_jkl012", + "type": "function", + "function": { + "name": "send_email", + "arguments": "{\"to\": \"john@example.com\", \"subject\": \"Latest AI News Summary\", \"body\": \"Here are the latest AI developments:\\n1. GPT-5 Announced - OpenAI announces next generation model\\n2. AI Regulation Updates - EU passes new AI act\"}" + } + } + ] + }, + { + "role": "tool", + "content": "{\"status\": \"sent\", \"message_id\": \"msg_12345\"}", + "tool_call_id": "call_jkl012" + }, + { + "role": "assistant", + "content": "I've found the latest AI news and sent a summary to john@example.com. The email includes information about GPT-5 announcement and EU AI regulation updates." + } + ], + "tools": [ + { + "type": "function", + "function": { + "name": "web_search", + "description": "Search the web for information", + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "The search query" + }, + "num_results": { + "type": "integer", + "description": "Number of results to return", + "default": 10 + } + }, + "required": ["query"] + } + } + }, + { + "type": "function", + "function": { + "name": "send_email", + "description": "Send an email to a recipient", + "parameters": { + "type": "object", + "properties": { + "to": { + "type": "string", + "description": "Email address of the recipient" + }, + "subject": { + "type": "string", + "description": "Email subject" + }, + "body": { + "type": "string", + "description": "Email body content" + } + }, + "required": ["to", "subject", "body"] + } + } + } + ] + }, + { + "messages": [ + { + "role": "system", + "content": "You are a helpful assistant with access to multiple tools." + }, + { + "role": "user", + "content": "I need to convert 1000 USD to EUR, then calculate 15% tax on that amount, and finally send the result to my email alice@example.com" + }, + { + "role": "assistant", + "content": null, + "tool_calls": [ + { + "id": "call_001", + "type": "function", + "function": { + "name": "currency_converter", + "arguments": "{\"amount\": 1000, \"from_currency\": \"USD\", \"to_currency\": \"EUR\"}" + } + } + ] + }, + { + "role": "tool", + "content": "{\"converted_amount\": 920.50, \"exchange_rate\": 0.9205}", + "tool_call_id": "call_001" + }, + { + "role": "assistant", + "content": null, + "tool_calls": [ + { + "id": "call_002", + "type": "function", + "function": { + "name": "calculate_percentage", + "arguments": "{\"base_amount\": 920.50, \"percentage\": 15}" + } + } + ] + }, + { + "role": "tool", + "content": "{\"result\": 138.075, \"total_with_percentage\": 1058.575}", + "tool_call_id": "call_002" + }, + { + "role": "assistant", + "content": null, + "tool_calls": [ + { + "id": "call_003", + "type": "function", + "function": { + "name": "send_email", + "arguments": "{\"to\": \"alice@example.com\", \"subject\": \"Currency Conversion and Tax Calculation\", \"body\": \"Here are your requested calculations:\\n\\n1. Currency Conversion: 1000 USD = 920.50 EUR (rate: 0.9205)\\n2. 15% tax on 920.50 EUR = 138.08 EUR\\n3. Total amount with tax: 1058.58 EUR\"}" + } + } + ] + }, + { + "role": "tool", + "content": "{\"status\": \"sent\", \"message_id\": \"msg_789xyz\"}", + "tool_call_id": "call_003" + }, + { + "role": "assistant", + "content": "I've completed all the calculations and sent the results to alice@example.com:\n\n1. Converted 1000 USD to 920.50 EUR (exchange rate: 0.9205)\n2. Calculated 15% tax: 138.08 EUR\n3. Total with tax: 1058.58 EUR\n\nThe email has been successfully sent (message ID: msg_789xyz)." + } + ], + "tools": [ + { + "type": "function", + "function": { + "name": "currency_converter", + "description": "Convert amount from one currency to another", + "parameters": { + "type": "object", + "properties": { + "amount": { + "type": "number", + "description": "The amount to convert" + }, + "from_currency": { + "type": "string", + "description": "Source currency code" + }, + "to_currency": { + "type": "string", + "description": "Target currency code" + } + }, + "required": ["amount", "from_currency", "to_currency"] + } + } + }, + { + "type": "function", + "function": { + "name": "calculate_percentage", + "description": "Calculate percentage of a given amount", + "parameters": { + "type": "object", + "properties": { + "base_amount": { + "type": "number", + "description": "The base amount" + }, + "percentage": { + "type": "number", + "description": "The percentage to calculate" + } + }, + "required": ["base_amount", "percentage"] + } + } + }, + { + "type": "function", + "function": { + "name": "send_email", + "description": "Send an email to a recipient", + "parameters": { + "type": "object", + "properties": { + "to": { + "type": "string", + "description": "Email address of the recipient" + }, + "subject": { + "type": "string", + "description": "Email subject" + }, + "body": { + "type": "string", + "description": "Email body content" + } + }, + "required": ["to", "subject", "body"] + } + } + } + ] + } +] diff --git a/tests/tools.py b/tests/tools.py index c50fa2d365..58108faa77 100644 --- a/tests/tools.py +++ b/tests/tools.py @@ -87,6 +87,17 @@ def get_unittest_dataset_config( response_key="response", ), ) + elif dataset_name == "sft_with_tools": + return StorageConfig( + name=dataset_name, + path=os.path.join(os.path.dirname(__file__), "template", "data", "sft_with_tools"), + split="train", + format=FormatConfig( + prompt_type=PromptType.MESSAGES, + messages_key="messages", + tools_key="tools", + ), + ) elif dataset_name == "dpo": return StorageConfig( name=dataset_name, diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index c44410eeea..1b0207a494 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -326,6 +326,34 @@ def tearDown(self): shutil.rmtree(self.config.checkpoint_job_dir) +class TestTrainerToolsSFT(BaseTrainerCase): + def test_trainer_tools(self): + """Test SFT with tools.""" + # test both mode + self.config.mode = "train" + self.config.algorithm.algorithm_type = "sft" + self.config.algorithm.policy_loss_fn = "sft" + self.config.algorithm.policy_loss_fn_args = {} + self.config.algorithm.kl_loss_fn = "none" + self.config.algorithm.entropy_loss_fn = "none" + self.config.synchronizer.sync_interval = 4 + self.config.buffer.train_batch_size = 4 + self.config.buffer.total_epochs = 4 + self.config.buffer.trainer_input.experience_buffer = get_unittest_dataset_config( + "sft_with_tools" + ) + self.config.check_and_update() + train(self.config) + parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard")) + actor_metrics = parser.metric_list("actor") + self.assertTrue(len(actor_metrics) > 0) + self.assertEqual(parser.metric_max_step(actor_metrics[0]), 4) + + def tearDown(self): + # remove dir only when the test passed + shutil.rmtree(self.config.checkpoint_job_dir) + + def run_trainer(config: Config) -> None: ray.init(namespace=config.ray_namespace) train(config) diff --git a/trinity/buffer/reader/file_reader.py b/trinity/buffer/reader/file_reader.py index 4de2e03fc6..be8d37b3d0 100644 --- a/trinity/buffer/reader/file_reader.py +++ b/trinity/buffer/reader/file_reader.py @@ -118,6 +118,7 @@ def __init__(self, meta: StorageConfig, config: BufferConfig): subset_name = meta.subset_name self.prompt_type = meta.format.prompt_type self.messages_key = meta.format.messages_key + self.tools_key = meta.format.tools_key self.prompt_key = meta.format.prompt_key self.response_key = meta.format.response_key self.read_batch_size = config.train_batch_size @@ -140,12 +141,22 @@ def read( if self.prompt_type == PromptType.MESSAGES: for sample in samples: messages = sample[self.messages_key] - tokens = self.tokenizer.apply_chat_template( - messages, add_generation_prompt=False, return_tensors="pt" - )[0] - prompt_tokens_ids = self.tokenizer.apply_chat_template( - messages[:-1], add_generation_prompt=True, return_tensors="pt" - )[0] + tools = sample.get(self.tools_key, None) + if tools: + tokens = self.tokenizer.apply_chat_template( + messages, tools=tools, add_generation_prompt=False, return_tensors="pt" + )[0] + prompt_tokens_ids = self.tokenizer.apply_chat_template( + messages[:-1], tools=tools, add_generation_prompt=True, return_tensors="pt" + )[0] + else: + tokens = self.tokenizer.apply_chat_template( + messages, add_generation_prompt=False, return_tensors="pt" + )[0] + prompt_tokens_ids = self.tokenizer.apply_chat_template( + messages[:-1], add_generation_prompt=True, return_tensors="pt" + )[0] + experience = Experience( tokens=tokens, prompt_length=len(prompt_tokens_ids), diff --git a/trinity/common/config.py b/trinity/common/config.py index c3fb390577..1ceb7234d0 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -34,6 +34,7 @@ class FormatConfig: prompt_key: str = "prompt" response_key: str = "response" messages_key: str = "message" + tools_key: str = "tools" chat_template: str = "" # deprecated system_prompt: Optional[str] = None diff --git a/trinity/common/experience.py b/trinity/common/experience.py index d1c0bdc8cc..d4d9aa6cc2 100644 --- a/trinity/common/experience.py +++ b/trinity/common/experience.py @@ -118,6 +118,7 @@ class Experience: # Action mask which indicates which tokens are generated by the model action_mask: Optional[Tensor] = None # [resp_length] messages: Optional[List[dict]] = None # List of messages + tools: Optional[List[dict]] = None # for dpo experiences chosen: Optional[Tensor] = None # Token ids of the chosen response [resp_length] @@ -141,6 +142,7 @@ def __init__( # noqa: C901 prompt_text=None, action_mask=None, messages=None, + tools=None, chosen=None, rejected=None, chosen_text=None, @@ -192,6 +194,7 @@ def __init__( # noqa: C901 action_mask = torch.tensor(action_mask, dtype=torch.bool) self.action_mask = action_mask self.messages = messages + self.tools = tools if isinstance(chosen, list): chosen = torch.tensor(chosen, dtype=torch.int32) self.chosen = chosen @@ -236,6 +239,8 @@ def to_dict(self) -> dict: res["response_text"] = self.response_text if self.messages is not None: res["messages"] = self.messages + if self.tools is not None: + res["tools"] = self.tools if self.chosen_text is not None: res["chosen_text"] = self.chosen_text if self.rejected_text is not None: