Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

to_support_animatediff #1009

Merged
merged 4 commits into from
Jul 12, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,58 @@


class TemporalTransformer3DModel_OF(TemporalTransformer3DModel_OF_CLS):
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, view_options=None):
def get_cameractrl_effect(self, hidden_states: torch.Tensor) :
# if no raw camera_Ctrl, return None
if self.raw_cameractrl_effect is None:
return 1.0
# if raw_cameractrl is not a Tensor, return it (should be a float)
if type(self.raw_cameractrl_effect) != torch.Tensor:
return self.raw_cameractrl_effect
shape = hidden_states.shape
batch, channel, height, width = shape
# if temp_cameractrl already calculated, return it
if self.temp_cameractrl_effect != None:
# check if hidden_states batch matches
if batch == self.prev_cameractrl_hidden_states_batch:
if self.sub_idxs is not None:
return self.temp_cameractrl_effect[:, self.sub_idxs, :]
return self.temp_cameractrl_effect
# if does not match, reset cached temp_cameractrl and recalculate it
del self.temp_cameractrl_effect
self.temp_cameractrl_effect = None
# otherwise, calculate temp_cameractrl
self.prev_cameractrl_hidden_states_batch = batch
mask = prepare_mask_batch(self.raw_scale_mask, shape=(self.full_length, 1, height, width))
mask = repeat_to_batch_size(mask, self.full_length)
# if mask not the same amount length as full length, make it match
if self.full_length != mask.shape[0]:
mask = broadcast_image_to(mask, self.full_length, 1)
# reshape mask to attention K shape (h*w, latent_count, 1)
batch, channel, height, width = mask.shape
# first, perform same operations as on hidden_states,
# turning (b, c, h, w) -> (b, h*w, c)
mask = mask.permute(0, 2, 3, 1).reshape(batch, height*width, channel)
# then, make it the same shape as attention's k, (h*w, b, c)
mask = mask.permute(1, 0, 2)
# make masks match the expected length of h*w
batched_number = shape[0] // self.video_length
if batched_number > 1:
mask = torch.cat([mask] * batched_number, dim=0)
# cache mask and set to proper device
self.temp_cameractrl_effect = mask
# move temp_cameractrl to proper dtype + device
self.temp_cameractrl_effect = self.temp_cameractrl_effect.to(dtype=hidden_states.dtype, device=hidden_states.device)
# return subset of masks, if needed
if self.sub_idxs is not None:
return self.temp_cameractrl_effect[:, self.sub_idxs, :]
return self.temp_cameractrl_effect


def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, view_options=None, mm_kwargs: dict[str]=None):
batch, channel, height, width = hidden_states.shape
residual = hidden_states
cameractrl_effect = self.get_cameractrl_effect(hidden_states)

scale_mask = self.get_scale_mask(hidden_states)
# add some casts for fp8 purposes - does not affect speed otherwise
hidden_states = self.norm(hidden_states).to(hidden_states.dtype)
Expand All @@ -41,7 +90,9 @@ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None
attention_mask=attention_mask,
video_length=self.video_length,
scale_mask=scale_mask,
view_options=view_options
cameractrl_effect=cameractrl_effect,
view_options=view_options,
mm_kwargs=mm_kwargs
)

# output
Expand All @@ -67,6 +118,8 @@ def forward(
attention_mask=None,
video_length=None,
scale_mask=None,
cameractrl_effect= 1.0,
mm_kwargs: dict[str]={},
):
if self.attention_mode != "Temporal":
raise NotImplementedError
Expand All @@ -89,6 +142,9 @@ def forward(
if encoder_hidden_states is not None
else encoder_hidden_states
)
if self.camera_feature_enabled and self.qkv_merge is not None and mm_kwargs is not None and "camera_feature" in mm_kwargs:
camera_feature: torch.Tensor = mm_kwargs["camera_feature"]
hidden_states = (self.qkv_merge(hidden_states + camera_feature) + hidden_states) * cameractrl_effect + hidden_states * (1. - cameractrl_effect)

# hidden_states = super().forward(
# hidden_states,
Expand Down