Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add checkbox, config and handling for saving only the final enhanced image #61

Merged
merged 1 commit into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion args_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

args_parser.parser.add_argument("--theme", type=str, help="launches the UI with light or dark theme", default=None)
args_parser.parser.add_argument("--disable-image-log", action='store_true',
help="Prevent writing images and logs to hard drive.")
help="Prevent writing images and logs to the outputs folder.")

args_parser.parser.add_argument("--disable-analytics", action='store_true',
help="Disables analytics for Gradio.")
Expand Down
3 changes: 3 additions & 0 deletions language/en.json
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@
"Read wildcards in order": "Read wildcards in order",
"Black Out NSFW": "Black Out NSFW",
"Use black image if NSFW is detected.": "Use black image if NSFW is detected.",
"Save only final enhanced image": "Save only final enhanced image",
"Save Metadata to Images": "Save Metadata to Images",
"Adds parameters to generated images allowing manual regeneration.": "Adds parameters to generated images allowing manual regeneration.",
"\ud83d\udcda History Log": "\uD83D\uDCDA History Log",
"Image Style": "Image Style",
"Fooocus V2": "Fooocus V2",
Expand Down
52 changes: 32 additions & 20 deletions modules/async_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def __init__(self, args):
self.inpaint_advanced_masking_checkbox = args.pop()
self.invert_mask_checkbox = args.pop()
self.inpaint_erode_or_dilate = args.pop()
self.save_final_enhanced_image_only = args.pop() if not args_manager.args.disable_image_log else False
self.save_metadata_to_images = args.pop() if not args_manager.args.disable_metadata else False
self.metadata_scheme = MetadataScheme(
args.pop()) if not args_manager.args.disable_metadata else MetadataScheme.FOOOCUS
Expand Down Expand Up @@ -278,7 +279,7 @@ def build_image_wall(async_task):
def process_task(all_steps, async_task, callback, controlnet_canny_path, controlnet_cpds_path, current_task_id,
denoising_strength, final_scheduler_name, goals, initial_latent, steps, switch, positive_cond,
negative_cond, task, loras, tiled, use_expansion, width, height, base_progress, preparation_steps,
total_count, show_intermediate_results):
total_count, show_intermediate_results, persist_image=True):
if async_task.last_stop is not False:
ldm_patched.modules.model_management.interrupt_current_processing()
if 'cn' in goals:
Expand Down Expand Up @@ -315,9 +316,8 @@ def process_task(all_steps, async_task, callback, controlnet_canny_path, control
if modules.config.default_black_out_nsfw or async_task.black_out_nsfw:
progressbar(async_task, current_progress, 'Checking for NSFW content ...')
imgs = default_censor(imgs)
progressbar(async_task, current_progress,
f'Saving image {current_task_id + 1}/{total_count} to system ...')
img_paths = save_and_log(async_task, height, imgs, task, use_expansion, width, loras)
progressbar(async_task, current_progress, f'Saving image {current_task_id + 1}/{total_count} to system ...')
img_paths = save_and_log(async_task, height, imgs, task, use_expansion, width, loras, persist_image)
yield_result(async_task, img_paths, current_progress, async_task.black_out_nsfw, False,
do_not_show_finished_images=not show_intermediate_results or async_task.disable_intermediate_results)

Expand All @@ -333,7 +333,7 @@ def apply_patch_settings(async_task):
async_task.adaptive_cfg
)

