From 75d369d427a6502ff745b8cb9d7ccb711755cc49 Mon Sep 17 00:00:00 2001 From: kinman0224 Date: Thu, 13 Feb 2025 00:10:04 +0800 Subject: [PATCH 1/6] [feat] support qwen2 megatron backend --- verl/models/qwen2/__init__.py | 13 + verl/models/qwen2/megatron/__init__.py | 24 + .../megatron/checkpoint_utils/__init__.py | 13 + .../megatron/checkpoint_utils/qwen2_loader.py | 462 ++++++++++++ .../megatron/checkpoint_utils/qwen2_saver.py | 449 ++++++++++++ verl/models/qwen2/megatron/layers/__init__.py | 18 + .../megatron/layers/parallel_attention.py | 401 +++++++++++ .../qwen2/megatron/layers/parallel_decoder.py | 146 ++++ .../qwen2/megatron/layers/parallel_linear.py | 74 ++ .../qwen2/megatron/layers/parallel_mlp.py | 74 ++ .../qwen2/megatron/layers/parallel_rmsnorm.py | 46 ++ .../qwen2/megatron/modeling_qwen2_megatron.py | 663 ++++++++++++++++++ verl/models/registry.py | 2 + verl/models/weight_loader_registry.py | 6 +- .../vllm_v_0_4_2/megatron_weight_loaders.py | 1 + .../vllm_v_0_5_4/megatron_weight_loaders.py | 1 + .../vllm_v_0_6_3/megatron_weight_loaders.py | 1 + 17 files changed, 2393 insertions(+), 1 deletion(-) create mode 100644 verl/models/qwen2/__init__.py create mode 100644 verl/models/qwen2/megatron/__init__.py create mode 100644 verl/models/qwen2/megatron/checkpoint_utils/__init__.py create mode 100644 verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader.py create mode 100644 verl/models/qwen2/megatron/checkpoint_utils/qwen2_saver.py create mode 100644 verl/models/qwen2/megatron/layers/__init__.py create mode 100644 verl/models/qwen2/megatron/layers/parallel_attention.py create mode 100644 verl/models/qwen2/megatron/layers/parallel_decoder.py create mode 100644 verl/models/qwen2/megatron/layers/parallel_linear.py create mode 100644 verl/models/qwen2/megatron/layers/parallel_mlp.py create mode 100644 verl/models/qwen2/megatron/layers/parallel_rmsnorm.py create mode 100644 verl/models/qwen2/megatron/modeling_qwen2_megatron.py diff --git a/verl/models/qwen2/__init__.py b/verl/models/qwen2/__init__.py new file mode 100644 index 00000000..1ce90c5e --- /dev/null +++ b/verl/models/qwen2/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/verl/models/qwen2/megatron/__init__.py b/verl/models/qwen2/megatron/__init__.py new file mode 100644 index 00000000..26ff0137 --- /dev/null +++ b/verl/models/qwen2/megatron/__init__.py @@ -0,0 +1,24 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .modeling_qwen2_megatron import ( + # original model with megatron + ParallelQwen2Model, + ParallelQwen2ForCausalLM, + # rmpad with megatron + ParallelQwen2ForCausalLMRmPad, + ParallelQwen2ForValueRmPad, + # rmpad with megatron and pipeline parallelism + ParallelQwen2ForCausalLMRmPadPP, + ParallelQwen2ForValueRmPadPP) diff --git a/verl/models/qwen2/megatron/checkpoint_utils/__init__.py b/verl/models/qwen2/megatron/checkpoint_utils/__init__.py new file mode 100644 index 00000000..1ce90c5e --- /dev/null +++ b/verl/models/qwen2/megatron/checkpoint_utils/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader.py b/verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader.py new file mode 100644 index 00000000..b9b98845 --- /dev/null +++ b/verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader.py @@ -0,0 +1,462 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import time +from typing import Dict, Any, Callable, Optional +import torch.distributed as dist + + +def _megatron_calc_layer_map(config): + """Calculate the mapping of global layer_idx to local layer_idx + Returns: + layer_map (Dict: int -> tuple(int, int, int)): + mapping from the global layer index to + a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model) + """ + import megatron + from megatron.core import mpu + + pp_size = mpu.get_pipeline_model_parallel_world_size() + virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 + + layer_map = dict() + num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size + assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers + + for pp_rank_idx in range(pp_size): + for virtual_pp_rank_idx in range(virtual_pp_size): + layer_offset = (virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + + pp_rank_idx * num_layers_per_model) + for layer_idx in range(num_layers_per_model): + layer_map[layer_offset + layer_idx] = ( + pp_rank_idx, + virtual_pp_rank_idx, + layer_idx, + ) + return layer_map + + +def load_state_dict_to_megatron_qwen2(state_dict, wrapped_models, config, params_dtype, is_value_model=False): + """Load merged state_dict to sharded Megatron module in training. + """ + import megatron + from megatron.core import mpu + from megatron.utils import print_rank_0, unwrap_model + from megatron.core.transformer.module import Float16Module + from megatron.core import DistributedDataParallel as LocalDDP + from torch.nn.parallel import DistributedDataParallel as torchDDP + + start_time = time.time() + + def _get_gpt_model(model): + return model + + def broadcast_params(module): + for param in module.parameters(): + torch.distributed.broadcast(param.data, + src=mpu.get_data_parallel_src_rank(), + group=mpu.get_data_parallel_group()) + + dp_rank = mpu.get_data_parallel_rank() + pp_rank = mpu.get_pipeline_model_parallel_rank() + pp_size = mpu.get_pipeline_model_parallel_world_size() + virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 + mp_group = mpu.get_model_parallel_group() + + if torch.distributed.get_rank() == 0: + assert mp_group.rank() == 0, f"mp_rank:[{mp_group.rank}] != 0 on rank #0" + assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0" + assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0" + + if not isinstance(wrapped_models, (list, tuple)): + wrapped_models = list(wrapped_models) + + assert len(wrapped_models) == virtual_pp_size + num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size + assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers + + models = [None] * len(wrapped_models) + + for i, wrapped_model in enumerate(wrapped_models): + models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module)) + gpt_model_module = _get_gpt_model(models[i]) + assert len(gpt_model_module.model.layers) == num_layers_per_model + + def _broadcast_tensor(tensor, name) -> torch.Tensor: + """broadcast tensor from rank0 across mp_group""" + nonlocal state_dict + nonlocal mp_group + if torch.distributed.get_rank() == 0: + if name in state_dict: + weight = state_dict[name] + tensor_shape = weight.shape + else: + tensor_shape = None + else: + weight = None + tensor_shape = None + + obj_list = [tensor_shape] + dist.broadcast_object_list(obj_list, src=0, group=mp_group) + tensor_shape = obj_list[0] + + if tensor_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tensor:[{name}] not in state_dict, skip load") + return + + if tensor is None: + tensor = torch.empty( + tensor_shape, + dtype=params_dtype, + device=torch.cuda.current_device(), + requires_grad=False, + ) + if torch.distributed.get_rank() == 0: + tensor.data.copy_(weight) + dist.broadcast(tensor, src=0, group=mp_group) + + def _broadcast_tp_shard_tensor_vocab(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor: + """broadcast tensor in tp shards across mp_group""" + nonlocal state_dict + nonlocal mp_group + tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + + if torch.distributed.get_rank() == 0: + if name in state_dict: + full_weight = state_dict[name] + + if mutate_func is not None: + full_weight = mutate_func(full_weight) + tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim) + chunk_shape = tensor_chunk[0].shape + else: + chunk_shape = None + else: + chunk_shape = None + + obj_list = [chunk_shape] + dist.broadcast_object_list(obj_list, src=0, group=mp_group) + chunk_shape = obj_list[0] + if chunk_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tp_shard tensor:[{name}] not in state_dict, skip loading") + return + + if tensor is None: + sync_tensor = torch.empty( + chunk_shape, + dtype=params_dtype, + device=torch.cuda.current_device(), + requires_grad=False, + ) + else: + assert (tensor.shape == chunk_shape + ), f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}" + sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False) + + for i in range(tp_size): + if torch.distributed.get_rank() == 0: + sync_tensor.data.copy_(tensor_chunk[i]) + dist.broadcast(sync_tensor, src=0, group=mp_group) + if (i == tp_rank) and (tensor is not None): + tensor.data.copy_(sync_tensor) + + def _broadcast_tp_shard_tensor(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor: + """broadcast tensor in tp shards across mp_group""" + nonlocal state_dict + nonlocal mp_group + tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + + if torch.distributed.get_rank() == 0: + if name in state_dict: + full_weight = state_dict[name] + if mutate_func is not None: + full_weight = mutate_func(full_weight) + tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim) + chunk_shape = tensor_chunk[0].shape + else: + chunk_shape = None + else: + chunk_shape = None + + obj_list = [chunk_shape] + dist.broadcast_object_list(obj_list, src=0, group=mp_group) + chunk_shape = obj_list[0] + if chunk_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tp_shard tensor:[{name}] not in state_dict, skip loading") + return + + if tensor is None: + sync_tensor = torch.empty( + chunk_shape, + dtype=params_dtype, + device=torch.cuda.current_device(), + requires_grad=False, + ) + else: + assert (tensor.shape == chunk_shape + ), f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}" + sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False) + + for i in range(tp_size): + if torch.distributed.get_rank() == 0: + sync_tensor.data.copy_(tensor_chunk[i]) + dist.broadcast(sync_tensor, src=0, group=mp_group) + if (i == tp_rank) and (tensor is not None): + tensor.data.copy_(sync_tensor) + + def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tensor: + """broadcast tensor in tp shards across mp_group""" + nonlocal state_dict + nonlocal mp_group + tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + + if torch.distributed.get_rank() == 0: + gate_weight = state_dict[gate_name] + up_weight = state_dict[up_name] + new_gate_up_weight = torch.empty(config.intermediate_size * 2, + config.hidden_size, + dtype=params_dtype, + device=torch.cuda.current_device()) + for i in range(tp_size): + intermediate_size_tp = config.intermediate_size // tp_size + gate_weight_tp = gate_weight[i * intermediate_size_tp:(i + 1) * intermediate_size_tp] + up_weight_tp = up_weight[i * intermediate_size_tp:(i + 1) * intermediate_size_tp] + new_gate_up_weight[intermediate_size_tp * 2 * i:intermediate_size_tp * 2 * (i + 1)].copy_( + torch.cat([gate_weight_tp, up_weight_tp], dim=0)) + + tensor_chunk = torch.chunk(new_gate_up_weight, tp_size, dim=0) + chunk_shape = tensor_chunk[0].shape + else: + chunk_shape = None + + obj_list = [chunk_shape] + dist.broadcast_object_list(obj_list, src=0, group=mp_group) + chunk_shape = obj_list[0] + if chunk_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tp_shard tensor:[{gate_name, up_name}] not in state_dict, skip loading") + return + + if tensor is None: + sync_tensor = torch.empty( + chunk_shape, + dtype=params_dtype, + device=torch.cuda.current_device(), + requires_grad=False, + ) + else: + assert ( + tensor.shape == chunk_shape + ), f"rank #{torch.distributed.get_rank() == 0:} tensor {gate_name, up_name} shape {tensor.shape} != {chunk_shape}" + sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False) + + for i in range(tp_size): + if torch.distributed.get_rank() == 0: + sync_tensor.data.copy_(tensor_chunk[i]) + dist.broadcast(sync_tensor, src=0, group=mp_group) + if (i == tp_rank) and (tensor is not None): + tensor.data.copy_(sync_tensor) + + def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) -> torch.Tensor: + """broadcast tensor in tp shards across mp_group""" + nonlocal state_dict + nonlocal mp_group + tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + + if torch.distributed.get_rank() == 0: + assert (q_name in state_dict and k_name in state_dict and v_name in state_dict) + full_weight_q = state_dict[q_name] + full_weight_k = state_dict[k_name] + full_weight_v = state_dict[v_name] + + hidden_size_per_head = config.hidden_size // config.num_attention_heads + + if config.num_key_value_heads >= tp_size: + q_size_tp = config.hidden_size // tp_size + kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size + total_size = q_size_tp + 2 * kv_size_tp + if not bias: + new_weight_qkv = torch.empty(total_size * tp_size, + config.hidden_size, + dtype=params_dtype, + device=torch.cuda.current_device()) + else: + new_weight_qkv = torch.empty(total_size * tp_size, + dtype=params_dtype, + device=torch.cuda.current_device()) + for i in range(tp_size): + q_part = full_weight_q[i * q_size_tp:(i + 1) * q_size_tp] + k_part = full_weight_k[i * kv_size_tp:(i + 1) * kv_size_tp] + v_part = full_weight_v[i * kv_size_tp:(i + 1) * kv_size_tp] + new_weight_qkv[i * total_size:(i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], + dim=0)) + + else: + q_size_tp = config.hidden_size // tp_size + kv_size_tp = hidden_size_per_head + total_size = q_size_tp + 2 * kv_size_tp + if not bias: + new_weight_qkv = torch.empty(total_size * tp_size, + config.hidden_size, + dtype=params_dtype, + device=torch.cuda.current_device()) + else: + new_weight_qkv = torch.empty(total_size * tp_size, + dtype=params_dtype, + device=torch.cuda.current_device()) + for i in range(tp_size): + q_part = full_weight_q[i * q_size_tp:(i + 1) * q_size_tp] + start_idx = i * config.num_key_value_heads // tp_size * hidden_size_per_head + end_idx = (i * config.num_key_value_heads // tp_size + 1) * hidden_size_per_head + k_part = full_weight_k[start_idx:end_idx] + v_part = full_weight_v[start_idx:end_idx] + new_weight_qkv[i * total_size:(i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], + dim=0)) + + tensor_chunk = torch.chunk(new_weight_qkv, tp_size, dim=0) + chunk_shape = tensor_chunk[0].shape + else: + chunk_shape = None + + obj_list = [chunk_shape] + dist.broadcast_object_list(obj_list, src=0, group=mp_group) + chunk_shape = obj_list[0] + if chunk_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tp_shard tensor:[{name}] not in state_dict, skip loading") + return + + if tensor is None: + sync_tensor = torch.empty( + chunk_shape, + dtype=params_dtype, + device=torch.cuda.current_device(), + requires_grad=False, + ) + else: + assert (tensor.shape == chunk_shape + ), f"rank #{torch.distributed.get_rank()} tensor {q_name} shape {tensor.shape} != {chunk_shape}" + sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False) + + for i in range(tp_size): + if torch.distributed.get_rank() == 0: + sync_tensor.data.copy_(tensor_chunk[i]) + dist.broadcast(sync_tensor, src=0, group=mp_group) + if (i == tp_rank) and (tensor is not None): + tensor.data.copy_(sync_tensor) + + if dp_rank == 0: + # Embeddings + # ------------------- + print_rank_0("loading embeddings...") + gpt_model_module = _get_gpt_model(models[0]) + embed_tokens_weight = None + if pp_rank == 0: + embed_tokens_weight = gpt_model_module.model.embed_tokens.weight + _broadcast_tp_shard_tensor_vocab(embed_tokens_weight, "model.embed_tokens.weight") + + # Transformer layers + # ------------------- + layer_map = _megatron_calc_layer_map(config) + + for layer in range(config.num_hidden_layers): + print_rank_0(f"loading layer #{layer}...") + layer_name = f"model.layers.{layer}" + dst_pp_rank, dst_virtual_pp_rank, dst_layer_idx = layer_map[layer] + + gpt_model_module = _get_gpt_model(models[dst_virtual_pp_rank]) + sync_layer = gpt_model_module.model.layers[dst_layer_idx] + + _broadcast_tensor( + sync_layer.input_layernorm.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.input_layernorm.weight", + ) + + _broadcast_tp_shard_tensor_qkv( + sync_layer.self_attn.qkv_proj.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.self_attn.q_proj.weight", + f"{layer_name}.self_attn.k_proj.weight", + f"{layer_name}.self_attn.v_proj.weight", + ) + + _broadcast_tp_shard_tensor_qkv(sync_layer.self_attn.qkv_proj.bias if dst_pp_rank == pp_rank else None, + f"{layer_name}.self_attn.q_proj.bias", + f"{layer_name}.self_attn.k_proj.bias", + f"{layer_name}.self_attn.v_proj.bias", + bias=True) + + _broadcast_tp_shard_tensor( + sync_layer.self_attn.o_proj.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.self_attn.o_proj.weight", + chunk_dim=1, + ) + + _broadcast_tensor( + sync_layer.post_attention_layernorm.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.post_attention_layernorm.weight", + ) + + _broadcast_tp_shard_tensor_gate_up(sync_layer.mlp.gate_up_proj.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.mlp.gate_proj.weight", f"{layer_name}.mlp.up_proj.weight") + + _broadcast_tp_shard_tensor( + sync_layer.mlp.down_proj.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.mlp.down_proj.weight", + chunk_dim=1, + ) + # Final Layernorm + # ------------------- + print_rank_0("loading final layernorm...") + gpt_model_module = _get_gpt_model(models[-1]) + _broadcast_tensor( + getattr(gpt_model_module.model.norm, "weight", None), + "model.norm.weight", + ) + + print_rank_0("loading lm_head...") + lm_head_weight = None + if pp_rank + 1 == pp_size: + lm_head_weight = gpt_model_module.lm_head.weight + + if is_value_model: + # if torch.distributed.get_rank() == 0: + if 'lm_head.weight' in state_dict and state_dict['lm_head.weight'].shape[0] == 1: + _broadcast_tensor(lm_head_weight, "lm_head.weight") + elif 'reward_head.weight' in state_dict and state_dict['reward_head.weight'].shape[0] == 1: + _broadcast_tensor(lm_head_weight, "reward_head.weight") + print_rank_0('load lm_head from value_head weight') + else: + _broadcast_tensor(None, "lm_head.weight") + print_rank_0('fail to match lm_head in value_model') + # else: + + # _broadcast_tensor(lm_head_weight, "lm_head.weight") + + else: + _broadcast_tp_shard_tensor(lm_head_weight, "lm_head.weight") + dist.barrier() + # Broadcast weights inside data parallel groups + for wrapped_model in wrapped_models: + broadcast_params(wrapped_model) + + torch.cuda.empty_cache() + print_rank_0(f"loading megatron ckpt done, time elapsed {time.time() - start_time}s") diff --git a/verl/models/qwen2/megatron/checkpoint_utils/qwen2_saver.py b/verl/models/qwen2/megatron/checkpoint_utils/qwen2_saver.py new file mode 100644 index 00000000..0764b6fe --- /dev/null +++ b/verl/models/qwen2/megatron/checkpoint_utils/qwen2_saver.py @@ -0,0 +1,449 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import megatron +from megatron.core import mpu +from megatron.utils import print_rank_0, unwrap_model +from megatron.model import Float16Module +from megatron.model import DistributedDataParallel as LocalDDP +from torch.nn.parallel import DistributedDataParallel as torchDDP +import torch +import time +from typing import Optional +import torch.distributed as dist +from megatron import get_args + + +def _megatron_calc_global_rank(tp_rank: int = 0, dp_rank: int = 0, pp_rank: int = 0): + """given TP,DP,PP rank to get the global rank.""" + + args = get_args() + tp_size = mpu.get_tensor_model_parallel_world_size() + dp_size = mpu.get_data_parallel_world_size() + pp_size = mpu.get_pipeline_model_parallel_world_size() + assert (tp_size * dp_size * pp_size == torch.distributed.get_world_size() + ), f"{tp_size} x {dp_size} x {pp_size} != {torch.distributed.get_world_size()}" + if args.switch_dp_and_pp_grouping: + # TP-PP-DP grouping + return (dp_rank * pp_size + pp_rank) * tp_size + tp_rank + else: + # TP-DP-PP grouping + return (pp_rank * dp_size + dp_rank) * tp_size + tp_rank + + +def _megatron_calc_layer_map(config): + """Calculate the mapping of global layer_idx to local layer_idx + Returns: + layer_map (Dict: int -> tuple(int, int, int)): + mapping from the global layer index to + a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model) + """ + import megatron + from megatron.core import mpu + + pp_size = mpu.get_pipeline_model_parallel_world_size() + virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 + + args = megatron.get_args() + layer_map = dict() + num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size + assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers + + for pp_rank_idx in range(pp_size): + for virtual_pp_rank_idx in range(virtual_pp_size): + layer_offset = (virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + + pp_rank_idx * num_layers_per_model) + for layer_idx in range(num_layers_per_model): + layer_map[layer_offset + layer_idx] = ( + pp_rank_idx, + virtual_pp_rank_idx, + layer_idx, + ) + return layer_map + + +def merge_megatron_ckpt_llama(wrapped_models, config, is_value_model=False, dtype='bf16'): + """Merge sharded parameters of a Megatron module into a merged checkpoint. + + Args: + wrapped_modelss (list of megatron.model.DistributedDataParallel): + The local DDP wrapped megatron modules. + dtype (str or None): + The data type of state_dict. if None, the data type of the original parameters + is used. + gpt_model_key: key to access model + Returns: + state_dict (dict): + The merged state_dict in rank 0, and an empty dictionary in other ranks. + """ + start_time = time.time() + args = megatron.get_args() + + def _get_gpt_model(model): + return model + + dp_rank = mpu.get_data_parallel_rank() + pp_size = mpu.get_pipeline_model_parallel_world_size() + pp_rank = mpu.get_pipeline_model_parallel_rank() + virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 + mp_group = mpu.get_model_parallel_group() + + if dist.get_rank() == 0: + assert mp_group.rank() == 0, f"mp_rank:[{mp_group.rank}] != 0 on rank #0" + assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0" + assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0" + + if not isinstance(wrapped_models, (list, tuple)): + wrapped_models = list(wrapped_models) + + assert len(wrapped_models) == virtual_pp_size + num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size + assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers + + models = [None] * len(wrapped_models) + + for i, wrapped_model in enumerate(wrapped_models): + models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module)) + assert len(models[i].model.layers + ) == num_layers_per_model, 'len model layers {} not equal to num_layers_per_model {}'.format( + len(models[i].model.layers), num_layers_per_model) + + state_dict = dict() + + def _get_cpu_tensor(tensor: torch.Tensor): + if tensor is None: + return None + if tensor.device == torch.device("cpu"): + return tensor.detach().clone() + return tensor.detach().cpu() + + def _broadcast_tensor(tensor, name, src_pp_rank) -> torch.Tensor: + """broadcast tensor across mp_group""" + nonlocal state_dict + nonlocal mp_group + src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank) + + if torch.distributed.get_rank() == src_rank: + if tensor is None: + weight = None + tensor_shape = None + else: + weight = tensor + tensor_shape = weight.shape + else: + weight = None + tensor_shape = None + + obj_list = [tensor_shape] + dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) + tensor_shape = obj_list[0] + + if tensor_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tensor:[{name}] not exist, skip collect") + return + + if weight is None: + weight = torch.empty( + tensor_shape, + dtype=args.params_dtype, + device=torch.cuda.current_device(), + requires_grad=False, + ) + + dist.broadcast(weight, src=src_rank, group=mp_group) + + if torch.distributed.get_rank() == 0: + state_dict[name] = _get_cpu_tensor(weight) + + def _broadcast_tp_shard_tensor(tensor, name, src_pp_rank, concat_dim=0, mutate_func=None) -> torch.Tensor: + """broadcast tensor in tp shards across mp_group""" + nonlocal state_dict + nonlocal mp_group + tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank) + + if torch.distributed.get_rank() == src_rank: + chunk_shape = tensor.shape + else: + chunk_shape = None + + obj_list = [chunk_shape] + dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) + chunk_shape = obj_list[0] + if chunk_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tp_shard tensor:[{name}] not exist, skip collecting") + return + + buffer_tensor = torch.empty( + chunk_shape, + dtype=args.params_dtype, + device=torch.cuda.current_device(), + requires_grad=False, + ) + + chunk_tensors = [None] * tp_size + + for i in range(tp_size): + cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank) + sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor + dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group) + + if torch.distributed.get_rank() == 0: + chunk_tensors[i] = _get_cpu_tensor(sync_tensor) + + if torch.distributed.get_rank() == 0: + full_tensor = torch.concat(chunk_tensors, dim=concat_dim) + if mutate_func is not None: + full_tensor = mutate_func(full_tensor) + state_dict[name] = full_tensor + + def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name, src_pp_rank) -> torch.Tensor: + """broadcast tensor in tp shards across mp_group""" + nonlocal state_dict + nonlocal mp_group + tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank) + + if torch.distributed.get_rank() == src_rank: + chunk_shape = tensor.shape + else: + chunk_shape = None + + obj_list = [chunk_shape] + dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) + chunk_shape = obj_list[0] + if chunk_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tp_shard tensor:[{gate_name, up_name}] not exist, skip collecting") + return + + buffer_tensor = torch.empty( + chunk_shape, + dtype=args.params_dtype, + device=torch.cuda.current_device(), + requires_grad=False, + ) + + chunk_tensors = [None] * tp_size + + for i in range(tp_size): + cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank) + sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor + dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group) + + if torch.distributed.get_rank() == 0: + chunk_tensors[i] = _get_cpu_tensor(sync_tensor) + + if torch.distributed.get_rank() == 0: + full_tensor = torch.concat(chunk_tensors, dim=0) + intermediate_size_tp = config.intermediate_size // tp_size + gate_weight_list = [] + up_weight_list = [] + for i in range(tp_size): + gate_up_weight_tp = full_tensor[intermediate_size_tp * 2 * i:intermediate_size_tp * 2 * (i + 1)] + gate_weight_tp = gate_up_weight_tp[:intermediate_size_tp] + up_weight_tp = gate_up_weight_tp[intermediate_size_tp:] + gate_weight_list.append(gate_weight_tp) + up_weight_list.append(up_weight_tp) + + state_dict[gate_name] = torch.cat(gate_weight_list, dim=0) + state_dict[up_name] = torch.cat(up_weight_list, dim=0) + + def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, src_pp_rank): + """broadcast tensor in tp shards across mp_group""" + nonlocal state_dict + nonlocal mp_group + tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank) + + if torch.distributed.get_rank() == src_rank: + chunk_shape = tensor.shape + else: + chunk_shape = None + + obj_list = [chunk_shape] + dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) + chunk_shape = obj_list[0] + if chunk_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tp_shard tensor:[{q_name}] not exist, skip collecting") + return + + buffer_tensor = torch.empty( + chunk_shape, + dtype=args.params_dtype, + device=torch.cuda.current_device(), + requires_grad=False, + ) + + chunk_tensors = [None] * tp_size + + for i in range(tp_size): + cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank) + sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor + dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group) + + if torch.distributed.get_rank() == 0: + chunk_tensors[i] = _get_cpu_tensor(sync_tensor) + + if torch.distributed.get_rank() == 0: + full_tensor = torch.concat(chunk_tensors, dim=0) + q_weight_list = [] + k_weight_list = [] + v_weight_list = [] + hidden_size_per_head = config.hidden_size // config.num_attention_heads + + if config.num_key_value_heads >= tp_size: + q_size_tp = config.hidden_size // tp_size + kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size + total_size = q_size_tp + 2 * kv_size_tp + for i in range(tp_size): + qkv_part = full_tensor[i * total_size:(i + 1) * total_size] + q_part = qkv_part[:q_size_tp] + k_part = qkv_part[q_size_tp:q_size_tp + kv_size_tp] + v_part = qkv_part[q_size_tp + kv_size_tp:total_size] + q_weight_list.append(q_part) + k_weight_list.append(k_part) + v_weight_list.append(v_part) + else: + q_size_tp = config.hidden_size // tp_size + kv_size_tp = hidden_size_per_head + total_size = q_size_tp + 2 * kv_size_tp + for i in range(tp_size): + qkv_part = full_tensor[i * total_size:(i + 1) * total_size] + q_part = qkv_part[:q_size_tp] + k_part = qkv_part[q_size_tp:q_size_tp + kv_size_tp] + v_part = qkv_part[q_size_tp + kv_size_tp:total_size] + q_weight_list.append(q_part) + if i * config.num_key_value_heads % tp_size == 0: + k_weight_list.append(k_part) + v_weight_list.append(v_part) + + state_dict[q_name] = torch.cat(q_weight_list, dim=0) + state_dict[k_name] = torch.cat(k_weight_list, dim=0) + state_dict[v_name] = torch.cat(v_weight_list, dim=0) + + # empty cache before collecting weights + torch.cuda.empty_cache() + # Embeddings + # ------------------- + if dp_rank == 0: + # Embeddings + # ------------------- + print_rank_0("collecting embeddings...") + gpt_model_module = _get_gpt_model(models[0]) + _broadcast_tp_shard_tensor( + gpt_model_module.model.embed_tokens.weight if pp_rank == 0 else None, + "model.embed_tokens.weight", + src_pp_rank=0, + ) + + # Transformer layers + # ------------------- + layer_map = _megatron_calc_layer_map(config) + for layer in range(config.num_hidden_layers): + print_rank_0(f"collecting layer #{layer}...") + layer_name = f"model.layers.{layer}" + src_pp_rank, src_virtual_pp_rank, src_layer_idx = layer_map[layer] + + gpt_model_module = _get_gpt_model(models[src_virtual_pp_rank]) + sync_layer = gpt_model_module.model.layers[src_layer_idx] + + _broadcast_tensor( + sync_layer.input_layernorm.weight, + f"{layer_name}.input_layernorm.weight", + src_pp_rank=src_pp_rank, + ) + + _broadcast_tp_shard_tensor_qkv( + sync_layer.self_attn.qkv_proj.weight, + f"{layer_name}.self_attn.q_proj.weight", + f"{layer_name}.self_attn.k_proj.weight", + f"{layer_name}.self_attn.v_proj.weight", + src_pp_rank=src_pp_rank, + ) + + _broadcast_tp_shard_tensor( + sync_layer.self_attn.o_proj.weight, + f"{layer_name}.self_attn.o_proj.weight", + concat_dim=1, + src_pp_rank=src_pp_rank, + ) + + _broadcast_tensor( + sync_layer.post_attention_layernorm.weight, + f"{layer_name}.post_attention_layernorm.weight", + src_pp_rank=src_pp_rank, + ) + + _broadcast_tp_shard_tensor_gate_up(sync_layer.mlp.gate_up_proj.weight, + f"{layer_name}.mlp.gate_proj.weight", + f"{layer_name}.mlp.up_proj.weight", + src_pp_rank=src_pp_rank) + + _broadcast_tp_shard_tensor( + sync_layer.mlp.down_proj.weight, + f"{layer_name}.mlp.down_proj.weight", + concat_dim=1, + src_pp_rank=src_pp_rank, + ) + + # Final Layernorm + # ------------------- + print_rank_0("collecting final layernorm...") + gpt_model_module = _get_gpt_model(models[-1]) + _broadcast_tensor( + getattr(gpt_model_module.model.norm, "weight", None), + "model.norm.weight", + src_pp_rank=pp_size - 1, + ) + + print_rank_0("collecting lm_head...") + + if is_value_model: + _broadcast_tensor(getattr(gpt_model_module.lm_head, "weight", None) if pp_rank == pp_size - 1 else None, + "reward_head.weight", + src_pp_rank=pp_size - 1) + + else: + _broadcast_tp_shard_tensor( + getattr(gpt_model_module.lm_head, "weight", None) if pp_rank == pp_size - 1 else None, + "lm_head.weight", + src_pp_rank=pp_size - 1, + ) + + dist.barrier() + + torch.cuda.empty_cache() + if torch.distributed.get_rank() == 0: + if dtype == "fp16": + dtype = torch.float16 + elif dtype == "bf16": + dtype = torch.bfloat16 + elif dtype is None or dtype == "fp32": + dtype = torch.float32 + else: + print(f'Unknown/unsupported dtype to save: {dtype}"') + exit(1) + for k, v in state_dict.items(): + if dtype != v.dtype: + state_dict[k] = v.to(dtype) + + print_rank_0(f"merge megatron ckpt done, time elapsed {time.time() - start_time}s") + return state_dict diff --git a/verl/models/qwen2/megatron/layers/__init__.py b/verl/models/qwen2/megatron/layers/__init__.py new file mode 100644 index 00000000..2d19bef2 --- /dev/null +++ b/verl/models/qwen2/megatron/layers/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .parallel_attention import ParallelQwen2Attention +from .parallel_decoder import ParallelQwen2DecoderLayer, ParallelQwen2DecoderLayerRmPad +from .parallel_mlp import ParallelQwen2MLP +from .parallel_rmsnorm import ParallelQwen2RMSNorm diff --git a/verl/models/qwen2/megatron/layers/parallel_attention.py b/verl/models/qwen2/megatron/layers/parallel_attention.py new file mode 100644 index 00000000..a4f65e89 --- /dev/null +++ b/verl/models/qwen2/megatron/layers/parallel_attention.py @@ -0,0 +1,401 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional, Tuple + +import torch +from megatron.core import parallel_state as mpu +from megatron.core import tensor_parallel +from megatron.core import ModelParallelConfig +from torch import nn +from transformers import Qwen2Config +from verl.models.qwen2.megatron.layers.parallel_linear import QKVParallelLinear + +from verl.utils.megatron import tensor_parallel as tp_utils + + +class Qwen2RotaryEmbedding(nn.Module): + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base**(torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache(seq_len=max_position_embeddings, + device=self.inv_freq.device, + dtype=torch.get_default_dtype()) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), + ) + + +class Qwen2LinearScalingRotaryEmbedding(Qwen2RotaryEmbedding): + """Qwen2RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + t = t / self.scaling_factor + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + +class Qwen2DynamicNTKScalingRotaryEmbedding(Qwen2RotaryEmbedding): + """Qwen2RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + + if seq_len > self.max_position_embeddings: + base = self.base * ((self.scaling_factor * seq_len / self.max_position_embeddings) - + (self.scaling_factor - 1))**(self.dim / (self.dim - 2)) + inv_freq = 1.0 / (base**(torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., :x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class ParallelQwen2Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig): + super().__init__() + self.config = config + self.megatron_config = megatron_config + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + + # assign values after tp + tp_size = mpu.get_tensor_model_parallel_world_size() + assert self.num_heads % tp_size == 0, f'num_head must be divisible by tp_size. Got num_head={self.num_heads}, tp_size={tp_size}' + assert self.num_key_value_heads % tp_size == 0, \ + f'num_key_value_heads must be divisible by tp_size. Got num_key_value_heads={self.num_key_value_heads}, tp_size={tp_size}' + + self.num_heads_per_tp = self.num_heads // tp_size + self.num_key_value_heads_per_tp = self.num_key_value_heads // tp_size + self.hidden_size_per_tp = self.hidden_size // tp_size + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads}).") + + column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() + row_kwargs = tp_utils.get_default_kwargs_for_row_parallel_linear() + + if megatron_config is not None: + assert column_kwargs.get('config', False), 'must have ModelParallelConfig' + assert row_kwargs.get('config', False), 'must have ModelParallelConfig' + tp_utils.update_kwargs_with_config(column_kwargs, megatron_config) + tp_utils.update_kwargs_with_config(row_kwargs, megatron_config) + + # [self.q_size, self.k_size, self.v_size] + self.qkv_proj = QKVParallelLinear( + input_size=self.hidden_size, + num_heads=self.num_heads, + num_key_value_heads=self.num_key_value_heads, + head_dim=self.head_dim, + # bias=config.attention_bias, + bias=True, + gather_output=False, + skip_bias_add=False, + **column_kwargs) + + self.q_size = self.num_heads_per_tp * self.head_dim + self.k_size = self.num_key_value_heads_per_tp * self.head_dim + self.v_size = self.num_key_value_heads_per_tp * self.head_dim + + self.o_proj = tensor_parallel.RowParallelLinear( + input_size=self.num_heads * self.head_dim, + output_size=self.hidden_size, + # bias=config.attention_bias, + bias=False, + input_is_parallel=True, + skip_bias_add=False, + **row_kwargs) + + self._init_rope() + + def _init_rope(self): + self.rotary_emb = Qwen2RotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + qkv = self.qkv_proj(hidden_states)[0] + query_states, key_states, value_states = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1) + + query_states = query_states.view(bsz, q_len, self.num_heads_per_tp, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads_per_tp, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads_per_tp, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads_per_tp, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads_per_tp, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}") + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}") + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads_per_tp, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads_per_tp, q_len, self.head_dim)}, but is" + f" {attn_output.size()}") + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size_per_tp) + attn_output = self.o_proj(attn_output)[0] + return attn_output + + +""" +Remove padding Attention +- Using Flash-attn 2 +- Compatible with sequence parallel +""" + +from transformers.utils import is_flash_attn_2_available +import torch.nn.functional as F + +from einops import rearrange + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + +def apply_rotary_pos_emb_rmpad(q, k, cos, sin, position_ids, indices, sequence_length): + batch_size = position_ids.shape[0] + + q = pad_input(q, indices, batch_size, sequence_length) # (batch_size, seqlen, num_head, head_dim) + k = pad_input(k, indices, batch_size, sequence_length) + cos = cos[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim] + sin = sin[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim] + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + + q_embed = index_first_axis(rearrange(q_embed, "b s ... -> (b s) ..."), indices) + k_embed = index_first_axis(rearrange(k_embed, "b s ... -> (b s) ..."), indices) + + return q_embed, k_embed + + +from flash_attn.layers.rotary import apply_rotary_emb + + +# use flash-attn rotary embeddings with rmpad +# cos/sin shoudl be: (seq_length, rotary_dim / 2) +def apply_rotary_pos_emb_rmpad_flash(q, k, cos, sin, cu_seqlens, max_seqlen): + q_embed = apply_rotary_emb(q, + cos, + sin, + interleaved=False, + inplace=False, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen) + k_embed = apply_rotary_emb(k, + cos, + sin, + interleaved=False, + inplace=False, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen) + return q_embed, k_embed + + +class ParallelQwen2AttentionRmPad(ParallelQwen2Attention): + + def forward(self, + hidden_states: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + sequence_length: int = None, + indices: torch.Tensor = None, + cu_seqlens: torch.Tensor = None, + max_seqlen_in_batch: int = None): + total_nnz, _, _ = hidden_states.size() # This is the total_nnz padded after sequence parallel + + if self.megatron_config.sequence_parallel: + total_nnz = total_nnz * mpu.get_tensor_model_parallel_world_size() + + qkv = self.qkv_proj(hidden_states)[0] + query_states, key_states, value_states = qkv.split([self.q_size, self.k_size, self.v_size], + dim=-1) # (total_nnz, 1, hidden_size) + + if self.megatron_config.sequence_parallel: + sequence_parallel_pad = total_nnz - cu_seqlens[-1] + total_nnz = cu_seqlens[-1] # total_nnz before sp padding + query_states = query_states[:total_nnz] + key_states = key_states[:total_nnz] + value_states = value_states[:total_nnz] + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dime x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(total_nnz, self.num_heads_per_tp, self.head_dim) + key_states = key_states.view(total_nnz, self.num_key_value_heads_per_tp, self.head_dim) + value_states = value_states.view(total_nnz, self.num_key_value_heads_per_tp, self.head_dim) + + cos, sin = self.rotary_emb(value_states, seq_len=sequence_length) + cos, sin = cos[:, :cos.shape[1] // 2], sin[:, :sin.shape[1] // 2] # flash attn only needs half + query_states, key_states = apply_rotary_pos_emb_rmpad_flash(query_states, + key_states, + cos, + sin, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen_in_batch) + # query_states, key_states = apply_rotary_pos_emb_rmpad(query_states, key_states, cos, sin, position_ids, indices, + + # It is recommended to use dropout with FA according to the docs + # when training. + dropout_rate = 0.0 # if not self.training else self.attn_dropout + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (Qwen2RMSNorm handles it correctly) + input_dtype = query_states.dtype + if input_dtype == torch.float32: + query_states = query_states.to(torch.float16) + key_states = key_states.to(torch.float16) + value_states = value_states.to(torch.float16) + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen_in_batch, + max_seqlen_k=max_seqlen_in_batch, + dropout_p=dropout_rate, + softmax_scale=None, + causal=True, + ) + + attn_output_unpad = attn_output_unpad.to(input_dtype) + attn_output_unpad = attn_output_unpad.reshape(total_nnz, 1, self.hidden_size_per_tp).contiguous() + + # sequence parallel reduce_scatter is performed inside RowColumnParallel if enabled + # Here we need to repad + if self.megatron_config.sequence_parallel: + attn_output_unpad = F.pad(attn_output_unpad, pad=(0, 0, 0, 0, 0, sequence_parallel_pad)) + + attn_output_unpad = self.o_proj(attn_output_unpad)[0] + return attn_output_unpad diff --git a/verl/models/qwen2/megatron/layers/parallel_decoder.py b/verl/models/qwen2/megatron/layers/parallel_decoder.py new file mode 100644 index 00000000..68cefc84 --- /dev/null +++ b/verl/models/qwen2/megatron/layers/parallel_decoder.py @@ -0,0 +1,146 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple + +import torch +from torch import nn +from transformers import Qwen2Config +from megatron.core import ModelParallelConfig + +from .parallel_attention import ParallelQwen2Attention, ParallelQwen2AttentionRmPad +from .parallel_mlp import ParallelQwen2MLP +from .parallel_rmsnorm import ParallelQwen2RMSNorm + + +class ParallelQwen2DecoderLayer(nn.Module): + + def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = ParallelQwen2Attention(config=config, megatron_config=megatron_config) + + self.mlp = ParallelQwen2MLP(config, megatron_config=megatron_config) + self.input_layernorm = ParallelQwen2RMSNorm(config, megatron_config) + self.post_attention_layernorm = ParallelQwen2RMSNorm(config, megatron_config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Note: sequence parallel is hidden inside ColumnParallelLinear + # reduce scatter is hidden inside RowParallelLinear + + # Self Attention + hidden_states = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + ) + + # TODO: add sequence parallel operator reduce_scatter here + + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + + # TODO: add sequence parallel operator all_gather here + + hidden_states = self.mlp(hidden_states) + + # TODO: add sequence parallel operator reduce_scatter here + + hidden_states = residual + hidden_states + + outputs = hidden_states + + return outputs + + +class ParallelQwen2DecoderLayerRmPad(nn.Module): + + def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig): + super().__init__() + self.config = config + self.megatron_config = megatron_config + self.hidden_size = config.hidden_size + self.self_attn = ParallelQwen2AttentionRmPad(config=config, megatron_config=megatron_config) + + self.mlp = ParallelQwen2MLP(config, megatron_config=megatron_config) + self.input_layernorm = ParallelQwen2RMSNorm(config, megatron_config) + self.post_attention_layernorm = ParallelQwen2RMSNorm(config, megatron_config) + + def forward( + self, + hidden_states: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + sequence_length: int = None, + indices: torch.Tensor = None, + cu_seqlens: int = None, + max_seqlen_in_batch: int = None + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states # (total_nnz // sp, 1, hidden_size) + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + # (total_nnz // sp, 1, hidden_size) -> all-gather (total_nnz, 1, hidden_size) + # -> col + row -> reduce-scatter -> (total_nnz // sp, 1, hidden_size) + hidden_states = self.self_attn(hidden_states=hidden_states, + position_ids=position_ids, + sequence_length=sequence_length, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen_in_batch=max_seqlen_in_batch) + + hidden_states = residual + hidden_states + + # Fully Connected + # shape changes same as attn + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = hidden_states + + return outputs diff --git a/verl/models/qwen2/megatron/layers/parallel_linear.py b/verl/models/qwen2/megatron/layers/parallel_linear.py new file mode 100644 index 00000000..bfe5cf4e --- /dev/null +++ b/verl/models/qwen2/megatron/layers/parallel_linear.py @@ -0,0 +1,74 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023 The vLLM team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/linear.py + +from typing import Optional, Tuple + +from megatron.core import tensor_parallel + + +class QKVParallelLinear(tensor_parallel.ColumnParallelLinear): + + def __init__(self, + input_size, + num_heads, + num_key_value_heads, + head_dim, + *, + bias=True, + gather_output=True, + skip_bias_add=False, + **kwargs): + # Keep input parameters, and already restrict the head numbers + self.input_size = input_size + self.q_output_size = num_heads * head_dim + self.kv_output_size = num_key_value_heads * head_dim + self.head_dim = head_dim + self.gather_output = gather_output + self.skip_bias_add = skip_bias_add + + input_size = self.input_size + output_size = (num_heads + 2 * num_key_value_heads) * self.head_dim + + super().__init__(input_size=input_size, + output_size=output_size, + bias=bias, + gather_output=gather_output, + skip_bias_add=skip_bias_add, + **kwargs) + + +class MergedColumnParallelLinear(tensor_parallel.ColumnParallelLinear): + + def __init__(self, + input_size, + gate_ouput_size, + up_output_size, + *, + bias=True, + gather_output=True, + skip_bias_add=False, + **kwargs): + # Keep input parameters, and already restrict the head numbers + self.input_size = input_size + self.output_size = gate_ouput_size + up_output_size + self.gather_output = gather_output + self.skip_bias_add = skip_bias_add + + super().__init__(input_size=self.input_size, + output_size=self.output_size, + bias=bias, + gather_output=gather_output, + skip_bias_add=skip_bias_add, + **kwargs) diff --git a/verl/models/qwen2/megatron/layers/parallel_mlp.py b/verl/models/qwen2/megatron/layers/parallel_mlp.py new file mode 100644 index 00000000..48b97711 --- /dev/null +++ b/verl/models/qwen2/megatron/layers/parallel_mlp.py @@ -0,0 +1,74 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from megatron.core import parallel_state as mpu +from megatron.core import tensor_parallel +from megatron.core import ModelParallelConfig +from torch import nn +from transformers.activations import ACT2FN +from verl.models.qwen2.megatron.layers.parallel_linear import MergedColumnParallelLinear + +from verl.utils.megatron import tensor_parallel as tp_utils + + +class ParallelQwen2MLP(nn.Module): + + def __init__(self, config, megatron_config: ModelParallelConfig = None) -> None: + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + # The weight is only [hidden_size, intermediate_size // model_parallel_world_size] + + column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() + row_kwargs = tp_utils.get_default_kwargs_for_row_parallel_linear() + + if megatron_config is not None: + assert column_kwargs.get('config', False), 'must have ModelParallelConfig' + assert row_kwargs.get('config', False), 'must have ModelParallelConfig' + tp_utils.update_kwargs_with_config(row_kwargs, megatron_config) + tp_utils.update_kwargs_with_config(column_kwargs, megatron_config) + + tp_size = mpu.get_tensor_model_parallel_world_size() + + self.gate_up_proj = MergedColumnParallelLinear( + input_size=self.hidden_size, + gate_ouput_size=self.intermediate_size, + up_output_size=self.intermediate_size, + bias=False, + gather_output=False, + skip_bias_add=False, + **column_kwargs, + ) + self.gate_size = self.intermediate_size // tp_size + + self.down_proj = tensor_parallel.RowParallelLinear(input_size=self.intermediate_size, + output_size=self.hidden_size, + bias=False, + input_is_parallel=True, + skip_bias_add=False, + **row_kwargs) + + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + gate_up = self.gate_up_proj(x)[0] + gate, up = gate_up.split(self.gate_size, dim=-1) + return self.down_proj(self.act_fn(gate) * up)[0] diff --git a/verl/models/qwen2/megatron/layers/parallel_rmsnorm.py b/verl/models/qwen2/megatron/layers/parallel_rmsnorm.py new file mode 100644 index 00000000..726eb7f8 --- /dev/null +++ b/verl/models/qwen2/megatron/layers/parallel_rmsnorm.py @@ -0,0 +1,46 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numbers +import torch +from megatron.core import ModelParallelConfig +from torch import nn +from transformers import Qwen2Config + +from apex.normalization.fused_layer_norm import fused_rms_norm_affine +from verl.utils.megatron import sequence_parallel as sp_utils + + +class ParallelQwen2RMSNorm(nn.Module): + + def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig): + """ + Qwen2RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + if isinstance(config.hidden_size, numbers.Integral): + normalized_shape = (config.hidden_size,) + self.normalized_shape = torch.Size(normalized_shape) + self.weight = nn.Parameter(torch.ones(self.normalized_shape)) + self.variance_epsilon = config.rms_norm_eps + + if megatron_config.sequence_parallel: + sp_utils.mark_parameter_as_sequence_parallel(self.weight) + + def forward(self, hidden_states): + return fused_rms_norm_affine(input=hidden_states, + weight=self.weight, + normalized_shape=self.normalized_shape, + eps=self.variance_epsilon, + memory_efficient=True) \ No newline at end of file diff --git a/verl/models/qwen2/megatron/modeling_qwen2_megatron.py b/verl/models/qwen2/megatron/modeling_qwen2_megatron.py new file mode 100644 index 00000000..3d24082e --- /dev/null +++ b/verl/models/qwen2/megatron/modeling_qwen2_megatron.py @@ -0,0 +1,663 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch Qwen2 model.""" + +from typing import Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from megatron.core import tensor_parallel +from megatron.core import ModelParallelConfig +from torch import nn +from torch.nn import init +from transformers.modeling_outputs import BaseModelOutputWithPast +from transformers.models.qwen2.configuration_qwen2 import Qwen2Config +from transformers.models.qwen2.modeling_qwen2 import CausalLMOutputWithPast + +from verl.utils.megatron import sequence_parallel as sp_utils +from verl.utils.megatron import tensor_parallel as tp_utils +from .layers import ParallelQwen2DecoderLayer, ParallelQwen2RMSNorm, ParallelQwen2DecoderLayerRmPad +""" +TODO: +1. Add weight initialization. Here we need to be careful on TP weight init. +2. Add sequence parallel +3. Load checkpoint from Qwen2 pretrained checkpoint +""" + + +# Copied from transformers.models.bart.modeling_bart._make_causal_mask +def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len) + + +# Copied from transformers.models.bart.modeling_bart._expand_mask +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + +class ParallelQwen2Model(nn.Module): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`] + + Args: + config: Qwen2Config + """ + + def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig): + super().__init__() + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding() + if megatron_config is not None: + assert embedding_kwargs.get('config', False), 'must have ModelParallelConfig' + tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config) + self.embed_tokens = tensor_parallel.VocabParallelEmbedding(num_embeddings=config.vocab_size, + embedding_dim=config.hidden_size, + **embedding_kwargs) + + self.layers = nn.ModuleList( + [ParallelQwen2DecoderLayer(config, megatron_config) for _ in range(config.num_hidden_layers)]) + self.norm = ParallelQwen2RMSNorm(config, megatron_config) + + # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask + def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + inputs_embeds.dtype, + device=inputs_embeds.device, + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, + tgt_len=input_shape[-1]).to(inputs_embeds.device) + combined_attention_mask = (expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + + combined_attention_mask) + + return combined_attention_mask + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + """ + + Args: + input_ids: input ids. shape (batch_size, seq_length) + attention_mask: attention_mask. shape (batch_size, seq_length) + position_ids: position ids. shape (batch_size, seq_length) + + Returns: + + """ + batch_size, seq_length = input_ids.shape + inputs_embeds = self.embed_tokens(input_ids) + # embed positions + + attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), inputs_embeds) + + hidden_states = inputs_embeds + + for idx, decoder_layer in enumerate(self.layers): + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + ) + + hidden_states = layer_outputs + + hidden_states = self.norm(hidden_states) + + return hidden_states + + +class ParallelQwen2ForCausalLM(nn.Module): + + def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig): + super().__init__() + self.model = ParallelQwen2Model(config, megatron_config=megatron_config) + self.vocab_size = config.vocab_size + + column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() + if megatron_config is not None: + assert column_kwargs.get('config', False), 'must have ModelParallelConfig' + tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config) + + self.lm_head = tensor_parallel.ColumnParallelLinear(input_size=config.hidden_size, + output_size=config.vocab_size, + bias=False, + gather_output=False, + skip_bias_add=False, + **column_kwargs) + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + ```""" + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + ) + + hidden_states = outputs + logits = self.lm_head(hidden_states)[0] + + logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits) + + logits = logits.float() + return CausalLMOutputWithPast( + loss=None, + logits=logits, + past_key_values=None, + hidden_states=None, + attentions=None, + ) + + +from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + +class ParallelQwen2ModelRmPad(nn.Module): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`] + + Args: + config: Qwen2Config + """ + + def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig): + super().__init__() + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding() + self.megatron_config = megatron_config + if megatron_config is not None: + assert embedding_kwargs.get('config', False), 'must have ModelParallelConfig' + tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config) + self.embed_tokens = tensor_parallel.VocabParallelEmbedding(num_embeddings=config.vocab_size, + embedding_dim=config.hidden_size, + **embedding_kwargs) + + self.layers = nn.ModuleList( + [ParallelQwen2DecoderLayerRmPad(config, megatron_config) for _ in range(config.num_hidden_layers)]) + self.norm = ParallelQwen2RMSNorm(config, megatron_config) + + def forward(self, + input_ids: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + sequence_length: int = None, + indices: torch.Tensor = None, + cu_seqlens: int = None, + max_seqlen_in_batch: int = None) -> Union[Tuple, BaseModelOutputWithPast]: + """ + + Args: + input_ids: input ids. shape (1, totol_nnz) + position_ids: position ids. shape (batch_size, seq_length) + + Returns: + + """ + inputs_embeds = self.embed_tokens(input_ids) # (1, total_nnz) -> (1, total_nnz, hidden_size) + + # (1, total_nnz, hidden_size) -> (total_nnz, 1, hidden_size) -> (total_nnz // sp, 1, hidden_size) + inputs_embeds = inputs_embeds.transpose(0, 1) + if self.megatron_config.sequence_parallel: + inputs_embeds = tensor_parallel.scatter_to_sequence_parallel_region(inputs_embeds) + + hidden_states = inputs_embeds + for idx, decoder_layer in enumerate(self.layers): + layer_outputs = decoder_layer(hidden_states, + position_ids=position_ids, + sequence_length=sequence_length, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen_in_batch=max_seqlen_in_batch) + + hidden_states = layer_outputs + + hidden_states = self.norm(hidden_states) + + return hidden_states + + +class ParallelQwen2ForCausalLMRmPad(nn.Module): + + def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig): + super().__init__() + self.config = config + self.megatron_config = megatron_config + self.model = ParallelQwen2ModelRmPad(config, megatron_config=megatron_config) + self.vocab_size = config.vocab_size + self._init_head() + + def _init_head(self): + column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() + if self.megatron_config is not None: + assert column_kwargs.get('config', False), 'must have ModelParallelConfig' + tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config) + self.lm_head = tensor_parallel.ColumnParallelLinear(input_size=self.config.hidden_size, + output_size=self.config.vocab_size, + bias=False, + gather_output=False, + skip_bias_add=False, + **column_kwargs) + + def _forward_head(self, hidden_states): + # all_gather from sequence parallel region is performed inside lm_head + logits = self.lm_head(hidden_states)[0] + logits = logits.float() # (total_nnz_padded, 1, vocab_size // tp) + logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits) # (total_nnz_padded, 1, vocab_size) + return logits + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + ```""" + batch_size, sequence_length = input_ids.shape + + # remove padding here + input_ids, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(input_ids.unsqueeze(dim=-1), + attention_mask) # (total_nnz, 1) + + # pad input_ids to multiple of tp for all tp ranks + # TODO: for better performance, the sp padding should be removed at each layer. Not sure the performance gap + if self.megatron_config.sequence_parallel: + input_ids = sp_utils.pad_to_sequence_parallel(input_ids) + + input_ids = input_ids.transpose(0, 1) # (1, total_nnz+pad) + + outputs = self.model(input_ids=input_ids, + position_ids=position_ids, + sequence_length=sequence_length, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen_in_batch=max_seqlen_in_batch) + + hidden_states = outputs + + logits = self._forward_head(hidden_states) + + # remove padding from sequence parallel + if self.megatron_config.sequence_parallel: + totol_nnz = cu_seqlens[-1] + logits = logits[:totol_nnz] # (total_nnz_padded) + + logits = torch.squeeze(logits, dim=1) # remove the artificial batch dimension + # add removed padding back + logits = pad_input(logits, indices, batch_size, + seqlen=sequence_length) # (batch_size, sequence_length, vocab_size) + + return CausalLMOutputWithPast( + loss=None, + logits=logits, + past_key_values=None, + hidden_states=None, + attentions=None, + ) + + +class ParallelQwen2ForValueRmPad(ParallelQwen2ForCausalLMRmPad): + + def _init_head(self): + column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() + if self.megatron_config is not None: + assert column_kwargs.get('config', False), 'must have ModelParallelConfig' + tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config) + self.lm_head = nn.Linear(in_features=self.config.hidden_size, out_features=1, bias=False) + # lm_head is effectively the same as sequence parallel + sp_utils.mark_parameter_as_sequence_parallel(self.lm_head.weight) + + def _forward_head(self, hidden_states): + logits = self.lm_head(hidden_states) # (total_nnz_padded // tp, 1, 1) + logits = logits.float() + if self.megatron_config.sequence_parallel: + logits = tensor_parallel.gather_from_sequence_parallel_region(logits, tensor_parallel_output_grad=False) + return logits + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + output = super().forward(input_ids, attention_mask, position_ids) + output.logits = torch.squeeze(output.logits, dim=-1) + return output + + +""" +Support pipeline parallelism +""" + + +class ParallelQwen2ModelRmPadPP(nn.Module): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`] + This model definition supports pipeline parallelism. To support pp and vpp, + - This model only contains layer in this pp stage and vpp chunk + - When calling get_model in Megatron, this rank will instantiate all the vpp chunks in this pp. + Args: + config: Qwen2Config + """ + + def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig, pre_process, post_process): + super().__init__() + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.pre_process = pre_process + self.post_process = post_process + self.megatron_config = megatron_config + embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding() + if megatron_config is not None: + assert embedding_kwargs.get('config', False), 'must have ModelParallelConfig' + tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config) + if pre_process: + self.embed_tokens = tensor_parallel.VocabParallelEmbedding(num_embeddings=config.vocab_size, + embedding_dim=config.hidden_size, + **embedding_kwargs) + else: + self.embed_tokens = None + + # pp_rank = megatron_config.pipeline_model_parallel_rank + pp_size = megatron_config.pipeline_model_parallel_size + self.num_layer_per_pp = config.num_hidden_layers // pp_size + vpp_size = megatron_config.virtual_pipeline_model_parallel_size + + if vpp_size is not None: + self.num_layer_vpp_chunk = self.num_layer_per_pp // vpp_size + self.num_layer_this_model = self.num_layer_vpp_chunk + # vpp_rank = megatron_config.virtual_pipeline_model_parallel_rank + # self.offset = vpp_rank * ( + # config.num_hidden_layers // megatron_config.virtual_pipeline_model_parallel_size) + \ + # (megatron_config.pipeline_model_parallel_rank * self.num_layer_vpp_chunk) + else: + self.num_layer_this_model = self.num_layer_per_pp + # self.offset = pp_rank * self.num_layer_per_pp + + layers = [] + for i in range(self.num_layer_this_model): + layer = ParallelQwen2DecoderLayerRmPad(config, megatron_config) + # setattr(layer, 'hidden_layer_index', self.offset + i) + layers.append(layer) + + self.layers = nn.ModuleList(layers) + + if post_process: + self.norm = ParallelQwen2RMSNorm(config, megatron_config) + else: + self.norm = None + + def set_input_tensor(self, input_tensor): + """Set input tensor to be used instead of forward()'s input. + + When doing pipeline parallelism the input from the previous + stage comes from communication, not from the input, so the + model's forward_step_func won't have it. This function is thus + used by internal code to bypass the input provided by the + forward_step_func""" + self.input_tensor = input_tensor + + def forward(self, + input_ids: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + sequence_length: int = None, + indices: torch.Tensor = None, + cu_seqlens: int = None, + max_seqlen_in_batch: int = None) -> Union[Tuple, BaseModelOutputWithPast]: + """ + + Args: + input_ids: input ids. shape (1, totol_nnz) + position_ids: position ids. shape (batch_size, seq_length) + + Returns: + + """ + if self.pre_process: + # if torch.cuda.current_device() == 0: + # print(f'rank {torch.cuda.current_device()}: input_ids shape before embedding: {input_ids.shape}') + inputs_embeds = self.embed_tokens(input_ids) # (1, total_nnz) -> (1, total_nnz, hidden_size) + + # vocab parallel embedding will not do sequence parallel reduce-scatter in open source megatron + # so need to deal with it by handle here: + # (1, total_nnz, hidden_size) -> (total_nnz, 1, hidden_size) -> (total_nnz // sp, 1, hidden_size) + inputs_embeds = inputs_embeds.transpose(0, 1) + if self.megatron_config.sequence_parallel: + inputs_embeds = tensor_parallel.scatter_to_sequence_parallel_region(inputs_embeds) + + # if torch.cuda.current_device() == 0: + # print(f'rank {torch.cuda.current_device()}: input_embeds shape after embedding: {inputs_embeds.shape}') + hidden_states = inputs_embeds + else: + # self.hidden_states should be passed by Megatron + hidden_states = self.input_tensor + + for idx, decoder_layer in enumerate(self.layers): + layer_outputs = decoder_layer(hidden_states, + position_ids=position_ids, + sequence_length=sequence_length, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen_in_batch=max_seqlen_in_batch) + + hidden_states = layer_outputs + + if self.post_process: + hidden_states = self.norm(hidden_states) + + return hidden_states + + +class ParallelQwen2ForCausalLMRmPadPP(nn.Module): + + def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig, pre_process, post_process): + super().__init__() + self.config = config + self.megatron_config = megatron_config + self.model = ParallelQwen2ModelRmPadPP(config, + megatron_config=megatron_config, + pre_process=pre_process, + post_process=post_process) + self.share_embeddings_and_output_weights = None # workaround, megatron requires this attr + self.vocab_size = config.vocab_size + self.pre_process = pre_process + self.post_process = post_process + if post_process: + self._init_head() + + def set_input_tensor(self, input_tensor): + """Set input tensor to be used instead of forward()'s input. + + When doing pipeline parallelism the input from the previous + stage comes from communication, not from the input, so the + model's forward_step_func won't have it. This function is thus + used by internal code to bypass the input provided by the + forward_step_func""" + assert len(input_tensor) == 1 + self.model.set_input_tensor(input_tensor[0]) + + def _init_head(self): + column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() + if self.megatron_config is not None: + assert column_kwargs.get('config', False), 'must have ModelParallelConfig' + tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config) + self.lm_head = tensor_parallel.ColumnParallelLinear(input_size=self.config.hidden_size, + output_size=self.config.vocab_size, + bias=False, + gather_output=False, + skip_bias_add=False, + **column_kwargs) + + def _forward_head(self, hidden_states): + # all_gather from sequence parallel region is performed inside lm_head + # print(f'logits shape before forward_head: {hidden_states.shape}, vocab_size = {self.config.vocab_size}') # [4, 32, 4096] + logits = self.lm_head(hidden_states)[0] + # print(f'logits shape after forward_head: {logits.shape}') # [8, 32, 8] + logits = logits.float() # (total_nnz_padded, 1, vocab_size // tp) + return logits + + def forward( + self, + # original input + *, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + ```""" + + # Note that input_ids, attention_mask and position_ids should be passed to every pp layer. + # In the first pp, input_ids will be used, in other pp layers hidden_states will be used inside self.model + batch_size, sequence_length = input_ids.shape + # remove padding here + input_ids_rmpad, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(input_ids.unsqueeze(dim=-1), + attention_mask) # (total_nnz, 1) + # print(f'input_ids.shape = {input_ids.shape}, input_ids_rmpad.shape = {input_ids_rmpad.shape}, indices.shape = {indices.shape}, cu_seqlens[-1] = {cu_seqlens[-1]}') + # pad input_ids to multiple of tp for all tp ranks + # TODO: for better performance, the sp padding should be removed at each layer. Not sure the performance gap + if self.megatron_config.sequence_parallel: + input_ids_rmpad = sp_utils.pad_to_sequence_parallel(input_ids_rmpad) + + input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz+pad) + + outputs = self.model(input_ids=input_ids_rmpad, + position_ids=position_ids, + sequence_length=sequence_length, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen_in_batch=max_seqlen_in_batch) + + if self.post_process: + hidden_states = outputs + # print(f'hidden_states.shape = {hidden_states.shape}') # torch.Size([4, 32, 4096]) + logits = self._forward_head(hidden_states) + # print(f'logits.shape = {logits.shape}') + logits = torch.squeeze(logits, dim=1) # remove the artificial batch dimension # torch.Size([8, 32, 16]) + + # remove padding from sequence parallel + if self.megatron_config.sequence_parallel: + totol_nnz = cu_seqlens[-1] + logits = logits[:totol_nnz] # (total_nnz_padded) + # add removed padding back. If input is already rmpad, we let the caller pad_input + # print(f'logits.shape = {logits.shape}, indices.shape = {indices.shape}, batch_size = {batch_size}, seq_len = {sequence_length}') + logits = pad_input(logits, indices, batch_size, + seqlen=sequence_length) # (batch_size, sequence_length, vocab_size) + + return CausalLMOutputWithPast( + loss=None, + logits=logits, + past_key_values=None, + hidden_states=None, + attentions=None, + ) + else: + return outputs + + +class ParallelQwen2ForValueRmPadPP(ParallelQwen2ForCausalLMRmPadPP): + + def _init_head(self): + column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() + if self.megatron_config is not None: + assert column_kwargs.get('config', False), 'must have ModelParallelConfig' + tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config) + self.lm_head = nn.Linear(in_features=self.config.hidden_size, out_features=1, bias=False) + # lm_head is effectively the same as sequence parallel + sp_utils.mark_parameter_as_sequence_parallel(self.lm_head.weight) + + def _forward_head(self, hidden_states): + logits = self.lm_head(hidden_states) # (total_nnz_padded // tp, 1, 1) + logits = logits.float() + if self.megatron_config.sequence_parallel: + logits = tensor_parallel.gather_from_sequence_parallel_region(logits, tensor_parallel_output_grad=False) + return logits + + def forward( + self, + *, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + output = super().forward(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids) + if self.post_process: + output.logits = torch.squeeze(output.logits, dim=-1) + return output + else: + return output \ No newline at end of file diff --git a/verl/models/registry.py b/verl/models/registry.py index 55ddbd44..a363cc8b 100644 --- a/verl/models/registry.py +++ b/verl/models/registry.py @@ -37,6 +37,8 @@ def check_model_support_rmpad(model_type: str): _MODELS = { "LlamaForCausalLM": ("llama", ("ParallelLlamaForCausalLMRmPadPP", "ParallelLlamaForValueRmPadPP", "ParallelLlamaForCausalLMRmPad")), + "Qwen2ForCausalLM": + ("qwen2", ("ParallelQwen2ForCausalLMRmPadPP", "ParallelQwen2ForValueRmPadPP", "ParallelQwen2ForCausalLMRmPad")), "MistralForCausalLM": ("mistral", ("ParallelMistralForCausalLMRmPadPP", "ParallelMistralForValueRmPadPP", "ParallelMistralForCausalLMRmPad")) } diff --git a/verl/models/weight_loader_registry.py b/verl/models/weight_loader_registry.py index 17f0c5ca..e412877f 100644 --- a/verl/models/weight_loader_registry.py +++ b/verl/models/weight_loader_registry.py @@ -15,7 +15,11 @@ def get_weight_loader(arch: str): from verl.models.llama.megatron.checkpoint_utils.llama_loader import load_state_dict_to_megatron_llama - _MODEL_WEIGHT_MEGATRON_LOADER_REGISTRY = {'LlamaForCausalLM': load_state_dict_to_megatron_llama} + from verl.models.qwen2.megatron.checkpoint_utils.qwen2_loader import load_state_dict_to_megatron_qwen2 + _MODEL_WEIGHT_MEGATRON_LOADER_REGISTRY = { + 'LlamaForCausalLM': load_state_dict_to_megatron_llama, + 'Qwen2ForCausalLM': load_state_dict_to_megatron_qwen2, + } if arch in _MODEL_WEIGHT_MEGATRON_LOADER_REGISTRY: return _MODEL_WEIGHT_MEGATRON_LOADER_REGISTRY[arch] diff --git a/verl/third_party/vllm/vllm_v_0_4_2/megatron_weight_loaders.py b/verl/third_party/vllm/vllm_v_0_4_2/megatron_weight_loaders.py index 1a7c2e2c..281de83a 100644 --- a/verl/third_party/vllm/vllm_v_0_4_2/megatron_weight_loaders.py +++ b/verl/third_party/vllm/vllm_v_0_4_2/megatron_weight_loaders.py @@ -282,6 +282,7 @@ def mistral_megatron_weight_loader(actor_weights: Dict, vllm_model: nn.Module) - 'LlamaForCausalLM': llama_megatron_core_te_weight_loader, # use te backend for open-source megatron 'LLaMAForCausalLM': llama_megatron_core_te_weight_loader, 'MistralForCausalLM': mistral_megatron_weight_loader, + 'Qwen2ForCausalLM': llama_megatron_weight_loader, } diff --git a/verl/third_party/vllm/vllm_v_0_5_4/megatron_weight_loaders.py b/verl/third_party/vllm/vllm_v_0_5_4/megatron_weight_loaders.py index 4f2b19a9..b5441f36 100644 --- a/verl/third_party/vllm/vllm_v_0_5_4/megatron_weight_loaders.py +++ b/verl/third_party/vllm/vllm_v_0_5_4/megatron_weight_loaders.py @@ -282,6 +282,7 @@ def mistral_megatron_weight_loader(actor_weights: Dict, vllm_model: nn.Module) - 'LlamaForCausalLM': llama_megatron_weight_loader, # use te backend for open-source megatron 'LLaMAForCausalLM': llama_megatron_weight_loader, 'MistralForCausalLM': mistral_megatron_weight_loader, + 'Qwen2ForCausalLM': llama_megatron_weight_loader, } diff --git a/verl/third_party/vllm/vllm_v_0_6_3/megatron_weight_loaders.py b/verl/third_party/vllm/vllm_v_0_6_3/megatron_weight_loaders.py index 7fd6c0e6..ae8588f8 100644 --- a/verl/third_party/vllm/vllm_v_0_6_3/megatron_weight_loaders.py +++ b/verl/third_party/vllm/vllm_v_0_6_3/megatron_weight_loaders.py @@ -283,6 +283,7 @@ def mistral_megatron_weight_loader(actor_weights: Dict, vllm_model: nn.Module) - "LlamaForCausalLM": llama_megatron_weight_loader, # use te backend for open-source megatron "LLaMAForCausalLM": llama_megatron_weight_loader, "MistralForCausalLM": mistral_megatron_weight_loader, + 'Qwen2ForCausalLM': llama_megatron_weight_loader, } From b33325ce8dc12efe367114376b46a60ff4d2a0de Mon Sep 17 00:00:00 2001 From: kinman0224 Date: Thu, 13 Feb 2025 10:33:38 +0800 Subject: [PATCH 2/6] add example --- .../run_qwen2-7b_math_gsm8k_megatron.sh | 42 +++++++++++++++++++ 1 file changed, 42 insertions(+) create mode 100644 examples/ppo_trainer/run_qwen2-7b_math_gsm8k_megatron.sh diff --git a/examples/ppo_trainer/run_qwen2-7b_math_gsm8k_megatron.sh b/examples/ppo_trainer/run_qwen2-7b_math_gsm8k_megatron.sh new file mode 100644 index 00000000..0a66e83f --- /dev/null +++ b/examples/ppo_trainer/run_qwen2-7b_math_gsm8k_megatron.sh @@ -0,0 +1,42 @@ +set -x + +export VLLM_ATTENTION_BACKEND=XFORMERS + +gsm8k_train_path=$HOME/data/gsm8k/train.parquet +gsm8k_test_path=$HOME/data/gsm8k/test.parquet +math_train_path=$HOME/data/math/train.parquet +math_test_path=$HOME/data/math/test.parquet + +train_files="['$gsm8k_train_path', '$math_train_path']" +test_files="['$gsm8k_test_path', '$math_test_path']" + +python3 -m verl.trainer.main_ppo --config-path=./config --config-name='ppo_megatron_trainer'\ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.train_batch_size=1024 \ + data.val_batch_size=6312 \ + data.max_prompt_length=1024 \ + data.max_response_length=512 \ + actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.ppo_mini_batch_size=256 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=4 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ + critic.optim.lr=1e-5 \ + critic.model.path=Qwen/Qwen2-7B-Instruct \ + critic.model.enable_gradient_checkpointing=False \ + critic.ppo_micro_batch_size_per_gpu=4 \ + algorithm.kl_ctrl.kl_coef=0.001 \ + trainer.critic_warmup=0 \ + trainer.logger=['console','wandb'] \ + trainer.project_name='verl_megatron_math_gsm8k_examples' \ + trainer.experiment_name='qwen2_7b_function_rm' \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=-1 \ + trainer.test_freq=5 \ + trainer.total_epochs=100 $@ From 8e36b3e3e73f6f92c256cb83c66e990c82230541 Mon Sep 17 00:00:00 2001 From: kinman0224 Date: Sat, 15 Feb 2025 11:56:57 +0800 Subject: [PATCH 3/6] misc --- .../qwen2/megatron/modeling_qwen2_megatron.py | 19 ++++++------------- .../vllm_v_0_4_2/megatron_weight_loaders.py | 1 - .../vllm_v_0_5_4/megatron_weight_loaders.py | 1 - 3 files changed, 6 insertions(+), 15 deletions(-) diff --git a/verl/models/qwen2/megatron/modeling_qwen2_megatron.py b/verl/models/qwen2/megatron/modeling_qwen2_megatron.py index 3d24082e..ae3372c8 100644 --- a/verl/models/qwen2/megatron/modeling_qwen2_megatron.py +++ b/verl/models/qwen2/megatron/modeling_qwen2_megatron.py @@ -26,7 +26,6 @@ from megatron.core import tensor_parallel from megatron.core import ModelParallelConfig from torch import nn -from torch.nn import init from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.models.qwen2.configuration_qwen2 import Qwen2Config from transformers.models.qwen2.modeling_qwen2 import CausalLMOutputWithPast @@ -324,8 +323,8 @@ def forward( batch_size, sequence_length = input_ids.shape # remove padding here - input_ids, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(input_ids.unsqueeze(dim=-1), - attention_mask) # (total_nnz, 1) + input_ids, indices, cu_seqlens, max_seqlen_in_batch, *_ = unpad_input(input_ids.unsqueeze(dim=-1), + attention_mask) # (total_nnz, 1) # pad input_ids to multiple of tp for all tp ranks # TODO: for better performance, the sp padding should be removed at each layer. Not sure the performance gap @@ -482,8 +481,6 @@ def forward(self, """ if self.pre_process: - # if torch.cuda.current_device() == 0: - # print(f'rank {torch.cuda.current_device()}: input_ids shape before embedding: {input_ids.shape}') inputs_embeds = self.embed_tokens(input_ids) # (1, total_nnz) -> (1, total_nnz, hidden_size) # vocab parallel embedding will not do sequence parallel reduce-scatter in open source megatron @@ -493,8 +490,6 @@ def forward(self, if self.megatron_config.sequence_parallel: inputs_embeds = tensor_parallel.scatter_to_sequence_parallel_region(inputs_embeds) - # if torch.cuda.current_device() == 0: - # print(f'rank {torch.cuda.current_device()}: input_embeds shape after embedding: {inputs_embeds.shape}') hidden_states = inputs_embeds else: # self.hidden_states should be passed by Megatron @@ -586,9 +581,9 @@ def forward( # In the first pp, input_ids will be used, in other pp layers hidden_states will be used inside self.model batch_size, sequence_length = input_ids.shape # remove padding here - input_ids_rmpad, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(input_ids.unsqueeze(dim=-1), - attention_mask) # (total_nnz, 1) - # print(f'input_ids.shape = {input_ids.shape}, input_ids_rmpad.shape = {input_ids_rmpad.shape}, indices.shape = {indices.shape}, cu_seqlens[-1] = {cu_seqlens[-1]}') + input_ids_rmpad, indices, cu_seqlens, max_seqlen_in_batch, *_ = unpad_input(input_ids.unsqueeze(dim=-1), + attention_mask) # (total_nnz, 1) + # pad input_ids to multiple of tp for all tp ranks # TODO: for better performance, the sp padding should be removed at each layer. Not sure the performance gap if self.megatron_config.sequence_parallel: @@ -607,7 +602,6 @@ def forward( hidden_states = outputs # print(f'hidden_states.shape = {hidden_states.shape}') # torch.Size([4, 32, 4096]) logits = self._forward_head(hidden_states) - # print(f'logits.shape = {logits.shape}') logits = torch.squeeze(logits, dim=1) # remove the artificial batch dimension # torch.Size([8, 32, 16]) # remove padding from sequence parallel @@ -615,7 +609,6 @@ def forward( totol_nnz = cu_seqlens[-1] logits = logits[:totol_nnz] # (total_nnz_padded) # add removed padding back. If input is already rmpad, we let the caller pad_input - # print(f'logits.shape = {logits.shape}, indices.shape = {indices.shape}, batch_size = {batch_size}, seq_len = {sequence_length}') logits = pad_input(logits, indices, batch_size, seqlen=sequence_length) # (batch_size, sequence_length, vocab_size) @@ -660,4 +653,4 @@ def forward( output.logits = torch.squeeze(output.logits, dim=-1) return output else: - return output \ No newline at end of file + return output diff --git a/verl/third_party/vllm/vllm_v_0_4_2/megatron_weight_loaders.py b/verl/third_party/vllm/vllm_v_0_4_2/megatron_weight_loaders.py index 281de83a..1a7c2e2c 100644 --- a/verl/third_party/vllm/vllm_v_0_4_2/megatron_weight_loaders.py +++ b/verl/third_party/vllm/vllm_v_0_4_2/megatron_weight_loaders.py @@ -282,7 +282,6 @@ def mistral_megatron_weight_loader(actor_weights: Dict, vllm_model: nn.Module) - 'LlamaForCausalLM': llama_megatron_core_te_weight_loader, # use te backend for open-source megatron 'LLaMAForCausalLM': llama_megatron_core_te_weight_loader, 'MistralForCausalLM': mistral_megatron_weight_loader, - 'Qwen2ForCausalLM': llama_megatron_weight_loader, } diff --git a/verl/third_party/vllm/vllm_v_0_5_4/megatron_weight_loaders.py b/verl/third_party/vllm/vllm_v_0_5_4/megatron_weight_loaders.py index b5441f36..4f2b19a9 100644 --- a/verl/third_party/vllm/vllm_v_0_5_4/megatron_weight_loaders.py +++ b/verl/third_party/vllm/vllm_v_0_5_4/megatron_weight_loaders.py @@ -282,7 +282,6 @@ def mistral_megatron_weight_loader(actor_weights: Dict, vllm_model: nn.Module) - 'LlamaForCausalLM': llama_megatron_weight_loader, # use te backend for open-source megatron 'LLaMAForCausalLM': llama_megatron_weight_loader, 'MistralForCausalLM': mistral_megatron_weight_loader, - 'Qwen2ForCausalLM': llama_megatron_weight_loader, } From 2e2ab194f4730911b0e225491ff09ea3e02fa787 Mon Sep 17 00:00:00 2001 From: kinman0224 Date: Mon, 17 Feb 2025 13:23:46 +0800 Subject: [PATCH 4/6] update ci --- .github/workflows/e2e_gsm8k_megatron.yml | 8 +++-- tests/e2e/run_qwen_megatron.sh | 41 ++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 2 deletions(-) create mode 100644 tests/e2e/run_qwen_megatron.sh diff --git a/.github/workflows/e2e_gsm8k_megatron.yml b/.github/workflows/e2e_gsm8k_megatron.yml index 305d1724..feef98aa 100644 --- a/.github/workflows/e2e_gsm8k_megatron.yml +++ b/.github/workflows/e2e_gsm8k_megatron.yml @@ -41,9 +41,13 @@ jobs: - name: Prepare gsm8k dataset run: | python3 examples/data_preprocess/gsm8k.py - - name: Running gsm8k e2e training tests on 8 L20 GPUs with Megatron + - name: Running gsm8k e2e training tests on 8 L20 GPUs with Megatron (Deepseek) run: | ray stop --force [ ! -d "$HOME/Megatron-LM" ] && git clone -b core_v0.4.0_verl https://github.com/eric-haibin-lin/Megatron-LM $HOME/Megatron-LM export PYTHONPATH=$PYTHONPATH:$HOME/Megatron-LM - bash tests/e2e/run_deepseek_megatron.sh \ No newline at end of file + bash tests/e2e/run_deepseek_megatron.sh + - name: Running gsm8k e2e training tests on 8 L20 GPUs with Megatron (Qwen) + run: | + ray stop --force + bash tests/e2e/run_qwen_megatron.sh \ No newline at end of file diff --git a/tests/e2e/run_qwen_megatron.sh b/tests/e2e/run_qwen_megatron.sh new file mode 100644 index 00000000..daf78a03 --- /dev/null +++ b/tests/e2e/run_qwen_megatron.sh @@ -0,0 +1,41 @@ +set -x + +# the config file used: verl/trainer/main_ppo/config/ppo_megatron_trainer.yaml + +huggingface-cli download Qwen/Qwen2.5-0.5B + +python3 -m verl.trainer.main_ppo --config-path=config \ + --config-name='ppo_megatron_trainer.yaml'\ + data.train_files=$HOME/data/gsm8k/train.parquet \ + data.val_files=$HOME/data/gsm8k/test.parquet \ + data.train_batch_size=1024 \ + data.val_batch_size=1312 \ + data.max_prompt_length=512 \ + data.max_response_length=512 \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B \ + actor_rollout_ref.actor.optim.lr=2e-6 \ + actor_rollout_ref.actor.ppo_mini_batch_size=256 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.actor.megatron.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \ + actor_rollout_ref.ref.megatron.tensor_model_parallel_size=2 \ + critic.optim.lr=2e-5 \ + critic.model.path=Qwen/Qwen2.5-0.5B \ + critic.model.enable_gradient_checkpointing=False \ + critic.ppo_micro_batch_size_per_gpu=4 \ + critic.megatron.tensor_model_parallel_size=2 \ + algorithm.kl_ctrl.kl_coef=0.001 \ + trainer.critic_warmup=0 \ + trainer.logger=['console'] \ + trainer.project_name='verl_megatron_gsm8k_examples' \ + trainer.experiment_name='qwen2_5_0b5_function_rm' \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=-1 \ + trainer.test_freq=1 \ + trainer.total_epochs=15 \ + trainer.total_training_steps=3 $@ From e7e9e569deed4ec144cbfb3e6386165b99f42483 Mon Sep 17 00:00:00 2001 From: kinman0224 Date: Mon, 17 Feb 2025 14:04:59 +0800 Subject: [PATCH 5/6] update ci --- .github/workflows/e2e_gsm8k_megatron.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/e2e_gsm8k_megatron.yml b/.github/workflows/e2e_gsm8k_megatron.yml index feef98aa..4932c3f4 100644 --- a/.github/workflows/e2e_gsm8k_megatron.yml +++ b/.github/workflows/e2e_gsm8k_megatron.yml @@ -50,4 +50,5 @@ jobs: - name: Running gsm8k e2e training tests on 8 L20 GPUs with Megatron (Qwen) run: | ray stop --force + export PYTHONPATH=$PYTHONPATH:$HOME/Megatron-LM bash tests/e2e/run_qwen_megatron.sh \ No newline at end of file From 45daa8f6f57c95e67887996c09a7afe1d870d812 Mon Sep 17 00:00:00 2001 From: kinman0224 Date: Wed, 19 Feb 2025 00:26:54 +0800 Subject: [PATCH 6/6] fix share_embeddings_and_output_weights problem --- .../llama/megatron/modeling_llama_megatron.py | 5 +- .../qwen2/megatron/modeling_qwen2_megatron.py | 64 +++++++++++++++++-- .../vllm_v_0_6_3/megatron_weight_loaders.py | 14 +++- verl/utils/model.py | 13 +++- verl/workers/actor/megatron_actor.py | 5 ++ verl/workers/megatron_workers.py | 15 +++-- 6 files changed, 101 insertions(+), 15 deletions(-) diff --git a/verl/models/llama/megatron/modeling_llama_megatron.py b/verl/models/llama/megatron/modeling_llama_megatron.py index c693f33c..a951660a 100644 --- a/verl/models/llama/megatron/modeling_llama_megatron.py +++ b/verl/models/llama/megatron/modeling_llama_megatron.py @@ -513,7 +513,8 @@ def forward(self, class ParallelLlamaForCausalLMRmPadPP(nn.Module): - def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig, pre_process, post_process): + def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig, pre_process, post_process, + share_embeddings_and_output_weights): super().__init__() self.config = config self.megatron_config = megatron_config @@ -521,7 +522,7 @@ def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig, pr megatron_config=megatron_config, pre_process=pre_process, post_process=post_process) - self.share_embeddings_and_output_weights = None # workaround, megatron requires this attr + self.share_embeddings_and_output_weights = share_embeddings_and_output_weights self.vocab_size = config.vocab_size self.pre_process = pre_process self.post_process = post_process diff --git a/verl/models/qwen2/megatron/modeling_qwen2_megatron.py b/verl/models/qwen2/megatron/modeling_qwen2_megatron.py index ae3372c8..fefa6535 100644 --- a/verl/models/qwen2/megatron/modeling_qwen2_megatron.py +++ b/verl/models/qwen2/megatron/modeling_qwen2_megatron.py @@ -23,7 +23,7 @@ import torch import torch.utils.checkpoint -from megatron.core import tensor_parallel +from megatron.core import tensor_parallel, parallel_state from megatron.core import ModelParallelConfig from torch import nn from transformers.modeling_outputs import BaseModelOutputWithPast @@ -513,7 +513,8 @@ def forward(self, class ParallelQwen2ForCausalLMRmPadPP(nn.Module): - def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig, pre_process, post_process): + def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig, pre_process, post_process, + share_embeddings_and_output_weights): super().__init__() self.config = config self.megatron_config = megatron_config @@ -521,12 +522,14 @@ def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig, pr megatron_config=megatron_config, pre_process=pre_process, post_process=post_process) - self.share_embeddings_and_output_weights = None # workaround, megatron requires this attr + self.share_embeddings_and_output_weights = share_embeddings_and_output_weights self.vocab_size = config.vocab_size self.pre_process = pre_process self.post_process = post_process if post_process: self._init_head() + if pre_process or post_process: + self.setup_embeddings_and_output_layer() def set_input_tensor(self, input_tensor): """Set input tensor to be used instead of forward()'s input. @@ -549,12 +552,64 @@ def _init_head(self): bias=False, gather_output=False, skip_bias_add=False, + skip_weight_param_allocation=self.pre_process and + self.share_embeddings_and_output_weights, **column_kwargs) + def setup_embeddings_and_output_layer(self) -> None: + """Sets up embedding layer in first stage and output layer in last stage. + + This function initalizes word embeddings in the final stage when we are + using pipeline parallelism and sharing word embeddings, and sets up param + attributes on the embedding and output layers. + """ + # Set `is_embedding_or_output_parameter` attribute. + if self.pre_process: + self.model.embed_tokens.weight.is_embedding_or_output_parameter = True + if self.post_process and self.lm_head.weight is not None: + self.lm_head.weight.is_embedding_or_output_parameter = True + + if not self.share_embeddings_and_output_weights: + return + + if parallel_state.get_pipeline_model_parallel_world_size() == 1: + # Zero out wgrad if sharing embeddings between two layers on same + # pipeline stage to make sure grad accumulation into main_grad is + # correct and does not include garbage values (e.g., from torch.empty). + self.shared_embedding_or_output_weight().zero_out_wgrad = True + return + + if parallel_state.is_pipeline_first_stage() and self.pre_process and not self.post_process: + self.shared_embedding_or_output_weight().shared_embedding = True + + if self.post_process and not self.pre_process: + assert not parallel_state.is_pipeline_first_stage() + # set word_embeddings weights to 0 here, then copy first + # stage's weights using all_reduce below. + self.lm_head.weight.data.fill_(0) + self.lm_head.weight.shared = True + self.lm_head.weight.shared_embedding = True + + if torch.distributed.is_initialized(): + if parallel_state.is_rank_in_embedding_group(): + weight = self.shared_embedding_or_output_weight() + weight.data = weight.data.cuda() + torch.distributed.all_reduce(weight.data, group=parallel_state.get_embedding_group()) + + def shared_embedding_or_output_weight(self) -> torch.Tensor: + if self.pre_process: + return self.model.embed_tokens.weight + elif self.post_process: + return self.lm_head.weight + return None + def _forward_head(self, hidden_states): # all_gather from sequence parallel region is performed inside lm_head # print(f'logits shape before forward_head: {hidden_states.shape}, vocab_size = {self.config.vocab_size}') # [4, 32, 4096] - logits = self.lm_head(hidden_states)[0] + output_weight = None + if self.share_embeddings_and_output_weights: + output_weight = self.shared_embedding_or_output_weight() + logits = self.lm_head(hidden_states, weight=output_weight)[0] # print(f'logits shape after forward_head: {logits.shape}') # [8, 32, 8] logits = logits.float() # (total_nnz_padded, 1, vocab_size // tp) return logits @@ -600,7 +655,6 @@ def forward( if self.post_process: hidden_states = outputs - # print(f'hidden_states.shape = {hidden_states.shape}') # torch.Size([4, 32, 4096]) logits = self._forward_head(hidden_states) logits = torch.squeeze(logits, dim=1) # remove the artificial batch dimension # torch.Size([8, 32, 16]) diff --git a/verl/third_party/vllm/vllm_v_0_6_3/megatron_weight_loaders.py b/verl/third_party/vllm/vllm_v_0_6_3/megatron_weight_loaders.py index ae8588f8..4a8a5933 100644 --- a/verl/third_party/vllm/vllm_v_0_6_3/megatron_weight_loaders.py +++ b/verl/third_party/vllm/vllm_v_0_6_3/megatron_weight_loaders.py @@ -83,6 +83,18 @@ def llama_megatron_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> weight_loader(param, loaded_weight) +def qwen2_megatron_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: + params_dict = dict(vllm_model.named_parameters()) + for name, loaded_weight in actor_weights.items(): + if "rotary_emb.inv_freq" in name: + continue + if vllm_model.config.tie_word_embeddings and "lm_head.weight" in name: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + def llama_megatron_core_te_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: params_mapping = [ # (megatron core gpt model name, vllm model name) @@ -283,7 +295,7 @@ def mistral_megatron_weight_loader(actor_weights: Dict, vllm_model: nn.Module) - "LlamaForCausalLM": llama_megatron_weight_loader, # use te backend for open-source megatron "LLaMAForCausalLM": llama_megatron_weight_loader, "MistralForCausalLM": mistral_megatron_weight_loader, - 'Qwen2ForCausalLM': llama_megatron_weight_loader, + 'Qwen2ForCausalLM': qwen2_megatron_weight_loader, } diff --git a/verl/utils/model.py b/verl/utils/model.py index f319e400..9ad30ba0 100644 --- a/verl/utils/model.py +++ b/verl/utils/model.py @@ -248,12 +248,21 @@ def normalize_model_name(name, pp_rank, vpp_rank, pp_size, vpp_size, num_layers) return normalized_name_to_param -def get_parallel_model_from_config(config, megatron_config, pre_process=None, post_process=None, value=False): +def get_parallel_model_from_config(config, + megatron_config, + pre_process=None, + post_process=None, + share_embeddings_and_output_weights=False, + value=False): from megatron.core import ModelParallelConfig assert isinstance(megatron_config, ModelParallelConfig) model_class = _get_parallel_model_architecture_from_config(config, value) - model = model_class(config, megatron_config, pre_process=pre_process, post_process=post_process) + model = model_class(config, + megatron_config, + pre_process=pre_process, + post_process=post_process, + share_embeddings_and_output_weights=share_embeddings_and_output_weights) return model diff --git a/verl/workers/actor/megatron_actor.py b/verl/workers/actor/megatron_actor.py index 348c1576..ce0924e2 100644 --- a/verl/workers/actor/megatron_actor.py +++ b/verl/workers/actor/megatron_actor.py @@ -30,7 +30,9 @@ from verl.utils.megatron.optimizer_config import OptimizerConfig from megatron.core import parallel_state as mpu from megatron.core import ModelParallelConfig +from megatron.core.utils import get_model_config from megatron.core.pipeline_parallel import get_forward_backward_func +from megatron.core.distributed import finalize_model_grads # from megatron.core.optimizer import DistributedOptimizer from omegaconf import OmegaConf @@ -127,6 +129,9 @@ def __init__(self, config, model_config, megatron_config: ModelParallelConfig, a 'reduce_grads_use_alltoall': False }) + config = get_model_config(self.actor_module[0]) + config.finalize_model_grads_func = finalize_model_grads + def _validate_config(self, config) -> None: """Validate config options not implemented for Megatron backend""" assert config.get('ulysses_sequence_parallel_size', 1) == 1 diff --git a/verl/workers/megatron_workers.py b/verl/workers/megatron_workers.py index 894d64cf..6aded997 100644 --- a/verl/workers/megatron_workers.py +++ b/verl/workers/megatron_workers.py @@ -165,11 +165,14 @@ def megatron_actor_model_provider(pre_process, post_process): vpp_rank = mpu.get_virtual_pipeline_model_parallel_rank() # this will be set inside get_model # this_megatron_config = copy.deepcopy(megatron_config) # this_megatron_config.virtual_pipeline_model_parallel_rank = vpp_rank - parallel_model = get_parallel_model_from_config(config=actor_model_config, - megatron_config=megatron_config, - pre_process=pre_process, - post_process=post_process, - value=False) + share_embeddings_and_output_weights = getattr(actor_model_config, "tie_word_embeddings", False) + parallel_model = get_parallel_model_from_config( + config=actor_model_config, + megatron_config=megatron_config, + pre_process=pre_process, + post_process=post_process, + share_embeddings_and_output_weights=share_embeddings_and_output_weights, + value=False) parallel_model.cuda() return parallel_model @@ -509,6 +512,7 @@ def megatron_critic_model_provider(pre_process, post_process): megatron_config=megatron_config, pre_process=pre_process, post_process=post_process, + share_embeddings_and_output_weights=False, value=True) parallel_model.cuda() return parallel_model @@ -673,6 +677,7 @@ def megatron_rm_model_provider(pre_process, post_process): megatron_config=megatron_config, pre_process=pre_process, post_process=post_process, + share_embeddings_and_output_weights=False, value=True) parallel_model.cuda() return parallel_model