diff --git a/runner/app/pipelines/image_to_image.py b/runner/app/pipelines/image_to_image.py index f42e6550..9e20ff03 100644 --- a/runner/app/pipelines/image_to_image.py +++ b/runner/app/pipelines/image_to_image.py @@ -83,8 +83,8 @@ def __init__(self, model_id: str): elif "8step" in model_id: unet_id = "sdxl_lightning_8step_unet" else: - # Default to 2step - unet_id = "sdxl_lightning_2step_unet" + # Default to 8step + unet_id = "sdxl_lightning_8step_unet" unet_config = UNet2DConditionModel.load_config( pretrained_model_name_or_path=base, @@ -219,8 +219,8 @@ def __call__( elif "8step" in self.model_id: kwargs["num_inference_steps"] = 8 else: - # Default to 2step - kwargs["num_inference_steps"] = 2 + # Default to 8step + kwargs["num_inference_steps"] = 8 output = self.ldm(prompt, image=image, **kwargs) diff --git a/runner/app/pipelines/text_to_image.py b/runner/app/pipelines/text_to_image.py index 84a2228e..69302614 100644 --- a/runner/app/pipelines/text_to_image.py +++ b/runner/app/pipelines/text_to_image.py @@ -93,8 +93,8 @@ def __init__(self, model_id: str): elif "8step" in model_id: unet_id = "sdxl_lightning_8step_unet" else: - # Default to 2step - unet_id = "sdxl_lightning_2step_unet" + # Default to 8step + unet_id = "sdxl_lightning_8step_unet" unet_config = UNet2DConditionModel.load_config( pretrained_model_name_or_path=base,