Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

class-preservation target loss for LoRA / LyCORIS #1031

Closed
bghira opened this issue Oct 5, 2024 · 7 comments · Fixed by #1050
Closed

class-preservation target loss for LoRA / LyCORIS #1031

bghira opened this issue Oct 5, 2024 · 7 comments · Fixed by #1050
Labels
difficult-feature documentation Improvements or additions to documentation work-in-progress This issue relates to some currently in-progress work.

Comments

@bghira
Copy link
Owner

bghira commented Oct 5, 2024

the idea is based on this pastebin entry: https://pastebin.com/3eRwcAJD

snippet:

                    if batch['prompt'][0] == "woman":
                        with torch.no_grad():
                            self.model.transformer_lora.remove_hook_from_module()
                            regmodel_output_data = self.model_setup.predict(self.model, batch, self.config, train_progress)
                            self.model.transformer_lora.hook_to_module()
 
                        model_output_data = self.model_setup.predict(self.model, batch, self.config, train_progress)
                        model_output_data['target']=regmodel_output_data['predicted']
                        loss = self.model_setup.calculate_loss(self.model, batch, model_output_data, self.config)
                        loss *= 1.0
                        print("\nregmodel loss:",loss)
                    else:
 
                        model_output_data = self.model_setup.predict(self.model, batch, self.config, train_progress)
 
                        loss = self.model_setup.calculate_loss(self.model, batch, model_output_data, self.config)

the idea is that we can set a flag inside the multidatabackend.json for a dataset that contains our regularisation data.

instead of training on this data as we currently do, we will instead;

  • temporarily disable the lora/lycoris adapter
  • run a prediction using the regularisation data on the (probably quantised) base model network
  • re-enable the lora/lycoris adapter
  • run the prediction on the adapter
  • update the loss target from the clean latent to the base model prediction

instead of checking for woman in the first element's caption, the batch will come with a flag to enable this behaviour, from multidatabackend.json somehow.

this will indeed run more slowly as it runs two forward passes during training from the regularisation dataset but it has the intended effect of maintaining the original model's outputs for the given inputs, which helps substantially prevent subject bleed.

note: i'm not aware of the author of the code snippet, but i would love to give credit to whoever did create it.

example that came with the snippet:

image

requested by a user on the terminus research discord.

@bghira bghira added documentation Improvements or additions to documentation work-in-progress This issue relates to some currently in-progress work. difficult-feature labels Oct 5, 2024
@dxqbYD
Copy link

dxqbYD commented Oct 6, 2024

I'm the author of this. I am not entirely convinced yet myself that this is a useful feature. It seems to limit somewhat the training of the concept you do want to change ("ohwx woman" in this sample), by insisting that the concept "woman" remains exactly the same during training.

this was an experiment I first ran yesterday, so I have limited test data myself. Training TE or training additional embeddings might overcome the issue mentioned above by separating the concepts in TE space? I am currently trying embeddings.

Happy to help with your implementation of this!

@AmericanPresidentJimmyCarter

TIPO with random seeds and temperatures can be used to generate random prompts for related concepts. It can do tags -> natural language prompt or short prompt -> long prompt.

https://huggingface.co/KBlueLeaf/TIPO-500M

Screenshot_2024-10-06_20-00-26

@AmericanPresidentJimmyCarter

this was an experiment I first ran yesterday, so I have limited test data myself. Training TE or training additional embeddings might overcome the issue mentioned above by separating the concepts in TE space? I am currently trying embeddings.

There is no need to train the text encoder for flux models, as the model is partially a large text encoder aligned to image space.

@dxqbYD
Copy link

dxqbYD commented Oct 8, 2024

as the model is partially a large text encoder aligned to image space.

source, more info?

@bghira
Copy link
Owner Author

bghira commented Oct 8, 2024

mm-dit is this.

@dxqbYD
Copy link

dxqbYD commented Oct 11, 2024

After running some more tests, now I do think this is worth implementing.
It even works well with an empty prompt and no external reg image set - just reuse the training data set and:
if batch['prompt'][0] == "":

Making this a feature that does not require data, captions or configuration otherwise. Since there is no prompt provided, it can potentially preserve multiple classes and whatever you train on.

@dxqbYD
Copy link

dxqbYD commented Oct 11, 2024

branch here for anyone who wants to try: https://github.com/dxqbYD/OneTrainer/tree/prior_reg
but it's the same code as above

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
difficult-feature documentation Improvements or additions to documentation work-in-progress This issue relates to some currently in-progress work.
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants