Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: optimize model management of nsfw image censoring #2960

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 39 additions & 37 deletions extras/censor.py
Original file line number Diff line number Diff line change
@@ -1,56 +1,58 @@
# modified version of https://github.com/AUTOMATIC1111/stable-diffusion-webui-nsfw-censor/blob/master/scripts/censor.py
import numpy as np
import os

from extras.safety_checker.models.safety_checker import StableDiffusionSafetyChecker
from transformers import CLIPFeatureExtractor, CLIPConfig
from PIL import Image
import numpy as np
import torch
from transformers import CLIPConfig, CLIPImageProcessor

import ldm_patched.modules.model_management as model_management
import modules.config
from extras.safety_checker.models.safety_checker import StableDiffusionSafetyChecker
from ldm_patched.modules.model_patcher import ModelPatcher

safety_checker_repo_root = os.path.join(os.path.dirname(__file__), 'safety_checker')
config_path = os.path.join(safety_checker_repo_root, "configs", "config.json")
preprocessor_config_path = os.path.join(safety_checker_repo_root, "configs", "preprocessor_config.json")

safety_feature_extractor = None
safety_checker = None


def numpy_to_pil(image):
image = (image * 255).round().astype("uint8")
pil_image = Image.fromarray(image)

return pil_image


# check and replace nsfw content
def check_safety(x_image):
global safety_feature_extractor, safety_checker
class Censor:
def __init__(self):
self.safety_checker_model: ModelPatcher | None = None
self.clip_image_processor: CLIPImageProcessor | None = None
self.load_device = torch.device('cpu')
self.offload_device = torch.device('cpu')

if safety_feature_extractor is None or safety_checker is None:
safety_checker_model = modules.config.downloading_safety_checker_model()
safety_feature_extractor = CLIPFeatureExtractor.from_json_file(preprocessor_config_path)
clip_config = CLIPConfig.from_json_file(config_path)
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_checker_model, config=clip_config)
def init(self):
if self.safety_checker_model is None and self.clip_image_processor is None:
safety_checker_model = modules.config.downloading_safety_checker_model()
self.clip_image_processor = CLIPImageProcessor.from_json_file(preprocessor_config_path)
clip_config = CLIPConfig.from_json_file(config_path)
model = StableDiffusionSafetyChecker.from_pretrained(safety_checker_model, config=clip_config)
model.eval()

safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt")
x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
self.load_device = model_management.text_encoder_device()
self.offload_device = model_management.text_encoder_offload_device()

return x_checked_image, has_nsfw_concept
model.to(self.offload_device)

self.safety_checker_model = ModelPatcher(model, load_device=self.load_device, offload_device=self.offload_device)

def censor_single(x):
x_checked_image, has_nsfw_concept = check_safety(x)
def censor(self, images: list | np.ndarray) -> list | np.ndarray:
self.init()
model_management.load_model_gpu(self.safety_checker_model)

# replace image with black pixels, keep dimensions
# workaround due to different numpy / pytorch image matrix format
if has_nsfw_concept[0]:
imageshape = x_checked_image.shape
x_checked_image = np.zeros((imageshape[0], imageshape[1], 3), dtype = np.uint8)
single = False
if not isinstance(images, list) or isinstance(images, np.ndarray):
images = [images]
single = True

return x_checked_image
safety_checker_input = self.clip_image_processor(images, return_tensors="pt")
safety_checker_input.to(device=self.load_device)
checked_images, has_nsfw_concept = self.safety_checker_model.model(images=images,
clip_input=safety_checker_input.pixel_values)
if single:
checked_images = checked_images[0]

return checked_images

def censor_batch(images):
images = [censor_single(image) for image in images]

return images
default_censor = Censor().censor
14 changes: 7 additions & 7 deletions modules/async_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def worker():
import fooocus_version
import args_manager

from extras.censor import censor_batch, censor_single
from extras.censor import default_censor
from modules.sdxl_styles import apply_style, get_random_style, fooocus_expansion, apply_arrays, random_style_name
from modules.private_logger import log
from extras.expansion import safe_str
Expand Down Expand Up @@ -78,7 +78,7 @@ def yield_result(async_task, imgs, black_out_nsfw, censor=True, do_not_show_fini

if censor and (modules.config.default_black_out_nsfw or black_out_nsfw):
progressbar(async_task, progressbar_index, 'Checking for NSFW content ...')
imgs = censor_batch(imgs)
imgs = default_censor(imgs)

async_task.results = async_task.results + imgs

Expand Down Expand Up @@ -615,7 +615,7 @@ def handler(async_task):
d = [('Upscale (Fast)', 'upscale_fast', '2x')]
if modules.config.default_black_out_nsfw or black_out_nsfw:
progressbar(async_task, 100, 'Checking for NSFW content ...')
uov_input_image = censor_single(uov_input_image)
uov_input_image = default_censor(uov_input_image)
uov_input_image_path = log(uov_input_image, d, output_format=output_format)
yield_result(async_task, uov_input_image_path, black_out_nsfw, False, do_not_show_finished_images=True)
return
Expand Down Expand Up @@ -883,12 +883,12 @@ def callback(step, x0, x, total_steps, y):
imgs = [inpaint_worker.current_task.post_process(x) for x in imgs]

img_paths = []

current_progress = int(15.0 + 85.0 * float((current_task_id + 1) * steps) / float(all_steps))
if modules.config.default_black_out_nsfw or black_out_nsfw:
progressbar(async_task, int(15.0 + 85.0 * float((current_task_id + 1) * steps) / float(all_steps)),
'Checking for NSFW content ...')
imgs = censor_batch(imgs)
progressbar(async_task, current_progress, 'Checking for NSFW content ...')
imgs = default_censor(imgs)

progressbar(async_task, current_progress, 'Saving image to system ...')
for x in imgs:
d = [('Prompt', 'prompt', task['log_positive_prompt']),
('Negative Prompt', 'negative_prompt', task['log_negative_prompt']),
Expand Down