Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions scripts/models/glm4.7-30B-A3B.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
MOE_SHARED_EXPERTS=1

MOE_FFN_HIDDEN=1536
MOE_SHARED_EXPERT_INTERMEDIATE_SIZE=$((MOE_FFN_HIDDEN * MOE_SHARED_EXPERTS))
N_DENSE_LAYERS=1
N_MOE_LAYERS=46

MODEL_ARGS=(
--moe-layer-freq [0]*$N_DENSE_LAYERS+[1]*$N_MOE_LAYERS
--num-experts 64
--moe-shared-expert-intermediate-size $MOE_SHARED_EXPERT_INTERMEDIATE_SIZE
--moe-router-topk 4
--moe-grouped-gemm
--moe-permute-fusion
--moe-ffn-hidden-size $MOE_FFN_HIDDEN
--moe-router-score-function sigmoid
--moe-router-pre-softmax
--moe-router-enable-expert-bias
--moe-router-bias-update-rate 0
--moe-router-load-balancing-type seq_aux_loss
--moe-router-topk-scaling-factor 1.8
--moe-aux-loss-coeff 0
--moe-router-dtype fp32
--num-layers $((N_DENSE_LAYERS + N_MOE_LAYERS))
--hidden-size 2048
--ffn-hidden-size 10240
--num-attention-heads 20
--disable-bias-linear
--add-qkv-bias
--swiglu
--untie-embeddings-and-output-weights
--position-embedding-type rope
--no-position-embedding
--normalization RMSNorm
--qk-layernorm
--multi-latent-attention
--q-lora-rank 768
--kv-lora-rank 512
--qk-head-dim 192
--v-head-dim 256
--kv-channels 192
--qk-pos-emb-head-dim 64
--vocab-size 154880
--rotary-base 1000000
--enable-experimental
)
7 changes: 3 additions & 4 deletions slime/backends/megatron_utils/megatron_to_hf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ def convert_to_hf(args, model_name, name, param, quantization_config=None):

# TODO optimize code details
def _convert_to_hf_core(args, model_name, name, param):
if "glm4moe" in model_name:
if "glm4moelite" in model_name or "deepseekv3" in model_name:
converted_named_tensors = convert_deepseekv3_to_hf(args, name, param)
elif "glm4moe" in model_name:
converted_named_tensors = convert_glm4moe_to_hf(args, name, param)
elif "glm4" in model_name:
converted_named_tensors = convert_glm4_to_hf(args, name, param)
Expand All @@ -41,9 +43,6 @@ def _convert_to_hf_core(args, model_name, name, param):
converted_named_tensors = convert_qwen3_next_to_hf(args, name, param)
elif "qwen2" in model_name or "qwen3" in model_name:
converted_named_tensors = convert_qwen2_to_hf(args, name, param)
elif "deepseekv3" in model_name:
converted_named_tensors = convert_deepseekv3_to_hf(args, name, param)

elif "llama" in model_name:
converted_named_tensors = convert_llama_to_hf(args, name, param)
elif "mimo" in model_name:
Expand Down
3 changes: 2 additions & 1 deletion slime_plugins/mbridge/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .glm4 import GLM4Bridge
from .glm4moe import GLM4MoEBridge
from .glm4moe_lite import GLM4MoELiteBridge
from .mimo import MimoBridge
from .qwen3_next import Qwen3NextBridge

__all__ = ["GLM4Bridge", "GLM4MoEBridge", "Qwen3NextBridge", "MimoBridge"]
__all__ = ["GLM4Bridge", "GLM4MoEBridge", "GLM4MoELiteBridge", "Qwen3NextBridge", "MimoBridge"]
7 changes: 7 additions & 0 deletions slime_plugins/mbridge/glm4moe_lite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from mbridge.core import register_model
from mbridge.models import DeepseekV3Bridge


@register_model("glm4_moe_lite")
class GLM4MoELiteBridge(DeepseekV3Bridge):
pass