1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- import itertools
1615import math
1716from typing import List , Optional , Tuple
1817
2322
2423from ...configuration_utils import ConfigMixin , register_to_config
2524from ...loaders import FromOriginalModelMixin , PeftAdapterMixin
26- from ..attention_dispatch import dispatch_attention_fn
2725from ...models .attention_processor import Attention
2826from ...models .modeling_utils import ModelMixin
2927from ...utils .import_utils import is_flash_attn_available
3028from ...utils .torch_utils import maybe_allow_in_graph
29+ from ..attention_dispatch import dispatch_attention_fn
3130
3231
3332if is_flash_attn_available ():
@@ -99,7 +98,6 @@ def __call__(
9998 attention_mask : Optional [torch .Tensor ] = None ,
10099 freqs_cis : Optional [torch .Tensor ] = None ,
101100 ) -> torch .Tensor :
102-
103101 query = attn .to_q (hidden_states )
104102 key = attn .to_k (hidden_states )
105103 value = attn .to_v (hidden_states )
@@ -586,11 +584,7 @@ def forward(
586584 dtype = x_freqs_cis [0 ].dtype ,
587585 device = device ,
588586 )
589- x_attn_mask = torch .ones (
590- (bsz , x_max_item_seqlen ),
591- dtype = torch .bool ,
592- device = device
593- )
587+ x_attn_mask = torch .ones ((bsz , x_max_item_seqlen ), dtype = torch .bool , device = device )
594588 for i , (item , freqs_item ) in enumerate (zip (x , x_freqs_cis )):
595589 seq_len = x_item_seqlens [i ]
596590 pad_len = x_max_item_seqlen - seq_len
@@ -629,11 +623,7 @@ def forward(
629623 dtype = cap_freqs_cis [0 ].dtype ,
630624 device = device ,
631625 )
632- cap_attn_mask = torch .ones (
633- (bsz , cap_max_item_seqlen ),
634- dtype = torch .bool ,
635- device = device
636- )
626+ cap_attn_mask = torch .ones ((bsz , cap_max_item_seqlen ), dtype = torch .bool , device = device )
637627 for i , (item , freqs_item ) in enumerate (zip (cap_feats , cap_freqs_cis )):
638628 seq_len = cap_item_seqlens [i ]
639629 pad_len = cap_max_item_seqlen - seq_len
@@ -672,11 +662,7 @@ def forward(
672662 dtype = unified_freqs_cis [0 ].dtype ,
673663 device = device ,
674664 )
675- unified_attn_mask = torch .ones (
676- (bsz , unified_max_item_seqlen ),
677- dtype = torch .bool ,
678- device = device
679- )
665+ unified_attn_mask = torch .ones ((bsz , unified_max_item_seqlen ), dtype = torch .bool , device = device )
680666 for i , (item , freqs_item ) in enumerate (zip (unified , unified_freqs_cis )):
681667 seq_len = unified_item_seqlens [i ]
682668 pad_len = unified_max_item_seqlen - seq_len
@@ -694,9 +680,7 @@ def forward(
694680 adaln_input ,
695681 )
696682
697- unified = self .all_final_layer [f"{ patch_size } -{ f_patch_size } " ](
698- unified , adaln_input
699- )
683+ unified = self .all_final_layer [f"{ patch_size } -{ f_patch_size } " ](unified , adaln_input )
700684 unified = list (unified .unbind (dim = 0 ))
701685 x = self .unpatchify (unified , x_size , patch_size , f_patch_size )
702686
0 commit comments