@@ -617,7 +617,7 @@ def dtype(self) -> torch.dtype:
617617 def device (self ) -> torch .device :
618618 return self .patch_embed .patchifier .proj .weight .device
619619
620- def get_pos_ids_by_grid (self , grid_thw ) :
620+ def get_pos_ids_by_grid (self , grid_thw : list [ list [ int ]]) -> list [ torch . Tensor ] :
621621 pos_ids = []
622622 for t , h , w in grid_thw :
623623 hpos_ids = torch .arange (h ).unsqueeze (1 ).expand (- 1 , w )
@@ -643,10 +643,10 @@ def get_pos_ids_by_grid(self, grid_thw):
643643
644644 return pos_ids
645645
646- def rot_pos_emb (self , grid_thw ) :
646+ def rot_pos_emb (self , grid_thw : list [ list [ int ]]) -> torch . Tensor :
647647 pos_ids = self .get_pos_ids_by_grid (grid_thw )
648648 pos_ids = torch .cat (pos_ids , dim = 0 )
649- max_grid_size = grid_thw [:, 1 :]. max ()
649+ max_grid_size = max (max ( h , w ) for _ , h , w in grid_thw )
650650 rotary_pos_emb_full = self .rotary_pos_emb (max_grid_size )
651651 rotary_pos_emb = rotary_pos_emb_full [pos_ids ].flatten (1 )
652652 return rotary_pos_emb
@@ -667,13 +667,13 @@ def compute_attn_mask_seqlen(
667667 def forward (
668668 self , hidden_states : torch .Tensor , grid_thw : list [list [int ]]
669669 ) -> torch .Tensor :
670+ rotary_pos_emb = self .rot_pos_emb (grid_thw )
671+
670672 # Convert grid_thw to tensor (always expecting list format now)
671673 grid_thw = torch .tensor (grid_thw , device = hidden_states .device , dtype = torch .long )
672674 hidden_states = hidden_states .to (self .dtype )
673675 hidden_states = self .patch_embed (hidden_states , grid_thw )
674676
675- rotary_pos_emb = self .rot_pos_emb (grid_thw )
676-
677677 cu_seqlens = torch .repeat_interleave (
678678 grid_thw [:, 1 ] * grid_thw [:, 2 ], grid_thw [:, 0 ]
679679 ).cumsum (
@@ -807,7 +807,7 @@ def _process_image_input(
807807 rope_type = "rope_3d" ,
808808 )
809809 else :
810- image_embeds = self .vision_tower (pixel_values , grid_thw )[
810+ image_embeds = self .vision_tower (pixel_values , grid_thw_list )[
811811 :, : self .config .hidden_size
812812 ]
813813
0 commit comments