Skip to content

Add StableDiffusion repaint pipeline #1341

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 18 commits into from

Conversation

nathanielherman
Copy link

No description provided.

@@ -190,6 +190,7 @@ def set_timesteps(

timesteps = np.array(timesteps) * (self.config.num_train_timesteps // self.num_inference_steps)
self.timesteps = torch.from_numpy(timesteps).to(device)
self.timesteps += self.config.steps_offset
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

repaint scheduler wasn't doing this but other schedulers do, I assume this step is supposed to be here? (it doesn't seem to affect output much)

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
def get_timesteps(self, num_inference_steps, strength, device):
# get the original timestep using init_timestep
# TODO: steps_offset is usually 1, so this effectively cuts the first step out when strength=1.0, is that desired? (for inpaint/img2img)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this a bug in inpaint_legacy or intended? (ie inpaint_legacy will remove the first step when steps_offset is set to a default of 1)

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm actually ok to leave it here given that the code uses a lot of "Copied from statements" - @anton-l what do you think?

@nathanielherman
Copy link
Author

Bump on this PR! Also @patrickvonplaten wdym by "leave it here"?

Copy link
Member

@anton-l anton-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @nathanielherman, the PR looks great already! We've had to make some changes to the Stable Diffusion pipelines last week to accommodate for SD 2.0, so we'll need to do some tweaking here as well, hope that's ok :)

Most of the updates will be copied over when you run python utils/check_copies.py --fix_and_overwrite thanks to the # Copied from comments!

Also, could you add one integration test similar to https://github.com/huggingface/diffusers/blob/main/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint_legacy.py#L352 so that we have the same reference? 🙏

prompt: Union[str, List[str]],
init_image: Union[torch.FloatTensor, PIL.Image.Image],
mask_image: Union[torch.FloatTensor, PIL.Image.Image],
num_inference_steps: Optional[int] = 50,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry to jump in, but strength argument is missing in here

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah true, I updated to remove it entirely since I don't actually use it anymore (repaint just initializes the latents to random noise)

latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

if t >= t_last:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and sorry to jump in again, but this line actually causes a bug:
in original repaint pipeline, the t >= last is actually a condition that wraps the main denoise logic, if you instead put such condition check in here, the first time that t >= last is satisfied, the latents size will be doubled but skipped the unet forward that puts its shape back, thus causing error in the next round.
An easy way to reproduce this is to set jump_n_sample to 2 or anything larger than 1

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just wanna quickly suggest a candidate version of this for loop to avoid the shape doubling bug mentioned above:

        for i, t in enumerate(self.progress_bar(timesteps)):
            if t < t_last:
                # expand the latents if we are doing classifier free guidance
                latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

                # predict the noise residual
                noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample

                # perform guidance
                if do_classifier_free_guidance:
                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

                # compute the previous noisy sample x_t -> x_t-1
                latents = self.scheduler.step(noise_pred, t, latents, init_latents_orig, mask, generator).prev_sample

                # call the callback, if provided
                if callback is not None and i % callback_steps == 0:
                    callback(i, t, latents)
            else:
                # compute the reverse: x_t-1 -> x_t
                latents = self.scheduler.undo_step(latents, t_last, generator)
                
            t_last = t

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see, I think I can just put the if t >= t_last 2 lines earlier before the latent_model_input = to achieve the same effect.

@anton-l
Copy link
Member

anton-l commented Dec 8, 2022

Thanks for catching the issues @Randolph-zeng!

@nathanielherman let me know if you don't have bandwidth this week, I'd be happy to help getting the PR ready for merging :)

nathanielherman and others added 8 commits December 8, 2022 10:46
…sion_repaint.py

Co-authored-by: Anton Lozhkov <aglozhkov@gmail.com>
…sion_repaint.py

Co-authored-by: Anton Lozhkov <aglozhkov@gmail.com>
…sion_repaint.py

Co-authored-by: Anton Lozhkov <aglozhkov@gmail.com>
…sion_repaint.py

Co-authored-by: Anton Lozhkov <aglozhkov@gmail.com>
@nathanielherman
Copy link
Author

Hey! I did most of the updates and will look at adding an integration test. On that note, AFAICT from here https://github.com/huggingface/diffusers/blob/main/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint_legacy.py#L352, the code is actually loading the non-legacy pipeline rather than the legacy one? I'm a bit confused how that's not breaking it though, since it's initializing it with a non-inpainting model ("CompVis/stable-diffusion-v1-4")

@Randolph-zeng
Copy link
Contributor

Just curious, am I the only one that experienced the DDIM degradation here ? When I use the code in this PR I noticed that the DDIM almost failed completely in producing any meaningful impaint image that corresponds to the prompt.
@nathanielherman Are you troubled by this same issue #1602 or does it work fine with you ? Thanks a lot if you can share some insight : )

