diff --git a/examples/grpo_lora_gsm8k/README.md b/examples/grpo_lora_gsm8k/README.md new file mode 100644 index 0000000000..e3392cfd3e --- /dev/null +++ b/examples/grpo_lora_gsm8k/README.md @@ -0,0 +1,42 @@ +# GRPO with LoRA + +This example shows the usage of LoRA on the GSM8K dataset. + +## GRPO training +Compared with full model fine-tuning, Trinity-RFT enable LoRA by providing the `lora_configs` field as follows: + +```yaml +project: "Trinity-RFT-gsm8k" +name: "qwen2.5-1.5B-gsm8k" +model: + lora_configs: + - name: lora + lora_rank: 32 + lora_alpha: 32 +synchronizer: + sync_method: 'checkpoint' +``` + +Note that the `lora_rank` and `lora_alpha` are hyperparameters that need to be tuned. For `lora_rank`, a very small value can lead to slower convergence or worse training performance, while a very large value can lead to memory and performance issues. + +For now, we only support a single-lora training and synchronizing via `checkpoint`. + +## Benchmark with LoRA +After training, we can evaluate the performance of checkpoints via the `bench` mode. Some key configurations are shown below: + +```yaml +mode: bench +project: "Trinity-RFT-gsm8k" # same as training +name: "qwen2.5-1.5B-gsm8k" # same as training +model: + lora_configs: # same as training + - name: lora + lora_rank: 32 + lora_alpha: 32 +explorer: + rollout_model: + engine_num: 2 # ensure all gpus are used for benchmarking + tensor_parallel_size: 4 +synchronizer: + sync_method: 'checkpoint' +``` diff --git a/examples/grpo_lora_gsm8k/gsm8k.yaml b/examples/grpo_lora_gsm8k/gsm8k.yaml new file mode 100644 index 0000000000..26f19013fa --- /dev/null +++ b/examples/grpo_lora_gsm8k/gsm8k.yaml @@ -0,0 +1,82 @@ +project: "Trinity-RFT-gsm8k" +name: "qwen2.5-1.5B-gsm8k" +checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints} +algorithm: + algorithm_type: grpo + repeat_times: 8 +model: + model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-1.5B-Instruct} + max_response_tokens: 1024 + max_model_len: 1280 + lora_configs: + - name: lora + lora_rank: 32 + lora_alpha: 32 +cluster: + node_num: 1 + gpu_per_node: 8 +buffer: + total_epochs: 10 + batch_size: 96 + explorer_input: + taskset: + name: gsm8k + storage_type: file + path: 'openai/gsm8k' + subset_name: 'main' + split: 'train' + format: + prompt_key: 'question' + response_key: 'answer' + rollout_args: + temperature: 1.0 + eval_tasksets: + - name: gsm8k-eval + storage_type: file + path: 'openai/gsm8k' + subset_name: 'main' + split: 'test' + format: + prompt_key: 'question' + response_key: 'answer' + default_workflow_type: 'math_workflow' + trainer_input: + experience_buffer: + name: gsm8k_buffer + storage_type: queue +explorer: + eval_interval: 10 + runner_per_model: 16 + rollout_model: + engine_num: 1 + tensor_parallel_size: 4 + enable_prefix_caching: false + enforce_eager: true + dtype: bfloat16 + seed: 42 +synchronizer: + sync_method: 'checkpoint' + sync_interval: 1 + sync_timeout: 1200 +trainer: + trainer_type: 'verl' + save_interval: 100 + trainer_config: + actor_rollout_ref: + model: + use_remove_padding: true + actor: + use_dynamic_bsz: true + ppo_max_token_len_per_gpu: 16384 + ulysses_sequence_parallel_size: 1 + optim: + lr: 1e-5 + checkpoint: + load_contents: + - model + save_contents: + - model + ref: + log_prob_use_dynamic_bsz: ${trainer.trainer_config.actor_rollout_ref.actor.use_dynamic_bsz} + log_prob_max_token_len_per_gpu: ${trainer.trainer_config.actor_rollout_ref.actor.ppo_max_token_len_per_gpu} + ulysses_sequence_parallel_size: ${trainer.trainer_config.actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size diff --git a/tests/tools.py b/tests/tools.py index bfa045c469..7dc0b0e75f 100644 --- a/tests/tools.py +++ b/tests/tools.py @@ -6,7 +6,13 @@ import ray from tensorboard.backend.event_processing.event_accumulator import EventAccumulator -from trinity.common.config import Config, FormatConfig, StorageConfig, load_config +from trinity.common.config import ( + Config, + FormatConfig, + LoRAConfig, + StorageConfig, + load_config, +) from trinity.common.constants import ( CHECKPOINT_ROOT_DIR_ENV_VAR, MODEL_PATH_ENV_VAR, @@ -64,11 +70,15 @@ def get_vision_languge_model_path() -> str: return path +def get_lora_config() -> LoRAConfig: + return LoRAConfig(name="lora", lora_rank=16, lora_alpha=16) + + def get_unittest_dataset_config( dataset_name: str = "countdown", split: str = "train" ) -> StorageConfig: - """Countdown dataset with 17 samples.""" if dataset_name == "countdown" or dataset_name == "copy_countdown": + # Countdown dataset with 17 samples return StorageConfig( name=dataset_name, path=os.path.join(os.path.dirname(__file__), "template", "data", "countdown"), @@ -82,6 +92,7 @@ def get_unittest_dataset_config( default_reward_fn_type="countdown_reward", ) elif dataset_name in {"eval_short", "eval_long"}: + # Eval_short dataset with 2 samples, eval_long dataset with 8 samples return StorageConfig( name=dataset_name, path=os.path.join(os.path.dirname(__file__), "template", "data", dataset_name), @@ -94,6 +105,7 @@ def get_unittest_dataset_config( default_reward_fn_type="math_reward", ) elif dataset_name == "gsm8k": + # GSM8K dataset with 16 samples return StorageConfig( name=dataset_name, path=os.path.join(os.path.dirname(__file__), "template", "data", "gsm8k"), @@ -106,6 +118,7 @@ def get_unittest_dataset_config( default_reward_fn_type="math_reward", ) elif dataset_name == "sft_for_gsm8k": + # SFT dataset with 8 samples return StorageConfig( name=dataset_name, path=os.path.join(os.path.dirname(__file__), "template", "data", "sft_for_gsm8k"), @@ -118,6 +131,7 @@ def get_unittest_dataset_config( ), ) elif dataset_name == "sft_with_tools": + # SFT_with_tools dataset with 4 samples return StorageConfig( name=dataset_name, path=os.path.join(os.path.dirname(__file__), "template", "data", "sft_with_tools"), @@ -130,6 +144,7 @@ def get_unittest_dataset_config( ), ) elif dataset_name == "dpo": + # HumanLike DPO dataset with 17 samples return StorageConfig( name=dataset_name, path=os.path.join(os.path.dirname(__file__), "template", "data", "human_like"), @@ -142,6 +157,7 @@ def get_unittest_dataset_config( ), ) elif dataset_name == "geometry": + # Multi-modal geometry dataset with 8 samples return StorageConfig( name=dataset_name, path=os.path.join(os.path.dirname(__file__), "template", "data", "geometry"), diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index 3dc67b692f..84c3f36c05 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -16,6 +16,7 @@ RayUnittestBase, TensorBoardParser, get_checkpoint_path, + get_lora_config, get_model_path, get_template_config, get_unittest_dataset_config, @@ -724,3 +725,67 @@ def test_trainer(self): def tearDown(self): # remove dir only when the test passed shutil.rmtree(self.config.checkpoint_job_dir) + + +class TestTrainerLoRA(BaseTrainerCase): + def test_trainer(self): + """Test both mode with LoRA request.""" + self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k") + self.config.buffer.explorer_input.eval_tasksets.append( + get_unittest_dataset_config("gsm8k", "test") + ) + self.config.model.model_path = get_model_path() + self.config.algorithm.algorithm_type = "grpo" + self.config.algorithm.advantage_fn = "grpo" + self.config.algorithm.kl_loss_fn = "none" + self.config.algorithm.repeat_times = 4 + self.config.buffer.batch_size = 4 + self.config.buffer.total_steps = 2 + self.config.cluster.node_num = 1 + self.config.cluster.gpu_per_node = 4 + self.config.explorer.eval_interval = 2 + self.config.model.lora_configs = [get_lora_config()] + self.config.synchronizer.sync_method = SyncMethod.CHECKPOINT + self.config.synchronizer.sync_interval = 2 + self.config.trainer.save_interval = 2 + self.config.check_and_update() + both(self.config) + # check metrics are available + parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard")) + rollout_metrics = parser.metric_list("rollout") + self.assertTrue(len(rollout_metrics) > 0) + self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 2) + actor_metrics = parser.metric_list("actor") + self.assertTrue(len(actor_metrics) > 0) + self.assertEqual(parser.metric_max_step(actor_metrics[0]), 2) + response_metrics = parser.metric_list("response_length") + self.assertTrue(len(response_metrics) > 0) + self.assertEqual(parser.metric_max_step(response_metrics[0]), 2) + ray.shutdown(_exiting_interpreter=True) + # check save lastest checkpoint + checkpoint_step_2, step_num = get_checkpoint_dir_with_step_num( + checkpoint_root_path=self.config.checkpoint_job_dir, + trainer_type=self.config.trainer.trainer_type, + ) + self.assertTrue(len(os.listdir(os.path.join(checkpoint_step_2, "actor"))) > 0) + self.assertTrue( + len(os.listdir(os.path.join(checkpoint_step_2, "actor", "lora_adapter"))) > 0 + ) + self.assertEqual(step_num, 2) + + # test bench mode + ray.init(ignore_reinit_error=True, namespace=self.config.ray_namespace) + self.config.mode = "bench" + self.config.synchronizer.sync_method = SyncMethod.CHECKPOINT + self.config.explorer.bench_on_latest_checkpoint = False + self.config.check_and_update() + bench(self.config) + parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard")) + for prefix in ["eval", "bench"]: + gsm8k_metrics = parser.metric_list(f"{prefix}/gsm8k") + self.assertTrue(len(gsm8k_metrics) > 0) + gsm8k_metric_steps = parser.metric_steps(gsm8k_metrics[0]) + self.assertEqual([0, 2], gsm8k_metric_steps) + + def tearDown(self): + shutil.rmtree(self.config.checkpoint_job_dir) diff --git a/trinity/common/config.py b/trinity/common/config.py index 258bd78b06..6946416a27 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -22,6 +22,7 @@ ) from trinity.utils.annotations import Experimental from trinity.utils.log import get_logger +from trinity.utils.lora_utils import create_dummy_lora logger = get_logger(__name__) @@ -81,6 +82,19 @@ class GenerationConfig: n: int = 1 +@dataclass +class LoRAConfig: + """LoRA config, only effective for rollout model, not for auxiliary models.""" + + name: Optional[str] = None + path: Optional[str] = None + base_model_name: Optional[str] = None + lora_rank: int = 32 + lora_alpha: int = 32 + lora_dtype: str = "auto" + target_modules: str = "all-linear" + + @dataclass class StorageConfig: """Storage config.""" @@ -244,6 +258,11 @@ class ModelConfig: # the minimum number of tokens for the response min_response_tokens: int = 1 + # lora config + lora_configs: Optional[List[LoRAConfig]] = None + fully_sharded_loras: bool = False + max_cpu_loras: Optional[int] = None + @dataclass class InferenceModelConfig: @@ -294,6 +313,11 @@ class InferenceModelConfig: # ! DO NOT SET bundle_indices: str = "" + # ! DO NOT SET, automatically set from model.lora_configs + enable_lora: bool = False + lora_modules: Optional[List[Dict]] = None + lora_kwargs: Optional[dict] = field(default_factory=dict) + @dataclass class AlgorithmConfig: @@ -903,11 +927,17 @@ def check_and_update(self) -> Config: # noqa: C901 if not self.continue_from_checkpoint and ( os.path.exists(self.checkpoint_job_dir) and os.listdir(self.checkpoint_job_dir) ): - ori_name = self.name - timestamp = datetime.now().strftime("%Y%m%d%H%M%S") - self.name = f"{ori_name}_{timestamp}" - self.checkpoint_job_dir = f"{self.checkpoint_job_dir}_{timestamp}" - logger.warning(f"Experiment [{ori_name}] already exists, renamed as {self.name}.") + if self.mode == "bench": + logger.warning( + "For bench mode, `continue_from_checkpoint` is set as `true` to enable using existing checkpoints." + ) + self.continue_from_checkpoint = True + else: + ori_name = self.name + timestamp = datetime.now().strftime("%Y%m%d%H%M%S") + self.name = f"{ori_name}_{timestamp}" + self.checkpoint_job_dir = f"{self.checkpoint_job_dir}_{timestamp}" + logger.warning(f"Experiment [{ori_name}] already exists, renamed as {self.name}.") os.makedirs(self.checkpoint_job_dir, exist_ok=True) # check model @@ -928,6 +958,45 @@ def check_and_update(self) -> Config: # noqa: C901 set_if_none(aux_model, "max_response_tokens", self.model.max_response_tokens) set_if_none(aux_model, "min_response_tokens", self.model.min_response_tokens) + # for lora configs + if self.model.lora_configs is not None: + self.explorer.rollout_model.enable_lora = True + if len(self.model.lora_configs) > 1: + raise ValueError("Only one lora adapter is supported for now.") + if self.model.lora_configs[0].path is None: + logger.info("Creating dummy lora, since no lora_path is provided.") + lora_path = create_dummy_lora( + model_path=self.model.model_path, + checkpoint_job_dir=self.checkpoint_job_dir, + lora_rank=self.model.lora_configs[0].lora_rank, + lora_alpha=self.model.lora_configs[0].lora_alpha, + target_modules=self.model.lora_configs[0].target_modules, + ) + self.model.lora_configs[0].path = lora_path + self.explorer.rollout_model.lora_modules = [ + { + "lora_int_id": i + 1, + "lora_name": cfg.name, + "lora_path": cfg.path, + "base_model_name": cfg.base_model_name, + } + for i, cfg in enumerate(self.model.lora_configs) + ] + self.explorer.rollout_model.lora_kwargs = { + "max_loras": len(self.model.lora_configs), + "max_lora_rank": max( + ( + model_config.lora_rank + for model_config in self.model.lora_configs + if model_config.lora_rank > 0 + ), + default=0, + ), + "default_lora_path": os.path.join( + self.checkpoint_job_dir, "global_step_0", "actor", "lora_adapter" + ), # will be poped later + } + # check synchronizer self.synchronizer.ray_namespace = self.ray_namespace self.synchronizer.explorer_world_size = ( @@ -965,7 +1034,7 @@ def check_and_update(self) -> Config: # noqa: C901 # check buffer self._check_buffer() # check and update trainer - if self.mode in ["train", "both"]: + if self.mode in ["train", "both", "bench"]: if self.trainer.trainer_type == "verl": if self.trainer.trainer_config: from trinity.common.verl_config import veRLConfig diff --git a/trinity/common/models/model.py b/trinity/common/models/model.py index 80f978f227..7bbdd32272 100644 --- a/trinity/common/models/model.py +++ b/trinity/common/models/model.py @@ -10,6 +10,7 @@ import ray import torch from torch import Tensor +from vllm.lora.request import LoRARequest from trinity.common.constants import RunningStatus from trinity.common.experience import Experience @@ -77,13 +78,20 @@ def sync_wrapper(self, *args, **kwargs): class ModelWrapper: """A wrapper for the InferenceModel Ray Actor""" - def __init__(self, model: Any, engine_type: str = "vllm", enable_history: bool = False): + def __init__( + self, + model: Any, + engine_type: str = "vllm", + enable_lora: bool = False, + enable_history: bool = False, + ): assert engine_type.startswith("vllm"), "Only vLLM model is supported for now." self.model = model self.api_address: str = None self.openai_client: openai.OpenAI = None self.openai_async_client: openai.AsyncOpenAI = None self.logger = get_logger(__name__) + self.enable_lora = enable_lora self.enable_history = enable_history self.history = [] self.status = RunningStatus.RUNNING @@ -124,14 +132,18 @@ def _record_history(self, exps: Union[Experience, List[Experience]]) -> None: @_history_recorder def generate(self, prompts: List[str], **kwargs) -> List[Experience]: """Generate a list of experiences from a list of prompts.""" - results = ray.get([self.model.generate.remote(prompt, **kwargs) for prompt in prompts]) + lora_request = self.get_lora_request() + results = ray.get( + [self.model.generate.remote(prompt, lora_request, **kwargs) for prompt in prompts] + ) return [exp for exps in results for exp in exps] @_history_recorder async def generate_async(self, prompts: List[str], **kwargs) -> List[Experience]: """Generate a list of experiences from a list of prompts in async.""" + lora_request = await self.get_lora_request_async() results = await asyncio.gather( - *[self.model.generate.remote(prompt, **kwargs) for prompt in prompts] + *[self.model.generate.remote(prompt, lora_request, **kwargs) for prompt in prompts] ) return [exp for exps in results for exp in exps] @@ -163,12 +175,14 @@ async def generate_mm_async( @_history_recorder def chat(self, messages: List[dict], **kwargs) -> List[Experience]: """Generate a list of experiences from a list of messages.""" - return ray.get(self.model.chat.remote(messages, **kwargs)) + lora_request = self.get_lora_request() + return ray.get(self.model.chat.remote(messages, lora_request, **kwargs)) @_history_recorder async def chat_async(self, messages: List[dict], **kwargs) -> List[Experience]: """Generate a list of experiences from a list of messages in async.""" - return await self.model.chat.remote(messages, **kwargs) + lora_request = await self.get_lora_request_async() + return await self.model.chat.remote(messages, lora_request, **kwargs) @_history_recorder def chat_mm(self, messages: List[dict], raw_mm_data: dict, **kwargs) -> List[Experience]: @@ -206,6 +220,18 @@ async def model_version_async(self) -> int: """Get the version of the model.""" return await self.model.get_model_version.remote() + def get_lora_request(self) -> Optional[LoRARequest]: + if self.enable_lora: + return ray.get(self.model.get_lora_request.remote()) + else: + return None + + async def get_lora_request_async(self) -> Optional[LoRARequest]: + if self.enable_lora: + return await self.model.get_lora_request.remote() + else: + return None + def get_openai_client(self) -> openai.OpenAI: """Get the openai client. diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index d592c98654..1cb2d19f48 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -9,6 +9,7 @@ import vllm from packaging.version import parse as parse_version from transformers import AutoProcessor +from vllm.lora.request import LoRARequest from vllm.sampling_params import RequestOutputKind from trinity.common.config import InferenceModelConfig @@ -63,6 +64,8 @@ def __init__( self.enable_thinking = config.enable_thinking self.request_id = 0 max_model_len = config.max_model_len + self.enable_lora = config.enable_lora + self.default_lora_path = config.lora_kwargs.pop("default_lora_path", None) engine_args = vllm.AsyncEngineArgs( model=config.model_path, enforce_eager=config.enforce_eager, @@ -78,6 +81,8 @@ def __init__( gpu_memory_utilization=config.gpu_memory_utilization, enable_chunked_prefill=config.enable_chunked_prefill, # max_num_batched_tokens=256, # you can further set this parameter to reduce the vllm peak memory usage + enable_lora=config.enable_lora, + **config.lora_kwargs, ) if get_vllm_version() > parse_version("0.10.0"): engine_args.enable_log_requests = False @@ -98,8 +103,10 @@ def __init__( async def _initialize_tokenizer(self): if self.tokenizer is None: - if self.processor and hasattr(self.processor, "tokenizer"): - self.tokenizer = self.processor.tokenizer + if self.enable_lora: + self.tokenizer = await self.async_llm.get_tokenizer( + lora_request=self.get_lora_request() + ) else: self.tokenizer = await self.async_llm.get_tokenizer() self.tokenizer.truncation_side = "left" @@ -110,7 +117,9 @@ def _initialize_processor(self): ) self.tokenizer = self.processor.tokenizer - async def chat(self, messages: List[Dict], **kwargs) -> Sequence[Experience]: + async def chat( + self, messages: List[Dict], lora_request: LoRARequest = None, **kwargs + ) -> Sequence[Experience]: """Chat with the model with a list of messages in async. Args: @@ -139,9 +148,11 @@ async def chat(self, messages: List[Dict], **kwargs) -> Sequence[Experience]: chat_template=self.chat_template, enable_thinking=self.enable_thinking, ) - return await self.generate(prompt=prompt, **kwargs) + return await self.generate(prompt=prompt, lora_request=lora_request, **kwargs) - async def generate(self, prompt: str, **kwargs) -> Sequence[Experience]: + async def generate( + self, prompt: str, lora_request: LoRARequest = None, **kwargs + ) -> Sequence[Experience]: """Generate a response from the provided prompt in async. Args: @@ -156,7 +167,9 @@ async def generate(self, prompt: str, **kwargs) -> Sequence[Experience]: token_ids = self.tokenizer( # type: ignore prompt, truncation=True, max_length=self.config.max_prompt_tokens, return_tensors="pt" )["input_ids"][0].tolist() - output = await self._generate_internal(prompt={"prompt_token_ids": token_ids}, **kwargs) + output = await self._generate_internal( + prompt={"prompt_token_ids": token_ids}, lora_request=lora_request, **kwargs + ) experiences = [ Experience( tokens=torch.cat( @@ -287,7 +300,9 @@ async def generate_mm( ] return experiences - async def logprobs(self, token_ids: List[int]) -> torch.Tensor: + async def logprobs( + self, token_ids: List[int], lora_request: LoRARequest = None + ) -> torch.Tensor: """Calculate the logprobs of the given tokens in async. Please slice the result carefully to align with the actual response length. @@ -300,6 +315,7 @@ async def logprobs(self, token_ids: List[int]) -> torch.Tensor: """ output = await self._generate_internal( prompt={"prompt_token_ids": token_ids}, + lora_request=lora_request, n=1, max_tokens=1, prompt_logprobs=0, # vLLM return `prompt_logprobs + 1` logrpobs for each token @@ -309,13 +325,16 @@ async def logprobs(self, token_ids: List[int]) -> torch.Tensor: dtype=torch.float32, ) - async def _generate_internal(self, prompt: Any, **kwargs) -> Any: + async def _generate_internal( + self, prompt: Any, lora_request: LoRARequest = None, **kwargs + ) -> Any: # Send the request to the LLM engine. self.request_id += 1 stream = self.async_llm.generate( request_id=str(self.request_id), prompt=prompt, sampling_params=self._create_sampling_params(**kwargs), + lora_request=lora_request, ) # Consume the stream until the request is finished. @@ -334,7 +353,7 @@ async def convert_messages_to_experience( ) -> Experience: """Convert a list of messages into an experience.""" if self.tokenizer is None: - self.tokenizer = await self.async_llm.get_tokenizer() + await self._initialize_tokenizer() if self.chat_template is None: self.chat_template = self.tokenizer.get_chat_template() token_ids, action_mask, prompt_length = self.action_mask_method( @@ -394,6 +413,20 @@ async def _collective_rpc( async def sync_model(self, model_version: int) -> int: """Sync model weights to vLLM.""" + if self.enable_lora: + # Revise the lora path; no need to sync weights manually. + self.default_lora_path = self.default_lora_path.replace( + f"global_step_{self.model_version}", f"global_step_{model_version}" + ) + self.logger.info( + f"Redirect `lora_path` from old_model_version={self.model_version} to {model_version=} successfully." + ) + lora_int_ids = await self.async_llm.list_loras() + for lora_id in lora_int_ids: + await self.async_llm.remove_lora(lora_id) + await self.async_llm.add_lora(self.get_lora_request(self.default_lora_path)) + self.model_version = model_version + return model_version await self._collective_rpc("update_weight") self.logger.info("Sync model weights to vLLM successfully.") self.model_version = model_version @@ -465,6 +498,14 @@ async def reset_prefix_cache(self) -> None: def get_model_version(self) -> int: return self.model_version + def get_lora_request(self, lora_path: Optional[str] = None) -> LoRARequest: + assert self.config.lora_modules is not None + lora_request = LoRARequest(**self.config.lora_modules[0]) + if lora_path is not None: + self.config.lora_modules[0]["lora_path"] = lora_path # for consistency + lora_request.lora_path = lora_path + return lora_request + async def sleep(self, level: int = 1) -> None: await self.async_llm.sleep(level=level) diff --git a/trinity/common/verl_config.py b/trinity/common/verl_config.py index 639ec4de52..e05773e63b 100644 --- a/trinity/common/verl_config.py +++ b/trinity/common/verl_config.py @@ -33,6 +33,12 @@ class ActorModel: fused_kernel_options: FusedKernelOptions = field(default_factory=FusedKernelOptions) custom_chat_template: Optional[str] = None enable_activation_offload: bool = False + use_shm: bool = False + + # lora configs + lora_rank: int = 0 # The rank of the LoRA model, default to 0. If lora_rank > 0, LoRA module is enabled in trainer + lora_alpha: int = 32 + target_modules: Optional[str] = "all-linear" @dataclass @@ -187,8 +193,10 @@ class Rollout: multi_turn: _MultiTurn = field(default_factory=_MultiTurn) temperature: float = 1.0 n: int = 1 # > 1 for grpo + log_prob_use_dynamic_bsz: bool = True log_prob_micro_batch_size: Optional[int] = None - log_prob_micro_batch_size_per_gpu: int = 1 + log_prob_micro_batch_size_per_gpu: Optional[int] = None + log_prob_max_token_len_per_gpu: Optional[int] = None @dataclass @@ -435,6 +443,24 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901 if config.trainer.actor_grad_clip is not None: self.actor_rollout_ref.actor.grad_clip = config.trainer.actor_grad_clip + # LoRA related config + if config.model.lora_configs is not None: + self.actor_rollout_ref.model.lora_rank = config.model.lora_configs[0].lora_rank + self.actor_rollout_ref.model.lora_alpha = config.model.lora_configs[0].lora_alpha + self.actor_rollout_ref.model.target_modules = config.model.lora_configs[ + 0 + ].target_modules + if self.actor_rollout_ref.actor.strategy not in ["fsdp", "fsdp2"]: + logger.warning( + f"Lora is only supported for fsdp and fsdp2, but got {self.actor_rollout_ref.actor.strategy} instead, changed to fsdp." + ) + self.actor_rollout_ref.actor.strategy = "fsdp" + if self.critic.strategy not in ["fsdp", "fsdp2"]: + logger.warning( + f"Lora is only supported for fsdp and fsdp2, but got {self.critic.strategy} instead, changed to fsdp." + ) + self.critic.strategy = "fsdp" + # Algorithm related config self.actor_rollout_ref.actor.use_kl_loss = config.algorithm.kl_loss_fn != "none" self.algorithm.use_kl_in_reward = config.algorithm.kl_penalty_fn != "none" @@ -449,6 +475,25 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901 self.actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu *= 2 if self.actor_rollout_ref.rollout.n != 2: self.actor_rollout_ref.rollout.n = 2 + + # check rollout config (only works for lora) + self.actor_rollout_ref.rollout.n = config.algorithm.repeat_times + self.actor_rollout_ref.rollout.log_prob_use_dynamic_bsz = ( + self.actor_rollout_ref.actor.use_dynamic_bsz + ) + if self.actor_rollout_ref.rollout.log_prob_micro_batch_size is None: + self.actor_rollout_ref.rollout.log_prob_micro_batch_size = ( + self.actor_rollout_ref.actor.ppo_micro_batch_size + ) + if self.actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu is None: + self.actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu = ( + self.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu + ) + if self.actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu is None: + self.actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu = ( + self.actor_rollout_ref.actor.ppo_max_token_len_per_gpu + ) + # TODO: check other fields self.enable_preview = config.trainer.enable_preview diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index e90a82cf8c..a10b523af2 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -73,6 +73,7 @@ def __init__(self, config: Config): # For checkpoint weights update # Use explorer to periodically load the latest model weights and # boradcast to all rollout models + self.enable_lora = self.config.explorer.rollout_model.enable_lora self.model_version = -1 self.last_sync_successful = True self.logger.info("Finished initializing Explorer.") diff --git a/trinity/explorer/workflow_runner.py b/trinity/explorer/workflow_runner.py index 2d1ddddc31..d841e0e625 100644 --- a/trinity/explorer/workflow_runner.py +++ b/trinity/explorer/workflow_runner.py @@ -39,6 +39,7 @@ def __init__( self.model_wrapper = ModelWrapper( model, config.explorer.rollout_model.engine_type, + enable_lora=config.explorer.rollout_model.enable_lora, enable_history=config.explorer.rollout_model.enable_history, ) self.auxiliary_models = [ diff --git a/trinity/manager/synchronizer.py b/trinity/manager/synchronizer.py index c0992f56a5..0e0775693d 100644 --- a/trinity/manager/synchronizer.py +++ b/trinity/manager/synchronizer.py @@ -33,6 +33,7 @@ class Synchronizer: def __init__(self, config: Config, module_ref: ray.actor.ActorHandle): self.logger = get_logger("synchronizer", in_ray_actor=True) self.config = config + self.enable_lora = config.explorer.rollout_model.enable_lora self.trainer_status = RunningStatus.STOPPED self.explorer_status_counts: Dict[RunningStatus, int] = defaultdict(lambda: 0) self._ready_condition = asyncio.Condition() diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py index 4f6b33c3de..aeaf8ac3a0 100644 --- a/trinity/trainer/verl_trainer.py +++ b/trinity/trainer/verl_trainer.py @@ -179,7 +179,6 @@ def init_workers(self): role="ref", ) self.resource_pool_to_cls[resource_pool]["ref"] = ref_policy_cls - # create a reward model if reward_fn is None if self.use_rm: # we create a RM here @@ -289,7 +288,10 @@ def train_step(self, batch: Experiences) -> Dict: # noqa C901 if self.algorithm.use_reference: # ref_logprob may not be used # compute reference log_prob with marked_timer("ref", timing_raw): - ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) + if not self.ref_in_actor: + ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) + else: + ref_log_prob = self.actor_rollout_wg.compute_ref_log_prob(batch) batch = batch.union(ref_log_prob) if self.algorithm.use_critic: diff --git a/trinity/utils/lora_utils.py b/trinity/utils/lora_utils.py new file mode 100644 index 0000000000..3e881016aa --- /dev/null +++ b/trinity/utils/lora_utils.py @@ -0,0 +1,26 @@ +import torch +from peft import LoraConfig, TaskType, get_peft_model +from transformers import AutoConfig, AutoModelForCausalLM + + +def create_dummy_lora( + model_path: str, + checkpoint_job_dir: str, + lora_rank: int, + lora_alpha: int, + target_modules: str, +) -> str: + config = AutoConfig.from_pretrained(model_path) + model = AutoModelForCausalLM.from_config(config) + lora_config = { + "task_type": TaskType.CAUSAL_LM, + "r": lora_rank, + "lora_alpha": lora_alpha, + "target_modules": target_modules, + "bias": "none", + } + peft_model = get_peft_model(model, LoraConfig(**lora_config)) + peft_model.save_pretrained(f"{checkpoint_job_dir}/dummy_lora") + del model, peft_model + torch.cuda.empty_cache() + return f"{checkpoint_job_dir}/dummy_lora"