From 70d4b1ca2a27ff6e67aada0a47cb02670adfe056 Mon Sep 17 00:00:00 2001 From: hlky <106811348+hlky@users.noreply.github.com> Date: Wed, 7 Sep 2022 00:48:13 +0100 Subject: [PATCH] img2img-fix (#717) --- frontend/css_and_js.py | 7 +- frontend/frontend.py | 93 +++++------ frontend/job_manager.py | 168 ++++--------------- frontend/ui_functions.py | 30 +--- scripts/webui.py | 352 +++++++++++---------------------------- 5 files changed, 179 insertions(+), 471 deletions(-) diff --git a/frontend/css_and_js.py b/frontend/css_and_js.py index 0ee5a5c4f..266cf43a9 100644 --- a/frontend/css_and_js.py +++ b/frontend/css_and_js.py @@ -15,14 +15,9 @@ def css(opt): # TODO: @altryne restore this before merge if not opt.no_progressbar_hiding: styling += readTextFile("css", "no_progress_bar.css") - if opt.custom_css: - try: - styling += readTextFile("css", "custom.css") - print("Custom CSS loaded") - except: - pass return styling + def js(opt): data = readTextFile("js", "index.js") data = "(z) => {" + data + "; return z ?? [] }" diff --git a/frontend/frontend.py b/frontend/frontend.py index 3627210e5..ae764c1f5 100644 --- a/frontend/frontend.py +++ b/frontend/frontend.py @@ -57,21 +57,20 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x,imgproc=lambda output_txt2img_params = gr.Highlightedtext(label="Generation parameters", interactive=False, elem_id='highlight') with gr.Group(): with gr.Row(elem_id='txt2img_output_row'): - output_txt2img_copy_params = gr.Button("Copy all").click( + output_txt2img_copy_params = gr.Button("Copy full parameters").click( inputs=[output_txt2img_params], outputs=[], _js=js_copy_txt2img_output, fn=None, show_progress=False) output_txt2img_seed = gr.Number(label='Seed', interactive=False, visible=False) - output_txt2img_copy_seed = gr.Button("Copy seed").click( + output_txt2img_copy_seed = gr.Button("Copy only seed").click( inputs=[output_txt2img_seed], outputs=[], _js='(x) => navigator.clipboard.writeText(x)', fn=None, show_progress=False) output_txt2img_stats = gr.HTML(label='Stats') with gr.Column(): - with gr.Row(): - txt2img_steps = gr.Slider(minimum=1, maximum=250, step=1, label="Sampling Steps", - value=txt2img_defaults['ddim_steps']) - txt2img_sampling = gr.Dropdown(label='Sampling method (k_lms is default k-diffusion sampler)', + txt2img_steps = gr.Slider(minimum=1, maximum=250, step=1, label="Sampling Steps", + value=txt2img_defaults['ddim_steps']) + txt2img_sampling = gr.Dropdown(label='Sampling method (k_lms is default k-diffusion sampler)', choices=["DDIM", "PLMS", 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms'], value=txt2img_defaults['sampler_name']) @@ -158,28 +157,22 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x,imgproc=lambda img2img_btn_editor = gr.Button("Generate", variant="primary", elem_id="img2img_edit_btn") with gr.Row().style(equal_height=False): with gr.Column(): - with gr.Tabs(): - with gr.TabItem("Img2Img Input"): - #gr.Markdown('#### Img2Img Input') - img2img_image_editor = gr.Image(value=sample_img2img, source="upload", interactive=True, - type="pil", tool="select", elem_id="img2img_editor", - image_mode="RGBA") - img2img_image_mask = gr.Image(value=sample_img2img, source="upload", interactive=True, - type="pil", tool="sketch", visible=False, - elem_id="img2img_mask") - - with gr.TabItem("Img2Img Mask Input"): - img2img_mask_input = gr.Image(label="Mask",source="upload", interactive=False, - type="pil", visible=True) + gr.Markdown('#### Img2Img Input') + img2img_image_editor = gr.Image(value=sample_img2img, source="upload", interactive=True, + type="pil", tool="select", elem_id="img2img_editor", image_mode="RGBA" + ) + img2img_image_mask = gr.Image(value=sample_img2img, source="upload", interactive=True, + type="pil", tool="sketch", visible=False, image_mode="RGBA", + elem_id="img2img_mask") with gr.Tabs(): with gr.TabItem("Editor Options"): with gr.Row(): img2img_image_editor_mode = gr.Radio(choices=["Mask", "Crop", "Uncrop"], label="Image Editor Mode", value="Crop", elem_id='edit_mode_select') - img2img_mask = gr.Radio(choices=["Keep masked area", "Regenerate only masked area", "Resize and regenerate only masked area"], + img2img_mask = gr.Radio(choices=["Keep masked area", "Regenerate only masked area"], label="Mask Mode", type="index", - value=img2img_mask_modes[img2img_defaults['mask_mode']], visible=False) + value=img2img_mask_modes[img2img_defaults['mask_mode']], visible=False) img2img_mask_blur_strength = gr.Slider(minimum=1, maximum=10, step=1, label="How much blurry should the mask be? (to avoid hard edges)", @@ -263,16 +256,22 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x,imgproc=lambda img2img_image_editor_mode.change( uifn.change_image_editor_mode, - [img2img_image_editor_mode, img2img_image_editor, img2img_resize, img2img_width, img2img_height], + [img2img_image_editor_mode, + img2img_image_editor, + img2img_image_mask, + img2img_resize, + img2img_width, + img2img_height + ], [img2img_image_editor, img2img_image_mask, img2img_btn_editor, img2img_btn_mask, - img2img_painterro_btn, img2img_mask, img2img_mask_blur_strength, img2img_mask_input] + img2img_painterro_btn, img2img_mask, img2img_mask_blur_strength] ) - img2img_image_editor.edit( - uifn.update_image_mask, - [img2img_image_editor, img2img_resize, img2img_width, img2img_height], - img2img_image_mask - ) + # img2img_image_editor_mode.change( + # uifn.update_image_mask, + # [img2img_image_editor, img2img_resize, img2img_width, img2img_height], + # img2img_image_mask + # ) output_txt2img_copy_to_input_btn.click( uifn.copy_img_to_input, @@ -306,11 +305,11 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x,imgproc=lambda ) img2img_func = img2img - img2img_inputs = [img2img_prompt, img2img_image_editor_mode, img2img_image_editor, img2img_image_mask, img2img_mask, + img2img_inputs = [img2img_prompt, img2img_image_editor_mode, img2img_mask, img2img_mask_blur_strength, img2img_steps, img2img_sampling, img2img_toggles, img2img_realesrgan_model_name, img2img_batch_count, img2img_cfg, img2img_denoising, img2img_seed, img2img_height, img2img_width, img2img_resize, - img2img_embeddings, img2img_mask_input] + img2img_image_editor, img2img_image_mask, img2img_embeddings] img2img_outputs = [output_img2img_gallery, output_img2img_seed, output_img2img_params, output_img2img_stats] # If a JobManager was passed in then wrap the Generate functions @@ -321,33 +320,23 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x,imgproc=lambda outputs=img2img_outputs, ) - def generate(*args): - args_list = list(args) - init_info_mask = args_list[3] - # Get the mask input and remove it from the list - mask_input = args_list[18] - del args_list[18] - - # If an external mask is set, use it - if mask_input: - init_info_mask['mask'] = mask_input - - args_list[3] = init_info_mask - - # Return the result of img2img - return img2img_func(*args_list) - img2img_btn_mask.click( - generate, + img2img_func, img2img_inputs, img2img_outputs ) - - img2img_btn_editor.click( - img2img_func, + def img2img_submit_params(): + #print([img2img_prompt, img2img_image_editor_mode, img2img_mask, + # img2img_mask_blur_strength, img2img_steps, img2img_sampling, img2img_toggles, + # img2img_realesrgan_model_name, img2img_batch_count, img2img_cfg, + # img2img_denoising, img2img_seed, img2img_height, img2img_width, img2img_resize, + # img2img_image_editor, img2img_image_mask, img2img_embeddings]) + return (img2img_func, img2img_inputs, img2img_outputs) + img2img_btn_editor.click(*img2img_submit_params()) + # GENERATE ON ENTER img2img_prompt.submit(None, None, None, _js=call_JS("clickFirstVisibleButton", @@ -374,7 +363,7 @@ def generate(*args): # value=gfpgan_defaults['strength']) #select folder with images to process with gr.TabItem('Batch Process'): - imgproc_folder = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file") + imgproc_folder = gr.File(label="Batch Process", file_count="multiple",source="upload", interactive=True, type="file") imgproc_pngnfo = gr.Textbox(label="PNG Metadata", placeholder="PngNfo", visible=False, max_lines=5) with gr.Row(): imgproc_btn = gr.Button("Process", variant="primary") @@ -580,7 +569,7 @@ def generate(*args):

For help and advanced usage guides, visit the Project Wiki

Stable Diffusion WebUI is an open-source project. You can find the latest stable builds on the main repository. - If you would like to contribute to development or test bleeding edge builds, you can visit the development repository.

+ If you would like to contribute to development or test bleeding edge builds, you can visit the developement repository.

""") # Hack: Detect the load event on the frontend diff --git a/frontend/job_manager.py b/frontend/job_manager.py index 038b1d93d..8eda8d9a6 100644 --- a/frontend/job_manager.py +++ b/frontend/job_manager.py @@ -1,7 +1,7 @@ ''' Provides simple job management for gradio, allowing viewing and stopping in-progress multi-batch generations ''' from __future__ import annotations import gradio as gr -from gradio.components import Component, Gallery, Slider +from gradio.components import Component, Gallery from threading import Event, Timer from typing import Callable, List, Dict, Tuple, Optional, Any from dataclasses import dataclass, field @@ -30,17 +30,7 @@ class JobInfo: session_key: str job_token: Optional[int] = None images: List[Image] = field(default_factory=list) - active_image: Image = None - rec_steps_enabled: bool = False - rec_steps_imgs: List[Image] = field(default_factory=list) - rec_steps_intrvl: int = None - rec_steps_to_gallery: bool = False - rec_steps_to_file: bool = False should_stop: Event = field(default_factory=Event) - refresh_active_image_requested: Event = field(default_factory=Event) - refresh_active_image_done: Event = field(default_factory=Event) - stop_cur_iter: Event = field(default_factory=Event) - active_iteration_cnt: int = field(default_factory=int) job_status: str = field(default_factory=str) finished: bool = False removed_output_idxs: List[int] = field(default_factory=list) @@ -86,7 +76,7 @@ def wrap_func( ''' return self._job_manager._wrap_func( func=func, inputs=inputs, outputs=outputs, - job_ui=self + refresh_btn=self._refresh_btn, stop_btn=self._stop_btn, status_text=self._status_text ) _refresh_btn: gr.Button @@ -94,13 +84,6 @@ def wrap_func( _status_text: gr.Textbox _stop_all_session_btn: gr.Button _free_done_sessions_btn: gr.Button - _active_image: gr.Image - _active_image_stop_btn: gr.Button - _active_image_refresh_btn: gr.Button - _rec_steps_intrvl_sldr: gr.Slider - _rec_steps_checkbox: gr.Checkbox - _save_rec_steps_to_gallery_chkbx: gr.Checkbox - _save_rec_steps_to_file_chkbx: gr.Checkbox _job_manager: JobManager @@ -119,23 +102,11 @@ def draw_gradio_ui(self) -> JobManagerUi: ''' assert gr.context.Context.block is not None, "draw_gradio_ui must be called within a 'gr.Blocks' 'with' context" with gr.Tabs(): - with gr.TabItem("Job Controls"): + with gr.TabItem("Current Session"): with gr.Row(): - stop_btn = gr.Button("Stop All Batches", elem_id="stop", variant="secondary") - refresh_btn = gr.Button("Refresh Finished Batches", elem_id="refresh", variant="secondary") + stop_btn = gr.Button("Stop", elem_id="stop", variant="secondary") + refresh_btn = gr.Button("Refresh", elem_id="refresh", variant="secondary") status_text = gr.Textbox(placeholder="Job Status", interactive=False, show_label=False) - with gr.Row(): - active_image_stop_btn = gr.Button("Skip Active Batch", variant="secondary") - active_image_refresh_btn = gr.Button("View Batch Progress", variant="secondary") - active_image = gr.Image(type="pil", interactive=False, visible=False, elem_id="active_iteration_image") - with gr.TabItem("Batch Progress Settings"): - with gr.Row(): - record_steps_checkbox = gr.Checkbox(value=False, label="Enable Batch Progress Grid") - record_steps_interval_slider = gr.Slider( - value=3, label="Record Interval (steps)", minimum=1, maximum=25, step=1) - with gr.Row() as record_steps_box: - steps_to_gallery_checkbox = gr.Checkbox(value=False, label="Save Progress Grid to Gallery") - steps_to_file_checkbox = gr.Checkbox(value=False, label="Save Progress Grid to File") with gr.TabItem("Maintenance"): with gr.Row(): gr.Markdown( @@ -147,15 +118,9 @@ def draw_gradio_ui(self) -> JobManagerUi: free_done_sessions_btn = gr.Button( "Clear Finished Jobs", elem_id="clear_finished", variant="secondary" ) - return JobManagerUi(_refresh_btn=refresh_btn, _stop_btn=stop_btn, _status_text=status_text, _stop_all_session_btn=stop_all_sessions_btn, _free_done_sessions_btn=free_done_sessions_btn, - _active_image=active_image, _active_image_stop_btn=active_image_stop_btn, - _active_image_refresh_btn=active_image_refresh_btn, - _rec_steps_checkbox=record_steps_checkbox, - _save_rec_steps_to_gallery_chkbx=steps_to_gallery_checkbox, - _save_rec_steps_to_file_chkbx=steps_to_file_checkbox, - _rec_steps_intrvl_sldr=record_steps_interval_slider, _job_manager=self) + _job_manager=self) def clear_all_finished_jobs(self): ''' Removes all currently finished jobs, across all sessions. @@ -169,7 +134,6 @@ def stop_all_jobs(self): for session in self._sessions.values(): for job in session.jobs.values(): job.should_stop.set() - job.stop_cur_iter.set() def _get_job_token(self, block: bool = False) -> Optional[int]: ''' Attempts to acquire a job token, optionally blocking until available ''' @@ -211,26 +175,6 @@ def _stop_wrapped_func(self, func_key: FuncKey, session_key: str) -> List[Compon job_info.should_stop.set() return "Stopping after current batch finishes" - def _refresh_cur_iter_func(self, func_key: FuncKey, session_key: str) -> List[Component]: - ''' Updates information from the active iteration ''' - session_info, job_info = self._get_call_info(func_key, session_key) - if job_info is None: - return [None, f"Session {session_key} was not running function {func_key}"] - - job_info.refresh_active_image_requested.set() - if job_info.refresh_active_image_done.wait(timeout=20.0): - job_info.refresh_active_image_done.clear() - return [gr.Image.update(value=job_info.active_image, visible=True), f"Sample iteration {job_info.active_iteration_cnt}"] - return [gr.Image.update(visible=False), "Timed out getting image"] - - def _stop_cur_iter_func(self, func_key: FuncKey, session_key: str) -> List[Component]: - ''' Marks that the active iteration should be stopped''' - session_info, job_info = self._get_call_info(func_key, session_key) - if job_info is None: - return [None, f"Session {session_key} was not running function {func_key}"] - job_info.stop_cur_iter.set() - return [gr.Image.update(visible=False), "Stopping current iteration"] - def _get_call_info(self, func_key: FuncKey, session_key: str) -> Tuple[SessionInfo, JobInfo]: ''' Helper to get the SessionInfo and JobInfo. ''' session_info = self._sessions.get(session_key, None) @@ -263,8 +207,7 @@ def _run_queued_jobs(self) -> None: def _pre_call_func( self, func_key: FuncKey, output_dummy_obj: Component, refresh_btn: gr.Button, stop_btn: gr.Button, - status_text: gr.Textbox, active_image: gr.Image, active_refresh_btn: gr.Button, active_stop_btn: gr.Button, - session_key: str) -> List[Component]: + status_text: gr.Textbox, session_key: str) -> List[Component]: ''' Called when a job is about to start ''' session_info, job_info = self._get_call_info(func_key, session_key) @@ -276,9 +219,7 @@ def _pre_call_func( return {output_dummy_obj: triggerChangeEvent(), refresh_btn: gr.Button.update(variant="primary", value=refresh_btn.value), stop_btn: gr.Button.update(variant="primary", value=stop_btn.value), - status_text: gr.Textbox.update(value="Generation has started. Click 'Refresh' to see finished images, 'View Batch Progress' for active images"), - active_refresh_btn: gr.Button.update(variant="primary", value=active_refresh_btn.value), - active_stop_btn: gr.Button.update(variant="primary", value=active_stop_btn.value), + status_text: gr.Textbox.update(value="Generation has started. Click 'Refresh' for updates") } def _call_func(self, func_key: FuncKey, session_key: str) -> List[Component]: @@ -292,7 +233,7 @@ def _call_func(self, func_key: FuncKey, session_key: str) -> List[Component]: except Exception as e: job_info.job_status = f"Error: {e}" print(f"Exception processing job {job_info}: {e}\n{traceback.format_exc()}") - raise + outputs = [] # Filter the function output for any removed outputs filtered_output = [] @@ -313,16 +254,12 @@ def _call_func(self, func_key: FuncKey, session_key: str) -> List[Component]: def _post_call_func( self, func_key: FuncKey, output_dummy_obj: Component, refresh_btn: gr.Button, stop_btn: gr.Button, - status_text: gr.Textbox, active_image: gr.Image, active_refresh_btn: gr.Button, active_stop_btn: gr.Button, - session_key: str) -> List[Component]: + status_text: gr.Textbox, session_key: str) -> List[Component]: ''' Called when a job completes ''' return {output_dummy_obj: triggerChangeEvent(), refresh_btn: gr.Button.update(variant="secondary", value=refresh_btn.value), stop_btn: gr.Button.update(variant="secondary", value=stop_btn.value), - status_text: gr.Textbox.update(value="Generation has finished!"), - active_refresh_btn: gr.Button.update(variant="secondary", value=active_refresh_btn.value), - active_stop_btn: gr.Button.update(variant="secondary", value=active_stop_btn.value), - active_image: gr.Image.update(visible=False) + status_text: gr.Textbox.update(value="Generation has finished!") } def _update_gallery_event(self, func_key: FuncKey, session_key: str) -> List[Component]: @@ -338,15 +275,16 @@ def _update_gallery_event(self, func_key: FuncKey, session_key: str) -> List[Com return job_info.images - def _wrap_func(self, func: Callable, inputs: List[Component], - outputs: List[Component], - job_ui: JobManagerUi) -> Tuple[Callable, List[Component]]: + def _wrap_func( + self, func: Callable, inputs: List[Component], outputs: List[Component], + refresh_btn: gr.Button = None, stop_btn: gr.Button = None, + status_text: Optional[gr.Textbox] = None) -> Tuple[Callable, List[Component]]: ''' handles JobManageUI's wrap_func''' assert gr.context.Context.block is not None, "wrap_func must be called within a 'gr.Blocks' 'with' context" # Create a unique key for this job - func_key = FuncKey(job_id=uuid.uuid4().hex, func=func) + func_key = FuncKey(job_id=uuid.uuid4(), func=func) # Create a unique session key (next gradio release can use gr.State, see https://gradio.app/state_in_blocks/) if self._session_key is None: @@ -364,6 +302,9 @@ def _wrap_func(self, func: Callable, inputs: List[Component], del outputs[idx] break + # Add the session key to the inputs + inputs += [self._session_key] + # Create dummy objects update_gallery_obj = gr.JSON(visible=False, elem_id="JobManagerDummyObject") update_gallery_obj.change( @@ -372,44 +313,20 @@ def _wrap_func(self, func: Callable, inputs: List[Component], [gallery_comp] ) - if job_ui._refresh_btn: - job_ui._refresh_btn.variant = 'secondary' - job_ui._refresh_btn.click( + if refresh_btn: + refresh_btn.variant = 'secondary' + refresh_btn.click( partial(self._refresh_func, func_key), [self._session_key], - [update_gallery_obj, job_ui._status_text] + [update_gallery_obj, status_text] ) - if job_ui._stop_btn: - job_ui._stop_btn.variant = 'secondary' - job_ui._stop_btn.click( + if stop_btn: + stop_btn.variant = 'secondary' + stop_btn.click( partial(self._stop_wrapped_func, func_key), [self._session_key], - [job_ui._status_text] - ) - - if job_ui._active_image and job_ui._active_image_refresh_btn: - job_ui._active_image_refresh_btn.click( - partial(self._refresh_cur_iter_func, func_key), - [self._session_key], - [job_ui._active_image, job_ui._status_text] - ) - - if job_ui._active_image_stop_btn: - job_ui._active_image_stop_btn.click( - partial(self._stop_cur_iter_func, func_key), - [self._session_key], - [job_ui._active_image, job_ui._status_text] - ) - - if job_ui._stop_all_session_btn: - job_ui._stop_all_session_btn.click( - self.stop_all_jobs, [], [] - ) - - if job_ui._free_done_sessions_btn: - job_ui._free_done_sessions_btn.click( - self.clear_all_finished_jobs, [], [] + [status_text] ) # (ab)use gr.JSON to forward events. @@ -426,8 +343,7 @@ def _wrap_func(self, func: Callable, inputs: List[Component], # Since some parameters are optional it makes sense to use the 'dict' return value type, which requires # the Component as a key... so group together the UI components that the event listeners are going to update # to make it easy to append to function calls and outputs - job_ui_params = [job_ui._refresh_btn, job_ui._stop_btn, job_ui._status_text, - job_ui._active_image, job_ui._active_image_refresh_btn, job_ui._active_image_stop_btn] + job_ui_params = [refresh_btn, stop_btn, status_text] job_ui_outputs = [comp for comp in job_ui_params if comp is not None] # Here a chain is constructed that will make a 'pre' call, a 'run' call, and a 'post' call, @@ -453,39 +369,27 @@ def _wrap_func(self, func: Callable, inputs: List[Component], [call_dummyobj] + job_ui_outputs ) - # Add any components that we want the runtime values for - added_inputs = [self._session_key, job_ui._rec_steps_checkbox, job_ui._save_rec_steps_to_gallery_chkbx, - job_ui._save_rec_steps_to_file_chkbx, job_ui._rec_steps_intrvl_sldr] - # Now replace the original function with one that creates a JobInfo and triggers the dummy obj - def wrapped_func(*wrapped_inputs): - # Remove the added_inputs (pop opposite order of list) - wrapped_inputs = list(wrapped_inputs) - rec_steps_interval: int = wrapped_inputs.pop() - save_rec_steps_file: bool = wrapped_inputs.pop() - save_rec_steps_grid: bool = wrapped_inputs.pop() - record_steps_enabled: bool = wrapped_inputs.pop() - session_key: str = wrapped_inputs.pop() - job_inputs = tuple(wrapped_inputs) + def wrapped_func(*inputs): + session_key = inputs[-1] + inputs = inputs[:-1] # Get or create a session for this key session_info = self._sessions.setdefault(session_key, SessionInfo()) # Is this session already running this job? if func_key in session_info.jobs: - return {job_ui._status_text: "This session is already running that function!"} + return {status_text: "This session is already running that function!"} job_token = self._get_job_token(block=False) - job = JobInfo( - inputs=job_inputs, func=func, removed_output_idxs=removed_idxs, session_key=session_key, - job_token=job_token, rec_steps_enabled=record_steps_enabled, rec_steps_intrvl=rec_steps_interval, - rec_steps_to_gallery=save_rec_steps_grid, rec_steps_to_file=save_rec_steps_file) + job = JobInfo(inputs=inputs, func=func, removed_output_idxs=removed_idxs, session_key=session_key, + job_token=job_token) session_info.jobs[func_key] = job ret = {pre_call_dummyobj: triggerChangeEvent()} if job_token is None: - ret[job_ui._status_text] = "Job is queued" + ret[status_text] = "Job is queued" return ret - return wrapped_func, inputs + added_inputs, [pre_call_dummyobj, job_ui._status_text] + return wrapped_func, inputs, [pre_call_dummyobj, status_text] diff --git a/frontend/ui_functions.py b/frontend/ui_functions.py index a85d15432..cebe34e0c 100644 --- a/frontend/ui_functions.py +++ b/frontend/ui_functions.py @@ -6,33 +6,17 @@ import re -def change_image_editor_mode(choice, cropped_image, resize_mode, width, height): +def change_image_editor_mode(choice, cropped_image, masked_image, resize_mode, width, height): if choice == "Mask": - return [gr.Image.update(visible=False), - gr.Image.update(visible=True), - gr.Button.update("Generate", variant="primary", visible=False), - gr.Button.update("Generate", variant="primary", visible=True), - gr.Button.update("Advanced Editor", visible=False), - gr.Radio.update(choices=["Keep masked area", "Regenerate only masked area"], - label="Mask Mode", - value="Regenerate only masked area", visible=True), - gr.Slider.update(minimum=1, maximum=10, step=1, label="How much blurry should the mask be? (to avoid hard edges)", value=3, visible=True), - gr.Image.update(interactive=True)] - else: - return [gr.Image.update(visible=True), - gr.Image.update(visible=False), - gr.Button.update("Generate", variant="primary", visible=True), - gr.Button.update("Generate", variant="primary", visible=False), - gr.Button.update("Advanced Editor", visible=True), - gr.Radio.update(choices=["Keep masked area", "Regenerate only masked area"], - label="Mask Mode", - value="Regenerate only masked area", visible=False), - gr.Slider.update(minimum=1, maximum=10, step=1, label="How much blurry should the mask be? (to avoid hard edges)", value=3, visible=False), - gr.Image.update(interactive=False)] + update_image_result = update_image_mask(cropped_image, resize_mode, width, height) + return [gr.update(visible=False), update_image_result, gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True)] + + update_image_result = update_image_mask(masked_image["image"], resize_mode, width, height) + return [update_image_result, gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)] def update_image_mask(cropped_image, resize_mode, width, height): resized_cropped_image = resize_image(resize_mode, cropped_image, width, height) if cropped_image else None - return gr.Image.update(value=resized_cropped_image) + return gr.update(value=resized_cropped_image, visible=True) def toggle_options_gfpgan(selection): if 0 in selection: diff --git a/scripts/webui.py b/scripts/webui.py index 529abb2df..e65de0436 100644 --- a/scripts/webui.py +++ b/scripts/webui.py @@ -1,5 +1,7 @@ import argparse, os, sys, glob, re +import cv2 + from frontend.frontend import draw_gradio_ui from frontend.job_manager import JobManager, JobInfo from frontend.ui_functions import resize_image @@ -37,11 +39,9 @@ parser.add_argument("--share-password", type=str, help="Sharing is open by default, use this to set a password. Username: webui", default=None) parser.add_argument("--share", action='store_true', help="Should share your server on gradio.app, this allows you to use the UI from your mobile app", default=False) parser.add_argument("--skip-grid", action='store_true', help="do not save a grid, only individual samples. Helpful when evaluating lots of samples", default=False) -parser.add_argument("--save-each", action='store_true', help="save individual samples. For speed measurements.", default=False) +parser.add_argument("--skip-save", action='store_true', help="do not save indiviual samples. For speed measurements.", default=False) parser.add_argument('--no-job-manager', action='store_true', help="Don't use the experimental job manager on top of gradio", default=False) parser.add_argument("--max-jobs", type=int, help="Maximum number of concurrent 'generate' commands", default=1) -parser.add_argument("--custom-css", action='store_true', help="Place custom.css in css folder to load a custom theme of the UI", default=False) - opt = parser.parse_args() #Should not be needed anymore @@ -66,12 +66,9 @@ import torch.nn as nn import yaml import glob -import copy -from typing import List, Union, Dict, Callable, Any +from typing import List, Union, Dict from pathlib import Path from collections import namedtuple -import cv2 -from functools import partial from contextlib import contextmanager, nullcontext from einops import rearrange, repeat @@ -109,7 +106,6 @@ GFPGAN_dir = opt.gfpgan_dir RealESRGAN_dir = opt.realesrgan_dir LDSR_dir = opt.ldsr_dir -returned_info = {} if opt.optimized_turbo: opt.optimized = True @@ -140,13 +136,6 @@ grid_quality = abs(grid_quality) -def toImgOpenCV(imgPIL): # Conver imgPIL to imgOpenCV - i = np.array(imgPIL) # After mapping from PIL to numpy : [R,G,B,A] - # numpy Image Channel system: [B,G,R,A] - red = i[:,:,0].copy(); i[:,:,0] = i[:,:,2].copy(); i[:,:,2] = red - return i -def toImgPIL(imgOpenCV): return Image.fromarray(cv2.cvtColor(imgOpenCV, cv2.COLOR_BGR2RGB)) - def chunk(it, size): it = iter(it) return iter(lambda: tuple(islice(it, size)), ()) @@ -275,20 +264,14 @@ def __init__(self, m, sampler): self.schedule = sampler def get_sampler_name(self): return self.schedule - def sample(self, S, conditioning, batch_size, shape, verbose, unconditional_guidance_scale, unconditional_conditioning, eta, x_T, img_callback: Callable = None ): + def sample(self, S, conditioning, batch_size, shape, verbose, unconditional_guidance_scale, unconditional_conditioning, eta, x_T): sigmas = self.model_wrap.get_sigmas(S) x = x_T * sigmas[0] model_wrap_cfg = CFGDenoiser(self.model_wrap) - samples_ddim = K.sampling.__dict__[f'sample_{self.schedule}'](model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': unconditional_guidance_scale}, disable=False, callback=partial(KDiffusionSampler.img_callback_wrapper, img_callback)) - return samples_ddim, None + samples_ddim = K.sampling.__dict__[f'sample_{self.schedule}'](model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': unconditional_guidance_scale}, disable=False) - @classmethod - def img_callback_wrapper(cls, callback: Callable, *args): - ''' Converts a KDiffusion callback to the standard img_callback ''' - if callback: - arg_dict = args[0] - callback(image_sample=arg_dict['denoised'], iter_num=arg_dict['i']) + return samples_ddim, None def create_random_tensors(shape, seeds): @@ -524,7 +507,6 @@ def seed_to_int(s): n = n >> 32 return n - def draw_prompt_matrix(im, width, height, all_prompts): def wrap(text, d, font, line_length): lines = [''] @@ -590,63 +572,6 @@ def draw_texts(pos, x, y, texts, sizes): return result -def round_to_multiple(dimension, dimension_ceiling, multiple=64, round_down=True): - if round_down: - rounded_dimension = multiple * math.ceil(dimension / multiple) - else: - rounded_dimension = multiple * math.floor(dimension / multiple) - return rounded_dimension - - -def crop_image(img, mask, width, height): - def get_mask_and_img(img, mask,dimension, coords, target_width, target_height): - longest_target_dimension = round_to_multiple(dimension, dimension) - func_crop_coords = (coords[0], coords[1], coords[0]+longest_target_dimension, coords[1]+longest_target_dimension) - resized_img = img.crop(func_crop_coords) - scale_dimension = target_width if target_width > target_height else target_height - resized_img = resized_img.resize((scale_dimension, scale_dimension), resample=Image.Resampling.LANCZOS) - - resized_mask = mask.crop(func_crop_coords) - cropped_img_width, cropped_img_height = resized_mask.size - resized_mask = resized_mask.resize((scale_dimension, scale_dimension), resample=Image.Resampling.LANCZOS) - - alpha_mask = resized_mask.convert("RGBA") - mask_data = alpha_mask.getdata() - container = [] - for item in mask_data: - if item[0] == 0 and item[1] == 0 and item[2] == 0: - container.append((255, 255, 255, 0)) - else: - container.append(item) - alpha_mask.putdata(container) - - results = { - "cropped_img": resized_img, - "org_img": rgb_image, - "cropped_mask": alpha_mask, - "coords": crop_coords, - "scale_width": width, - "scale_height": height, - "org_width": cropped_img_width, - "org_height": cropped_img_height - } - return results - - rgb_image = img.convert("RGB") - rgb_mask = mask.convert("RGB") - np_mask = np.array(rgb_mask) - white_columns = np.where(np_mask.max(axis=0)>= 255)[0] - white_rows = np.where(np_mask.max(axis=1)>= 255)[0] - crop_coords = (min(white_columns), min(white_rows), max(white_columns), max(white_rows)) - crop_to_size = rgb_image.crop(crop_coords) - cropped_img_width, cropped_img_height = crop_to_size.size - - if cropped_img_width > cropped_img_height: - results_dict = get_mask_and_img(rgb_image, mask, cropped_img_width, crop_coords, width, height) - else: - results_dict = get_mask_and_img(rgb_image, mask, cropped_img_height, crop_coords, width, height) - - return results_dict def check_prompt_length(prompt, comments): @@ -668,8 +593,8 @@ def check_prompt_length(prompt, comments): comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n") def save_sample(image, sample_path_i, filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale, - normalize_prompt_weights, use_GFPGAN, write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, save_each, - skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, skip_metadata=True): +normalize_prompt_weights, use_GFPGAN, write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, skip_save, +skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, skip_metadata=True): filename_i = os.path.join(sample_path_i, filename) if not jpg_sample: if opt.save_metadata and not skip_metadata: @@ -702,7 +627,7 @@ def save_sample(image, sample_path_i, filename, jpg_sample, prompts, seeds, widt toggles.append(2) if uses_random_seed_loopback: toggles.append(3) - if save_each: + if not skip_save: toggles.append(2 + offset) if not skip_grid: toggles.append(3 + offset) @@ -852,12 +777,12 @@ def classToArrays( items, seed, n_iter ): def process_images( - outpath, func_init, func_sample, prompt, seed, sampler_name, skip_grid, save_each, batch_size, + outpath, func_init, func_sample, prompt, seed, sampler_name, skip_grid, skip_save, batch_size, n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, use_RealESRGAN, realesrgan_model_name, fp, ddim_eta=0.0, do_not_save_grid=False, normalize_prompt_weights=True, init_img=None, init_mask=None, keep_mask=False, mask_blur_strength=3, denoising_strength=0.75, resize_mode=None, uses_loopback=False, uses_random_seed_loopback=False, sort_samples=True, write_info_files=True, write_sample_info_to_log_file=False, jpg_sample=False, - variant_amount=0.0, variant_seed=None,imgProcessorTask=False,resize_mask=False, job_info: JobInfo = None): + variant_amount=0.0, variant_seed=None,imgProcessorTask=False, job_info: JobInfo = None): """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch""" prompt = prompt or '' torch_gc() @@ -956,7 +881,6 @@ def process_images( if job_info: job_info.job_status = f"Processing Iteration {n+1}/{n_iter}. Batch size {batch_size}" - job_info.rec_steps_imgs.clear() for idx,(p,s) in enumerate(zip(prompts,seeds)): job_info.job_status += f"\nItem {idx}: Seed {s}\nPrompt: {p}" @@ -987,7 +911,7 @@ def process_images( while(torch.cuda.memory_allocated()/1e6 >= mem): time.sleep(1) - cur_variant_amount = variant_amount + cur_variant_amount = variant_amount if variant_amount == 0.0: # we manually generate all input noises because each one should have a specific seed x = create_random_tensors(shape, seeds=seeds) @@ -1010,78 +934,17 @@ def process_images( # finally, slerp base_x noise to target_x noise for creating a variant x = slerp(device, max(0.0, min(1.0, cur_variant_amount)), base_x, target_x) - - # If in optimized mode then make a CPU-copy of the model to generate preview images - if opt.optimized: - step_preview_model = copy.deepcopy(modelFS).to("cpu") - if not opt.no_half: - step_preview_model.float() - else: - step_preview_model = model - - def sample_iteration_callback(image_sample: torch.Tensor, iter_num: int): - ''' Called from the sampler every iteration ''' - if job_info: - job_info.active_iteration_cnt = iter_num - record_periodic_image = job_info.rec_steps_enabled and (0 == iter_num % job_info.rec_steps_intrvl) - if record_periodic_image or job_info.refresh_active_image_requested.is_set(): - preview_start_time = time.time() - if opt.optimized: - image_sample = image_sample.to("cpu") - - batch_ddim = step_preview_model.decode_first_stage(image_sample) - batch_ddim = torch.clamp((batch_ddim + 1.0) / 2.0, min=0.0, max=1.0) - preview_elapsed_timed = time.time() - preview_start_time - - if preview_elapsed_timed > 1: - print( - f"Warning: Preview generation is slow! It took {preview_elapsed_timed:.2f}s to generate one preview!") - - images: List[Image.Image] = [] - # Convert tensor to image (copied from code below) - for ddim in batch_ddim: - x_sample = 255. * rearrange(ddim.cpu().numpy(), 'c h w -> h w c') - x_sample = x_sample.astype(np.uint8) - image = Image.fromarray(x_sample) - images.append(image) - - caption = f"Iter {iter_num}" - grid = image_grid(images, len(images), force_n_rows=1, captions=[caption]*len(images)) - - # Save the images if recording steps, and append existing saved steps - if job_info.rec_steps_enabled: - gallery_img_size = tuple( int(0.25*dim) for dim in images[0].size) - job_info.rec_steps_imgs.append(grid.resize(gallery_img_size)) - - # Notify the requester that the image is updated - if job_info.refresh_active_image_requested.is_set(): - if job_info.rec_steps_enabled: - grid = image_grid(job_info.rec_steps_imgs, 1) - job_info.active_image = grid - job_info.refresh_active_image_done.set() - job_info.refresh_active_image_requested.clear() - - # Interrupt current iteration? - if job_info.stop_cur_iter.is_set(): - job_info.stop_cur_iter.clear() - raise StopIteration() - - try: - samples_ddim = func_sample(init_data=init_data, x=x, conditioning=c, unconditional_conditioning=uc, sampler_name=sampler_name, img_callback=sample_iteration_callback) - except StopIteration: - print("Skipping iteration") - job_info.job_status = "Skipping iteration" - continue + samples_ddim = func_sample(init_data=init_data, x=x, conditioning=c, unconditional_conditioning=uc, sampler_name=sampler_name) if opt.optimized: modelFS.to(device) + x_samples_ddim = (model if not opt.optimized else modelFS).decode_first_stage(samples_ddim) x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) for i, x_sample in enumerate(x_samples_ddim): sanitized_prompt = prompts[i].replace(' ', '_').translate({ord(x): '' for x in invalid_filename_chars}) - sanitized_prompt = sanitized_prompt.lower() if variant_seed != None and variant_seed != '': if variant_amount == 0.0: seed_used = f"{current_seeds[i]}-{variant_seed}" @@ -1106,17 +969,6 @@ def sample_iteration_callback(image_sample: torch.Tensor, iter_num: int): image = Image.fromarray(x_sample) original_sample = x_sample original_filename = filename - - if resize_mask: - scaled_img = image.resize((returned_info["org_width"], returned_info["org_height"]), resample=Image.Resampling.LANCZOS).convert("RGB") - scaled_mask = returned_info["cropped_mask"].resize((returned_info["org_width"], returned_info["org_height"]), resample=Image.Resampling.LANCZOS).convert("RGBA") - scaled_mask = scaled_mask.filter(ImageFilter.GaussianBlur(mask_blur_strength)) - returned_info["org_img"].paste(scaled_img, (returned_info["coords"][0], returned_info["coords"][1]), mask=scaled_mask) - image = returned_info["org_img"].copy() - original_sample = np.asarray(image).astype(np.uint8) - #returned_info["org_img"].save(sample_path_i+"\\"+filename+" test.png", format="PNG") - - if use_GFPGAN and GFPGAN is not None and not use_RealESRGAN: skip_save = True # #287 >_> torch_gc() @@ -1124,12 +976,10 @@ def sample_iteration_callback(image_sample: torch.Tensor, iter_num: int): gfpgan_sample = restored_img[:,:,::-1] gfpgan_image = Image.fromarray(gfpgan_sample) gfpgan_filename = original_filename + '-gfpgan' - if save_each: - save_sample(gfpgan_image, sample_path_i, gfpgan_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale, - normalize_prompt_weights, use_GFPGAN, write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, save_each, - skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, skip_metadata=True) + save_sample(gfpgan_image, sample_path_i, gfpgan_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale, +normalize_prompt_weights, use_GFPGAN, write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, skip_save, +skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, skip_metadata=True) output_images.append(gfpgan_image) #287 - # save_each = True # #287 >_> #if simple_templating: # grid_captions.append( captions[i] + "\ngfpgan" ) @@ -1140,30 +990,26 @@ def sample_iteration_callback(image_sample: torch.Tensor, iter_num: int): esrgan_filename = original_filename + '-esrgan4x' esrgan_sample = output[:,:,::-1] esrgan_image = Image.fromarray(esrgan_sample) - if save_each: - save_sample(esrgan_image, sample_path_i, esrgan_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale, - normalize_prompt_weights, use_GFPGAN, write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, save_each, - skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, skip_metadata=True) + save_sample(esrgan_image, sample_path_i, esrgan_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale, +normalize_prompt_weights, use_GFPGAN,write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, skip_save, +skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, skip_metadata=True) output_images.append(esrgan_image) #287 - # save_each = False # #287 >_> #if simple_templating: # grid_captions.append( captions[i] + "\nesrgan" ) if use_RealESRGAN and RealESRGAN is not None and use_GFPGAN and GFPGAN is not None: skip_save = True # #287 >_> torch_gc() - cropped_faces, restored_faces, restored_img = GFPGAN.enhance(original_sample[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True) + cropped_faces, restored_faces, restored_img = GFPGAN.enhance(x_sample[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True) gfpgan_sample = restored_img[:,:,::-1] output, img_mode = RealESRGAN.enhance(gfpgan_sample[:,:,::-1]) gfpgan_esrgan_filename = original_filename + '-gfpgan-esrgan4x' gfpgan_esrgan_sample = output[:,:,::-1] gfpgan_esrgan_image = Image.fromarray(gfpgan_esrgan_sample) - if save_each: - save_sample(gfpgan_esrgan_image, sample_path_i, gfpgan_esrgan_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale, - normalize_prompt_weights, use_GFPGAN, write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, save_each, - skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, skip_metadata=True) + save_sample(gfpgan_esrgan_image, sample_path_i, gfpgan_esrgan_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale, +normalize_prompt_weights, use_GFPGAN, write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, skip_save, +skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, skip_metadata=True) output_images.append(gfpgan_esrgan_image) #287 - # save_each = False # #287 >_> #if simple_templating: # grid_captions.append( captions[i] + "\ngfpgan_esrgan" ) @@ -1171,30 +1017,15 @@ def sample_iteration_callback(image_sample: torch.Tensor, iter_num: int): if imgProcessorTask == True: output_images.append(image) - - if save_each: + if not skip_save: save_sample(image, sample_path_i, filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale, - normalize_prompt_weights, use_GFPGAN, write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, save_each, - skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, False) +normalize_prompt_weights, use_GFPGAN, write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, skip_save, +skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, False) if add_original_image or not simple_templating: output_images.append(image) if simple_templating: grid_captions.append( captions[i] ) - # Save the progress images? - if job_info: - if job_info.rec_steps_enabled and (job_info.rec_steps_to_file or job_info.rec_steps_to_gallery): - steps_grid = image_grid(job_info.rec_steps_imgs, 1) - if job_info.rec_steps_to_gallery: - gallery_img_size = tuple(2*dim for dim in image.size) - output_images.append( steps_grid.resize( gallery_img_size ) ) - if job_info.rec_steps_to_file: - steps_grid_filename = f"{original_filename}_step_grid" - save_sample(steps_grid, sample_path_i, steps_grid_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale, - normalize_prompt_weights, use_GFPGAN, write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, skip_save, - skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, False) - - if opt.optimized: mem = torch.cuda.memory_allocated()/1e6 modelFS.to("cpu") @@ -1263,7 +1094,7 @@ def txt2img(prompt: str, ddim_steps: int, sampler_name: str, toggles: List[int], seed = seed_to_int(seed) prompt_matrix = 0 in toggles normalize_prompt_weights = 1 in toggles - save_each = 2 in toggles + skip_save = 2 not in toggles skip_grid = 3 not in toggles sort_samples = 4 in toggles write_info_files = 5 in toggles @@ -1302,8 +1133,8 @@ def txt2img(prompt: str, ddim_steps: int, sampler_name: str, toggles: List[int], def init(): pass - def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name, img_callback: Callable = None): - samples_ddim, _ = sampler.sample(S=ddim_steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=cfg_scale, unconditional_conditioning=unconditional_conditioning, eta=ddim_eta, x_T=x, img_callback=img_callback) + def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name): + samples_ddim, _ = sampler.sample(S=ddim_steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=cfg_scale, unconditional_conditioning=unconditional_conditioning, eta=ddim_eta, x_T=x) return samples_ddim try: @@ -1314,7 +1145,7 @@ def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name, prompt=prompt, seed=seed, sampler_name=sampler_name, - save_each=save_each, + skip_save=skip_save, skip_grid=skip_grid, batch_size=batch_size, n_iter=n_iter, @@ -1393,9 +1224,14 @@ def flag(self, flag_data, flag_option=None, flag_index=None, username=None): print("Logged:", filenames[0]) -def img2img(prompt: str, image_editor_mode: str, init_info: any, init_info_mask: any, mask_mode: str, mask_blur_strength: int, ddim_steps: int, sampler_name: str, +def img2img(prompt: str, image_editor_mode: str, mask_mode: str, mask_blur_strength: int, ddim_steps: int, sampler_name: str, toggles: List[int], realesrgan_model_name: str, n_iter: int, cfg_scale: float, denoising_strength: float, - seed: int, height: int, width: int, resize_mode: int, fp = None, job_info: JobInfo = None): + seed: int, height: int, width: int, resize_mode: int, init_info: any = None, init_info_mask: any = None, fp = None, job_info: JobInfo = None): + print([prompt, image_editor_mode, init_info, init_info_mask, mask_mode, + mask_blur_strength, ddim_steps, sampler_name, toggles, + realesrgan_model_name, n_iter, cfg_scale, + denoising_strength, seed, height, width, resize_mode, + fp]) outpath = opt.outdir_img2img or opt.outdir or "outputs/img2img-samples" err = False seed = seed_to_int(seed) @@ -1406,7 +1242,7 @@ def img2img(prompt: str, image_editor_mode: str, init_info: any, init_info_mask: normalize_prompt_weights = 1 in toggles loopback = 2 in toggles random_seed_loopback = 3 in toggles - save_each = 4 in toggles + skip_save = 4 not in toggles skip_grid = 5 not in toggles sort_samples = 6 in toggles write_info_files = 7 in toggles @@ -1441,44 +1277,35 @@ def img2img(prompt: str, image_editor_mode: str, init_info: any, init_info_mask: raise Exception("Unknown sampler: " + sampler_name) if image_editor_mode == 'Mask': - global returned_info init_img = init_info_mask["image"] init_img = init_img.convert("RGB") init_img = resize_image(resize_mode, init_img, width, height) - image = image.convert("RGB") + init_img = init_img.convert("RGB") init_mask = init_info_mask["mask"] + init_mask = init_mask.convert("RGB") init_mask = resize_image(resize_mode, init_mask, width, height) - resize_mask = mask_mode == 2 - - if resize_mask: - returned_info = crop_image(init_img, init_mask, width, height) - init_img = returned_info["cropped_img"] - init_mask = returned_info["cropped_mask"] - - keep_mask = mask_mode == 0 init_mask = init_mask.convert("RGB") + keep_mask = mask_mode == 0 init_mask = init_mask if keep_mask else ImageOps.invert(init_mask) else: - init_img = init_info.convert("RGB") + init_img = init_info init_mask = None keep_mask = False - resize_mask = False assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]' t_enc = int(denoising_strength * ddim_steps) def init(): image = init_img.convert("RGB") - if resize_mask: - image = resize_image(resize_mode, image, width, height) - #image = image.convert("RGB") #todo: mask mode -> ValueError: could not convert string to float: + image = resize_image(resize_mode, image, width, height) + #image = image.convert("RGB") image = np.array(image).astype(np.float32) / 255.0 image = image[None].transpose(0, 3, 1, 2) image = torch.from_numpy(image) mask_channel = None if image_editor_mode == "Uncrop": - alpha = init_img.convert("RGB") + alpha = init_img.convert("RGBA") alpha = resize_image(resize_mode, alpha, width // 8, height // 8) mask_channel = alpha.split()[-1] mask_channel = mask_channel.filter(ImageFilter.GaussianBlur(4)) @@ -1486,7 +1313,7 @@ def init(): mask_channel[mask_channel >= 255] = 255 mask_channel[mask_channel < 255] = 0 mask_channel = Image.fromarray(mask_channel).filter(ImageFilter.GaussianBlur(2)) - elif init_mask is not None: + elif image_editor_mode == "Mask": alpha = init_mask.convert("RGBA") alpha = resize_image(resize_mode, alpha, width // 8, height // 8) mask_channel = alpha.split()[1] @@ -1505,7 +1332,7 @@ def init(): init_image = init_image.to(device) init_image = repeat(init_image, '1 ... -> b ...', b=batch_size) init_latent = (model if not opt.optimized else modelFS).get_first_stage_encoding((model if not opt.optimized else modelFS).encode_first_stage(init_image)) # move to latent space - + if opt.optimized: mem = torch.cuda.memory_allocated()/1e6 modelFS.to("cpu") @@ -1514,7 +1341,7 @@ def init(): return init_latent, mask, - def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name, img_callback: Callable = None): + def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name): t_enc_steps = t_enc obliterate = False if ddim_steps == t_enc_steps: @@ -1536,7 +1363,7 @@ def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name, sigma_sched = sigmas[ddim_steps - t_enc_steps - 1:] model_wrap_cfg = CFGMaskedDenoiser(sampler.model_wrap) - samples_ddim = K.sampling.__dict__[f'sample_{sampler.get_sampler_name()}'](model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': cfg_scale, 'mask': z_mask, 'x0': x0, 'xi': xi}, disable=False, callback=partial(KDiffusionSampler.img_callback_wrapper, img_callback)) + samples_ddim = K.sampling.__dict__[f'sample_{sampler.get_sampler_name()}'](model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': cfg_scale, 'mask': z_mask, 'x0': x0, 'xi': xi}, disable=False) else: x0, z_mask = init_data @@ -1563,7 +1390,17 @@ def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name, history = [] initial_seed = None + do_color_correction = False + try: + from skimage import exposure + do_color_correction = True + except: + print("Install scikit-image to perform color correction on loopback") + for i in range(n_iter): + if do_color_correction and i == 0: + correction_target = cv2.cvtColor(np.asarray(init_img.copy()), cv2.COLOR_RGB2LAB) + output_images, seed, info, stats = process_images( outpath=outpath, func_init=init, @@ -1571,7 +1408,7 @@ def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name, prompt=prompt, seed=seed, sampler_name=sampler_name, - save_each=save_each, + skip_save=skip_save, skip_grid=skip_grid, batch_size=1, n_iter=1, @@ -1605,6 +1442,17 @@ def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name, initial_seed = seed init_img = output_images[0] + + if do_color_correction and correction_target is not None: + init_img = Image.fromarray(cv2.cvtColor(exposure.match_histograms( + cv2.cvtColor( + np.asarray(init_img), + cv2.COLOR_RGB2LAB + ), + correction_target, + channel_axis=2 + ), cv2.COLOR_LAB2RGB).astype("uint8")) + if not random_seed_loopback: seed = seed + 1 else: @@ -1630,7 +1478,7 @@ def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name, prompt=prompt, seed=seed, sampler_name=sampler_name, - save_each=save_each, + skip_save=skip_save, skip_grid=skip_grid, batch_size=batch_size, n_iter=n_iter, @@ -1655,7 +1503,6 @@ def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name, write_info_files=write_info_files, write_sample_info_to_log_file=write_sample_info_to_log_file, jpg_sample=jpg_sample, - resize_mask=resize_mask, job_info=job_info ) @@ -1723,10 +1570,9 @@ def imgproc(image,image_batch,imgproc_prompt,imgproc_toggles, imgproc_upscale_to output = [] images = [] def processGFPGAN(image,strength): - cvimage = toImgOpenCV(image) - cropped_faces, restored_faces, restored_img = GFPGAN.enhance(np.array(cvimage, dtype=np.uint8), has_aligned=False, only_center_face=False, paste_back=True) - #save restored image - result = toImgPIL(restored_img) + image = image.convert("RGB") + cropped_faces, restored_faces, restored_img = GFPGAN.enhance(np.array(image, dtype=np.uint8), has_aligned=False, only_center_face=False, paste_back=True) + result = Image.fromarray(restored_img) if strength < 1.0: result = Image.blend(image, result, strength) @@ -1764,7 +1610,7 @@ def processGoBig(image): height = int(imgproc_height) cfg_scale = float(imgproc_cfg) denoising_strength = float(imgproc_denoising) - save_each = True + skip_save = True skip_grid = True prompt = imgproc_prompt t_enc = int(denoising_strength * ddim_steps) @@ -1918,7 +1764,7 @@ def make_mask_image(r): prompt=prompt, seed=seed, sampler_name=sampler_name, - save_each=save_each, + skip_save=skip_save, skip_grid=skip_grid, batch_size=batch_size, n_iter=n_iter, @@ -1964,9 +1810,8 @@ def make_mask_image(r): return combined_image def processLDSR(image): result = LDSR.superResolution(image,int(imgproc_ldsr_steps),str(imgproc_ldsr_pre_downSample),str(imgproc_ldsr_post_downSample)) - return result - - + return result + if image_batch != None: if image != None: @@ -1993,7 +1838,7 @@ def processLDSR(image): if 1 in imgproc_toggles: if imgproc_upscale_toggles == 0: ModelLoader(['GFPGAN','LDSR'],False,True) # Unload unused models - ModelLoader(['RealESGAN'],True,False,imgproc_realesrgan_model_name) # Load used models + ModelLoader(['RealESGAN'],True,False,imgproc_realesrgan_model_name) # Load used models elif imgproc_upscale_toggles == 1: ModelLoader(['GFPGAN','LDSR'],False,True) # Unload unused models ModelLoader(['RealESGAN','model'],True,False) # Load used models @@ -2106,14 +1951,15 @@ def ModelLoader(models,load=False,unload=False,imgproc_realesrgan_model_name='Re def run_GFPGAN(image, strength): ModelLoader(['LDSR','RealESRGAN'],False,True) ModelLoader(['GFPGAN'],True,False) - cvimage = toImgOpenCV(image) - cropped_faces, restored_faces, restored_img = GFPGAN.enhance(np.array(cvimage, dtype=np.uint8), has_aligned=False, only_center_face=False, paste_back=True) - #save restored image - result = toImgPIL(restored_img) + image = image.convert("RGB") + + cropped_faces, restored_faces, restored_img = GFPGAN.enhance(np.array(image, dtype=np.uint8), has_aligned=False, only_center_face=False, paste_back=True) + res = Image.fromarray(restored_img) + if strength < 1.0: - result = Image.blend(image, result, strength) + res = Image.blend(image, res, strength) - return result + return res def run_RealESRGAN(image, model_name: str): ModelLoader(['GFPGAN','LDSR'],False,True) @@ -2195,9 +2041,9 @@ def run_RealESRGAN(image, model_name: str): 'Upscale' ] -sample_img2img = "assets/stable-samples/img2img/sketch-mountains-input.jpg" -sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None - +#sample_img2img = "assets/stable-samples/img2img/sketch-mountains-input.jpg" +#sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None +sample_img2img = None # make sure these indicies line up at the top of img2img() img2img_toggles = [ 'Create prompt matrix (separate multiple prompts using |, and get all combinations of them)', @@ -2226,7 +2072,6 @@ def run_RealESRGAN(image, model_name: str): "Just resize", "Crop and resize", "Resize and fill", - "Resize Masked Area" ] img2img_defaults = { @@ -2262,22 +2107,13 @@ def update_image_mask(cropped_image, resize_mode, width, height): resized_cropped_image = resize_image(resize_mode, cropped_image, width, height) if cropped_image else None return gr.update(value=resized_cropped_image) -def copy_img_to_input(img): - try: - image_data = re.sub('^data:image/.+;base64,', '', img) - processed_image = Image.open(BytesIO(base64.b64decode(image_data))) - tab_update = gr.update(selected='img2img_tab') - img_update = gr.update(value=processed_image) - return {img2img_image_mask: processed_image, img2img_image_editor: img_update, tabs: tab_update} - except IndexError: - return [None, None] def copy_img_to_upscale_esrgan(img): update = gr.update(selected='realesrgan_tab') image_data = re.sub('^data:image/.+;base64,', '', img) processed_image = Image.open(BytesIO(base64.b64decode(image_data))) - return {realesrgan_source: processed_image, tabs: update} + return {'realesrgan_source': processed_image, 'tabs': update} help_text = """ @@ -2341,7 +2177,7 @@ def run(self): 'inbrowser': opt.inbrowser, 'server_name': '0.0.0.0', 'server_port': opt.port, - 'share': opt.share, + 'share': opt.share, 'show_error': True } if not opt.share: