Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev quant tools and fix graph file management #495

Merged
merged 29 commits into from
Jan 16, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions onediff_comfy_nodes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,37 @@
"OneDiffCheckpointLoaderSimple": "Load Checkpoint - OneDiff",
}


if _USE_UNET_INT8:
from ._nodes import UNETLoaderInt8, Quant8Model
from ._nodes import UNETLoaderInt8, Quant8Model, OneDiffQuantCheckpointLoaderSimple
from ._quant_tools import (
UnetQuantKSampler,
FineTuneCalibrateInfo,
SaveQuantizedConfig,
LoadQuantizedConfig,
)

NODE_CLASS_MAPPINGS.update(
{"UNETLoaderInt8": UNETLoaderInt8, "Quant8Model": Quant8Model}
{
"UNETLoaderInt8": UNETLoaderInt8,
"Quant8Model": Quant8Model,
"OneDiffQuantCheckpointLoaderSimple": OneDiffQuantCheckpointLoaderSimple,
"FineTuneCalibrateInfo": FineTuneCalibrateInfo,
"SaveCalibrateInfo": SaveQuantizedConfig,
"LoadCalibrateInfo": LoadQuantizedConfig,
}
)

NODE_DISPLAY_NAME_MAPPINGS.update(
{
"UNETLoaderInt8": "UNET Loader Int8",
"Quant8Model": "Model Quantization(int8)",
"OneDiffQuantCheckpointLoaderSimple": "Load Checkpoint - OneDiff Quant",
"FineTuneCalibrateInfo": "Fine Tune Calibrate Info",
"SaveCalibrateInfo": "Save Calibrate Info",
"LoadCalibrateInfo": "Load Calibrate Info",
}
)

NODE_CLASS_MAPPINGS.update({"UnetQuantKSampler": UnetQuantKSampler})
NODE_DISPLAY_NAME_MAPPINGS.update({"UnetQuantKSampler": "Unet Quant K Sampler"})
ccssu marked this conversation as resolved.
Show resolved Hide resolved
10 changes: 9 additions & 1 deletion onediff_comfy_nodes/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import sys
from pathlib import Path

ONEDIFF_QUANTIZED_OPTIMIZED_MODELS = "onediff_fastquant_models"
_USE_UNET_INT8 = True

COMFYUI_ROOT = Path(os.path.abspath(__file__)).parents[2]
COMFYUI_SPEEDUP_ROOT = Path(os.path.abspath(__file__)).parents[0]
INFER_COMPILER_REGISTRY = Path(COMFYUI_SPEEDUP_ROOT) / "infer_compiler_registry"
Expand All @@ -30,3 +30,11 @@
[str(unet_int8_model_dir)],
supported_pt_extensions,
)

opt_models_dir = Path(models_dir) / ONEDIFF_QUANTIZED_OPTIMIZED_MODELS
opt_models_dir.mkdir(parents=True, exist_ok=True)

folder_names_and_paths[ONEDIFF_QUANTIZED_OPTIMIZED_MODELS] = (
[str(opt_models_dir)],
supported_pt_extensions,
)
104 changes: 97 additions & 7 deletions onediff_comfy_nodes/_nodes.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from functools import partial
from onediff.infer_compiler.transform import torch2oflow
from onediff.infer_compiler.with_oneflow_compile import oneflow_compile
from ._config import _USE_UNET_INT8
from ._config import _USE_UNET_INT8, ONEDIFF_QUANTIZED_OPTIMIZED_MODELS

