Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix bug in webui when switching from quantized to non-quantized #830

Merged
merged 2 commits into from
Apr 19, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions onediff_sd_webui_extensions/scripts/onediff.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

"""oneflow_compiled UNetModel"""
compiled_unet = None
is_compiled_unet_quantized = False
compiled_ckpt_name = None

def generate_graph_path(ckpt_name: str, model_name: str) -> str:
Expand Down Expand Up @@ -74,7 +75,8 @@ def compile_unet(
RuntimeWarning,
)
compiled_unet = unet_model
if quantization:
# In OneDiff Community, quantization can be True when called by api
if quantization and varify_can_use_quantization():
calibrate_info = get_calibrate_info(
f"{Path(select_checkpoint().filename).stem}_sd_calibrate_info.txt"
)
Expand Down Expand Up @@ -142,7 +144,7 @@ def ui(self, is_img2img):
def show(self, is_img2img):
return True

def check_model_structure_change(self, model):
def check_model_change(self, model):
is_changed = False

def get_model_type(model):
Expand Down Expand Up @@ -170,23 +172,25 @@ def run(self, p, quantization=False):
if isinstance(quantization, str):
quantization = False

global compiled_unet, compiled_ckpt_name
global compiled_unet, compiled_ckpt_name, is_unet_quantized
current_checkpoint = shared.opts.sd_model_checkpoint
original_diffusion_model = shared.sd_model.model.diffusion_model

ckpt_name = (
current_checkpoint + "_quantized" if quantization else current_checkpoint
ckpt_changed = current_checkpoint != compiled_ckpt_name
model_changed = self.check_model_change(shared.sd_model)
quantization_changed = quantization != is_compiled_unet_quantized
need_recompile = (
(quantization and ckpt_changed) # always recompile when switching ckpt with 'int8 speed model' enabled
or model_changed # always recompile when switching model to another structure
or quantization_changed # always recompile when switching model from non-quantized to quantized (and vice versa)
)

model_changed = ckpt_name != compiled_ckpt_name
model_structure_changed = self.check_model_structure_change(shared.sd_model)
need_recompile = (quantization and model_changed) or model_structure_changed

is_unet_quantized = quantization
compiled_ckpt_name = current_checkpoint
if need_recompile:
compiled_unet = compile_unet(
original_diffusion_model, quantization=quantization
)
compiled_ckpt_name = ckpt_name
else:
logger.info(
f"Model {current_checkpoint} has same sd type of graph type {self.current_type}, skip compile"
Expand Down