def save_and_log(async_task, height, imgs, task, use_expansion, width, loras) -> list:
def save_and_log(async_task, height, imgs, task, use_expansion, width, loras, persist_image=True) -> list:
img_paths = []
for x in imgs:
d = [('Prompt', 'prompt', task['log_positive_prompt']),
Expand Down Expand Up @@ -388,7 +388,7 @@ def save_and_log(async_task, height, imgs, task, use_expansion, width, loras) ->
d.append(('Metadata Scheme', 'metadata_scheme',
async_task.metadata_scheme.value if async_task.save_metadata_to_images else async_task.save_metadata_to_images))
d.append(('Version', 'version', 'Fooocus v' + fooocus_version.version))
img_paths.append(log(x, d, metadata_parser, async_task.output_format, task))
img_paths.append(log(x, d, metadata_parser, async_task.output_format, task, persist_image))

return img_paths

Expand Down Expand Up @@ -963,7 +963,7 @@ def process_enhance(all_steps, async_task, callback, controlnet_canny_path, cont
inpaint_engine, inpaint_respective_field, inpaint_strength,
prompt, negative_prompt, final_scheduler_name, goals, height, img, mask,
preparation_steps, steps, switch, tiled, total_count, use_expansion, use_style,
use_synthetic_refiner, width, show_intermediate_results=True):
use_synthetic_refiner, width, show_intermediate_results=True, persist_image=True):
base_model_additional_loras = []
inpaint_head_model_path = None
inpaint_parameterized = inpaint_engine != 'None' # inpaint_engine = None, improve detail
Expand All @@ -985,7 +985,7 @@ def process_enhance(all_steps, async_task, callback, controlnet_canny_path, cont
progressbar(async_task, current_progress, 'Checking for NSFW content ...')
img = default_censor(img)
progressbar(async_task, current_progress, f'Saving image {current_task_id + 1}/{total_count} to system ...')
uov_image_path = log(img, d, output_format=async_task.output_format)
uov_image_path = log(img, d, output_format=async_task.output_format, persist_image=persist_image)
yield_result(async_task, uov_image_path, current_progress, async_task.black_out_nsfw, False,
do_not_show_finished_images=not show_intermediate_results or async_task.disable_intermediate_results)
return current_progress, img, prompt, negative_prompt
Expand Down Expand Up @@ -1019,15 +1019,16 @@ def process_enhance(all_steps, async_task, callback, controlnet_canny_path, cont
final_scheduler_name, goals, initial_latent, steps, switch,
task_enhance['c'], task_enhance['uc'], task_enhance, loras,
tiled, use_expansion, width, height, current_progress,
preparation_steps, total_count, show_intermediate_results)
preparation_steps, total_count, show_intermediate_results,
persist_image)

del task_enhance['c'], task_enhance['uc'] # Save memory
return current_progress, imgs[0], prompt, negative_prompt

def enhance_upscale(all_steps, async_task, base_progress, callback, controlnet_canny_path, controlnet_cpds_path,
current_task_id, denoising_strength, done_steps_inpainting, done_steps_upscaling, enhance_steps,
prompt, negative_prompt, final_scheduler_name, height, img, preparation_steps, switch, tiled,
total_count, use_expansion, use_style, use_synthetic_refiner, width):
total_count, use_expansion, use_style, use_synthetic_refiner, width, persist_image=True):
# reset inpaint worker to prevent tensor size issues and not mix upscale and inpainting
inpaint_worker.current_task = None

Expand All @@ -1045,7 +1046,7 @@ def enhance_upscale(all_steps, async_task, base_progress, callback, controlnet_c
controlnet_cpds_path, current_progress, current_task_id, denoising_strength, False,
'None', 0.0, 0.0, prompt, negative_prompt, final_scheduler_name,
goals_enhance, height, img, None, preparation_steps, steps, switch, tiled, total_count,
use_expansion, use_style, use_synthetic_refiner, width)
use_expansion, use_style, use_synthetic_refiner, width, persist_image=persist_image)

except ldm_patched.modules.model_management.InterruptProcessingException:
if async_task.last_stop == 'skip':
Expand Down Expand Up @@ -1168,6 +1169,8 @@ def handler(async_task: AsyncTask):
current_progress += 1
progressbar(async_task, current_progress, 'Image processing ...')

should_enhance = async_task.enhance_checkbox and (async_task.enhance_uov_method != flags.disabled.casefold() or len(async_task.enhance_ctrls) > 0)

if 'vary' in goals:
async_task.uov_input_image, denoising_strength, initial_latent, width, height, current_progress = apply_vary(
async_task, async_task.uov_method, denoising_strength, async_task.uov_input_image, switch,
Expand Down Expand Up @@ -1273,8 +1276,8 @@ def callback(step, x0, x, total_steps, y):
int(current_progress + async_task.callback_steps),
f'Sampling step {step + 1}/{total_steps}, image {current_task_id + 1}/{total_count} ...', y)])

