Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Finetuning and Config used for HPS #17

Open
anonymous-atom opened this issue Nov 9, 2024 · 9 comments
Open

Finetuning and Config used for HPS #17

anonymous-atom opened this issue Nov 9, 2024 · 9 comments

Comments

@anonymous-atom
Copy link

anonymous-atom commented Nov 9, 2024

@mihirp1998
I was tryin to finetune Stable Diffusion 1.5 using your HPS reward function and the hps.sh training script, I used batch size of 1 but still the training seems to get completed very quickly, 50 epochs just took 2-4 minutes.

def from_file(path, low=None, high=None):
    prompts = _load_lines(path)[low:high]
    return random.choice(prompts), {}

And here you are trying to just use batch_size number of prompts ? I am using batch_size of 2 on 1 A100 GPU to test the script.

    def _generate_samples(self, batch_size, with_grad=True, prompts=None):
        """
        Generate samples from the model

        Args:
            batch_size (int): Batch size to use for sampling
            with_grad (bool): Whether the generated RGBs should have gradients attached to it.

        Returns:
            prompt_image_pairs (Dict[Any])
        """
        prompt_image_pairs = {}

        sample_neg_prompt_embeds = self.neg_prompt_embed.repeat(batch_size, 1, 1)

        if prompts is None:
            prompts, prompt_metadata = zip(*[self.prompt_fn() for _ in range(batch_size)])

Your help will mean a lot!

@mihirp1998
Copy link
Owner

Can u check whether the training is actually getting run ? if not why is it skipping the training loop

@anonymous-atom
Copy link
Author

It's working but I think it's only using few examples from the prompt file.

@mihirp1998
Copy link
Owner

i think random.choice is uniform over all prompts, i'm not sure what's the bug here. If you find it let me know.

@anonymous-atom
Copy link
Author

Yeah Sure!

@anonymous-atom
Copy link
Author

Also while I was tryin to train with a custom loss function, the models seem to collapse very early unless I adjust the learning rate. is this the expected behaviour ?

@anonymous-atom
Copy link
Author

@mihirp1998 Sorry to tag you again, but can you let me know how much time it took per epoch on your 4 A100 GPU's ?

Clear Skies!

@mihirp1998
Copy link
Owner

mihirp1998 commented Nov 9, 2024 via email

@anonymous-atom
Copy link
Author

I wanted to confirm are you using all 750 prompts in hps_v2_all.txt file for 1 single epoch ?

@anonymous-atom
Copy link
Author

Hi @mihirp1998 , so this is where I am confused:

So you used step() function to do 1 training step/ 1 epoch ?

    def train(self, epochs: Optional[int] = None):
        """
        Train the model for a given number of epochs
        """
        global_step = 0
        if epochs is None:
            epochs = self.config.num_epochs
        for epoch in range(self.first_epoch, epochs):
            global_step = self.step(epoch, global_step)

And here in step() function, it only seems to finetine on num_gpus * batch_size * train_gradient_accumulation_steps number of images, am I missing something ? What if someone used just 1 GPU to train ?

    def step(self, epoch: int, global_step: int):
   
        info = defaultdict(list)
        print(f"Epoch: {epoch}, Global Step: {global_step}")

        self.sd_pipeline.unet.train()

        for _ in range(self.config.train_gradient_accumulation_steps):
            with self.accelerator.accumulate(self.sd_pipeline.unet), self.autocast(), torch.enable_grad():
                prompt_image_pairs = self._generate_samples(
                    batch_size=self.config.train_batch_size,
                )
                

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants