@@ -43,7 +43,7 @@ const char* model_version_to_str[] = {
4343};
4444
4545const char * sampling_methods_str[] = {
46- " Euler A " ,
46+ " default " ,
4747 " Euler" ,
4848 " Heun" ,
4949 " DPM2" ,
@@ -55,6 +55,7 @@ const char* sampling_methods_str[] = {
5555 " LCM" ,
5656 " DDIM \" trailing\" " ,
5757 " TCD" ,
58+ " Euler A" ,
5859};
5960
6061/* ================================================== Helper Functions ================================================*/
@@ -1502,7 +1503,7 @@ enum rng_type_t str_to_rng_type(const char* str) {
15021503}
15031504
15041505const char * sample_method_to_str[] = {
1505- " euler_a " ,
1506+ " default " ,
15061507 " euler" ,
15071508 " heun" ,
15081509 " dpm2" ,
@@ -1514,6 +1515,7 @@ const char* sample_method_to_str[] = {
15141515 " lcm" ,
15151516 " ddim_trailing" ,
15161517 " tcd" ,
1518+ " euler_a" ,
15171519};
15181520
15191521const char * sd_sample_method_name (enum sample_method_t sample_method) {
@@ -1652,7 +1654,7 @@ void sd_sample_params_init(sd_sample_params_t* sample_params) {
16521654 sample_params->guidance .slg .layer_end = 0 .2f ;
16531655 sample_params->guidance .slg .scale = 0 .f ;
16541656 sample_params->scheduler = DEFAULT;
1655- sample_params->sample_method = EULER_A ;
1657+ sample_params->sample_method = SAMPLE_METHOD_DEFAULT ;
16561658 sample_params->sample_steps = 20 ;
16571659}
16581660
@@ -1794,6 +1796,18 @@ void free_sd_ctx(sd_ctx_t* sd_ctx) {
17941796 free (sd_ctx);
17951797}
17961798
1799+ enum sample_method_t sd_get_default_sample_method (const sd_ctx_t * sd_ctx)
1800+ {
1801+ if (sd_ctx != NULL && sd_ctx->sd != NULL ) {
1802+ SDVersion version = sd_ctx->sd ->version ;
1803+ if (sd_version_is_dit (version))
1804+ return EULER;
1805+ else
1806+ return EULER_A;
1807+ }
1808+ return SAMPLE_METHOD_COUNT;
1809+ }
1810+
17971811sd_image_t * generate_image_internal (sd_ctx_t * sd_ctx,
17981812 struct ggml_context * work_ctx,
17991813 ggml_tensor* init_latent,
@@ -2358,6 +2372,11 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
23582372 LOG_INFO (" encode_first_stage completed, taking %.2fs" , (t1 - t0) * 1 .0f / 1000 );
23592373 }
23602374
2375+ enum sample_method_t sample_method = sd_img_gen_params->sample_params .sample_method ;
2376+ if (sample_method == SAMPLE_METHOD_DEFAULT) {
2377+ sample_method = sd_get_default_sample_method (sd_ctx);
2378+ }
2379+
23612380 sd_image_t * result_images = generate_image_internal (sd_ctx,
23622381 work_ctx,
23632382 init_latent,
@@ -2368,7 +2387,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
23682387 sd_img_gen_params->sample_params .eta ,
23692388 width,
23702389 height,
2371- sd_img_gen_params-> sample_params . sample_method ,
2390+ sample_method,
23722391 sigmas,
23732392 seed,
23742393 sd_img_gen_params->batch_count ,
0 commit comments