From 5c022be5ed9f2329c07f66eb1e311fbfd7bce886 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Wed, 18 Jun 2025 07:53:43 -0700 Subject: [PATCH 1/5] improve moe training benchmarking --- .../benchmarks/benchmark_scaled_grouped_mm.py | 80 +++++++++++++------ .../moe_training/conversion_utils.py | 9 ++- .../moe_training/scaled_grouped_mm.py | 51 ++++++++++-- torchao/prototype/moe_training/tensor.py | 13 ++- 4 files changed, 115 insertions(+), 38 deletions(-) diff --git a/torchao/prototype/moe_training/benchmarks/benchmark_scaled_grouped_mm.py b/torchao/prototype/moe_training/benchmarks/benchmark_scaled_grouped_mm.py index af1a652fc0..2c77dda11d 100644 --- a/torchao/prototype/moe_training/benchmarks/benchmark_scaled_grouped_mm.py +++ b/torchao/prototype/moe_training/benchmarks/benchmark_scaled_grouped_mm.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. # this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py - +import argparse import itertools import time from dataclasses import dataclass @@ -28,11 +28,11 @@ class ExperimentConfig: A_shape: tuple[int] B_shape: tuple[int] - @dataclass(frozen=True) class ExperimentResult: - time_us: float - + torch_time_us: float + triton_time_us: bool + triton_speedup: float @dataclass(frozen=True) class Experiment: @@ -41,12 +41,12 @@ class Experiment: def get_configs() -> List[ExperimentConfig]: - A_shapes = [(2**8, 4096), (2**12, 4096), (2**16, 4096)] - B_shapes = [(4, 4096, 4096), (8, 4096, 4096), (16, 4096, 4096)] + A_shapes = [(2**8, 8192), (2**12, 8192), (2**16, 8192)] + B_shapes = [(4, 8192, 8192), (8, 8192, 8192), (16, 8192, 8192)] high_precision_dtypes = [torch.bfloat16] configs = [] for A_shape, B_shape, high_precision_dtype in itertools.product( - A_shapes, B_shapes, high_precision_dtypes + A_shapes, B_shapes, high_precision_dtypes, ): configs.append( ExperimentConfig( @@ -58,7 +58,7 @@ def get_configs() -> List[ExperimentConfig]: return configs -def run_experiment(config: ExperimentConfig) -> ExperimentResult: +def run_experiment(config: ExperimentConfig, args: argparse.Namespace) -> ExperimentResult: # define test inputs A = torch.randn( *config.A_shape, @@ -92,26 +92,54 @@ def warmup(func, *args, **kwargs): for _ in range(10): func(*args, **kwargs) - def forward_backward(A, B_t, offs): - out = _scaled_grouped_mm(A, B_t, offs=offs, out_dtype=torch.bfloat16) + def forward_backward(A, B_t, offs, use_triton=True): + out = _scaled_grouped_mm( + A, + B_t, + offs=offs, + out_dtype=torch.bfloat16, + use_triton_for_per_group_scales=use_triton, + ) out.sum().backward() - - # bench triton - warmup(forward_backward, A, B_t, offs) + torch.cuda.synchronize() + + # benchmark torch + if args.compile: + compiled_fwd_bwd = torch.compile(forward_backward) + warmup(compiled_fwd_bwd, A, B_t, offs, use_triton=False) + start_time_ns = time.perf_counter_ns() + compiled_fwd_bwd(A, B_t, offs, use_triton=False) + torch_time_ns = time.perf_counter_ns() - start_time_ns + torch_time_us = torch_time_ns / 1e3 + else: + warmup(forward_backward, A, B_t, offs, use_triton=False) + start_time_ns = time.perf_counter_ns() + forward_backward(A, B_t, offs, use_triton=False) + torch_time_ns = time.perf_counter_ns() - start_time_ns + torch_time_us = torch_time_ns / 1e3 + + # benchmark triton + warmup(forward_backward, A, B_t, offs, use_triton=True) start_time_ns = time.perf_counter_ns() - forward_backward(A, B_t, offs) - time_ns = time.perf_counter_ns() - start_time_ns - time_us = time_ns / 1e3 + forward_backward(A, B_t, offs, use_triton=True) + triton_time_ns = time.perf_counter_ns() - start_time_ns + triton_time_us = triton_time_ns / 1e3 - return ExperimentResult(time_us=time_us) + + return ExperimentResult( + torch_time_us=round(torch_time_us, 3), + triton_time_us=round(triton_time_us, 3), + triton_speedup=round(torch_time_us / triton_time_us, 3), + ) def print_results(experiments: List[Experiment]): headers = [ "A_shape", "B_shape", - "high_precision_dtype", - "time_us", + "torch_time_us", + "triton_time_us", + "triton_speedup", ] rows = [] for experiment in experiments: @@ -121,19 +149,20 @@ def print_results(experiments: List[Experiment]): [ A_shape, B_shape, - experiment.config.high_precision_dtype, - experiment.result.time_us, + experiment.result.torch_time_us, + experiment.result.triton_time_us, + experiment.result.triton_speedup, ] ) print(tabulate(rows, headers=headers)) -def main(): +def main(args: argparse.Namespace): torch.random.manual_seed(123) configs = get_configs() results = [] for config in tqdm(configs): - result = run_experiment(config) + result = run_experiment(config, args) results.append(Experiment(config=config, result=result)) # Use Tabulate to print results @@ -141,4 +170,7 @@ def main(): if __name__ == "__main__": - main() + arg_parser = argparse.ArgumentParser() + arg_parser.add_argument("--compile", action="store_true") + args = arg_parser.parse_args() + main(args) diff --git a/torchao/prototype/moe_training/conversion_utils.py b/torchao/prototype/moe_training/conversion_utils.py index 928af1cf2e..fc74d56f1e 100644 --- a/torchao/prototype/moe_training/conversion_utils.py +++ b/torchao/prototype/moe_training/conversion_utils.py @@ -27,8 +27,7 @@ class MoETrainingConfig(AOBaseConfig): For all other ops, ScaledGroupedMMTensor behaves like a regular torch.Tensor. """ - - pass + use_triton_for_per_group_scales: bool = True @register_quantize_module_handler(MoETrainingConfig) @@ -46,7 +45,7 @@ def _moe_training_transform( Returns: nn.Module: The modified module with swapped parameters. """ - out = _swap_params(module) + out = _swap_params(module, config=config) return out @@ -54,6 +53,7 @@ def _swap_params( module: nn.Module, *, module_filter_fn: Optional[Callable[[nn.Module, str], bool]] = None, + config: Optional[MoETrainingConfig] = None, ) -> nn.Module: """ Recurses through the nn.Module, recursively swapping the data tensor of @@ -69,6 +69,7 @@ def _swap_params( Returns: nn.Module: The modified module with swapped linear layers. """ + use_triton = config.use_triton_for_per_group_scales if config is not None else False if isinstance(module, nn.Parameter) and ( module_filter_fn is None or module_filter_fn(module, "") ): @@ -77,7 +78,7 @@ def _swap_params( f"Does not support a root nn.Parameter with children: {module}" ) if not isinstance(module.data, ScaledGroupedMMTensor): - new_data = ScaledGroupedMMTensor(module.data) + new_data = ScaledGroupedMMTensor(module.data, use_triton_for_per_group_scales=use_triton) return nn.Parameter(new_data, requires_grad=module.requires_grad) return module diff --git a/torchao/prototype/moe_training/scaled_grouped_mm.py b/torchao/prototype/moe_training/scaled_grouped_mm.py index d3aaf615db..4e08e0a991 100644 --- a/torchao/prototype/moe_training/scaled_grouped_mm.py +++ b/torchao/prototype/moe_training/scaled_grouped_mm.py @@ -14,6 +14,10 @@ triton_fp8_col_major_jagged_colwise_scales, triton_fp8_row_major_jagged_rowwise_scales, ) +from torchao.prototype.moe_training.utils import ( + _to_2d_jagged_float8_tensor_colwise, + _to_2d_jagged_float8_tensor_rowwise, +) from torchao.prototype.moe_training.utils import _is_column_major @@ -22,6 +26,7 @@ def _scaled_grouped_mm( B_t: torch.Tensor, offs: torch.Tensor, out_dtype: Optional[torch.dtype] = torch.bfloat16, + use_triton_for_per_group_scales: bool = True, ) -> torch.Tensor: """ This function performs dynamic float8 quantization with row-wise scaling @@ -34,6 +39,7 @@ def _scaled_grouped_mm( and in column-major memory layout. offs (int32 torch.Tensor): The offsets to use to mark the starting index of each group along dim0 of the A tensor. out_dtype (Optional[torch.dtype]): The dtype of the output tensor. Currently only torch.bfloat16 is supported. + use_triton_for_per_group_scales (bool): Whether to use custom triton kernels to compute per-group scales. Default is True. """ return _Float8GroupedMM.apply( A, @@ -53,6 +59,7 @@ def forward( B_t: torch.Tensor, offs: torch.Tensor, out_dtype: Optional[torch.dtype] = torch.bfloat16, + use_triton_for_per_group_scales: bool = True, ) -> torch.Tensor: # torchao _scaled_grouped_mm only supports A=2D, B=3D. assert A.ndim == 2, "A must be 2D" @@ -136,9 +143,12 @@ def forward( # Store what we need for backward. ctx.save_for_backward(A, B_fp8_col_major, B_scales, offs) ctx.out_dtype = out_dtype + ctx.use_triton_for_per_group_scales = use_triton_for_per_group_scales # Perform scaled grouped GEMM and return result. # output shape: scaled grouped mm of (M,K) @ (B,K,N) = (M,N) + assert not _is_column_major(A_fp8_row_major), "A must be row-major for output = A @ B" + assert _is_column_major(B_t_fp8_col_major), "B must be column-major for output = A @ B" return torch._scaled_grouped_mm( A_fp8_row_major, B_t_fp8_col_major, @@ -153,6 +163,7 @@ def forward( def backward(ctx, grad_output: torch.Tensor): A, B_fp8_col_major, B_scales, offs = ctx.saved_tensors out_dtype = ctx.out_dtype + use_triton_for_per_group_scales = ctx.use_triton_for_per_group_scales # Convert grad_output to float8, row-major for left operand of grouped GEMM # needed for grad_A: grad_output @ B @@ -175,6 +186,8 @@ def backward(ctx, grad_output: torch.Tensor): # # grad_A = grad_output @ B # grad_A = scaled grouped mm of (M,N) @ (B,N,K) = (M,K) + assert not _is_column_major(grad_output_fp8_row_major), "grad_output must be row-major for grad_A = grad_output @ B" + assert _is_column_major(B_fp8_col_major), "B must be column-major for grad_A = grad_output @ B" grad_A = torch._scaled_grouped_mm( grad_output_fp8_row_major, B_fp8_col_major, @@ -195,25 +208,47 @@ def backward(ctx, grad_output: torch.Tensor): # grad_B is a special case. both operands of the grouped gemm will be 2D with offsets determing the "groups." # Compute scales for grad_output_t and A, which are both 2D tensors with offsets which define the "jagged" groups. - grad_output_t_fp8_row_major, grad_output_t_scales = ( - triton_fp8_row_major_jagged_rowwise_scales( + if use_triton_for_per_group_scales: + grad_output_t_fp8_row_major, grad_output_t_scales = triton_fp8_row_major_jagged_rowwise_scales( grad_output_t_row_major, offs, output_dtype=torch.float8_e4m3fn, round_scales_to_power_of_2=True, ) - ) + else: + grad_output_t_fp8_row_major, grad_output_t_scales = ( + _to_2d_jagged_float8_tensor_rowwise( + grad_output_t_row_major, + offs, + torch.float8_e4m3fn, + round_scales_to_power_of_2=True, + ) - A_fp8_col_major, A_scales = triton_fp8_col_major_jagged_colwise_scales( - A_col_major, - offs, - output_dtype=torch.float8_e4m3fn, - round_scales_to_power_of_2=True, ) + if use_triton_for_per_group_scales: + A_fp8_col_major, A_scales = ( + triton_fp8_col_major_jagged_colwise_scales( + A_col_major, + offs, + output_dtype=torch.float8_e4m3fn, + round_scales_to_power_of_2=True, + ) + ) + else: + A_fp8_col_major, A_scales = ( + _to_2d_jagged_float8_tensor_colwise( + A_col_major, + offs, + torch.float8_e4m3fn, + round_scales_to_power_of_2=True, + ) + ) # Compute grad_B = grad_output_t @ A. # grad_B = grad_output_t @ A # grad_B = (N,M) @ (M,K) = (N,K) + assert not _is_column_major(grad_output_t_fp8_row_major), "grad_output_t must be row-major for grad_B = grad_output_t @ A" + assert _is_column_major(A_fp8_col_major), "A must be column-major for grad_B = grad_output_t @ A" grad_B = torch._scaled_grouped_mm( grad_output_t_fp8_row_major, A_fp8_col_major, diff --git a/torchao/prototype/moe_training/tensor.py b/torchao/prototype/moe_training/tensor.py index 2a929d3b76..1c15a22978 100644 --- a/torchao/prototype/moe_training/tensor.py +++ b/torchao/prototype/moe_training/tensor.py @@ -12,9 +12,18 @@ class ScaledGroupedMMTensor(torch.Tensor): grouped_mm_func_name = "_grouped_mm" offs_arg_name = "offs" + use_triton_for_per_group_scales = True - def __init__(self, data: torch.Tensor): + def __new__(cls, data: torch.Tensor, use_triton_for_per_group_scales: bool = True): + cls.use_triton_for_per_group_scales = use_triton_for_per_group_scales + return cls + + def __init__(self, data: torch.Tensor, use_triton_for_per_group_scales: bool = True): self._data = data + self.use_triton_for_per_group_scales = use_triton_for_per_group_scales + + def __repr__(self): + return f"ScaledGroupedMMTensor(use_triton_for_per_group_scales={self.use_triton_for_per_group_scales}, {self._data})" @classmethod def __torch_function__(cls, func, types, args, kwargs={}): @@ -31,5 +40,5 @@ def __torch_function__(cls, func, types, args, kwargs={}): B_is_3d = B.dim() == 3 has_offs = kwargs.get(cls.offs_arg_name) is not None if A_is_2d and B_is_3d and has_offs: - return _scaled_grouped_mm(*args, **kwargs) + return _scaled_grouped_mm(*args, use_triton_for_per_group_scales=self.use_triton_for_per_group_scales, **kwargs) return super().__torch_function__(func, types, args, kwargs) From 911936352cfb0dcc9ff66e7feb6e0251fb6b7b55 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Wed, 18 Jun 2025 07:55:31 -0700 Subject: [PATCH 2/5] lint --- .../benchmarks/benchmark_scaled_grouped_mm.py | 15 +++-- .../moe_training/conversion_utils.py | 5 +- .../moe_training/scaled_grouped_mm.py | 65 +++++++++++-------- torchao/prototype/moe_training/tensor.py | 10 ++- 4 files changed, 59 insertions(+), 36 deletions(-) diff --git a/torchao/prototype/moe_training/benchmarks/benchmark_scaled_grouped_mm.py b/torchao/prototype/moe_training/benchmarks/benchmark_scaled_grouped_mm.py index 2c77dda11d..8d1c1949b8 100644 --- a/torchao/prototype/moe_training/benchmarks/benchmark_scaled_grouped_mm.py +++ b/torchao/prototype/moe_training/benchmarks/benchmark_scaled_grouped_mm.py @@ -28,12 +28,14 @@ class ExperimentConfig: A_shape: tuple[int] B_shape: tuple[int] + @dataclass(frozen=True) class ExperimentResult: torch_time_us: float triton_time_us: bool triton_speedup: float + @dataclass(frozen=True) class Experiment: config: ExperimentConfig @@ -46,7 +48,9 @@ def get_configs() -> List[ExperimentConfig]: high_precision_dtypes = [torch.bfloat16] configs = [] for A_shape, B_shape, high_precision_dtype in itertools.product( - A_shapes, B_shapes, high_precision_dtypes, + A_shapes, + B_shapes, + high_precision_dtypes, ): configs.append( ExperimentConfig( @@ -58,7 +62,9 @@ def get_configs() -> List[ExperimentConfig]: return configs -def run_experiment(config: ExperimentConfig, args: argparse.Namespace) -> ExperimentResult: +def run_experiment( + config: ExperimentConfig, args: argparse.Namespace +) -> ExperimentResult: # define test inputs A = torch.randn( *config.A_shape, @@ -123,16 +129,15 @@ def forward_backward(A, B_t, offs, use_triton=True): start_time_ns = time.perf_counter_ns() forward_backward(A, B_t, offs, use_triton=True) triton_time_ns = time.perf_counter_ns() - start_time_ns - triton_time_us = triton_time_ns / 1e3 + triton_time_us = triton_time_ns / 1e3 - - return ExperimentResult( torch_time_us=round(torch_time_us, 3), triton_time_us=round(triton_time_us, 3), triton_speedup=round(torch_time_us / triton_time_us, 3), ) + def print_results(experiments: List[Experiment]): headers = [ "A_shape", diff --git a/torchao/prototype/moe_training/conversion_utils.py b/torchao/prototype/moe_training/conversion_utils.py index fc74d56f1e..6eab7e2b73 100644 --- a/torchao/prototype/moe_training/conversion_utils.py +++ b/torchao/prototype/moe_training/conversion_utils.py @@ -27,6 +27,7 @@ class MoETrainingConfig(AOBaseConfig): For all other ops, ScaledGroupedMMTensor behaves like a regular torch.Tensor. """ + use_triton_for_per_group_scales: bool = True @@ -78,7 +79,9 @@ def _swap_params( f"Does not support a root nn.Parameter with children: {module}" ) if not isinstance(module.data, ScaledGroupedMMTensor): - new_data = ScaledGroupedMMTensor(module.data, use_triton_for_per_group_scales=use_triton) + new_data = ScaledGroupedMMTensor( + module.data, use_triton_for_per_group_scales=use_triton + ) return nn.Parameter(new_data, requires_grad=module.requires_grad) return module diff --git a/torchao/prototype/moe_training/scaled_grouped_mm.py b/torchao/prototype/moe_training/scaled_grouped_mm.py index 4e08e0a991..90c2e2e4aa 100644 --- a/torchao/prototype/moe_training/scaled_grouped_mm.py +++ b/torchao/prototype/moe_training/scaled_grouped_mm.py @@ -15,10 +15,10 @@ triton_fp8_row_major_jagged_rowwise_scales, ) from torchao.prototype.moe_training.utils import ( + _is_column_major, _to_2d_jagged_float8_tensor_colwise, _to_2d_jagged_float8_tensor_rowwise, ) -from torchao.prototype.moe_training.utils import _is_column_major def _scaled_grouped_mm( @@ -147,8 +147,12 @@ def forward( # Perform scaled grouped GEMM and return result. # output shape: scaled grouped mm of (M,K) @ (B,K,N) = (M,N) - assert not _is_column_major(A_fp8_row_major), "A must be row-major for output = A @ B" - assert _is_column_major(B_t_fp8_col_major), "B must be column-major for output = A @ B" + assert not _is_column_major(A_fp8_row_major), ( + "A must be row-major for output = A @ B" + ) + assert _is_column_major(B_t_fp8_col_major), ( + "B must be column-major for output = A @ B" + ) return torch._scaled_grouped_mm( A_fp8_row_major, B_t_fp8_col_major, @@ -186,8 +190,12 @@ def backward(ctx, grad_output: torch.Tensor): # # grad_A = grad_output @ B # grad_A = scaled grouped mm of (M,N) @ (B,N,K) = (M,K) - assert not _is_column_major(grad_output_fp8_row_major), "grad_output must be row-major for grad_A = grad_output @ B" - assert _is_column_major(B_fp8_col_major), "B must be column-major for grad_A = grad_output @ B" + assert not _is_column_major(grad_output_fp8_row_major), ( + "grad_output must be row-major for grad_A = grad_output @ B" + ) + assert _is_column_major(B_fp8_col_major), ( + "B must be column-major for grad_A = grad_output @ B" + ) grad_A = torch._scaled_grouped_mm( grad_output_fp8_row_major, B_fp8_col_major, @@ -209,11 +217,13 @@ def backward(ctx, grad_output: torch.Tensor): # grad_B is a special case. both operands of the grouped gemm will be 2D with offsets determing the "groups." # Compute scales for grad_output_t and A, which are both 2D tensors with offsets which define the "jagged" groups. if use_triton_for_per_group_scales: - grad_output_t_fp8_row_major, grad_output_t_scales = triton_fp8_row_major_jagged_rowwise_scales( - grad_output_t_row_major, - offs, - output_dtype=torch.float8_e4m3fn, - round_scales_to_power_of_2=True, + grad_output_t_fp8_row_major, grad_output_t_scales = ( + triton_fp8_row_major_jagged_rowwise_scales( + grad_output_t_row_major, + offs, + output_dtype=torch.float8_e4m3fn, + round_scales_to_power_of_2=True, + ) ) else: grad_output_t_fp8_row_major, grad_output_t_scales = ( @@ -223,32 +233,31 @@ def backward(ctx, grad_output: torch.Tensor): torch.float8_e4m3fn, round_scales_to_power_of_2=True, ) - - ) + ) if use_triton_for_per_group_scales: - A_fp8_col_major, A_scales = ( - triton_fp8_col_major_jagged_colwise_scales( - A_col_major, - offs, - output_dtype=torch.float8_e4m3fn, - round_scales_to_power_of_2=True, - ) + A_fp8_col_major, A_scales = triton_fp8_col_major_jagged_colwise_scales( + A_col_major, + offs, + output_dtype=torch.float8_e4m3fn, + round_scales_to_power_of_2=True, ) else: - A_fp8_col_major, A_scales = ( - _to_2d_jagged_float8_tensor_colwise( - A_col_major, - offs, - torch.float8_e4m3fn, - round_scales_to_power_of_2=True, - ) + A_fp8_col_major, A_scales = _to_2d_jagged_float8_tensor_colwise( + A_col_major, + offs, + torch.float8_e4m3fn, + round_scales_to_power_of_2=True, ) # Compute grad_B = grad_output_t @ A. # grad_B = grad_output_t @ A # grad_B = (N,M) @ (M,K) = (N,K) - assert not _is_column_major(grad_output_t_fp8_row_major), "grad_output_t must be row-major for grad_B = grad_output_t @ A" - assert _is_column_major(A_fp8_col_major), "A must be column-major for grad_B = grad_output_t @ A" + assert not _is_column_major(grad_output_t_fp8_row_major), ( + "grad_output_t must be row-major for grad_B = grad_output_t @ A" + ) + assert _is_column_major(A_fp8_col_major), ( + "A must be column-major for grad_B = grad_output_t @ A" + ) grad_B = torch._scaled_grouped_mm( grad_output_t_fp8_row_major, A_fp8_col_major, diff --git a/torchao/prototype/moe_training/tensor.py b/torchao/prototype/moe_training/tensor.py index 1c15a22978..d3b6e96609 100644 --- a/torchao/prototype/moe_training/tensor.py +++ b/torchao/prototype/moe_training/tensor.py @@ -18,7 +18,9 @@ def __new__(cls, data: torch.Tensor, use_triton_for_per_group_scales: bool = Tru cls.use_triton_for_per_group_scales = use_triton_for_per_group_scales return cls - def __init__(self, data: torch.Tensor, use_triton_for_per_group_scales: bool = True): + def __init__( + self, data: torch.Tensor, use_triton_for_per_group_scales: bool = True + ): self._data = data self.use_triton_for_per_group_scales = use_triton_for_per_group_scales @@ -40,5 +42,9 @@ def __torch_function__(cls, func, types, args, kwargs={}): B_is_3d = B.dim() == 3 has_offs = kwargs.get(cls.offs_arg_name) is not None if A_is_2d and B_is_3d and has_offs: - return _scaled_grouped_mm(*args, use_triton_for_per_group_scales=self.use_triton_for_per_group_scales, **kwargs) + return _scaled_grouped_mm( + *args, + use_triton_for_per_group_scales=cls.use_triton_for_per_group_scales, + **kwargs, + ) return super().__torch_function__(func, types, args, kwargs) From 91a2787c0c6c8144e802d06d8f46d60983e01158 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Wed, 18 Jun 2025 09:39:26 -0700 Subject: [PATCH 3/5] readability improvements --- .../benchmarks/benchmark_scaled_grouped_mm.py | 19 +++---- .../moe_training/scaled_grouped_mm.py | 50 ++++++++----------- 2 files changed, 28 insertions(+), 41 deletions(-) diff --git a/torchao/prototype/moe_training/benchmarks/benchmark_scaled_grouped_mm.py b/torchao/prototype/moe_training/benchmarks/benchmark_scaled_grouped_mm.py index 8d1c1949b8..a347763fe6 100644 --- a/torchao/prototype/moe_training/benchmarks/benchmark_scaled_grouped_mm.py +++ b/torchao/prototype/moe_training/benchmarks/benchmark_scaled_grouped_mm.py @@ -110,19 +110,12 @@ def forward_backward(A, B_t, offs, use_triton=True): torch.cuda.synchronize() # benchmark torch - if args.compile: - compiled_fwd_bwd = torch.compile(forward_backward) - warmup(compiled_fwd_bwd, A, B_t, offs, use_triton=False) - start_time_ns = time.perf_counter_ns() - compiled_fwd_bwd(A, B_t, offs, use_triton=False) - torch_time_ns = time.perf_counter_ns() - start_time_ns - torch_time_us = torch_time_ns / 1e3 - else: - warmup(forward_backward, A, B_t, offs, use_triton=False) - start_time_ns = time.perf_counter_ns() - forward_backward(A, B_t, offs, use_triton=False) - torch_time_ns = time.perf_counter_ns() - start_time_ns - torch_time_us = torch_time_ns / 1e3 + torch_func = torch.compile(forward_backward) if args.compile else forward_backward + warmup(torch_func, A, B_t, offs, use_triton=False) + start_time_ns = time.perf_counter_ns() + torch_func(A, B_t, offs, use_triton=False) + torch_time_ns = time.perf_counter_ns() - start_time_ns + torch_time_us = torch_time_ns / 1e3 # benchmark triton warmup(forward_backward, A, B_t, offs, use_triton=True) diff --git a/torchao/prototype/moe_training/scaled_grouped_mm.py b/torchao/prototype/moe_training/scaled_grouped_mm.py index 90c2e2e4aa..f7d470e556 100644 --- a/torchao/prototype/moe_training/scaled_grouped_mm.py +++ b/torchao/prototype/moe_training/scaled_grouped_mm.py @@ -216,39 +216,33 @@ def backward(ctx, grad_output: torch.Tensor): # grad_B is a special case. both operands of the grouped gemm will be 2D with offsets determing the "groups." # Compute scales for grad_output_t and A, which are both 2D tensors with offsets which define the "jagged" groups. - if use_triton_for_per_group_scales: - grad_output_t_fp8_row_major, grad_output_t_scales = ( - triton_fp8_row_major_jagged_rowwise_scales( - grad_output_t_row_major, - offs, - output_dtype=torch.float8_e4m3fn, - round_scales_to_power_of_2=True, - ) - ) - else: - grad_output_t_fp8_row_major, grad_output_t_scales = ( - _to_2d_jagged_float8_tensor_rowwise( - grad_output_t_row_major, - offs, - torch.float8_e4m3fn, - round_scales_to_power_of_2=True, - ) - ) + per_group_rowwise_scale_func = ( + triton_fp8_row_major_jagged_rowwise_scales + if use_triton_for_per_group_scales + else _to_2d_jagged_float8_tensor_rowwise + ) + per_group_colwise_scale_func = ( + triton_fp8_col_major_jagged_colwise_scales + if use_triton_for_per_group_scales + else _to_2d_jagged_float8_tensor_colwise + ) - if use_triton_for_per_group_scales: - A_fp8_col_major, A_scales = triton_fp8_col_major_jagged_colwise_scales( - A_col_major, - offs, - output_dtype=torch.float8_e4m3fn, - round_scales_to_power_of_2=True, - ) - else: - A_fp8_col_major, A_scales = _to_2d_jagged_float8_tensor_colwise( - A_col_major, + grad_output_t_fp8_row_major, grad_output_t_scales = ( + per_group_rowwise_scale_func( + grad_output_t_row_major, offs, torch.float8_e4m3fn, round_scales_to_power_of_2=True, ) + ) + + A_fp8_col_major, A_scales = per_group_colwise_scale_func( + A_col_major, + offs, + torch.float8_e4m3fn, + round_scales_to_power_of_2=True, + ) + # Compute grad_B = grad_output_t @ A. # grad_B = grad_output_t @ A # grad_B = (N,M) @ (M,K) = (N,K) From 40526cb5fa4129fd36fea710972726a39585ba84 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Wed, 18 Jun 2025 09:51:43 -0700 Subject: [PATCH 4/5] grab use_triton for args instead of class attribute --- .../prototype/moe_training/conversion_utils.py | 1 + torchao/prototype/moe_training/tensor.py | 15 ++++++++------- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/torchao/prototype/moe_training/conversion_utils.py b/torchao/prototype/moe_training/conversion_utils.py index 6eab7e2b73..4d65303b89 100644 --- a/torchao/prototype/moe_training/conversion_utils.py +++ b/torchao/prototype/moe_training/conversion_utils.py @@ -28,6 +28,7 @@ class MoETrainingConfig(AOBaseConfig): For all other ops, ScaledGroupedMMTensor behaves like a regular torch.Tensor. """ + # temporary config flag for testing/benchmarking, will remove before graduating out of prototype use_triton_for_per_group_scales: bool = True diff --git a/torchao/prototype/moe_training/tensor.py b/torchao/prototype/moe_training/tensor.py index d3b6e96609..84a160630c 100644 --- a/torchao/prototype/moe_training/tensor.py +++ b/torchao/prototype/moe_training/tensor.py @@ -14,18 +14,14 @@ class ScaledGroupedMMTensor(torch.Tensor): offs_arg_name = "offs" use_triton_for_per_group_scales = True - def __new__(cls, data: torch.Tensor, use_triton_for_per_group_scales: bool = True): - cls.use_triton_for_per_group_scales = use_triton_for_per_group_scales - return cls - def __init__( self, data: torch.Tensor, use_triton_for_per_group_scales: bool = True ): self._data = data - self.use_triton_for_per_group_scales = use_triton_for_per_group_scales + self._use_triton_for_per_group_scales = use_triton_for_per_group_scales def __repr__(self): - return f"ScaledGroupedMMTensor(use_triton_for_per_group_scales={self.use_triton_for_per_group_scales}, {self._data})" + return f"ScaledGroupedMMTensor(use_triton_for_per_group_scales={self._use_triton_for_per_group_scales}, {self._data})" @classmethod def __torch_function__(cls, func, types, args, kwargs={}): @@ -42,9 +38,14 @@ def __torch_function__(cls, func, types, args, kwargs={}): B_is_3d = B.dim() == 3 has_offs = kwargs.get(cls.offs_arg_name) is not None if A_is_2d and B_is_3d and has_offs: + use_triton = ( + A._use_triton_for_per_group_scales + if isinstance(A, cls) + else B._use_triton_for_per_group_scales + ) return _scaled_grouped_mm( *args, - use_triton_for_per_group_scales=cls.use_triton_for_per_group_scales, + use_triton_for_per_group_scales=use_triton, **kwargs, ) return super().__torch_function__(func, types, args, kwargs) From a285fc8d68851a140a030eddce96f90e38123813 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Wed, 18 Jun 2025 09:55:32 -0700 Subject: [PATCH 5/5] add comment --- torchao/prototype/moe_training/tensor.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/torchao/prototype/moe_training/tensor.py b/torchao/prototype/moe_training/tensor.py index 84a160630c..8d7a8f815b 100644 --- a/torchao/prototype/moe_training/tensor.py +++ b/torchao/prototype/moe_training/tensor.py @@ -38,10 +38,12 @@ def __torch_function__(cls, func, types, args, kwargs={}): B_is_3d = B.dim() == 3 has_offs = kwargs.get(cls.offs_arg_name) is not None if A_is_2d and B_is_3d and has_offs: + # prefer to use B to check use_triton, as that will be the weight/nn.Parameter + # that is converted to ScaledGroupedMMTensor use_triton = ( - A._use_triton_for_per_group_scales - if isinstance(A, cls) - else B._use_triton_for_per_group_scales + B._use_triton_for_per_group_scales + if isinstance(B, cls) + else A._use_triton_for_per_group_scales ) return _scaled_grouped_mm( *args,