Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

muti-step model for sdv1.5 #36

Open
kmpartner opened this issue Jul 20, 2024 · 5 comments
Open

muti-step model for sdv1.5 #36

kmpartner opened this issue Jul 20, 2024 · 5 comments

Comments

@kmpartner
Copy link

Is it possible to train 4steps (multi-steps) model for sd-v1.5 in this repo?
I see in sdxl experiment docs, but not in sdv1.5 experiment docs.

@tianweiy
Copy link
Owner

it is possible but you need some coding

@kmpartner
Copy link
Author

kmpartner commented Jul 20, 2024

Thank you for response.
Which part of code and what kind of modification is required to enable multi-step training for sdv1.5?
Is it denoising part?

@tianweiy
Copy link
Owner

i think you would need to modify this function

assert self.sdxl, "Denoising is only supported for SDXL"

first remove this assert and adapt the text encoder and backward simulation code (if you need the later one). you can see this line for how to modify the text encoder

real_text_embedding_output = self.text_encoder(real_train_dict["text_input_ids_one"].squeeze(1))

@tianweiy
Copy link
Owner

i can help reviewing your changes if you are interested to do a pull request

@kmpartner
Copy link
Author

kmpartner commented Jul 21, 2024

I updated text_embedding and pooled_text_embedding by text_encoder in prepare_denoising_data function.
Is this right way to enable multi-step in sdv1.5?

Is denoising_timestep 250 in sdxl case? Is there any reasons to use 250 other than to limit 4 steps?

def prepare_denoising_data(self, denoising_dict, real_train_dict, noise):
    # assert self.sdxl, "Denoising is only supported for SDXL"

    indices = torch.randint(
        0, self.num_denoising_step, (noise.shape[0],), device=noise.device, dtype=torch.long
    )
    timesteps = self.denoising_step_list.to(noise.device)[indices]

    # text_embedding, pooled_text_embedding = self.text_encoder(denoising_dict)
    
    if self.sdxl:
        text_embedding, pooled_text_embedding = self.text_encoder(denoising_dict)
    else:
      text_embedding_dict = self.text_encoder(denoising_dict["text_input_ids_one"].squeeze(1))
      text_embedding = text_embedding_dict["last_hidden_state"]
      pooled_text_embedding = text_embedding_dict["pooler_output"]
        

    if real_train_dict is not None:
        real_text_embedding, real_pooled_text_embedding = self.text_encoder(real_train_dict)

        real_train_dict['text_embedding'] = real_text_embedding

        real_unet_added_conditions = {
            "time_ids": self.add_time_ids.repeat(len(real_text_embedding), 1),
            "text_embeds": real_pooled_text_embedding
        }
        real_train_dict['unet_added_conditions'] = real_unet_added_conditions

    if self.backward_simulation:
        # we overwrite the denoising timesteps 
        # note: we also use uncorrelated noise 
        clean_images, timesteps = self.sample_backward(torch.randn_like(noise), text_embedding, pooled_text_embedding) 
    else:
        clean_images = denoising_dict['images'].to(noise.device)

    noisy_image = self.noise_scheduler.add_noise(
        clean_images, noise, timesteps
    )

    # set last timestep to pure noise
    pure_noise_mask = (timesteps == (self.num_train_timesteps-1))
    noisy_image[pure_noise_mask] = noise[pure_noise_mask]

    return timesteps, text_embedding, pooled_text_embedding, real_train_dict, noisy_image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants