diff --git a/test/quantization/test_moe_quant.py b/test/quantization/test_moe_quant.py index 425b881dba..454cae2106 100644 --- a/test/quantization/test_moe_quant.py +++ b/test/quantization/test_moe_quant.py @@ -8,13 +8,12 @@ from torchao.dtypes.uintx.plain_layout import PlainAQTTensorImpl from torchao.dtypes.uintx.tensor_core_tiled_layout import TensorCoreTiledAQTTensorImpl from torchao.prototype.moe_quant.quantizable_moe_modules import ( - MOEFeedForwardAOQuantizable, + MoEFeedForwardAOQuantizable, ) from torchao.prototype.moe_quant.utils import ( FakeExtraDimTensor, MoEQuantConfig, UseFakeExtraDimTensor, - cond_ffn_filter, ) from torchao.quantization.quant_api import ( AffineQuantizedTensor, @@ -24,12 +23,14 @@ Int8DynamicActivationInt8WeightConfig, Int8WeightOnlyConfig, LinearActivationQuantizedTensor, + PerRow, quantize_, ) from torchao.quantization.utils import compute_error from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6, + TORCH_VERSION_AT_LEAST_2_8, is_sm_at_least_90, ) @@ -40,8 +41,12 @@ ) +def _moe_filter(mod, fqn): + return isinstance(mod, MoEFeedForwardAOQuantizable) + + class TestMoEQuantCompile(unittest.TestCase): - DEFAULT_PARAMS = (512, 256, 8, 2) # hidden_dim, expert_dim, num_experts, top_k + DEFAULT_PARAMS = (8, 512, 256, 2) # num_experts, hidden_dim, expert_dim, top_k @torch.no_grad() def _test_impl_moe_quant( @@ -49,11 +54,12 @@ def _test_impl_moe_quant( config, num_tokens=1, model_params=None, - base_class=AffineQuantizedTensor, + base_class=None, tensor_impl_class=None, dtype=torch.bfloat16, device="cuda", fullgraph=False, + decompose_grouped_mm=True, ): """ Tests moe quant for techniques using fake extra dim @@ -61,9 +67,13 @@ def _test_impl_moe_quant( if model_params is None: model_params = self.DEFAULT_PARAMS - input_shape = (num_tokens, model_params[0]) + input_shape = (num_tokens, model_params[1]) model = ( - MOEFeedForwardAOQuantizable(*model_params, empty_init=False) + MoEFeedForwardAOQuantizable( + *model_params, + empty_init=False, + decompose_grouped_mm=decompose_grouped_mm, + ) .to(dtype) .to(device) ) @@ -71,24 +81,27 @@ def _test_impl_moe_quant( out = model(input) - quantize_(model, config, cond_ffn_filter) + if config is not None: + quantize_(model, config, _moe_filter) if ( isinstance(config, MoEQuantConfig) and config.use_fake_extra_dim_tensor == UseFakeExtraDimTensor.TRUE ): - self.assertIsInstance(model.experts.w1, FakeExtraDimTensor) + self.assertIsInstance(model.experts.up_proj, FakeExtraDimTensor) if base_class is not None: - self.assertIsInstance(model.experts.w1.head_tensor, base_class) + self.assertIsInstance(model.experts.up_proj.head_tensor, base_class) if tensor_impl_class is not None: self.assertIsInstance( - model.experts.w1.head_tensor.tensor_impl, tensor_impl_class + model.experts.up_proj.head_tensor.tensor_impl, tensor_impl_class ) else: if base_class is not None: - self.assertIsInstance(model.experts.w1, base_class) + self.assertIsInstance(model.experts.up_proj, base_class) if tensor_impl_class is not None: - self.assertIsInstance(model.experts.w1.tensor_impl, tensor_impl_class) + self.assertIsInstance( + model.experts.up_proj.tensor_impl, tensor_impl_class + ) out_q = model(input) @@ -109,251 +122,238 @@ def _test_impl_moe_quant( @parameterized.expand( [ - ("single_token", 1, False), - ("multiple_tokens", 8, False), - ] - ) - def test_int4wo_fake_dim(self, name, num_tokens, fullgraph): - if not torch.cuda.is_available(): - self.skipTest("Need CUDA available") - if not TORCH_VERSION_AT_LEAST_2_5: - self.skipTest("Test only enabled for 2.5+") - - config = MoEQuantConfig( - Int4WeightOnlyConfig(), use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE - ) - tensor_impl_class = TensorCoreTiledAQTTensorImpl - - self._test_impl_moe_quant( - config=config, - num_tokens=num_tokens, - tensor_impl_class=tensor_impl_class, - fullgraph=fullgraph, - ) - - @parameterized.expand( - [ - ("single_token", 1, True), - ("multiple_tokens", 8, False), + ( + "single_token_grouped_mm_base", + 1, + True, + UseFakeExtraDimTensor.FALSE, + False, + ), + ( + "multiple_token_grouped_mm_base", + 8, + False, + UseFakeExtraDimTensor.FALSE, + False, + ), ] ) - def test_int4wo_base(self, name, num_tokens, fullgraph): + def test_noquant( + self, + name, + num_tokens, + fullgraph, + use_fake_extra_dim_tensor, + decompose_grouped_mm, + ): if not torch.cuda.is_available(): self.skipTest("Need CUDA available") if not is_sm_at_least_90(): self.skipTest("Requires CUDA capability >= 9.0") - if not TORCH_VERSION_AT_LEAST_2_5: - self.skipTest("Test only enabled for 2.5+") + if not (decompose_grouped_mm or TORCH_VERSION_AT_LEAST_2_8): + self.skipTest("Test only enabled for 2.8+ for grouped mm") - config = MoEQuantConfig(Int4WeightOnlyConfig()) - tensor_impl_class = TensorCoreTiledAQTTensorImpl + config = None self._test_impl_moe_quant( config=config, num_tokens=num_tokens, - tensor_impl_class=tensor_impl_class, fullgraph=fullgraph, + decompose_grouped_mm=decompose_grouped_mm, ) @parameterized.expand( [ - ("single_token", 1, False), - ("multiple_tokens", 8, False), + ("single_token_base", 1, True, UseFakeExtraDimTensor.FALSE), + ("multiple_token_base", 8, False, UseFakeExtraDimTensor.FALSE), + ("single_token_fake", 1, False, UseFakeExtraDimTensor.TRUE), + ("multiple_token_fake", 8, False, UseFakeExtraDimTensor.TRUE), ] ) - def test_int8wo_fake_dim(self, name, num_tokens, fullgraph): + def test_int4wo(self, name, num_tokens, fullgraph, use_fake_extra_dim_tensor): if not torch.cuda.is_available(): self.skipTest("Need CUDA available") if not TORCH_VERSION_AT_LEAST_2_5: self.skipTest("Test only enabled for 2.5+") config = MoEQuantConfig( - Int8WeightOnlyConfig(), use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE + base_config=Int4WeightOnlyConfig(), + use_fake_extra_dim_tensor=use_fake_extra_dim_tensor, ) - tensor_impl_class = PlainAQTTensorImpl + base_class = AffineQuantizedTensor + tensor_impl_class = TensorCoreTiledAQTTensorImpl + decompose_grouped_mm = True self._test_impl_moe_quant( config=config, num_tokens=num_tokens, tensor_impl_class=tensor_impl_class, + base_class=base_class, fullgraph=fullgraph, + decompose_grouped_mm=decompose_grouped_mm, ) @parameterized.expand( [ - ("single_token", 1, True), - ("multiple_tokens", 8, False), + ("single_token_base", 1, True, UseFakeExtraDimTensor.FALSE), + ("multiple_token_base", 8, False, UseFakeExtraDimTensor.FALSE), + ("single_token_fake", 1, False, UseFakeExtraDimTensor.TRUE), + ("multiple_token_fake", 8, False, UseFakeExtraDimTensor.TRUE), ] ) - def test_int8wo_base(self, name, num_tokens, fullgraph): + def test_int8wo(self, name, num_tokens, fullgraph, use_fake_extra_dim_tensor): if not torch.cuda.is_available(): self.skipTest("Need CUDA available") if not TORCH_VERSION_AT_LEAST_2_6: self.skipTest("Test only enabled for 2.6+") - config = MoEQuantConfig(Int8WeightOnlyConfig()) + config = MoEQuantConfig( + base_config=Int8WeightOnlyConfig(), + use_fake_extra_dim_tensor=use_fake_extra_dim_tensor, + ) tensor_impl_class = PlainAQTTensorImpl + base_class = AffineQuantizedTensor + decompose_grouped_mm = True self._test_impl_moe_quant( config=config, num_tokens=num_tokens, tensor_impl_class=tensor_impl_class, + base_class=base_class, fullgraph=fullgraph, + decompose_grouped_mm=decompose_grouped_mm, ) @parameterized.expand( [ - ("single_token", 1, True), - ("multiple_tokens", 8, False), + ("single_token_base", 1, True, UseFakeExtraDimTensor.FALSE), + ("multiple_token_base", 8, False, UseFakeExtraDimTensor.FALSE), + ("single_token_fake", 1, False, UseFakeExtraDimTensor.TRUE), + ("multiple_token_fake", 8, False, UseFakeExtraDimTensor.TRUE), ] ) - def test_int8wo_base_cpu(self, name, num_tokens, fullgraph): + def test_int8wo_cpu(self, name, num_tokens, fullgraph, use_fake_extra_dim_tensor): if not TORCH_VERSION_AT_LEAST_2_6: self.skipTest("Test only enabled for 2.6+") - config = MoEQuantConfig(Int8WeightOnlyConfig()) + config = MoEQuantConfig( + base_config=Int8WeightOnlyConfig(), + use_fake_extra_dim_tensor=use_fake_extra_dim_tensor, + ) tensor_impl_class = PlainAQTTensorImpl + base_class = AffineQuantizedTensor + decompose_grouped_mm = True self._test_impl_moe_quant( config=config, num_tokens=num_tokens, tensor_impl_class=tensor_impl_class, + base_class=base_class, fullgraph=fullgraph, + decompose_grouped_mm=decompose_grouped_mm, device="cpu", ) @parameterized.expand( [ - ("multiple_tokens", 32, False), + ("multiple_tokens_base", 32, False, UseFakeExtraDimTensor.FALSE), + ("multiple_tokens_fake", 32, False, UseFakeExtraDimTensor.TRUE), ] ) - def test_int8dq_fake_dim(self, name, num_tokens, fullgraph): + def test_int8dq(self, name, num_tokens, fullgraph, use_fake_extra_dim_tensor): if not torch.cuda.is_available(): self.skipTest("Need CUDA available") if not TORCH_VERSION_AT_LEAST_2_5: self.skipTest("Test only enabled for 2.5+") config = MoEQuantConfig( - Int8DynamicActivationInt8WeightConfig(), - use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE, - ) - base_class = LinearActivationQuantizedTensor - - self._test_impl_moe_quant( - model_params=(512, 256, 2, 2), - config=config, - num_tokens=num_tokens, - base_class=base_class, - fullgraph=fullgraph, + base_config=Int8DynamicActivationInt8WeightConfig(), + use_fake_extra_dim_tensor=use_fake_extra_dim_tensor, ) - - @parameterized.expand( - [ - ("multiple_tokens", 32, False), - ] - ) - def test_int8dq_base(self, name, num_tokens, fullgraph): - if not torch.cuda.is_available(): - self.skipTest("Need CUDA available") - if not TORCH_VERSION_AT_LEAST_2_5: - self.skipTest("Test only enabled for 2.5+") - - config = MoEQuantConfig(Int8DynamicActivationInt8WeightConfig()) base_class = LinearActivationQuantizedTensor + decompose_grouped_mm = True self._test_impl_moe_quant( - model_params=(512, 256, 2, 2), + model_params=(2, 512, 256, 2), config=config, num_tokens=num_tokens, base_class=base_class, fullgraph=fullgraph, + decompose_grouped_mm=decompose_grouped_mm, ) @parameterized.expand( [ - ("single_token", 1, False), - ("multiple_tokens", 8, False), + ("single_token_base", 1, True, UseFakeExtraDimTensor.FALSE), + ("multiple_token_base", 8, False, UseFakeExtraDimTensor.FALSE), + ("single_token_fake", 1, False, UseFakeExtraDimTensor.TRUE), + ("multiple_token_fake", 8, False, UseFakeExtraDimTensor.TRUE), ] ) - def test_fp8wo_fake_dim(self, name, num_tokens, fullgraph): + def test_fp8wo(self, name, num_tokens, fullgraph, use_fake_extra_dim_tensor): if not torch.cuda.is_available(): self.skipTest("Need CUDA available") if not is_sm_at_least_90(): self.skipTest("Requires CUDA capability >= 9.0") config = MoEQuantConfig( - Float8WeightOnlyConfig(), - use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE, - ) - tensor_impl_class = Float8AQTTensorImpl - - self._test_impl_moe_quant( - config=config, - num_tokens=num_tokens, - tensor_impl_class=tensor_impl_class, - fullgraph=fullgraph, + base_config=Float8WeightOnlyConfig(), + use_fake_extra_dim_tensor=use_fake_extra_dim_tensor, ) - - @parameterized.expand( - [ - ("single_token", 1, True), - ("multiple_tokens", 8, False), - ] - ) - def test_fp8wo_base(self, name, num_tokens, fullgraph): - if not torch.cuda.is_available(): - self.skipTest("Need CUDA available") - if not is_sm_at_least_90(): - self.skipTest("Requires CUDA capability >= 9.0") - - config = MoEQuantConfig(Float8WeightOnlyConfig()) tensor_impl_class = Float8AQTTensorImpl + base_class = AffineQuantizedTensor + decompose_grouped_mm = True self._test_impl_moe_quant( config=config, num_tokens=num_tokens, tensor_impl_class=tensor_impl_class, + base_class=base_class, fullgraph=fullgraph, + decompose_grouped_mm=decompose_grouped_mm, ) @parameterized.expand( [ - ("single_token", 1, False), - ("multiple_tokens", 8, False), + ("single_token_base", 1, True, UseFakeExtraDimTensor.FALSE, True), + ("multiple_token_base", 8, False, UseFakeExtraDimTensor.FALSE, True), + ("single_token_fake", 1, False, UseFakeExtraDimTensor.TRUE, True), + ("multiple_token_fake", 8, False, UseFakeExtraDimTensor.TRUE, True), + ( + "single_token_grouped_mm_base", + 1, + True, + UseFakeExtraDimTensor.FALSE, + False, + ), + ( + "multiple_token_grouped_mm_base", + 8, + True, + UseFakeExtraDimTensor.FALSE, + False, + ), ] ) - def test_fp8dq_fake_dim(self, name, num_tokens, fullgraph): + def test_fp8dq( + self, + name, + num_tokens, + fullgraph, + use_fake_extra_dim_tensor, + decompose_grouped_mm, + ): if not torch.cuda.is_available(): self.skipTest("Need CUDA available") if not is_sm_at_least_90(): self.skipTest("Requires CUDA capability >= 9.0") + if not (decompose_grouped_mm or TORCH_VERSION_AT_LEAST_2_8): + self.skipTest("Test only enabled for 2.8+ for grouped mm") config = MoEQuantConfig( - Float8DynamicActivationFloat8WeightConfig(), - use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE, - ) - base_class = LinearActivationQuantizedTensor - - self._test_impl_moe_quant( - config=config, - num_tokens=num_tokens, - base_class=base_class, - fullgraph=fullgraph, + Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()), + use_fake_extra_dim_tensor=use_fake_extra_dim_tensor, ) - - @parameterized.expand( - [ - ("single_token", 1, True), - ("multiple_tokens", 8, False), - ] - ) - def test_fp8dq_base(self, name, num_tokens, fullgraph): - if not torch.cuda.is_available(): - self.skipTest("Need CUDA available") - if not is_sm_at_least_90(): - self.skipTest("Requires CUDA capability >= 9.0") - - config = MoEQuantConfig(Float8DynamicActivationFloat8WeightConfig()) base_class = LinearActivationQuantizedTensor self._test_impl_moe_quant( @@ -361,6 +361,7 @@ def test_fp8dq_base(self, name, num_tokens, fullgraph): num_tokens=num_tokens, base_class=base_class, fullgraph=fullgraph, + decompose_grouped_mm=decompose_grouped_mm, ) diff --git a/torchao/_models/mixtral-moe/generate.py b/torchao/_models/mixtral-moe/generate.py index 11a53043ad..0cc94d409e 100644 --- a/torchao/_models/mixtral-moe/generate.py +++ b/torchao/_models/mixtral-moe/generate.py @@ -12,6 +12,7 @@ import torch import torch._dynamo.config import torch._inductor.config +from model import MoEFeedForward from torchao.utils import get_model_size_in_bytes @@ -199,7 +200,9 @@ def main( checkpoint_path: Path = Path("checkpoints/mistralai/Mixtral-8x7B-v0.1/model.pth"), compile: bool = True, compile_prefill: bool = False, + compile_mode: str = "reduce-overhead", moe_quant: Optional[str] = None, + decompose_grouped_mm: bool = False, profile: Optional[Path] = None, memory_profile: Optional[Path] = None, device="cuda", @@ -212,6 +215,13 @@ def main( precision = torch.bfloat16 is_chat = "chat" in str(checkpoint_path) + if batch_size > 1 and moe_quant is None: + print( + "Warning: Detected no moe_quant but batchsize>1. The default MoE implementation uses a lot of memory when batched," + + " if it OOMs you can instead run without quantization by specifying --moe_quant noquant which uses the AO quantizable" + + "module without quantization to run the quantizable module without quantization" + ) + if device == "cuda" and memory_profile is not None: torch.cuda.memory._record_memory_history( True, trace_alloc_max_entries=500000, trace_alloc_record_context=True @@ -236,10 +246,10 @@ def main( ] ) - from torchao.prototype.moe_quant.utils import ( + from torchao.prototype.moe_quant import ( + MoEMapping, MoEQuantConfig, UseFakeExtraDimTensor, - cond_ffn_filter, ) from torchao.quantization.quant_api import ( Float8DynamicActivationFloat8WeightConfig, @@ -255,71 +265,72 @@ def main( if moe_quant: torch._dynamo.config.capture_dynamic_output_shape_ops = True - config = None + config = MoEQuantConfig( + mapping=MoEMapping( + target_module_type=MoEFeedForward, + decompose_grouped_mm=decompose_grouped_mm, + ) + ) if "int8wo-base" in moe_quant: - config = MoEQuantConfig(Int8WeightOnlyConfig()) + config.base_config = Int8WeightOnlyConfig() elif "int8wo" in moe_quant: - config = MoEQuantConfig( - Int8WeightOnlyConfig(), - use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE, - ) + config.base_config = Int8WeightOnlyConfig() + config.use_fake_extra_dim_tensor = UseFakeExtraDimTensor.TRUE elif "int8dq-base" in moe_quant: - config = MoEQuantConfig(Int8DynamicActivationInt8WeightConfig()) + config.base_config = Int8DynamicActivationInt8WeightConfig() elif "int8dq" in moe_quant: - config = MoEQuantConfig( - Int8DynamicActivationInt8WeightConfig(), - use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE, - ) + config.base_config = Int8DynamicActivationInt8WeightConfig() + config.use_fake_extra_dim_tensor = UseFakeExtraDimTensor.TRUE elif "int4wo-base" in moe_quant: - config = MoEQuantConfig(Int4WeightOnlyConfig()) + config.base_config = Int4WeightOnlyConfig() elif "int4wo" in moe_quant: - config = MoEQuantConfig( - Int4WeightOnlyConfig(), - use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE, - ) + config.base_config = Int4WeightOnlyConfig() + config.use_fake_extra_dim_tensor = UseFakeExtraDimTensor.TRUE elif "fp8wo-base" in moe_quant: - config = MoEQuantConfig(Float8WeightOnlyConfig()) + config.base_config = Float8WeightOnlyConfig() elif "fp8wo" in moe_quant: - config = MoEQuantConfig( - Float8WeightOnlyConfig(), - use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE, - ) + config.base_config = Float8WeightOnlyConfig() + config.use_fake_extra_dim_tensor = UseFakeExtraDimTensor.TRUE elif "fp8dq-base" in moe_quant: - config = MoEQuantConfig( - Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()) + config.base_config = Float8DynamicActivationFloat8WeightConfig( + granularity=PerRow() ) elif "fp8dq" in moe_quant: - config = MoEQuantConfig( - Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()), - use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE, + config.base_config = Float8DynamicActivationFloat8WeightConfig( + granularity=PerRow() ) + config.use_fake_extra_dim_tensor = UseFakeExtraDimTensor.TRUE elif "intxdq" in moe_quant: - config = MoEQuantConfig( + config.base_config = ( Int8DynamicActivationIntxWeightConfig( layout=PackedLinearInt8DynamicActivationIntxWeightLayout(), ), - use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE, ) + config.use_fake_extra_dim_tensor = UseFakeExtraDimTensor.TRUE + elif "noquant" in moe_quant: + pass else: assert config is not None, ( f"expected moe_quant to match one of the options but got {moe_quant}" ) - if config is not None: - quantize_(model, config, filter_fn=cond_ffn_filter, device=device) - print( - f"Time to apply quantization with config {config} to model: {time.time() - t0:.02f} seconds" - ) + def filter_fn(mod, fqn): + return isinstance(mod, MoEFeedForward) + + quantize_(model, config, filter_fn=filter_fn, device=device) + print( + f"Time to apply quantization with config {config} to model: {time.time() - t0:.02f} seconds" + ) model.to(device=device) device_sync(device=device) @@ -335,12 +346,14 @@ def main( global decode_one_token, prefill - if batch_size == 1 and (isinstance(moe_quant, str) and "base" in moe_quant): + if not decompose_grouped_mm or ( + batch_size == 1 and (isinstance(moe_quant, str) and "base" in moe_quant) + ): decode_one_token = torch.compile( - decode_one_token, mode="reduce-overhead", fullgraph=True + decode_one_token, mode=compile_mode, fullgraph=True ) else: - decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead") + decode_one_token = torch.compile(decode_one_token, mode=compile_mode) if args.compile_prefill: prefill = torch.compile(prefill, fullgraph=True, dynamic=True) @@ -474,11 +487,22 @@ def callback(x): action="store_true", help="Whether to compile the prefill (improves prefill perf, but higher compile times)", ) - # parser.add_argument('-q', '--quantization', type=str, help='Which quantization techniques to apply: int8dq, int8wo, int4wo, fp8') + parser.add_argument( + "--compile_mode", + type=str, + default="reduce-overhead", + help="which torch.compile mode to use: reduce-overhead or max-autotune, does nothing if --compile is not set.", + ) parser.add_argument( "--moe_quant", type=str, - help="Which quantization techniques to apply: int8dq, int8wo, int4wo, fp8wo, fp8dq", + help="Which quantization techniques to apply: int8dq, int8wo, int4wo, fp8wo, fp8dq, noquant", + ) + parser.add_argument( + "--decompose_grouped_mm", + action="store_true", + default=False, + help="Whether to decompose grouped_mm into linear ops for the MoE module, only relevant when moe_quant is set", ) parser.add_argument("--profile", type=Path, default=None, help="Profile path.") parser.add_argument( @@ -499,7 +523,9 @@ def callback(x): args.checkpoint_path, args.compile, args.compile_prefill, + args.compile_mode, args.moe_quant, + args.decompose_grouped_mm, args.profile, args.memory_profile, args.device, diff --git a/torchao/_models/mixtral-moe/model.py b/torchao/_models/mixtral-moe/model.py index 685323843d..bab15df948 100644 --- a/torchao/_models/mixtral-moe/model.py +++ b/torchao/_models/mixtral-moe/model.py @@ -11,8 +11,6 @@ from torch import Tensor from torch.nn import functional as F -from torchao.prototype.moe_quant.utils import FakeExtraDimTensor - def find_multiple(n: int, k: int) -> int: if n % k == 0: @@ -156,7 +154,7 @@ class TransformerBlock(nn.Module): def __init__(self, config: ModelArgs) -> None: super().__init__() self.attention = Attention(config) - self.block_sparse_moe = MOEFeedForwardAOQuantizable(config) + self.block_sparse_moe = MoEFeedForward(config) self.ffn_norm = RMSNorm(config.dim, config.norm_eps) self.attention_norm = RMSNorm(config.dim, config.norm_eps) @@ -226,40 +224,49 @@ def forward( return y -# class ConditionalFeedForward(nn.Module): -# def __init__(self, config): -# super().__init__() -# self.w1 = nn.Parameter(torch.empty(config.num_experts, config.intermediate_size, config.dim)) -# self.w2 = nn.Parameter(torch.empty(config.num_experts, config.dim, config.intermediate_size)) -# self.w3 = nn.Parameter(torch.empty(config.num_experts, config.intermediate_size, config.dim)) - -# def forward(self, x: Tensor, expert_indices: Tensor) -> Tensor: -# w1_weights = self.w1[expert_indices] # [T, A, D, D] -# w3_weights = self.w3[expert_indices] # [T, A, D, D] -# w2_weights = self.w2[expert_indices] # [T, A, D, D] -# x1 = F.silu(torch.einsum('ti,taoi -> tao', x, w1_weights)) -# x3 = torch.einsum('ti, taoi -> tao', x, w3_weights) -# expert_outs = torch.einsum('tao, taio -> tai', (x1 * x3), w2_weights) -# return expert_outs - - -# class MOEFeedForward(nn.Module): -# def __init__(self, config) -> None: -# super().__init__() -# self.gate = nn.Linear(config.dim, config.num_experts, bias=False) -# self.cond_ffn = ConditionalFeedForward(config) -# self.dim = config.dim -# self.num_activated_experts = config.num_activated_experts -# def forward(self, x: Tensor) -> Tensor: -# x = x.view(-1, self.dim) -# # T = num_tokens, E = num_experts, D = hidden dim, A = activated experts -# # x: [T, D] -# scores = self.gate(x) # [T, E] -# expert_weights = F.softmax(scores, dim=-1) -# expert_weights, expert_indices = torch.topk(expert_weights, self.num_activated_experts, dim=-1) # [T, A], [T, A] -# expert_weights /= expert_weights.sum(dim=-1, keepdim=True) # [T, A] -# expert_outs = self.cond_ffn(x, expert_indices) -# return torch.einsum('tai,ta -> ti', expert_outs, expert_weights) +class MoEFeedForward(nn.Module): + def __init__(self, config) -> None: + super().__init__() + self.gate = nn.Linear(config.dim, config.num_experts, bias=False) + self.cond_ffn = ConditionalFeedForward(config) + self.dim = config.dim + self.num_activated_experts = config.num_activated_experts + + def forward(self, x: Tensor) -> Tensor: + x = x.view(-1, self.dim) + # T = num_tokens, E = num_experts, D = hidden dim, A = activated experts + # x: [T, D] + scores = self.gate(x) # [T, E] + expert_weights = F.softmax(scores, dim=-1) + expert_weights, expert_indices = torch.topk( + expert_weights, self.num_activated_experts, dim=-1 + ) # [T, A], [T, A] + expert_weights /= expert_weights.sum(dim=-1, keepdim=True) # [T, A] + expert_outs = self.cond_ffn(x, expert_indices) + return torch.einsum("tai,ta -> ti", expert_outs, expert_weights) + + +class ConditionalFeedForward(nn.Module): + def __init__(self, config): + super().__init__() + self.w1 = nn.Parameter( + torch.empty(config.num_experts, config.intermediate_size, config.dim) + ) + self.w2 = nn.Parameter( + torch.empty(config.num_experts, config.dim, config.intermediate_size) + ) + self.w3 = nn.Parameter( + torch.empty(config.num_experts, config.intermediate_size, config.dim) + ) + + def forward(self, x: Tensor, expert_indices: Tensor) -> Tensor: + w1_weights = self.w1[expert_indices] # [T, A, D, D] + w3_weights = self.w3[expert_indices] # [T, A, D, D] + w2_weights = self.w2[expert_indices] # [T, A, D, D] + x1 = F.silu(torch.einsum("ti,taoi -> tao", x, w1_weights)) + x3 = torch.einsum("ti, taoi -> tao", x, w3_weights) + expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights) + return expert_outs class RMSNorm(nn.Module): @@ -300,165 +307,3 @@ def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: x_out2 = x_out2.flatten(3) return x_out2.type_as(x) - - -# T tokens -# E experts -# D dim -# I intermediate dim -# A activated experts -# T'(e) tokens for expert e - - -class MOEFeedForwardAOQuantizable(nn.Module): - def __init__(self, config) -> None: - super().__init__() - self.gate = nn.Linear(config.dim, config.num_experts, bias=False) - self.cond_ffn = ConditionalFeedForwardAOQuantizable(config) - self.dim = config.dim - self.num_activated_experts = config.num_activated_experts - - def forward(self, x: Tensor) -> Tensor: - batch_size = x.shape[0] - x = x.view(-1, self.dim) # x: [T, D] - scores = self.gate(x) # [T, E] - expert_weights = F.softmax(scores, dim=-1) - expert_weights, expert_indices = torch.topk( - expert_weights, self.num_activated_experts, dim=-1 - ) # [T, A], [T, A] - expert_weights /= expert_weights.sum(dim=-1, keepdim=True).to(x.dtype) # [T, A] - out = self.cond_ffn( - x, expert_indices, expert_weights, self.num_activated_experts - ) - return out.reshape(batch_size, -1, self.dim) - - -class ConditionalFeedForwardAOQuantizable(nn.Module): - def __init__(self, config): - super().__init__() - self.config = config - self.w1 = nn.Parameter( - torch.empty(config.num_experts, config.intermediate_size, config.dim) - ) # E, I, D - self.w2 = nn.Parameter( - torch.empty(config.num_experts, config.dim, config.intermediate_size) - ) # E, D, I - self.w3 = nn.Parameter( - torch.empty(config.num_experts, config.intermediate_size, config.dim) - ) # E, I, D - self.num_experts = config.num_experts - - def forward( - self, - x: Tensor, # T, D - expert_indices: Tensor, # T, A - expert_weights: Tensor, # T, A - num_activated_experts: int, - ) -> Tensor: - num_tokens, dim = x.shape - num_token_activations = num_tokens * num_activated_experts - if x.shape[0] == 1 and not isinstance( - self.w1, FakeExtraDimTensor - ): # only 1 token (can be done without graph breaks when compiled) - outs = [] - expert_indices = expert_indices.view(num_activated_experts) - # collect used experts - w1 = self.w1[expert_indices] - w2 = self.w2[expert_indices] - w3 = self.w3[expert_indices] - - # run token through each expert - for index in range(num_activated_experts): - y1 = F.silu(F.linear(x, w1[index])) - y3 = F.linear(x, w3[index]) - y2 = w2[index] - cur_out = F.linear(y1 * y3, y2) - outs.append(cur_out) - - # combine outputs - final_out = ( - (torch.cat(outs, dim=0) * expert_weights.view(-1, 1)) - .sum(dim=0) - .unsqueeze(-1) - ) - return final_out - else: - expert_list = [x for x in range(self.num_experts)] - - # shuffle tokens into groups for each expert - ordered_token_activations = expert_indices.view(-1).argsort( - stable=True - ) # [A] - ordered_token_indices = ( - ordered_token_activations.div(num_activated_experts) - .floor() - .to(torch.int64) - ) # [T] - - if not expert_indices.is_cuda: # histc doesn't work on cpu for integers - num_tokens_per_expert = torch.bincount( - expert_indices.view(-1) + 1, minlength=self.num_experts + 1 - ) - else: - num_tokens_per_expert = torch.histc( - expert_indices, - bins=self.num_experts + 1, - min=-1, - max=self.num_experts, - ) # [E+1] (added leading 0 so can be used for indexing) - cum_tokens_per_expert = num_tokens_per_expert.cumsum(0).to( - torch.int64 - ) # [E+1] - - @torch._dynamo.disable() - def group_tokens_by_expert( - ordered_token_indices, cum_tokens_per_expert, expert_list - ): - token_indices_per_expert = [ - ordered_token_indices[ - cum_tokens_per_expert[expert] : cum_tokens_per_expert[ - expert + 1 - ] - ] - for expert in expert_list - ] # [T'(e1)], [T'(e2)] ... - return token_indices_per_expert - - token_indices_per_expert = group_tokens_by_expert( - ordered_token_indices, cum_tokens_per_expert, expert_list - ) - tokens_grouped_by_expert = [ - x[indices] for indices in token_indices_per_expert - ] - - # calculate outputs for each expert - outs = [] - for cur_x, expert in zip(tokens_grouped_by_expert, expert_list): - w1 = self.w1[expert] # I, D - w2 = self.w2[expert] # D, I - w3 = self.w3[expert] # I, D - - cur_out = F.linear( - F.silu(F.linear(cur_x, w1)) * F.linear(cur_x, w3), w2 - ) # [T'(e), D] - outs.append(cur_out) - - # weigh outputs - ordered_outs = torch.cat(outs, dim=0) # [T*A, D] - ordered_token_activation_weights = expert_weights.view(-1, 1)[ - ordered_token_activations - ].view(-1, 1) # [T*A, 1] - weighted_ordered_outs = ( - ordered_outs * ordered_token_activation_weights - ) # [T*A, D] - - # sum weighted token-activation outputs together for each token - final_out = torch.zeros_like(x) # [T, D] - final_out = final_out.scatter_add( - dim=0, - index=ordered_token_indices.unsqueeze(-1) - .expand(num_token_activations, dim) - .to(torch.int64), - src=weighted_ordered_outs, - ) - return final_out diff --git a/torchao/_models/mixtral-moe/run.sh b/torchao/_models/mixtral-moe/run.sh index d9e3a50405..48af045190 100644 --- a/torchao/_models/mixtral-moe/run.sh +++ b/torchao/_models/mixtral-moe/run.sh @@ -1,39 +1,64 @@ export MODEL_REPO=mistralai/Mixtral-8x7B-Instruct-v0.1 export CHECKPOINT_PATH=checkpoints/ -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 1 --compile -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 8 --compile +######### GROUPED_MM ####### + +# noquant +# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 8 --moe_quant noquant --compile +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 8 --moe_quant noquant --compile --compile_mode "max-autotune" + +# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 1 --moe_quant noquant --compile +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 1 --moe_quant noquant --compile --compile_mode "max-autotune" + +# scaled_grouped_mm +# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 8 --moe_quant fp8dq-base --compile +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 8 --moe_quant fp8dq-base --compile --compile_mode "max-autotune" + +# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 1 --moe_quant fp8dq-base --compile +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 1 --moe_quant fp8dq-base --compile --compile_mode "max-autotune" + +# ######### MULTI TOKEN ####### -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 1 --moe_quant int8wo --compile -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 8 --moe_quant int8wo --compile +# noquant +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 8 --moe_quant noquant --compile --decompose_grouped_mm -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 1 --moe_quant int8wo-base --compile -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 8 --moe_quant int8wo-base --compile +# int8wo-base +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 8 --moe_quant int8wo-base --compile --decompose_grouped_mm -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 1 --moe_quant int4wo --compile -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 8 --moe_quant int4wo --compile +# needs balanced tokens due to minimum matmul sizes +# int8dq-base +# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 8 --moe_quant int8dq-base --compile --decompose_grouped_mm -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 1 --moe_quant int4wo-base --compile -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 8 --moe_quant int4wo-base --compile +# int4wo-base +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 8 --moe_quant int4wo-base --compile --decompose_grouped_mm + +# fp8wo-base +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 8 --moe_quant fp8wo-base --compile --decompose_grouped_mm + +# fp8dq-base +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 8 --moe_quant fp8dq-base --compile --decompose_grouped_mm + + +######### SINGLE TOKEN ####### + +# einsum +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 1 --compile -# # EXPERT CHOICE -# # python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 1 --moe_quant int8dq --compile -# # # # python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 8 --moe_quant int8dq --compile -# # python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 1 --moe_quant int8dq-base --compile -# # # # python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 8 --moe_quant int8dq-base --compile +#noquant +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 1 --moe_quant noquant --compile --decompose_grouped_mm -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 1 --moe_quant fp8wo --compile -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 8 --moe_quant fp8wo --compile +# int8wo-base +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 1 --moe_quant int8wo-base --compile --decompose_grouped_mm -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 1 --moe_quant fp8wo-base --compile -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 8 --moe_quant fp8wo-base --compile +# int4wo-base +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 1 --moe_quant int4wo-base --compile --decompose_grouped_mm -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 1 --moe_quant fp8dq --compile -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 8 --moe_quant fp8dq --compile +# fp8wo-base +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 1 --moe_quant fp8wo-base --compile --decompose_grouped_mm -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 1 --moe_quant fp8dq-base --compile -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 8 --moe_quant fp8dq-base --compile +# fp8dq-base +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 1 --moe_quant fp8dq-base --compile --decompose_grouped_mm -# ARM -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 8 --moe_quant intxdq --device cpu -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 8 --moe_quant intxdq --compile --device cpu +########## ARM ########## +# # python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 8 --moe_quant intxdq --device cpu +# # python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 8 --moe_quant intxdq --compile --device cpu diff --git a/torchao/dtypes/affine_quantized_tensor_ops.py b/torchao/dtypes/affine_quantized_tensor_ops.py index 8b028352e4..22decf4bec 100644 --- a/torchao/dtypes/affine_quantized_tensor_ops.py +++ b/torchao/dtypes/affine_quantized_tensor_ops.py @@ -109,6 +109,11 @@ aten = torch.ops.aten +grouped_mm = [ + aten._grouped_mm.default if hasattr(aten, "_grouped_mm") else None, + torch._grouped_mm if hasattr(torch, "_grouped_mm") else None, +] + _AQT_QLINEAR_DISPATCH_TABLE = {} @@ -296,6 +301,25 @@ def _(func, types, args, kwargs): return torch.nn.functional.linear(input_tensor, weight_tensor, bias) +@implements(grouped_mm) +def _(func, types, args, kwargs): + new_arg0 = ( + args[0].tensor_impl if isinstance(args[0], AffineQuantizedTensor) else args[0] + ) + new_arg1 = ( + args[1].tensor_impl if isinstance(args[1], AffineQuantizedTensor) else args[1] + ) + out = func( + *( + new_arg0, + new_arg1, + *args[2:], + ), + **kwargs, + ) + return out + + @implements(torch.nn.functional.embedding) def _(func, types, args, kwargs): if _embedding_q_dq_check(args, kwargs): @@ -484,6 +508,26 @@ def _(func, types, args, kwargs): return return_and_correct_aliasing(func, args, kwargs, new) +@implements(aten.transpose.int) +def _(func, types, args, kwargs): + tensor, dim0, dim1 = args + block_size = list(tensor.block_size) + block_size[dim0], block_size[dim1] = block_size[dim1], block_size[dim0] + new_shape = list(tensor.shape) + new_shape[dim0], new_shape[dim1] = new_shape[dim1], new_shape[dim0] + new = tensor.__class__( + func(tensor.tensor_impl, *args[1:]), + block_size, + new_shape, + tensor.quant_min, + tensor.quant_max, + tensor.zero_point_domain, + dtype=tensor.dtype, + strides=tensor.stride(), + ) + return return_and_correct_aliasing(func, args, kwargs, new) + + @implements(aten.slice.Tensor) def _(func, types, args, kwargs): self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) diff --git a/torchao/dtypes/floatx/float8_layout.py b/torchao/dtypes/floatx/float8_layout.py index 4afc5fdfee..3bf5b6ad66 100644 --- a/torchao/dtypes/floatx/float8_layout.py +++ b/torchao/dtypes/floatx/float8_layout.py @@ -32,6 +32,12 @@ FLOAT8_IMPL_OPS_TABLE: Dict[Any, Any] = {} +grouped_mm = [ + aten._grouped_mm.default if hasattr(aten, "_grouped_mm") else None, + torch._grouped_mm if hasattr(torch, "_grouped_mm") else None, +] + + def implements(aten_ops: List[Any]): """Register aten ops to the float8 op table""" @@ -100,7 +106,11 @@ def __new__( ) kwargs["dtype"] = float8_data.dtype kwargs["requires_grad"] = False - shape = float8_data.shape + shape = ( + float8_data.shape + if not transposed + else float8_data.shape[:-2] + float8_data.shape[-1:-3:-1] + ) return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] def __init__( @@ -243,13 +253,41 @@ def _(func, types, args, kwargs): ) -@implements([aten.t.default]) +@implements([aten.t.default, aten.transpose.int]) def _(func, types, args, kwargs): """we don't need to repack the weight and just rely on external shape being changed and record the status of transpose/no-transpose """ - args[0].transposed = not args[0].transposed - return return_and_correct_aliasing(func, args, kwargs, args[0]) + return return_and_correct_aliasing( + func, + args, + kwargs, + Float8AQTTensorImpl( + args[0].float8_data, + args[0].scale, + not args[0].transposed, + args[0]._layout, + ), + ) + + +@implements(grouped_mm) +def _(func, types, args, kwargs): + input, weight, offs = args[0], args[1], args[2] + assert len(args) == 3, ( + "scaled_grouped_mm only implemented with 3 args for float8 in torchao" + ) + assert weight.transposed, ( + "weight tensor must be transposed before being called in scaled_grouped_mm" + ) + in_f8 = input.float8_data + in_scale = input.scale.squeeze() + w_f8 = weight.float8_data.transpose(-2, -1) + w_scale = weight.scale.squeeze() + out = torch._scaled_grouped_mm( + in_f8, w_f8, in_scale, w_scale, offs, out_dtype=torch.bfloat16 + ) + return out @implements([aten.copy_.default]) diff --git a/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py b/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py index 8c65f6f891..07edc2d2a7 100644 --- a/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py +++ b/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py @@ -639,13 +639,12 @@ def test_identical_to_Int8DynActInt4WeightQATQuantizer( def test_moe_quant_intx(self): from torchao.prototype.moe_quant.quantizable_moe_modules import ( - MOEFeedForwardAOQuantizable, + MoEFeedForwardAOQuantizable, ) from torchao.prototype.moe_quant.utils import ( FakeExtraDimTensor, MoEQuantConfig, UseFakeExtraDimTensor, - cond_ffn_filter, ) from torchao.quantization.quant_api import ( Int8DynamicActivationIntxWeightConfig, @@ -655,10 +654,12 @@ def test_moe_quant_intx(self): from torchao.quantization.utils import compute_error with torch.device("cpu"): - model = MOEFeedForwardAOQuantizable(512, 256, 8, 2, empty_init=False).to( - torch.float32 + model = torch.nn.Sequential( + MoEFeedForwardAOQuantizable( + 2, 512, 256, 2, empty_init=False, decompose_grouped_mm=True + ).to(torch.float32) ) - x = torch.randn(8, 512, dtype=torch.float32) + x = torch.randn(64, 512, dtype=torch.float32) out = model(x).clone() @@ -666,13 +667,17 @@ def test_moe_quant_intx(self): layout=PackedLinearInt8DynamicActivationIntxWeightLayout() ) moe_config = MoEQuantConfig( - base_config, use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE + base_config=base_config, + use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE, ) - quantize_(model, moe_config, cond_ffn_filter) + def _moe_filter(mod, fqn): + return isinstance(mod, MoEFeedForwardAOQuantizable) + + quantize_(model, moe_config, _moe_filter) out_q = model(x).clone() - assert isinstance(model.experts.w1, FakeExtraDimTensor) + assert isinstance(model.experts.up_proj, FakeExtraDimTensor) mod_c = torch.compile(model, mode="reduce-overhead") diff --git a/torchao/prototype/moe_quant/README.md b/torchao/prototype/moe_quant/README.md index 734b409f65..54ad95513a 100644 --- a/torchao/prototype/moe_quant/README.md +++ b/torchao/prototype/moe_quant/README.md @@ -1,51 +1,130 @@ # MoE Quantization -Our goal with this prototype implementation of moe quantization is to enable usage of existing linear quantization techniques for moe quantization. While it would likely be more performant to use a fused kernel for quantized moe, by decomposing the moe operation into a sequence of linear operations, we can utilize the existing tools and UX that work for lienar quantization and apply them to moe. +This prototype implementation enables quantization of Mixture of Experts (MoE) models using two complementary approaches: -Examples of the usage of these apis can be found in both the llama4_quant.py and ao/torchao/_models/mixtral-moe/generate.py +1. **Grouped Matrix Multiplication (`_grouped_mm`)**: Leverages PyTorch's dedicated grouped MM kernels for optimal performance +2. **Linear Decomposition Fallback**: Decomposes MoE operations into linear operations when grouped MM is unavailable or quantized kernels don't exist -## Quantization API +## Recent Updates -The API for moe quantization is very similar to linear quantization, given a moe module that is decomposed into linear operations, is quantizable and compilable. In practice this requires us to use the modules found in quantizable_moe_modules.py or something similar. Once this change has been made the API is as follows for a few different quantization techniques: +This implementation has been significantly refactored to prioritize PyTorch's optimized `_grouped_mm` kernels: + +- **Primary `_grouped_mm` implementation**: Now uses PyTorch's improved grouped matrix multiplication without padding requirements +- **Enhanced module swapping**: Generic `MoEMapping` class for converting existing MoE implementations +- **Flexible execution modes**: Automatic selection between grouped MM and linear decomposition based on availability and tensor types +- **Improved quantization compatibility**: Better integration with existing quantized tensor subclasses + +Examples of the usage of these APIs can be found in both the `llama4_quant.py` and `torchao/_models/mixtral-moe/generate.py` + +## API + +### Supported Techniques + +- **BFloat16**: 16-bit floating point inference using `torch._grouped_mm` +- **Float8DynamicActivationFloat8WeightConfig**: Float8 dynamic activation and weight quantization using `torch.scaled_grouped_mm` +- **Int8WeightOnlyConfig**: 8-bit weight-only quantization using linear decomposition +- **Int4WeightOnlyConfig**: 4-bit weight-only quantization using linear decomposition + +### Basic Usage + +Going forward the intended direction of TorchAO's MoE quantization will be for model owners to use torch._grouped_mm and then quantize the parameters directly similar to linear quantization without requiring a module swap. + +However currently the existing space has a variety of implementations and so the expectage usage will be to swap to the AO Quantizable MoE module using the new `MoEMapping` to facilitate transfering the necessary information. As an example: ```python +from torchao.prototype.moe_quant.utils import MoEMapping, MoEQuantConfig +from torchao.quantization.quant_api import quantize_, Int4WeightOnlyConfig + +moe_mapping = MoEMapping( + target_module_type=Llama4TextMoe, + router_fqn="router", + top_k_fqn="top_k", + up_proj_fqn="experts.gate_up_proj", + up_proj_part2_fqn=None, + down_proj_fqn="experts.down_proj", + order_of_weight_indices=(0, 2, 1), + act_fn_fqn="experts.act_fn", + shared_expert_fqn="shared_expert", + return_scores=True, + decompose_grouped_mm=True, # change this to false if doing bf16 or fp8dq +) +base_config = Int4WeightOnlyConfig() # this can be set to None to just do the swap to the AO Quantizable module -from torchao.prototype.moe_quant.utils import cond_ffn_filter, -from torchao.quantization.quant_api import quantize_, Int8WeightOnlyConfig +config = MoEQuantConfig(base_config, moe_mapping) -quantize_(model, MoEQuantConfig(Int8WeightOnlyConfig()), filter_fn=cond_ffn_filter) -model=torch.compile(model, mode="reduce-overhead") -# you can also use fullgraph=True for single token inference +def moe_filter(module, fqn): + return isinstance(module, YourModelsMoEModuleClass) + +quantize_(model, config, moe_filter) +model = torch.compile(model, mode="reduce-overhead", fullgraph=False) # can use fullgraph for grouped_mm or single_token inference ``` -This api is the same as for normal linear quantization but with a specific filter function. This works for several different quantization techniques where the quantized tensor subclass has been adapted to work with 3D tensors. Specifically this means Int8WeightOnlyConfig, Int4WeightOnlyConfig, Int4WeightOnlyConfig, Float8DynamicActivationFloat8WeightConfig, and Int8DynamicActivationInt8WeightConfig. It should be noted that due to the requirements on minimum tensor input size (>16), Int8DynamicActivationInt8WeightConfig is best used for expert choice moe rather than token choice which is what the rest of the framework in this folder supports. +### Production Examples +- **Llama4**: Complete integration example in `llama4_quant.py` +- **Mixtral**: Full pipeline with benchmarking in `torchao/_models/mixtral-moe/generate.py` -## Alternative Quantization API +Both examples demonstrate end-to-end workflows including model loading, conversion, quantization, and performance evaluation. -To make the above api work, each tensor subclass had to be edited to work as 3D tensors. However the only ops we actually need to support are a few indexing and slicing ops on the 0th dimension, the majority of the work was removing hard coded assumptions about the tensor dimensionality. This means its possible to instead create a new tensor subclass that pretends to be a 3D tensor by storing a series of 2D tensors and simulating the slicing and indexing ops until eventually just returning the singular desired 2D quantized tensor subclass. This can be achieved using the alternative api by changing the fake_extra_dim_tensor flag of the MoEQuantConfig: +## Execution Modes +### 1. Grouped Matrix Multiplication (Primary) ```python +# Uses torch._grouped_mm for optimal performance +final_out = self._forward_grouped_mm(x, expert_indices, scores, up_proj, down_proj, act_fn) +``` +- **Best performance**: Leverages optimized grouped MM kernels +- **No padding required**: Uses PyTorch's improved implementation +- **Selection**: Used when `decompose_grouped_mm=False` (default) -from torchao.prototype.moe_quant.utils import cond_ffn_filter, MoEQuantConfig, UseFakeExtraDimTensor -from torchao.quantization.quant_api import quantize_, Int8DynamicActivationIntxWeightConfig +### 2. Linear Decomposition (Fallback) +```python +# Falls back to linear operations when needed +if x.shape[0] > 1: # Multi-token + final_out = self._forward_multi_token_linear_decomposition(...) +else: # Single token + final_out = self._forward_single_token_linear_decomposition(...) +``` +- **Quantization compatibility**: Works with all quantized tensor subclasses +- **Selection**: Used when `decompose_grouped_mm=True` or for quantized tensors -config = MoEQuantConfig( - Int8DynamicActivationIntxWeightConfig(), - # this is the only difference from the above api - use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE, -) -quantize_(model, , filter_fn=cond_ffn_filter) -model=torch.compile(model, mode="reduce-overhead") +## Performance Notes + +### Grouped MM vs Linear Decomposition + +- **Grouped MM**: Provides optimal performance for multi-token MoE operations using PyTorch's dedicated kernels +- **Linear Decomposition**: Enables quantization compatibility for existing techniques and is generally faster for single-token inference + +## Alternative Quantization Technique: FakeExtraDimTensor + +For broader compatibility with existing quantization techniques, we provide an alternative approach using `FakeExtraDimTensor`. This method simulates 3D tensors by storing multiple 2D tensors and implementing slicing/indexing operations, enabling compatibility with all existing linear quantization techniques without modifications. This can be done using the same API as above but adding the option for use_fake_extra_dim_tensor + +### Usage + +```python +from torchao.prototype.moe_quant.utils import UseFakeExtraDimTensor + +# Configure with FakeExtraDimTensor +config = MoEQuantConfig(... + use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE, # Key difference +) ``` -It should also be noted that the default value for use_fake_extra_dim_tensor is AS_FALLBACK which means that it will try to use the base method but if not, will use the more general but less performant fake_extra_dim_tensor method. +### Configuration Options -While this approach turns out to not be especially performant, it does allow for slightly better memory characteristics since all the tensors are held seperately and aren't actually modified or indexed. It is flexible enough to work with all of the existing linear quantization techniques that make use of quantized tensor subclasses without any changes being made to those classes. It is compilable though neither single token nor multi token inference works with fullgraph compilation. +- **`UseFakeExtraDimTensor.TRUE`**: Always use the fake tensor approach +- **`UseFakeExtraDimTensor.FALSE`**: Use the direct 3D tensor approach +- **`UseFakeExtraDimTensor.AS_FALLBACK`** (default): Try direct approach first, fallback to fake tensor if needed -## Model API +### Trade-offs -In practice the moe implementations of known models tend to not be easy to quantize and even of those that are, they are often either compiled with many graph breaks or impossible to torch.compile at all. +**Benefits:** +- Compatible with all existing quantization techniques without modifications +- Better memory characteristics (tensors stored separately) +- Flexible and general approach -The modules in the quantizable_moe_modules.py file were carefully written to satisfy both of those necessary characteristics but to apply moe quantization to your own model, it will require first a module swap from the existing MoE module type, to these more flexible ones. While there isn't a one size fits all way to do this, an example of how it was done for huggingface's llama4 implementation can be found in llama4_quant.py which can be seen as a proof of concept. +**Limitations:** +- Less performant than direct 3D tensor approach +- Fullgraph compilation not supported for single/multi-token inference +- Additional overhead from tensor simulation diff --git a/torchao/prototype/moe_quant/__init__.py b/torchao/prototype/moe_quant/__init__.py index e69de29bb2..f568221b65 100644 --- a/torchao/prototype/moe_quant/__init__.py +++ b/torchao/prototype/moe_quant/__init__.py @@ -0,0 +1,20 @@ +from .quantizable_moe_modules import ( + ExpertsAOQuantizable, + MoEFeedForwardAOQuantizable, +) +from .utils import ( + FakeExtraDimTensor, + MoEMapping, + MoEQuantConfig, + UseFakeExtraDimTensor, +) + +__all__ = [ + "MoEQuantConfig", + "MoEMappingFakeExtraDimTensor", + "FakeExtraDimTensor", + "MoEMapping", + "UseFakeExtraDimTensor", + "MoEFeedForwardAOQuantizable", + "ExpertsAOQuantizable", +] diff --git a/torchao/prototype/moe_quant/kernels.py b/torchao/prototype/moe_quant/kernels.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/torchao/prototype/moe_quant/llama4_quant.py b/torchao/prototype/moe_quant/llama4_quant.py index 36e684d47d..80eee534f4 100644 --- a/torchao/prototype/moe_quant/llama4_quant.py +++ b/torchao/prototype/moe_quant/llama4_quant.py @@ -11,82 +11,64 @@ # A activated experts # T'(e) tokens for expert e +from time import time + import torch -import torch.nn as nn from transformers import AutoTokenizer, Llama4ForCausalLM from transformers.models.llama4.modeling_llama4 import Llama4TextMoe -from torchao.prototype.moe_quant.quantizable_moe_modules import ( - MOEFeedForwardAOQuantizable, +from torchao.prototype.moe_quant.utils import ( + MoEMapping, + MoEQuantConfig, ) -from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter +from torchao.quantization.quant_api import Int4WeightOnlyConfig, quantize_ def llama4_moe_filter_fn(module, fqn): return isinstance(module, Llama4TextMoe) -def convert_fn(module): - # get data - hidden_dim = module.hidden_dim - expert_dim = module.experts.expert_dim - num_experts = module.num_experts - top_k = module.top_k - act_fn = module.experts.act_fn - shared_expert = module.shared_expert - return_scores = True - new_mod = MOEFeedForwardAOQuantizable( - hidden_dim, - expert_dim, - num_experts, - top_k, - act_fn, - shared_expert, - return_scores, - ) - - router = module.router - up_proj = module.experts.gate_up_proj - w1, w3 = up_proj.permute(0, 2, 1).chunk(2, dim=1) - w2 = module.experts.down_proj.permute(0, 2, 1) - - new_mod.router = router - new_mod.experts.w1 = nn.Parameter(w1, requires_grad=False) - new_mod.experts.w2 = nn.Parameter(w2, requires_grad=False) - new_mod.experts.w3 = nn.Parameter(w3, requires_grad=False) - return new_mod - - +max_tok = 200 model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct" model = Llama4ForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16) tokenizer = AutoTokenizer.from_pretrained(model_id) -_replace_with_custom_fn_if_matches_filter( - model, - convert_fn, - llama4_moe_filter_fn, +moe_mapping = MoEMapping( + target_module_type=Llama4TextMoe, + router_fqn="router", + top_k_fqn="top_k", + up_proj_fqn="experts.gate_up_proj", + up_proj_part2_fqn=None, + down_proj_fqn="experts.down_proj", + order_of_weight_indices=(0, 2, 1), + act_fn_fqn="experts.act_fn", + shared_expert_fqn="shared_expert", + return_scores=True, + decompose_grouped_mm=True, ) +base_config = Int4WeightOnlyConfig() -model = model - -from torchao.prototype.moe_quant.utils import ( - MoEQuantConfig, - cond_ffn_filter, -) -from torchao.quantization import Int4WeightOnlyConfig, quantize_ - -quantize_(model, MoEQuantConfig(Int4WeightOnlyConfig()), cond_ffn_filter, device="cuda") - -model.cuda() - +config = MoEQuantConfig(base_config, moe_mapping) +quantize_(model, config, llama4_moe_filter_fn, device="cuda") model = torch.compile(model, mode="reduce-overhead") prompt = "He is here, the one who will tear apart the very stars" inputs = tokenizer(prompt, return_tensors="pt") -model.generate(inputs.input_ids.cuda(), max_length=30) -model.generate(inputs.input_ids.cuda(), max_length=30) -generate_ids = model.generate(inputs.input_ids.cuda(), max_length=50) +inputs.input_ids = inputs.input_ids.cuda() +model.generate(inputs.input_ids, max_length=30) +model.generate(inputs.input_ids, max_length=30) +generate_ids = model.generate(inputs.input_ids, max_length=max_tok) out = tokenizer.batch_decode( generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False )[0] print(out) + +do_bench = True +if do_bench: + start = time() + for i in range(10): + model.generate(inputs.input_ids, max_length=max_tok) + elapsed = (time() - start) / 10 + print( + f"took {elapsed:.2f} seconds, {(max_tok - inputs.input_ids.numel()) / elapsed:.2f} tok/s" + ) diff --git a/torchao/prototype/moe_quant/quantizable_moe_modules.py b/torchao/prototype/moe_quant/quantizable_moe_modules.py index d806f50b4f..6a7c9a9182 100644 --- a/torchao/prototype/moe_quant/quantizable_moe_modules.py +++ b/torchao/prototype/moe_quant/quantizable_moe_modules.py @@ -1,191 +1,275 @@ +from typing import Callable, List + import torch import torch.nn.functional as F from torch import Tensor, nn -from torchao.prototype.moe_quant.utils import FakeExtraDimTensor +__all__ = [ + "MoEFeedForwardAOQuantizable", + "ExpertsAOQuantizable", +] -class MOEFeedForwardAOQuantizable(nn.Module): +class MoEFeedForwardAOQuantizable(nn.Module): def __init__( self, - hidden_dim, - expert_dim, - num_experts, - top_k, - act_fn=F.silu, - shared_expert=None, - return_scores=False, - empty_init=True, + num_experts: int, + hidden_dim: int, + expert_dim: int, + top_k: int, + act_fn: Callable[Tensor, Tensor] = F.silu, + shared_expert: torch.nn.Module = None, + return_scores: bool = False, + decompose_grouped_mm: bool = False, + empty_init: bool = True, ) -> None: super().__init__() self.router = nn.Linear(hidden_dim, num_experts, bias=False) - self.experts = ConditionalFeedForwardAOQuantizable( - num_experts, hidden_dim, expert_dim, act_fn, empty_init + self.experts = ExpertsAOQuantizable( + num_experts, + hidden_dim, + expert_dim, + act_fn, + decompose_grouped_mm, + empty_init, ) - self.hidden_dim = hidden_dim self.top_k = top_k self.shared_expert = shared_expert self.return_scores = return_scores def forward(self, x: Tensor) -> Tensor: - batch_size = x.shape[0] - x = x.view(-1, self.hidden_dim) # x: [T, D] + orig_shape = x.shape + x = x.view(-1, orig_shape[-1]) # x: [T, H] scores = self.router(x) # [T, E] scores = F.softmax(scores, dim=-1) scores, expert_indices = torch.topk( scores, self.top_k, dim=-1 - ) # [T, A], [T, A] - scores /= scores.sum(dim=-1, keepdim=True).to(x.dtype) # [T, A] + ) # [T, K], [T, K] + scores /= scores.sum(dim=-1, keepdim=True).to(x.dtype) # [T, K] - out = self.experts(x, expert_indices, scores, self.top_k) + out = self.experts(x, expert_indices, scores) if self.shared_expert: out += self.shared_expert(x) if self.return_scores: - return out.reshape(batch_size, -1, self.hidden_dim), scores + return out.view(*orig_shape), scores else: - return out.reshape(batch_size, -1, self.hidden_dim) + return out.view(*orig_shape) + +class ExpertsAOQuantizable(nn.Module): + weight_attrs: List[str] = ["up_proj", "down_proj"] -class ConditionalFeedForwardAOQuantizable(nn.Module): - def __init__(self, num_experts, hidden_dim, expert_dim, act_fn, empty_init=True): + def __init__( + self, + num_experts: int, + hidden_dim: int, + expert_dim: int, + act_fn: Callable[Tensor, Tensor] = F.silu, + decompose_grouped_mm: bool = False, + empty_init: bool = True, + ): super().__init__() if empty_init: - self.w1 = nn.Parameter( - torch.empty(num_experts, expert_dim, hidden_dim) - ) # E, I, D - self.w2 = nn.Parameter( + # E, 2H, D + self.up_proj = nn.Parameter( + torch.empty(num_experts, 2 * expert_dim, hidden_dim) + ) + # E, D, H + self.down_proj = nn.Parameter( torch.empty(num_experts, hidden_dim, expert_dim) - ) # E, D, I - self.w3 = nn.Parameter( - torch.empty(num_experts, expert_dim, hidden_dim) - ) # E, I, D + ) else: - self.w1 = nn.Parameter( - torch.randn(num_experts, expert_dim, hidden_dim) - ) # E, I, D - self.w2 = nn.Parameter( + self.up_proj = nn.Parameter( + torch.randn(num_experts, 2 * expert_dim, hidden_dim) + ) + self.down_proj = nn.Parameter( torch.randn(num_experts, hidden_dim, expert_dim) - ) # E, D, I - self.w3 = nn.Parameter( - torch.randn(num_experts, expert_dim, hidden_dim) - ) # E, I, D - self.num_experts = num_experts + ) + self.act_fn = act_fn - self.hidden_dim = hidden_dim - self.expert_dim = expert_dim + self.decompose_grouped_mm = decompose_grouped_mm def forward( self, x: Tensor, # T, D expert_indices: Tensor, # T, A - expert_weights: Tensor, # T, A - top_k: int, + scores: Tensor, # T, A ) -> Tensor: - num_tokens, _hidden_dim = x.shape - num_token_activations = num_tokens * top_k - - if x.shape[0] == 1 and not isinstance( - self.w1, FakeExtraDimTensor - ): # only 1 token (can be done without graph breaks when compiled) - outs = [] - expert_indices = expert_indices.view(top_k) - # collect used experts - w1 = self.w1[expert_indices] - w2 = self.w2[expert_indices] - w3 = self.w3[expert_indices] - # run token through each expert - for index in range(top_k): - y1 = F.silu(F.linear(x, w1[index])) - y3 = F.linear(x, w3[index]) - y2 = w2[index] - - cur_out = F.linear(y1 * y3, y2) - outs.append(cur_out) - - # combine outputs - final_out = ( - (torch.cat(outs, dim=0) * expert_weights.view(-1, 1)) - .sum(dim=0) - .reshape(x.shape) + if not self.decompose_grouped_mm: + final_out = self._forward_grouped_mm( + x, expert_indices, scores, self.up_proj, self.down_proj, self.act_fn ) - return final_out - else: - expert_list = [x for x in range(self.num_experts)] - - # shuffle tokens into groups for each expert - ordered_token_activations = expert_indices.view(-1).argsort( - stable=True - ) # [A] - ordered_token_indices = ( - ordered_token_activations.div(top_k).floor().to(torch.int64) - ) # [T] - if not expert_indices.is_cuda: # histc doesn't work on cpu for integers - num_tokens_per_expert = torch.bincount( - expert_indices.view(-1) + 1, minlength=self.num_experts + 1 - ) - else: - num_tokens_per_expert = torch.histc( - expert_indices, - bins=self.num_experts + 1, - min=-1, - max=self.num_experts, - ) # [E+1] (added leading 0 so can be used for indexing) - cum_tokens_per_expert = num_tokens_per_expert.cumsum(0).to( - torch.int64 - ) # [E+1] - - @torch._dynamo.disable() - def group_tokens_by_expert( - ordered_token_indices, cum_tokens_per_expert, expert_list - ): - token_indices_per_expert = [ - ordered_token_indices[ - cum_tokens_per_expert[expert] : cum_tokens_per_expert[ - expert + 1 - ] - ].to(torch.int64) - for expert in expert_list - ] # [T'(e1)], [T'(e2)] ... - return token_indices_per_expert - - token_indices_per_expert = group_tokens_by_expert( - ordered_token_indices, cum_tokens_per_expert, expert_list + elif x.shape[0] > 1 or "FakeExtraDimTensor" in str(type(self.up_proj)): + final_out = self._forward_multi_token_linear_decomposition( + x, expert_indices, scores, self.up_proj, self.down_proj, self.act_fn ) - tokens_grouped_by_expert = [ - x[indices] for indices in token_indices_per_expert - ] - - # calculate outputs for each expert - outs = [] - for cur_x, expert in zip(tokens_grouped_by_expert, expert_list): - w1 = self.w1[expert] # I, D - w2 = self.w2[expert] # D, I - w3 = self.w3[expert] # I, D - - y1 = F.silu(F.linear(cur_x, w1)) - y3 = F.linear(cur_x, w3) - y2 = w2 - - cur_out = F.linear(y1 * y3, y2) # [T'(e), D] - outs.append(cur_out) - - # weigh outputs - ordered_outs = torch.cat(outs, dim=0) # [T*A, D] - ordered_token_activation_weights = expert_weights.view(-1, 1)[ - ordered_token_activations - ].view(-1, 1) # [T*A, 1] - weighted_ordered_outs = ( - ordered_outs * ordered_token_activation_weights - ) # [T*A, D] - - # sum weighted token-activation outputs together for each token - final_out = torch.zeros_like(x) # [T, D] - final_out = final_out.scatter_add( - dim=0, - index=ordered_token_indices.unsqueeze(-1) - .expand(num_token_activations, self.hidden_dim) - .to(torch.int64), - src=weighted_ordered_outs, + else: + final_out = self._forward_single_token_linear_decomposition( + x, expert_indices, scores, self.up_proj, self.down_proj, self.act_fn ) + + return final_out + + @staticmethod + def _forward_grouped_mm( + x: Tensor, + expert_indices: Tensor, + scores: Tensor, + up_proj: Tensor, + down_proj: Tensor, + act_fn: Callable[Tensor, Tensor], + ): + assert hasattr(torch, "_grouped_mm"), ( + "the _grouped_mm op was not found, try installing pytorch nightly or test with: >python -c 'import torch; print(torch._grouped_mm)'" + ) + + # get shapes + num_experts, hidden_dim, expert_dim = down_proj.shape + num_tokens, top_k = expert_indices.shape + num_token_activations = num_tokens * top_k + + # token shuffle + expert_indices = expert_indices.view(-1) + ordered_token_activations = expert_indices.argsort(stable=True) + ordered_token_indices = ( + ordered_token_activations.div(top_k).floor().to(torch.int32) + ) + indices_for_histc = ( + expert_indices if expert_indices.is_cuda else expert_indices.float() + ) + num_tokens_per_expert = torch.histc( # histc doesn't work on cpu for integers + indices_for_histc, + bins=num_experts, + min=0, + max=num_experts, + ) + offs = num_tokens_per_expert.cumsum(dim=0).to(torch.int32) + ordered_inputs = x[ordered_token_indices] + ordered_scores = scores.view(-1, 1)[ordered_token_activations] + + # calculate outputs + gate, up = torch._grouped_mm( + ordered_inputs, up_proj.transpose(-2, -1), offs + ).chunk(2, dim=1) + y1 = act_fn(gate) * up + ordered_outs = torch._grouped_mm(y1, down_proj.transpose(-2, -1), offs) + ordered_weighted_outs = ordered_scores * ordered_outs + + # un-shuffle output + final_out = torch.zeros_like(x) + final_out = final_out.scatter_add( + dim=0, + index=ordered_token_indices.unsqueeze(-1) + .expand(num_token_activations, hidden_dim) + .to(torch.int64), + src=ordered_weighted_outs, + ) + return final_out + + @staticmethod + def _forward_single_token_linear_decomposition( + x: Tensor, + expert_indices: Tensor, + scores: Tensor, + up_proj: Tensor, + down_proj: Tensor, + act_fn: Callable[Tensor, Tensor], + ): + # get shapes + assert x.shape[0] == 1 and x.dim() == 2, ( + f"single_token_moe_kernel_linear_decomposition only works with inputs of shape [1, hidden_dim] but got {x.shape}" + ) + num_activated_experts = expert_indices.numel() + expert_indices = expert_indices.view(-1) + + # collect only experts that get activated + cur_up_proj = up_proj[expert_indices] + cur_down_proj = down_proj[expert_indices] + + # calculate outputs + outs = [] + for index in range(num_activated_experts): + gate, up = F.linear(x, cur_up_proj[index]).chunk(2, dim=-1) + y1 = act_fn(gate) * up + cur_out = F.linear(y1, cur_down_proj[index]) + outs.append(cur_out) + + # combine output + out = torch.cat(outs, dim=0) + final_out = (out * scores.view(-1, 1)).sum(dim=0).unsqueeze(0) + return final_out + + @staticmethod + def _forward_multi_token_linear_decomposition( + x: Tensor, + expert_indices: Tensor, + scores: Tensor, + up_proj: Tensor, + down_proj: Tensor, + act_fn: Callable[Tensor, Tensor], + ): + @torch._dynamo.disable() + def _group_tokens_by_expert(ordered_token_indices, cum_tokens_per_expert): + num_experts = cum_tokens_per_expert.numel() - 1 + token_indices_per_expert = [ + ordered_token_indices[ + cum_tokens_per_expert[expert] : cum_tokens_per_expert[expert + 1] + ] + for expert in range(num_experts) + if cum_tokens_per_expert[expert] < cum_tokens_per_expert[expert + 1] + ] # [T'(e1)], [T'(e2)] ... + return token_indices_per_expert + + # get shapes + num_experts, hidden_dim, expert_dim = down_proj.shape + num_tokens, top_k = expert_indices.shape + num_token_activations = num_tokens * top_k + + # token shuffle + expert_indices = expert_indices.view(-1) + ordered_token_activations = expert_indices.argsort(stable=True) + ordered_token_indices = ( + ordered_token_activations.div(top_k).floor().to(torch.int32) + ) + indices_for_histc = ( + expert_indices if expert_indices.is_cuda else expert_indices.float() + ) + num_tokens_per_expert = torch.histc( # histc doesn't work on cpu for integers + indices_for_histc, + bins=num_experts + 1, + min=-1, + max=num_experts, + ) + cum_tokens_per_expert = num_tokens_per_expert.cumsum(0).to(torch.int64) + token_indices_per_expert = _group_tokens_by_expert( + ordered_token_indices, cum_tokens_per_expert + ) + tokens_grouped_by_expert = [x[indices] for indices in token_indices_per_expert] + + # calculate outputs for each expert + outs = [] + for expert, cur_x in enumerate(tokens_grouped_by_expert): + cur_up_proj = up_proj[expert] + cur_down_proj = down_proj[expert] + + gate, up = F.linear(cur_x, cur_up_proj).chunk(2, dim=1) + y1 = act_fn(gate) * up + cur_out = F.linear(y1, cur_down_proj) + outs.append(cur_out) + + # weigh outputs + ordered_outs = torch.cat(outs, dim=0) # [T*A, D] + ordered_scores = scores.view(-1, 1)[ordered_token_activations] # [T*A, 1] + ordered_weighted_outs = ordered_scores * ordered_outs + + # un-shuffle outputs + final_out = torch.zeros_like(x) + final_out = final_out.scatter_add( + dim=0, + index=ordered_token_indices.unsqueeze(-1) + .expand(num_token_activations, hidden_dim) + .to(torch.int64), + src=ordered_weighted_outs, + ) return final_out diff --git a/torchao/prototype/moe_quant/utils.py b/torchao/prototype/moe_quant/utils.py index 28291afdf4..1a8542a92d 100644 --- a/torchao/prototype/moe_quant/utils.py +++ b/torchao/prototype/moe_quant/utils.py @@ -5,14 +5,20 @@ # LICENSE file in the root directory of this source tree. import torch +import torch.nn.functional as F from torch.utils._python_dispatch import ( return_and_correct_aliasing, ) +import torchao + aten = torch.ops.aten +import warnings from enum import Enum, auto -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union + +from torch.ao.quantization.utils import getattr_from_fqn from torchao.quantization.quant_api import ( _QUANTIZE_CONFIG_HANDLER, @@ -22,6 +28,233 @@ ) from torchao.utils import DummyModule, fill_defaults +from .quantizable_moe_modules import MoEFeedForwardAOQuantizable + +warnings.simplefilter("ignore", lineno=84) +warnings.simplefilter("ignore", lineno=105) + +__all__ = [ + "MoEQuantConfig", + "MoEMapping", + "FakeExtraDimTensor", + "UseFakeExtraDimTensor", +] + + +class UseFakeExtraDimTensor(Enum): + """Enum that indicate whether to use FakeExtraDimTensor""" + + TRUE = auto() + FALSE = auto() + AS_FALLBACK = auto() + + +@dataclass +class MoEQuantConfig(AOBaseConfig): + """Configuration for applying quantization to MoE + Args: + `Optional[base_config]`: normal AO Config to be applied to a MoEFeedforwardAOQuantizable module, + if None, then will only do the conversion to MoEFeedforwardAOQuantizable using the mapping + `Optional[mapping]`: MoEMapping, if None, then this will do no conversion, note: only + MoEFeedforwardAOQuantizable modules can be quantized. + """ + + base_config: Optional[AOBaseConfig] = None + mapping: Optional["MoEMapping"] = None + + use_fake_extra_dim_tensor: UseFakeExtraDimTensor = UseFakeExtraDimTensor.AS_FALLBACK + set_inductor_config: bool = True + + +@register_quantize_module_handler(MoEQuantConfig) +def moe_convert_and_quant_fn(module: torch.nn.Module, config: MoEQuantConfig): + mapping = config.mapping + base_config = config.base_config + assert mapping is not None or base_config is not None, ( + "need one of mapping or base_config to use MoEQuantConfig" + ) + + # maybe convert module to quantizable + if mapping is not None and isinstance(module, mapping.target_module_type): + module = _convert_module_to_ao_quantizable(module, mapping) + + # maybe quantize module + if base_config is not None and isinstance(module, MoEFeedForwardAOQuantizable): + module = _quantize_moe_module(module, config) + + return module + + +def _quantize_moe_module(module: torch.nn.Module, config: MoEQuantConfig): + assert isinstance(module, MoEFeedForwardAOQuantizable), ( + f"can only apply quantization to MoEFeedForwardAOQuantizable modules but got {type(module)}" + ) + + experts = module.experts + + for weight_attr in experts.weight_attrs: + param = getattr(experts, weight_attr) + assert param.dim() == 3, ( + f"when applying moe_quant to {module} expected 3D tensor for {weight_attr} but got {param.dim()}" + ) + assert isinstance(config.base_config, AOBaseConfig), ( + f"MoEQuantConfig expected to be initialized with an AOBaseConfig but got {type(config.base_config)}" + + "this can happen if you initiaze with MoEQuantConfig(AOConfig) rather than MoEQuantConfig(AOConfig())" + ) + new_param = _quantize_moe_tensor(param, config) + new_param = torch.nn.Parameter(new_param, requires_grad=False) + setattr(experts, weight_attr, new_param) + del param + return module + + +# Module-level flag to track if we've already printed the error +_quantize_moe_tensor_has_printed_error = False + + +def _quantize_moe_tensor(weight: torch.Tensor, config: MoEQuantConfig): + def _quantize_moe_tensor_base(weight, config): + base_config_handler = _QUANTIZE_CONFIG_HANDLER[type(config.base_config)] + dummy_mod = DummyModule(weight) + quant_mod = base_config_handler(dummy_mod, config.base_config) + return quant_mod.weight + + def _quantize_moe_tensor_fake_extra_dim_tensor( + weight: torch.Tensor, config: MoEQuantConfig + ): + base_config_handler = _QUANTIZE_CONFIG_HANDLER[type(config.base_config)] + # break 3D tensor + tensors = [weight[i] for i in range(weight.shape[0])] + # put tensors into modules since the handlers target modules not tensors + dummy_modules = [DummyModule(tensor) for tensor in tensors] + # apply handler to each module + quant_mods = list( + map(lambda x: base_config_handler(x, config.base_config), dummy_modules) + ) + # pack quantized subclasses into FakeExtraDimTensor + quant_weight = FakeExtraDimTensor([mod.weight for mod in quant_mods]) + return quant_weight + + global _quantize_moe_tensor_has_printed_error + + use_fake = config.use_fake_extra_dim_tensor + if use_fake == UseFakeExtraDimTensor.FALSE: + return _quantize_moe_tensor_base(weight, config) + elif use_fake == UseFakeExtraDimTensor.AS_FALLBACK: + try: + return _quantize_moe_tensor_base(weight, config) + except Exception as e: + if not _quantize_moe_tensor_has_printed_error: + print(f"tried to do moe_quant but got error: {e}") + _quantize_moe_tensor_has_printed_error = True + return _quantize_moe_tensor_fake_extra_dim_tensor(weight, config) + else: # This handles UseFakeExtraDimTensor.TRUE + return _quantize_moe_tensor_fake_extra_dim_tensor(weight, config) + + +@dataclass +class MoEMapping: + """This mapping dataclass is used to map an existing MoE module to the AOQuantizable one + and is used with the convert_moe_with_mapping fn to convert a model to use the AO moe modules + """ + + target_module_type: type + + router_fqn: str = "gate" + top_k_fqn: Optional[str] = "num_activated_experts" + + # if up_proj is a single tensor, leave up_proj_part2_fqn as None, otherwise list the fqn + # for w1 and up_proj_fqn and w3 as up_proj_part2_fqn + up_proj_fqn: str = "cond_ffn.w1" + up_proj_part2_fqn: Optional[str] = "cond_ffn.w3" + down_proj_fqn: str = "cond_ffn.w2" # also known as down_proj + + # what is the order of indices of the weights, + # specifically which order are the experts, out_features, in_features indices in? + # for up_proj this would be experts, expert_dim*2, hidden_dim, + # for down_proj this would be experts, hidden_dim, expert_dim, + order_of_weight_indices: Union[Tuple[int], Tuple[int]] = ( + 0, + 1, + 2, + ) + + # can't both be None + act_fn: Optional[Callable[[torch.Tensor], torch.Tensor]] = F.silu + act_fn_fqn: Optional[str] = None + + # Options + shared_expert_fqn: Optional[str] = None + return_scores: bool = False + decompose_grouped_mm: bool = False + + +def _convert_module_to_ao_quantizable(module: torch.nn.Module, mapping: MoEMapping): + assert isinstance(module, mapping.target_module_type), ( + f"_convert_module_to_ao_quantizable only works on modules of type {mapping.target_module_type} but got {type(module)}" + ) + + # get router and top_k + router = getattr_from_fqn(module, mapping.router_fqn) + top_k = getattr_from_fqn(module, mapping.top_k_fqn) + + # get up and down_proj + order_of_indices = mapping.order_of_weight_indices + if mapping.up_proj_part2_fqn is None: + up_proj = ( + getattr_from_fqn(module, mapping.up_proj_fqn) + .permute(*order_of_indices) + .contiguous() + ) + else: + w1 = getattr_from_fqn(module, mapping.up_proj_fqn).permute(*order_of_indices) + w3 = getattr_from_fqn(module, mapping.up_proj_part2_fqn).permute( + *order_of_indices + ) + up_proj = torch.cat((w1, w3), dim=1).contiguous() + + down_proj = ( + getattr_from_fqn(module, mapping.down_proj_fqn) + .permute(*order_of_indices) + .contiguous() + ) + + # get sizes + num_experts, hidden_dim, expert_dim = down_proj.shape + + # get act_fn + act_fn = mapping.act_fn + if act_fn is None: + act_fn = getattr_from_fqn(module, mapping.act_fn_fqn) + assert act_fn is not None, ( + "both act_fn and act_fn_fqn can't be None in the MoEMapping" + ) + + # get final options + shared_expert = None + if isinstance(mapping.shared_expert_fqn, str): + shared_expert = getattr_from_fqn(module, mapping.shared_expert_fqn) + return_scores = mapping.return_scores + decompose_grouped_mm = mapping.decompose_grouped_mm + + # make new module + new_module = torchao.prototype.moe_quant.MoEFeedForwardAOQuantizable( + num_experts=num_experts, + hidden_dim=hidden_dim, + expert_dim=expert_dim, + top_k=top_k, + act_fn=act_fn, + shared_expert=shared_expert, + return_scores=return_scores, + decompose_grouped_mm=decompose_grouped_mm, + ) + + new_module.router = router + new_module.experts.up_proj = torch.nn.Parameter(up_proj) + new_module.experts.down_proj = torch.nn.Parameter(down_proj) + + return new_module + class FakeExtraDimTensor(torch.Tensor): """This is a subclass of torch.Tensor that simulates a tensor of n+1 dimensions, akin to concatenating several tensors along the 0th dimension. @@ -207,96 +440,3 @@ def __torch_dispatch__(cls, func, types, args, kwargs): "run function on its elements: " ) raise e - - -class UseFakeExtraDimTensor(Enum): - """Enum that indicate whether to use FakeExtraDimTensor""" - - TRUE = auto() - FALSE = auto() - AS_FALLBACK = auto() - - -@dataclass -class MoEQuantConfig(AOBaseConfig): - """Configuration for applying quantization to MoE - Args: - `base_config`: normal AO Config - """ - - base_config: AOBaseConfig - use_fake_extra_dim_tensor: UseFakeExtraDimTensor = UseFakeExtraDimTensor.AS_FALLBACK - set_inductor_config: bool = True - - -# Module-level flag to track if we've already printed the error -_moe_quant_tensor_has_printed_error = False - - -def _moe_quant_tensor(weight, config): - def _moe_quant_tensor_base(weight, config): - base_config_handler = _QUANTIZE_CONFIG_HANDLER[type(config.base_config)] - dummy_mod = DummyModule(weight) - quant_mod = base_config_handler(dummy_mod, config.base_config) - return quant_mod.weight - - def _moe_quant_tensor_fake_extra_dim_tensor(weight, config): - base_config_handler = _QUANTIZE_CONFIG_HANDLER[type(config.base_config)] - # break 3D tensor - tensors = [weight[i] for i in range(weight.shape[0])] - # put tensors into modules since the handlers target modules not tensors - dummy_modules = [DummyModule(tensor) for tensor in tensors] - # apply handler to each module - quant_mods = list( - map(lambda x: base_config_handler(x, config.base_config), dummy_modules) - ) - # pack quantized subclasses into FakeExtraDimTensor - quant_weight = FakeExtraDimTensor([mod.weight for mod in quant_mods]) - return quant_weight - - global _moe_quant_tensor_has_printed_error - - use_fake = config.use_fake_extra_dim_tensor - if use_fake == UseFakeExtraDimTensor.FALSE: - return _moe_quant_tensor_base(weight, config) - elif use_fake == UseFakeExtraDimTensor.AS_FALLBACK: - try: - return _moe_quant_tensor_base(weight, config) - except Exception as e: - if not _moe_quant_tensor_has_printed_error: - print(f"tried to do moe_quant but got error: {e}") - _moe_quant_tensor_has_printed_error = True - return _moe_quant_tensor_fake_extra_dim_tensor(weight, config) - else: # This handles UseFakeExtraDimTensor.TRUE - return _moe_quant_tensor_fake_extra_dim_tensor(weight, config) - - -@register_quantize_module_handler(MoEQuantConfig) -def moe_quant_fn(module, config: MoEQuantConfig): - import warnings - - warnings.simplefilter("ignore", lineno=84) - warnings.simplefilter("ignore", lineno=105) - - for weight_attr in ["w1", "w2", "w3"]: - param = getattr(module, weight_attr) - assert param.dim() == 3, ( - f"when applying moe_quant to {module} expected 3D tensor for {weight_attr} but got {param.dim()}" - ) - assert isinstance(config.base_config, AOBaseConfig), ( - f"MoEQuantConfig expected to be initialized with an AOBaseConfig but got {type(config.base_config)}" - + "this can happen if you initiaze with MoEQuantConfig(AOConfig) rather than MoEQuantConfig(AOConfig())" - ) - new_param = _moe_quant_tensor(param, config) - new_param = torch.nn.Parameter(new_param, requires_grad=False) - setattr(module, weight_attr, new_param) - del param - return module - - -def moe_filter(module, fqn): - return "MOEFeedForwardAOQuantizable" in str(type(module)) - - -def cond_ffn_filter(module, fqn): - return "ConditionalFeedForwardAOQuantizable" in str(type(module)) diff --git a/torchao/quantization/linear_activation_quantized_tensor.py b/torchao/quantization/linear_activation_quantized_tensor.py index 658b172994..772c10dd43 100644 --- a/torchao/quantization/linear_activation_quantized_tensor.py +++ b/torchao/quantization/linear_activation_quantized_tensor.py @@ -21,6 +21,12 @@ aten = torch.ops.aten +grouped_mm = [ + aten._grouped_mm.default if hasattr(aten, "_grouped_mm") else None, + torch._grouped_mm if hasattr(torch, "_grouped_mm") else None, +] + + class LinearActivationQuantizedTensor(TorchAOBaseTensor): """ Applies activation quantization for linear operator, this is used to support @@ -82,8 +88,15 @@ def __tensor_unflatten__( def _quantized_linear_op( input_tensor: torch.Tensor, weight_tensor: torch.Tensor, bias: torch.Tensor ): - if input_tensor.numel() == 0: - return input_tensor + if ( + input_tensor.numel() == 0 + ): # need to actually return the correct shape for compile + return torch.empty( + input_tensor.shape[0], + weight_tensor.shape[0], + dtype=input_tensor.dtype, + device=input_tensor.device, + ) input_quant_func = weight_tensor.input_quant_func original_weight_tensor = weight_tensor.original_weight_tensor quant_kwargs = weight_tensor.quant_kwargs @@ -186,6 +199,19 @@ def _(func, types, args, kwargs): return func(qtensor, original_weight_tensor) +@implements(grouped_mm) +def _(func, types, args, kwargs): + kwargs = {} if kwargs is None else kwargs + input_tensor, weight_tensor, offs = args[0], args[1], args[2] + assert len(args) == 3 and kwargs == {}, ( + "grouped_mm_only implemented for 3 args in ao" + ) + input_quant_func = weight_tensor.input_quant_func + original_weight_tensor = weight_tensor.original_weight_tensor + qtensor = input_quant_func(input_tensor, **weight_tensor.quant_kwargs) + return func(qtensor, original_weight_tensor, offs) + + @implements([aten.detach.default, aten.alias.default]) def _(func, types, args, kwargs): return return_and_correct_aliasing( @@ -231,6 +257,16 @@ def _(func, types, args, kwargs): ) +@implements(aten.transpose.int) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, + args, + kwargs, + args[0]._apply_fn_to_data(lambda x: torch.transpose(x, *args[1:])), + ) + + @implements(aten.slice.Tensor) def _(func, types, args, kwargs): return return_and_correct_aliasing(