Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mask_restore to restore images based on mask, fixing #665 #898

Merged
merged 8 commits into from
Sep 9, 2022
14 changes: 9 additions & 5 deletions frontend/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,13 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda
value=img2img_mask_modes[img2img_defaults['mask_mode']],
visible=True)

img2img_mask_blur_strength = gr.Slider(minimum=1, maximum=10, step=1,
img2img_mask_restore = gr.Checkbox(label="Only modify regenerated parts of image",
value=img2img_defaults['mask_restore'],
visible=True)

img2img_mask_blur_strength = gr.Slider(minimum=1, maximum=100, step=1,
label="How much blurry should the mask be? (to avoid hard edges)",
value=3, visible=False)
value=3, visible=True)

img2img_resize = gr.Radio(label="Resize mode",
choices=["Just resize", "Crop and resize",
Expand Down Expand Up @@ -294,7 +298,7 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda
img2img_height
],
[img2img_image_editor, img2img_image_mask, img2img_btn_editor, img2img_btn_mask,
img2img_painterro_btn, img2img_mask, img2img_mask_blur_strength]
img2img_painterro_btn, img2img_mask, img2img_mask_blur_strength, img2img_mask_restore]
)

# img2img_image_editor_mode.change(
Expand Down Expand Up @@ -335,8 +339,8 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda
)

img2img_func = img2img
img2img_inputs = [img2img_prompt, img2img_image_editor_mode, img2img_mask,
img2img_mask_blur_strength, img2img_steps, img2img_sampling, img2img_toggles,
img2img_inputs = [img2img_prompt, img2img_image_editor_mode, img2img_mask, img2img_mask_blur_strength,
img2img_mask_restore, img2img_steps, img2img_sampling, img2img_toggles,
img2img_realesrgan_model_name, img2img_batch_count, img2img_cfg,
img2img_denoising, img2img_seed, img2img_height, img2img_width, img2img_resize,
img2img_image_editor, img2img_image_mask, img2img_embeddings]
Expand Down
4 changes: 2 additions & 2 deletions frontend/ui_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
def change_image_editor_mode(choice, cropped_image, masked_image, resize_mode, width, height):
if choice == "Mask":
update_image_result = update_image_mask(cropped_image, resize_mode, width, height)
return [gr.update(visible=False), update_image_result, gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True)]
return [gr.update(visible=False), update_image_result, gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)]

update_image_result = update_image_mask(masked_image["image"] if masked_image is not None else None, resize_mode, width, height)
return [update_image_result, gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)]
return [update_image_result, gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)]

