diff --git a/tensorrt_llm/models/falcon/model.py b/tensorrt_llm/models/falcon/model.py index 57c3d0b97..6a085016c 100644 --- a/tensorrt_llm/models/falcon/model.py +++ b/tensorrt_llm/models/falcon/model.py @@ -42,6 +42,9 @@ def __init__(self, config: PretrainedConfig, layer_idx: int): self.new_decoder_architecture = config.new_decoder_architecture self.parallel_attn = config.parallel_attention + self.num_ln_in_parallel_attn = config.num_ln_in_parallel_attn + if self.num_ln_in_parallel_attn is None: + self.num_ln_in_parallel_attn = 2 if self.is_parallel_attention: # Not to apply allreduce inside the Attention/MLP layers. # allreduce applies after those layer. @@ -66,7 +69,7 @@ def __init__(self, config: PretrainedConfig, layer_idx: int): mlp_hidden_size = hidden_size * 4 if config.intermediate_size is None else config.intermediate_size - if self.new_decoder_architecture: + if self.new_decoder_architecture and self.num_ln_in_parallel_attn==2: # Layernorm before MLP. self.mlp_layernorm = LayerNorm(normalized_shape=hidden_size, eps=layernorm_epsilon, @@ -103,7 +106,7 @@ def forward(self, residual = hidden_states - if self.new_decoder_architecture: + if self.new_decoder_architecture and self.num_ln_in_parallel_attn == 2: mlp_ln_output = self.mlp_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states) input_ln_output = hidden_states @@ -123,7 +126,7 @@ def forward(self, hidden_states = residual + attention_output residual = hidden_states hidden_states = self.post_layernorm(hidden_states) - else: + elif self.num_ln_in_parallel_attn == 2: hidden_states = mlp_ln_output hidden_states = self.mlp(hidden_states)