From 557631c0a98547c1fa2c15b69afe05943e7e5838 Mon Sep 17 00:00:00 2001 From: ad-astra-video <99882368+ad-astra-video@users.noreply.github.com> Date: Fri, 16 Aug 2024 21:10:58 -0500 Subject: [PATCH] feat: update ByteDance/SDXL-Lighting to default to 8step (#162) * update ByteDance/SDXL-Lightning to default to 8 step unet * update I2I to 8step default for ByteDance/SDXL-Lightning model --- runner/app/pipelines/image_to_image.py | 8 ++++---- runner/app/pipelines/text_to_image.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/runner/app/pipelines/image_to_image.py b/runner/app/pipelines/image_to_image.py index f42e6550..9e20ff03 100644 --- a/runner/app/pipelines/image_to_image.py +++ b/runner/app/pipelines/image_to_image.py @@ -83,8 +83,8 @@ def __init__(self, model_id: str): elif "8step" in model_id: unet_id = "sdxl_lightning_8step_unet" else: - # Default to 2step - unet_id = "sdxl_lightning_2step_unet" + # Default to 8step + unet_id = "sdxl_lightning_8step_unet" unet_config = UNet2DConditionModel.load_config( pretrained_model_name_or_path=base, @@ -219,8 +219,8 @@ def __call__( elif "8step" in self.model_id: kwargs["num_inference_steps"] = 8 else: - # Default to 2step - kwargs["num_inference_steps"] = 2 + # Default to 8step + kwargs["num_inference_steps"] = 8 output = self.ldm(prompt, image=image, **kwargs) diff --git a/runner/app/pipelines/text_to_image.py b/runner/app/pipelines/text_to_image.py index 84a2228e..69302614 100644 --- a/runner/app/pipelines/text_to_image.py +++ b/runner/app/pipelines/text_to_image.py @@ -93,8 +93,8 @@ def __init__(self, model_id: str): elif "8step" in model_id: unet_id = "sdxl_lightning_8step_unet" else: - # Default to 2step - unet_id = "sdxl_lightning_2step_unet" + # Default to 8step + unet_id = "sdxl_lightning_8step_unet" unet_config = UNet2DConditionModel.load_config( pretrained_model_name_or_path=base,