diff --git a/slime/utils/flops_utils.py b/slime/utils/flops_utils.py index 57cf9de90..71cdd4c65 100644 --- a/slime/utils/flops_utils.py +++ b/slime/utils/flops_utils.py @@ -16,8 +16,8 @@ def calculate_qkv_projection_flops(seqlen, hidden_size, num_attention_heads, num def calculate_attention_flops(seqlen, num_attention_heads, head_dim): - # QK^T - flops = 2 * num_attention_heads * seqlen * seqlen * head_dim + # QK^T with causal + flops = 2 * num_attention_heads * seqlen * seqlen * head_dim // 2 # A*V flops += 2 * num_attention_heads * seqlen * seqlen * head_dim return flops @@ -31,8 +31,9 @@ def calculate_mlp_flops(seqlen, hidden_size, ffn_hidden_size): return 2 * seqlen * hidden_size * ffn_hidden_size * 3 -def calculate_layer_flops(seqlen, hidden_size, num_attention_heads, num_query_groups, ffn_hidden_size): - head_dim = hidden_size // num_attention_heads +def calculate_layer_flops(seqlen, hidden_size, num_attention_heads, num_query_groups, ffn_hidden_size, head_dim): + if head_dim is None: + head_dim = hidden_size // num_attention_heads return ( calculate_qkv_projection_flops(seqlen, hidden_size, num_attention_heads, num_query_groups) + calculate_attention_flops(seqlen, num_attention_heads, head_dim) @@ -49,6 +50,7 @@ def calculate_fwd_flops( num_attention_heads = args.num_attention_heads num_query_groups = args.num_query_groups vocab_size = args.vocab_size + kv_channels = args.kv_channels total_flops = 0 @@ -82,6 +84,7 @@ def calculate_fwd_flops( num_attention_heads, num_query_groups, dense_ffn, + kv_channels, ) * num_dense_layers ) @@ -94,6 +97,7 @@ def calculate_fwd_flops( num_attention_heads, num_query_groups, moe_ffn, + kv_channels, ) * num_moe_layers )