import os
import re
Expand Down Expand Up @@ -420,6 +420,7 @@ def INPUT_TYPES(s):
"required": {
"model": ("MODEL",),
"static_mode": (["enable", "disable"],),
"no_compile": (["disable", "enable"],),
"cache_interval": (
"INT",
{
Expand Down Expand Up @@ -468,13 +469,15 @@ def deep_cache_convert(
self,
model,
static_mode,
no_compile,
cache_interval,
cache_layer_id,
cache_block_id,
start_step,
end_step,
):
use_graph = static_mode == "enable"
no_compile = (no_compile == "enable") or (not use_graph)
ccssu marked this conversation as resolved.
Show resolved Hide resolved

offload_device = model_management.unet_offload_device()
oneflow_model = OneFlowDeepCacheSpeedUpModelPatcher(
Expand All @@ -484,6 +487,7 @@ def deep_cache_convert(
cache_layer_id=cache_layer_id,
cache_block_id=cache_block_id,
use_graph=use_graph,
no_compile=no_compile,
)

current_t = -1
Expand Down Expand Up @@ -603,13 +607,9 @@ def INPUT_TYPES(s):
}

CATEGORY = "OneDiff"
ccssu marked this conversation as resolved.
Show resolved Hide resolved
FUNCTION = "onediff_load_checkpoint"

def load_checkpoint(
self, ckpt_name, output_vae=True, output_clip=True, vae_speedup="disable"
):
modelpatcher, clip, vae = super().load_checkpoint(
ckpt_name, output_vae, output_clip
)
def speedup_unet(self, ckpt_name, modelpatcher):
offload_device = model_management.unet_offload_device()
load_device = model_management.get_torch_device()

Expand Down Expand Up @@ -639,6 +639,17 @@ def load_checkpoint(
modelpatcher.model.diffusion_model = diffusion_model
modelpatcher.model._register_state_dict_hook(state_dict_hook)

return modelpatcher

def onediff_load_checkpoint(
self, ckpt_name, output_vae=True, output_clip=True, vae_speedup="disable"
):
modelpatcher, clip, vae = CheckpointLoaderSimple.load_checkpoint(
ckpt_name, output_vae, output_clip
)

modelpatcher = self.speedup_unet(ckpt_name, modelpatcher)

if vae_speedup == "enable":
file_path = generate_graph_path(ckpt_name, vae.first_stage_model)
vae.first_stage_model = oneflow_compile(
Expand All @@ -653,3 +664,82 @@ def load_checkpoint(
# set inplace update
modelpatcher.weight_inplace_update = True
return modelpatcher, clip, vae


class OneDiffQuantCheckpointLoaderSimple(OneDiffCheckpointLoaderSimple):
@classmethod
def INPUT_TYPES(s):
paths = []
for search_path in folder_paths.get_folder_paths(
ONEDIFF_QUANTIZED_OPTIMIZED_MODELS
):
if os.path.exists(search_path):
search_path = Path(search_path)
paths.extend(
[
os.path.relpath(p, start=search_path)
for p in search_path.glob("*.pt")
]
)

return {
"required": {
"ckpt_name": (folder_paths.get_filename_list("checkpoints"),),
"model_path": (paths,),
"no_compile": (["disable", "enable"],),
"compute_density_threshold": (
"INT",
{
"default": 600,
"min": 1,
"max": 10000,
"step": 1,
"display": "number",
},
),
}
}

CATEGORY = "OneDiff/Loaders"
FUNCTION = "onediff_load_checkpoint"

def onediff_load_checkpoint(
self,
ckpt_name,
no_compile,
compute_density_threshold,
model_path,
output_vae=True,
output_clip=True,
):
no_compile = no_compile == "enable"
from onediff.optimization.quant_optimizer import quantize_model

modelpatcher, clip, vae = self.load_checkpoint(
ckpt_name, output_vae, output_clip
)

ckpt_name = f"{ckpt_name}_quant"
model_path = (
Path(folder_paths.models_dir)
/ ONEDIFF_QUANTIZED_OPTIMIZED_MODELS
/ model_path
)

diffusion_model = modelpatcher.model.diffusion_model
diffusion_model = quantize_model(
model=diffusion_model,
inplace=True,
quant_config_file=str(model_path),
compute_density_threshold=compute_density_threshold,
conv_ssim_threshold=0.98,
linear_ssim_threshold=0.98,
)
modelpatcher.model.diffusion_model = diffusion_model

if not no_compile:
modelpatcher = self.speedup_unet(ckpt_name, modelpatcher)

# set inplace update
modelpatcher.weight_inplace_update = True
return modelpatcher, clip, vae
Loading