1212BYTES_PER_EL_FLOAT4  =  0.5 
1313BYTES_PER_EL_FLOAT8  =  1 
1414BYTES_PER_EL_BF16  =  2 
15+ BYTES_PER_EL_FLOAT8_E8M0  =  1 
16+ BYTES_PER_EL_FLOAT32  =  4 
1517
1618gpu_name_to_specs  =  {
1719    "NVIDIA H100" : {
@@ -241,7 +243,7 @@ def get_individual_gemm_time_sympy(
241243    elif  dtype  is  torch .float4_e2m1fn_x2 :
242244        peak_tops  =  specs ["fp4_peak_tops" ]
243245    else :
244-         assert  False , "unsupported" 
246+         assert  False , f "unsupported dtype:  { dtype } 
245247    compute_gemm_time_s  =  gemm_ops  /  peak_tops  /  specs ["pct_achievable_gemm_tops" ]
246248
247249    # memory bound 
@@ -274,7 +276,7 @@ def get_individual_gemm_time_sympy(
274276    elif  dtype  is  torch .float4_e2m1fn_x2 :
275277        bytes_rw  =  num_reads  *  BYTES_PER_EL_FLOAT4  +  num_writes  *  BYTES_PER_EL_BF16 
276278    else :
277-         assert  False , "unsupported" 
279+         assert  False , f "unsupported dtype:  { dtype } 
278280    mem_gemm_time_s  =  (
279281        bytes_rw  /  specs ["peak_mem_bw_bytes_sec" ] /  specs ["pct_achievable_mem_bw" ]
280282    )
@@ -376,27 +378,56 @@ def get_inference_tensor_memory_traffic_ovhd_s(
376378    dim1 ,
377379    tensor_role : str ,
378380    float8_recipe_name : Optional [str ],
381+     mx_recipe_name : Optional [str ],
379382    fuse_with_prev = False ,
380383) ->  List [Union [sympy .Symbol , float ]]:
381384    """ 
382385    Inference version of `get_tensor_memory_traffic_ovhd_s`. 
383386    The only thing happening here is we quantize the activation. 
384387    """ 
385-     assert  float8_recipe_name  ==  "rowwise" , "unsupported" 
386388    assert  fuse_with_prev  is  False , "unsupported" 
389+     assert  tensor_role  ==  "input" , "inference only quantizes input activations" 
387390
388391    # assumes input bf16, output f8 
389392    numel  =  dim0  *  dim1 
390393
391394    res_bytes  =  None 
392395
393-     assert  tensor_role  ==  "input" 
394-     # x_bf16 = ... 
395-     # kernel 1:               x_bf16 -> x_fp8 
396-     kernel_1_rw  =  BYTES_PER_EL_BF16  *  numel  +  BYTES_PER_EL_FLOAT8  *  numel 
397-     res_bytes  =  [
398-         kernel_1_rw ,
399-     ]
396+     if  float8_recipe_name  ==  "tensorwise" :
397+         # x_bf16 = ... 
398+         # kernel 1:               x_bf16 -> max_abs_stage_1 -> tmp 
399+         # kernel 2 (mem traffic not modeled): tmp -> max_abs_stage_2 -> max_abs 
400+         # kernel 3:               x_bf16, max_abs -> to_float8 -> x_fp8 
401+         # kernel 1: read numel, write 0 (assume size(tmp) ~ 0) 
402+         kernel_1_rw  =  BYTES_PER_EL_BF16  *  numel 
403+         # kernel 3: read in bf16, write in float8 
404+         kernel_3_rw  =  BYTES_PER_EL_BF16  *  numel  +  BYTES_PER_EL_FLOAT8  *  numel 
405+         res_bytes  =  [kernel_1_rw , kernel_3_rw ]
406+ 
407+     elif  float8_recipe_name  ==  "rowwise" :
408+         # x_bf16 = ... 
409+         # kernel 1:               x_bf16 -> x_fp8 (with per-row scaling) 
410+         kernel_1_rw  =  BYTES_PER_EL_BF16  *  numel  +  BYTES_PER_EL_FLOAT8  *  numel 
411+         # add in the bytes for scale writes 
412+         kernel_1_rw  +=  BYTES_PER_EL_FLOAT32  *  dim0 
413+         res_bytes  =  [kernel_1_rw ]
414+ 
415+     elif  mx_recipe_name  in  ("mxfp8_emulated" , "mxfp8_cublas" , "mxfp8_cublas_rceil" ):
416+         # x_bf16 = ... 
417+         # kernel 1:               x_bf16 -> x_mxfp8 (block-wise scaling for inference) 
418+         kernel_1_rw  =  BYTES_PER_EL_BF16  *  numel  +  BYTES_PER_EL_FLOAT8  *  numel 
419+         # add in the bytes for scale writes 
420+         kernel_1_rw  +=  BYTES_PER_EL_FLOAT8_E8M0  *  dim0  *  (dim1  //  32 )
421+         res_bytes  =  [kernel_1_rw ]
422+ 
423+     else :
424+         # For NVFP4, assume minimal overhead since it's primarily a compute format 
425+         # x_bf16 = ... 
426+         # kernel 1:               x_bf16 -> x_nvfp4 (per-tensor scaling for inference) 
427+         kernel_1_rw  =  BYTES_PER_EL_BF16  *  numel  +  BYTES_PER_EL_FLOAT4  *  numel 
428+         # add minimal scaling overhead (per-tensor scale) 
429+         kernel_1_rw  +=  BYTES_PER_EL_FLOAT32   # single scale factor 
430+         res_bytes  =  [kernel_1_rw ]
400431
401432    # convert from bytes to seconds 
402433    res_s  =  [
@@ -415,6 +446,8 @@ def get_inference_float8_mem_sympy(
415446    K ,
416447    N ,
417448    float8_recipe_name : Optional [str ],
449+     mx_recipe_name : Optional [str ] =  None ,
450+     nvfp4_recipe_name : Optional [str ] =  None ,
418451    gpu_name : Optional [str ] =  None ,
419452):
420453    specs  =  get_specs (gpu_name )
@@ -426,6 +459,7 @@ def get_inference_float8_mem_sympy(
426459        K ,
427460        tensor_role = "input" ,
428461        float8_recipe_name = float8_recipe_name ,
462+         mx_recipe_name = mx_recipe_name ,
429463        fuse_with_prev = False ,
430464    )
431465    res  =  sum ([* fwd_fp8_input_mem ])
@@ -438,9 +472,9 @@ def get_inference_gemm_time_sympy(
438472    N : sympy .Symbol ,
439473    dtype ,
440474    float8_recipe_name : Optional [str ],
441-     gpu_name : Optional [str ],
475+     nvfp4_recipe_name : Optional [str ] =  None ,
476+     gpu_name : Optional [str ] =  None ,
442477):
443-     assert  float8_recipe_name  ==  "rowwise"  or  float8_recipe_name  is  None , "unsupported" 
444478    # note: this function is currently not super accurate for small shapes: 
445479    # when M,K,N <= 1k,1k,1k it undercounts by around 2x 
446480    gemm_output_time_s  =  get_individual_gemm_time_sympy (M , K , N , dtype , None , gpu_name )
0 commit comments