Skip to content

Commit 5bf5c1b

Browse files
stduhpfleejet
authored andcommitted
Wan MoE: Automatic expert routing based on timestep boundary
1 parent cb1d975 commit 5bf5c1b

File tree

3 files changed

+34
-9
lines changed

3 files changed

+34
-9
lines changed

examples/cli/main.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,10 +113,12 @@ struct SDParams {
113113
bool chroma_use_dit_mask = true;
114114
bool chroma_use_t5_mask = false;
115115
int chroma_t5_mask_pad = 1;
116+
float boundary = 0.875;
116117

117118
SDParams() {
118119
sd_sample_params_init(&sample_params);
119120
sd_sample_params_init(&high_noise_sample_params);
121+
high_noise_sample_params.sample_steps = -1;
120122
}
121123
};
122124

@@ -243,7 +245,7 @@ void print_usage(int argc, const char* argv[]) {
243245
printf(" --high-noise-scheduler {discrete, karras, exponential, ays, gits} Denoiser sigma scheduler (default: discrete)\n");
244246
printf(" --high-noise-sampling-method {euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd}\n");
245247
printf(" (high noise) sampling method (default: \"euler_a\")\n");
246-
printf(" --high-noise-steps STEPS (high noise) number of sample steps (default: 20)\n");
248+
printf(" --high-noise-steps STEPS (high noise) number of sample steps (default: -1 = auto)\n");
247249
printf(" SLG will be enabled at step int([STEPS]*[START]) and disabled at int([STEPS]*[END])\n");
248250
printf(" --strength STRENGTH strength for noising/unnoising (default: 0.75)\n");
249251
printf(" --style-ratio STYLE-RATIO strength for keeping input identity (default: 20)\n");
@@ -274,6 +276,8 @@ void print_usage(int argc, const char* argv[]) {
274276
printf(" --chroma-t5-mask-pad PAD_SIZE t5 mask pad size of chroma\n");
275277
printf(" --video-frames video frames (default: 1)\n");
276278
printf(" --fps fps (default: 24)\n");
279+
printf(" --moe-boundary BOUNDARY Timestep boundary for Wan2.2 MoE model. (default: 0.875)");
280+
printf(" Only enabled if `--high-noise-steps` is set to -1");
277281
printf(" -v, --verbose print extra info\n");
278282
}
279283

@@ -507,6 +511,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
507511
{"", "--strength", "", &params.strength},
508512
{"", "--style-ratio", "", &params.style_ratio},
509513
{"", "--control-strength", "", &params.control_strength},
514+
{"", "--moe-boundary", "", &params.boundary},
510515
};
511516

512517
options.bool_options = {
@@ -767,8 +772,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
767772
}
768773

769774
if (params.high_noise_sample_params.sample_steps <= 0) {
770-
fprintf(stderr, "error: the high_noise_sample_steps must be greater than 0\n");
771-
exit(1);
775+
params.high_noise_sample_params.sample_steps = -1;
772776
}
773777

774778
if (params.strength < 0.f || params.strength > 1.f) {
@@ -1225,6 +1229,7 @@ int main(int argc, const char* argv[]) {
12251229
params.strength,
12261230
params.seed,
12271231
params.video_frames,
1232+
params.boundary
12281233
};
12291234

12301235
results = generate_video(sd_ctx, &vid_gen_params, &num_results);

stable-diffusion.cpp

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1727,11 +1727,13 @@ void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params) {
17271727
memset((void*)sd_vid_gen_params, 0, sizeof(sd_vid_gen_params_t));
17281728
sd_sample_params_init(&sd_vid_gen_params->sample_params);
17291729
sd_sample_params_init(&sd_vid_gen_params->high_noise_sample_params);
1730-
sd_vid_gen_params->width = 512;
1731-
sd_vid_gen_params->height = 512;
1732-
sd_vid_gen_params->strength = 0.75f;
1733-
sd_vid_gen_params->seed = -1;
1734-
sd_vid_gen_params->video_frames = 6;
1730+
sd_vid_gen_params->high_noise_sample_params.sample_steps = -1;
1731+
sd_vid_gen_params->width = 512;
1732+
sd_vid_gen_params->height = 512;
1733+
sd_vid_gen_params->strength = 0.75f;
1734+
sd_vid_gen_params->seed = -1;
1735+
sd_vid_gen_params->video_frames = 6;
1736+
sd_vid_gen_params->boundary = 0.875f;
17351737
}
17361738

17371739
struct sd_ctx_t {
@@ -2381,7 +2383,24 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
23812383
high_noise_sample_steps = sd_vid_gen_params->high_noise_sample_params.sample_steps;
23822384
}
23832385

2384-
std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps + high_noise_sample_steps);
2386+
int total_steps = sample_steps;
2387+
2388+
if (high_noise_sample_steps > 0) {
2389+
total_steps += high_noise_sample_steps;
2390+
}
2391+
std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(total_steps);
2392+
2393+
if(high_noise_sample_steps < 0) {
2394+
// timesteps∝sigmas for Flow models (like wan2.2 a14b)
2395+
for (size_t i = 0; i < sigmas.size(); ++i) {
2396+
if (sigmas[i] < sd_vid_gen_params->boundary) {
2397+
high_noise_sample_steps = i;
2398+
break;
2399+
}
2400+
}
2401+
LOG_DEBUG("Switching from high noise model at step %d", high_noise_sample_steps);
2402+
sample_steps = total_steps - high_noise_sample_steps;
2403+
}
23852404

23862405
struct ggml_init_params params;
23872406
params.mem_size = static_cast<size_t>(200 * 1024) * 1024; // 200 MB

stable-diffusion.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,7 @@ typedef struct {
208208
float strength;
209209
int64_t seed;
210210
int video_frames;
211+
float boundary;
211212
} sd_vid_gen_params_t;
212213

213214
typedef struct sd_ctx_t sd_ctx_t;

0 commit comments

Comments
 (0)