Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Each pipe run generates different image (run with the same params and seed) #1087

Open
liorplaytika opened this issue Aug 18, 2024 · 2 comments
Labels
Request-bug Something isn't working

Comments

@liorplaytika
Copy link

liorplaytika commented Aug 18, 2024

Your current environment information

onediff = "1.2.0"
oneflow = { version = "0.9.1.dev.20240510", source = "oneflow_source" }
onediffx = "1.2.0"

🐛 Describe the bug

Each pipe run using onediffx generates different image when using exactly the same params and seed.

Example code:

import numpy as np
from PIL import Image
import torch
from diffusers import AutoPipelineForText2Image
from onediffx import compile_pipe
from onediffx.lora import load_and_fuse_lora
from onediffx.lora import unfuse_lora


def start(checkpoint):
    device = "cuda"
    dtype = torch.float16

    pipe = AutoPipelineForText2Image.from_pretrained(checkpoint,
                                                     variant='fp16',
                                                     torch_dtype=dtype
                                                     ).to(device)


    pipe = compile_pipe(pipe, ignores=("vae.encoder", "vae.decoder"))

    # warmup
    pipe(prompt="prompt",
         negative_prompt="negative_prompt",
         num_inference_steps=5,
         width=128,
         height=128,
         num_images_per_prompt=1)


    return pipe


def main(pipe, checkpoint, iteration_number):
    input_request = {
        'checkpoint': checkpoint,
        'prompt': 'Woman wearing pink pijamas, pixel art',
        'negative_prompt': 'Red, Dots',
        'generator': torch.Generator(device="cuda").manual_seed(0),
        'num_images_per_prompt': 2,
        'num_inference_steps': 30,
        'width': 1024,
        'height': 1024,
        'guidance_scale': 8.1,
    }


    load_and_fuse_lora(pipe, "nerijs/pixel-art-xl", lora_scale=1)
    res = pipe(**input_request).images[0]
    unfuse_lora(pipe)
    pipe.unload_lora_weights()
    torch.cuda.empty_cache()

    res.save(f'pixel_image_iteration_{iteration_number}.png')


if __name__ == "__main__":
    checkpoint = 'stabilityai/stable-diffusion-xl-base-1.0'

    pipe = start(checkpoint)

    for i in range(2):
        main(pipe, checkpoint, i)

    for i in range(1):
        Image.fromarray(np.asarray(Image.open(f'pixel_image_iteration_{i}.png')) - (
            Image.open(f'pixel_image_iteration_{i + 1}.png'))).save(f'pixel_diff_{i}_{i + 1}.png')

The images I got:
pixel_image_iteration_0

pixel_image_iteration_1

There are differences between these two images that expected to be the same. Another way to see the difference, by minus between this 2 images, the result isn't black as expected -

pixel_diff_0_1

I checked the code without loading lora and the generated images are exactly the same. So as I see it the problem appears at "load_and_fuse_lora" function.

@liorplaytika liorplaytika added the Request-bug Something isn't working label Aug 18, 2024
@strint
Copy link
Collaborator

strint commented Sep 5, 2024

So as I see it the problem appears at "load_and_fuse_lora" function.

Have you try the native API of diffusers to load and fuse lora?

@Amitg1
Copy link

Amitg1 commented Sep 23, 2024

Yes, it crashes. @strint
diffusers 0.29.2
onediff 1.2.0
onediffx 1.2.0
torch 2.4.0

changed to

#load_and_fuse_lora(pipe, "nerijs/pixel-art-xl", lora_scale=1)
pipe.load_lora_weights("nerijs/pixel-art-xl")
res = pipe(**input_request).images[0]
# unfuse_lora(pipe)
pipe.unload_lora_weights()

error

RuntimeError Traceback (most recent call last)
File /api/.venv/lib/python3.10/site-packages/onediff/infer_compiler/backends/oneflow/transform/builtin_transform.py:229, in _..proxy_getattr(self, attr)
228 try:
--> 229 return super().getattribute(attr)
230 except Exception as e:

RuntimeError: super(): class cell not found

During handling of the above exception, another exception occurred:

AttributeError Traceback (most recent call last)
Cell In[2], line 67
64 pipe = start(checkpoint)
66 for i in range(2):
---> 67 main(pipe, checkpoint, i)
69 for i in range(1):
70 Image.fromarray(np.asarray(Image.open(f'pixel_image_iteration_{i}.png')) - (
71 Image.open(f'pixel_image_iteration_{i + 1}.png'))).save(f'pixel_diff_{i}_{i + 1}.png')

Cell In[2], line 52, in main(pipe, checkpoint, iteration_number)
38 input_request = {
39 'checkpoint': checkpoint,
40 'prompt': 'Woman wearing pink pijamas, pixel art',
(...)
47 'guidance_scale': 8.1,
48 }
51 # load_and_fuse_lora(pipe, "nerijs/pixel-art-xl", lora_scale=1)
---> 52 pipe.load_lora_weights("nerijs/pixel-art-xl")
53 res = pipe(**input_request).images[0]
54 # unfuse_lora(pipe)

File /api/.venv/lib/python3.10/site-packages/diffusers/loaders/lora.py:1234, in StableDiffusionXLLoraLoaderMixin.load_lora_weights(self, pretrained_model_name_or_path_or_dict, adapter_name, **kwargs)
1231 if not is_correct_format:
1232 raise ValueError("Invalid LoRA checkpoint.")
-> 1234 self.load_lora_into_unet(
1235 state_dict, network_alphas=network_alphas, unet=self.unet, adapter_name=adapter_name, _pipeline=self
1236 )
1237 text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
1238 if len(text_encoder_state_dict) > 0:

File /api/.venv/lib/python3.10/site-packages/diffusers/loaders/lora.py:401, in LoraLoaderMixin.load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name, _pipeline)
398 if not only_text_encoder:
399 # Load the layers corresponding to UNet.
400 logger.info(f"Loading {cls.unet_name}.")
--> 401 unet.load_attn_procs(
402 state_dict, network_alphas=network_alphas, adapter_name=adapter_name, _pipeline=_pipeline
403 )

File /api/.venv/lib/python3.10/site-packages/huggingface_hub/utils/_validators.py:114, in validate_hf_hub_args.._inner_fn(*args, **kwargs)
111 if check_use_auth_token:
112 kwargs = smoothly_deprecate_use_auth_token(fn_name=fn.name, has_token=has_token, kwargs=kwargs)
--> 114 return fn(*args, **kwargs)

File /api/.venv/lib/python3.10/site-packages/diffusers/loaders/unet.py:217, in UNet2DConditionLoadersMixin.load_attn_procs(self, pretrained_model_name_or_path_or_dict, **kwargs)
215 attn_processors = self._process_custom_diffusion(state_dict=state_dict)
216 elif is_lora:
--> 217 is_model_cpu_offload, is_sequential_cpu_offload = self._process_lora(
218 state_dict=state_dict,
219 unet_identifier_key=self.unet_name,
220 network_alphas=network_alphas,
221 adapter_name=adapter_name,
222 _pipeline=_pipeline,
223 )
224 else:
225 raise ValueError(
226 f"{model_file} does not seem to be in the correct format expected by Custom Diffusion training."
227 )

File /api/.venv/lib/python3.10/site-packages/diffusers/loaders/unet.py:350, in UNet2DConditionLoadersMixin._process_lora(self, state_dict, unet_identifier_key, network_alphas, adapter_name, _pipeline)
346 # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
347 # otherwise loading LoRA weights will lead to an error
348 is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline)
--> 350 inject_adapter_in_model(lora_config, self, adapter_name=adapter_name)
351 incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name)
353 if incompatible_keys is not None:
354 # check only for unexpected keys

File /api/.venv/lib/python3.10/site-packages/peft/mapping.py:215, in inject_adapter_in_model(peft_config, model, adapter_name)
212 tuner_cls = PEFT_TYPE_TO_TUNER_MAPPING[peft_config.peft_type]
214 # By instantiating a peft model we are injecting randomly initialized LoRA layers into the model's modules.
--> 215 peft_model = tuner_cls(model, peft_config, adapter_name=adapter_name)
217 return peft_model.model

File /api/.venv/lib/python3.10/site-packages/peft/tuners/lora/model.py:139, in LoraModel.init(self, model, config, adapter_name)
138 def init(self, model, config, adapter_name) -> None:
--> 139 super().init(model, config, adapter_name)

File /api/.venv/lib/python3.10/site-packages/peft/tuners/tuners_utils.py:175, in BaseTuner.init(self, model, peft_config, adapter_name)
173 self._pre_injection_hook(self.model, self.peft_config[adapter_name], adapter_name)
174 if peft_config != PeftType.XLORA or peft_config[adapter_name] != PeftType.XLORA:
--> 175 self.inject_adapter(self.model, adapter_name)
177 # Copy the peft_config in the injected model.
178 self.model.peft_config = self.peft_config

File /api/.venv/lib/python3.10/site-packages/peft/tuners/tuners_utils.py:431, in BaseTuner.inject_adapter(self, model, adapter_name, autocast_adapter_dtype)
429 is_target_modules_in_base_model = True
430 parent, target, target_name = _get_submodules(model, key)
--> 431 self._create_and_replace(peft_config, adapter_name, target, target_name, parent, current_key=key)
433 # Handle X-LoRA case.
434 if not is_target_modules_in_base_model and hasattr(peft_config, "target_modules"):

File /api/.venv/lib/python3.10/site-packages/peft/tuners/lora/model.py:214, in LoraModel._create_and_replace(self, lora_config, adapter_name, target, target_name, parent, current_key)
211 from peft.tuners.adalora import AdaLoraLayer
213 if isinstance(target, LoraLayer) and not isinstance(target, AdaLoraLayer):
--> 214 target.update_layer(
215 adapter_name,
216 r,
217 lora_alpha=alpha,
218 lora_dropout=lora_config.lora_dropout,
219 init_lora_weights=lora_config.init_lora_weights,
220 use_rslora=lora_config.use_rslora,
221 use_dora=lora_config.use_dora,
222 )
223 else:
224 new_module = self._create_new_module(lora_config, adapter_name, target, **kwargs)

File /api/.venv/lib/python3.10/site-packages/peft/tuners/lora/layer.py:109, in LoraLayer.update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora, use_dora)
106 if r <= 0:
107 raise ValueError(f"r should be a positive integer value but the value passed is {r}")
--> 109 self.r[adapter_name] = r
110 self.lora_alpha[adapter_name] = lora_alpha
111 if lora_dropout > 0.0:

File /api/.venv/lib/python3.10/site-packages/onediff/infer_compiler/backends/oneflow/dual_module.py:89, in DualModule.getattr(self, name)
83 return super().getattribute(name)
85 torch_attr = getattr(self._torch_module, name)
86 oneflow_attr = (
87 None
88 if self._oneflow_module is None
---> 89 else getattr(self._oneflow_module, name)
90 )
92 if isinstance(torch_attr, torch.nn.ModuleList):
93 if oneflow_attr is None:

File /api/.venv/lib/python3.10/site-packages/onediff/infer_compiler/backends/oneflow/transform/builtin_transform.py:238, in _..proxy_getattr(self, attr)
236 return self._buffers[attr]
237 else:
--> 238 return getattr(proxy_md, attr)

File /api/.venv/lib/python3.10/site-packages/onediff/infer_compiler/backends/oneflow/transform/builtin_transform.py:129, in ProxySubmodule.getattribute(self, attribute)
127 return "channels_first"
128 else:
--> 129 a = getattr(self._oflow_proxy_submod, attribute)
131 if isinstance(a, (torch.nn.parameter.Parameter, torch.Tensor)):
132 # TODO(oneflow): assert a.requires_grad == False
133 if attribute not in self._oflow_proxy_parameters:

File /api/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1729, in Module.getattr(self, name)
1727 if name in modules:
1728 return modules[name]
-> 1729 raise AttributeError(f"'{type(self).name}' object has no attribute '{name}'")

AttributeError: 'Linear' object has no attribute 'r'

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Request-bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants