Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions slime/utils/flops_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ def calculate_qkv_projection_flops(seqlen, hidden_size, num_attention_heads, num


def calculate_attention_flops(seqlen, num_attention_heads, head_dim):
# QK^T
flops = 2 * num_attention_heads * seqlen * seqlen * head_dim
# QK^T with causal
flops = 2 * num_attention_heads * seqlen * seqlen * head_dim // 2
# A*V
flops += 2 * num_attention_heads * seqlen * seqlen * head_dim
return flops
Expand All @@ -31,8 +31,9 @@ def calculate_mlp_flops(seqlen, hidden_size, ffn_hidden_size):
return 2 * seqlen * hidden_size * ffn_hidden_size * 3


def calculate_layer_flops(seqlen, hidden_size, num_attention_heads, num_query_groups, ffn_hidden_size):
head_dim = hidden_size // num_attention_heads
def calculate_layer_flops(seqlen, hidden_size, num_attention_heads, num_query_groups, ffn_hidden_size, head_dim):
if head_dim is None:
head_dim = hidden_size // num_attention_heads
return (
calculate_qkv_projection_flops(seqlen, hidden_size, num_attention_heads, num_query_groups)
+ calculate_attention_flops(seqlen, num_attention_heads, head_dim)
Expand All @@ -49,6 +50,7 @@ def calculate_fwd_flops(
num_attention_heads = args.num_attention_heads
num_query_groups = args.num_query_groups
vocab_size = args.vocab_size
kv_channels = args.kv_channels

total_flops = 0

Expand Down Expand Up @@ -82,6 +84,7 @@ def calculate_fwd_flops(
num_attention_heads,
num_query_groups,
dense_ffn,
kv_channels,
)
* num_dense_layers
)
Expand All @@ -94,6 +97,7 @@ def calculate_fwd_flops(
num_attention_heads,
num_query_groups,
moe_ffn,
kv_channels,
)
* num_moe_layers
)
Expand Down