Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 22 additions & 21 deletions tests/trainer/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import unittest
from copy import deepcopy
from datetime import datetime
from logging import Logger
from typing import Dict
from unittest import mock

Expand Down Expand Up @@ -1530,26 +1531,24 @@ def tearDown(self) -> None:
ray.shutdown(_exiting_interpreter=True)

def test_agentscope_tuner(self):
try:
from agentscope.agent import ReActAgent
from agentscope.formatter import OpenAIChatFormatter
from agentscope.message import Msg
from agentscope.model import ChatModelBase
from agentscope.tuner import (
Algorithm,
Dataset,
JudgeOutput,
TunerChatModel,
WorkflowOutput,
tune,
)
except ImportError:
self.skipTest("agentscope >= 1.0.12 is not installed")
from agentscope.agent import ReActAgent
from agentscope.formatter import OpenAIChatFormatter
from agentscope.message import Msg
from agentscope.model import ChatModelBase
from agentscope.tuner import (
AlgorithmConfig,
DatasetConfig,
JudgeOutput,
TunerModelConfig,
WorkflowOutput,
tune,
)

async def workflow_func(
task: Dict,
model: ChatModelBase,
auxiliary_models: Dict[str, ChatModelBase],
logger: Logger,
) -> WorkflowOutput:
assert isinstance(model, ChatModelBase)
assert "judge_model" in auxiliary_models
Expand All @@ -1563,10 +1562,11 @@ async def workflow_func(
st = time.time()
response = await agent.reply(Msg("user", task["question"], role="user"))
et = time.time()
logger.info(f"Question: {task['question']}\nAnswer: {response.get_text_content()}")
return WorkflowOutput(response=response, metrics={"workflow_time": et - st})

async def judge_func(
task: Dict, response: Msg, auxiliary_models: Dict[str, ChatModelBase]
task: Dict, response: Msg, auxiliary_models: Dict[str, ChatModelBase], logger: Logger
) -> JudgeOutput:
assert "judge_model" in auxiliary_models
judge_model = auxiliary_models["judge_model"]
Expand All @@ -1587,6 +1587,7 @@ async def judge_func(
)
)
et = time.time()
logger.info(f"Judge Response: {judge_response.get_text_content()}")
judge_response = judge_response.get_text_content()
if judge_response is not None and "yes" in judge_response.lower():
is_correct = True
Expand All @@ -1599,33 +1600,33 @@ async def judge_func(

gsm8k_dataset = get_unittest_dataset_config("gsm8k")

dataset = Dataset(
dataset = DatasetConfig(
path=gsm8k_dataset.path,
split="train",
total_steps=2,
)
eval_dataset = Dataset(
eval_dataset = DatasetConfig(
path=gsm8k_dataset.path,
split="test",
)

model = TunerChatModel(
model = TunerModelConfig(
model_path=get_model_path(),
max_model_len=4096,
max_tokens=2048,
inference_engine_num=2,
)

auxiliary_models = {
"judge_model": TunerChatModel(
"judge_model": TunerModelConfig(
model_path=get_model_path(),
max_model_len=8192,
max_tokens=2048,
inference_engine_num=2,
)
}

algorithm = Algorithm(
algorithm = AlgorithmConfig(
algorithm_type="multi_step_grpo",
batch_size=4,
group_size=4,
Expand Down
8 changes: 6 additions & 2 deletions trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,11 +432,12 @@ class DataProcessorConfig:
@dataclass
class TinkerConfig:
enable: bool = False
rank: int = 32 # lora rank
rank: int = 16 # lora rank
seed: Optional[int] = None
train_mlp: bool = True
train_attn: bool = True
train_unembed: bool = True
base_url: Optional[str] = None


@dataclass
Expand Down Expand Up @@ -930,12 +931,15 @@ def _flatten(obj, parent_key="", sep="."):

def get_envs(self) -> Dict[str, str]:
"""Get the environment variables from the config."""
return {
envs = {
PLUGIN_DIRS_ENV_VAR: os.getenv(PLUGIN_DIRS_ENV_VAR, ""),
LOG_LEVEL_ENV_VAR: self.log.level,
LOG_DIR_ENV_VAR: self.log.save_dir,
LOG_NODE_IP_ENV_VAR: "1" if self.log.group_by_node else "0",
}
if self.model.tinker.base_url:
envs["TINKER_BASE_URL"] = self.model.tinker.base_url
return envs


def load_config(config_path: str) -> Config:
Expand Down
5 changes: 3 additions & 2 deletions trinity/common/config_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ def _check_tinker(self, config: Config) -> None:

import tinker

service_client = tinker.ServiceClient()
service_client = tinker.ServiceClient(base_url=config.model.tinker.base_url)
supported_models = {
item.model_name for item in service_client.get_server_capabilities().supported_models
}
Expand Down Expand Up @@ -799,7 +799,8 @@ def validate(self, config: Config) -> None:
config.buffer.batch_size * config.algorithm.repeat_times
)
if (
config.mode in {"train", "both"}
not config.model.tinker.enable
and config.mode in {"train", "both"}
and config.buffer.train_batch_size % config.cluster.trainer_gpu_num != 0
):
raise ValueError(
Expand Down
1 change: 0 additions & 1 deletion trinity/common/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,6 @@ def __init__(
engine_type (str): The type of the model engine. Default to "vllm".
enable_lora (bool): Whether to enable LoRA. Default to False.
enable_history (bool): Whether to enable history recording. Default to False.
enable_thinking (Optional[bool]): Whether to enable thinking mode. Default to None. Only used for Qwen3 series models.
"""
assert (
engine_type.startswith("vllm") or engine_type == "tinker"
Expand Down
4 changes: 4 additions & 0 deletions trinity/common/models/vllm_patch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,13 @@ def get_api_server(
async_llm,
host=host,
port=port,
logger=logger,
model_path=config.model_path, # type: ignore [arg-type]
enable_auto_tool_choice=config.enable_auto_tool_choice,
tool_call_parser=config.tool_call_parser,
reasoning_parser=config.reasoning_parser,
enable_log_requests=config.enable_log_requests,
chat_template=config.chat_template,
)
)
elif vllm_version == parse_version("0.12.0"):
Expand All @@ -59,6 +61,7 @@ def get_api_server(
tool_call_parser=config.tool_call_parser,
reasoning_parser=config.reasoning_parser,
enable_log_requests=config.enable_log_requests,
chat_template=config.chat_template,
)
)
else:
Expand All @@ -78,5 +81,6 @@ def get_api_server(
tool_call_parser=config.tool_call_parser,
reasoning_parser=config.reasoning_parser,
enable_log_requests=config.enable_log_requests,
chat_template=config.chat_template,
)
)
6 changes: 6 additions & 0 deletions trinity/common/models/vllm_patch/api_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import asyncio
import functools
import json
import logging
import time
from typing import Optional, Union

Expand Down Expand Up @@ -335,6 +336,8 @@ async def run_api_server_in_ray_actor(
host: str,
port: int,
model_path: str,
logger: logging.Logger,
chat_template: Optional[str] = None,
enable_auto_tool_choice: bool = False,
tool_call_parser: Optional[str] = None,
reasoning_parser: Optional[str] = None,
Expand Down Expand Up @@ -369,4 +372,7 @@ async def run_api_server_in_ray_actor(
args = parser.parse_args(cli_args)
if vllm_version >= parse_version("0.11.0"):
args.structured_outputs_config.reasoning_parser = reasoning_parser
if chat_template:
args.chat_template = chat_template
logger.info(f"Starting vLLM OpenAI API server with args: {args}")
await run_server_in_ray(args, async_llm)
3 changes: 3 additions & 0 deletions trinity/common/models/vllm_patch/api_patch_v12.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ async def run_api_server_in_ray_actor_v12(
port: int,
model_path: str,
logger: logging.Logger,
chat_template: Optional[str] = None,
enable_auto_tool_choice: bool = False,
tool_call_parser: Optional[str] = None,
reasoning_parser: Optional[str] = None,
Expand Down Expand Up @@ -161,5 +162,7 @@ async def run_api_server_in_ray_actor_v12(
args = parser.parse_args(cli_args)
if vllm_version >= parse_version("0.11.0"):
args.structured_outputs_config.reasoning_parser = reasoning_parser
if chat_template:
args.chat_template = chat_template
logger.info(f"Starting vLLM OpenAI API server with args: {args}")
await run_server_in_ray(args, async_llm, logger)
3 changes: 3 additions & 0 deletions trinity/common/models/vllm_patch/api_patch_v13.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ async def run_api_server_in_ray_actor_v13(
port: int,
model_path: str,
logger: logging.Logger,
chat_template: Optional[str] = None,
enable_auto_tool_choice: bool = False,
tool_call_parser: Optional[str] = None,
reasoning_parser: Optional[str] = None,
Expand Down Expand Up @@ -170,5 +171,7 @@ async def run_api_server_in_ray_actor_v13(
cli_args.extend(["--reasoning-parser", reasoning_parser])
args = parser.parse_args(cli_args)
args.structured_outputs_config.reasoning_parser = reasoning_parser
if chat_template:
args.chat_template = chat_template
logger.info(f"Starting vLLM OpenAI API server with args: {args}")
await run_server_in_ray(args, async_llm, logger)
36 changes: 16 additions & 20 deletions trinity/common/workflows/agentscope_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,35 +179,31 @@ async def run_async(self) -> List[Experience]:

metrics = {}
workflow_sig = inspect.signature(self.workflow_func)
workflow_input_dict = {
"task": self.task.raw_task,
"model": self.chat_model,
}
if "auxiliary_models" in workflow_sig.parameters:
workflow_output = await self.workflow_func(
task=self.task.raw_task,
model=self.chat_model,
auxiliary_models=self.auxiliary_chat_models,
)
else:
workflow_output = await self.workflow_func(
task=self.task.raw_task,
model=self.chat_model,
)
workflow_input_dict["auxiliary_models"] = self.auxiliary_chat_models
if "logger" in workflow_sig.parameters:
workflow_input_dict["logger"] = self.logger
workflow_output = await self.workflow_func(**workflow_input_dict)
if not isinstance(workflow_output, WorkflowOutput):
raise ValueError(
"The 'workflow_func' must return a WorkflowOutput object.",
)
metrics.update(workflow_output.metrics or {})
if self.judge_func is not None:
judge_sig = inspect.signature(self.judge_func)
judge_input_dict = {
"task": self.task.raw_task,
"response": workflow_output.response,
}
if "auxiliary_models" in judge_sig.parameters:
judge_output = await self.judge_func(
task=self.task.raw_task,
response=workflow_output.response,
auxiliary_models=self.auxiliary_chat_models,
)
else:
judge_output = await self.judge_func(
task=self.task.raw_task,
response=workflow_output.response,
)
judge_input_dict["auxiliary_models"] = self.auxiliary_chat_models
if "logger" in judge_sig.parameters:
judge_input_dict["logger"] = self.logger
judge_output = await self.judge_func(**judge_input_dict)
if not isinstance(judge_output, JudgeOutput):
raise ValueError(
"The 'judge_func' must return a JudgeOutput object.",
Expand Down
Loading