From 118c349219de82bca0bc5c7d51d0c70854a2474d Mon Sep 17 00:00:00 2001 From: Qinghao Hu Date: Fri, 16 Aug 2024 03:02:03 +0800 Subject: [PATCH] Add ZigzagRing Support (#157) --- llava/data/dataset.py | 50 ++++++++++----- llava/model/llava_arch.py | 61 ++++++++++++++++--- llava/train/args.py | 4 +- llava/train/sequence_parallel/globals.py | 16 +++-- llava/train/sequence_parallel/hybrid_attn.py | 19 +----- llava/train/sequence_parallel/input_utils.py | 61 +++++++------------ llava/train/sequence_parallel/monkey_patch.py | 40 ++++++++---- .../ring/zigzag_ring_flash_attn_varlen.py | 1 + llava/train/train.py | 2 +- 9 files changed, 161 insertions(+), 93 deletions(-) diff --git a/llava/data/dataset.py b/llava/data/dataset.py index 3a7b32f9..aa26081c 100755 --- a/llava/data/dataset.py +++ b/llava/data/dataset.py @@ -2590,6 +2590,7 @@ class DataCollatorForSupervisedDatasetSeqParallel: sp_degree: int sp_rank: int ring_degree: int + ring_type: str def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: input_ids, labels, images = [], [], [] @@ -2689,20 +2690,39 @@ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: # Handle RingAttn_Varlen which requires `seqlens_in_batch` should be divisible by `ring_degree` if self.ring_degree > 1: RING_PAD_TOKEN_INDEX = 2 - if num_incoming_tokens % self.sp_degree != 0: - pad_len = self.sp_degree - num_incoming_tokens % self.sp_degree - num_incoming_tokens += pad_len - # pad `input_ids` - pad_tensor = torch.full( - (pad_len,), RING_PAD_TOKEN_INDEX, dtype=sorted_ids[i].dtype, device=sorted_ids[i].device - ) - sorted_ids[i] = torch.cat([sorted_ids[i], pad_tensor]) - - # pad `label` - pad_label_tensor = torch.full( - (pad_len,), IGNORE_INDEX, dtype=sorted_labels[i].dtype, device=sorted_labels[i].device - ) - sorted_labels[i] = torch.cat([sorted_labels[i], pad_label_tensor]) + if self.ring_type == "ring_varlen": + if num_incoming_tokens % self.sp_degree != 0: + pad_len = self.sp_degree - num_incoming_tokens % self.sp_degree + num_incoming_tokens += pad_len + # pad `input_ids` + pad_tensor = torch.full( + (pad_len,), RING_PAD_TOKEN_INDEX, dtype=sorted_ids[i].dtype, device=sorted_ids[i].device + ) + sorted_ids[i] = torch.cat([sorted_ids[i], pad_tensor]) + + # pad `label` + pad_label_tensor = torch.full( + (pad_len,), IGNORE_INDEX, dtype=sorted_labels[i].dtype, device=sorted_labels[i].device + ) + sorted_labels[i] = torch.cat([sorted_labels[i], pad_label_tensor]) + elif self.ring_type == "zigzag_ring_varlen": + self.zigzag_sp_degree = self.sp_degree * 2 + if num_incoming_tokens % self.zigzag_sp_degree != 0: + pad_len = self.zigzag_sp_degree - num_incoming_tokens % self.zigzag_sp_degree + num_incoming_tokens += pad_len + # pad `input_ids` + pad_tensor = torch.full( + (pad_len,), RING_PAD_TOKEN_INDEX, dtype=sorted_ids[i].dtype, device=sorted_ids[i].device + ) + sorted_ids[i] = torch.cat([sorted_ids[i], pad_tensor]) + + # pad `label` + pad_label_tensor = torch.full( + (pad_len,), IGNORE_INDEX, dtype=sorted_labels[i].dtype, device=sorted_labels[i].device + ) + sorted_labels[i] = torch.cat([sorted_labels[i], pad_label_tensor]) + else: + raise ValueError(f"Invalid ring_type: {self.ring_type}") if num_incoming_tokens > max_seq_length: print( @@ -2855,6 +2875,7 @@ def make_supervised_data_module( sp_degree = training_args.seq_parallel_size sp_rank = PROCESS_GROUP_MANAGER.sp_rank ring_degree = PROCESS_GROUP_MANAGER.ring_degree + ring_type = PROCESS_GROUP_MANAGER.ring_type data_collator = DataCollatorForSupervisedDatasetSeqParallel( tokenizer=tokenizer, data_args=data_args, @@ -2862,6 +2883,7 @@ def make_supervised_data_module( sp_degree=sp_degree, sp_rank=sp_rank, ring_degree=ring_degree, + ring_type=ring_type, ) return dict( diff --git a/llava/model/llava_arch.py b/llava/model/llava_arch.py index 74211dc6..1a7b7b47 100755 --- a/llava/model/llava_arch.py +++ b/llava/model/llava_arch.py @@ -538,6 +538,10 @@ def repack_multimodal_data( sp_rank = PROCESS_GROUP_MANAGER.sp_rank sp_group = PROCESS_GROUP_MANAGER.sp_pg ring_degree = PROCESS_GROUP_MANAGER.ring_degree + ring_rank = PROCESS_GROUP_MANAGER.ring_rank + ring_type = PROCESS_GROUP_MANAGER.ring_type + ulysses_degree = PROCESS_GROUP_MANAGER.ulysses_degree + ulysses_rank = PROCESS_GROUP_MANAGER.ulysses_rank bs, shard_seqlen = position_ids.shape sp_seq_len = [torch.zeros(1, dtype=torch.int64, device=position_ids.device) for _ in range(sp_degree)] @@ -642,13 +646,56 @@ def repack_multimodal_data( dtype=global_inputs_embeds.dtype, device=global_inputs_embeds.device, ) - for i in range(bs): - start_idx = new_seqlen_per_rank[i] * sp_rank - end_idx = start_idx + new_seqlen_per_rank[i] - new_attention_mask[i, : new_seqlen_per_rank[i]] = global_attention_mask[i, start_idx:end_idx] - new_position_ids[i, : new_seqlen_per_rank[i]] = global_position_ids[i, start_idx:end_idx] - new_labels[i, : new_seqlen_per_rank[i]] = global_labels[i, start_idx:end_idx] - new_inputs_embeds[i, : new_seqlen_per_rank[i], :] = global_inputs_embeds[i, start_idx:end_idx, :] + + if ring_type == "ring_varlen": + for i in range(bs): + start_idx = new_seqlen_per_rank[i] * sp_rank + end_idx = start_idx + new_seqlen_per_rank[i] + new_attention_mask[i, : new_seqlen_per_rank[i]] = global_attention_mask[i, start_idx:end_idx] + new_position_ids[i, : new_seqlen_per_rank[i]] = global_position_ids[i, start_idx:end_idx] + new_labels[i, : new_seqlen_per_rank[i]] = global_labels[i, start_idx:end_idx] + new_inputs_embeds[i, : new_seqlen_per_rank[i], :] = global_inputs_embeds[ + i, start_idx:end_idx, : + ] + elif ring_type == "zigzag_ring_varlen": + chunk_size = total_effective_seqlen // (2 * sp_degree) + for i in range(bs): + # Zigzag pattern indices + if sp_degree == ring_degree: + forward_rank_idx = sp_rank + backward_rank_idx = 2 * sp_degree - sp_rank - 1 + else: + ulysses_offset = ulysses_rank * ring_degree * 2 + forward_rank_idx = ring_rank + ulysses_offset + backward_rank_idx = sp_degree - ring_rank - 1 + ulysses_offset + + # Calculate start and end indices for the forward and backward zigzag + start_idx_fwd = forward_rank_idx * chunk_size[i] + end_idx_fwd = start_idx_fwd + chunk_size[i] + + start_idx_bwd = backward_rank_idx * chunk_size[i] + end_idx_bwd = start_idx_bwd + chunk_size[i] + + # Fill new tensors with zigzag data + new_attention_mask[i, : chunk_size[i]] = global_attention_mask[i, start_idx_fwd:end_idx_fwd] + new_attention_mask[i, chunk_size[i] : 2 * chunk_size[i]] = global_attention_mask[ + i, start_idx_bwd:end_idx_bwd + ] + + new_position_ids[i, : chunk_size[i]] = global_position_ids[i, start_idx_fwd:end_idx_fwd] + new_position_ids[i, chunk_size[i] : 2 * chunk_size[i]] = global_position_ids[ + i, start_idx_bwd:end_idx_bwd + ] + + new_labels[i, : chunk_size[i]] = global_labels[i, start_idx_fwd:end_idx_fwd] + new_labels[i, chunk_size[i] : 2 * chunk_size[i]] = global_labels[i, start_idx_bwd:end_idx_bwd] + + new_inputs_embeds[i, : chunk_size[i], :] = global_inputs_embeds[i, start_idx_fwd:end_idx_fwd, :] + new_inputs_embeds[i, chunk_size[i] : 2 * chunk_size[i], :] = global_inputs_embeds[ + i, start_idx_bwd:end_idx_bwd, : + ] + else: + raise ValueError(f"Invalid ring_type: {ring_type}") else: global_seq_len = global_attention_mask.shape[-1] seq_len_sharded = global_seq_len // sp_degree diff --git a/llava/train/args.py b/llava/train/args.py index 8b888da8..e604b6e3 100755 --- a/llava/train/args.py +++ b/llava/train/args.py @@ -111,7 +111,9 @@ class TrainingArguments(transformers.TrainingArguments): ) seq_parallel_ring_type: str = field( default="ring_varlen", - metadata={"help": "Ring Attention implementation."}, + metadata={ + "help": "Ring Attention implementation. Support ['ring_varlen', 'zigzag_ring_varlen'] in 2D attention. Only works when `seq_parallel_ring_size` > 1." + }, ) debug_e2e: bool = field( default=False, diff --git a/llava/train/sequence_parallel/globals.py b/llava/train/sequence_parallel/globals.py index 738c01b4..44ccbbf0 100755 --- a/llava/train/sequence_parallel/globals.py +++ b/llava/train/sequence_parallel/globals.py @@ -39,10 +39,11 @@ class ProcessGroupManager(Singleton): sp_degree = sp_ring_degree x sp_ulysses_degree """ - def __init__(self, ulysses_degree, ring_degree, dp_degree, use_ulysses_low): + def __init__(self, ulysses_degree, ring_degree, dp_degree, use_ulysses_low, ring_type): if not hasattr(self, "__initialized"): super().__init__() self.ulysses_degree = ulysses_degree + self.ring_type = ring_type self.ulysses_seq_len = None self.ring_degree = ring_degree @@ -148,7 +149,7 @@ def __init__(self, ulysses_degree, ring_degree, dp_degree, use_ulysses_low): PROCESS_GROUP_MANAGER = None -def set_pg_manager(sp_degree, sp_ring_degree=1, use_ulysses_low=True): +def set_pg_manager(sp_degree, sp_ring_degree=1, use_ulysses_low=True, ring_type=None): """ Set the process group manager for sequence parallelism. sp_degree = sp_ring_degree x sp_ulysses_degree @@ -185,7 +186,9 @@ def set_pg_manager(sp_degree, sp_ring_degree=1, use_ulysses_low=True): # Init the process group manager global PROCESS_GROUP_MANAGER - PROCESS_GROUP_MANAGER = ProcessGroupManager(sp_ulysses_degree, sp_ring_degree, dp_degree, use_ulysses_low) + PROCESS_GROUP_MANAGER = ProcessGroupManager( + sp_ulysses_degree, sp_ring_degree, dp_degree, use_ulysses_low, ring_type + ) def get_pg_manager(): @@ -243,10 +246,15 @@ def get_ring_sp_rank(): def get_ring_sp_pg(): - """Get the Ulysses sequence parallel process group.""" + """Get the RingAttn sequence parallel process group.""" return PROCESS_GROUP_MANAGER.ring_pg +def get_ring_type(): + """Get the RingAttn implementation type.""" + return PROCESS_GROUP_MANAGER.ring_type + + def get_data_parallel_size(): """Get the size of the data parallel group.""" return PROCESS_GROUP_MANAGER.dp_degree diff --git a/llava/train/sequence_parallel/hybrid_attn.py b/llava/train/sequence_parallel/hybrid_attn.py index e7ebb870..b89c7de1 100755 --- a/llava/train/sequence_parallel/hybrid_attn.py +++ b/llava/train/sequence_parallel/hybrid_attn.py @@ -23,19 +23,7 @@ from torch.nn import Module from .all_to_all import SeqAllToAll4D, SeqAllToAll5D -from .globals import ( - get_pg_manager, - get_ring_sp_pg, - get_ring_sp_rank, - get_ring_sp_size, - get_sequence_parallel_pg, - get_sequence_parallel_rank, - get_sequence_parallel_size, - get_ulysses_seq_len, - get_ulysses_sp_pg, - get_ulysses_sp_rank, - get_ulysses_sp_size, -) +from .globals import get_ring_sp_pg, get_ring_type, get_ulysses_sp_pg from .ring import ( ring_flash_attn_func, ring_flash_attn_qkvpacked_func, @@ -54,7 +42,7 @@ "zigzag": zigzag_ring_flash_attn_func, "strip": stripe_flash_attn_func, "ring_varlen": ring_flash_attn_varlen_func, - "zigzag_varlen": zigzag_ring_flash_attn_varlen_func, + "zigzag_ring_varlen": zigzag_ring_flash_attn_varlen_func, } RING_IMPL_QKVPACKED_DICT = { @@ -80,7 +68,6 @@ def __init__( self, scatter_idx: int = 2, gather_idx: int = 1, - ring_impl_type: str = "ring_varlen", use_pack_qkv: bool = False, attention_warper: Module = None, ) -> None: @@ -96,7 +83,7 @@ def __init__( self.scatter_idx = scatter_idx self.gather_idx = gather_idx if attention_warper is None: - self.ring_attn_fn = RING_IMPL_DICT[ring_impl_type] + self.ring_attn_fn = RING_IMPL_DICT[get_ring_type()] else: self.ring_attn_fn = attention_warper diff --git a/llava/train/sequence_parallel/input_utils.py b/llava/train/sequence_parallel/input_utils.py index 5baa58b6..294e837d 100755 --- a/llava/train/sequence_parallel/input_utils.py +++ b/llava/train/sequence_parallel/input_utils.py @@ -17,11 +17,31 @@ import torch -def extract_local_from_list(vaule_list, sp_rank, sp_size): - quotient, remainder = divmod(len(vaule_list), sp_size) +def extract_local_zigzag(value, rank, world_size, device, dim=1): + value_chunks = value.chunk(2 * world_size, dim=dim) + local_value = torch.cat([value_chunks[rank], value_chunks[2 * world_size - rank - 1]], dim=dim) + return local_value.to(device) + + +def extract_local_from_list(value_list, sp_rank, sp_size): + quotient, remainder = divmod(len(value_list), sp_size) start_idx = sp_rank * quotient + min(sp_rank, remainder) end_idx = (sp_rank + 1) * quotient + min(sp_rank + 1, remainder) - return vaule_list[start_idx:end_idx] + return value_list[start_idx:end_idx] + + +def extract_local_from_list_zigzag(value_list, sp_rank, sp_size): + chunk_size, remainder = divmod(len(value_list), (2 * sp_size)) + value_chunks = [] + start_idx = 0 + for i in range(2 * sp_size): + extra = 1 if i < remainder else 0 + end_idx = start_idx + chunk_size + extra + value_chunks.append(value_list[start_idx:end_idx]) + start_idx = end_idx + + local_value = value_chunks[sp_rank] + value_chunks[2 * sp_size - sp_rank - 1] + return local_value def extract_local_input_ids(input_ids, image_positions, sp_rank, sp_size, bos_token_id=1, image_token_len=3): @@ -58,38 +78,3 @@ def extract_local_position_ids(input_ids, image_positions, image_ids, sp_rank, s return input_ids[start_position_idx:] else: return input_ids[start_position_idx:end_position_idx] - - -def extract_local(value, rank, world_size, dim=1): - value_chunks = value.chunk(2 * world_size, dim=dim) - local_value = torch.cat([value_chunks[rank], value_chunks[2 * world_size - rank - 1]], dim=dim) - return local_value - - -def prepare_hybrid_attn_inputs(input_ids, position_ids, target_ids, rank, world_size, device): - local_input_ids = extract_local( - input_ids, - rank, - world_size, - device, - ) - local_position_ids = extract_local( - position_ids, - rank, - world_size, - device, - ) - if target_ids is not None: - local_target_ids = extract_local( - target_ids, - rank, - world_size, - device, - ) - else: - local_target_ids = None - return { - "local_input_ids": local_input_ids, - "local_position_ids": local_position_ids, - "local_target_ids": local_target_ids, - } diff --git a/llava/train/sequence_parallel/monkey_patch.py b/llava/train/sequence_parallel/monkey_patch.py index 44ee3771..240a8f11 100755 --- a/llava/train/sequence_parallel/monkey_patch.py +++ b/llava/train/sequence_parallel/monkey_patch.py @@ -27,7 +27,7 @@ from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.models.llama.modeling_llama import LlamaAttention, _get_unpad_data, apply_rotary_pos_emb -from llava.train.sequence_parallel.globals import get_pg_manager, get_ring_sp_pg, get_ulysses_sp_pg +from llava.train.sequence_parallel.globals import get_pg_manager, get_ring_sp_pg, get_ring_type, get_ulysses_sp_pg from .hybrid_attn import HybridAttention from .ring import ( @@ -167,17 +167,33 @@ def hybrid_attn_varlen_func_helper( # print("rank", dist.get_rank(), "cu_seq_lens", cu_seq_lens) # exit() - attn_output_unpad = ring_flash_attn_varlen_func( - query_states, - key_states, - value_states, - cu_seq_lens, - max_seq_lens[0], - dropout_p=dropout_p, - softmax_scale=softmax_scale, - causal=self.is_causal, - group=group, - ) + ring_type = get_ring_type() + if ring_type == "ring_varlen": + attn_output_unpad = ring_flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seq_lens, + max_seq_lens[0], + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=self.is_causal, + group=group, + ) + elif ring_type == "zigzag_ring_varlen": + attn_output_unpad = zigzag_ring_flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seq_lens, + max_seq_lens[0], + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=self.is_causal, + group=group, + ) + else: + raise ValueError(f"Invalid ring_type: {ring_type}") # print(dist.get_rank(), "finish ring_flash_attn_varlen_func") diff --git a/llava/train/sequence_parallel/ring/zigzag_ring_flash_attn_varlen.py b/llava/train/sequence_parallel/ring/zigzag_ring_flash_attn_varlen.py index ed5c7000..f20d37cc 100755 --- a/llava/train/sequence_parallel/ring/zigzag_ring_flash_attn_varlen.py +++ b/llava/train/sequence_parallel/ring/zigzag_ring_flash_attn_varlen.py @@ -113,6 +113,7 @@ def forward(q, k, v, causal): window_size=window_size, alibi_slopes=alibi_slopes, return_softmax=True and dropout_p > 0, + block_table=None, ) return block_out, block_lse diff --git a/llava/train/train.py b/llava/train/train.py index 3ae22f22..66269e5f 100755 --- a/llava/train/train.py +++ b/llava/train/train.py @@ -426,7 +426,7 @@ def train(): sp_degree = training_args.seq_parallel_size ring_degree = training_args.seq_parallel_ring_size if sp_degree > 1: - set_pg_manager(sp_degree, ring_degree) + set_pg_manager(sp_degree, ring_degree, ring_type=training_args.seq_parallel_ring_type) print(f"Sequence parallelism is enabled, SP = {sp_degree}") resume_path, continue_training = get_checkpoint_path(training_args.output_dir)