@@ -1209,6 +1209,7 @@ def _omni_get_input_positions_tensor(
12091209        video_token_id  =  thinker_config .video_token_index 
12101210        audio_start_token_id  =  thinker_config .audio_start_token_id 
12111211        audio_end_token_id  =  thinker_config .audio_end_token_id 
1212+         vision_start_token_id  =  thinker_config .vision_start_token_id 
12121213        vision_end_token_id  =  thinker_config .vision_end_token_id 
12131214        seconds_per_chunk  =  thinker_config .seconds_per_chunk 
12141215        spatial_merge_size  =  thinker_config .vision_config .spatial_merge_size 
@@ -1238,8 +1239,15 @@ def _omni_get_input_positions_tensor(
12381239            if  src_item [idx ] not  in 
12391240                    audio_token_id , video_token_id , image_token_id 
12401241            ]:
1241-                 if  src_item [idx ] ==  vision_end_token_id  and  use_audio_in_video :
1242-                     start_idx  -=  1 
1242+                 if  use_audio_in_video  and  idx  >  0 :
1243+                     if  src_item [idx ] ==  vision_end_token_id  and  \
1244+                         src_item [idx  -  1 ] ==  audio_end_token_id :
1245+                         # processing the <|audio_eos|> before <|vision_eos|> 
1246+                         start_idx  -=  1 
1247+                     elif  src_item [idx ] ==  audio_start_token_id  and  \
1248+                         src_item [idx  -  1 ] ==  vision_start_token_id :
1249+                         # processing the <|audio_bos|> after <|vision_eos|> 
1250+                         start_idx  -=  1 
12431251                new_src_item .append (src_item [idx ])
12441252                llm_pos_ids  =  torch .tensor ([start_idx ],
12451253                                           dtype = torch .long ).expand (3 , - 1 )
@@ -1297,11 +1305,6 @@ def _omni_get_input_positions_tensor(
12971305                           tokens_per_second ).long ()
12981306                t_index_split_chunk  =  cls ._split_list_into_ranges (
12991307                    t_index , t_ntoken_per_chunk )
1300-                 new_src_item .extend ([audio_start_token_id ])
1301-                 start_idx  -=  1 
1302-                 llm_pos_ids_list .extend ([
1303-                     torch .tensor ([start_idx ], dtype = torch .long ).expand (3 , - 1 )
1304-                 ] *  1 )
13051308                place_num  =  (((audio_seqlen  -  1 ) //  2  +  1  -  2 ) //  2  +  1 ) +  2 
13061309                pure_audio_len  =  place_num  -  2 
13071310                added_audio_len  =  0 
@@ -1312,21 +1315,21 @@ def _omni_get_input_positions_tensor(
13121315                    new_src_item .extend ([video_token_id ] * 
13131316                                        vision_ntoken_per_chunk )
13141317                    vision_llm_pos_ids_list  =  cls ._get_llm_pos_ids_for_vision (
1315-                         start_idx   +   1 , video_idx , spatial_merge_size , t_chunk ,
1318+                         start_idx , video_idx , spatial_merge_size , t_chunk ,
13161319                        grid_hs , grid_ws ).split (1 , dim = 1 )
13171320                    llm_pos_ids_list .extend (vision_llm_pos_ids_list )
13181321                    new_src_item .extend (
13191322                        min (t_ntoken_per_chunk , pure_audio_len  - 
13201323                            added_audio_len ) *  [audio_token_id ])
13211324                    audio_start_idx  =  start_idx  if  len (
13221325                        audio_llm_pos_ids_list 
1323-                     ) ==  0  else  audio_llm_pos_ids_list [- 1 ][0 ].item ()
1326+                     ) ==  0  else  audio_llm_pos_ids_list [- 1 ][0 ].item ()  +   1 
13241327                    if  min (t_ntoken_per_chunk ,
13251328                           pure_audio_len  -  added_audio_len ) >  0 :
13261329                        audio_llm_pos_ids_list  =  (torch .arange (
13271330                            min (t_ntoken_per_chunk , pure_audio_len  - 
13281331                                added_audio_len )).expand (3 , - 1 ) + 
1329-                                                   audio_start_idx   +   1 ).split (
1332+                                                   audio_start_idx ).split (
13301333                                                      1 , dim = 1 )
13311334                    else :
13321335                        audio_llm_pos_ids_list  =  []
@@ -1341,11 +1344,6 @@ def _omni_get_input_positions_tensor(
13411344                            3 , - 1 ) +  llm_pos_ids_list [- 1 ].max () +  1 ).split (
13421345                                1 , dim = 1 )
13431346                    llm_pos_ids_list .extend (audio_llm_pos_ids_list )
1344-                 llm_pos_ids_list .extend ([
1345-                     torch .tensor (
1346-                         [llm_pos_ids_list [- 1 ].max () +  1 ] *  3 ).unsqueeze (1 )
1347-                 ] *  1 )
1348-                 new_src_item .extend ([audio_end_token_id ])
13491347                audio_idx  +=  1 
13501348                video_idx  +=  1 
13511349            # move to the next token 
0 commit comments