Skip to content

Commit bb16a01

Browse files
committed
feat: use Euler sampling by default for SD3 and Flux
1 parent c82d9aa commit bb16a01

File tree

3 files changed

+33
-8
lines changed

3 files changed

+33
-8
lines changed

examples/cli/main.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ struct SDParams {
8383
int fps = 6;
8484
float augmentation_level = 0.f;
8585

86-
sample_method_t sample_method = EULER_A;
86+
sample_method_t sample_method = SAMPLE_METHOD_DEFAULT;
8787
schedule_t schedule = DEFAULT;
8888
int sample_steps = 20;
8989
float strength = 0.75f;
@@ -222,7 +222,7 @@ void print_usage(int argc, const char* argv[]) {
222222
printf(" -H, --height H image height, in pixel space (default: 512)\n");
223223
printf(" -W, --width W image width, in pixel space (default: 512)\n");
224224
printf(" --sampling-method {euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd}\n");
225-
printf(" sampling method (default: \"euler_a\")\n");
225+
printf(" sampling method (default: \"euler\" for Flux/SD3, \"euler_a\" otherwise)\n");
226226
printf(" --steps STEPS number of sample steps (default: 20)\n");
227227
printf(" --rng {std_default, cuda} RNG (default: cuda)\n");
228228
printf(" -s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0)\n");
@@ -925,6 +925,10 @@ int main(int argc, const char* argv[]) {
925925
return 1;
926926
}
927927

928+
if (params.sample_method == SAMPLE_METHOD_DEFAULT) {
929+
params.sample_method = sd_get_default_sample_method (sd_ctx);
930+
}
931+
928932
sd_image_t input_image = {(uint32_t)params.width,
929933
(uint32_t)params.height,
930934
3,

stable-diffusion.cpp

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ const char* model_version_to_str[] = {
3939
"Flux Fill"};
4040

4141
const char* sampling_methods_str[] = {
42-
"Euler A",
42+
"default",
4343
"Euler",
4444
"Heun",
4545
"DPM2",
@@ -50,7 +50,8 @@ const char* sampling_methods_str[] = {
5050
"iPNDM_v",
5151
"LCM",
5252
"DDIM \"trailing\"",
53-
"TCD"};
53+
"TCD",
54+
"Euler A"};
5455

5556
/*================================================== Helper Functions ================================================*/
5657

@@ -1251,7 +1252,7 @@ enum rng_type_t str_to_rng_type(const char* str) {
12511252
}
12521253

12531254
const char* sample_method_to_str[] = {
1254-
"euler_a",
1255+
"default",
12551256
"euler",
12561257
"heun",
12571258
"dpm2",
@@ -1263,6 +1264,7 @@ const char* sample_method_to_str[] = {
12631264
"lcm",
12641265
"ddim_trailing",
12651266
"tcd",
1267+
"euler_a",
12661268
};
12671269

12681270
const char* sd_sample_method_name(enum sample_method_t sample_method) {
@@ -1399,7 +1401,7 @@ void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params) {
13991401
sd_img_gen_params->ref_images_count = 0;
14001402
sd_img_gen_params->width = 512;
14011403
sd_img_gen_params->height = 512;
1402-
sd_img_gen_params->sample_method = EULER_A;
1404+
sd_img_gen_params->sample_method = SAMPLE_METHOD_DEFAULT;
14031405
sd_img_gen_params->sample_steps = 20;
14041406
sd_img_gen_params->eta = 0.f;
14051407
sd_img_gen_params->strength = 0.75f;
@@ -1524,6 +1526,18 @@ void free_sd_ctx(sd_ctx_t* sd_ctx) {
15241526
free(sd_ctx);
15251527
}
15261528

1529+
SD_API enum sample_method_t sd_get_default_sample_method(const sd_ctx_t* sd_ctx)
1530+
{
1531+
if (sd_ctx != NULL && sd_ctx->sd != NULL) {
1532+
SDVersion version = sd_ctx->sd->version;
1533+
if (sd_version_is_dit(version))
1534+
return EULER;
1535+
else
1536+
return EULER_A;
1537+
}
1538+
return SAMPLE_METHOD_COUNT;
1539+
}
1540+
15271541
sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
15281542
struct ggml_context* work_ctx,
15291543
ggml_tensor* init_latent,
@@ -2076,6 +2090,11 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
20762090
LOG_INFO("encode_first_stage completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
20772091
}
20782092

2093+
enum sample_method_t sample_method = sd_img_gen_params->sample_method;
2094+
if (sample_method == SAMPLE_METHOD_DEFAULT) {
2095+
sample_method = sd_get_default_sample_method (sd_ctx);
2096+
}
2097+
20792098
sd_image_t* result_images = generate_image_internal(sd_ctx,
20802099
work_ctx,
20812100
init_latent,
@@ -2086,7 +2105,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
20862105
sd_img_gen_params->eta,
20872106
width,
20882107
height,
2089-
sd_img_gen_params->sample_method,
2108+
sample_method,
20902109
sigmas,
20912110
seed,
20922111
sd_img_gen_params->batch_count,

stable-diffusion.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ enum rng_type_t {
3535
};
3636

3737
enum sample_method_t {
38-
EULER_A,
38+
SAMPLE_METHOD_DEFAULT,
3939
EULER,
4040
HEUN,
4141
DPM2,
@@ -47,6 +47,7 @@ enum sample_method_t {
4747
LCM,
4848
DDIM_TRAILING,
4949
TCD,
50+
EULER_A,
5051
SAMPLE_METHOD_COUNT
5152
};
5253

@@ -227,6 +228,7 @@ SD_API char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params);
227228

228229
SD_API sd_ctx_t* new_sd_ctx(const sd_ctx_params_t* sd_ctx_params);
229230
SD_API void free_sd_ctx(sd_ctx_t* sd_ctx);
231+
SD_API enum sample_method_t sd_get_default_sample_method(const sd_ctx_t* sd_ctx);
230232

231233
SD_API void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params);
232234
SD_API char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params);

0 commit comments

Comments
 (0)