Skip to content

Commit e2d8457

Browse files
committed
feat: use Euler sampling by default for SD3 and Flux
1 parent fce6afc commit e2d8457

File tree

3 files changed

+31
-6
lines changed

3 files changed

+31
-6
lines changed

examples/cli/main.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ void print_usage(int argc, const char* argv[]) {
240240
printf(" --skip-layer-end END SLG disabling point: (default: 0.2)\n");
241241
printf(" --scheduler {discrete, karras, exponential, ays, gits, smoothstep} Denoiser sigma scheduler (default: discrete)\n");
242242
printf(" --sampling-method {euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd}\n");
243-
printf(" sampling method (default: \"euler_a\")\n");
243+
printf(" sampling method (default: \"euler\" for Flux/SD3, \"euler_a\" otherwise)\n");
244244
printf(" --steps STEPS number of sample steps (default: 20)\n");
245245
printf(" --high-noise-cfg-scale SCALE (high noise) unconditional guidance scale: (default: 7.0)\n");
246246
printf(" --high-noise-img-cfg-scale SCALE (high noise) image guidance scale for inpaint or instruct-pix2pix models: (default: same as --cfg-scale)\n");
@@ -1202,6 +1202,10 @@ int main(int argc, const char* argv[]) {
12021202
return 1;
12031203
}
12041204

1205+
if (params.sample_params.sample_method == SAMPLE_METHOD_DEFAULT) {
1206+
params.sample_params.sample_method = sd_get_default_sample_method(sd_ctx);
1207+
}
1208+
12051209
sd_image_t* results;
12061210
int num_results = 1;
12071211
if (params.mode == IMG_GEN) {

stable-diffusion.cpp

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ const char* model_version_to_str[] = {
4343
};
4444

4545
const char* sampling_methods_str[] = {
46-
"Euler A",
46+
"default",
4747
"Euler",
4848
"Heun",
4949
"DPM2",
@@ -55,6 +55,7 @@ const char* sampling_methods_str[] = {
5555
"LCM",
5656
"DDIM \"trailing\"",
5757
"TCD",
58+
"Euler A",
5859
};
5960

6061
/*================================================== Helper Functions ================================================*/
@@ -1502,7 +1503,7 @@ enum rng_type_t str_to_rng_type(const char* str) {
15021503
}
15031504

15041505
const char* sample_method_to_str[] = {
1505-
"euler_a",
1506+
"default",
15061507
"euler",
15071508
"heun",
15081509
"dpm2",
@@ -1514,6 +1515,7 @@ const char* sample_method_to_str[] = {
15141515
"lcm",
15151516
"ddim_trailing",
15161517
"tcd",
1518+
"euler_a",
15171519
};
15181520

15191521
const char* sd_sample_method_name(enum sample_method_t sample_method) {
@@ -1652,7 +1654,7 @@ void sd_sample_params_init(sd_sample_params_t* sample_params) {
16521654
sample_params->guidance.slg.layer_end = 0.2f;
16531655
sample_params->guidance.slg.scale = 0.f;
16541656
sample_params->scheduler = DEFAULT;
1655-
sample_params->sample_method = EULER_A;
1657+
sample_params->sample_method = SAMPLE_METHOD_DEFAULT;
16561658
sample_params->sample_steps = 20;
16571659
}
16581660

@@ -1794,6 +1796,18 @@ void free_sd_ctx(sd_ctx_t* sd_ctx) {
17941796
free(sd_ctx);
17951797
}
17961798

1799+
enum sample_method_t sd_get_default_sample_method(const sd_ctx_t* sd_ctx)
1800+
{
1801+
if (sd_ctx != NULL && sd_ctx->sd != NULL) {
1802+
SDVersion version = sd_ctx->sd->version;
1803+
if (sd_version_is_dit(version))
1804+
return EULER;
1805+
else
1806+
return EULER_A;
1807+
}
1808+
return SAMPLE_METHOD_COUNT;
1809+
}
1810+
17971811
sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
17981812
struct ggml_context* work_ctx,
17991813
ggml_tensor* init_latent,
@@ -2358,6 +2372,11 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
23582372
LOG_INFO("encode_first_stage completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
23592373
}
23602374

2375+
enum sample_method_t sample_method = sd_img_gen_params->sample_params.sample_method;
2376+
if (sample_method == SAMPLE_METHOD_DEFAULT) {
2377+
sample_method = sd_get_default_sample_method(sd_ctx);
2378+
}
2379+
23612380
sd_image_t* result_images = generate_image_internal(sd_ctx,
23622381
work_ctx,
23632382
init_latent,
@@ -2368,7 +2387,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
23682387
sd_img_gen_params->sample_params.eta,
23692388
width,
23702389
height,
2371-
sd_img_gen_params->sample_params.sample_method,
2390+
sample_method,
23722391
sigmas,
23732392
seed,
23742393
sd_img_gen_params->batch_count,

stable-diffusion.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ enum rng_type_t {
3535
};
3636

3737
enum sample_method_t {
38-
EULER_A,
38+
SAMPLE_METHOD_DEFAULT,
3939
EULER,
4040
HEUN,
4141
DPM2,
@@ -47,6 +47,7 @@ enum sample_method_t {
4747
LCM,
4848
DDIM_TRAILING,
4949
TCD,
50+
EULER_A,
5051
SAMPLE_METHOD_COUNT
5152
};
5253

@@ -238,6 +239,7 @@ SD_API char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params);
238239

239240
SD_API sd_ctx_t* new_sd_ctx(const sd_ctx_params_t* sd_ctx_params);
240241
SD_API void free_sd_ctx(sd_ctx_t* sd_ctx);
242+
SD_API enum sample_method_t sd_get_default_sample_method(const sd_ctx_t* sd_ctx);
241243

242244
SD_API void sd_sample_params_init(sd_sample_params_t* sample_params);
243245
SD_API char* sd_sample_params_to_str(const sd_sample_params_t* sample_params);

0 commit comments

Comments
 (0)