Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for falcon2 #1926

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions tensorrt_llm/models/falcon/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ def __init__(self, config: PretrainedConfig, layer_idx: int):

self.new_decoder_architecture = config.new_decoder_architecture
self.parallel_attn = config.parallel_attention
self.num_ln_in_parallel_attn = config.num_ln_in_parallel_attn
if self.num_ln_in_parallel_attn is None:
self.num_ln_in_parallel_attn = 2
if self.is_parallel_attention:
# Not to apply allreduce inside the Attention/MLP layers.
# allreduce applies after those layer.
Expand All @@ -66,7 +69,7 @@ def __init__(self, config: PretrainedConfig, layer_idx: int):

mlp_hidden_size = hidden_size * 4 if config.intermediate_size is None else config.intermediate_size

if self.new_decoder_architecture:
if self.new_decoder_architecture and self.num_ln_in_parallel_attn==2:
# Layernorm before MLP.
self.mlp_layernorm = LayerNorm(normalized_shape=hidden_size,
eps=layernorm_epsilon,
Expand Down Expand Up @@ -103,7 +106,7 @@ def forward(self,

residual = hidden_states

if self.new_decoder_architecture:
if self.new_decoder_architecture and self.num_ln_in_parallel_attn == 2:
mlp_ln_output = self.mlp_layernorm(hidden_states)
hidden_states = self.input_layernorm(hidden_states)
input_ln_output = hidden_states
Expand All @@ -123,7 +126,7 @@ def forward(self,
hidden_states = residual + attention_output
residual = hidden_states
hidden_states = self.post_layernorm(hidden_states)
else:
elif self.num_ln_in_parallel_attn == 2:
hidden_states = mlp_ln_output

hidden_states = self.mlp(hidden_states)
Expand Down