Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

修复paddlenlp develop版本适配错误_10-11 #735

Merged
merged 2 commits into from
Oct 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ppdiffusers/deploy/controlnet/scripts/export.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@

export USE_PPXFORMERS=False
export FLAGS_set_to_1d=1
python export_model.py --pretrained_model_name_or_path runwayml/stable-diffusion-v1-5 --controlnet_pretrained_model_name_or_path lllyasviel/sd-controlnet-canny --output_path static_model/stable-diffusion-v1-5-canny
python export_model.py --pretrained_model_name_or_path runwayml/stable-diffusion-v1-5 --controlnet_pretrained_model_name_or_path lllyasviel/sd-controlnet-canny --output_path static_model/stable-diffusion-v1-5-canny --width 512 --height 512
2 changes: 1 addition & 1 deletion ppdiffusers/deploy/controlnet/scripts/tune_and_tensorrt.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
export USE_PPXFORMERS=False
export FLAGS_set_to_1d=1
# 1. export the model to static_model.
python export_model.py --pretrained_model_name_or_path runwayml/stable-diffusion-v1-5 --controlnet_pretrained_model_name_or_path lllyasviel/sd-controlnet-canny --output_path static_model/stable-diffusion-v1-5-canny
python export_model.py --pretrained_model_name_or_path runwayml/stable-diffusion-v1-5 --controlnet_pretrained_model_name_or_path lllyasviel/sd-controlnet-canny --output_path static_model/stable-diffusion-v1-5-canny --width 512 --height 512

# 2. tune the shapes of the model for tensorrt
python infer.py --model_dir static_model/stable-diffusion-v1-5-canny/ --scheduler "ddim" --backend paddle --device gpu --task_name all --width 512 --height 512 --inference_steps 50 --tune True --use_fp16 False
Expand Down
25 changes: 13 additions & 12 deletions ppdiffusers/ppdiffusers/models/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,19 +225,20 @@ def forward(self, latent):
latent = self.norm(latent)

# Interpolate or crop positional embeddings as needed
if self.pos_embed_max_size:
pos_embed = self.cropped_pos_embed(height, width)
else:
if self.height != height or self.width != width:
pos_embed = get_2d_sincos_pos_embed(
embed_dim=self.pos_embed.shape[-1],
grid_size=(height, width),
base_size=self.base_size,
interpolation_scale=self.interpolation_scale,
)
pos_embed = paddle.to_tensor(pos_embed).astype(paddle.float32).unsqueeze(0)
if self.add_pos_embed:
if self.pos_embed_max_size:
pos_embed = self.cropped_pos_embed(height, width)
else:
pos_embed = self.pos_embed
if self.height != height or self.width != width:
pos_embed = get_2d_sincos_pos_embed(
embed_dim=self.pos_embed.shape[-1],
grid_size=(height, width),
base_size=self.base_size,
interpolation_scale=self.interpolation_scale,
)
pos_embed = paddle.to_tensor(pos_embed).astype(paddle.float32).unsqueeze(0)
else:
pos_embed = self.pos_embed

# NOTE, new add for unidiffusers!
if self.add_pos_embed:
Expand Down
12 changes: 9 additions & 3 deletions ppdiffusers/ppdiffusers/models/transformer_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ def forward(
else:
batch, height, width, _ = hidden_states.shape
residual = hidden_states

shape = paddle.shape(hidden_states)
hidden_states = self.norm(hidden_states)
if not self.use_linear_projection:
hidden_states = (
Expand Down Expand Up @@ -441,7 +441,10 @@ def custom_forward(*inputs):
# 3. Output
if self.is_input_continuous:
if not self.use_linear_projection:
hidden_states = hidden_states.reshape([batch, height, width, self.inner_dim])
if self.data_format == "NCHW":
hidden_states = hidden_states.reshape([shape[0], shape[2], shape[3], self.inner_dim])
else:
hidden_states = hidden_states.reshape([shape[0], shape[1], shape[2], self.inner_dim])
if self.data_format == "NCHW":
hidden_states = hidden_states.transpose([0, 3, 1, 2])
hidden_states = (
Expand All @@ -455,7 +458,10 @@ def custom_forward(*inputs):
if not USE_PEFT_BACKEND
else self.proj_out(hidden_states)
)
hidden_states = hidden_states.reshape([batch, height, width, self.inner_dim])
if self.data_format == "NCHW":
hidden_states = hidden_states.reshape([shape[0], shape[2], shape[3], self.inner_dim])
else:
hidden_states = hidden_states.reshape([shape[0], shape[1], shape[2], self.inner_dim])
if self.data_format == "NCHW":
hidden_states = hidden_states.transpose([0, 3, 1, 2])

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,8 @@ def generate_beam(
logits[is_stopped] = -float(np.inf)
logits[is_stopped, 0] = 0
scores_sum = scores[:, None] + logits
seq_lengths[~is_stopped] += 1
is_stopped_tensor_int32 = paddle.cast(~is_stopped, dtype='int32')
seq_lengths += is_stopped_tensor_int32
scores_sum_average = scores_sum / seq_lengths[:, None].cast(scores_sum.dtype)
scores_sum_average, next_tokens = scores_sum_average.reshape([-1]).topk(beam_size, -1)
next_tokens_source = next_tokens // scores_sum.shape[1]
Expand Down
2 changes: 1 addition & 1 deletion ppdiffusers/ppdiffusers/transformers/clip/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,7 @@ def from_text_vision_configs(cls, text_config: CLIPTextConfig, vision_config: CL

return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)

def to_dict(self):
def to_dict(self, *args, ** kwargs):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -479,42 +479,3 @@ def test_xformers_attention_forwardGenerator_pass(self):
self.assertLess(max_diff, expected_max_diff, "XFormers attention should not affect the inference results")

enable_full_determinism()


@slow
@require_paddle_gpu
class StableVideoDiffusionPipelineSlowTests(unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
paddle.device.cuda.empty_cache()

def test_sd_video(self):
pipe = StableVideoDiffusionPipeline.from_pretrained(
"stabilityai/stable-video-diffusion-img2vid-xt",
variant="fp16",
paddle_dtype=paddle.float16,
)
pipe.set_progress_bar_config(disable=None)
image = load_image(
"https://paddlenlp.bj.bcebos.com/models/community/hf-internal-testing/diffusers-images/cat_6.png"
)

generator = paddle.Generator().manual_seed(0)
num_frames = 3

output = pipe(
image=image,
num_frames=num_frames,
generator=generator,
num_inference_steps=3,
output_type="np",
)

image = output.frames[0]
assert image.shape == (num_frames, 576, 1024, 3)

image_slice = image[0, -3:, -3:, -1]
expected_slice = np.array([0.8592, 0.8645, 0.8499, 0.8722, 0.8769, 0.8421, 0.8557, 0.8528, 0.8285])
assert numpy_cosine_similarity_distance(image_slice.flatten(), expected_slice.flatten()) < 1e-3
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import gc
import random
import tempfile
import unittest

import numpy as np
import paddle

import ppdiffusers
from ppdiffusers import (
StableVideoDiffusionPipeline,
)
from ppdiffusers.utils import (
is_accelerate_available,
is_accelerate_version,
load_image,
logging,
)
from ppdiffusers.utils.testing_utils import (
paddle_device,
slow,
require_paddle_gpu,
numpy_cosine_similarity_distance,
)

@slow
@require_paddle_gpu
class StableVideoDiffusionPipelineSlowTests(unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
paddle.device.cuda.empty_cache()

def test_sd_video(self):
pipe = StableVideoDiffusionPipeline.from_pretrained(
"stabilityai/stable-video-diffusion-img2vid-xt",
variant="fp16",
paddle_dtype=paddle.float16,
)
pipe.set_progress_bar_config(disable=None)
image = load_image(
"https://paddlenlp.bj.bcebos.com/models/community/hf-internal-testing/diffusers-images/cat_6.png"
)

generator = paddle.Generator().manual_seed(0)
num_frames = 3

output = pipe(
image=image,
num_frames=num_frames,
generator=generator,
num_inference_steps=25,
output_type="np",
)

image = output.frames[0]
assert image.shape == (num_frames, 576, 1024, 3)

image_slice = image[0, -3:, -3:, -1]
expected_slice = np.array([0.8592, 0.8645, 0.8499, 0.8722, 0.8769, 0.8421, 0.8557, 0.8528, 0.8285])
assert numpy_cosine_similarity_distance(image_slice.flatten(), expected_slice.flatten()) < 1e-3