diff --git a/MaxText/maxtext_utils.py b/MaxText/maxtext_utils.py index cffb524ef..0f8c2827b 100644 --- a/MaxText/maxtext_utils.py +++ b/MaxText/maxtext_utils.py @@ -223,6 +223,49 @@ def calculate_gemma3_tflops_training_per_device(config, total_ffn_flops, qkv_flo return attention_tflops, learnable_weight_tflops +def _calculate_chunked_attention_flops_per_layer(config, seq_len, chunk_size): + """Calculates the non-causal FLOPs for a single layer of chunked attention.""" + num_chunks = seq_len // chunk_size + rem_chunk_size = seq_len % chunk_size + # The complexity of chunked attention is the sum of squares of chunk lengths. + chunked_complexity = (num_chunks * chunk_size**2) + (rem_chunk_size**2) + # The formula for non-causal attention FLOPs is 4 * B * complexity * H * D, + # where B=batch_size, H=num_heads, D=head_dim. + return 4 * config.per_device_batch_size * chunked_complexity * config.num_query_heads * config.head_dim + + +def calculate_llama4_attention_tflops(config): + """ + Calculates attention-only training TFLOPs for Llama4's specific architecture, + which has an alternating pattern of global and chunked attention layers. + """ + num_layers = config.num_decoder_layers + seq_len = config.max_target_length + chunk_size = config.chunk_attn_window_size + + # Determine number of global vs. chunked layers based on the NoPE interval. + # A "NoPE" layer uses global attention. + num_global_layers = num_layers // config.nope_layer_interval + num_chunked_layers = num_layers - num_global_layers + + # FLOPs for a single global attention layer (full attention, non-causal) + global_attention_flops_per_layer = 4 * config.per_device_batch_size * seq_len**2 * config.num_query_heads * config.head_dim + + # FLOPs for a single chunked attention layer (non-causal) + chunked_attention_flops_per_layer = _calculate_chunked_attention_flops_per_layer(config, seq_len, chunk_size) + + # Total non-causal attention FLOPs is the sum of all global and all chunked layers + noncausal_attention_flops = (num_global_layers * global_attention_flops_per_layer) + ( + num_chunked_layers * chunked_attention_flops_per_layer + ) + + # Apply causal mask and convert to TFLOPs (multiply by 3 for fwd/bwd pass) + causal_attention_flops = noncausal_attention_flops / 2 + attention_tflops = causal_attention_flops * 3 / 10**12 + + return attention_tflops + + def calculate_mla_tflops_per_device(config): """Calculate Multi-Head Latent Attention TFLOP""" batch_len = config.per_device_batch_size * config.max_target_length @@ -351,7 +394,14 @@ def calculate_tflops_training_per_device(config, log=True): attention_tflops, learnable_weight_tflops = calculate_gemma3_tflops_training_per_device( config, total_ffn_flops, qkv_flops, projection_flops, embedding_flops ) - elif config.decoder_block in (DecoderBlockType.DEEPSEEK, DecoderBlockType.LLAMA4): + elif config.decoder_block == DecoderBlockType.LLAMA4: + # Use the new helper to calculate attention TFLOPs correctly. + attention_tflops = calculate_llama4_attention_tflops(config) + # The learnable weight calculation remains the same as it correctly handles Llama4's MoE structure. + learnable_weight_tflops = ( + (total_ffn_flops + (qkv_flops + projection_flops) * config.num_decoder_layers + embedding_flops) * 3 / 10**12 + ) + elif config.decoder_block == DecoderBlockType.DEEPSEEK: learnable_weight_tflops = ( (total_ffn_flops + (qkv_flops + projection_flops) * config.num_decoder_layers + embedding_flops) * 3 / 10**12 )