Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support for parallel inference with xfuser #176

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion easyanimate/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1145,4 +1145,4 @@ def forward(
norm_encoder_hidden_states = self.ff(norm_encoder_hidden_states)
hidden_states = hidden_states + gate_ff * norm_hidden_states
encoder_hidden_states = encoder_hidden_states + enc_gate_ff * norm_encoder_hidden_states
return hidden_states, encoder_hidden_states
return hidden_states, encoder_hidden_states
53 changes: 47 additions & 6 deletions easyanimate/models/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,21 @@
from diffusers.models.embeddings import apply_rotary_emb
from einops import rearrange, repeat

try:
import xfuser
from xfuser.core.distributed import (
get_sequence_parallel_world_size,
get_sequence_parallel_rank,
get_sp_group,
initialize_model_parallel,
init_distributed_environment
)
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
except Exception as ex:
get_sequence_parallel_world_size = None
get_sequence_parallel_rank = None
xFuserLongContextAttention = None


class HunyuanAttnProcessor2_0:
r"""
Expand Down Expand Up @@ -217,7 +232,14 @@ def __call__(

class EasyAnimateAttnProcessor2_0:
def __init__(self):
pass
if xFuserLongContextAttention is not None:
try:
get_sequence_parallel_world_size()
self.hybrid_seq_parallel_attn = xFuserLongContextAttention()
except Exception:
self.hybrid_seq_parallel_attn = None
else:
self.hybrid_seq_parallel_attn = None

def __call__(
self,
Expand Down Expand Up @@ -284,11 +306,30 @@ def __call__(
if not attn.is_cross_attention:
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)

hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)

hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
if self.hybrid_seq_parallel_attn is None:
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2)
else:
sp_world_rank = get_sequence_parallel_rank()
sp_world_size = get_sequence_parallel_world_size()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sp_world_size & sp_world_rank unuse here


img_q = query[:, :, text_seq_length:].transpose(1,2)
txt_q = query[:, :, :text_seq_length].transpose(1,2)
img_k = key[:, :, text_seq_length:].transpose(1,2)
txt_k = key[:, :, :text_seq_length].transpose(1,2)
img_v = value[:, :, text_seq_length:].transpose(1,2)
txt_v = value[:, :, :text_seq_length].transpose(1,2)

hidden_states = self.hybrid_seq_parallel_attn(None,
img_q, img_k, img_v, dropout_p=0.0, causal=False,
joint_tensor_query=txt_q,
joint_tensor_key=txt_k,
joint_tensor_value=txt_v,
joint_strategy='front',)

hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim)

if attn2 is None:
# linear proj
Expand Down
64 changes: 63 additions & 1 deletion easyanimate/models/transformer3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,23 @@
from diffusers.models.embeddings import \
CaptionProjection as PixArtAlphaTextProjection

try:
import xfuser
from xfuser.core.distributed import (
get_sequence_parallel_world_size,
get_sequence_parallel_rank,
get_sp_group,
initialize_model_parallel,
init_distributed_environment
)
except Exception as ex:
xfuser = None
get_sequence_parallel_world_size = None
get_sequence_parallel_rank = None
get_sp_group = None
initialize_model_parallel = None
init_distributed_environment = None


class CLIPProjection(nn.Module):
"""
Expand Down Expand Up @@ -1375,6 +1392,14 @@ def __init__(

self.gradient_checkpointing = False

try:
self.sp_world_size = get_sequence_parallel_world_size()
self.sp_world_rank = get_sequence_parallel_rank()
except Exception:
self.sp_world_size = 1
self.sp_world_rank = 0
xfuser = None

def _set_gradient_checkpointing(self, module, value=False):
self.gradient_checkpointing = value

Expand All @@ -1399,6 +1424,40 @@ def forward(
):
batch_size, channels, video_length, height, width = hidden_states.size()

if xfuser is not None and self.sp_world_size > 1:
if hidden_states.shape[-2] // self.patch_size % self.sp_world_size == 0:
split_height = height // self.sp_world_size
split_dim = -2
elif hidden_states.shape[-1] // self.patch_size % self.sp_world_size == 0:
split_width = width // self.sp_world_size
split_dim = -1
else:
raise ValueError("Cannot split video sequence into ulysses_degree x ring_degree=%d parts evenly, hidden_states.shape=%s" % (self.sp_world_size, str(hidden_states.shape)))


hidden_states = torch.chunk(hidden_states, self.sp_world_size, dim=split_dim)[self.sp_world_rank]
if inpaint_latents is not None:
inpaint_latents = torch.chunk(inpaint_latents, self.sp_world_size, dim=split_dim)[self.sp_world_rank]

if image_rotary_emb is not None:
embed_dim = image_rotary_emb[0].shape[-1]
freq_cos = image_rotary_emb[0].reshape(video_length, height // self.patch_size, width // self.patch_size, embed_dim)
freq_sin = image_rotary_emb[1].reshape(video_length, height // self.patch_size, width // self.patch_size, embed_dim)

freq_cos = torch.chunk(freq_cos, self.sp_world_size, dim=split_dim-1)[self.sp_world_rank]
freq_sin = torch.chunk(freq_sin, self.sp_world_size, dim=split_dim-1)[self.sp_world_rank]

freq_cos = freq_cos.reshape(-1, embed_dim)
freq_sin = freq_sin.reshape(-1, embed_dim)

image_rotary_emb = (freq_cos, freq_sin)

if split_dim == -2:
height = split_height
elif split_dim == -1:
width = split_width


# 1. Time embedding
temb = self.time_proj(timestep).to(dtype=hidden_states.dtype)
temb = self.time_embedding(temb, timestep_cond)
Expand Down Expand Up @@ -1486,6 +1545,9 @@ def custom_forward(*inputs):
output = hidden_states.reshape(batch_size, video_length, height // p, width // p, channels, p, p)
output = output.permute(0, 4, 1, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)

if xfuser is not None and self.sp_world_size > 1:
output = get_sp_group().all_gather(output, dim=split_dim)

if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
Expand Down Expand Up @@ -1606,4 +1668,4 @@ def from_pretrained_2d(
print(f"### attn1 Parameters: {sum(params) / 1e6} M")

model = model.to(torch_dtype)
return model
return model
Loading