diff --git a/src/infer_compiler_registry/register_diffusers/spatio_temporal_oflow.py b/src/infer_compiler_registry/register_diffusers/spatio_temporal_oflow.py index 57e9c3d65..f79864ed7 100644 --- a/src/infer_compiler_registry/register_diffusers/spatio_temporal_oflow.py +++ b/src/infer_compiler_registry/register_diffusers/spatio_temporal_oflow.py @@ -366,7 +366,6 @@ def forward( """ # 1. Input batch_frames, _, height, width = hidden_states.shape - hidden_states_in = hidden_states num_frames = image_only_indicator.shape[-1] batch_size = batch_frames // num_frames @@ -382,9 +381,11 @@ def forward( # height * width, batch_size, 1, time_context.shape[-1] # ) # Rewrite for onediff SVD dynamic shape - broadcast_tensor = hidden_states.flatten(2, 3).permute(2, 0, 1) time_context = torch._C.broadcast_dim_like( - time_context_first_timestep[None, :], broadcast_tensor, dim=0 + time_context_first_timestep[None, :], + hidden_states.flatten(2, 3), + dim=0, + like_dim=2, ) # time_context = time_context.reshape(height * width * batch_size, 1, time_context.shape[-1]) # Rewrite for onediff SVD dynamic shape @@ -450,7 +451,9 @@ def forward( # hidden_states = hidden_states.reshape(batch_frames, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() # Rewrite for onediff SVD dynamic shape hidden_states = ( - hidden_states.permute(0, 2, 1).reshape_as(hidden_states_in).contiguous() + hidden_states.reshape_as(residual.permute(0, 2, 3, 1)) + .permute(0, 3, 1, 2) + .contiguous() ) output = hidden_states + residual