@nathanielherman
Copy link
Author

@Randolph-zeng hmm I'm confused by the linked issue, do you only get the issue for CFG outside of 6-7, or for any CFG? I only really use the default CFG of 7.5 but for that I get pretty reasonable outputs. (Though I wouldn't say repaint is obviously better results than default inpaint_legacy.)

@nathanielherman
Copy link
Author

@anton-l bump on my question for the unit test, I just want to make sure I'm understanding correctly before I add my own unit test

@anton-l
Copy link
Member

anton-l commented Dec 12, 2022

@nathanielherman regarding your question

the code is actually loading the non-legacy pipeline rather than the legacy one

The tests there carried over from the time when the legacy inpainting pipeline wasn't yet Legacy :) The pipeline loader substitutes the appropriate class for now:

if pipeline_class.__name__ == "StableDiffusionInpaintPipeline" and version.parse(
version.parse(config_dict["_diffusers_version"]).base_version
) <= version.parse("0.5.1"):
from diffusers import StableDiffusionInpaintPipeline, StableDiffusionInpaintPipelineLegacy
pipeline_class = StableDiffusionInpaintPipelineLegacy

So you can safely assume that those tests are actually using StableDiffusionInpaintPipelineLegacy (we should probably update them, thanks for bringing it up!)

@nathanielherman
Copy link
Author

nathanielherman commented Dec 12, 2022

@nathanielherman
Copy link
Author

@anton-l makes sense! I've added the test and attached the npy file as a comment on this PR — IIUC from the docs, someone would need to upload that npy file and then I can update the test to download it from the url?

@anton-l
Copy link
Member

anton-l commented Dec 15, 2022

@nathanielherman uploaded it to the repo: https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/repaint/red_cat_sitting_on_a_park_bench_repaint.npy
But also feel free to set up a personal repository on the hub to link the files, we can adapt later! :)

Great progress on the PR, let me know if it's ready for the final review!

@nathanielherman
Copy link
Author

@anton-l perfect, it should be good for final review now!

@patrickvonplaten
Copy link
Contributor

Gently ping @anton-l for a final review

):
super().__init__()

if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we remove all those deprecation messages? We should not add new models with deprecation messages :-)

new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)

if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove

) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove

@patrickvonplaten
Copy link
Contributor

Thanks for the nice PR @nathanielherman!

Three things from my side:

Thanks!

Comment on lines +30 to +32
@slow
@require_torch_gpu
class StableDiffusionRepaintPipelineIntegrationTests(unittest.TestCase):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that we're trying to move all of the slow integration tests to nightly runs (reference PR: #1664), this cab be moved as well:

Suggested change
@slow
@require_torch_gpu
class StableDiffusionRepaintPipelineIntegrationTests(unittest.TestCase):
@nightly
@require_torch_gpu
class StableDiffusionRepaintPipelineNightlyTests(unittest.TestCase):

Then the tests can be launched locally with RUN_NIGHTLY=1 pytest <your usual path and args>


from diffusers import RePaintScheduler, StableDiffusionRepaintPipeline
from diffusers.utils import load_image, slow, torch_device
from diffusers.utils.testing_utils import load_numpy, require_torch_gpu
Copy link
Member

@anton-l anton-l Jan 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As Patrick mentioned above, most of the models are now getting covered by common tests from PipelineTesterMixin that check API compatibility, common functionality, etc.
What we need here is just a test class similar to RepaintPipelineFastTests:

class RepaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase):

with pipeline_class = StableDiffusionRepaintPipeline and slightly adapted get_dummy_components() and get_dummy_inputs() which you can probably borrow without many changes from StableDiffusionInpaintPipelineFastTests:

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These added tests will probably uncover some missing pieces in the pipeline, so feel free to ping us if something is tough to fix! :)

@github-actions
Copy link
Contributor

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Feb 26, 2023
@github-actions github-actions bot closed this Mar 7, 2023
@Markus-Pobitzer
Copy link
Contributor

Good Morning

Thanks for the great work. I am wondering why this pull request has been closed and how one can help.

@anton-l anton-l removed the stale Issues that haven't received updates label Mar 9, 2023
@anton-l anton-l reopened this Mar 9, 2023
@github-actions github-actions bot added the stale Issues that haven't received updates label Apr 2, 2023
@huggingface huggingface deleted a comment from github-actions bot Apr 4, 2023
@vlordier
Copy link

vlordier commented Apr 9, 2023

@anton-l is there something left we can do to merge this pipeline ?

@anton-l
Copy link
Member

anton-l commented Apr 12, 2023

@vlordier the TODO is mostly just to update the tests as per the comments above (and fix any API issues uncovered by the common tests), and resolve the merge conflicts.

@github-actions
Copy link
Contributor

github-actions bot commented May 6, 2023

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this May 14, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stale Issues that haven't received updates
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants