1414from transformers import PretrainedConfig , PreTrainedModel
1515
1616from tensorrt_llm ._torch import model_config as model_config_lib
17+ from tensorrt_llm ._torch .attention_backend import AttentionMetadata
1718from tensorrt_llm ._torch .attention_backend import \
1819 interface as attention_interface
1920from tensorrt_llm ._torch .attention_backend import utils as attention_utils
@@ -540,9 +541,8 @@ def __init__(
540541 act_layer = nn .GELU
541542
542543 self .model_config = model_config
543- if self .model_config is not None :
544- self .config = model_config .pretrained_config
545- self .config .num_key_value_heads = num_heads
544+ self .config = model_config .pretrained_config
545+ self .config .num_key_value_heads = num_heads
546546
547547 self .num_classes = num_classes
548548 self .global_pool = global_pool
@@ -622,28 +622,31 @@ def __init__(
622622 self .patch_size = patch_size
623623 self .num_cls_tokens = num_cls_tokens
624624 self .num_registers = self .patch_generator .num_registers
625- if self .model_config is not None :
626- self .metadata_cls = attention_utils .get_attention_backend (
627- model_config .attn_backend ).Metadata
628- else :
629- self .metadata_cls = None
630625
631- def prepare_attn_metadata (self , batch_size : int , seq_lengths : List [int ]):
626+ self .metadata_cls = attention_utils .get_attention_backend (
627+ model_config .attn_backend ).Metadata
628+ self .attn_metadata = self .metadata_cls (
629+ max_num_requests = 8192 , # TODO: Make this dynamic
630+ max_num_tokens = model_config .max_num_tokens ,
631+ kv_cache_manager = None ,
632+ )
633+
634+ def prepare_attn_metadata (self , batch_size : int , seq_lengths : List [int ],
635+ attn_metadata : AttentionMetadata ):
632636 """
633637 To simplify the usage of the model, this function aims to fill the metadata for Attention
634638 Call this function before forward pass
635639 """
640+ prompt_lens = seq_lengths
641+ seq_lens = torch .tensor (seq_lengths , dtype = torch .int , pin_memory = True )
636642 request_ids = list (range (1 , batch_size + 1 ))
637- attn_metadata = self .metadata_cls (
638- seq_lens = torch .tensor (seq_lengths , dtype = torch .int ),
639- num_contexts = batch_size ,
640- max_num_requests = batch_size ,
641- max_num_tokens = sum (seq_lengths ),
642- kv_cache_manager = None ,
643- request_ids = request_ids ,
644- prompt_lens = seq_lengths ,
645- )
646- attn_metadata .max_seq_len = max (seq_lengths )
643+
644+ attn_metadata .seq_lens = seq_lens
645+ attn_metadata .num_contexts = batch_size
646+ attn_metadata .request_ids = request_ids
647+ attn_metadata .prompt_lens = prompt_lens
648+ attn_metadata .max_seq_len = seq_lens .max ().item ()
649+
647650 attn_metadata .prepare ()
648651 return attn_metadata
649652
@@ -652,13 +655,11 @@ def forward_features(self, x: torch.Tensor) -> torch.Tensor:
652655 x = self .patch_generator (x )
653656
654657 batch_size , seq_len , hidden_size = x .shape
655- if self .model_config is not None :
656- seq_lengths = [seq_len ] * batch_size
657- attn_metadata = self .prepare_attn_metadata (batch_size , seq_lengths )
658- # Need flatten batch/seq_len for trtllm attention.
659- x = x .reshape (batch_size * seq_len , hidden_size )
660- else :
661- attn_metadata = None
658+ seq_lengths = [seq_len ] * batch_size
659+ attn_metadata = self .prepare_attn_metadata (batch_size , seq_lengths ,
660+ self .attn_metadata )
661+ # Need flatten batch/seq_len for trtllm attention.
662+ x = x .reshape (batch_size * seq_len , hidden_size )
662663 for block in self .blocks :
663664 x = block (x , attn_metadata = attn_metadata )
664665 x = x .reshape (batch_size , seq_len , hidden_size )
0 commit comments