diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index c6fd000983..9ac918ab12 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -10,6 +10,7 @@ import unittest from copy import deepcopy from datetime import datetime +from logging import Logger from typing import Dict from unittest import mock @@ -1530,26 +1531,24 @@ def tearDown(self) -> None: ray.shutdown(_exiting_interpreter=True) def test_agentscope_tuner(self): - try: - from agentscope.agent import ReActAgent - from agentscope.formatter import OpenAIChatFormatter - from agentscope.message import Msg - from agentscope.model import ChatModelBase - from agentscope.tuner import ( - Algorithm, - Dataset, - JudgeOutput, - TunerChatModel, - WorkflowOutput, - tune, - ) - except ImportError: - self.skipTest("agentscope >= 1.0.12 is not installed") + from agentscope.agent import ReActAgent + from agentscope.formatter import OpenAIChatFormatter + from agentscope.message import Msg + from agentscope.model import ChatModelBase + from agentscope.tuner import ( + AlgorithmConfig, + DatasetConfig, + JudgeOutput, + TunerModelConfig, + WorkflowOutput, + tune, + ) async def workflow_func( task: Dict, model: ChatModelBase, auxiliary_models: Dict[str, ChatModelBase], + logger: Logger, ) -> WorkflowOutput: assert isinstance(model, ChatModelBase) assert "judge_model" in auxiliary_models @@ -1563,10 +1562,11 @@ async def workflow_func( st = time.time() response = await agent.reply(Msg("user", task["question"], role="user")) et = time.time() + logger.info(f"Question: {task['question']}\nAnswer: {response.get_text_content()}") return WorkflowOutput(response=response, metrics={"workflow_time": et - st}) async def judge_func( - task: Dict, response: Msg, auxiliary_models: Dict[str, ChatModelBase] + task: Dict, response: Msg, auxiliary_models: Dict[str, ChatModelBase], logger: Logger ) -> JudgeOutput: assert "judge_model" in auxiliary_models judge_model = auxiliary_models["judge_model"] @@ -1587,6 +1587,7 @@ async def judge_func( ) ) et = time.time() + logger.info(f"Judge Response: {judge_response.get_text_content()}") judge_response = judge_response.get_text_content() if judge_response is not None and "yes" in judge_response.lower(): is_correct = True @@ -1599,17 +1600,17 @@ async def judge_func( gsm8k_dataset = get_unittest_dataset_config("gsm8k") - dataset = Dataset( + dataset = DatasetConfig( path=gsm8k_dataset.path, split="train", total_steps=2, ) - eval_dataset = Dataset( + eval_dataset = DatasetConfig( path=gsm8k_dataset.path, split="test", ) - model = TunerChatModel( + model = TunerModelConfig( model_path=get_model_path(), max_model_len=4096, max_tokens=2048, @@ -1617,7 +1618,7 @@ async def judge_func( ) auxiliary_models = { - "judge_model": TunerChatModel( + "judge_model": TunerModelConfig( model_path=get_model_path(), max_model_len=8192, max_tokens=2048, @@ -1625,7 +1626,7 @@ async def judge_func( ) } - algorithm = Algorithm( + algorithm = AlgorithmConfig( algorithm_type="multi_step_grpo", batch_size=4, group_size=4, diff --git a/trinity/common/config.py b/trinity/common/config.py index 8547e0374b..5b712a2730 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -432,11 +432,12 @@ class DataProcessorConfig: @dataclass class TinkerConfig: enable: bool = False - rank: int = 32 # lora rank + rank: int = 16 # lora rank seed: Optional[int] = None train_mlp: bool = True train_attn: bool = True train_unembed: bool = True + base_url: Optional[str] = None @dataclass @@ -930,12 +931,15 @@ def _flatten(obj, parent_key="", sep="."): def get_envs(self) -> Dict[str, str]: """Get the environment variables from the config.""" - return { + envs = { PLUGIN_DIRS_ENV_VAR: os.getenv(PLUGIN_DIRS_ENV_VAR, ""), LOG_LEVEL_ENV_VAR: self.log.level, LOG_DIR_ENV_VAR: self.log.save_dir, LOG_NODE_IP_ENV_VAR: "1" if self.log.group_by_node else "0", } + if self.model.tinker.base_url: + envs["TINKER_BASE_URL"] = self.model.tinker.base_url + return envs def load_config(config_path: str) -> Config: diff --git a/trinity/common/config_validator.py b/trinity/common/config_validator.py index 0b131cc6c1..8e565c730e 100644 --- a/trinity/common/config_validator.py +++ b/trinity/common/config_validator.py @@ -396,7 +396,7 @@ def _check_tinker(self, config: Config) -> None: import tinker - service_client = tinker.ServiceClient() + service_client = tinker.ServiceClient(base_url=config.model.tinker.base_url) supported_models = { item.model_name for item in service_client.get_server_capabilities().supported_models } @@ -799,7 +799,8 @@ def validate(self, config: Config) -> None: config.buffer.batch_size * config.algorithm.repeat_times ) if ( - config.mode in {"train", "both"} + not config.model.tinker.enable + and config.mode in {"train", "both"} and config.buffer.train_batch_size % config.cluster.trainer_gpu_num != 0 ): raise ValueError( diff --git a/trinity/common/models/model.py b/trinity/common/models/model.py index 81b57aa74e..86305b6658 100644 --- a/trinity/common/models/model.py +++ b/trinity/common/models/model.py @@ -265,7 +265,6 @@ def __init__( engine_type (str): The type of the model engine. Default to "vllm". enable_lora (bool): Whether to enable LoRA. Default to False. enable_history (bool): Whether to enable history recording. Default to False. - enable_thinking (Optional[bool]): Whether to enable thinking mode. Default to None. Only used for Qwen3 series models. """ assert ( engine_type.startswith("vllm") or engine_type == "tinker" diff --git a/trinity/common/models/vllm_patch/__init__.py b/trinity/common/models/vllm_patch/__init__.py index 4253127a68..e97b1da07c 100644 --- a/trinity/common/models/vllm_patch/__init__.py +++ b/trinity/common/models/vllm_patch/__init__.py @@ -36,11 +36,13 @@ def get_api_server( async_llm, host=host, port=port, + logger=logger, model_path=config.model_path, # type: ignore [arg-type] enable_auto_tool_choice=config.enable_auto_tool_choice, tool_call_parser=config.tool_call_parser, reasoning_parser=config.reasoning_parser, enable_log_requests=config.enable_log_requests, + chat_template=config.chat_template, ) ) elif vllm_version == parse_version("0.12.0"): @@ -59,6 +61,7 @@ def get_api_server( tool_call_parser=config.tool_call_parser, reasoning_parser=config.reasoning_parser, enable_log_requests=config.enable_log_requests, + chat_template=config.chat_template, ) ) else: @@ -78,5 +81,6 @@ def get_api_server( tool_call_parser=config.tool_call_parser, reasoning_parser=config.reasoning_parser, enable_log_requests=config.enable_log_requests, + chat_template=config.chat_template, ) ) diff --git a/trinity/common/models/vllm_patch/api_patch.py b/trinity/common/models/vllm_patch/api_patch.py index 623f9c04af..0a3b02e654 100644 --- a/trinity/common/models/vllm_patch/api_patch.py +++ b/trinity/common/models/vllm_patch/api_patch.py @@ -6,6 +6,7 @@ import asyncio import functools import json +import logging import time from typing import Optional, Union @@ -335,6 +336,8 @@ async def run_api_server_in_ray_actor( host: str, port: int, model_path: str, + logger: logging.Logger, + chat_template: Optional[str] = None, enable_auto_tool_choice: bool = False, tool_call_parser: Optional[str] = None, reasoning_parser: Optional[str] = None, @@ -369,4 +372,7 @@ async def run_api_server_in_ray_actor( args = parser.parse_args(cli_args) if vllm_version >= parse_version("0.11.0"): args.structured_outputs_config.reasoning_parser = reasoning_parser + if chat_template: + args.chat_template = chat_template + logger.info(f"Starting vLLM OpenAI API server with args: {args}") await run_server_in_ray(args, async_llm) diff --git a/trinity/common/models/vllm_patch/api_patch_v12.py b/trinity/common/models/vllm_patch/api_patch_v12.py index 1419184a90..14b87108b3 100644 --- a/trinity/common/models/vllm_patch/api_patch_v12.py +++ b/trinity/common/models/vllm_patch/api_patch_v12.py @@ -127,6 +127,7 @@ async def run_api_server_in_ray_actor_v12( port: int, model_path: str, logger: logging.Logger, + chat_template: Optional[str] = None, enable_auto_tool_choice: bool = False, tool_call_parser: Optional[str] = None, reasoning_parser: Optional[str] = None, @@ -161,5 +162,7 @@ async def run_api_server_in_ray_actor_v12( args = parser.parse_args(cli_args) if vllm_version >= parse_version("0.11.0"): args.structured_outputs_config.reasoning_parser = reasoning_parser + if chat_template: + args.chat_template = chat_template logger.info(f"Starting vLLM OpenAI API server with args: {args}") await run_server_in_ray(args, async_llm, logger) diff --git a/trinity/common/models/vllm_patch/api_patch_v13.py b/trinity/common/models/vllm_patch/api_patch_v13.py index 480ad5424e..8f6dea10f4 100644 --- a/trinity/common/models/vllm_patch/api_patch_v13.py +++ b/trinity/common/models/vllm_patch/api_patch_v13.py @@ -137,6 +137,7 @@ async def run_api_server_in_ray_actor_v13( port: int, model_path: str, logger: logging.Logger, + chat_template: Optional[str] = None, enable_auto_tool_choice: bool = False, tool_call_parser: Optional[str] = None, reasoning_parser: Optional[str] = None, @@ -170,5 +171,7 @@ async def run_api_server_in_ray_actor_v13( cli_args.extend(["--reasoning-parser", reasoning_parser]) args = parser.parse_args(cli_args) args.structured_outputs_config.reasoning_parser = reasoning_parser + if chat_template: + args.chat_template = chat_template logger.info(f"Starting vLLM OpenAI API server with args: {args}") await run_server_in_ray(args, async_llm, logger) diff --git a/trinity/common/workflows/agentscope_workflow.py b/trinity/common/workflows/agentscope_workflow.py index feccfc5932..c3a42f4f62 100644 --- a/trinity/common/workflows/agentscope_workflow.py +++ b/trinity/common/workflows/agentscope_workflow.py @@ -179,17 +179,15 @@ async def run_async(self) -> List[Experience]: metrics = {} workflow_sig = inspect.signature(self.workflow_func) + workflow_input_dict = { + "task": self.task.raw_task, + "model": self.chat_model, + } if "auxiliary_models" in workflow_sig.parameters: - workflow_output = await self.workflow_func( - task=self.task.raw_task, - model=self.chat_model, - auxiliary_models=self.auxiliary_chat_models, - ) - else: - workflow_output = await self.workflow_func( - task=self.task.raw_task, - model=self.chat_model, - ) + workflow_input_dict["auxiliary_models"] = self.auxiliary_chat_models + if "logger" in workflow_sig.parameters: + workflow_input_dict["logger"] = self.logger + workflow_output = await self.workflow_func(**workflow_input_dict) if not isinstance(workflow_output, WorkflowOutput): raise ValueError( "The 'workflow_func' must return a WorkflowOutput object.", @@ -197,17 +195,15 @@ async def run_async(self) -> List[Experience]: metrics.update(workflow_output.metrics or {}) if self.judge_func is not None: judge_sig = inspect.signature(self.judge_func) + judge_input_dict = { + "task": self.task.raw_task, + "response": workflow_output.response, + } if "auxiliary_models" in judge_sig.parameters: - judge_output = await self.judge_func( - task=self.task.raw_task, - response=workflow_output.response, - auxiliary_models=self.auxiliary_chat_models, - ) - else: - judge_output = await self.judge_func( - task=self.task.raw_task, - response=workflow_output.response, - ) + judge_input_dict["auxiliary_models"] = self.auxiliary_chat_models + if "logger" in judge_sig.parameters: + judge_input_dict["logger"] = self.logger + judge_output = await self.judge_func(**judge_input_dict) if not isinstance(judge_output, JudgeOutput): raise ValueError( "The 'judge_func' must return a JudgeOutput object.", diff --git a/trinity/trainer/tinker_trainer.py b/trinity/trainer/tinker_trainer.py index c58c953045..5ff08e0272 100644 --- a/trinity/trainer/tinker_trainer.py +++ b/trinity/trainer/tinker_trainer.py @@ -75,7 +75,8 @@ def _init_algorithm(self): self.min_lr_ratio = algorithm_config.optimizer.min_lr_ratio assert 0.0 <= self.min_lr_ratio <= 1.0 self.logger.info( - f"Total steps: {self.total_steps}, num_warmup_steps: {self.num_warmup_steps}" + f"Total steps: {self.total_steps if self.total_steps != sys.maxsize else 'unlimited'}," + f" num_warmup_steps: {self.num_warmup_steps}" ) if self.lr_scheduler_type not in {"constant", "cosine"}: @@ -124,7 +125,7 @@ def adam_params(self): async def prepare(self): self.service_client = tinker.ServiceClient() - + self.checkpoint_manager = self.service_client.create_rest_client() name_prefix_list = [self.config.project, self.config.group, self.config.name] self.tinker_checkpoint_name_prefix = "-".join( [prefix for prefix in name_prefix_list if prefix] @@ -165,6 +166,7 @@ async def prepare(self): self.latest_remote_checkpoint_step = 0 self.latest_remote_checkpoint_path = None self._train_step_num = 0 + self.model_info = await self.actor_client.get_info_async() if os.path.exists(self.local_latest_state_dict_iteration): with open(self.local_latest_state_dict_iteration, "r") as f: @@ -177,7 +179,7 @@ async def prepare(self): with open(sampler_file_path, "r") as f: self.latest_remote_sampler_path = f.read().strip() else: - self.latest_remote_sampler_step = 0 + self.latest_remote_sampler_step = None self.latest_remote_sampler_path = None self.ref_client = await self.service_client.create_sampling_client_async( @@ -318,15 +320,16 @@ async def train_step(self, batch_exps: List[Experience]) -> Dict: return metrics - def save_checkpoint(self, block_until_saved: bool = False, save_as_hf: bool = False) -> None: + async def save_checkpoint( + self, block_until_saved: bool = False, save_as_hf: bool = False + ) -> None: """Save the checkpoint.""" if self.train_step_num == self.latest_remote_checkpoint_step: return self.latest_remote_checkpoint_step = self.train_step_num checkpoint_name = f"{self.tinker_checkpoint_name_prefix}-state-{self.train_step_num}" - self.latest_remote_checkpoint_path = ( - self.actor_client.save_state(checkpoint_name).result().path - ) + save_state_future = await self.actor_client.save_state_async(checkpoint_name) + self.latest_remote_checkpoint_path = (await save_state_future).path local_path = os.path.join( self.default_local_dir, f"global_step_{self.train_step_num}", @@ -352,24 +355,38 @@ def sync_weight(self) -> None: """Sync the model weight.""" raise NotImplementedError("Tinker trainer does not support NCCL sync") - def upload_state_dict(self) -> None: + async def upload_state_dict(self) -> None: """Upload the state dict to Synchronizer.""" - self.save_state_dict() + await self.save_state_dict() ray.get( self.synchronizer.set_model_state_dict.remote( self.latest_remote_sampler_path, self.train_step_num ) ) - def save_state_dict(self) -> None: + async def save_state_dict(self) -> None: """Only save the model state dict for Synchronizer.""" if self.train_step_num == self.latest_remote_sampler_step: return + self.stale_remote_sampler_step = self.latest_remote_sampler_step self.latest_remote_sampler_step = self.train_step_num - checkpoint_name = f"{self.tinker_checkpoint_name_prefix}-sampler-{self.train_step_num}" - self.latest_remote_sampler_path = ( - self.actor_client.save_weights_for_sampler(checkpoint_name).result().path + current_checkpoint_name = ( + f"{self.tinker_checkpoint_name_prefix}-sampler-{self.train_step_num}" + ) + save_weights_future = await self.actor_client.save_weights_for_sampler_async( + current_checkpoint_name ) + self.latest_remote_sampler_path = (await save_weights_future).path + if self.stale_remote_sampler_step is not None: + stale_checkpoint_name = ( + f"{self.tinker_checkpoint_name_prefix}-sampler-{self.stale_remote_sampler_step}" + ) + try: + await self.checkpoint_manager.delete_checkpoint_async( + self.model_info.model_id, stale_checkpoint_name + ) + except Exception: + self.logger.warning(f"Failed to remove stale state_dict {stale_checkpoint_name}") local_path = os.path.join( self.default_local_dir, f"global_step_{self.train_step_num}", diff --git a/trinity/trainer/trainer.py b/trinity/trainer/trainer.py index 4cb9c52cc4..cb16bf7b64 100644 --- a/trinity/trainer/trainer.py +++ b/trinity/trainer/trainer.py @@ -91,7 +91,7 @@ async def train(self) -> str: metrics.update(await self.sync_weight()) if self.need_save(): metrics.update( - self.save_checkpoint(save_as_hf=self.save_hf_checkpoint == "always") + await self.save_checkpoint(save_as_hf=self.save_hf_checkpoint == "always") ) if self.config.trainer.enable_preview: self._log_experiences(repr_samples) @@ -103,7 +103,9 @@ async def train(self) -> str: self.logger.error(f"Error in Trainer:\n{traceback.format_exc()}") break - self.save_checkpoint(block_until_saved=True, save_as_hf=self.save_hf_checkpoint != "never") + await self.save_checkpoint( + block_until_saved=True, save_as_hf=self.save_hf_checkpoint != "never" + ) await self.synchronizer.set_trainer_status.remote(RunningStatus.STOPPED) self.logger.info("--------------------\n> Trainer finished.\n--------------------") return self.config.trainer.name @@ -176,9 +178,9 @@ async def sync_weight(self) -> Dict: else: self.engine.sync_weight() elif self.config.synchronizer.sync_method == SyncMethod.CHECKPOINT: - self.engine.save_state_dict() + await self.engine.save_state_dict() elif self.config.synchronizer.sync_method == SyncMethod.MEMORY: - self.engine.upload_state_dict() + await self.engine.upload_state_dict() self.last_sync_step = self.train_step_num self.last_sync_time = time.time() await self.synchronizer.set_trainer_status.remote(RunningStatus.RUNNING) @@ -193,11 +195,15 @@ def _log_experiences(self, samples: List[Dict]) -> None: ) self._sample_exps_to_log.clear() - def save_checkpoint(self, block_until_saved: bool = False, save_as_hf: bool = False) -> Dict: + async def save_checkpoint( + self, block_until_saved: bool = False, save_as_hf: bool = False + ) -> Dict: metrics = {} with Timer(metrics, "time/save_checkpoint"): self.logger.info(f"Saving checkpoint at step {self.train_step_num}...") - self.engine.save_checkpoint(block_until_saved=block_until_saved, save_as_hf=save_as_hf) + await self.engine.save_checkpoint( + block_until_saved=block_until_saved, save_as_hf=save_as_hf + ) self.state.save_trainer( current_step=self.train_step_num, sample_strategy_state=self.sample_strategy.state_dict(), @@ -250,7 +256,9 @@ async def train_step(self, batch_exps: List[Experience]) -> Dict: """ @abstractmethod - def save_checkpoint(self, block_until_saved: bool = False, save_as_hf: bool = False) -> None: + async def save_checkpoint( + self, block_until_saved: bool = False, save_as_hf: bool = False + ) -> None: """Save the checkpoint.""" @abstractmethod @@ -258,11 +266,11 @@ def sync_weight(self) -> None: """Sync the model weight.""" @abstractmethod - def upload_state_dict(self) -> None: + async def upload_state_dict(self) -> None: """Upload the state dict to Synchronizer.""" @abstractmethod - def save_state_dict(self) -> None: + async def save_state_dict(self) -> None: """Only save the model state dict for Synchronizer.""" diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py index 84dfa8e659..cdd74acd5f 100644 --- a/trinity/trainer/verl_trainer.py +++ b/trinity/trainer/verl_trainer.py @@ -435,7 +435,7 @@ def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampl self.config.actor_rollout_ref.actor.optim.total_training_steps = self.total_training_steps self.config.critic.optim.total_training_steps = self.total_training_steps - def save_state_dict(self): # checkpoint sync + async def save_state_dict(self): # checkpoint sync actor_local_path = os.path.join( self.config.trainer.default_local_dir, f"global_step_{self.global_steps}", "actor" ) @@ -443,9 +443,9 @@ def save_state_dict(self): # checkpoint sync actor_local_path, global_step=self.global_steps, ) - ray.get(self.checkpoint_monitor.monitor_step.remote(self.global_steps, is_state_dict=True)) + await self.checkpoint_monitor.monitor_step.remote(self.global_steps, is_state_dict=True) - def upload_state_dict(self): # state dict sync + async def upload_state_dict(self): # state dict sync self.actor_rollout_wg.upload_state_dict(self.global_steps) async def train_step(self, batch_exps: List[Experience]) -> Dict: # noqa C901 @@ -585,14 +585,16 @@ async def train_step(self, batch_exps: List[Experience]) -> Dict: # noqa C901 return metrics - def save_checkpoint(self, block_until_saved: bool = False, save_as_hf: bool = False) -> None: - self._save_checkpoint(save_as_hf=save_as_hf) + async def save_checkpoint( + self, block_until_saved: bool = False, save_as_hf: bool = False + ) -> None: + await self._save_checkpoint(save_as_hf=save_as_hf) if block_until_saved: self.actor_rollout_wg.wait_on_save_thread() if self.algorithm and self.algorithm.use_critic: self.critic_wg.wait_on_save_thread() - def _save_checkpoint(self, save_as_hf: bool = False): + async def _save_checkpoint(self, save_as_hf: bool = False): # path: given_path + `/global_step_{global_steps}` + `/actor` local_global_step_folder = os.path.join( self.config.trainer.default_local_dir, f"global_step_{self.global_steps}" @@ -644,7 +646,7 @@ def _save_checkpoint(self, save_as_hf: bool = False): max_ckpt_to_keep=max_critic_ckpt_to_keep, ) - ray.get(self.checkpoint_monitor.monitor_step.remote(self.global_steps)) + await self.checkpoint_monitor.monitor_step.remote(self.global_steps) def _load_checkpoint(self): if self.config.trainer.resume_mode == "disable": diff --git a/trinity/utils/monitor.py b/trinity/utils/monitor.py index d421e2b279..e225a7ea2d 100644 --- a/trinity/utils/monitor.py +++ b/trinity/utils/monitor.py @@ -345,21 +345,8 @@ def __init__( self.console_logger = get_logger(__name__, in_ray_actor=True) def log_table(self, table_name: str, experiences_table: pd.DataFrame, step: int): - # Convert pandas DataFrame to SwanLab ECharts Table - headers: List[str] = list(experiences_table.columns) - # Ensure rows are native Python types - rows: List[List[object]] = experiences_table.astype(object).values.tolist() - try: - tbl = swanlab.echarts.Table() - tbl.add(headers, rows) - swanlab.log({table_name: tbl}, step=step) - except Exception as e: - self.console_logger.warning( - f"Failed to log table '{table_name}' as echarts, falling back to CSV. Error: {e}" - ) - # Fallback: log as CSV string if echarts table is unavailable - csv_str = experiences_table.to_csv(index=False) - swanlab.log({table_name: csv_str}, step=step) + # Not support log table yet + pass def log(self, data: dict, step: int, commit: bool = False) -> None: """Log metrics."""