Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PPDiffusers] add photomaker & InstantID model #401

Merged
Merged
73 changes: 73 additions & 0 deletions ppdiffusers/examples/InstantID/attention_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import paddle
from typing import Optional
from ppdiffusers.models.attention_processor import Attention
from ppdiffusers.utils import USE_PEFT_BACKEND

class AttnProcessor(paddle.nn.Layer):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这应该没必要重写吧,直接参考ip adapter训练的那个,因为可能未来版本会变,里面的__call__会变化

r"""
Default processor for performing attention-related computations.
"""

def __call__(
self,
attn: Attention,
hidden_states: paddle.Tensor,
encoder_hidden_states: Optional[paddle.Tensor] = None,
attention_mask: Optional[paddle.Tensor] = None,
temb: Optional[paddle.Tensor] = None,
scale: float = 1.0,
**kwargs,
) -> paddle.Tensor:
residual = hidden_states

args = () if USE_PEFT_BACKEND else (scale,)

if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)

input_ndim = hidden_states.ndim

if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.reshape([batch_size, channel, height * width]).transpose([0, 2, 1])

batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)

if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose([0, 2, 1])).transpose([0, 2, 1])

query = attn.to_q(hidden_states, *args)

if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

key = attn.to_k(encoder_hidden_states, *args)
value = attn.to_v(encoder_hidden_states, *args)

query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)

attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = paddle.matmul(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)

# linear proj
hidden_states = attn.to_out[0](hidden_states, *args)
# dropout
hidden_states = attn.to_out[1](hidden_states)

if input_ndim == 4:
hidden_states = hidden_states.transpose([0, 2, 1]).reshape([batch_size, channel, height, width])

if attn.residual_connection:
hidden_states = hidden_states + residual

hidden_states = hidden_states / attn.rescale_output_factor

return hidden_states
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading