Skip to content

Commit

Permalink
extracted cum seq len computation to collator
Browse files Browse the repository at this point in the history
  • Loading branch information
achew010 committed Aug 8, 2024
1 parent 09afb42 commit 4e42420
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from dataclasses import dataclass
import warnings
from transformers import DefaultDataCollator, default_data_collator

import torch

@dataclass
class DataCollatorWithFlattening(DefaultDataCollator):
Expand Down Expand Up @@ -51,4 +51,24 @@ def __call__(self, features, return_tensors=None):
ret["labels"] += [-100] + feature["labels"][1:]
else:
ret["labels"] += [-100] + feature["input_ids"][1:]
return default_data_collator([ret], return_tensors)

position_ids = torch.tensor(ret["position_ids"]).flatten()
indices_q = torch.arange(
position_ids.size(0), device=position_ids.device, dtype=torch.int32
)
cu_seq_lens = torch.cat(
(
indices_q[position_ids == 0],
torch.tensor(
position_ids.size(), dtype=torch.int32
),
)
)
max_length = position_ids.max() + 1

# return default_data_collator([ret], return_tensors)
return {
**default_data_collator([ret], return_tensors),
"cu_seq_lens": cu_seq_lens,
"max_length": max_length,
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,36 +32,27 @@
inspect.signature(flash_attn_func).parameters
)

# model id -> position_ids
POSITION_IDS_CACHE = {}
CU_SEQ_LENS_CACHE = {}
MAX_SEQ_LENS_CACHE = {}

def prepare_fa2_from_position_ids(query, key, value, position_ids, query_length):
query = query.view(-1, query.size(-2), query.size(-1))
key = key.view(-1, key.size(-2), key.size(-1))
value = value.view(-1, value.size(-2), value.size(-1))
position_ids = position_ids.flatten()
indices_q = torch.arange(
position_ids.size(0), device=position_ids.device, dtype=torch.int32
)
cu_seq_lens = torch.cat(
(
indices_q[position_ids == 0],
torch.tensor(
position_ids.size(), device=position_ids.device, dtype=torch.int32
),
)
)
max_length = position_ids.max() + 1
return (
query,
key,
value,
indices_q,
(cu_seq_lens, cu_seq_lens),
(max_length, max_length),
)
# This is used to patch the top-level model to accept cuseqlen
# and maxseqlen as additional args that are cached for attention
# computation
def build_toplevel_model_forward(
model: torch.nn.Module,
model_id: str,
):
# forward
old_forward = model.forward

def forward(self, *args, cu_seq_lens, max_length, **kwargs):
CU_SEQ_LENS_CACHE[model_id] = (cu_seq_lens, cu_seq_lens)
MAX_SEQ_LENS_CACHE[model_id] = (max_length, max_length)
return old_forward(*args, **kwargs)

# model id -> position_ids
POSITION_IDS_CACHE = {}
return forward


# - needed to store position ids when first come into model
Expand Down Expand Up @@ -123,6 +114,8 @@ def _flash_attention_forward_with_posids(
):
# get the position ids out here
position_ids = POSITION_IDS_CACHE[model_id]
cu_seqlens_q, cu_seqlens_k = CU_SEQ_LENS_CACHE[model_id]
max_seqlen_in_batch_q, max_seqlen_in_batch_k = MAX_SEQ_LENS_CACHE[model_id]

if not use_top_left_mask:
causal = is_causal
Expand Down Expand Up @@ -161,19 +154,10 @@ def _flash_attention_forward_with_posids(
assert attention_mask is None, "should not be using attention mask"
assert position_ids is not None, "should be expecting position ids"
batch_size = query_states.size(0)
(
query_states,
key_states,
value_states,
_,
cu_seq_lens,
max_seq_lens,
) = prepare_fa2_from_position_ids(
query_states, key_states, value_states, position_ids, query_length
)

cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
query_states = query_states.view(-1, query_states.size(-2), query_states.size(-1))
key_states = key_states.view(-1, key_states.size(-2), key_states.size(-1))
value_states = value_states.view(-1, value_states.size(-2), value_states.size(-1))

attn_output = flash_attn_varlen_func(
query_states,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def _is_backbone(module: torch.nn.Module):
# - patch backbone
model_type = model.config.model_type
# pylint: disable=import-outside-toplevel
from .flash_attn import build_backbone_forward
from .flash_attn import build_backbone_forward, build_toplevel_model_forward

ModelPatcher.register(
ModelPatcherRule(
Expand All @@ -106,6 +106,19 @@ def _is_backbone(module: torch.nn.Module):
),
)

# Need to patch the top-level model to accept and cache additional
# kwargs, cu_seq_lens and max_len from data collator
ModelPatcher.register(
ModelPatcherRule(
rule_id=f"{model_type}-cumseqlen-cache",
trigger=ModelPatcherTrigger(check=model.__class__),
forward_builder=partial(
build_toplevel_model_forward,
model_id=id(model),
),
),
)

# Next, the flash attention function needs to be patched
# how it is patched depends on the transformers version
try:
Expand Down

0 comments on commit 4e42420

Please sign in to comment.