should_enhance = async_task.enhance_checkbox and (async_task.enhance_uov_method != flags.disabled.casefold() or len(async_task.enhance_ctrls) > 0)
show_intermediate_results = len(tasks) > 1 or should_enhance
persist_image = not should_enhance or not async_task.save_final_enhanced_image_only

for current_task_id, task in enumerate(tasks):
progressbar(async_task, current_progress, f'Preparing task {current_task_id + 1}/{async_task.image_number} ...')
Expand All @@ -1287,7 +1290,8 @@ def callback(step, x0, x, total_steps, y):
initial_latent, async_task.steps, switch, task['c'],
task['uc'], task, loras, tiled, use_expansion, width,
height, current_progress, preparation_steps,
async_task.image_number, show_intermediate_results)
async_task.image_number, show_intermediate_results,
persist_image)

current_progress = int(preparation_steps + (100 - preparation_steps) / float(all_steps) * async_task.steps * (current_task_id + 1))
images_to_enhance += imgs
Expand All @@ -1314,8 +1318,12 @@ def callback(step, x0, x, total_steps, y):

active_enhance_tabs = len(async_task.enhance_ctrls)
should_process_enhance_uov = async_task.enhance_uov_method != flags.disabled.casefold()
enhance_uov_before = False
enhance_uov_after = False
if should_process_enhance_uov:
active_enhance_tabs += 1
enhance_uov_before = async_task.enhance_uov_processing_order == flags.enhancement_uov_before
enhance_uov_after = async_task.enhance_uov_processing_order == flags.enhancement_uov_after
total_count = len(images_to_enhance) * active_enhance_tabs

base_progress = current_progress
Expand All @@ -1330,13 +1338,14 @@ def callback(step, x0, x, total_steps, y):
last_enhance_prompt = async_task.prompt
last_enhance_negative_prompt = async_task.negative_prompt

if should_process_enhance_uov and async_task.enhance_uov_processing_order == flags.enhancement_uov_before:
if enhance_uov_before:
current_task_id += 1
persist_image = not async_task.save_final_enhanced_image_only or active_enhance_tabs == 0
current_task_id, done_steps_inpainting, done_steps_upscaling, img, exception_result = enhance_upscale(
all_steps, async_task, base_progress, callback, controlnet_canny_path, controlnet_cpds_path,
current_task_id, denoising_strength, done_steps_inpainting, done_steps_upscaling, enhance_steps,
async_task.prompt, async_task.negative_prompt, final_scheduler_name, height, img, preparation_steps,
switch, tiled, total_count, use_expansion, use_style, use_synthetic_refiner, width)
switch, tiled, total_count, use_expansion, use_style, use_synthetic_refiner, width, persist_image)
if exception_result == 'continue':
continue
elif exception_result == 'break':
Expand All @@ -1348,6 +1357,8 @@ def callback(step, x0, x, total_steps, y):
current_progress = int(base_progress + (100 - preparation_steps) / float(all_steps) * (done_steps_upscaling + done_steps_inpainting))
progressbar(async_task, current_progress, f'Preparing enhancement {current_task_id + 1}/{total_count} ...')
enhancement_task_start_time = time.perf_counter()
is_last_enhance_for_image = (current_task_id + 1) % active_enhance_tabs == 0 and not enhance_uov_after
persist_image = not async_task.save_final_enhanced_image_only or is_last_enhance_for_image

extras = {}
if enhance_mask_model == 'sam':
Expand Down Expand Up @@ -1383,8 +1394,7 @@ def callback(step, x0, x, total_steps, y):
print(f'[Enhance] {sam_detection_count} segments detected in boxes')
print(f'[Enhance] {sam_detection_on_mask_count} segments applied to mask')

if enhance_mask_model == 'sam' and (
dino_detection_count == 0 or not async_task.debugging_dino and sam_detection_on_mask_count == 0):
if enhance_mask_model == 'sam' and (dino_detection_count == 0 or not async_task.debugging_dino and sam_detection_on_mask_count == 0):
print(f'[Enhance] No "{enhance_mask_dino_prompt_text}" detected, skipping')
continue

