We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CPU architecture: x86_64 GPU name: NVIDIA A10 TensorRT branch: 9.0.0 TensorRT LLM: 0.1.3 Cuda: 12.1.66 Cudnn: 8.9.0 Container: registry.cn-hangzhou.aliyuncs.com/trt-hackathon/trt-hackathon:final_v1 NVIDIA driver version: 525.105.17 OS: Ubuntu 22.04.3 LTS x86_64 Kernel: 5.15.0-73-generic
实现了WhisperDecoderAttention类,支持self/cross和with/without kv_cache的Attention,该Attention单测正常,但用它搭建起WhisperDecoderLayer后,在self+with kv_cache下计算结果不正确,但如果将中间结果mark_output一下,计算就正确了。猜测是加了mark_output破坏了原始的图融合。
模型代码:
import enum import math from dataclasses import dataclass from typing import Optional import torch import numpy as np import tensorrt as trt from ..._common import default_net, precision from ..._utils import str_dtype_to_trt from ...functional import (Tensor, RaggedTensor, ACT2FN, unsqueeze, gelu, shape, gather, concat, view, permute, constant, split, matmul, softmax, cast, identity) from ...layers import Attention, LayerNorm, ColumnLinear, Conv2d from ...module import Module, ModuleList from ...parameter import Parameter from ...layers.linear import ColumnLinear, RowLinear def squeeze(input, axis): dims = input.ndim() input_shape = shape(input) out_shapes = [] for i in range(dims): if i == axis: continue out_shapes.append(gather(input_shape, 0, i)) out_shape = concat(out_shapes) input = view(input, out_shape) return input class WhisperEncoderLayer(Module): def __init__(self, d_model=512, encoder_attention_heads=8, activation_function='gelu', encoder_ffn_dim=2048): super().__init__() self.embed_dim = d_model self.self_attn = Attention(self.embed_dim, encoder_attention_heads, 1) self.self_attn_layer_norm = LayerNorm(self.embed_dim) self.activation_fn = ACT2FN[activation_function] self.fc1 = ColumnLinear(self.embed_dim, encoder_ffn_dim) self.fc2 = ColumnLinear(encoder_ffn_dim, self.embed_dim) self.final_layer_norm = LayerNorm(self.embed_dim) def forward(self, hidden_states: RaggedTensor): input_lengths = hidden_states.row_lengths max_input_length = hidden_states.max_row_length hidden_states = hidden_states.data residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states = self.self_attn(RaggedTensor.from_row_lengths(hidden_states, input_lengths, max_input_length)) hidden_states = residual + hidden_states.data residual = hidden_states hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = self.fc2(hidden_states) hidden_states = residual + hidden_states return hidden_states class WhisperEncoder(Module): def __init__(self, d_model=512, num_mel_bins=80, max_source_positions=1500, encoder_layers=6, encoder_attention_heads=8, activation_function='gelu', encoder_ffn_dim=2048): super().__init__() embed_dim = d_model # 原本应该是Conv1d的,但trtllm还没实现,先用Conv2d替换 self.conv1 = Conv2d(num_mel_bins, embed_dim, kernel_size=(1,3), padding=(0,1)) self.conv2 = Conv2d(embed_dim, embed_dim, kernel_size=(1,3), stride=(1,2), padding=(0,1)) self.embed_positions_weight = torch.zeros(1,max_source_positions,embed_dim).numpy() self.layers = ModuleList([WhisperEncoderLayer(d_model=d_model, encoder_attention_heads=encoder_attention_heads, activation_function=activation_function, encoder_ffn_dim=encoder_ffn_dim) for _ in range(encoder_layers)]) self.layer_norm = LayerNorm(embed_dim) def forward(self, input_features: RaggedTensor): input_lengths = input_features.row_lengths max_input_length = input_features.max_row_length input_features = input_features.data input_features = unsqueeze(input_features,2) inputs_embeds = gelu(self.conv1(input_features)) inputs_embeds = gelu(self.conv2(inputs_embeds)) inputs_embeds = squeeze(inputs_embeds,2) inputs_embeds = permute(inputs_embeds,[0,2,1]) hidden_states = inputs_embeds + constant(self.embed_positions_weight) for layer in self.layers: hidden_states = layer(RaggedTensor.from_row_lengths(hidden_states, input_lengths, max_input_length)) hidden_states = self.layer_norm(hidden_states) return hidden_states # class SimpleConvTRTLLMNet(Module): # def __init__(self): # super().__init__() # self.encoder = WhisperEncoder() # def forward(self, input_features: RaggedTensor): # hidden_states = self.encoder(input_features) # hidden_states.mark_output('output', str_dtype_to_trt('float32')) # return hidden_states # def prepare_inputs(self): # input_features_data = Tensor(name='data', # dtype=trt.float32, # shape=[1, 80, 3000]) # input_features_length = Tensor(name='length', # dtype=trt.float32, # shape=[1]) # input_features = RaggedTensor.from_row_lengths(input_features_data, input_features_length) # return (input_features) class AttentionMaskType(enum.Enum): padding = 0 causal = 1 bidirectional = 2 class PositionEmbeddingType(enum.Enum): learned_absolute = enum.auto() rope = enum.auto() alibi = enum.auto() @dataclass class InflightBatchingParam: host_beam_widths: Tensor cache_indir_pointers: Tensor host_req_cache_max_seq_lengths: Tensor host_input_lengths: Tensor past_key_value_pointers: Tensor max_input_length: int max_beam_width: int kv_orig_quant_scale: Optional[Tensor] = None kv_quant_orig_scale: Optional[Tensor] = None use_int8_kv_cache: bool = False def __post_init__(self): assert self.max_input_length > 0, f"max_input_length must be positive, got {self.max_input_length}" assert self.max_beam_width > 0, f"max_beam_width must be positive, got {self.max_beam_width}" class WhisperDecoderAttention(Module): def __init__(self, hidden_size, num_attention_heads, max_position_embeddings=0, num_layers=1, apply_query_key_layer_scaling=False, bias=True, dtype=None, position_embedding_type=PositionEmbeddingType.learned_absolute, neox_rotary_style=False, use_int8_kv_cache=False, rotary_embedding_percentage=1.0, tp_group=None, tp_size=1, multi_block_mode=False, multi_query_mode=False): super().__init__() self.attention_head_size = hidden_size // num_attention_heads self.num_attention_heads = num_attention_heads // tp_size self.num_attention_kv_heads = 1 if multi_query_mode else self.num_attention_heads self.hidden_size = hidden_size // tp_size self.max_position_embeddings = max_position_embeddings self.num_layers = num_layers self.apply_query_key_layer_scaling = apply_query_key_layer_scaling self.norm_factor = math.sqrt(self.attention_head_size) self.q_scaling = 1 if self.apply_query_key_layer_scaling: self.norm_factor *= self.num_layers self.q_scaling *= self.num_layers self.position_embedding_type = position_embedding_type self.multi_block_mode = multi_block_mode self.multi_query_mode = multi_query_mode self.rotary_embedding_dim = 0 self.neox_rotary_style = neox_rotary_style if self.position_embedding_type == PositionEmbeddingType.rope: self.rotary_embedding_dim = int(self.attention_head_size * rotary_embedding_percentage) # TODO: Once we add RotaryEmbedding outside GPTAttention plugin, # we need to set it up here self.dtype = dtype self.use_int8_kv_cache = use_int8_kv_cache if self.use_int8_kv_cache: self.kv_orig_quant_scale = Parameter(shape=(1, ), dtype='float32') self.kv_quant_orig_scale = Parameter(shape=(1, ), dtype='float32') else: self.register_parameter('kv_orig_quant_scale', None) self.register_parameter('kv_quant_orig_scale', None) # Note: in multi_query_mode, only query heads are split between multiple GPUs, # while key/value head are not split as there is only one head per key/value. # The output feature size is therefore (h/tp + 2) * d, where h is num_heads, # d is head_size, and tp is tensor_parallel_size. # In ColumnLinear op, the output dim is calculated by (h + 2*tp) * d / tp, # which matches the desired output size (h/tp + 2) * d after splitting self.q_proj = ColumnLinear(hidden_size, hidden_size, bias=bias, dtype=dtype, tp_group=tp_group, tp_size=tp_size) self.k_proj = ColumnLinear(hidden_size, hidden_size, bias=False, dtype=dtype, tp_group=tp_group, tp_size=tp_size) self.v_proj = ColumnLinear(hidden_size, hidden_size, bias=bias, dtype=dtype, tp_group=tp_group, tp_size=tp_size) self.dense = RowLinear(hidden_size, hidden_size, bias=bias, dtype=dtype, tp_group=tp_group, tp_size=tp_size) def forward(self, hidden_states: RaggedTensor, key_value_states: Optional[RaggedTensor] = None, past_key_value: Optional[Tensor] = None ): input_lengths = hidden_states.row_lengths max_input_length = hidden_states.max_row_length hidden_states = hidden_states.data def transpose_for_scores(x): new_x_shape = concat([ shape(x, 0), shape(x, 1), self.num_attention_heads, self.attention_head_size ]) return x.view(new_x_shape).permute([0, 2, 1, 3]) query_states = transpose_for_scores(self.q_proj(hidden_states)) is_cross_attention = key_value_states is not None is_reuse = past_key_value is not None if is_cross_attention and is_reuse: dumpy_key_value_states = constant(np.zeros((512),dtype=np.float32)) key_states = self.k_proj(dumpy_key_value_states) value_states = self.v_proj(dumpy_key_value_states) key_states, value_states = split(past_key_value,1,dim=0) elif is_cross_attention: key_states = transpose_for_scores(self.k_proj(key_value_states)) value_states = transpose_for_scores(self.v_proj(key_value_states)) elif is_reuse: # curr_key_states = transpose_for_scores(self.k_proj(hidden_states)) # curr_value_states = transpose_for_scores(self.v_proj(hidden_states)) # past_key_states, past_value_states = split(past_key_value,1,dim=0) # key_states = concat([past_key_states, curr_key_states], dim=2) # value_states = concat([past_value_states, curr_value_states], dim=2) past_key_states, past_value_states = split(past_key_value,1,dim=0) curr_key_states = transpose_for_scores(self.k_proj(hidden_states)) curr_value_states = transpose_for_scores(self.v_proj(hidden_states)) past_value_states.mark_output('hook0', str_dtype_to_trt('float32')) key_states = concat([past_key_states, curr_key_states], dim=2) value_states = concat([past_value_states, curr_value_states], dim=2) else: key_states = transpose_for_scores(self.k_proj(hidden_states)) value_states = transpose_for_scores(self.v_proj(hidden_states)) query = query_states key = key_states value = value_states past_key_value = concat([key, value], dim=0) key = key.permute([0, 1, 3, 2]) with precision('float32'): attention_scores = matmul(cast(query, 'float32'), cast(key, 'float32')) attention_scores = attention_scores / self.norm_factor attention_probs = softmax(attention_scores, dim=-1) context = matmul(attention_probs, value).permute([0, 2, 1, 3]) context = context.view(concat([shape(context, 0), shape(context, 1), self.hidden_size])) context = self.dense(context) context = RaggedTensor.from_row_lengths(context, input_lengths, max_input_length) return context, past_key_value class WhisperDecoderLayer(Module): def __init__(self, d_model=512, decoder_attention_heads=8, activation_function='gelu', decoder_ffn_dim=2048): super().__init__() self.embed_dim = d_model self.self_attn = WhisperDecoderAttention(self.embed_dim,decoder_attention_heads) self.activation_fn = ACT2FN[activation_function] self.self_attn_layer_norm = LayerNorm(self.embed_dim) self.encoder_attn = WhisperDecoderAttention(self.embed_dim,decoder_attention_heads) self.encoder_attn_layer_norm = LayerNorm(self.embed_dim) self.fc1 = ColumnLinear(self.embed_dim, decoder_ffn_dim) self.fc2 = ColumnLinear(decoder_ffn_dim, self.embed_dim) self.final_layer_norm = LayerNorm(self.embed_dim) def forward(self, hidden_states: RaggedTensor, encoder_hidden_states: Optional[Tensor] = None, self_attn_past_key_value: Optional[Tensor] = None, cross_attn_past_key_value: Optional[Tensor] = None, ): input_lengths = hidden_states.row_lengths max_input_length = hidden_states.max_row_length hidden_states = hidden_states.data residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) # Self Attention hidden_states, present_key_value = self.self_attn( hidden_states=RaggedTensor.from_row_lengths(hidden_states, input_lengths, max_input_length), key_value_states=None, past_key_value=self_attn_past_key_value ) hidden_states = residual + hidden_states.data residual = hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) # Cross Attention hidden_states, cross_attn_present_key_value = self.encoder_attn( hidden_states=RaggedTensor.from_row_lengths(hidden_states, input_lengths, max_input_length), key_value_states=encoder_hidden_states, past_key_value=cross_attn_past_key_value, ) hidden_states = residual + hidden_states.data # Fully Connected residual = hidden_states hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = self.fc2(hidden_states) hidden_states = residual + hidden_states return hidden_states, present_key_value, cross_attn_present_key_value class SimpleConvTRTLLMNet(Module): def __init__(self): super().__init__() self.layer = WhisperDecoderLayer() def forward(self, hidden_states: RaggedTensor, encoder_hidden_states: Tensor, self_attn_past_key_value: Tensor, cross_attn_past_key_value: Tensor): hidden_states, present_key_value, cross_attn_present_key_value = self.layer(hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, self_attn_past_key_value=self_attn_past_key_value, cross_attn_past_key_value=cross_attn_past_key_value) hidden_states.mark_output('output0', str_dtype_to_trt('float32')) present_key_value.mark_output('output1', str_dtype_to_trt('float32')) cross_attn_present_key_value.mark_output('output2', str_dtype_to_trt('float32')) return hidden_states def prepare_inputs(self): input_features_data = Tensor(name='data', dtype=trt.float32, shape=[1, 1, 512]) input_features_length = Tensor(name='length', dtype=trt.float32, shape=[1]) input_features = RaggedTensor.from_row_lengths(input_features_data, input_features_length) encoder_hidden_states = Tensor(name='encoder_hidden_states', dtype=trt.float32, shape=[1, 1500, 512]) self_attn_past_key_value = Tensor(name='self_attn_past_key_value', dtype=trt.float32, shape=[2, 8, 23, 64]) cross_attn_past_key_value = Tensor(name='cross_attn_past_key_value', dtype=trt.float32, shape=[2, 8, 1500, 64]) return (input_features, encoder_hidden_states, self_attn_past_key_value, cross_attn_past_key_value) if __name__ == '__main__': net = SimpleConvTRTLLMNet()
生成对比用的pytorch模型:
import math import torch import torch.nn as nn from activations import ACT2FN from typing import Optional, Tuple class WhisperEncoderAttention(nn.Module): def __init__( self, embed_dim: int = 512, num_heads: int = 8, dropout: float = 0.0, bias: bool = True, ): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.dropout = dropout self.head_dim = embed_dim // num_heads self.scaling = self.head_dim**-0.5 self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() def forward( self, hidden_states: torch.Tensor, ): # bsz=1, tgt_len=1500, _取决于模型大小 bsz, tgt_len, _ = hidden_states.size() # get query proj query_states = self.q_proj(hidden_states) * self.scaling key_states = self._shape(self.k_proj(hidden_states), -1, bsz) value_states = self._shape(self.v_proj(hidden_states), -1, bsz) proj_shape = (bsz * self.num_heads, -1, self.head_dim) query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) key_states = key_states.reshape(*proj_shape) value_states = value_states.reshape(*proj_shape) attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) attn_weights = nn.functional.softmax(attn_weights, dim=-1) attn_output = torch.bmm(attn_weights, value_states) attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) attn_output = attn_output.transpose(1, 2) attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) attn_output = self.out_proj(attn_output) return attn_output class WhisperEncoderLayer(nn.Module): def __init__(self, d_model=512, encoder_attention_heads=8, activation_function='gelu', encoder_ffn_dim=2048): super().__init__() self.embed_dim = d_model self.self_attn = WhisperEncoderAttention( embed_dim=self.embed_dim, num_heads=encoder_attention_heads, ) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.activation_fn = ACT2FN[activation_function] self.fc1 = nn.Linear(self.embed_dim, encoder_ffn_dim) self.fc2 = nn.Linear(encoder_ffn_dim, self.embed_dim) self.final_layer_norm = nn.LayerNorm(self.embed_dim) def forward( self, hidden_states: torch.Tensor, ) -> torch.Tensor: residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states = self.self_attn(hidden_states=hidden_states) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = self.fc2(hidden_states) hidden_states = residual + hidden_states outputs = hidden_states return outputs class WhisperEncoder(nn.Module): def __init__(self, d_model=512, num_mel_bins=80, max_source_positions=1500, encoder_layers=6, encoder_attention_heads=8, activation_function='gelu', encoder_ffn_dim=2048): super().__init__() embed_dim = d_model self.conv1 = nn.Conv1d(num_mel_bins, embed_dim, kernel_size=3, padding=1) self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1) self.embed_positions = nn.Embedding(max_source_positions, embed_dim) self.layers = nn.ModuleList([WhisperEncoderLayer(d_model=d_model, encoder_attention_heads=encoder_attention_heads, activation_function=activation_function, encoder_ffn_dim=encoder_ffn_dim) for _ in range(encoder_layers)]) self.layer_norm = nn.LayerNorm(d_model) def forward( self, input_features, # (1,80,3000) ): inputs_embeds = nn.functional.gelu(self.conv1(input_features)) inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds)) inputs_embeds = inputs_embeds.permute(0, 2, 1) embed_pos = self.embed_positions.weight hidden_states = inputs_embeds + embed_pos for encoder_layer in self.layers: hidden_states = encoder_layer(hidden_states) hidden_states = self.layer_norm(hidden_states) return hidden_states class WhisperDecoderAttention(nn.Module): def __init__( self, embed_dim: int = 512, num_heads: int = 8, bias: bool = True, ): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.head_dim = embed_dim // num_heads self.scaling = self.head_dim**-0.5 self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) # Copied from transformers.models.bart.modeling_bart.BartAttention._shape with BART->whisper def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() # Copied from transformers.models.bart.modeling_bart.BartAttention.forward with BART->whisper def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, ): is_cross_attention = key_value_states is not None bsz, tgt_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) * self.scaling if ( is_cross_attention and past_key_value is not None and past_key_value[0].shape[2] == key_value_states.shape[1] ): # reuse k,v, cross_attentions key_states = past_key_value[0] value_states = past_key_value[1] elif is_cross_attention: # cross_attentions key_states = self._shape(self.k_proj(key_value_states), -1, bsz) value_states = self._shape(self.v_proj(key_value_states), -1, bsz) elif past_key_value is not None: # reuse k, v, self_attention key_states = self._shape(self.k_proj(hidden_states), -1, bsz) value_states = self._shape(self.v_proj(hidden_states), -1, bsz) key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) else: # self_attention key_states = self._shape(self.k_proj(hidden_states), -1, bsz) value_states = self._shape(self.v_proj(hidden_states), -1, bsz) past_key_value = (key_states, value_states) proj_shape = (bsz * self.num_heads, -1, self.head_dim) query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) key_states = key_states.reshape(*proj_shape) value_states = value_states.reshape(*proj_shape) attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) attn_weights = nn.functional.softmax(attn_weights, dim=-1) attn_output = torch.bmm(attn_weights, value_states) attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) attn_output = attn_output.transpose(1, 2) attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) attn_output = self.out_proj(attn_output) return attn_output, past_key_value class WhisperDecoderLayer(nn.Module): def __init__(self, d_model = 512, decoder_attention_heads = 8, activation_function = 'gelu', decoder_ffn_dim = 2048): super().__init__() self.embed_dim = d_model self.self_attn = WhisperDecoderAttention(embed_dim=self.embed_dim,num_heads=decoder_attention_heads) self.activation_fn = ACT2FN[activation_function] self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.encoder_attn = WhisperDecoderAttention(self.embed_dim,decoder_attention_heads,) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.fc1 = nn.Linear(self.embed_dim, decoder_ffn_dim) self.fc2 = nn.Linear(decoder_ffn_dim, self.embed_dim) self.final_layer_norm = nn.LayerNorm(self.embed_dim) def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, ) -> torch.Tensor: residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) # Self Attention # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None # add present self-attn cache to positions 1,2 of present_key_value tuple hidden_states, present_key_value = self.self_attn( hidden_states=hidden_states, key_value_states=None, past_key_value=self_attn_past_key_value ) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None hidden_states, cross_attn_present_key_value = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, past_key_value=cross_attn_past_key_value, ) hidden_states = residual + hidden_states # add cross-attn to positions 3,4 of present_key_value tuple present_key_value = present_key_value + cross_attn_present_key_value # Fully Connected residual = hidden_states hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = self.fc2(hidden_states) hidden_states = residual + hidden_states outputs = (hidden_states,) outputs += (present_key_value,) return outputs class SimpleConvTorchNet(nn.Module): def __init__(self): super().__init__() self.layer = WhisperDecoderLayer() def forward(self, hidden_states,encoder_hidden_states,past_key_value): output = self.layer(hidden_states,encoder_hidden_states,past_key_value) return output if __name__ == '__main__': torch_net = SimpleConvTorchNet() torch.save(torch_net.state_dict(),'weight.pth') output = torch_net(torch.rand(1,1,512),torch.rand(1,1500,512),None) print(len(output),output[0].shape,[i.shape for i in output[1]]) # 2 torch.Size([1, 1, 512]) [torch.Size([1, 8, 1, 64]), torch.Size([1, 8, 1, 64]), torch.Size([1, 8, 1500, 64]), torch.Size([1, 8, 1500, 64])] output = torch_net(torch.rand(1,1,512),torch.rand(1,1500,512),(torch.rand(1,8,23,64),torch.rand(1,8,23,64),torch.rand(1,8,1500,64),torch.rand(1,8,1500,64))) print(len(output),output[0].shape,[i.shape for i in output[1]]) # 2 torch.Size([1, 1, 512]) [torch.Size([1, 8, 24, 64]), torch.Size([1, 8, 24, 64]), torch.Size([1, 8, 1500, 64]), torch.Size([1, 8, 1500, 64])] output = torch_net(torch.rand(1,1,512),torch.rand(1,1500,512),(torch.rand(1,8,1,64),torch.rand(1,8,1,64),torch.rand(1,8,1500,64),torch.rand(1,8,1500,64))) print(len(output),output[0].shape,[i.shape for i in output[1]]) # 2 torch.Size([1, 1, 512]) [torch.Size([1, 8, 2, 64]), torch.Size([1, 8, 2, 64]), torch.Size([1, 8, 1500, 64]), torch.Size([1, 8, 1500, 64])]
构建engine:
import time import torch import tensorrt_llm from tensorrt_llm.builder import Builder from tensorrt_llm.logger import logger from tensorrt_llm.network import net_guard def serialize_engine(engine, path): logger.info(f'Serializing engine to {path}...') tik = time.time() with open(path, 'wb') as f: f.write(bytearray(engine)) tok = time.time() t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) logger.info(f'Engine serialized. Total time: {t}') if __name__ == '__main__': logger.set_level('info') torch.cuda.set_device(0) tensorrt_llm.logger.set_level('info') # create builder builder = Builder() builder_config = builder.create_builder_config( name='SimpleWhisper', precision='float32', timing_cache='model.cache', tensor_parallel=1, parallel_build=False, int8=False, opt_level=None, ) # create tensort-llm model tensorrt_llm_test = tensorrt_llm.models.SimpleConvTRTLLMNet() ckpt = torch.load('weight.pth',map_location='cpu') print(ckpt.keys()) tensorrt_llm_test.layer.self_attn.q_proj.weight.value = ckpt['layer.self_attn.q_proj.weight'].numpy() tensorrt_llm_test.layer.self_attn.q_proj.bias.value = ckpt['layer.self_attn.q_proj.bias'].numpy() tensorrt_llm_test.layer.self_attn.k_proj.weight.value = ckpt['layer.self_attn.k_proj.weight'].numpy() tensorrt_llm_test.layer.self_attn.v_proj.weight.value = ckpt['layer.self_attn.v_proj.weight'].numpy() tensorrt_llm_test.layer.self_attn.v_proj.bias.value = ckpt['layer.self_attn.v_proj.bias'].numpy() tensorrt_llm_test.layer.self_attn.dense.weight.value = ckpt['layer.self_attn.out_proj.weight'].numpy() tensorrt_llm_test.layer.self_attn.dense.bias.value = ckpt['layer.self_attn.out_proj.bias'].numpy() tensorrt_llm_test.layer.self_attn_layer_norm.weight.value = ckpt['layer.self_attn_layer_norm.weight'].numpy() tensorrt_llm_test.layer.self_attn_layer_norm.bias.value = ckpt['layer.self_attn_layer_norm.bias'].numpy() tensorrt_llm_test.layer.encoder_attn.q_proj.weight.value = ckpt['layer.encoder_attn.q_proj.weight'].numpy() tensorrt_llm_test.layer.encoder_attn.q_proj.bias.value = ckpt['layer.encoder_attn.q_proj.bias'].numpy() tensorrt_llm_test.layer.encoder_attn.k_proj.weight.value = ckpt['layer.encoder_attn.k_proj.weight'].numpy() tensorrt_llm_test.layer.encoder_attn.v_proj.weight.value = ckpt['layer.encoder_attn.v_proj.weight'].numpy() tensorrt_llm_test.layer.encoder_attn.v_proj.bias.value = ckpt['layer.encoder_attn.v_proj.bias'].numpy() tensorrt_llm_test.layer.encoder_attn.dense.weight.value = ckpt['layer.encoder_attn.out_proj.weight'].numpy() tensorrt_llm_test.layer.encoder_attn.dense.bias.value = ckpt['layer.encoder_attn.out_proj.bias'].numpy() tensorrt_llm_test.layer.encoder_attn_layer_norm.weight.value = ckpt['layer.encoder_attn_layer_norm.weight'].numpy() tensorrt_llm_test.layer.encoder_attn_layer_norm.bias.value = ckpt['layer.encoder_attn_layer_norm.bias'].numpy() tensorrt_llm_test.layer.fc1.weight.value = ckpt['layer.fc1.weight'].numpy() tensorrt_llm_test.layer.fc1.bias.value = ckpt['layer.fc1.bias'].numpy() tensorrt_llm_test.layer.fc2.weight.value = ckpt['layer.fc2.weight'].numpy() tensorrt_llm_test.layer.fc2.bias.value = ckpt['layer.fc2.bias'].numpy() tensorrt_llm_test.layer.final_layer_norm.weight.value = ckpt['layer.final_layer_norm.weight'].numpy() tensorrt_llm_test.layer.final_layer_norm.bias.value = ckpt['layer.final_layer_norm.bias'].numpy() network = builder.create_network() network.trt_network.name = 'SimpleWhisper' with net_guard(network): network.set_named_parameters(tensorrt_llm_test.named_parameters()) inputs = tensorrt_llm_test.prepare_inputs() tensorrt_llm_test(*inputs) engine = builder.build_engine(network, builder_config) assert engine is not None, f'Failed to build engine' serialize_engine(engine, 'simplewhisper.engine')
运行对比结果:
import argparse import csv import json from pathlib import Path import contextlib import numpy as np import torch import tensorrt as trt import tensorrt_llm from tensorrt_llm.runtime import Session, TensorInfo from create import SimpleConvTorchNet @contextlib.contextmanager def _scoped_stream(): '''Create a scoped cuda stream, and synchronize it when the context is destroyed ''' #TODO: delete torch, use cuda native python bindings import torch stream = torch.cuda.current_stream() try: # return a handle, trt and other lib does not recognize torch.cuda.Stream yield stream.cuda_stream finally: stream.synchronize() if __name__ == '__main__': tensorrt_llm.logger.set_level('info') runtime_rank = tensorrt_llm.mpi_rank() runtime_mapping = tensorrt_llm.Mapping(1, 0) torch.cuda.set_device(0) # load engine with open('simplewhisper.engine', 'rb') as f: engine_buffer = f.read() session = Session.from_serialized_engine(engine_buffer) # inference output shape inputs_shape = [ TensorInfo('data',trt.float32,(1,1,512)), TensorInfo('length',trt.float32,(1,)), TensorInfo('encoder_hidden_states',trt.float32,(1,1500,512)), TensorInfo('self_attn_past_key_value',trt.float32,(2,8,23,64)), TensorInfo('cross_attn_past_key_value',trt.float32,(2,8,1500,64)), ] outputs_shape = session.infer_shapes(inputs_shape) # malloc buffer inputs = { 'data': torch.rand(1,1,512).cuda(), 'length': torch.Tensor([1.0]).cuda(), 'encoder_hidden_states': torch.rand(1,1500,512).cuda(), 'self_attn_past_key_value': torch.rand(2,8,23,64).cuda(), 'cross_attn_past_key_value': torch.rand(2,8,1500,64).cuda(), } outputs = {} for output in outputs_shape: outputs[output.name] = torch.zeros(*output.shape).cuda() # execute with _scoped_stream() as stream: ok = session.run(inputs, outputs, stream) torch.cuda.synchronize() trtllm_out = outputs['output0'] trtllm_skv = outputs['output1'] trtllm_ckv = outputs['output2'] # print(trtllm_out.shape,trtllm_skv.shape,trtllm_ckv.shape) torch_net = SimpleConvTorchNet() torch_net.load_state_dict(torch.load('weight.pth',map_location='cpu')) torch_net.cuda() with torch.inference_mode(): torch_out, (torch_sk, torch_sv, torch_ck, torch_cv) = torch_net(inputs['data'],inputs['encoder_hidden_states'], (inputs['self_attn_past_key_value'][0:1],inputs['self_attn_past_key_value'][1:2], inputs['cross_attn_past_key_value'][0:1],inputs['cross_attn_past_key_value'][1:2])) torch_skv = torch.cat([torch_sk,torch_sv],dim=0) torch_ckv = torch.cat([torch_ck,torch_cv],dim=0) a = trtllm_skv[0].cpu().numpy() b = torch_skv[0].cpu().numpy() diff = np.abs(a-b) print(a.shape,a.min(),a.mean(),a.max(),a.var()) print(b.shape,b.min(),b.mean(),b.max(),b.var()) print(diff.shape,diff.min(),diff.mean(),diff.max(),diff.var()) a = trtllm_skv[1].cpu().numpy() b = torch_skv[1].cpu().numpy() diff = np.abs(a-b) print(a.shape,a.min(),a.mean(),a.max(),a.var()) print(b.shape,b.min(),b.mean(),b.max(),b.var()) print(diff.shape,diff.min(),diff.mean(),diff.max(),diff.var())
通过启用/关闭注释掉模型代码中第299行的mark_output可以对比结果。具体依赖较多,建议直接参考开发commit的代码
运行create.py -> build.py -> run.py就可以得到输出。
The text was updated successfully, but these errors were encountered:
No branches or pull requests
Environment
CPU architecture: x86_64
GPU name: NVIDIA A10
TensorRT branch: 9.0.0
TensorRT LLM: 0.1.3
Cuda: 12.1.66
Cudnn: 8.9.0
Container: registry.cn-hangzhou.aliyuncs.com/trt-hackathon/trt-hackathon:final_v1
NVIDIA driver version: 525.105.17
OS: Ubuntu 22.04.3 LTS x86_64
Kernel: 5.15.0-73-generic
问题简要描述
实现了WhisperDecoderAttention类,支持self/cross和with/without kv_cache的Attention,该Attention单测正常,但用它搭建起WhisperDecoderLayer后,在self+with kv_cache下计算结果不正确,但如果将中间结果mark_output一下,计算就正确了。猜测是加了mark_output破坏了原始的图融合。
复现代码
模型代码:
生成对比用的pytorch模型:
构建engine:
运行对比结果:
通过启用/关闭注释掉模型代码中第299行的mark_output可以对比结果。具体依赖较多,建议直接参考开发commit的代码
运行create.py -> build.py -> run.py就可以得到输出。
The text was updated successfully, but these errors were encountered: