Skip to content

Commit

Permalink
add img2img option for color correction. (#936)
Browse files Browse the repository at this point in the history
color correction is already used for loopback to prevent color drift with the first image as correction target.
the option allows to use the color correction even without loopback mode.
it helps keeping the colors similar to the input image.
  • Loading branch information
xaedes authored Sep 9, 2022
1 parent 32a3c05 commit 0706294
Showing 1 changed file with 49 additions and 24 deletions.
73 changes: 49 additions & 24 deletions scripts/webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -784,15 +784,35 @@ def classToArrays( items, seed, n_iter ):

return all_seeds, n_iter, prompt_matrix_parts, all_prompts, needrows


def perform_color_correction(img_rgb, correction_target_lab, do_color_correction):
try:
from skimage import exposure
except:
print("Install scikit-image to perform color correction")
return img_rgb

if not do_color_correction: return img_rgb
if correction_target_lab is None: return img_rgb

return (
Image.fromarray(cv2.cvtColor(exposure.match_histograms(
cv2.cvtColor(
np.asarray(img_rgb),
cv2.COLOR_RGB2LAB
),
correction_target_lab,
channel_axis=2
), cv2.COLOR_LAB2RGB).astype("uint8")
)
)

def process_images(
outpath, func_init, func_sample, prompt, seed, sampler_name, skip_grid, skip_save, batch_size,
n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, use_RealESRGAN, realesrgan_model_name,
fp, ddim_eta=0.0, do_not_save_grid=False, normalize_prompt_weights=True, init_img=None, init_mask=None,
keep_mask=False, mask_blur_strength=3, mask_restore=False, denoising_strength=0.75, resize_mode=None, uses_loopback=False,
uses_random_seed_loopback=False, sort_samples=True, write_info_files=True, write_sample_info_to_log_file=False, jpg_sample=False,
variant_amount=0.0, variant_seed=None,imgProcessorTask=False, job_info: JobInfo = None):
variant_amount=0.0, variant_seed=None,imgProcessorTask=False, job_info: JobInfo = None, do_color_correction=False, correction_target=None):
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
prompt = prompt or ''
torch_gc()
Expand Down Expand Up @@ -991,6 +1011,7 @@ def process_images(
cfg_scale=cfg_scale, normalize_prompt_weights=normalize_prompt_weights, denoising_strength=denoising_strength,
GFPGAN=use_GFPGAN )
image = Image.fromarray(x_sample)
image = perform_color_correction(image, correction_target, do_color_correction)
ImageMetadata.set_on_image(image, metadata)

