-
Notifications
You must be signed in to change notification settings - Fork 9
[Model] add kimi-k2-thinking fp4 support #167
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
|
||
|
|
||
| def should_ignore_layer(quantization_config: Optional[QuantizationConfig], prefix: str) -> bool: | ||
| exclude_layers: List[str] = quantization_config["exclude_layers"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If quantization_config is None, accessing ["exclude_layers"] will raise TypeError.
| regex_pattern = exclude_layer[3:] | ||
| if re.search(regex_pattern, prefix): | ||
| return True | ||
| elif exclude_layer.startswith(prefix): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
exclude_layer.startswith(prefix) is backwards - should be prefix.startswith(exclude_layers)
| # case "lm_head". Common practice won't quant lm_head, however. | ||
| if prefix.split(".")[-1] == exclude_layer: | ||
| return True | ||
| return False No newline at end of file |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider sth like this?
def should_ignore_layer(quantization_config: Optional[QuantizationConfig], prefix: str) -> bool:
if quantization_config is None:
return False
exclude_layers: List[str] = quantization_config.get("exclude_layers", [])
if not exclude_layers:
return False
for exclude_layer in exclude_layers:
if exclude_layer.startswith("re:"):
# case "re:model.layers.self_attn.", remove the 're:' prefix
regex_pattern = exclude_layer[3:]
if re.search(regex_pattern, prefix):
return True
elif prefix.startswith(exclude_layer):
# case "model.layers.0.self_attn.q_a_proj"
return True
elif prefix.split(".")[-1] == exclude_layer:
# case "lm_head". Common practice won't quant lm_head, however.
return True
return False
| hidden_size, | ||
| bias=False, | ||
| quant_config=quant_config, | ||
| quant_config=None if should_ignore_layer(quant_config, prefix=f"{prefix}.down_proj") else quant_config, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The pattern None if should_ignore_layer(...) else quant_config is repeated 4 times. Consider a helper function or local variable.
Motivation
This PR added support for amd/Kimi-K2-Thinking-MXFP4, which only quantize the linear layers in MoE experts to MXFP4 format, i.e., attn layers, MoE gate, dense MLP layers, MoE shared experts and lm_head remains in bf16 weight.
Test Result
TP4 result
Accuracy
NOTE: torch.compile will hit an error with triton==3.5.1. It's a known issue as detailed in pytorch/pytorch#161618, and already fixed by pytorch/pytorch@05eeb29.

Either upgrade torch or downgrade triton can resolve the issue.
The result above is obtained by downgrading triton to 3.4.0.
TP8 result
Since Kimi-k2 has 64 num_heads, each rank will handle 8 heads when running TP8. Some existing kernels are not applicable in this case:
Submission Checklist