diff --git a/extras/censor.py b/extras/censor.py index 2047db246..80d019eab 100644 --- a/extras/censor.py +++ b/extras/censor.py @@ -1,56 +1,58 @@ -# modified version of https://github.com/AUTOMATIC1111/stable-diffusion-webui-nsfw-censor/blob/master/scripts/censor.py -import numpy as np import os -from extras.safety_checker.models.safety_checker import StableDiffusionSafetyChecker -from transformers import CLIPFeatureExtractor, CLIPConfig -from PIL import Image +import numpy as np +import torch +from transformers import CLIPConfig, CLIPImageProcessor + +import ldm_patched.modules.model_management as model_management import modules.config +from extras.safety_checker.models.safety_checker import StableDiffusionSafetyChecker +from ldm_patched.modules.model_patcher import ModelPatcher safety_checker_repo_root = os.path.join(os.path.dirname(__file__), 'safety_checker') config_path = os.path.join(safety_checker_repo_root, "configs", "config.json") preprocessor_config_path = os.path.join(safety_checker_repo_root, "configs", "preprocessor_config.json") -safety_feature_extractor = None -safety_checker = None - - -def numpy_to_pil(image): - image = (image * 255).round().astype("uint8") - pil_image = Image.fromarray(image) - - return pil_image - -# check and replace nsfw content -def check_safety(x_image): - global safety_feature_extractor, safety_checker +class Censor: + def __init__(self): + self.safety_checker_model: ModelPatcher | None = None + self.clip_image_processor: CLIPImageProcessor | None = None + self.load_device = torch.device('cpu') + self.offload_device = torch.device('cpu') - if safety_feature_extractor is None or safety_checker is None: - safety_checker_model = modules.config.downloading_safety_checker_model() - safety_feature_extractor = CLIPFeatureExtractor.from_json_file(preprocessor_config_path) - clip_config = CLIPConfig.from_json_file(config_path) - safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_checker_model, config=clip_config) + def init(self): + if self.safety_checker_model is None and self.clip_image_processor is None: + safety_checker_model = modules.config.downloading_safety_checker_model() + self.clip_image_processor = CLIPImageProcessor.from_json_file(preprocessor_config_path) + clip_config = CLIPConfig.from_json_file(config_path) + model = StableDiffusionSafetyChecker.from_pretrained(safety_checker_model, config=clip_config) + model.eval() - safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt") - x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values) + self.load_device = model_management.text_encoder_device() + self.offload_device = model_management.text_encoder_offload_device() - return x_checked_image, has_nsfw_concept + model.to(self.offload_device) + self.safety_checker_model = ModelPatcher(model, load_device=self.load_device, offload_device=self.offload_device) -def censor_single(x): - x_checked_image, has_nsfw_concept = check_safety(x) + def censor(self, images: list | np.ndarray) -> list | np.ndarray: + self.init() + model_management.load_model_gpu(self.safety_checker_model) - # replace image with black pixels, keep dimensions - # workaround due to different numpy / pytorch image matrix format - if has_nsfw_concept[0]: - imageshape = x_checked_image.shape - x_checked_image = np.zeros((imageshape[0], imageshape[1], 3), dtype = np.uint8) + single = False + if not isinstance(images, list) or isinstance(images, np.ndarray): + images = [images] + single = True - return x_checked_image + safety_checker_input = self.clip_image_processor(images, return_tensors="pt") + safety_checker_input.to(device=self.load_device) + checked_images, has_nsfw_concept = self.safety_checker_model.model(images=images, + clip_input=safety_checker_input.pixel_values) + if single: + checked_images = checked_images[0] + return checked_images -def censor_batch(images): - images = [censor_single(image) for image in images] - return images \ No newline at end of file +default_censor = Censor().censor diff --git a/modules/async_worker.py b/modules/async_worker.py index 892f99a7e..042ba49eb 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -44,7 +44,7 @@ def worker(): import fooocus_version import args_manager - from extras.censor import censor_batch, censor_single + from extras.censor import default_censor from modules.sdxl_styles import apply_style, get_random_style, fooocus_expansion, apply_arrays, random_style_name from modules.private_logger import log from extras.expansion import safe_str @@ -78,7 +78,7 @@ def yield_result(async_task, imgs, black_out_nsfw, censor=True, do_not_show_fini if censor and (modules.config.default_black_out_nsfw or black_out_nsfw): progressbar(async_task, progressbar_index, 'Checking for NSFW content ...') - imgs = censor_batch(imgs) + imgs = default_censor(imgs) async_task.results = async_task.results + imgs @@ -615,7 +615,7 @@ def handler(async_task): d = [('Upscale (Fast)', 'upscale_fast', '2x')] if modules.config.default_black_out_nsfw or black_out_nsfw: progressbar(async_task, 100, 'Checking for NSFW content ...') - uov_input_image = censor_single(uov_input_image) + uov_input_image = default_censor(uov_input_image) uov_input_image_path = log(uov_input_image, d, output_format=output_format) yield_result(async_task, uov_input_image_path, black_out_nsfw, False, do_not_show_finished_images=True) return @@ -883,12 +883,12 @@ def callback(step, x0, x, total_steps, y): imgs = [inpaint_worker.current_task.post_process(x) for x in imgs] img_paths = [] - + current_progress = int(15.0 + 85.0 * float((current_task_id + 1) * steps) / float(all_steps)) if modules.config.default_black_out_nsfw or black_out_nsfw: - progressbar(async_task, int(15.0 + 85.0 * float((current_task_id + 1) * steps) / float(all_steps)), - 'Checking for NSFW content ...') - imgs = censor_batch(imgs) + progressbar(async_task, current_progress, 'Checking for NSFW content ...') + imgs = default_censor(imgs) + progressbar(async_task, current_progress, 'Saving image to system ...') for x in imgs: d = [('Prompt', 'prompt', task['log_positive_prompt']), ('Negative Prompt', 'negative_prompt', task['log_negative_prompt']),