Skip to content

Commit ea9190e

Browse files
committed
Route all experts
Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
1 parent 4e4bf16 commit ea9190e

File tree

2 files changed

+29
-1
lines changed

2 files changed

+29
-1
lines changed

examples/deepseek/ptq.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,26 @@ def _setup(self):
198198
self.kv_bmm_quantizer = TensorQuantizer()
199199
self.pe_bmm_quantizer = TensorQuantizer()
200200

201+
class CalibMoe(deekseep_model.MoE):
202+
def __init__(self, *args, **kwargs):
203+
super().__init__(*args, **kwargs)
204+
self._setup()
205+
206+
def _setup(self):
207+
self._original_topk = self.gate.topk
208+
self._original_topk_groups = self.gate.topk_groups
209+
210+
def forward(self, x: torch.Tensor) -> torch.Tensor:
211+
# Forward all tokens to all experts for calibration
212+
self.gate.topk = self.n_routed_experts
213+
self.gate.topk_groups = self.gate.n_groups
214+
super().forward(x)
215+
# Restore the original topk and topk_groups
216+
self.gate.topk = self._original_topk
217+
self.gate.topk_groups = self._original_topk_groups
218+
219+
return super().forward(x)
220+
201221
mtq.register(
202222
original_cls=deekseep_model.RowParallelLinear,
203223
quantized_cls=QuantRowParallelLinear,
@@ -208,6 +228,7 @@ def _setup(self):
208228
)
209229
mtq.register(original_cls=deekseep_model.Linear, quantized_cls=QuantLinear)
210230
mtq.register(original_cls=deekseep_model.MLA, quantized_cls=QuantMLA)
231+
mtq.register(original_cls=deekseep_model.MoE, quantized_cls=CalibMoe)
211232

212233

213234
def load_deepseek_model(model_config: str, model_path: str, batch_size: int):
@@ -319,6 +340,13 @@ def state_dict_filter(state_dict):
319340
os.path.join(output_path, f"amax_dict_rank{rank}-mp{world_size}.pt"),
320341
)
321342

343+
# if rank == 0:
344+
# with open("expert_activation_counts.txt", "w") as f:
345+
# for name, module in model.named_modules():
346+
# if isinstance(module, deekseep_model.MoE):
347+
# counts = module.activated_expert_counts()
348+
# f.writelines(f"{name}: {count}\n" for count in counts)
349+
322350
quant_config = get_quant_config(model.named_modules())
323351

324352
if enable_fp8_kvcache:

examples/deepseek/quantize_to_nvfp4.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def convert_fp8_ckpt_to_nvfp4(
151151
per_layer_quant_config,
152152
):
153153
def amax_to_nvfp4_scaling_factor_2(amax):
154-
return amax.float() / 6.0 / 448.0
154+
return amax.float() / (6.0 * 448.0)
155155

156156
def amax_to_fp8_scaling_factor(amax):
157157
return amax.float() / 448.0

0 commit comments

Comments
 (0)