diff --git a/slime/backends/megatron_utils/loss.py b/slime/backends/megatron_utils/loss.py index 39d0dae2d..d25989ae4 100644 --- a/slime/backends/megatron_utils/loss.py +++ b/slime/backends/megatron_utils/loss.py @@ -771,7 +771,7 @@ def loss_function( return ( loss, - num_tokens if args.calculate_per_token_loss else 1, + torch.tensor(num_tokens if args.calculate_per_token_loss else 1, device=logits.device), { "keys": list(log.keys()), "values": torch.tensor( diff --git a/slime/backends/megatron_utils/model.py b/slime/backends/megatron_utils/model.py index 7a1b73927..86b21ccdd 100644 --- a/slime/backends/megatron_utils/model.py +++ b/slime/backends/megatron_utils/model.py @@ -84,9 +84,6 @@ def get_optimizer_param_scheduler(args: Namespace, optimizer: MegatronOptimizer) def setup_model_and_optimizer( args: Namespace, role: str = "actor", - no_wd_decay_cond: Callable[..., bool] | None = None, - scale_lr_cond: Callable[..., bool] | None = None, - lr_mult: float = 1.0, ) -> tuple[list[DDP], MegatronOptimizer, OptimizerParamScheduler]: """Build model(s), wrap with DDP, and construct optimizer and scheduler. @@ -119,11 +116,8 @@ def setup_model_and_optimizer( config.timers = None optimizer = get_megatron_optimizer( - config, - model, - no_wd_decay_cond, - scale_lr_cond, - lr_mult, + config=config, + model_chunks=model, use_gloo_process_groups=args.enable_gloo_process_groups, ) opt_param_scheduler = get_optimizer_param_scheduler(args, optimizer) diff --git a/slime/backends/megatron_utils/model_provider.py b/slime/backends/megatron_utils/model_provider.py index d31ecf994..c4553d24f 100644 --- a/slime/backends/megatron_utils/model_provider.py +++ b/slime/backends/megatron_utils/model_provider.py @@ -94,19 +94,19 @@ def model_provider(pre_process: bool = True, post_process: bool = True, vp_stage # Define the decoder layer spec if use_te: transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( - args.num_experts, - args.moe_grouped_gemm, - args.qk_layernorm, - args.multi_latent_attention, - args.moe_use_legacy_grouped_gemm, + num_experts=args.num_experts, + moe_grouped_gemm=args.moe_grouped_gemm, + qk_layernorm=args.qk_layernorm, + multi_latent_attention=args.multi_latent_attention, + moe_use_legacy_grouped_gemm=args.moe_use_legacy_grouped_gemm, ) else: transformer_layer_spec = get_gpt_layer_local_spec( - args.num_experts, - args.moe_grouped_gemm, - args.qk_layernorm, - args.multi_latent_attention, - args.moe_use_legacy_grouped_gemm, + num_experts=args.num_experts, + moe_grouped_gemm=args.moe_grouped_gemm, + qk_layernorm=args.qk_layernorm, + multi_latent_attention=args.multi_latent_attention, + moe_use_legacy_grouped_gemm=args.moe_use_legacy_grouped_gemm, ) build_model_context = nullcontext diff --git a/slime_plugins/models/glm4.py b/slime_plugins/models/glm4.py index d3e920efd..ba42ea1a6 100644 --- a/slime_plugins/models/glm4.py +++ b/slime_plugins/models/glm4.py @@ -3,11 +3,11 @@ def get_glm_spec(args, config, vp_stage): transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( - args.num_experts, - args.moe_grouped_gemm, - args.qk_layernorm, - args.multi_latent_attention, - args.moe_use_legacy_grouped_gemm, + num_experts=args.num_experts, + moe_grouped_gemm=args.moe_grouped_gemm, + qk_layernorm=args.qk_layernorm, + multi_latent_attention=args.multi_latent_attention, + moe_use_legacy_grouped_gemm=args.moe_use_legacy_grouped_gemm, post_self_attn_layernorm=args.post_self_attn_layernorm, post_mlp_layernorm=args.post_mlp_layernorm, )