Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove oudated args for load_checkpoint #962

Merged
merged 5 commits into from
Jun 21, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 8 additions & 22 deletions onediff_comfy_nodes/extras_nodes/nodes_oneflow_booster.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
from comfy.cli_args import args

from onediff.utils.import_utils import is_onediff_quant_available
from onediff.infer_compiler.backends.oneflow.utils.version_util import is_community_version
from onediff.infer_compiler.backends.oneflow.utils.version_util import (
is_community_version,
)


from ..modules import BoosterScheduler
Expand Down Expand Up @@ -316,8 +318,6 @@ def onediff_load_checkpoint(
self,
ckpt_name,
vae_speedup,
output_vae=True,
output_clip=True,
static_mode="enable",
cache_interval=3,
cache_layer_id=0,
Expand All @@ -326,9 +326,7 @@ def onediff_load_checkpoint(
end_step=1000,
):
# CheckpointLoaderSimple.load_checkpoint
modelpatcher, clip, vae = self.load_checkpoint(
ckpt_name, output_vae, output_clip
)
modelpatcher, clip, vae = self.load_checkpoint(ckpt_name)
booster = BoosterScheduler(
DeepcacheBoosterExecutor(
cache_interval=cache_interval,
Expand Down Expand Up @@ -618,12 +616,8 @@ def INPUT_TYPES(s):
CATEGORY = "OneDiff/Loaders"
FUNCTION = "onediff_load_checkpoint"

def onediff_load_checkpoint(
self, ckpt_name, vae_speedup, output_vae=True, output_clip=True
):
modelpatcher, clip, vae = self.load_checkpoint(
ckpt_name, output_vae, output_clip
)
def onediff_load_checkpoint(self, ckpt_name, vae_speedup):
modelpatcher, clip, vae = self.load_checkpoint(ckpt_name)
booster = BoosterScheduler(
OnelineQuantizationBoosterExecutor(
conv_percentage=100,
Expand Down Expand Up @@ -671,19 +665,11 @@ def INPUT_TYPES(s):
FUNCTION = "onediff_load_checkpoint"

def onediff_load_checkpoint(
self,
ckpt_name,
model_path,
compile,
vae_speedup,
output_vae=True,
output_clip=True,
self, ckpt_name, model_path, compile, vae_speedup,
):
need_compile = compile == "enable"

modelpatcher, clip, vae = self.load_checkpoint(
ckpt_name, output_vae, output_clip
)
modelpatcher, clip, vae = self.load_checkpoint(ckpt_name)
# TODO fix by op.compile
from ..modules.oneflow.utils.onediff_load_utils import (
onediff_load_quant_checkpoint_advanced,
Expand Down