1414import torch ._inductor .config
1515
1616from torchao .utils import get_model_size_in_bytes
17+ from torchao .prototype .moe_quant import MoEFeedForwardAOQuantizable
18+ from torchao .quantization .quant_api import _replace_with_custom_fn_if_matches_filter
19+ from model import MoEFeedForward
1720
1821torch .manual_seed (0 )
1922
@@ -199,7 +202,9 @@ def main(
199202 checkpoint_path : Path = Path ("checkpoints/mistralai/Mixtral-8x7B-v0.1/model.pth" ),
200203 compile : bool = True ,
201204 compile_prefill : bool = False ,
205+ compile_mode : str = "reduce-overhead" ,
202206 moe_quant : Optional [str ] = None ,
207+ decompose_grouped_mm : bool = False ,
203208 profile : Optional [Path ] = None ,
204209 memory_profile : Optional [Path ] = None ,
205210 device = "cuda" ,
@@ -212,6 +217,13 @@ def main(
212217 precision = torch .bfloat16
213218 is_chat = "chat" in str (checkpoint_path )
214219
220+ if batch_size > 1 and moe_quant is None :
221+ print (
222+ "Warning: Detected no moe_quant but batchsize>1. The default MoE implementation uses a lot of memory when batched," +
223+ " if it OOMs you can instead run without quantization by specifying --moe_quant noquant which uses the AO quantizable" +
224+ "module without quantization to run the quantizable module without quantization"
225+ )
226+
215227 if device == "cuda" and memory_profile is not None :
216228 torch .cuda .memory ._record_memory_history (
217229 True , trace_alloc_max_entries = 500000 , trace_alloc_record_context = True
@@ -236,10 +248,11 @@ def main(
236248 ]
237249 )
238250
239- from torchao .prototype .moe_quant . utils import (
251+ from torchao .prototype .moe_quant import (
240252 MoEQuantConfig ,
253+ MoEMapping ,
241254 UseFakeExtraDimTensor ,
242- cond_ffn_filter ,
255+ MoEFeedForwardAOQuantizable ,
243256 )
244257 from torchao .quantization .quant_api import (
245258 Float8DynamicActivationFloat8WeightConfig ,
@@ -255,71 +268,61 @@ def main(
255268
256269 if moe_quant :
257270 torch ._dynamo .config .capture_dynamic_output_shape_ops = True
258- config = None
271+ config = MoEQuantConfig ( mapping = MoEMapping ( target_module_type = MoEFeedForward , decompose_grouped_mm = decompose_grouped_mm ))
259272 if "int8wo-base" in moe_quant :
260- config = MoEQuantConfig ( Int8WeightOnlyConfig () )
273+ config . base_config = Int8WeightOnlyConfig ()
261274
262275 elif "int8wo" in moe_quant :
263- config = MoEQuantConfig (
264- Int8WeightOnlyConfig (),
265- use_fake_extra_dim_tensor = UseFakeExtraDimTensor .TRUE ,
266- )
276+ config .base_config = Int8WeightOnlyConfig ()
277+ config .use_fake_extra_dim_tensor = UseFakeExtraDimTensor .TRUE
267278
268279 elif "int8dq-base" in moe_quant :
269- config = MoEQuantConfig ( Int8DynamicActivationInt8WeightConfig () )
280+ config . base_config = Int8DynamicActivationInt8WeightConfig ()
270281
271282 elif "int8dq" in moe_quant :
272- config = MoEQuantConfig (
273- Int8DynamicActivationInt8WeightConfig (),
274- use_fake_extra_dim_tensor = UseFakeExtraDimTensor .TRUE ,
275- )
283+ config .base_config = Int8DynamicActivationInt8WeightConfig ()
284+ config .use_fake_extra_dim_tensor = UseFakeExtraDimTensor .TRUE
276285
277286 elif "int4wo-base" in moe_quant :
278- config = MoEQuantConfig ( Int4WeightOnlyConfig () )
287+ config . base_config = Int4WeightOnlyConfig ()
279288
280289 elif "int4wo" in moe_quant :
281- config = MoEQuantConfig (
282- Int4WeightOnlyConfig (),
283- use_fake_extra_dim_tensor = UseFakeExtraDimTensor .TRUE ,
284- )
290+ config .base_config = Int4WeightOnlyConfig ()
291+ config .use_fake_extra_dim_tensor = UseFakeExtraDimTensor .TRUE
285292
286293 elif "fp8wo-base" in moe_quant :
287- config = MoEQuantConfig ( Float8WeightOnlyConfig () )
294+ config . base_config = Float8WeightOnlyConfig ()
288295
289296 elif "fp8wo" in moe_quant :
290- config = MoEQuantConfig (
291- Float8WeightOnlyConfig (),
292- use_fake_extra_dim_tensor = UseFakeExtraDimTensor .TRUE ,
293- )
297+ config .base_config = Float8WeightOnlyConfig ()
298+ config .use_fake_extra_dim_tensor = UseFakeExtraDimTensor .TRUE
294299
295300 elif "fp8dq-base" in moe_quant :
296- config = MoEQuantConfig (
297- Float8DynamicActivationFloat8WeightConfig (granularity = PerRow ())
298- )
301+ config .base_config = Float8DynamicActivationFloat8WeightConfig (granularity = PerRow ())
299302
300303 elif "fp8dq" in moe_quant :
301- config = MoEQuantConfig (
302- Float8DynamicActivationFloat8WeightConfig (granularity = PerRow ()),
303- use_fake_extra_dim_tensor = UseFakeExtraDimTensor .TRUE ,
304- )
304+ config .base_config = Float8DynamicActivationFloat8WeightConfig (granularity = PerRow ())
305+ config .use_fake_extra_dim_tensor = UseFakeExtraDimTensor .TRUE
305306
306307 elif "intxdq" in moe_quant :
307- config = MoEQuantConfig (
308- Int8DynamicActivationIntxWeightConfig (
308+ config .base_config = Int8DynamicActivationIntxWeightConfig (
309309 layout = PackedLinearInt8DynamicActivationIntxWeightLayout (),
310310 ),
311- use_fake_extra_dim_tensor = UseFakeExtraDimTensor .TRUE ,
312- )
311+ config .use_fake_extra_dim_tensor = UseFakeExtraDimTensor .TRUE
312+ elif "noquant" in moe_quant :
313+ pass
313314 else :
314315 assert config is not None , (
315316 f"expected moe_quant to match one of the options but got { moe_quant } "
316317 )
317318
318- if config is not None :
319- quantize_ (model , config , filter_fn = cond_ffn_filter , device = device )
320- print (
321- f"Time to apply quantization with config { config } to model: { time .time () - t0 :.02f} seconds"
322- )
319+ def filter_fn (mod , fqn ):
320+ return isinstance (mod , MoEFeedForward )
321+
322+ quantize_ (model , config , filter_fn = filter_fn , device = device )
323+ print (
324+ f"Time to apply quantization with config { config } to model: { time .time () - t0 :.02f} seconds"
325+ )
323326
324327 model .to (device = device )
325328 device_sync (device = device )
@@ -335,12 +338,12 @@ def main(
335338
336339 global decode_one_token , prefill
337340
338- if batch_size == 1 and (isinstance (moe_quant , str ) and "base" in moe_quant ):
341+ if not decompose_grouped_mm or ( batch_size == 1 and (isinstance (moe_quant , str ) and "base" in moe_quant ) ):
339342 decode_one_token = torch .compile (
340- decode_one_token , mode = "reduce-overhead" , fullgraph = True
343+ decode_one_token , mode = compile_mode , fullgraph = True
341344 )
342345 else :
343- decode_one_token = torch .compile (decode_one_token , mode = "reduce-overhead" )
346+ decode_one_token = torch .compile (decode_one_token , mode = compile_mode )
344347
345348 if args .compile_prefill :
346349 prefill = torch .compile (prefill , fullgraph = True , dynamic = True )
@@ -474,11 +477,22 @@ def callback(x):
474477 action = "store_true" ,
475478 help = "Whether to compile the prefill (improves prefill perf, but higher compile times)" ,
476479 )
477- # parser.add_argument('-q', '--quantization', type=str, help='Which quantization techniques to apply: int8dq, int8wo, int4wo, fp8')
480+ parser .add_argument (
481+ "--compile_mode" ,
482+ type = str ,
483+ default = "reduce-overhead" ,
484+ help = "which torch.compile mode to use: reduce-overhead or max-autotune, does nothing if --compile is not set." ,
485+ )
478486 parser .add_argument (
479487 "--moe_quant" ,
480488 type = str ,
481- help = "Which quantization techniques to apply: int8dq, int8wo, int4wo, fp8wo, fp8dq" ,
489+ help = "Which quantization techniques to apply: int8dq, int8wo, int4wo, fp8wo, fp8dq, noquant" ,
490+ )
491+ parser .add_argument (
492+ "--decompose_grouped_mm" ,
493+ action = "store_true" ,
494+ default = False ,
495+ help = "Whether to decompose grouped_mm into linear ops for the MoE module, only relevant when moe_quant is set" ,
482496 )
483497 parser .add_argument ("--profile" , type = Path , default = None , help = "Profile path." )
484498 parser .add_argument (
@@ -499,7 +513,9 @@ def callback(x):
499513 args .checkpoint_path ,
500514 args .compile ,
501515 args .compile_prefill ,
516+ args .compile_mode ,
502517 args .moe_quant ,
518+ args .decompose_grouped_mm ,
503519 args .profile ,
504520 args .memory_profile ,
505521 args .device ,
0 commit comments