original_sample = x_sample
Expand All @@ -1001,6 +1022,7 @@ def process_images(
cropped_faces, restored_faces, restored_img = GFPGAN.enhance(original_sample[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True)
gfpgan_sample = restored_img[:,:,::-1]
gfpgan_image = Image.fromarray(gfpgan_sample)
gfpgan_image = perform_color_correction(gfpgan_image, correction_target, do_color_correction)
gfpgan_metadata = copy.copy(metadata)
gfpgan_metadata.GFPGAN = True
ImageMetadata.set_on_image( gfpgan_image, gfpgan_metadata )
Expand All @@ -1018,6 +1040,7 @@ def process_images(
esrgan_filename = original_filename + '-esrgan4x'
esrgan_sample = output[:,:,::-1]
esrgan_image = Image.fromarray(esrgan_sample)
esrgan_image = perform_color_correction(esrgan_image, correction_target, do_color_correction)
ImageMetadata.set_on_image( esrgan_image, metadata )
save_sample(esrgan_image, sample_path_i, esrgan_filename, jpg_sample, write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, skip_save,
skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, skip_metadata=False)
Expand All @@ -1034,6 +1057,7 @@ def process_images(
gfpgan_esrgan_filename = original_filename + '-gfpgan-esrgan4x'
gfpgan_esrgan_sample = output[:,:,::-1]
gfpgan_esrgan_image = Image.fromarray(gfpgan_esrgan_sample)
gfpgan_esrgan_image = perform_color_correction(gfpgan_esrgan_image, correction_target, do_color_correction)
ImageMetadata.set_on_image(gfpgan_esrgan_image, metadata)
save_sample(gfpgan_esrgan_image, sample_path_i, gfpgan_esrgan_filename, jpg_sample, write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback,
skip_save, skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, skip_metadata=False)
Expand Down Expand Up @@ -1149,6 +1173,10 @@ def txt2img(prompt: str, ddim_steps: int, sampler_name: str, toggles: List[int],
jpg_sample = 7 in toggles
use_GFPGAN = 8 in toggles
use_RealESRGAN = 9 in toggles

do_color_correction = False
correction_target = None

ModelLoader(['model'],True,False)
if use_GFPGAN and not use_RealESRGAN:
ModelLoader(['GFPGAN'],True,False)
Expand Down Expand Up @@ -1214,6 +1242,8 @@ def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name)
variant_amount=variant_amount,
variant_seed=variant_seed,
job_info=job_info,
do_color_correction=do_color_correction,
correction_target=correction_target
)

del sampler
Expand Down Expand Up @@ -1303,8 +1333,9 @@ def img2img(prompt: str, image_editor_mode: str, mask_mode: str, mask_blur_stren
write_info_files = 7 in toggles
write_sample_info_to_log_file = 8 in toggles
jpg_sample = 9 in toggles
use_GFPGAN = 10 in toggles
use_RealESRGAN = 11 in toggles
do_color_correction = 10 in toggles
use_GFPGAN = 11 in toggles
use_RealESRGAN = 12 in toggles
ModelLoader(['model'],True,False)
if use_GFPGAN and not use_RealESRGAN:
ModelLoader(['GFPGAN'],True,False)
Expand Down Expand Up @@ -1481,19 +1512,15 @@ def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name)
return samples_ddim



correction_target = None
if loopback:
output_images, info = None, None
history = []
initial_seed = None

do_color_correction = False
try:
from skimage import exposure
do_color_correction = True
except:
print("Install scikit-image to perform color correction on loopback")

# turn on color correction for loopback to prevent known issue of color drift
do_color_correction = True

for i in range(n_iter):
if do_color_correction and i == 0:
correction_target = cv2.cvtColor(np.asarray(init_img.copy()), cv2.COLOR_RGB2LAB)
Expand Down Expand Up @@ -1533,24 +1560,16 @@ def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name)
write_info_files=write_info_files,
write_sample_info_to_log_file=write_sample_info_to_log_file,
jpg_sample=jpg_sample,
job_info=job_info
job_info=job_info,
do_color_correction=do_color_correction,
correction_target=correction_target
)

if initial_seed is None:
initial_seed = seed

init_img = output_images[0]

if do_color_correction and correction_target is not None:
init_img = Image.fromarray(cv2.cvtColor(exposure.match_histograms(
cv2.cvtColor(
np.asarray(init_img),
cv2.COLOR_RGB2LAB
),
correction_target,
channel_axis=2
), cv2.COLOR_LAB2RGB).astype("uint8"))

if not random_seed_loopback:
seed = seed + 1
else:
Expand All @@ -1569,6 +1588,9 @@ def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name)
seed = initial_seed

else:
if do_color_correction:
correction_target = cv2.cvtColor(np.asarray(init_img.copy()), cv2.COLOR_RGB2LAB)

output_images, seed, info, stats = process_images(
outpath=outpath,
func_init=init,
Expand Down Expand Up @@ -1602,7 +1624,9 @@ def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name)
write_info_files=write_info_files,
write_sample_info_to_log_file=write_sample_info_to_log_file,
jpg_sample=jpg_sample,
job_info=job_info
job_info=job_info,
do_color_correction=do_color_correction,
correction_target=correction_target
)

del sampler
Expand Down Expand Up @@ -2181,6 +2205,7 @@ def run_RealESRGAN(image, model_name: str):
'Write sample info files',
'Write sample info to one file',
'jpg samples',
'Color correction (always enabled on loopback mode)'
]
# removed for now becuase of Image Lab implementation
if GFPGAN is not None:
Expand Down

0 comments on commit 0706294

Please sign in to comment.