def update_image_mask(cropped_image, resize_mode, width, height):
resized_cropped_image = resize_image(resize_mode, cropped_image, width, height) if cropped_image else None
Expand Down
47 changes: 27 additions & 20 deletions scripts/webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,7 +790,7 @@ def process_images(
outpath, func_init, func_sample, prompt, seed, sampler_name, skip_grid, skip_save, batch_size,
n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, use_RealESRGAN, realesrgan_model_name,
fp, ddim_eta=0.0, do_not_save_grid=False, normalize_prompt_weights=True, init_img=None, init_mask=None,
keep_mask=False, mask_blur_strength=3, denoising_strength=0.75, resize_mode=None, uses_loopback=False,
keep_mask=False, mask_blur_strength=3, mask_restore=False, denoising_strength=0.75, resize_mode=None, uses_loopback=False,
uses_random_seed_loopback=False, sort_samples=True, write_info_files=True, write_sample_info_to_log_file=False, jpg_sample=False,
variant_amount=0.0, variant_seed=None,imgProcessorTask=False, job_info: JobInfo = None):
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
Expand Down Expand Up @@ -1045,6 +1045,26 @@ def process_images(
if imgProcessorTask == True:
output_images.append(image)

if mask_restore and init_mask:
#init_mask = init_mask if keep_mask else ImageOps.invert(init_mask)
init_mask = init_mask.filter(ImageFilter.GaussianBlur(mask_blur_strength))
init_mask = init_mask.convert('L')
init_img = init_img.convert('RGB')
image = image.convert('RGB')

if use_RealESRGAN and RealESRGAN is not None:
if RealESRGAN.model.name != realesrgan_model_name:
try_loading_RealESRGAN(realesrgan_model_name)
output, img_mode = RealESRGAN.enhance(np.array(init_img, dtype=np.uint8))
init_img = Image.fromarray(output)
init_img = init_img.convert('RGB')

output, img_mode = RealESRGAN.enhance(np.array(init_mask, dtype=np.uint8))
init_mask = Image.fromarray(output)
init_mask = init_mask.convert('L')

image = Image.composite(init_img, image, init_mask)

if not skip_save:
save_sample(image, sample_path_i, filename, jpg_sample, write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, skip_save,
skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, False)
Expand Down Expand Up @@ -1259,7 +1279,7 @@ def blurArr(a,r=8):



def img2img(prompt: str, image_editor_mode: str, mask_mode: str, mask_blur_strength: int, ddim_steps: int, sampler_name: str,
def img2img(prompt: str, image_editor_mode: str, mask_mode: str, mask_blur_strength: int, mask_restore: bool, ddim_steps: int, sampler_name: str,
toggles: List[int], realesrgan_model_name: str, n_iter: int, cfg_scale: float, denoising_strength: float,
seed: int, height: int, width: int, resize_mode: int, init_info: any = None, init_info_mask: any = None, fp = None, job_info: JobInfo = None):
# print([prompt, image_editor_mode, init_info, init_info_mask, mask_mode,
Expand Down Expand Up @@ -1504,6 +1524,7 @@ def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name)
init_mask=init_mask,
keep_mask=keep_mask,
mask_blur_strength=mask_blur_strength,
mask_restore=mask_restore,
denoising_strength=denoising_strength,
resize_mode=resize_mode,
uses_loopback=loopback,
Expand Down Expand Up @@ -1574,6 +1595,7 @@ def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name)
keep_mask=keep_mask,
mask_blur_strength=mask_blur_strength,
denoising_strength=denoising_strength,
mask_restore=mask_restore,
resize_mode=resize_mode,
uses_loopback=loopback,
sort_samples=sort_samples,
Expand Down Expand Up @@ -1722,6 +1744,7 @@ def processGoBig(image):
init_img = result
init_mask = None
keep_mask = False
mask_restore = False
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'

def init():
Expand Down Expand Up @@ -1868,6 +1891,7 @@ def make_mask_image(r):
keep_mask=False,
mask_blur_strength=None,
denoising_strength=denoising_strength,
mask_restore=mask_restore,
resize_mode=resize_mode,
uses_loopback=False,
sort_samples=True,
Expand Down Expand Up @@ -2186,6 +2210,7 @@ def run_RealESRGAN(image, model_name: str):
'cfg_scale': 5.0,
'denoising_strength': 0.75,
'mask_mode': 0,
'mask_restore': False,
'resize_mode': 0,
'seed': '',
'height': 512,
Expand All @@ -2199,24 +2224,6 @@ def run_RealESRGAN(image, model_name: str):
img2img_toggle_defaults = [img2img_toggles[i] for i in img2img_defaults['toggles']]
img2img_image_mode = 'sketch'

def change_image_editor_mode(choice, cropped_image, resize_mode, width, height):
if choice == "Mask":
return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True)]
return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)]

def update_image_mask(cropped_image, resize_mode, width, height):
resized_cropped_image = resize_image(resize_mode, cropped_image, width, height) if cropped_image else None
return gr.update(value=resized_cropped_image)



def copy_img_to_upscale_esrgan(img):
update = gr.update(selected='realesrgan_tab')
image_data = re.sub('^data:image/.+;base64,', '', img)
processed_image = Image.open(BytesIO(base64.b64decode(image_data)))
return {'realesrgan_source': processed_image, 'tabs': update}


