@@ -39,7 +39,7 @@ const char* model_version_to_str[] = {
3939 " Flux Fill" };
4040
4141const char * sampling_methods_str[] = {
42- " Euler A " ,
42+ " default " ,
4343 " Euler" ,
4444 " Heun" ,
4545 " DPM2" ,
@@ -50,7 +50,8 @@ const char* sampling_methods_str[] = {
5050 " iPNDM_v" ,
5151 " LCM" ,
5252 " DDIM \" trailing\" " ,
53- " TCD" };
53+ " TCD" ,
54+ " Euler A" };
5455
5556/* ================================================== Helper Functions ================================================*/
5657
@@ -1251,7 +1252,7 @@ enum rng_type_t str_to_rng_type(const char* str) {
12511252}
12521253
12531254const char * sample_method_to_str[] = {
1254- " euler_a " ,
1255+ " default " ,
12551256 " euler" ,
12561257 " heun" ,
12571258 " dpm2" ,
@@ -1263,6 +1264,7 @@ const char* sample_method_to_str[] = {
12631264 " lcm" ,
12641265 " ddim_trailing" ,
12651266 " tcd" ,
1267+ " euler_a" ,
12661268};
12671269
12681270const char * sd_sample_method_name (enum sample_method_t sample_method) {
@@ -1399,7 +1401,7 @@ void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params) {
13991401 sd_img_gen_params->ref_images_count = 0 ;
14001402 sd_img_gen_params->width = 512 ;
14011403 sd_img_gen_params->height = 512 ;
1402- sd_img_gen_params->sample_method = EULER_A ;
1404+ sd_img_gen_params->sample_method = SAMPLE_METHOD_DEFAULT ;
14031405 sd_img_gen_params->sample_steps = 20 ;
14041406 sd_img_gen_params->eta = 0 .f ;
14051407 sd_img_gen_params->strength = 0 .75f ;
@@ -1524,6 +1526,18 @@ void free_sd_ctx(sd_ctx_t* sd_ctx) {
15241526 free (sd_ctx);
15251527}
15261528
1529+ SD_API enum sample_method_t sd_get_default_sample_method (const sd_ctx_t * sd_ctx)
1530+ {
1531+ if (sd_ctx != NULL && sd_ctx->sd != NULL ) {
1532+ SDVersion version = sd_ctx->sd ->version ;
1533+ if (sd_version_is_dit (version))
1534+ return EULER;
1535+ else
1536+ return EULER_A;
1537+ }
1538+ return SAMPLE_METHOD_COUNT;
1539+ }
1540+
15271541sd_image_t * generate_image_internal (sd_ctx_t * sd_ctx,
15281542 struct ggml_context * work_ctx,
15291543 ggml_tensor* init_latent,
@@ -2076,6 +2090,11 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
20762090 LOG_INFO (" encode_first_stage completed, taking %.2fs" , (t1 - t0) * 1 .0f / 1000 );
20772091 }
20782092
2093+ enum sample_method_t sample_method = sd_img_gen_params->sample_method ;
2094+ if (sample_method == SAMPLE_METHOD_DEFAULT) {
2095+ sample_method = sd_get_default_sample_method (sd_ctx);
2096+ }
2097+
20792098 sd_image_t * result_images = generate_image_internal (sd_ctx,
20802099 work_ctx,
20812100 init_latent,
@@ -2086,7 +2105,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
20862105 sd_img_gen_params->eta ,
20872106 width,
20882107 height,
2089- sd_img_gen_params-> sample_method ,
2108+ sample_method,
20902109 sigmas,
20912110 seed,
20922111 sd_img_gen_params->batch_count ,
0 commit comments