Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for CLIP Skip #236

Merged
merged 1 commit into from
Aug 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,12 @@
action="store_true",
help="Use offline model",
)
parser.add_argument(
"--clip_skip",
type=int,
help="CLIP Skip (1-12), default : 1 (disabled) ",
default=1,
)
parser.add_argument(
"--use_safety_checker",
action="store_true",
Expand Down Expand Up @@ -334,6 +340,7 @@
else:
config.lcm_diffusion_setting.use_seed = False
config.lcm_diffusion_setting.use_offline_model = args.use_offline_model
config.lcm_diffusion_setting.clip_skip = args.clip_skip
config.lcm_diffusion_setting.use_safety_checker = args.use_safety_checker

# Read custom settings from JSON file
Expand Down
8 changes: 8 additions & 0 deletions src/backend/lcm_text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,12 @@ def generate(
if self.is_openvino_init:
self.is_openvino_init = False

pipeline_extra_args = {}
if lcm_diffusion_setting.clip_skip > 1:
# We follow the convention that "CLIP Skip == 2" means "skip
# the last layer", so "CLIP Skip == 1" means "no skipping"
pipeline_extra_args['clip_skip'] = lcm_diffusion_setting.clip_skip - 1

if not lcm_diffusion_setting.use_safety_checker:
self.pipeline.safety_checker = None
if (
Expand Down Expand Up @@ -369,6 +375,7 @@ def generate(
height=lcm_diffusion_setting.image_height,
num_images_per_prompt=lcm_diffusion_setting.number_of_images,
timesteps=self._get_timesteps(),
**pipeline_extra_args,
**controlnet_args,
).images

Expand All @@ -386,6 +393,7 @@ def generate(
width=lcm_diffusion_setting.image_width,
height=lcm_diffusion_setting.image_height,
num_images_per_prompt=lcm_diffusion_setting.number_of_images,
**pipeline_extra_args,
**controlnet_args,
).images
return result_images
1 change: 1 addition & 0 deletions src/backend/models/lcmdiffusion_setting.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class LCMDiffusionSetting(BaseModel):
image_width: Optional[int] = 512
inference_steps: Optional[int] = 1
guidance_scale: Optional[float] = 1
clip_skip: Optional[int] = 1
number_of_images: Optional[int] = 1
seed: Optional[int] = 123123
use_seed: bool = False
Expand Down
17 changes: 17 additions & 0 deletions src/frontend/gui/app_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ def init_ui_values(self):
self.inference_steps.setValue(
int(self.config.settings.lcm_diffusion_setting.inference_steps)
)
self.clip_skip.setValue(
int(self.config.settings.lcm_diffusion_setting.clip_skip)
)
self.seed_check.setChecked(self.config.settings.lcm_diffusion_setting.use_seed)
self.seed_value.setText(str(self.config.settings.lcm_diffusion_setting.seed))
self.use_local_model_folder.setChecked(
Expand Down Expand Up @@ -246,6 +249,13 @@ def create_settings_tab(self):
self.guidance.setValue(10)
self.guidance.valueChanged.connect(self.update_guidance_label)

self.clip_skip_value = QLabel("CLIP Skip: 1")
self.clip_skip = QSlider(orientation=Qt.Orientation.Horizontal)
self.clip_skip.setMaximum(12)
self.clip_skip.setMinimum(1)
self.clip_skip.setValue(1)
self.clip_skip.valueChanged.connect(self.update_clip_skip_label)

self.width_value = QLabel("Width :")
self.width = QComboBox(self)
self.width.addItem("256")
Expand Down Expand Up @@ -340,6 +350,8 @@ def create_settings_tab(self):
vlayout.addWidget(self.height)
vlayout.addWidget(self.guidance_value)
vlayout.addWidget(self.guidance)
vlayout.addWidget(self.clip_skip_value)
vlayout.addWidget(self.clip_skip)
vlayout.addLayout(hlayout)
vlayout.addWidget(self.safety_checker)

Expand Down Expand Up @@ -491,6 +503,10 @@ def use_lcm_lora_changed(self, state):
self.neg_prompt.setEnabled(False)
self.config.settings.lcm_diffusion_setting.use_lcm_lora = False

def update_clip_skip_label(self, value):
self.clip_skip_value.setText(f"CLIP Skip: {value}")
self.config.settings.lcm_diffusion_setting.clip_skip = value

def use_safety_checker_changed(self, state):
if state == 2:
self.config.settings.lcm_diffusion_setting.use_safety_checker = True
Expand Down Expand Up @@ -604,6 +620,7 @@ def reset_all_settings(self):
self.height.setCurrentText("512")
self.inference_steps.setValue(4)
self.guidance.setValue(10)
self.clip_skip.setValue(1)
self.use_openvino_check.setChecked(False)
self.seed_check.setChecked(False)
self.safety_checker.setChecked(False)
Expand Down
14 changes: 14 additions & 0 deletions src/frontend/webui/generation_settings_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ def on_change_guidance_scale(guidance_scale):
app_settings.settings.lcm_diffusion_setting.guidance_scale = guidance_scale


def on_change_clip_skip(clip_skip):
app_settings.settings.lcm_diffusion_setting.clip_skip = clip_skip


def on_change_seed_value(seed):
app_settings.settings.lcm_diffusion_setting.seed = seed

Expand Down Expand Up @@ -103,6 +107,15 @@ def get_generation_settings_ui() -> None:
label="Guidance Scale",
interactive=True,
)
clip_skip = gr.Slider(
1,
12,
value=app_settings.settings.lcm_diffusion_setting.clip_skip,
step=1,
label="CLIP Skip",
interactive=True,
)


seed = gr.Slider(
value=app_settings.settings.lcm_diffusion_setting.seed,
Expand Down Expand Up @@ -145,6 +158,7 @@ def get_generation_settings_ui() -> None:
image_width.change(on_change_image_width, image_width)
num_images.change(on_change_num_images, num_images)
guidance_scale.change(on_change_guidance_scale, guidance_scale)
clip_skip.change(on_change_clip_skip, clip_skip)
seed.change(on_change_seed_value, seed)
seed_checkbox.change(on_change_seed_checkbox, seed_checkbox)
safety_checker_checkbox.change(
Expand Down