diff --git a/.gitignore b/.gitignore index 38fe570d..552d5673 100644 --- a/.gitignore +++ b/.gitignore @@ -10,4 +10,5 @@ test/ *.gguf output*.png models* -*.log \ No newline at end of file +*.log +preview.png diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index cf8f5b13..cc1c14d3 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -60,6 +60,13 @@ const char* modes_str[] = { "convert", }; +const char* previews_str[] = { + "none", + "proj", + "tae", + "vae", +}; + enum SDMode { TXT2IMG, IMG2IMG, @@ -129,6 +136,11 @@ struct SDParams { float slg_scale = 0.; float skip_layer_start = 0.01; float skip_layer_end = 0.2; + + sd_preview_t preview_method = SD_PREVIEW_NONE; + int preview_interval = 1; + std::string preview_path = "preview.png"; + bool taesd_preview = false; }; void print_params(SDParams params) { @@ -174,10 +186,12 @@ void print_params(SDParams params) { printf(" sample_steps: %d\n", params.sample_steps); printf(" strength(img2img): %.2f\n", params.strength); printf(" rng: %s\n", rng_type_to_str[params.rng_type]); - printf(" seed: %ld\n", params.seed); + printf(" seed: %lld\n", params.seed); printf(" batch_count: %d\n", params.batch_count); printf(" vae_tiling: %s\n", params.vae_tiling ? "true" : "false"); printf(" upscale_repeats: %d\n", params.upscale_repeats); + printf(" preview_mode: %s\n", previews_str[params.preview_method]); + printf(" preview_interval: %d\n", params.preview_interval); } void print_usage(int argc, const char* argv[]) { @@ -185,16 +199,17 @@ void print_usage(int argc, const char* argv[]) { printf("\n"); printf("arguments:\n"); printf(" -h, --help show this help message and exit\n"); - printf(" -M, --mode [MODEL] run mode (txt2img or img2img or convert, default: txt2img)\n"); + printf(" -M, --mode [MODE] run mode (txt2img or img2img or convert, default: txt2img)\n"); printf(" -t, --threads N number of threads to use during computation (default: -1)\n"); printf(" If threads <= 0, then threads will be set to the number of CPU physical cores\n"); printf(" -m, --model [MODEL] path to full model\n"); - printf(" --diffusion-model path to the standalone diffusion model\n"); - printf(" --clip_l path to the clip-l text encoder\n"); - printf(" --clip_g path to the clip-g text encoder\n"); - printf(" --t5xxl path to the the t5xxl text encoder\n"); + printf(" --diffusion-model [MODEL] path to the standalone diffusion model\n"); + printf(" --clip_l [ENCODER] path to the clip-l text encoder\n"); + printf(" --clip_g [ENCODER] path to the clip-g text encoder\n"); + printf(" --t5xxl [ENCODER] path to the the t5xxl text encoder\n"); printf(" --vae [VAE] path to vae\n"); - printf(" --taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)\n"); + printf(" --taesd [TAESD] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)\n"); + printf(" --taesd-preview-only prevents usage of taesd for decoding the final image. (for use with --preview %s)\n", previews_str[SD_PREVIEW_TAE]); printf(" --control-net [CONTROL_PATH] path to control net model\n"); printf(" --embd-dir [EMBEDDING_PATH] path to embeddings\n"); printf(" --stacked-id-embd-dir [DIR] path to PHOTOMAKER stacked id embeddings\n"); @@ -243,6 +258,10 @@ void print_usage(int argc, const char* argv[]) { printf(" This might crash if it is not supported by the backend.\n"); printf(" --control-net-cpu keep controlnet in cpu (for low vram)\n"); printf(" --canny apply canny preprocessor (edge detection)\n"); + printf(" --preview {%s,%s,%s,%s} preview method. (default is %s(disabled))\n", previews_str[0], previews_str[1], previews_str[2], previews_str[3], previews_str[SD_PREVIEW_NONE]); + printf(" %s is the fastest\n", previews_str[SD_PREVIEW_PROJ]); + printf(" --preview-interval [N] How often to save the image preview"); + printf(" --preview-path [PATH} path to write preview image to (default: ./preview.png)\n"); printf(" --color Colors the logging tags according to level\n"); printf(" -v, --verbose print extra info\n"); } @@ -507,6 +526,8 @@ void parse_args(int argc, const char** argv, SDParams& params) { params.diffusion_flash_attn = true; // can reduce MEM significantly } else if (arg == "--canny") { params.canny_preprocess = true; + } else if (arg == "--taesd-preview-only") { + params.taesd_preview = true; } else if (arg == "-b" || arg == "--batch-count") { if (++i >= argc) { invalid_arg = true; @@ -629,6 +650,35 @@ void parse_args(int argc, const char** argv, SDParams& params) { break; } params.skip_layer_end = std::stof(argv[i]); + } else if (arg == "--preview") { + if (++i >= argc) { + invalid_arg = true; + break; + } + const char* preview = argv[i]; + int preview_method = -1; + for (int m = 0; m < N_PREVIEWS; m++) { + if (!strcmp(preview, previews_str[m])) { + preview_method = m; + } + } + if (preview_method == -1) { + invalid_arg = true; + break; + } + params.preview_method = (sd_preview_t)preview_method; + } else if (arg == "--preview-interval") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.preview_interval = std::stoi(argv[i]); + } else if (arg == "--preview-path") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.preview_path = argv[i]; } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); print_usage(argc, argv); @@ -787,12 +837,20 @@ void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) { fflush(out_stream); } +const char* preview_path; + +void step_callback(int step, sd_image_t image) { + stbi_write_png(preview_path, image.width, image.height, image.channel, image.data, 0); +} + int main(int argc, const char* argv[]) { SDParams params; parse_args(argc, argv, params); + preview_path = params.preview_path.c_str(); sd_set_log_callback(sd_log_cb, (void*)¶ms); + sd_set_preview_callback((sd_preview_cb_t)step_callback, params.preview_method, params.preview_interval); if (params.verbose) { print_params(params); @@ -900,7 +958,8 @@ int main(int argc, const char* argv[]) { params.clip_on_cpu, params.control_net_cpu, params.vae_on_cpu, - params.diffusion_flash_attn); + params.diffusion_flash_attn, + params.taesd_preview); if (sd_ctx == NULL) { printf("new_sd_ctx_t failed\n"); @@ -1075,11 +1134,11 @@ int main(int argc, const char* argv[]) { std::string dummy_name, ext, lc_ext; bool is_jpg; - size_t last = params.output_path.find_last_of("."); + size_t last = params.output_path.find_last_of("."); size_t last_path = std::min(params.output_path.find_last_of("/"), params.output_path.find_last_of("\\")); - if (last != std::string::npos // filename has extension - && (last_path == std::string::npos || last > last_path)) { + if (last != std::string::npos // filename has extension + && (last_path == std::string::npos || last > last_path)) { dummy_name = params.output_path.substr(0, last); ext = lc_ext = params.output_path.substr(last); std::transform(ext.begin(), ext.end(), lc_ext.begin(), ::tolower); @@ -1087,7 +1146,7 @@ int main(int argc, const char* argv[]) { } else { dummy_name = params.output_path; ext = lc_ext = ""; - is_jpg = false; + is_jpg = false; } // appending ".png" to absent or unknown extension if (!is_jpg && lc_ext != ".png") { @@ -1099,7 +1158,7 @@ int main(int argc, const char* argv[]) { continue; } std::string final_image_path = i > 0 ? dummy_name + "_" + std::to_string(i + 1) + ext : dummy_name + ext; - if(is_jpg) { + if (is_jpg) { stbi_write_jpg(final_image_path.c_str(), results[i].width, results[i].height, results[i].channel, results[i].data, 90, get_image_params(params, params.seed + i).c_str()); printf("save result JPEG image to '%s'\n", final_image_path.c_str()); diff --git a/ggml_extend.hpp b/ggml_extend.hpp index c5913be4..8404a997 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -627,7 +627,7 @@ __STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const ggml_tensor* output_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, tile_size * scale, tile_size * scale, output->ne[2], 1); on_processing(input_tile, NULL, true); int num_tiles = ceil((float)input_width / non_tile_overlap) * ceil((float)input_height / non_tile_overlap); - LOG_INFO("processing %i tiles", num_tiles); + LOG_DEBUG("processing %i tiles", num_tiles); pretty_progress(1, num_tiles, 0.0f); int tile_count = 1; bool last_y = false, last_x = false; diff --git a/latent-preview.h b/latent-preview.h new file mode 100644 index 00000000..ca4d132f --- /dev/null +++ b/latent-preview.h @@ -0,0 +1,83 @@ + +// https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/latent_formats.py#L152-L169 +const float flux_latent_rgb_proj[16][3] = { + {-0.0346f, 0.0244f, 0.0681f}, + {0.0034f, 0.0210f, 0.0687f}, + {0.0275f, -0.0668f, -0.0433f}, + {-0.0174f, 0.0160f, 0.0617f}, + {0.0859f, 0.0721f, 0.0329f}, + {0.0004f, 0.0383f, 0.0115f}, + {0.0405f, 0.0861f, 0.0915f}, + {-0.0236f, -0.0185f, -0.0259f}, + {-0.0245f, 0.0250f, 0.1180f}, + {0.1008f, 0.0755f, -0.0421f}, + {-0.0515f, 0.0201f, 0.0011f}, + {0.0428f, -0.0012f, -0.0036f}, + {0.0817f, 0.0765f, 0.0749f}, + {-0.1264f, -0.0522f, -0.1103f}, + {-0.0280f, -0.0881f, -0.0499f}, + {-0.1262f, -0.0982f, -0.0778f}}; + +// https://github.com/Stability-AI/sd3.5/blob/main/sd3_impls.py#L228-L246 +const float sd3_latent_rgb_proj[16][3] = { + {-0.0645f, 0.0177f, 0.1052f}, + {0.0028f, 0.0312f, 0.0650f}, + {0.1848f, 0.0762f, 0.0360f}, + {0.0944f, 0.0360f, 0.0889f}, + {0.0897f, 0.0506f, -0.0364f}, + {-0.0020f, 0.1203f, 0.0284f}, + {0.0855f, 0.0118f, 0.0283f}, + {-0.0539f, 0.0658f, 0.1047f}, + {-0.0057f, 0.0116f, 0.0700f}, + {-0.0412f, 0.0281f, -0.0039f}, + {0.1106f, 0.1171f, 0.1220f}, + {-0.0248f, 0.0682f, -0.0481f}, + {0.0815f, 0.0846f, 0.1207f}, + {-0.0120f, -0.0055f, -0.0867f}, + {-0.0749f, -0.0634f, -0.0456f}, + {-0.1418f, -0.1457f, -0.1259f}, +}; + +// https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/latent_formats.py#L32-L38 +const float sdxl_latent_rgb_proj[4][3] = { + {0.3651f, 0.4232f, 0.4341f}, + {-0.2533f, -0.0042f, 0.1068f}, + {0.1076f, 0.1111f, -0.0362f}, + {-0.3165f, -0.2492f, -0.2188f}}; + +// https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/latent_formats.py#L32-L38 +const float sd_latent_rgb_proj[4][3]{ + {0.3512f, 0.2297f, 0.3227f}, + {0.3250f, 0.4974f, 0.2350f}, + {-0.2829f, 0.1762f, 0.2721f}, + {-0.2120f, -0.2616f, -0.7177f}}; + +void preview_latent_image(uint8_t* buffer, struct ggml_tensor* latents, const float (*latent_rgb_proj)[3], int width, int height, int dim) { + size_t buffer_head = 0; + for (int j = 0; j < height; j++) { + for (int i = 0; i < width; i++) { + size_t latent_id = (i * latents->nb[0] + j * latents->nb[1]); + float r = 0, g = 0, b = 0; + for (int d = 0; d < dim; d++) { + float value = *(float*)((char*)latents->data + latent_id + d * latents->nb[2]); + r += value * latent_rgb_proj[d][0]; + g += value * latent_rgb_proj[d][1]; + b += value * latent_rgb_proj[d][2]; + } + + // change range + r = r * .5f + .5f; + g = g * .5f + .5f; + b = b * .5f + .5f; + + // clamp rgb values to [0,1] range + r = r >= 0 ? r <= 1 ? r : 1 : 0; + g = g >= 0 ? g <= 1 ? g : 1 : 0; + b = b >= 0 ? b <= 1 ? b : 1 : 0; + + buffer[buffer_head++] = (uint8_t)(r * 255); + buffer[buffer_head++] = (uint8_t)(g * 255); + buffer[buffer_head++] = (uint8_t)(b * 255); + } + } +} \ No newline at end of file diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index a2d33bca..7d325321 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -20,6 +20,8 @@ #define STB_IMAGE_STATIC #include "stb_image.h" +#include "latent-preview.h" + // #define STB_IMAGE_WRITE_IMPLEMENTATION // #define STB_IMAGE_WRITE_STATIC // #include "stb_image_write.h" @@ -48,8 +50,7 @@ const char* sampling_methods_str[] = { "iPNDM_v", "LCM", "DDIM \"trailing\"", - "TCD" -}; + "TCD"}; /*================================================== Helper Functions ================================================*/ @@ -68,6 +69,14 @@ void calculate_alphas_cumprod(float* alphas_cumprod, } } +void suppress_pp(int step, int steps, float time, void* data) { + (void)step; + (void)steps; + (void)time; + (void)data; + return; +} + /*=============================================== StableDiffusionGGML ================================================*/ class StableDiffusionGGML { @@ -159,7 +168,8 @@ class StableDiffusionGGML { bool clip_on_cpu, bool control_net_cpu, bool vae_on_cpu, - bool diffusion_flash_attn) { + bool diffusion_flash_attn, + bool tae_preview_only) { use_tiny_autoencoder = taesd_path.size() > 0; #ifdef SD_USE_CUDA LOG_DEBUG("Using CUDA backend"); @@ -351,7 +361,7 @@ class StableDiffusionGGML { diffusion_model->alloc_params_buffer(); diffusion_model->get_param_tensors(tensors); - if (!use_tiny_autoencoder) { + if (!use_tiny_autoencoder || tae_preview_only) { if (vae_on_cpu && !ggml_backend_is_cpu(backend)) { LOG_INFO("VAE Autoencoder: Using CPU backend"); vae_backend = ggml_backend_cpu_init(); @@ -361,7 +371,8 @@ class StableDiffusionGGML { first_stage_model = std::make_shared(vae_backend, model_loader.tensor_storages_types, "first_stage_model", vae_decode_only, false, version); first_stage_model->alloc_params_buffer(); first_stage_model->get_param_tensors(tensors, "first_stage_model"); - } else { + } + if (use_tiny_autoencoder) { tae_first_stage = std::make_shared(backend, model_loader.tensor_storages_types, "decoder.layers", vae_decode_only, version); } // first_stage_model->get_param_tensors(tensors, "first_stage_model."); @@ -453,9 +464,10 @@ class StableDiffusionGGML { size_t clip_params_mem_size = cond_stage_model->get_params_buffer_size(); size_t unet_params_mem_size = diffusion_model->get_params_buffer_size(); size_t vae_params_mem_size = 0; - if (!use_tiny_autoencoder) { + if (!use_tiny_autoencoder || tae_preview_only) { vae_params_mem_size = first_stage_model->get_params_buffer_size(); - } else { + } + if (use_tiny_autoencoder) { if (!tae_first_stage->load_from_file(taesd_path)) { return false; } @@ -599,6 +611,7 @@ class StableDiffusionGGML { LOG_DEBUG("finished loaded file"); ggml_free(ctx); + use_tiny_autoencoder = use_tiny_autoencoder && !tae_preview_only; return true; } @@ -682,7 +695,7 @@ class StableDiffusionGGML { float curr_multiplier = kv.second; lora_state_diff[lora_name] -= curr_multiplier; } - + size_t rm = lora_state_diff.size() - lora_state.size(); if (rm != 0) { LOG_INFO("Attempting to apply %lu LoRAs (removing %lu applied LoRAs)", lora_state.size(), rm); @@ -785,38 +798,135 @@ class StableDiffusionGGML { return {c_crossattn, y, c_concat}; } - ggml_tensor* sample(ggml_context* work_ctx, - ggml_tensor* init_latent, - ggml_tensor* noise, - SDCondition cond, - SDCondition uncond, - ggml_tensor* control_hint, - float control_strength, - float min_cfg, - float cfg_scale, - float guidance, - float eta, - sample_method_t method, - const std::vector& sigmas, - int start_merge_step, - SDCondition id_cond, - std::vector skip_layers = {}, - float slg_scale = 0, - float skip_layer_start = 0.01, - float skip_layer_end = 0.2, - ggml_tensor* noise_mask = nullptr) { - LOG_DEBUG("Sample"); - struct ggml_init_params params; - size_t data_size = ggml_row_size(init_latent->type, init_latent->ne[0]); - for (int i = 1; i < 4; i++) { - data_size *= init_latent->ne[i]; + void silent_tiling(ggml_tensor* input, ggml_tensor* output, const int scale, const int tile_size, const float tile_overlap_factor, on_tile_process on_processing) { + sd_progress_cb_t cb = sd_get_progress_callback(); + void* cbd = sd_get_progress_callback_data(); + sd_set_progress_callback((sd_progress_cb_t)suppress_pp, NULL); + sd_tiling(input, output, scale, tile_size, tile_overlap_factor, on_processing); + sd_set_progress_callback(cb, cbd); + } + + void preview_image(ggml_context* work_ctx, + int step, + struct ggml_tensor* latents, + enum SDVersion version, + sd_preview_t preview_mode, + ggml_tensor* result, + std::function step_callback) { + const uint32_t channel = 3; + uint32_t width = latents->ne[0]; + uint32_t height = latents->ne[1]; + uint32_t dim = latents->ne[2]; + if (preview_mode == SD_PREVIEW_PROJ) { + const float(*latent_rgb_proj)[channel]; + + if (dim == 16) { + // 16 channels VAE -> Flux or SD3 + + if (sd_version_is_sd3(version)) { + latent_rgb_proj = sd3_latent_rgb_proj; + } else if (sd_version_is_flux(version)) { + latent_rgb_proj = flux_latent_rgb_proj; + } else { + LOG_WARN("No latent to RGB projection known for this model"); + // unknown model + return; + } + + } else if (dim == 4) { + // 4 channels VAE + if (sd_version_is_sdxl(version)) { + latent_rgb_proj = sdxl_latent_rgb_proj; + } else if (sd_version_is_sd1(version) || sd_version_is_sd2(version)) { + latent_rgb_proj = sd_latent_rgb_proj; + } else { + // unknown model + LOG_WARN("No latent to RGB projection known for this model"); + return; + } + } else { + LOG_WARN("No latent to RGB projection known for this model"); + // unknown latent space + return; + } + uint8_t* data = (uint8_t*)malloc(width * height * channel * sizeof(uint8_t)); + + preview_latent_image(data, latents, latent_rgb_proj, width, height, dim); + sd_image_t image = { + width, + height, + channel, + data}; + step_callback(step, image); + free(image.data); + } else { + if (preview_mode == SD_PREVIEW_VAE) { + ggml_tensor_scale(latents, 1.0f / scale_factor); + if (vae_tiling) { + // split latent in 32x32 tiles and compute in several steps + auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) { + first_stage_model->compute(n_threads, in, true, &out); + }; + silent_tiling(latents, result, 8, 32, 0.5f, on_tiling); + + } else { + first_stage_model->compute(n_threads, latents, true, &result); + } + first_stage_model->free_compute_buffer(); + ggml_tensor_scale(latents, scale_factor); + + ggml_tensor_scale_output(result); + } else if (preview_mode == SD_PREVIEW_TAE) { + if (tae_first_stage == nullptr) { + LOG_WARN("TAE not found for preview"); + return; + } + if (vae_tiling) { + // split latent in 64x64 tiles and compute in several steps + auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) { + tae_first_stage->compute(n_threads, in, true, &out); + }; + silent_tiling(latents, result, 8, 64, 0.5f, on_tiling); + } else { + tae_first_stage->compute(n_threads, latents, true, &result); + } + tae_first_stage->free_compute_buffer(); + } else { + return; + } + ggml_tensor_clamp(result, 0.0f, 1.0f); + sd_image_t image = { + width * 8, + height * 8, + channel, + sd_tensor_to_image(result)}; + ggml_tensor_scale(result, 0); + step_callback(step, image); + free(image.data); } - data_size += 1024; - params.mem_size = data_size * 3; - params.mem_buffer = NULL; - params.no_alloc = false; - ggml_context* tmp_ctx = ggml_init(params); + } + ggml_tensor* + sample(ggml_context* work_ctx, + ggml_tensor* init_latent, + ggml_tensor* noise, + SDCondition cond, + SDCondition uncond, + ggml_tensor* control_hint, + float control_strength, + float min_cfg, + float cfg_scale, + float guidance, + float eta, + sample_method_t method, + const std::vector& sigmas, + int start_merge_step, + SDCondition id_cond, + std::vector skip_layers = {}, + float slg_scale = 0, + float skip_layer_start = 0.01, + float skip_layer_end = 0.2, + ggml_tensor* noise_mask = nullptr) { size_t steps = sigmas.size() - 1; // noise = load_tensor_from_file(work_ctx, "./rand0.bin"); // print_ggml_tensor(noise); @@ -847,6 +957,16 @@ class StableDiffusionGGML { } struct ggml_tensor* denoised = ggml_dup_tensor(work_ctx, x); + struct ggml_tensor* preview_tensor = NULL; + auto sd_preview_mode = sd_get_preview_mode(); + if (sd_preview_mode != SD_PREVIEW_NONE && sd_preview_mode != SD_PREVIEW_PROJ) { + preview_tensor = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, + (denoised->ne[0] * 8), + (denoised->ne[1] * 8), + 3, + denoised->ne[3]); + } + auto denoise = [&](ggml_tensor* input, float sigma, int step) -> ggml_tensor* { if (step == 1) { pretty_progress(0, (int)steps, 0); @@ -971,10 +1091,6 @@ class StableDiffusionGGML { vec_denoised[i] = latent_result * c_out + vec_input[i] * c_skip; } int64_t t1 = ggml_time_us(); - if (step > 0) { - pretty_progress(step, (int)steps, (t1 - t0) / 1000000.f); - // LOG_INFO("step %d sampling completed taking %.2fs", step, (t1 - t0) * 1.0f / 1000000); - } if (noise_mask != nullptr) { for (int64_t x = 0; x < denoised->ne[0]; x++) { for (int64_t y = 0; y < denoised->ne[1]; y++) { @@ -987,7 +1103,17 @@ class StableDiffusionGGML { } } } - + if (step > 0) { + pretty_progress(step, (int)steps, (t1 - t0) / 1000000.f); + // LOG_INFO("step %d sampling completed taking %.2fs", step, (t1 - t0) * 1.0f / 1000000); + } + auto sd_preview_cb = sd_get_preview_callback(); + auto sd_preview_mode = sd_get_preview_mode(); + if (sd_preview_cb != NULL) { + if (step % sd_get_preview_interval() == 0) { + preview_image(work_ctx, step, denoised, version, sd_preview_mode, preview_tensor, sd_preview_cb); + } + } return denoised; }; @@ -1130,7 +1256,8 @@ sd_ctx_t* new_sd_ctx(const char* model_path_c_str, bool keep_clip_on_cpu, bool keep_control_net_cpu, bool keep_vae_on_cpu, - bool diffusion_flash_attn) { + bool diffusion_flash_attn, + bool tae_preview_only) { sd_ctx_t* sd_ctx = (sd_ctx_t*)malloc(sizeof(sd_ctx_t)); if (sd_ctx == NULL) { return NULL; @@ -1172,7 +1299,8 @@ sd_ctx_t* new_sd_ctx(const char* model_path_c_str, keep_clip_on_cpu, keep_control_net_cpu, keep_vae_on_cpu, - diffusion_flash_attn)) { + diffusion_flash_attn, + tae_preview_only)) { delete sd_ctx->sd; sd_ctx->sd = NULL; free(sd_ctx); @@ -1561,6 +1689,10 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx, if (sd_ctx->sd->stacked_id) { params.mem_size += static_cast(10 * 1024 * 1024); // 10 MB } + auto sd_preview_mode = sd_get_preview_mode(); + if (sd_preview_mode != SD_PREVIEW_NONE && sd_preview_mode != SD_PREVIEW_PROJ) { + params.mem_size *= 2; + } params.mem_size += width * height * 3 * sizeof(float); params.mem_size *= batch_count; params.mem_buffer = NULL; @@ -1621,7 +1753,8 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx, skip_layers_vec, slg_scale, skip_layer_start, - skip_layer_end); + skip_layer_end, + NULL); size_t t1 = ggml_time_ms(); @@ -1904,7 +2037,9 @@ SD_API sd_image_t* img2vid(sd_ctx_t* sd_ctx, sample_method, sigmas, -1, - SDCondition(NULL, NULL, NULL)); + SDCondition(NULL, NULL, NULL), + {}, + 0, 0, 0, NULL); int64_t t2 = ggml_time_ms(); LOG_INFO("sampling completed, taking %.2fs", (t2 - t1) * 1.0f / 1000); diff --git a/stable-diffusion.h b/stable-diffusion.h index 8872bbaa..d422cea6 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -109,13 +109,13 @@ enum sd_log_level_t { SD_LOG_ERROR }; -typedef void (*sd_log_cb_t)(enum sd_log_level_t level, const char* text, void* data); -typedef void (*sd_progress_cb_t)(int step, int steps, float time, void* data); - -SD_API void sd_set_log_callback(sd_log_cb_t sd_log_cb, void* data); -SD_API void sd_set_progress_callback(sd_progress_cb_t cb, void* data); -SD_API int32_t get_num_physical_cores(); -SD_API const char* sd_get_system_info(); +enum sd_preview_t { + SD_PREVIEW_NONE, + SD_PREVIEW_PROJ, + SD_PREVIEW_TAE, + SD_PREVIEW_VAE, + N_PREVIEWS +}; typedef struct { uint32_t width; @@ -124,6 +124,17 @@ typedef struct { uint8_t* data; } sd_image_t; +typedef void (*sd_log_cb_t)(enum sd_log_level_t level, const char* text, void* data); +typedef void (*sd_progress_cb_t)(int step, int steps, float time, void* data); +typedef void (*sd_preview_cb_t)(int, sd_image_t); + + +SD_API void sd_set_log_callback(sd_log_cb_t sd_log_cb, void* data); +SD_API void sd_set_progress_callback(sd_progress_cb_t cb, void* data); +SD_API void sd_set_preview_callback(sd_preview_cb_t cb, sd_preview_t mode, int interval); +SD_API int32_t get_num_physical_cores(); +SD_API const char* sd_get_system_info(); + typedef struct sd_ctx_t sd_ctx_t; SD_API sd_ctx_t* new_sd_ctx(const char* model_path, @@ -147,7 +158,8 @@ SD_API sd_ctx_t* new_sd_ctx(const char* model_path, bool keep_clip_on_cpu, bool keep_control_net_cpu, bool keep_vae_on_cpu, - bool diffusion_flash_attn); + bool diffusion_flash_attn, + bool tae_preview_only); SD_API void free_sd_ctx(sd_ctx_t* sd_ctx); diff --git a/util.cpp b/util.cpp index da11a14d..bf7178ca 100644 --- a/util.cpp +++ b/util.cpp @@ -247,6 +247,10 @@ int32_t get_num_physical_cores() { static sd_progress_cb_t sd_progress_cb = NULL; void* sd_progress_cb_data = NULL; +static sd_preview_cb_t sd_preview_cb = NULL; +sd_preview_t sd_preview_mode = SD_PREVIEW_NONE; +int sd_preview_interval = 1; + std::u32string utf8_to_utf32(const std::string& utf8_str) { std::wstring_convert, char32_t> converter; return converter.from_bytes(utf8_str); @@ -420,6 +424,29 @@ void sd_set_progress_callback(sd_progress_cb_t cb, void* data) { sd_progress_cb = cb; sd_progress_cb_data = data; } +void sd_set_preview_callback(sd_preview_cb_t cb, sd_preview_t mode = SD_PREVIEW_PROJ, int interval = 1) { + sd_preview_cb = cb; + sd_preview_mode = mode; + sd_preview_interval = interval; +} + +sd_preview_cb_t sd_get_preview_callback() { + return sd_preview_cb; +} + +sd_preview_t sd_get_preview_mode() { + return sd_preview_mode; +} +int sd_get_preview_interval() { + return sd_preview_interval; +} + +sd_progress_cb_t sd_get_progress_callback() { + return sd_progress_cb; +} +void* sd_get_progress_callback_data() { + return sd_progress_cb_data; +} const char* sd_get_system_info() { static char buffer[1024]; std::stringstream ss; diff --git a/util.h b/util.h index 14fa812e..36a2e18a 100644 --- a/util.h +++ b/util.h @@ -54,6 +54,13 @@ std::string trim(const std::string& s); std::vector> parse_prompt_attention(const std::string& text); +sd_progress_cb_t sd_get_progress_callback(); +void* sd_get_progress_callback_data(); + +sd_preview_cb_t sd_get_preview_callback(); +sd_preview_t sd_get_preview_mode(); +int sd_get_preview_interval(); + #define LOG_DEBUG(format, ...) log_printf(SD_LOG_DEBUG, __FILE__, __LINE__, format, ##__VA_ARGS__) #define LOG_INFO(format, ...) log_printf(SD_LOG_INFO, __FILE__, __LINE__, format, ##__VA_ARGS__) #define LOG_WARN(format, ...) log_printf(SD_LOG_WARN, __FILE__, __LINE__, format, ##__VA_ARGS__)