help_text = """
## Mask/Crop
* The masking/cropping is very temperamental.
Expand Down
32 changes: 28 additions & 4 deletions scripts/webui_streamlit.py
Original file line number Diff line number Diff line change
Expand Up @@ -913,7 +913,7 @@ def process_images(
outpath, func_init, func_sample, prompt, seed, sampler_name, save_grid, batch_size,
n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, use_RealESRGAN, realesrgan_model_name,
fp=None, ddim_eta=0.0, normalize_prompt_weights=True, init_img=None, init_mask=None,
keep_mask=False, mask_blur_strength=3, denoising_strength=0.75, resize_mode=None, uses_loopback=False,
keep_mask=False, mask_blur_strength=3, mask_restore=False, denoising_strength=0.75, resize_mode=None, uses_loopback=False,
uses_random_seed_loopback=False, sort_samples=True, write_info_files=True, jpg_sample=False,
variant_amount=0.0, variant_seed=None, save_individual_images: bool = True):
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
Expand Down Expand Up @@ -1156,7 +1156,29 @@ def process_images(

if simple_templating:
grid_captions.append( captions[i] + "\ngfpgan_esrgan" )


if mask_restore and init_mask:
#init_mask = init_mask if keep_mask else ImageOps.invert(init_mask)
init_mask = init_mask.filter(ImageFilter.GaussianBlur(mask_blur_strength))
init_mask = init_mask.convert('L')
init_img = init_img.convert('RGB')
image = image.convert('RGB')

if use_RealESRGAN and st.session_state["RealESRGAN"] is not None:
if st.session_state["RealESRGAN"].model.name != realesrgan_model_name:
#try_loading_RealESRGAN(realesrgan_model_name)
load_models(use_GFPGAN=use_GFPGAN, use_RealESRGAN=use_RealESRGAN, RealESRGAN_model=realesrgan_model_name)

output, img_mode = st.session_state["RealESRGAN"].enhance(np.array(init_img, dtype=np.uint8))
init_img = Image.fromarray(output)
init_img = init_img.convert('RGB')

output, img_mode = st.session_state["RealESRGAN"].enhance(np.array(init_mask, dtype=np.uint8))
init_mask = Image.fromarray(output)
init_mask = init_mask.convert('L')

image = Image.composite(init_img, image, init_mask)

if save_individual_images:
save_sample(image, sample_path_i, filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale,
normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback,
Expand Down Expand Up @@ -1257,7 +1279,7 @@ def resize_image(resize_mode, im, width, height):
return res

def img2img(prompt: str = '', init_info: any = None, init_info_mask: any = None, mask_mode: int = 0, mask_blur_strength: int = 3,
ddim_steps: int = 50, sampler_name: str = 'DDIM',
mask_restore: bool = False, ddim_steps: int = 50, sampler_name: str = 'DDIM',
n_iter: int = 1, cfg_scale: float = 7.5, denoising_strength: float = 0.8,
seed: int = -1, height: int = 512, width: int = 512, resize_mode: int = 0, fp = None,
variant_amount: float = None, variant_seed: int = None, ddim_eta:float = 0.0,
Expand Down Expand Up @@ -1426,6 +1448,7 @@ def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name)
init_mask=init_mask,
keep_mask=keep_mask,
mask_blur_strength=mask_blur_strength,
mask_restore=mask_restore,
denoising_strength=denoising_strength,
resize_mode=resize_mode,
uses_loopback=loopback,
Expand Down Expand Up @@ -1486,8 +1509,9 @@ def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name)
init_img=init_img,
init_mask=init_mask,
keep_mask=keep_mask,
mask_blur_strength=2,
mask_blur_strength=mask_blur_strength,
denoising_strength=denoising_strength,
mask_restore=mask_restore,
resize_mode=resize_mode,
uses_loopback=loopback,
sort_samples=group_by_prompt,
Expand Down