Expand All @@ -1397,7 +1407,7 @@ def callback(step, x0, x, total_steps, y):
enhance_inpaint_engine, enhance_inpaint_respective_field, enhance_inpaint_strength,
enhance_prompt, enhance_negative_prompt, final_scheduler_name, goals_enhance, height, img, mask,
preparation_steps, enhance_steps, switch, tiled, total_count, use_expansion, use_style,
use_synthetic_refiner, width)
use_synthetic_refiner, width, persist_image=persist_image)

if (should_process_enhance_uov and async_task.enhance_uov_processing_order == flags.enhancement_uov_after
and async_task.enhance_uov_prompt_type == flags.enhancement_uov_prompt_type_last_filled):
Expand All @@ -1424,14 +1434,16 @@ def callback(step, x0, x, total_steps, y):
if exception_result == 'break':
break

if should_process_enhance_uov and async_task.enhance_uov_processing_order == flags.enhancement_uov_after:
if enhance_uov_after:
current_task_id += 1
# last step in enhance, always save
persist_image = True
current_task_id, done_steps_inpainting, done_steps_upscaling, img, exception_result = enhance_upscale(
all_steps, async_task, base_progress, callback, controlnet_canny_path, controlnet_cpds_path,
current_task_id, denoising_strength, done_steps_inpainting, done_steps_upscaling, enhance_steps,
last_enhance_prompt, last_enhance_negative_prompt, final_scheduler_name, height, img,
preparation_steps, switch, tiled, total_count, use_expansion, use_style, use_synthetic_refiner,
width)
width, persist_image)
if exception_result == 'continue':
continue
elif exception_result == 'break':
Expand Down
6 changes: 6 additions & 0 deletions modules/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,12 @@ def init_temp_path(path: str | None, default_path: str) -> str:
validator=lambda x: isinstance(x, bool),
expected_type=bool
)
default_save_only_final_enhanced_image = get_config_item_or_set_default(
key='default_save_only_final_enhanced_image',
default_value=False,
validator=lambda x: isinstance(x, bool),
expected_type=bool
)
default_save_metadata_to_images = get_config_item_or_set_default(
key='default_save_metadata_to_images',
default_value=False,
Expand Down
4 changes: 2 additions & 2 deletions modules/private_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ def get_current_html_path(output_format=None):
return html_name


def log(img, metadata, metadata_parser: MetadataParser | None = None, output_format=None, task=None) -> str:
path_outputs = modules.config.temp_path if args_manager.args.disable_image_log else modules.config.path_outputs
def log(img, metadata, metadata_parser: MetadataParser | None = None, output_format=None, task=None, persist_image=True) -> str:
path_outputs = modules.config.temp_path if args_manager.args.disable_image_log or not persist_image else modules.config.path_outputs
output_format = output_format if output_format else modules.config.default_output_format
date_string, local_temp_filename, only_name = generate_temp_filename(folder=path_outputs, extension=output_format)
os.makedirs(os.path.dirname(local_temp_filename), exist_ok=True)
Expand Down
7 changes: 7 additions & 0 deletions webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,6 +753,10 @@ def update_history_link():
inputs=black_out_nsfw, outputs=disable_preview, queue=False,
show_progress=False)

if not args_manager.args.disable_image_log:
save_final_enhanced_image_only = gr.Checkbox(label='Save only final enhanced image',
value=modules.config.default_save_only_final_enhanced_image)

if not args_manager.args.disable_metadata:
save_metadata_to_images = gr.Checkbox(label='Save Metadata to Images', value=modules.config.default_save_metadata_to_images,
info='Adds parameters to generated images allowing manual regeneration.')
Expand Down Expand Up @@ -992,6 +996,9 @@ def inpaint_engine_state_change(inpaint_engine_version, *args):
ctrls += freeu_ctrls
ctrls += inpaint_ctrls

if not args_manager.args.disable_image_log:
ctrls += [save_final_enhanced_image_only]

if not args_manager.args.disable_metadata:
ctrls += [save_metadata_to_images, metadata_scheme]

Expand Down