From 4e424201192cf1fa0146ff3581b24d93d9cd015f Mon Sep 17 00:00:00 2001 From: 1000960000 user Date: Wed, 7 Aug 2024 12:50:25 +0000 Subject: [PATCH] extracted cum seq len computation to collator --- .../src/fms_acceleration_aadp/aadp_utils.py | 24 ++++++- .../src/fms_acceleration_aadp/flash_attn.py | 62 +++++++------------ .../framework_plugin_padding_free.py | 15 ++++- 3 files changed, 59 insertions(+), 42 deletions(-) diff --git a/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/aadp_utils.py b/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/aadp_utils.py index 330bf5eb..08314e68 100644 --- a/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/aadp_utils.py +++ b/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/aadp_utils.py @@ -15,7 +15,7 @@ from dataclasses import dataclass import warnings from transformers import DefaultDataCollator, default_data_collator - +import torch @dataclass class DataCollatorWithFlattening(DefaultDataCollator): @@ -51,4 +51,24 @@ def __call__(self, features, return_tensors=None): ret["labels"] += [-100] + feature["labels"][1:] else: ret["labels"] += [-100] + feature["input_ids"][1:] - return default_data_collator([ret], return_tensors) + + position_ids = torch.tensor(ret["position_ids"]).flatten() + indices_q = torch.arange( + position_ids.size(0), device=position_ids.device, dtype=torch.int32 + ) + cu_seq_lens = torch.cat( + ( + indices_q[position_ids == 0], + torch.tensor( + position_ids.size(), dtype=torch.int32 + ), + ) + ) + max_length = position_ids.max() + 1 + + # return default_data_collator([ret], return_tensors) + return { + **default_data_collator([ret], return_tensors), + "cu_seq_lens": cu_seq_lens, + "max_length": max_length, + } diff --git a/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/flash_attn.py b/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/flash_attn.py index 26e26d01..1ecb978c 100644 --- a/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/flash_attn.py +++ b/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/flash_attn.py @@ -32,36 +32,27 @@ inspect.signature(flash_attn_func).parameters ) +# model id -> position_ids +POSITION_IDS_CACHE = {} +CU_SEQ_LENS_CACHE = {} +MAX_SEQ_LENS_CACHE = {} -def prepare_fa2_from_position_ids(query, key, value, position_ids, query_length): - query = query.view(-1, query.size(-2), query.size(-1)) - key = key.view(-1, key.size(-2), key.size(-1)) - value = value.view(-1, value.size(-2), value.size(-1)) - position_ids = position_ids.flatten() - indices_q = torch.arange( - position_ids.size(0), device=position_ids.device, dtype=torch.int32 - ) - cu_seq_lens = torch.cat( - ( - indices_q[position_ids == 0], - torch.tensor( - position_ids.size(), device=position_ids.device, dtype=torch.int32 - ), - ) - ) - max_length = position_ids.max() + 1 - return ( - query, - key, - value, - indices_q, - (cu_seq_lens, cu_seq_lens), - (max_length, max_length), - ) +# This is used to patch the top-level model to accept cuseqlen +# and maxseqlen as additional args that are cached for attention +# computation +def build_toplevel_model_forward( + model: torch.nn.Module, + model_id: str, +): + # forward + old_forward = model.forward + def forward(self, *args, cu_seq_lens, max_length, **kwargs): + CU_SEQ_LENS_CACHE[model_id] = (cu_seq_lens, cu_seq_lens) + MAX_SEQ_LENS_CACHE[model_id] = (max_length, max_length) + return old_forward(*args, **kwargs) -# model id -> position_ids -POSITION_IDS_CACHE = {} + return forward # - needed to store position ids when first come into model @@ -123,6 +114,8 @@ def _flash_attention_forward_with_posids( ): # get the position ids out here position_ids = POSITION_IDS_CACHE[model_id] + cu_seqlens_q, cu_seqlens_k = CU_SEQ_LENS_CACHE[model_id] + max_seqlen_in_batch_q, max_seqlen_in_batch_k = MAX_SEQ_LENS_CACHE[model_id] if not use_top_left_mask: causal = is_causal @@ -161,19 +154,10 @@ def _flash_attention_forward_with_posids( assert attention_mask is None, "should not be using attention mask" assert position_ids is not None, "should be expecting position ids" batch_size = query_states.size(0) - ( - query_states, - key_states, - value_states, - _, - cu_seq_lens, - max_seq_lens, - ) = prepare_fa2_from_position_ids( - query_states, key_states, value_states, position_ids, query_length - ) - cu_seqlens_q, cu_seqlens_k = cu_seq_lens - max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + query_states = query_states.view(-1, query_states.size(-2), query_states.size(-1)) + key_states = key_states.view(-1, key_states.size(-2), key_states.size(-1)) + value_states = value_states.view(-1, value_states.size(-2), value_states.size(-1)) attn_output = flash_attn_varlen_func( query_states, diff --git a/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/framework_plugin_padding_free.py b/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/framework_plugin_padding_free.py index 4680313e..70d6fcbf 100644 --- a/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/framework_plugin_padding_free.py +++ b/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/framework_plugin_padding_free.py @@ -93,7 +93,7 @@ def _is_backbone(module: torch.nn.Module): # - patch backbone model_type = model.config.model_type # pylint: disable=import-outside-toplevel - from .flash_attn import build_backbone_forward + from .flash_attn import build_backbone_forward, build_toplevel_model_forward ModelPatcher.register( ModelPatcherRule( @@ -106,6 +106,19 @@ def _is_backbone(module: torch.nn.Module): ), ) + # Need to patch the top-level model to accept and cache additional + # kwargs, cu_seq_lens and max_len from data collator + ModelPatcher.register( + ModelPatcherRule( + rule_id=f"{model_type}-cumseqlen-cache", + trigger=ModelPatcherTrigger(check=model.__class__), + forward_builder=partial( + build_toplevel_model_forward, + model_id=id(model), + ), + ), + ) + # Next, the flash attention function needs to be patched # how it is patched depends on the transformers version try: