|
64 | 64 | FP8_DTYPE = current_platform.fp8_dtype() |
65 | 65 | MiB = 1024 * 1024 |
66 | 66 |
|
67 | | -# FlashInfer max sizes per world size (from collective_fusion.py) |
| 67 | +# FlashInfer max sizes per world size |
| 68 | +# Enable 64MB for 2, 4, 8 world sizes to verify large input sizes |
| 69 | +# use --disable-oneshot to disable oneshot mode for very large input sizes |
68 | 70 | _FI_MAX_SIZES = { |
69 | 71 | 2: 64 * MiB, # 64MB |
70 | | - 4: 32 * MiB, # 32MB |
71 | | - 6: 32 * MiB, # 32MB |
72 | | - 8: 32 * MiB, # 32MB |
| 72 | + 4: 64 * MiB, # 64MB |
| 73 | + 8: 64 * MiB, # 64MB |
73 | 74 | } |
74 | 75 |
|
75 | 76 | # Global workspace tensor for FlashInfer |
@@ -186,7 +187,7 @@ def flashinfer_fused_allreduce_rmsnorm( |
186 | 187 | allreduce_out=None, |
187 | 188 | quant_out=None, |
188 | 189 | scale_out=None, |
189 | | - layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED, |
| 190 | + layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4_, |
190 | 191 | scale_factor=None, |
191 | 192 | use_oneshot=use_oneshot, |
192 | 193 | **allreduce_params.get_trtllm_fused_allreduce_kwargs(), |
@@ -228,7 +229,7 @@ def flashinfer_fused_allreduce_rmsnorm_fp8_quant( |
228 | 229 | allreduce_out=None, |
229 | 230 | quant_out=quant_out, |
230 | 231 | scale_out=None, |
231 | | - layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED, |
| 232 | + layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4, |
232 | 233 | scale_factor=scale_factor, |
233 | 234 | use_oneshot=use_oneshot, |
234 | 235 | **allreduce_params.get_trtllm_fused_allreduce_kwargs(), |
@@ -271,7 +272,7 @@ def flashinfer_fused_allreduce_rmsnorm_fp4_quant( |
271 | 272 | allreduce_out=None, |
272 | 273 | quant_out=quant_out, |
273 | 274 | scale_out=output_scale, |
274 | | - layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED, |
| 275 | + layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4, |
275 | 276 | scale_factor=input_global_scale, |
276 | 277 | use_oneshot=use_oneshot, |
277 | 278 | **allreduce_params.get_trtllm_fused_allreduce_kwargs(), |
@@ -579,6 +580,7 @@ def run_benchmarks( |
579 | 580 | use_residual: bool, |
580 | 581 | allreduce_params: Optional[FlashInferFusedAllReduceParams], |
581 | 582 | quant_mode: str = "all", |
| 583 | + disable_oneshot: bool = False, |
582 | 584 | ): |
583 | 585 | """Run all benchmarks for given configuration. |
584 | 586 |
|
@@ -638,17 +640,18 @@ def run_benchmarks( |
638 | 640 | # FlashInfer Fused AllReduce + RMSNorm Oneshot |
639 | 641 | if flashinfer_comm is not None and allreduce_params is not None: |
640 | 642 | try: |
641 | | - time_ms = benchmark_operation( |
642 | | - flashinfer_fused_allreduce_rmsnorm, |
643 | | - input_tensor, |
644 | | - residual=residual, |
645 | | - norm_out=norm_out, |
646 | | - rms_gamma=rms_gamma, |
647 | | - rms_eps=rms_eps, |
648 | | - allreduce_params=allreduce_params, |
649 | | - use_oneshot=True, |
650 | | - ) |
651 | | - results["flashinfer_fused_allreduce_rmsnorm_oneshot"] = time_ms |
| 643 | + if not disable_oneshot: |
| 644 | + time_ms = benchmark_operation( |
| 645 | + flashinfer_fused_allreduce_rmsnorm, |
| 646 | + input_tensor, |
| 647 | + residual=residual, |
| 648 | + norm_out=norm_out, |
| 649 | + rms_gamma=rms_gamma, |
| 650 | + rms_eps=rms_eps, |
| 651 | + allreduce_params=allreduce_params, |
| 652 | + use_oneshot=True, |
| 653 | + ) |
| 654 | + results["flashinfer_fused_allreduce_rmsnorm_oneshot"] = time_ms |
652 | 655 | except Exception as e: |
653 | 656 | logger.error("FlashInfer Fused AllReduce+RMSNorm Oneshot failed: %s", e) |
654 | 657 | results["flashinfer_fused_allreduce_rmsnorm_oneshot"] = float("inf") |
@@ -712,21 +715,22 @@ def run_benchmarks( |
712 | 715 | # FlashInfer Fused AllReduce + RMSNorm + FP8 Quant Oneshot |
713 | 716 | if flashinfer_comm is not None and allreduce_params is not None: |
714 | 717 | try: |
715 | | - time_ms = benchmark_operation( |
716 | | - flashinfer_fused_allreduce_rmsnorm_fp8_quant, |
717 | | - input_tensor, |
718 | | - norm_out=norm_out, |
719 | | - residual=residual, |
720 | | - rms_gamma=rms_gamma, |
721 | | - rms_eps=rms_eps, |
722 | | - scale_factor=scale_fp8, |
723 | | - quant_out=quant_out_fp8, |
724 | | - allreduce_params=allreduce_params, |
725 | | - use_oneshot=True, |
726 | | - ) |
727 | | - results["flashinfer_fused_allreduce_rmsnorm_fp8_quant_oneshot"] = ( |
728 | | - time_ms |
729 | | - ) |
| 718 | + if not disable_oneshot: |
| 719 | + time_ms = benchmark_operation( |
| 720 | + flashinfer_fused_allreduce_rmsnorm_fp8_quant, |
| 721 | + input_tensor, |
| 722 | + norm_out=norm_out, |
| 723 | + residual=residual, |
| 724 | + rms_gamma=rms_gamma, |
| 725 | + rms_eps=rms_eps, |
| 726 | + scale_factor=scale_fp8, |
| 727 | + quant_out=quant_out_fp8, |
| 728 | + allreduce_params=allreduce_params, |
| 729 | + use_oneshot=True, |
| 730 | + ) |
| 731 | + results["flashinfer_fused_allreduce_rmsnorm_fp8_quant_oneshot"] = ( |
| 732 | + time_ms |
| 733 | + ) |
730 | 734 | except Exception as e: |
731 | 735 | logger.error( |
732 | 736 | "FlashInfer Fused AllReduce+RMSNorm+FP8 Oneshot failed: %s", |
@@ -802,22 +806,23 @@ def run_benchmarks( |
802 | 806 | # FlashInfer Fused AllReduce + RMSNorm + FP4 Quant Oneshot |
803 | 807 | if flashinfer_comm is not None and allreduce_params is not None: |
804 | 808 | try: |
805 | | - time_ms = benchmark_operation( |
806 | | - flashinfer_fused_allreduce_rmsnorm_fp4_quant, |
807 | | - input_tensor, |
808 | | - residual=residual, |
809 | | - norm_out=norm_out, |
810 | | - rms_gamma=rms_gamma, |
811 | | - rms_eps=rms_eps, |
812 | | - input_global_scale=scale_fp4, |
813 | | - allreduce_params=allreduce_params, |
814 | | - quant_out=fp4_quant_out, |
815 | | - output_scale=fp4_output_scale, |
816 | | - use_oneshot=True, |
817 | | - ) |
818 | | - results["flashinfer_fused_allreduce_rmsnorm_fp4_quant_oneshot"] = ( |
819 | | - time_ms |
820 | | - ) |
| 809 | + if not disable_oneshot: |
| 810 | + time_ms = benchmark_operation( |
| 811 | + flashinfer_fused_allreduce_rmsnorm_fp4_quant, |
| 812 | + input_tensor, |
| 813 | + residual=residual, |
| 814 | + norm_out=norm_out, |
| 815 | + rms_gamma=rms_gamma, |
| 816 | + rms_eps=rms_eps, |
| 817 | + input_global_scale=scale_fp4, |
| 818 | + allreduce_params=allreduce_params, |
| 819 | + quant_out=fp4_quant_out, |
| 820 | + output_scale=fp4_output_scale, |
| 821 | + use_oneshot=True, |
| 822 | + ) |
| 823 | + results["flashinfer_fused_allreduce_rmsnorm_fp4_quant_oneshot"] = ( |
| 824 | + time_ms |
| 825 | + ) |
821 | 826 | except Exception as e: |
822 | 827 | logger.error( |
823 | 828 | "FlashInfer Fused AllReduce+RMSNorm+FP4 Oneshot failed: %s", |
@@ -1224,6 +1229,7 @@ def main(): |
1224 | 1229 | use_residual, |
1225 | 1230 | allreduce_params, |
1226 | 1231 | quant_mode=quant_mode, |
| 1232 | + disable_oneshot=args.disable_oneshot, |
1227 | 1233 | ) |
1228 | 1234 |
|
1229 | 1235 | # Store results for markdown export |
|
0 commit comments