-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add StableDiffusion repaint pipeline #1341
Changes from all commits
5a478cf
17fd219
a696c14
80737f4
e9890e9
a01b16a
b4dc538
269bcb1
9cb5d44
4cdec74
41833a5
ce924ec
7996688
7f728b0
33e37eb
1fabaf9
3f0ffc6
3984383
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,77 @@ | ||||||||||||||
# coding=utf-8 | ||||||||||||||
# Copyright 2022 HuggingFace Inc. | ||||||||||||||
# | ||||||||||||||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||||||||||||||
# you may not use this file except in compliance with the License. | ||||||||||||||
# You may obtain a copy of the License at | ||||||||||||||
# | ||||||||||||||
# http://www.apache.org/licenses/LICENSE-2.0 | ||||||||||||||
# | ||||||||||||||
# Unless required by applicable law or agreed to in writing, software | ||||||||||||||
# distributed under the License is distributed on an "AS IS" BASIS, | ||||||||||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||||||||||
# See the License for the specific language governing permissions and | ||||||||||||||
# limitations under the License. | ||||||||||||||
|
||||||||||||||
import gc | ||||||||||||||
import unittest | ||||||||||||||
|
||||||||||||||
import numpy as np | ||||||||||||||
import torch | ||||||||||||||
|
||||||||||||||
from diffusers import RePaintScheduler, StableDiffusionRepaintPipeline | ||||||||||||||
from diffusers.utils import load_image, slow, torch_device | ||||||||||||||
from diffusers.utils.testing_utils import load_numpy, require_torch_gpu | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As Patrick mentioned above, most of the models are now getting covered by common tests from
with pipeline_class = StableDiffusionRepaintPipeline and slightly adapted get_dummy_components() and get_dummy_inputs() which you can probably borrow without many changes from StableDiffusionInpaintPipelineFastTests :
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These added tests will probably uncover some missing pieces in the pipeline, so feel free to ping us if something is tough to fix! :) |
||||||||||||||
|
||||||||||||||
|
||||||||||||||
torch.backends.cuda.matmul.allow_tf32 = False | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
@slow | ||||||||||||||
@require_torch_gpu | ||||||||||||||
class StableDiffusionRepaintPipelineIntegrationTests(unittest.TestCase): | ||||||||||||||
Comment on lines
+30
to
+32
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Now that we're trying to move all of the slow integration tests to nightly runs (reference PR: #1664), this cab be moved as well:
Suggested change
Then the tests can be launched locally with |
||||||||||||||
def tearDown(self): | ||||||||||||||
# clean up the VRAM after each test | ||||||||||||||
super().tearDown() | ||||||||||||||
gc.collect() | ||||||||||||||
torch.cuda.empty_cache() | ||||||||||||||
|
||||||||||||||
def test_stable_diffusion_repaint_pipeline(self): | ||||||||||||||
init_image = load_image( | ||||||||||||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" | ||||||||||||||
"/in_paint/overture-creations-5sI6fQgYIuo.png" | ||||||||||||||
) | ||||||||||||||
mask_image = load_image( | ||||||||||||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" | ||||||||||||||
"/in_paint/overture-creations-5sI6fQgYIuo_mask.png" | ||||||||||||||
) | ||||||||||||||
expected_image = load_numpy( | ||||||||||||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/repaint" | ||||||||||||||
"/red_cat_sitting_on_a_park_bench_repaint.npy" | ||||||||||||||
) | ||||||||||||||
|
||||||||||||||
model_id = "CompVis/stable-diffusion-v1-4" | ||||||||||||||
pipe = StableDiffusionRepaintPipeline.from_pretrained(model_id, safety_checker=None) | ||||||||||||||
pipe.scheduler = RePaintScheduler.from_config(pipe.scheduler.config) | ||||||||||||||
pipe.to(torch_device) | ||||||||||||||
pipe.set_progress_bar_config(disable=None) | ||||||||||||||
pipe.enable_attention_slicing() | ||||||||||||||
|
||||||||||||||
prompt = "A red cat sitting on a park bench" | ||||||||||||||
|
||||||||||||||
generator = torch.Generator(device=torch_device).manual_seed(0) | ||||||||||||||
output = pipe( | ||||||||||||||
prompt=prompt, | ||||||||||||||
image=init_image, | ||||||||||||||
mask_image=mask_image, | ||||||||||||||
jump_length=3, | ||||||||||||||
jump_n_sample=3, | ||||||||||||||
num_inference_steps=50, | ||||||||||||||
guidance_scale=7.5, | ||||||||||||||
generator=generator, | ||||||||||||||
output_type="np", | ||||||||||||||
) | ||||||||||||||
image = output.images[0] | ||||||||||||||
|
||||||||||||||
assert image.shape == (512, 512, 3) | ||||||||||||||
assert np.abs(expected_image - image).max() < 1e-3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
repaint scheduler wasn't doing this but other schedulers do, I assume this step is supposed to be here? (it doesn't seem to affect output much)