diff --git a/.gitignore b/.gitignore index 4fbd2436..5fa6202f 100644 --- a/.gitignore +++ b/.gitignore @@ -113,4 +113,7 @@ tests/e2e/toy_examples/deepspeed/synchronous/output.txt *.swp # ckpt -*.lock \ No newline at end of file +*.lock + +# data +*.parquet diff --git a/examples/data_preprocess/geo3k.py b/examples/data_preprocess/geo3k.py new file mode 100644 index 00000000..cf4a26fc --- /dev/null +++ b/examples/data_preprocess/geo3k.py @@ -0,0 +1,83 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Preprocess the Geometry3k dataset to parquet format +""" + +import os +import datasets + +from verl.utils.hdfs_io import copy, makedirs +import argparse + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--local_dir', default='~/data/geo3k') + parser.add_argument('--hdfs_dir', default=None) + + args = parser.parse_args() + + data_source = 'hiyouga/geometry3k' + + dataset = datasets.load_dataset(data_source) + + train_dataset = dataset['train'] + test_dataset = dataset['test'] + + instruction_following = r"Please reason step by step, and put your final answer within \boxed{}." + + # add a row to each data item that represents a unique id + def make_map_fn(split): + + def process_fn(example, idx): + problem = example.pop('problem') + prompt = problem + ' ' + instruction_following + answer = example.pop('answer') + images = example.pop('images') + + data = { + "data_source": data_source, + "prompt": [{ + "role": "user", + "content": prompt, + }], + "images": images, + "ability": "math", + "reward_model": { + "style": "rule", + "ground_truth": answer + }, + "extra_info": { + 'split': split, + 'index': idx, + 'answer': answer, + "question": problem, + } + } + return data + + return process_fn + + train_dataset = train_dataset.map(function=make_map_fn('train'), with_indices=True) + test_dataset = test_dataset.map(function=make_map_fn('test'), with_indices=True) + + local_dir = args.local_dir + hdfs_dir = args.hdfs_dir + + train_dataset.to_parquet(os.path.join(local_dir, 'train.parquet')) + test_dataset.to_parquet(os.path.join(local_dir, 'test.parquet')) + + if hdfs_dir is not None: + makedirs(hdfs_dir) + copy(src=local_dir, dst=hdfs_dir) diff --git a/examples/grpo_trainer/run_qwen2_5_vl-7b.sh b/examples/grpo_trainer/run_qwen2_5_vl-7b.sh new file mode 100644 index 00000000..add2c9bc --- /dev/null +++ b/examples/grpo_trainer/run_qwen2_5_vl-7b.sh @@ -0,0 +1,43 @@ +set -x + +export VLLM_ATTENTION_BACKEND=XFORMERS + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=$HOME/data/geo3k/train.parquet \ + data.val_files=$HOME/data/geo3k/test.parquet \ + data.train_batch_size=512 \ + data.max_prompt_length=1536 \ + data.max_response_length=1536 \ + data.image_key=images \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-7B-Instruct \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=128 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=10 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=20 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.enable_chunked_prefill=False \ + actor_rollout_ref.rollout.enforce_eager=False \ + actor_rollout_ref.rollout.free_cache_engine=False \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=20 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.kl_ctrl.kl_coef=0.001 \ + trainer.critic_warmup=0 \ + trainer.logger=['console','wandb'] \ + trainer.project_name='verl_grpo_example_geo3k' \ + trainer.experiment_name='qwen2_5_vl_7b_function_rm' \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=-1 \ + trainer.test_freq=5 \ + trainer.total_epochs=15 $@ diff --git a/scripts/model_merger.py b/scripts/model_merger.py index 1595a9be..c9a669ce 100644 --- a/scripts/model_merger.py +++ b/scripts/model_merger.py @@ -17,7 +17,7 @@ import os import torch import argparse -from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForTokenClassification +from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForTokenClassification, AutoModelForVision2Seq from concurrent.futures import ThreadPoolExecutor from torch.distributed._tensor import DTensor, Shard, Placement @@ -140,6 +140,8 @@ def process_one_shard(rank): auto_model = AutoModelForTokenClassification elif 'ForCausalLM' in config.architectures[0]: auto_model = AutoModelForCausalLM + elif 'ForConditionalGeneration' in config.architectures[0]: + auto_model = AutoModelForVision2Seq else: raise NotImplementedError(f'Unknown architecture {config["architectures"]}') diff --git a/verl/models/registry.py b/verl/models/registry.py index a363cc8b..a735d396 100644 --- a/verl/models/registry.py +++ b/verl/models/registry.py @@ -19,18 +19,25 @@ # Supported models using HF Rmpad # TODO(sgm): HF may supported more than listed here, we should add more after testing -from transformers import LlamaConfig, MistralConfig, GemmaConfig, Qwen2Config - -_REOVEPAD_MODELS = {'llama': LlamaConfig, 'mistral': MistralConfig, 'gemma': GemmaConfig, 'qwen2': Qwen2Config} +_MODELS_SUPPORT_RMPAD = {'llama', 'mistral', 'gemma', 'qwen2', 'qwen2_vl', 'qwen2_5_vl'} def check_model_support_rmpad(model_type: str): assert isinstance(model_type, str) - if not model_type in _REOVEPAD_MODELS.keys(): + if not model_type in _MODELS_SUPPORT_RMPAD: raise ValueError(f"Model architecture {model_type} is not supported for now. " - f"RMPad supported architectures: {_REOVEPAD_MODELS.keys()}." + f"RMPad supported architectures: {_MODELS_SUPPORT_RMPAD}." f"Please set `use_remove_padding=False` in the model config.") + if model_type in ("qwen2_vl", "qwen2_5_vl"): # patch remove padding for qwen2vl mrope + from verl.models.transformers.qwen2_vl import ulysses_flash_attn_forward + from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLFlashAttention2 + from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLFlashAttention2 + + Qwen2VLFlashAttention2.forward = ulysses_flash_attn_forward + Qwen2_5_VLFlashAttention2.forward = ulysses_flash_attn_forward + print("Qwen2vl patch applied!") + # Supported models in Megatron-LM # Architecture -> (module, class). diff --git a/verl/models/transformers/qwen2_vl.py b/verl/models/transformers/qwen2_vl.py new file mode 100644 index 00000000..b72fd3b9 --- /dev/null +++ b/verl/models/transformers/qwen2_vl.py @@ -0,0 +1,273 @@ +from typing import Optional, Tuple +import inspect +import torch +import os +from transformers.utils import is_flash_attn_greater_or_equal +from transformers.modeling_flash_attention_utils import _flash_attention_forward +from verl.utils.ulysses import gather_heads_scatter_seq, gather_seq_scatter_heads, get_ulysses_sequence_parallel_world_size + +try: + from flash_attn import flash_attn_func, flash_attn_varlen_func + + _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) +except ImportError: + flash_attn_varlen_func = None + + +def get_rope_index( + processor, + input_ids: torch.Tensor, + image_grid_thw: Optional[torch.Tensor] = None, + video_grid_thw: Optional[torch.Tensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """ + Gets the position ids for Qwen2-VL, it should be generated before sharding the sequence. + The batch dim has been removed and the input_ids should be a 1D tensor representing a single example. + https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py#L1546 + """ + spatial_merge_size = processor.image_processor.merge_size + tokens_per_second = 2 + image_token_id = processor.tokenizer.convert_tokens_to_ids("<|image_pad|>") + video_token_id = processor.tokenizer.convert_tokens_to_ids("<|video_pad|>") + vision_start_token_id = processor.tokenizer.convert_tokens_to_ids("<|vision_start|>") + if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + + position_ids = torch.ones(3, input_ids.size(0), dtype=input_ids.dtype, device=input_ids.device) # (3, seqlen) + image_index, video_index = 0, 0 + input_ids = input_ids[attention_mask == 1] + image_nums, video_nums = 0, 0 + vision_start_indices = torch.argwhere(input_ids == vision_start_token_id) + vision_tokens = input_ids[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + input_tokens = input_ids.tolist() + llm_pos_ids_list: list = [] + st = 0 + remain_images, remain_videos = image_nums, video_nums + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + second_per_grid_t = 0 + image_index += 1 + remain_images -= 1 + ed = ed_image + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + if second_per_grid_ts is not None: + second_per_grid_t = second_per_grid_ts[video_index] + else: + second_per_grid_t = 1.0 + + video_index += 1 + remain_videos -= 1 + ed = ed_video + + llm_grid_t, llm_grid_h, llm_grid_w = ( + t.item(), + h.item() // spatial_merge_size, + w.item() // spatial_merge_size, + ) + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w) + t_index = (t_index * second_per_grid_t * tokens_per_second).long().flatten() + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() + llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[..., attention_mask == 1] = llm_positions.to(position_ids.device) + else: + if attention_mask is not None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = position_ids.unsqueeze(0).expand(3, -1).to(input_ids.device) + else: + position_ids = torch.arange(input_ids.shape[1], device=input_ids.device).view(1, -1).expand(3, -1) + + return position_ids + + +def prepare_fa2_from_position_ids(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + position_ids: torch.Tensor): + query = query.view(-1, query.size(-2), query.size(-1)) + key = key.view(-1, key.size(-2), key.size(-1)) + value = value.view(-1, value.size(-2), value.size(-1)) + position_ids = position_ids.flatten() + indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32) + cu_seqlens = torch.cat(( + indices_q[position_ids == 0], + torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32), + )) + max_length = cu_seqlens.diff().max() # use cu_seqlens to infer max_length for qwen2vl mrope + return (query, key, value, indices_q, (cu_seqlens, cu_seqlens), (max_length, max_length)) + + +def flash_attention_forward( + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + attention_mask: torch.Tensor, + query_length: int, + is_causal: bool = True, + position_ids: Optional[torch.Tensor] = None, + sliding_window: Optional[int] = None, + use_top_left_mask: bool = False, + deterministic: Optional[bool] = None, + **kwargs, +): + """ + Patches flash attention forward to handle 3D position ids in mrope. (3, batch_size, seq_length) + """ + if not use_top_left_mask: + causal = is_causal + else: + causal = is_causal and query_length != 1 + + # Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length). + use_sliding_windows = (_flash_supports_window_size and sliding_window is not None and + key_states.shape[1] > sliding_window) + flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {} + + if is_flash_attn_greater_or_equal("2.4.1"): + if deterministic is None: + deterministic = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1" + flash_kwargs["deterministic"] = deterministic + + if position_ids is not None and query_length != 1 and not (torch.diff(position_ids[0], dim=-1) >= 0).all(): + batch_size = query_states.size(0) + query_states, key_states, value_states, _, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids( + query_states, key_states, value_states, position_ids[0]) # remove channel dimension + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + attn_output = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=kwargs.pop("dropout", 0.0), + softmax_scale=kwargs.pop("softmax_scale", None), + causal=causal, + **flash_kwargs, + ) + attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1)) + else: + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + query_length, + is_causal=is_causal, + sliding_window=sliding_window, + use_top_left_mask=use_top_left_mask, + deterministic=deterministic, + **kwargs, + ) # do not pass position_ids to old flash_attention_forward + + return attn_output + + +def ulysses_flash_attn_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs, +) -> Tuple[torch.Tensor, None, None]: + from transformers.models.qwen2_vl.modeling_qwen2_vl import repeat_kv, apply_multimodal_rotary_pos_emb + + bsz, q_len, _ = hidden_states.size() # q_len = seq_length / sp_size + query_states = self.q_proj(hidden_states) # (batch_size, seq_length / sp_size, num_heads * head_size) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + ulysses_sp_size = get_ulysses_sequence_parallel_world_size() + + if ulysses_sp_size > 1: + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1) + key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1) + value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1) + # (batch_size, num_head / sp_size, seq_length, head_size) + full_q_len = query_states.size(2) # full_q_len = seq_length + else: + full_q_len = q_len + + # Because the input can be padded, the absolute sequence length depends on the max position id. + if position_embeddings is None: + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + + query_states, key_states = apply_multimodal_rotary_pos_emb(query_states, key_states, cos, sin, + self.rope_scaling["mrope_section"]) + dropout_rate = 0.0 if not self.training else self.attention_dropout + + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + if (self.config.use_sliding_window and getattr(self.config, "sliding_window", None) is not None and + self.layer_idx >= self.config.max_window_layers): + sliding_window = self.config.sliding_window + else: + sliding_window = None + + attn_output = flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + full_q_len, + dropout=dropout_rate, + sliding_window=sliding_window, + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + position_ids=position_ids, # important: pass position ids + ) # (batch_size, seq_length, num_head / sp_size, head_size) + if ulysses_sp_size > 1: + attn_output = gather_heads_scatter_seq(attn_output, head_dim=2, seq_dim=1) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, None, None diff --git a/verl/protocol.py b/verl/protocol.py index 737ec140..7f216ad0 100644 --- a/verl/protocol.py +++ b/verl/protocol.py @@ -24,6 +24,7 @@ from typing import Callable, Dict, List, Union import torch +import torch.distributed import tensordict from tensordict import TensorDict from torch.utils.data import DataLoader, Dataset @@ -597,6 +598,33 @@ def repeat(self, repeat_times=2, interleave=True): meta_info=self.meta_info, ) + def broadcast(self, src, group=None): + for key in self.batch.sorted_keys: + torch.distributed.broadcast(self.batch[key], src=src, group=group, async_op=False) + + object_list = [self.non_tensor_batch] + torch.distributed.broadcast_object_list(object_list, src=src, group=group) + self.non_tensor_batch = object_list[0] + + def all_gather(self, group=None): + world_size = torch.distributed.get_world_size(group) + output = {} + for key in self.batch.sorted_keys: + value = self.batch[key].contiguous() + output[key] = [torch.empty_like(value) for _ in range(world_size)] + torch.distributed.all_gather(output[key], value, group=group, async_op=False) + output[key] = torch.cat(output[key], dim=0) + + self.batch = TensorDict(output, batch_size=self.batch.batch_size[0] * world_size) + + # all gather non_tensor_batch + all_non_tensor_batch = [None for _ in range(world_size)] + torch.distributed.all_gather_object(all_non_tensor_batch, self.non_tensor_batch, group=group) + self.non_tensor_batch = { + key: np.concatenate([batch[key] for batch in all_non_tensor_batch]) for key in self.non_tensor_batch + } + self.check_consistency() + import ray diff --git a/verl/third_party/vllm/vllm_spmd/dtensor_weight_loaders.py b/verl/third_party/vllm/vllm_spmd/dtensor_weight_loaders.py index a3042cab..c3fd92c8 100644 --- a/verl/third_party/vllm/vllm_spmd/dtensor_weight_loaders.py +++ b/verl/third_party/vllm/vllm_spmd/dtensor_weight_loaders.py @@ -203,7 +203,11 @@ def qwen2vl_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue - name = name.replace(weight_name, param_name) + + if "visual" in name: + continue + + name = "language_model." + name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue @@ -216,6 +220,11 @@ def qwen2vl_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if "visual" in name: + name = name + else: + name = "language_model." + name + param = params_dict[name] local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) weight_loader = getattr(param, "weight_loader", default_weight_loader) @@ -355,6 +364,7 @@ def _process_parameter_names(name): "Qwen2ForCausalLM": qwen2_dtensor_weight_loader, "DeepseekV2ForCausalLM": deepseekv2_dtensor_weight_loader, "Qwen2VLForConditionalGeneration": qwen2vl_dtensor_weight_loader, + "Qwen2_5_VLForConditionalGeneration": qwen2vl_dtensor_weight_loader, } diff --git a/verl/trainer/config/ppo_trainer.yaml b/verl/trainer/config/ppo_trainer.yaml index bf919089..d6f2e96c 100644 --- a/verl/trainer/config/ppo_trainer.yaml +++ b/verl/trainer/config/ppo_trainer.yaml @@ -10,6 +10,7 @@ data: return_raw_input_ids: False # This should be set to true when the tokenizer between policy and rm differs return_raw_chat: False shuffle: True + image_key: images actor_rollout_ref: hybrid_engine: True diff --git a/verl/trainer/main_ppo.py b/verl/trainer/main_ppo.py index 9fac8fec..41364d7d 100644 --- a/verl/trainer/main_ppo.py +++ b/verl/trainer/main_ppo.py @@ -46,8 +46,9 @@ def main_task(config, compute_score=None): local_path = copy_to_local(config.actor_rollout_ref.model.path) # instantiate tokenizer - from verl.utils import hf_tokenizer + from verl.utils import hf_tokenizer, hf_processor tokenizer = hf_tokenizer(local_path) + processor = hf_processor(local_path, use_fast=True) # define worker classes if config.actor_rollout_ref.actor.strategy == 'fsdp': @@ -117,6 +118,7 @@ def main_task(config, compute_score=None): trainer = RayPPOTrainer(config=config, tokenizer=tokenizer, + processor=processor, role_worker_mapping=role_worker_mapping, resource_pool_manager=resource_pool_manager, ray_worker_group_cls=ray_worker_group_cls, diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index fa61f479..09bae2a5 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -350,6 +350,7 @@ class RayPPOTrainer(object): def __init__(self, config, tokenizer, + processor, role_worker_mapping: dict[Role, WorkerType], resource_pool_manager: ResourcePoolManager, ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup, @@ -359,6 +360,7 @@ def __init__(self, # assert torch.cuda.is_available(), 'cuda must be available on driver' self.tokenizer = tokenizer + self.processor = processor self.config = config self.reward_fn = reward_fn self.val_reward_fn = val_reward_fn @@ -491,7 +493,9 @@ def _create_dataloader(self): # TODO: we have to make sure the batch size is divisible by the dp size self.train_dataset = RLHFDataset(parquet_files=self.config.data.train_files, tokenizer=self.tokenizer, + processor=self.processor, prompt_key=self.config.data.prompt_key, + image_key=self.config.data.image_key, max_prompt_length=self.config.data.max_prompt_length, filter_prompts=True, return_raw_chat=self.config.data.get('return_raw_chat', False), @@ -506,13 +510,16 @@ def _create_dataloader(self): self.train_dataloader = DataLoader(dataset=self.train_dataset, batch_size=self.config.data.train_batch_size, + num_workers=8, drop_last=True, collate_fn=collate_fn, sampler=sampler) self.val_dataset = RLHFDataset(parquet_files=self.config.data.val_files, tokenizer=self.tokenizer, + processor=self.processor, prompt_key=self.config.data.prompt_key, + image_key=self.config.data.image_key, max_prompt_length=self.config.data.max_prompt_length, filter_prompts=True, return_raw_chat=self.config.data.get('return_raw_chat', False), @@ -522,6 +529,7 @@ def _create_dataloader(self): # Validation datasets are sent to inference engines as a whole batch, # which will schedule the memory themselves. batch_size=len(self.val_dataset), + num_workers=8, shuffle=True, drop_last=False, collate_fn=collate_fn) @@ -617,7 +625,17 @@ def _validate(self): input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids] sample_inputs.extend(input_texts) - test_gen_batch = test_batch.pop(['input_ids', 'attention_mask', 'position_ids']) + if 'multi_modal_inputs' in test_batch.non_tensor_batch.keys(): + test_gen_batch = test_batch.pop( + batch_keys=['input_ids', 'attention_mask', 'position_ids'], + non_tensor_batch_keys=['raw_prompt_ids', 'multi_modal_data', 'multi_modal_inputs'], + ) + else: + test_gen_batch = test_batch.pop( + batch_keys=['input_ids', 'attention_mask', 'position_ids'], + non_tensor_batch_keys=['raw_prompt_ids'], + ) + test_gen_batch.meta_info = { 'eos_token_id': self.tokenizer.eos_token_id, 'pad_token_id': self.tokenizer.pad_token_id, @@ -876,7 +894,16 @@ def fit(self): batch: DataProto = DataProto.from_single_dict(batch_dict) # pop those keys for generation - gen_batch = batch.pop(batch_keys=['input_ids', 'attention_mask', 'position_ids']) + if 'multi_modal_inputs' in batch.non_tensor_batch.keys(): + gen_batch = batch.pop( + batch_keys=['input_ids', 'attention_mask', 'position_ids'], + non_tensor_batch_keys=['raw_prompt_ids', 'multi_modal_data', 'multi_modal_inputs'], + ) + else: + gen_batch = batch.pop( + batch_keys=['input_ids', 'attention_mask', 'position_ids'], + non_tensor_batch_keys=['raw_prompt_ids'], + ) with _timer('step', timing_raw): # generate a batch diff --git a/verl/utils/__init__.py b/verl/utils/__init__.py index e453070a..bc781029 100644 --- a/verl/utils/__init__.py +++ b/verl/utils/__init__.py @@ -13,6 +13,6 @@ # limitations under the License. from . import tokenizer -from .tokenizer import * +from .tokenizer import hf_tokenizer, hf_processor __all__ = tokenizer.__all__ \ No newline at end of file diff --git a/verl/utils/checkpoint/checkpoint_manager.py b/verl/utils/checkpoint/checkpoint_manager.py index 11a8f55f..030c2dc6 100644 --- a/verl/utils/checkpoint/checkpoint_manager.py +++ b/verl/utils/checkpoint/checkpoint_manager.py @@ -15,11 +15,11 @@ import shutil from filelock import FileLock import tempfile - +from typing import Union import torch import torch.distributed from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, StateDictType -from transformers import PreTrainedTokenizer +from transformers import PreTrainedTokenizer, ProcessorMixin import numpy as np import random @@ -40,14 +40,15 @@ class BaseCheckpointManager: """ def __init__(self, model: FSDP, optimizer: torch.optim.Optimizer, - lr_scheduler: torch.optim.lr_scheduler.LRScheduler, tokenizer: PreTrainedTokenizer): + lr_scheduler: torch.optim.lr_scheduler.LRScheduler, processing_class: Union[PreTrainedTokenizer, + ProcessorMixin]): self.previous_global_step = None self.previous_save_local_path = None self.model = model self.optimizer = optimizer self.lr_scheduler = lr_scheduler - self.tokenizer = tokenizer + self.processing_class = processing_class assert isinstance(self.model, FSDP) self.rank = torch.distributed.get_rank() diff --git a/verl/utils/checkpoint/fsdp_checkpoint_manager.py b/verl/utils/checkpoint/fsdp_checkpoint_manager.py index 4d269e0c..ad9b0c85 100644 --- a/verl/utils/checkpoint/fsdp_checkpoint_manager.py +++ b/verl/utils/checkpoint/fsdp_checkpoint_manager.py @@ -16,7 +16,7 @@ import os import warnings - +from typing import Union import torch import torch.distributed from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, StateDictType @@ -24,7 +24,7 @@ from verl.utils.fs import copy_to_local, is_non_local -from transformers import PreTrainedTokenizer +from transformers import PreTrainedTokenizer, ProcessorMixin from .checkpoint_manager import BaseCheckpointManager @@ -41,12 +41,13 @@ class FSDPCheckpointManager(BaseCheckpointManager): We save - sharded model states and optimizer states - full lr_scheduler states - - huggingface tokenizer and config for ckpt merge + - huggingface tokenizer/processor and config for ckpt merge """ def __init__(self, model: FSDP, optimizer: torch.optim.Optimizer, - lr_scheduler: torch.optim.lr_scheduler.LRScheduler, tokenizer: PreTrainedTokenizer, *args, **kwargs): - super().__init__(model, optimizer, lr_scheduler, tokenizer) + lr_scheduler: torch.optim.lr_scheduler.LRScheduler, + processing_class: Union[PreTrainedTokenizer, ProcessorMixin], *args, **kwargs): + super().__init__(model, optimizer, lr_scheduler, processing_class) def load_checkpoint(self, path=None, del_local_after_load=False, *args, **kwargs): if path is None: @@ -142,7 +143,7 @@ def save_checkpoint(self, local_path: str, global_step: int, remove_previous_ckp hf_local_path = os.path.join(local_path, 'huggingface') os.makedirs(hf_local_path, exist_ok=True) self.model._fsdp_wrapped_module.config.save_pretrained(hf_local_path) - self.tokenizer.save_pretrained(hf_local_path) + self.processing_class.save_pretrained(hf_local_path) torch.distributed.barrier() diff --git a/verl/utils/dataset/rl_dataset.py b/verl/utils/dataset/rl_dataset.py index 6e0a5c90..17a97835 100644 --- a/verl/utils/dataset/rl_dataset.py +++ b/verl/utils/dataset/rl_dataset.py @@ -14,32 +14,29 @@ from omegaconf import ListConfig import os -from typing import List, Union +from typing import List, Union, Optional import copy import pandas as pd +from collections import defaultdict import torch import numpy as np from torch.utils.data import Dataset -from transformers import PreTrainedTokenizer +from transformers import PreTrainedTokenizer, ProcessorMixin from verl.utils.model import compute_position_id_with_mask import verl.utils.torch_functional as verl_F def collate_fn(data_list: list[dict]) -> dict: - tensors = {} - non_tensors = {} + tensors = defaultdict(list) + non_tensors = defaultdict(list) for data in data_list: for key, val in data.items(): if isinstance(val, torch.Tensor): - if key not in tensors: - tensors[key] = [] tensors[key].append(val) else: - if key not in non_tensors: - non_tensors[key] = [] non_tensors[key].append(val) for key, val in tensors.items(): @@ -48,10 +45,31 @@ def collate_fn(data_list: list[dict]) -> dict: for key, val in non_tensors.items(): non_tensors[key] = np.array(val, dtype=object) - output = {} - output.update(tensors) - output.update(non_tensors) - return output + return {**tensors, **non_tensors} + + +def process_image(image: dict, max_pixels: int = 2048 * 2048, min_pixels: int = 512 * 512): + import math + from io import BytesIO + from PIL import Image + + if isinstance(image, dict): + image = Image.open(BytesIO(image['bytes'])) + + if (image.width * image.height) > max_pixels: + resize_factor = math.sqrt(max_pixels / (image.width * image.height)) + width, height = int(image.width * resize_factor), int(image.height * resize_factor) + image = image.resize((width, height), resample=Image.Resampling.NEAREST) + + if (image.width * image.height) < min_pixels: + resize_factor = math.sqrt(min_pixels / (image.width * image.height)) + width, height = int(image.width * resize_factor), int(image.height * resize_factor) + image = image.resize((width, height), resample=Image.Resampling.NEAREST) + + if image.mode != 'RGB': + image = image.convert('RGB') + + return image class RLHFDataset(Dataset): @@ -62,7 +80,9 @@ class RLHFDataset(Dataset): def __init__(self, parquet_files: Union[str, List[str]], tokenizer: PreTrainedTokenizer, + processor: Optional[ProcessorMixin] = None, prompt_key='prompt', + image_key='images', max_prompt_length=1024, filter_prompts=True, cache_dir='~/.cache/verl/rlhf', @@ -76,8 +96,10 @@ def __init__(self, self.original_parquet_files = copy.deepcopy(parquet_files) # use for resume self.cache_dir = os.path.expanduser(cache_dir) self.tokenizer = tokenizer + self.processor = processor self.prompt_key = prompt_key + self.image_key = image_key self.max_prompt_length = max_prompt_length self.filter_prompts = filter_prompts @@ -132,12 +154,36 @@ def __getitem__(self, item): """ Note that we also return the raw_input_ids so that it can be combined with other chat template """ - row_dict = self.dataframe.iloc[item].to_dict() + row_dict: dict = self.dataframe.iloc[item].to_dict() chat = row_dict.pop(self.prompt_key) prompt_with_chat_template = self.tokenizer.apply_chat_template(chat, add_generation_prompt=True, tokenize=False) + if self.image_key in row_dict: # expand image token + raw_prompt = prompt_with_chat_template.replace('', '<|vision_start|><|image_pad|><|vision_end|>') + row_dict['multi_modal_data'] = {'image': [process_image(image) for image in row_dict.pop(self.image_key)]} + image_inputs = self.processor.image_processor(row_dict['multi_modal_data']['image'], return_tensors='pt') + image_grid_thw = image_inputs['image_grid_thw'] + row_dict['multi_modal_inputs'] = {key: val for key, val in image_inputs.items()} + + if image_grid_thw is not None: + merge_length = self.processor.image_processor.merge_size**2 + index = 0 + while '' in prompt_with_chat_template: + prompt_with_chat_template = prompt_with_chat_template.replace( + '', + '<|vision_start|>' + '<|placeholder|>' * (image_grid_thw[index].prod() // merge_length) + + '<|vision_end|>', + 1, + ) + index += 1 + + prompt_with_chat_template = prompt_with_chat_template.replace('<|placeholder|>', + self.processor.image_token) + else: + raw_prompt = prompt_with_chat_template + input_ids, attention_mask = verl_F.tokenize_and_postprocess_data(prompt=prompt_with_chat_template, tokenizer=self.tokenizer, max_length=self.max_prompt_length, @@ -145,11 +191,22 @@ def __getitem__(self, item): left_pad=True, truncation=self.truncation) - position_ids = compute_position_id_with_mask(attention_mask) + if self.image_key in row_dict: + from verl.models.transformers.qwen2_vl import get_rope_index + + position_ids = get_rope_index( + self.processor, + input_ids=input_ids[0], + image_grid_thw=image_grid_thw, + attention_mask=attention_mask[0], + ) # (3, seq_len) + else: + position_ids = compute_position_id_with_mask(attention_mask) row_dict['input_ids'] = input_ids[0] row_dict['attention_mask'] = attention_mask[0] row_dict['position_ids'] = position_ids[0] + row_dict['raw_prompt_ids'] = self.tokenizer.encode(raw_prompt, add_special_tokens=False) # encode prompts without chat template if self.return_raw_chat: diff --git a/verl/utils/flops_counter.py b/verl/utils/flops_counter.py index 3c5ac1a9..fd61d51a 100644 --- a/verl/utils/flops_counter.py +++ b/verl/utils/flops_counter.py @@ -13,9 +13,9 @@ # limitations under the License. import torch -from transformers import PretrainedConfig, Qwen2Config, LlamaConfig +from transformers import PretrainedConfig -VALID_CONFIG_TYPE = (Qwen2Config, LlamaConfig) +VALID_CONFIG_TYPE = {"llama", "qwen2", "qwen2_vl", "qwen2_5_vl"} def get_device_flops(unit="T"): @@ -59,18 +59,22 @@ class FlopsCounter: """ def __init__(self, config: PretrainedConfig): - if not isinstance(config, VALID_CONFIG_TYPE): - print(f"Only support config type of {VALID_CONFIG_TYPE}, but got {type(config)}. " + if not config.model_type in VALID_CONFIG_TYPE: + print(f"Only support config type of {VALID_CONFIG_TYPE}, but got {self.config.model_type}. " f"MFU will always be zero.") - self.estimate_func = {"qwen2": self._estimate_qwen2_flops, 'llama': self._estimate_qwen2_flops} + self.estimate_func = { + 'qwen2': self._estimate_qwen2_flops, + 'llama': self._estimate_qwen2_flops, + 'qwen2_vl': self._estimate_qwen2_flops, + 'qwen2_5_vl': self._estimate_qwen2_flops + } self.config = config def _estimate_unknown_flops(self, tokens_sum, batch_seqlens, delta_time): return 0 def _estimate_qwen2_flops(self, tokens_sum, batch_seqlens, delta_time): - assert isinstance(self.config, (Qwen2Config, LlamaConfig)) hidden_size = self.config.hidden_size vocab_size = self.config.vocab_size num_hidden_layers = self.config.num_hidden_layers diff --git a/verl/utils/reward_score/__init__.py b/verl/utils/reward_score/__init__.py index d14c6fc7..f90352c3 100644 --- a/verl/utils/reward_score/__init__.py +++ b/verl/utils/reward_score/__init__.py @@ -30,6 +30,9 @@ def _default_compute_score(data_source, solution_str, ground_truth, extra_info=N elif data_source in ['codecontests', 'apps', 'codeforces', 'taco']: from . import prime_code res = prime_code.compute_score(solution_str, ground_truth, continuous=True) + elif data_source in ['hiyouga/geometry3k']: + from . import geo3k + res = geo3k.compute_score(solution_str, ground_truth) else: raise NotImplementedError diff --git a/verl/utils/reward_score/geo3k.py b/verl/utils/reward_score/geo3k.py new file mode 100644 index 00000000..37c7c694 --- /dev/null +++ b/verl/utils/reward_score/geo3k.py @@ -0,0 +1,23 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from mathruler.grader import extract_boxed_content, grade_answer + + +def compute_score(predict_str: str, ground_truth: str) -> float: + answer = extract_boxed_content(predict_str) + if grade_answer(answer, ground_truth): + return 1.0 # correct answer + + return 0.0 # wrong answer diff --git a/verl/utils/tokenizer.py b/verl/utils/tokenizer.py index 578a8d53..a55b4e45 100644 --- a/verl/utils/tokenizer.py +++ b/verl/utils/tokenizer.py @@ -14,7 +14,7 @@ """Utils for tokenization.""" import warnings -__all__ = ['hf_tokenizer'] +__all__ = ['hf_tokenizer', 'hf_processor'] def set_pad_token_id(tokenizer): @@ -56,4 +56,17 @@ def hf_tokenizer(name_or_path, correct_pad_token=True, correct_gemma2=True, **kw tokenizer = AutoTokenizer.from_pretrained(name_or_path, **kwargs) if correct_pad_token: set_pad_token_id(tokenizer) - return tokenizer \ No newline at end of file + return tokenizer + + +def hf_processor(model_path, **kwargs): + from transformers import AutoProcessor + try: + processor = AutoProcessor.from_pretrained(model_path, **kwargs) + except Exception: + processor = None + # Avoid load tokenizer, see: + # https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/auto/processing_auto.py#L344 + if processor is not None and "Processor" not in processor.__class__.__name__: + processor = None + return processor diff --git a/verl/workers/actor/dp_actor.py b/verl/workers/actor/dp_actor.py index 6afe69b7..a6cf9940 100644 --- a/verl/workers/actor/dp_actor.py +++ b/verl/workers/actor/dp_actor.py @@ -62,11 +62,19 @@ def _forward_micro_batch(self, micro_batch, temperature) -> Tuple[torch.Tensor, log_probs: # (bs, response_len) """ response_length = micro_batch['responses'].size(-1) + multi_modal_inputs = {} + if 'multi_modal_inputs' in micro_batch: + for key in micro_batch['multi_modal_inputs'][0].keys(): + multi_modal_inputs[key] = torch.cat([inputs[key] for inputs in micro_batch['multi_modal_inputs']], + dim=0) + with torch.autocast(device_type='cuda', dtype=torch.bfloat16): input_ids = micro_batch['input_ids'] batch_size, seqlen = input_ids.shape attention_mask = micro_batch['attention_mask'] position_ids = micro_batch['position_ids'] + if position_ids.dim() == 3: # qwen2vl mrope + position_ids = position_ids.transpose(0, 1) # (bsz, 3, seqlen) -> (3, bsz, seqlen) if self.use_remove_padding: input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1), @@ -74,8 +82,13 @@ def _forward_micro_batch(self, micro_batch, temperature) -> Tuple[torch.Tensor, input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) # unpad the position_ids to align the rotary - position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), - indices).transpose(0, 1) + if position_ids.dim() == 3: + position_ids_rmpad = index_first_axis(rearrange(position_ids, "c b s ... -> (b s) c ..."), + indices).transpose(0, 1).unsqueeze( + 1) # (3, bsz, seqlen) -> (3, 1, bsz * seqlen) + else: + position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), + indices).transpose(0, 1) # for compute the log_prob input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=1) # (1, total_nnz) @@ -94,6 +107,7 @@ def _forward_micro_batch(self, micro_batch, temperature) -> Tuple[torch.Tensor, output = self.actor_module(input_ids=input_ids_rmpad, attention_mask=None, position_ids=position_ids_rmpad, + **multi_modal_inputs, use_cache=False) # prevent model thinks we are generating logits_rmpad = output.logits.squeeze(0) # (total_nnz, vocab_size) @@ -131,6 +145,7 @@ def _forward_micro_batch(self, micro_batch, temperature) -> Tuple[torch.Tensor, output = self.actor_module(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, + **multi_modal_inputs, use_cache=False) # prevent model thinks we are generating logits = output.logits logits.div_(temperature) @@ -177,8 +192,13 @@ def compute_log_prob(self, data: DataProto) -> torch.Tensor: select_keys = ['responses', 'input_ids', 'attention_mask', 'position_ids'] batch = data.select(batch_keys=select_keys).batch + has_multi_modal_inputs = 'multi_modal_inputs' in data.non_tensor_batch.keys() - if use_dynamic_bsz: + if has_multi_modal_inputs: + num_micro_batches = data.batch.batch_size[0] // micro_batch_size + non_tensor_select_keys = ['multi_modal_inputs'] + micro_batches = data.select(select_keys, non_tensor_select_keys).chunk(num_micro_batches) + elif use_dynamic_bsz: # split using dynamic bsz max_token_len = data.meta_info['max_token_len'] * self.ulysses_sequence_parallel_size micro_batches, indices = rearrange_micro_batches(batch=batch, max_token_len=max_token_len) @@ -187,6 +207,9 @@ def compute_log_prob(self, data: DataProto) -> torch.Tensor: log_probs_lst = [] for micro_batch in micro_batches: + if isinstance(micro_batch, DataProto): + micro_batch = {**micro_batch.batch, **micro_batch.non_tensor_batch} + with torch.no_grad(): _, log_probs = self._forward_micro_batch(micro_batch, temperature=temperature) log_probs_lst.append(log_probs) @@ -210,17 +233,27 @@ def update_policy(self, data: DataProto): if self.config.use_kl_loss: select_keys.append('ref_log_prob') batch = data.select(batch_keys=select_keys).batch + has_multi_modal_inputs = 'multi_modal_inputs' in data.non_tensor_batch.keys() # Split to make minibatch iterator for updating the actor # See PPO paper for details. https://arxiv.org/abs/1707.06347 - dataloader = batch.split(self.config.ppo_mini_batch_size) + if has_multi_modal_inputs: + num_mini_batches = data.batch.batch_size[0] // self.config.ppo_mini_batch_size + non_tensor_select_keys = ['multi_modal_inputs'] + dataloader = data.select(select_keys, non_tensor_select_keys).chunk(num_mini_batches) + else: + dataloader = batch.split(self.config.ppo_mini_batch_size) metrics = {} for epoch in range(self.config.ppo_epochs): for batch_idx, data in enumerate(dataloader): # split batch into micro_batches mini_batch = data - if self.config.use_dynamic_bsz: + if has_multi_modal_inputs: + self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu + num_micro_batches = mini_batch.batch.batch_size[0] // self.config.ppo_micro_batch_size_per_gpu + micro_batches = data.select(select_keys, non_tensor_select_keys).chunk(num_micro_batches) + elif self.config.use_dynamic_bsz: max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len) else: @@ -231,7 +264,11 @@ def update_policy(self, data: DataProto): self.actor_optimizer.zero_grad() for data in micro_batches: - data = data.cuda() # actor device is cpu when using offload + if isinstance(data, DataProto): + data = {**data.batch.cuda(), **data.non_tensor_batch} + else: + data = data.cuda() # actor device is cpu when using offload + responses = data['responses'] response_length = responses.size(1) attention_mask = data['attention_mask'] diff --git a/verl/workers/critic/dp_critic.py b/verl/workers/critic/dp_critic.py index 11ad02dc..1d167e87 100644 --- a/verl/workers/critic/dp_critic.py +++ b/verl/workers/critic/dp_critic.py @@ -49,11 +49,19 @@ def __init__(self, config, critic_module: nn.Module, critic_optimizer: optim.Opt def _forward_micro_batch(self, micro_batch): response_length = micro_batch['responses'].size(-1) + multi_modal_inputs = {} + if 'multi_modal_inputs' in micro_batch: + for key in micro_batch['multi_modal_inputs'][0].keys(): + multi_modal_inputs[key] = torch.cat([inputs[key] for inputs in micro_batch['multi_modal_inputs']], + dim=0) + with torch.autocast(device_type='cuda', dtype=torch.bfloat16): input_ids = micro_batch['input_ids'] batch, seqlen = input_ids.shape attention_mask = micro_batch['attention_mask'] position_ids = micro_batch['position_ids'] + if position_ids.dim() == 3: # qwen2vl mrope + position_ids = position_ids.transpose(0, 1) if self.use_remove_padding: input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1), @@ -61,8 +69,13 @@ def _forward_micro_batch(self, micro_batch): input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) # unpad the position_ids to align the rotary - position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), - indices).transpose(0, 1) + if position_ids.dim() == 3: + position_ids_rmpad = index_first_axis(rearrange(position_ids, "c b s ... -> (b s) c ..."), + indices).transpose(0, 1).unsqueeze( + 1) # (3, bsz, seqlen) -> (3, 1, bsz * seqlen) + else: + position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), + indices).transpose(0, 1) # pad and slice the inputs if sp > 1 if self.ulysses_sequence_parallel_size > 1: @@ -74,6 +87,7 @@ def _forward_micro_batch(self, micro_batch): output = self.critic_module(input_ids=input_ids_rmpad, attention_mask=None, position_ids=position_ids_rmpad, + **multi_modal_inputs, use_cache=False) # prevent model thinks we are generating values_rmpad = output.logits values_rmpad = values_rmpad.squeeze(0) # (total_nnz) @@ -92,6 +106,7 @@ def _forward_micro_batch(self, micro_batch): output = self.critic_module(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, + **multi_modal_inputs, use_cache=False) # prevent model thinks we are generating values = output.logits values = values[:, -response_length - 1:-1].squeeze(-1) @@ -113,8 +128,13 @@ def compute_values(self, data: DataProto) -> torch.Tensor: select_keys = ['responses', 'input_ids', 'attention_mask', 'position_ids'] batch = data.select(batch_keys=select_keys).batch use_dynamic_bsz = data.meta_info['use_dynamic_bsz'] + has_multi_modal_inputs = 'multi_modal_inputs' in data.non_tensor_batch.keys() - if use_dynamic_bsz: + if has_multi_modal_inputs: + num_micro_batches = data.batch.batch_size[0] // micro_batch_size + non_tensor_select_keys = ['multi_modal_inputs'] + micro_batches = data.select(select_keys, non_tensor_select_keys).chunk(num_micro_batches) + elif use_dynamic_bsz: # split using dynamic bsz max_token_len = data.meta_info['max_token_len'] * self.ulysses_sequence_parallel_size micro_batches, indices = rearrange_micro_batches(batch=batch, max_token_len=max_token_len) @@ -123,6 +143,9 @@ def compute_values(self, data: DataProto) -> torch.Tensor: values_lst = [] for micro_batch in micro_batches: + if isinstance(micro_batch, DataProto): + micro_batch = {**micro_batch.batch, **micro_batch.non_tensor_batch} + with torch.no_grad(): values = self._forward_micro_batch(micro_batch) values_lst.append(values) @@ -147,15 +170,25 @@ def update_critic(self, data: DataProto): select_keys = ['input_ids', 'responses', 'attention_mask', 'position_ids', 'values', 'returns'] batch = data.select(batch_keys=select_keys).batch + has_multi_modal_inputs = 'multi_modal_inputs' in data.non_tensor_batch.keys() + # Split to make minibatch iterator for updating the actor # See PPO paper for details. https://arxiv.org/abs/1707.06347 - dataloader = batch.split(self.config.ppo_mini_batch_size) + if has_multi_modal_inputs: + num_mini_batches = data.batch.batch_size[0] // self.config.ppo_mini_batch_size + non_tensor_select_keys = ['multi_modal_inputs'] + dataloader = data.select(select_keys, non_tensor_select_keys).chunk(num_mini_batches) + else: + dataloader = batch.split(self.config.ppo_mini_batch_size) for epoch in range(self.config.ppo_epochs): for batch_idx, data in enumerate(dataloader): # split batch into micro_batches mini_batch = data - if self.config.use_dynamic_bsz: + if has_multi_modal_inputs: + num_micro_batches = mini_batch.batch.batch_size[0] // self.config.ppo_micro_batch_size_per_gpu + micro_batches = data.select(select_keys, non_tensor_select_keys).chunk(num_micro_batches) + elif self.config.use_dynamic_bsz: max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len) else: @@ -165,7 +198,11 @@ def update_critic(self, data: DataProto): self.critic_optimizer.zero_grad() for data in micro_batches: - data = data.cuda() # critic device is cpu when using offload + if isinstance(data, DataProto): + data = {**data.batch.cuda(), **data.non_tensor_batch} + else: + data = data.cuda() # critic device is cpu when using offload + input_ids = data['input_ids'] responses = data['responses'] attention_mask = data['attention_mask'] diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py index 4ebeccdc..a63a4088 100644 --- a/verl/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -27,7 +27,7 @@ from verl import DataProto from verl.single_controller.base import Worker from verl.single_controller.base.decorator import register, Dispatch -from verl.utils import hf_tokenizer +from verl.utils import hf_tokenizer, hf_processor from verl.utils.debug import log_gpu_memory_usage from verl.utils.fs import copy_to_local from verl.utils.fsdp_utils import get_fsdp_wrap_policy, init_fn, get_init_weight_context_manager @@ -151,7 +151,7 @@ def _build_model_optimizer(self, role='actor'): from verl.utils.model import print_model_size, update_model_config, get_generation_config from verl.utils.torch_dtypes import PrecisionType - from transformers import AutoModelForCausalLM, AutoConfig + from transformers import AutoModelForCausalLM, AutoConfig, AutoModelForVision2Seq from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy, MixedPrecision, CPUOffload from torch import optim @@ -163,6 +163,7 @@ def _build_model_optimizer(self, # note that we have to create model in fp32. Otherwise, the optimizer is in bf16, which is incorrect # TODO(zhangchi.usc1992): 1. support create from random initialized model. 2. Support init with FSDP directly self.tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code) + self.processor = hf_processor(local_path, trust_remote_code=trust_remote_code) torch_dtype = fsdp_config.get('model_dtype', None) if torch_dtype is None: @@ -198,11 +199,16 @@ def _build_model_optimizer(self, with init_context(), warnings.catch_warnings(): warnings.simplefilter("ignore") - actor_module = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path=local_path, - torch_dtype=torch_dtype, - config=actor_model_config, - attn_implementation='flash_attention_2', - trust_remote_code=trust_remote_code) + if type(actor_model_config) in AutoModelForVision2Seq._model_mapping.keys(): + actor_module_class = AutoModelForVision2Seq + else: + actor_module_class = AutoModelForCausalLM + + actor_module = actor_module_class.from_pretrained(pretrained_model_name_or_path=local_path, + torch_dtype=torch_dtype, + config=actor_model_config, + attn_implementation='flash_attention_2', + trust_remote_code=trust_remote_code) # Apply Liger kernel to the model if use_liger is set to True if use_liger: from liger_kernel.transformers.monkey_patch import _apply_liger_kernel_to_instance @@ -396,10 +402,11 @@ def init_model(self): if self._is_actor: self.flops_counter = FlopsCounter(self.actor_model_config) - self.checkpoint_manager = FSDPCheckpointManager(model=self.actor_module_fsdp, - optimizer=self.actor.actor_optimizer, - lr_scheduler=self.actor_lr_scheduler, - tokenizer=self.tokenizer) + self.checkpoint_manager = FSDPCheckpointManager( + model=self.actor_module_fsdp, + optimizer=self.actor.actor_optimizer, + lr_scheduler=self.actor_lr_scheduler, + processing_class=self.processor if self.processor is not None else self.tokenizer) torch.cuda.empty_cache() @@ -634,6 +641,7 @@ def _build_critic_model_optimizer(self, config): tokenizer_path = copy_to_local(config.model.tokenizer_path) self.tokenizer = hf_tokenizer(tokenizer_path, trust_remote_code=config.model.get('trust_remote_code', False)) + self.processor = hf_processor(tokenizer_path, trust_remote_code=config.model.get('trust_remote_code', False)) from omegaconf import OmegaConf override_config = OmegaConf.to_container(self.config.model.get('override_config', OmegaConf.create())) @@ -757,10 +765,11 @@ def init_model(self): critic_optimizer=self.critic_optimizer) self.flops_counter = FlopsCounter(self.critic_model_config) - self.checkpoint_manager = FSDPCheckpointManager(model=self.critic_module, - optimizer=self.critic_optimizer, - lr_scheduler=self.critic_lr_scheduler, - tokenizer=self.tokenizer) + self.checkpoint_manager = FSDPCheckpointManager( + model=self.critic_module, + optimizer=self.critic_optimizer, + lr_scheduler=self.critic_lr_scheduler, + processing_class=self.processor if self.processor is not None else self.tokenizer) torch.cuda.empty_cache() diff --git a/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py b/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py index bcee3544..3f15bdb0 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py +++ b/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py @@ -24,6 +24,7 @@ - Do inference in tp. pp is treated as additional dp - After inference, all the parameters that doesn't belong to this pp rank is freed. """ +import numpy as np from typing import List from contextlib import contextmanager from omegaconf import DictConfig @@ -31,7 +32,7 @@ import torch.distributed from tensordict import TensorDict from torch import nn - +from typing import Any, Union from verl import DataProto from verl.utils.torch_functional import get_eos_mask, pad_2d_list_to_length from verl.workers.rollout.base import BaseRollout @@ -54,6 +55,13 @@ def _pre_process_inputs(pad_token_id, prompt_token_ids: torch.Tensor) -> List[in return token_ids +def _repeat_interleave(value: Union[torch.Tensor, np.ndarray], repeats: int) -> Union[torch.Tensor, List[Any]]: + if isinstance(value, torch.Tensor): + return value.repeat_interleave(repeats, dim=0) + else: + return np.repeat(value, repeats, axis=0) + + class vLLMRollout(BaseRollout): def __init__(self, model_path: str, config: DictConfig, tokenizer, model_hf_config, **kwargs): @@ -110,7 +118,7 @@ def __init__(self, model_path: str, config: DictConfig, tokenizer, model_hf_conf kwargs = dict( n=1, - logprobs=1, # can be set to 0 and let actor to recompute + logprobs=0, # can be set to 0 and let actor to recompute max_tokens=config.response_length, ) @@ -160,10 +168,19 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: batch_size = idx.size(0) - idx_list = [] - # parse idx from torch.Tensor to List[List[str]] - for i in range(batch_size): - idx_list.append(_pre_process_inputs(self.pad_token_id, idx[i])) + non_tensor_batch = prompts.non_tensor_batch + if batch_size != len(non_tensor_batch['raw_prompt_ids']): + raise RuntimeError('vllm sharding manager is not work properly.') + + if 'multi_modal_data' in non_tensor_batch: + vllm_inputs = [] + for raw_prompt_ids, multi_modal_data in zip(non_tensor_batch.pop('raw_prompt_ids'), + non_tensor_batch.pop('multi_modal_data')): + vllm_inputs.append({'prompt_token_ids': raw_prompt_ids, 'multi_modal_data': multi_modal_data}) + else: + vllm_inputs = [{ + 'prompt_token_ids': raw_prompt_ids + } for raw_prompt_ids in non_tensor_batch.pop('raw_prompt_ids')] do_sample = prompts.meta_info.get('do_sample', True) if not do_sample: @@ -179,9 +196,8 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: # users can customize different sampling_params at different run with self.update_sampling_params(**kwargs): outputs = self.inference_engine.generate( - prompts=None, # because we have already convert it to prompt token id + prompts=vllm_inputs, # because we have already convert it to prompt token id sampling_params=self.sampling_params, - prompt_token_ids=idx_list, use_tqdm=False) # TODO(sgm): disable logprob when recompute_log_prob is enable @@ -196,15 +212,21 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: max_length=self.config.response_length).to(idx.device) if self.config.n > 1 and do_sample: - idx = idx.repeat_interleave(self.config.n, dim=0) - attention_mask = attention_mask.repeat_interleave(self.config.n, dim=0) - position_ids = position_ids.repeat_interleave(self.config.n, dim=0) + idx = _repeat_interleave(idx, self.config.n) + attention_mask = _repeat_interleave(attention_mask, self.config.n) + position_ids = _repeat_interleave(position_ids, self.config.n) batch_size = batch_size * self.config.n + if 'multi_modal_inputs' in non_tensor_batch.keys(): + non_tensor_batch['multi_modal_inputs'] = _repeat_interleave(non_tensor_batch['multi_modal_inputs'], + self.config.n) + seq = torch.cat([idx, response], dim=-1) response_length = response.size(1) delta_position_id = torch.arange(1, response_length + 1, device=position_ids.device) - delta_position_id = delta_position_id.unsqueeze(0).repeat(batch_size, 1) + delta_position_id = delta_position_id.unsqueeze(0).expand(batch_size, -1) + if position_ids.dim() == 3: # qwen2vl mrope + delta_position_id = delta_position_id.view(batch_size, 1, -1).expand(batch_size, 3, -1) # TODO(sgm): fix position_ids on right_pad # prompt: left pad + response: right pad @@ -231,4 +253,4 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: if vllm_version in ('0.3.1', '0.4.2', '0.5.4', '0.6.3') and self.config.free_cache_engine: self.inference_engine.free_cache_engine() - return DataProto(batch=batch) + return DataProto(batch=batch, non_tensor_batch=non_tensor_batch) diff --git a/verl/workers/sharding_manager/fsdp_ulysses.py b/verl/workers/sharding_manager/fsdp_ulysses.py index 85328cc8..dc3fce65 100644 --- a/verl/workers/sharding_manager/fsdp_ulysses.py +++ b/verl/workers/sharding_manager/fsdp_ulysses.py @@ -60,19 +60,9 @@ def preprocess_data(self, data: DataProto) -> DataProto: In Ulysses, we need to make sure the same data is used across a SP group """ if self.device_mesh is not None: - sp_size = self.device_mesh['sp'].size() group = self.device_mesh['sp'].get_group() - - prev_device = data.batch.device - data.batch = data.batch.cuda(device=torch.cuda.current_device()) - data.batch = allgather_dict_tensors(data.batch.contiguous(), size=sp_size, group=group, dim=0) - data.batch = data.batch.to(prev_device) - # all gather non_tensor_batch - all_non_tensor_batch = [None for _ in range(sp_size)] - torch.distributed.all_gather_object(all_non_tensor_batch, data.non_tensor_batch, group=group) - data.non_tensor_batch = { - k: np.concatenate([d[k] for d in all_non_tensor_batch]) for k in data.non_tensor_batch - } + data = data.to("cuda") + data.all_gather(group) return data def postprocess_data(self, data: DataProto) -> DataProto: diff --git a/verl/workers/sharding_manager/fsdp_vllm.py b/verl/workers/sharding_manager/fsdp_vllm.py index c79d3031..f5e4e2d8 100644 --- a/verl/workers/sharding_manager/fsdp_vllm.py +++ b/verl/workers/sharding_manager/fsdp_vllm.py @@ -127,15 +127,9 @@ def __exit__(self, exc_type, exc_value, traceback): def preprocess_data(self, data: DataProto) -> DataProto: # TODO: Current impl doesn't consider FSDP with torch micro-dp if vllm_version in ('0.3.1', '0.4.2', '0.5.4', '0.6.3'): - data.batch = allgather_dict_tensors(data.batch.contiguous(), - size=vllm_ps.get_tensor_model_parallel_world_size(), - group=vllm_ps.get_tensor_model_parallel_group(), - dim=0) + data.all_gather(vllm_ps.get_tensor_model_parallel_group()) else: - data.batch = allgather_dict_tensors(data.batch.contiguous(), - size=vllm_ps.get_tensor_model_parallel_world_size(), - group=vllm_ps.get_tensor_model_parallel_group().device_group, - dim=0) + data.all_gather(vllm_ps.get_tensor_model_parallel_group().device_group) return data