diff --git a/test/quantization/test_moe_quant.py b/test/quantization/test_moe_quant.py index 425b881dba..a5b3446ead 100644 --- a/test/quantization/test_moe_quant.py +++ b/test/quantization/test_moe_quant.py @@ -25,6 +25,8 @@ Int8WeightOnlyConfig, LinearActivationQuantizedTensor, quantize_, + PerRow, + PerTensor, ) from torchao.quantization.utils import compute_error from torchao.utils import ( @@ -32,13 +34,17 @@ TORCH_VERSION_AT_LEAST_2_6, is_sm_at_least_90, ) +from torchao.quantization.utils import compute_error if torch.version.hip is not None: pytest.skip( "ROCm support for MoE quantization is under development", allow_module_level=True, ) +from torchao.prototype.moe_quant.kernels import fp8_dq_moe_op +from torchao.quantization.utils import _fbgemm_available +torch.manual_seed(0) class TestMoEQuantCompile(unittest.TestCase): DEFAULT_PARAMS = (512, 256, 8, 2) # hidden_dim, expert_dim, num_experts, top_k @@ -68,7 +74,6 @@ def _test_impl_moe_quant( .to(device) ) input = torch.randn(input_shape, dtype=torch.bfloat16, device=device) - out = model(input) quantize_(model, config, cond_ffn_filter) @@ -363,6 +368,113 @@ def test_fp8dq_base(self, name, num_tokens, fullgraph): fullgraph=fullgraph, ) +class TestFusedMoEQuant(unittest.TestCase): + DEFAULT_PARAMS = (512, 256, 8, 2) # hidden_dim, expert_dim, num_experts, top_k + + @parameterized.expand( + [ + ("multiple_tokens", 8), + ] + ) + def test_pytorch_scaled_grouped_gemm(self, name, num_tokens): + if not torch.cuda.is_available(): + self.skipTest("Need CUDA available") + if not is_sm_at_least_90(): + self.skipTest("Requires CUDA capability >= 9.0") + + device = "cuda" + dtype = torch.bfloat16 + + config = MoEQuantConfig(Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())) + + model_params = self.DEFAULT_PARAMS + + input_shape = (num_tokens, model_params[0]) + input = torch.randn(input_shape, dtype=torch.bfloat16, device=device) + + model = ( + MOEFeedForwardAOQuantizable(*model_params, empty_init=False) + ) + model = model.to(dtype).to(device) + + out_orig = model(input) + + quantize_(model, config, cond_ffn_filter) + + w1 = model.experts.w1 + w2 = model.experts.w2 + w3 = model.experts.w3 + + router = model.router + top_k = model.top_k + + # preprocess + scores = router(input) # [T, E] + scores = torch.nn.functional.softmax(scores, dim=-1) + scores, expert_indices = torch.topk( + scores, top_k, dim=-1 + ) # [T, A], [T, A] + scores /= scores.sum(dim=-1, keepdim=True).to(input.dtype) # [T, A] + + out = fp8_dq_moe_op(input, w1, w2, w3, expert_indices, scores) + out2 = model(input) + + self.assertTrue(compute_error(out_orig, out) > 20) + self.assertTrue(compute_error(out_orig, out2) > 20) + + + @parameterized.expand( + [ + ("multiple_tokens", 8), + ] + ) + def test_fbgemm_scaled_grouped_gemm(self, name, num_tokens): + if not _fbgemm_available: + self.skipTest("Need FBGEMM available") + if not torch.cuda.is_available(): + self.skipTest("Need CUDA available") + if not is_sm_at_least_90(): + self.skipTest("Requires CUDA capability >= 9.0") + + device = "cuda" + dtype = torch.bfloat16 + + config = MoEQuantConfig(Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())) + + model_params = self.DEFAULT_PARAMS + + input_shape = (num_tokens, model_params[0]) + input = torch.randn(input_shape, dtype=torch.bfloat16, device=device) + + model = ( + MOEFeedForwardAOQuantizable(*model_params, empty_init=False, use_fbgemm_kernel=True) + ) + model = model.to(dtype).to(device) + + out_orig = model(input) + + quantize_(model, config, cond_ffn_filter) + + w1 = model.experts.w1 + w2 = model.experts.w2 + w3 = model.experts.w3 + + router = model.router + top_k = model.top_k + + # preprocess + scores = router(input) # [T, E] + scores = torch.nn.functional.softmax(scores, dim=-1) + scores, expert_indices = torch.topk( + scores, top_k, dim=-1 + ) # [T, A], [T, A] + scores /= scores.sum(dim=-1, keepdim=True).to(input.dtype) # [T, A] + + out = fp8_dq_moe_op(input, w1, w2, w3, expert_indices, scores, use_fbgemm_kernel=True) + out2 = model(input) + + self.assertTrue(compute_error(out_orig, out) > 20) + self.assertTrue(compute_error(out_orig, out2) > 20) if __name__ == "__main__": unittest.main() diff --git a/torchao/_models/mixtral-moe/model.py b/torchao/_models/mixtral-moe/model.py index 685323843d..17be122ee0 100644 --- a/torchao/_models/mixtral-moe/model.py +++ b/torchao/_models/mixtral-moe/model.py @@ -7,12 +7,14 @@ from typing import Optional import torch +import torchao import torch.nn as nn from torch import Tensor from torch.nn import functional as F from torchao.prototype.moe_quant.utils import FakeExtraDimTensor - +from torchao.quantization.utils import _torchtitan_available +from torchao.prototype.moe_quant.kernels import fp8_dq_moe_op def find_multiple(n: int, k: int) -> int: if n % k == 0: @@ -34,6 +36,7 @@ class ModelArgs: norm_eps: float = 1e-5 num_experts: int = 8 num_activated_experts: int = 2 + use_fbgemm_kernel: bool = False def __post_init__(self): if self.n_local_heads == -1: @@ -225,43 +228,6 @@ def forward( y = self.wo(y) return y - -# class ConditionalFeedForward(nn.Module): -# def __init__(self, config): -# super().__init__() -# self.w1 = nn.Parameter(torch.empty(config.num_experts, config.intermediate_size, config.dim)) -# self.w2 = nn.Parameter(torch.empty(config.num_experts, config.dim, config.intermediate_size)) -# self.w3 = nn.Parameter(torch.empty(config.num_experts, config.intermediate_size, config.dim)) - -# def forward(self, x: Tensor, expert_indices: Tensor) -> Tensor: -# w1_weights = self.w1[expert_indices] # [T, A, D, D] -# w3_weights = self.w3[expert_indices] # [T, A, D, D] -# w2_weights = self.w2[expert_indices] # [T, A, D, D] -# x1 = F.silu(torch.einsum('ti,taoi -> tao', x, w1_weights)) -# x3 = torch.einsum('ti, taoi -> tao', x, w3_weights) -# expert_outs = torch.einsum('tao, taio -> tai', (x1 * x3), w2_weights) -# return expert_outs - - -# class MOEFeedForward(nn.Module): -# def __init__(self, config) -> None: -# super().__init__() -# self.gate = nn.Linear(config.dim, config.num_experts, bias=False) -# self.cond_ffn = ConditionalFeedForward(config) -# self.dim = config.dim -# self.num_activated_experts = config.num_activated_experts -# def forward(self, x: Tensor) -> Tensor: -# x = x.view(-1, self.dim) -# # T = num_tokens, E = num_experts, D = hidden dim, A = activated experts -# # x: [T, D] -# scores = self.gate(x) # [T, E] -# expert_weights = F.softmax(scores, dim=-1) -# expert_weights, expert_indices = torch.topk(expert_weights, self.num_activated_experts, dim=-1) # [T, A], [T, A] -# expert_weights /= expert_weights.sum(dim=-1, keepdim=True) # [T, A] -# expert_outs = self.cond_ffn(x, expert_indices) -# return torch.einsum('tai,ta -> ti', expert_outs, expert_weights) - - class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-5): super().__init__() @@ -347,7 +313,9 @@ def __init__(self, config): torch.empty(config.num_experts, config.intermediate_size, config.dim) ) # E, I, D self.num_experts = config.num_experts + self.use_fbgemm_kernel = config.use_fbgemm_kernel + # TODO move this into kernels, single token decomposed kernel, multi token...etc def forward( self, x: Tensor, # T, D @@ -382,6 +350,14 @@ def forward( .unsqueeze(-1) ) return final_out + # fp8 dq moe + elif ( + isinstance(self.w1, torchao.quantization.linear_activation_quantized_tensor.LinearActivationQuantizedTensor) and + isinstance(self.w1.original_weight_tensor._layout, torchao.dtypes.floatx.float8_layout.Float8Layout) + ): + + final_out = fp8_dq_moe_op(x, self.w1, self.w2, self.w3, expert_indices, expert_weights, use_fbgemm_kernel=self.use_fbgemm_kernel) + return final_out else: expert_list = [x for x in range(self.num_experts)] diff --git a/torchao/prototype/moe_quant/kernels.py b/torchao/prototype/moe_quant/kernels.py new file mode 100644 index 0000000000..755f16e8a6 --- /dev/null +++ b/torchao/prototype/moe_quant/kernels.py @@ -0,0 +1,167 @@ +import torch +import torch.nn.functional as F +import warnings +from torchao.quantization.utils import _torchtitan_available, _fbgemm_available + +grouped_gemm_fp8_rowwise = None +if _fbgemm_available: + try: + from fbgemm_gpu.experimental.gemm.triton_gemm.grouped_gemm import grouped_gemm_fp8_rowwise + except: + pass + +__all__ = ["fp8_dq_moe_op", + "manual_pad", + "torchtitan_pad", + ] + + +def fp8_dq_moe_op(input, w1, w2, w3, expert_indices, scores, fast_accum=True, use_fbgemm_kernel=True): + # parameters + orig_in_shape = input.shape + input.reshape(-1, orig_in_shape[-1]) + num_tokens, dim = input.shape + num_experts, expert_dim, _ = w1.shape + scores = scores.view(-1, scores.shape[-1]) + top_k = scores.shape[-1] + total_activations = num_tokens*top_k + + # preprocess indices + expert_indices = expert_indices.view(-1) + activation_shuffle = expert_indices.argsort(stable=True) + token_shuffle = activation_shuffle.div(top_k).floor().to(torch.int64) + num_tokens_per_expert = torch.histc(expert_indices, bins=num_experts, min=0, max=num_experts).to(torch.int32) + + # get data for weights + w1_fp8 = w1.original_weight_tensor.tensor_impl.float8_data + w1_scale = w1.original_weight_tensor.tensor_impl.scale.squeeze() + w1_qfunc = w1.input_quant_func + w1_quant_kwargs = w1.quant_kwargs + + w3_fp8 = w3.original_weight_tensor.tensor_impl.float8_data + w3_scale = w3.original_weight_tensor.tensor_impl.scale.squeeze() + + w2_fp8 = w2.original_weight_tensor.tensor_impl.float8_data + w2_scale = w2.original_weight_tensor.tensor_impl.scale.squeeze() + w2_qfunc = w2.input_quant_func + w2_quant_kwargs = w2.quant_kwargs + + # quantize input + q_input = w1_qfunc(input, **w1_quant_kwargs) + q_input_data = q_input.tensor_impl.float8_data + q_input_scale = q_input.tensor_impl.scale.squeeze() + + + if use_fbgemm_kernel: + # quant without padding + input_fp8 = q_input_data[token_shuffle] + input_scale = q_input_scale[token_shuffle] if q_input_scale.numel()>1 else q_input_scale + + @torch._dynamo.disable() + def do_group_gemms(input_fp8, input_scale, w1_fp8, w1_scale, w2_fp8, w2_scale, w3_fp8, w3_scale, num_tokens_per_expert, w2_qfunc, w2_quant_kwargs): + assert grouped_gemm_fp8_rowwise is not None, "fbgemm kernel requires fbgemm-gpu-genai to be installed: https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gen_ai/README.md" + y1 = grouped_gemm_fp8_rowwise(input_fp8, w1_fp8.reshape(-1, w1_fp8.shape[-1]), num_tokens_per_expert, input_scale, w1_scale.reshape(-1), use_fast_accum=True, _use_warp_specialization=False) + y3 = grouped_gemm_fp8_rowwise(input_fp8, w3_fp8.reshape(-1, w3_fp8.shape[-1]), num_tokens_per_expert, input_scale, w3_scale.reshape(-1), use_fast_accum=True, _use_warp_specialization=False) + y = F.silu(y1)*y3 + y_q = w2_qfunc(y, **w2_quant_kwargs) + y_fp8 = y_q.tensor_impl.float8_data + y_scale = y_q.tensor_impl.scale.squeeze() + + # TODO use _scatter_add_indices to combine the last group gemm with the final_out calculation + out = grouped_gemm_fp8_rowwise(y_fp8, w2_fp8.view(-1, w2_fp8.shape[-1]), num_tokens_per_expert, y_scale, w2_scale.view(-1), use_fast_accum=fast_accum, _use_warp_specialization=False) + return out + + out = do_group_gemms(input_fp8, input_scale, w1_fp8, w1_scale, w2_fp8, w2_scale, w3_fp8, w3_scale, num_tokens_per_expert, w2_qfunc, w2_quant_kwargs) + + # unpad and combine output with weights + sorted_scores = scores.reshape(-1,1)[activation_shuffle] + out = out*sorted_scores + + # sum weighted outputs + final_out = torch.zeros_like(input) + final_out = final_out.scatter_add( + dim=0, + index=token_shuffle.unsqueeze(-1).expand(total_activations, dim).to(torch.int64), + src=out + ) + final_out = final_out.reshape(orig_in_shape) + return final_out + + else: + # padding + alignment = 16 + if _torchtitan_available: + num_ranks = 1 + padded_indices, m_sizes, m_offsets = torchtitan_pad(num_tokens_per_expert, alignment, num_ranks) + else: + padded_indices, m_sizes, m_offsets = manual_pad(num_tokens_per_expert, alignment) + + pad_len = padded_indices.shape[0] + valid_values = padded_indices >= 0 + + # shuffle/pad input + input_fp8 = torch.zeros((pad_len, q_input_data.shape[-1]), dtype=q_input_data.dtype, device=q_input_data.device) + input_scale = torch.zeros(pad_len, dtype=q_input_scale.dtype, device=q_input_scale.device) + input_fp8[valid_values] = q_input_data[token_shuffle] + input_scale[valid_values] = q_input_scale[token_shuffle] if q_input_scale.numel()>1 else q_input_scale + + + y1 = torch._scaled_grouped_mm(input_fp8, w1_fp8.transpose(-2, -1), input_scale, w1_scale, offs=m_offsets, out_dtype=torch.bfloat16, use_fast_accum=fast_accum) + y3 = torch._scaled_grouped_mm(input_fp8, w3_fp8.transpose(-2, -1), input_scale, w3_scale, offs=m_offsets, out_dtype=torch.bfloat16, use_fast_accum=fast_accum) + y = F.silu(y1)*y3 + y_q = w2_qfunc(y, **w2_quant_kwargs) + + y_fp8 = y_q.tensor_impl.float8_data + y_scale = y_q.tensor_impl.scale.squeeze() + out = torch._scaled_grouped_mm(y_fp8, w2_fp8.transpose(-2, -1), y_scale, w2_scale, offs=m_offsets, out_dtype=torch.bfloat16, use_fast_accum=fast_accum) + + # unpad and combine output with weights + out = out[valid_values] + sorted_scores = scores.reshape(-1,1)[activation_shuffle] + out = out*sorted_scores + + # sum weighted outputs + final_out = torch.zeros_like(input) + final_out = final_out.scatter_add( + dim=0, + index=token_shuffle.unsqueeze(-1).expand(total_activations, dim).to(torch.int64), + src=out + ) + final_out = final_out.reshape(orig_in_shape) + return final_out + +def torchtitan_pad(num_tokens_per_expert, alignment, num_ranks): + from torchtitan.experiments.kernels.moe.indices import generate_permute_indices + num_experts = num_tokens_per_expert.shape[0] + + # pad to nearest multiple of alignment that's greater than 0 + padded_sizes = (((num_tokens_per_expert + (num_tokens_per_expert==0))/alignment).ceil() * alignment) + pad_len = int(padded_sizes.sum().item()) + + padded_indices, m_sizes, m_offsets = generate_permute_indices( + num_tokens_per_expert, + num_experts, + num_ranks, + pad_len, + alignment, + use_cpu=False + ) + return padded_indices, m_sizes, m_offsets + +def manual_pad(num_tokens_per_expert, alignment): + num_experts = num_tokens_per_expert.shape[0] + + m_sizes = ((((num_tokens_per_expert + (num_tokens_per_expert==0))/alignment).ceil() * alignment)).to(torch.int32) + pad_len = int(m_sizes.sum().item()) + + padded_indices = torch.zeros(pad_len, dtype=torch.int32, device=num_tokens_per_expert.device)-1 + start_tok_index = 0 + start_pad_index = 0 + for i in range(num_experts): + end_tok_index = int(start_tok_index+num_tokens_per_expert[i].item()) + end_pad_index = int(start_pad_index+num_tokens_per_expert[i].item()) + padded_indices[start_pad_index:end_pad_index] = torch.arange(start_tok_index, end_tok_index, dtype=torch.int32, device=num_tokens_per_expert.device) + start_tok_index = end_tok_index + start_pad_index = start_pad_index + int(m_sizes[i].item()) + m_offsets = m_sizes.cumsum(0).to(torch.int32) + return padded_indices, m_sizes, m_offsets diff --git a/torchao/prototype/moe_quant/quantizable_moe_modules.py b/torchao/prototype/moe_quant/quantizable_moe_modules.py index d806f50b4f..a83c0538de 100644 --- a/torchao/prototype/moe_quant/quantizable_moe_modules.py +++ b/torchao/prototype/moe_quant/quantizable_moe_modules.py @@ -1,9 +1,11 @@ import torch +import torchao import torch.nn.functional as F from torch import Tensor, nn from torchao.prototype.moe_quant.utils import FakeExtraDimTensor - +from torchao.quantization.utils import _torchtitan_available +from torchao.prototype.moe_quant.kernels import fp8_dq_moe_op class MOEFeedForwardAOQuantizable(nn.Module): def __init__( @@ -16,11 +18,12 @@ def __init__( shared_expert=None, return_scores=False, empty_init=True, + use_fbgemm_kernel=False, ) -> None: super().__init__() self.router = nn.Linear(hidden_dim, num_experts, bias=False) self.experts = ConditionalFeedForwardAOQuantizable( - num_experts, hidden_dim, expert_dim, act_fn, empty_init + num_experts, hidden_dim, expert_dim, act_fn, empty_init, use_fbgemm_kernel, ) self.hidden_dim = hidden_dim self.top_k = top_k @@ -28,7 +31,7 @@ def __init__( self.return_scores = return_scores def forward(self, x: Tensor) -> Tensor: - batch_size = x.shape[0] + shape_no_dim = x.shape[:-1] x = x.view(-1, self.hidden_dim) # x: [T, D] scores = self.router(x) # [T, E] scores = F.softmax(scores, dim=-1) @@ -40,15 +43,16 @@ def forward(self, x: Tensor) -> Tensor: out = self.experts(x, expert_indices, scores, self.top_k) if self.shared_expert: out += self.shared_expert(x) - + out = out.reshape(*shape_no_dim, -1) + if self.return_scores: - return out.reshape(batch_size, -1, self.hidden_dim), scores + return out, scores else: - return out.reshape(batch_size, -1, self.hidden_dim) + return out class ConditionalFeedForwardAOQuantizable(nn.Module): - def __init__(self, num_experts, hidden_dim, expert_dim, act_fn, empty_init=True): + def __init__(self, num_experts, hidden_dim, expert_dim, act_fn, empty_init=True, use_fbgemm_kernel=False): super().__init__() if empty_init: self.w1 = nn.Parameter( @@ -74,12 +78,13 @@ def __init__(self, num_experts, hidden_dim, expert_dim, act_fn, empty_init=True) self.act_fn = act_fn self.hidden_dim = hidden_dim self.expert_dim = expert_dim + self.use_fbgemm_kernel = use_fbgemm_kernel def forward( self, x: Tensor, # T, D expert_indices: Tensor, # T, A - expert_weights: Tensor, # T, A + scores: Tensor, # T, A top_k: int, ) -> Tensor: num_tokens, _hidden_dim = x.shape @@ -105,11 +110,20 @@ def forward( # combine outputs final_out = ( - (torch.cat(outs, dim=0) * expert_weights.view(-1, 1)) + (torch.cat(outs, dim=0) * scores.view(-1, 1)) .sum(dim=0) .reshape(x.shape) ) return final_out + + # fp8 dq moe + elif ( + isinstance(self.w1, torchao.quantization.linear_activation_quantized_tensor.LinearActivationQuantizedTensor) and + isinstance(self.w1.original_weight_tensor._layout, torchao.dtypes.floatx.float8_layout.Float8Layout) + ): + final_out = fp8_dq_moe_op(x, self.w1, self.w2, self.w3, expert_indices, scores, use_fbgemm_kernel=self.use_fbgemm_kernel) + return final_out + else: expert_list = [x for x in range(self.num_experts)] @@ -172,7 +186,7 @@ def group_tokens_by_expert( # weigh outputs ordered_outs = torch.cat(outs, dim=0) # [T*A, D] - ordered_token_activation_weights = expert_weights.view(-1, 1)[ + ordered_token_activation_weights = scores.view(-1, 1)[ ordered_token_activations ].view(-1, 1) # [T*A, 1] weighted_ordered_outs = ( diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index 8f2554849c..12217d2eb7 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -54,6 +54,9 @@ _lm_eval_available = importlib.util.find_spec("lm_eval") is not None +_torchtitan_available = importlib.util.find_spec("torchtitan") is not None + +_fbgemm_available = importlib.util.find_spec("fbgemm_gpu") is not None # basic SQNR def compute_error(x, y):