@@ -1209,6 +1209,7 @@ def _omni_get_input_positions_tensor(
12091209 video_token_id = thinker_config .video_token_index
12101210 audio_start_token_id = thinker_config .audio_start_token_id
12111211 audio_end_token_id = thinker_config .audio_end_token_id
1212+ vision_start_token_id = thinker_config .vision_start_token_id
12121213 vision_end_token_id = thinker_config .vision_end_token_id
12131214 seconds_per_chunk = thinker_config .seconds_per_chunk
12141215 spatial_merge_size = thinker_config .vision_config .spatial_merge_size
@@ -1238,8 +1239,15 @@ def _omni_get_input_positions_tensor(
12381239 if src_item [idx ] not in [
12391240 audio_token_id , video_token_id , image_token_id
12401241 ]:
1241- if src_item [idx ] == vision_end_token_id and use_audio_in_video :
1242- start_idx -= 1
1242+ if use_audio_in_video and idx > 0 :
1243+ if src_item [idx ] == vision_end_token_id and \
1244+ src_item [idx - 1 ] == audio_end_token_id :
1245+ # processing the <|audio_eos|> before <|vision_eos|>
1246+ start_idx -= 1
1247+ elif src_item [idx ] == audio_start_token_id and \
1248+ src_item [idx - 1 ] == vision_start_token_id :
1249+ # processing the <|audio_bos|> after <|vision_eos|>
1250+ start_idx -= 1
12431251 new_src_item .append (src_item [idx ])
12441252 llm_pos_ids = torch .tensor ([start_idx ],
12451253 dtype = torch .long ).expand (3 , - 1 )
@@ -1297,11 +1305,6 @@ def _omni_get_input_positions_tensor(
12971305 tokens_per_second ).long ()
12981306 t_index_split_chunk = cls ._split_list_into_ranges (
12991307 t_index , t_ntoken_per_chunk )
1300- new_src_item .extend ([audio_start_token_id ])
1301- start_idx -= 1
1302- llm_pos_ids_list .extend ([
1303- torch .tensor ([start_idx ], dtype = torch .long ).expand (3 , - 1 )
1304- ] * 1 )
13051308 place_num = (((audio_seqlen - 1 ) // 2 + 1 - 2 ) // 2 + 1 ) + 2
13061309 pure_audio_len = place_num - 2
13071310 added_audio_len = 0
@@ -1312,21 +1315,21 @@ def _omni_get_input_positions_tensor(
13121315 new_src_item .extend ([video_token_id ] *
13131316 vision_ntoken_per_chunk )
13141317 vision_llm_pos_ids_list = cls ._get_llm_pos_ids_for_vision (
1315- start_idx + 1 , video_idx , spatial_merge_size , t_chunk ,
1318+ start_idx , video_idx , spatial_merge_size , t_chunk ,
13161319 grid_hs , grid_ws ).split (1 , dim = 1 )
13171320 llm_pos_ids_list .extend (vision_llm_pos_ids_list )
13181321 new_src_item .extend (
13191322 min (t_ntoken_per_chunk , pure_audio_len -
13201323 added_audio_len ) * [audio_token_id ])
13211324 audio_start_idx = start_idx if len (
13221325 audio_llm_pos_ids_list
1323- ) == 0 else audio_llm_pos_ids_list [- 1 ][0 ].item ()
1326+ ) == 0 else audio_llm_pos_ids_list [- 1 ][0 ].item () + 1
13241327 if min (t_ntoken_per_chunk ,
13251328 pure_audio_len - added_audio_len ) > 0 :
13261329 audio_llm_pos_ids_list = (torch .arange (
13271330 min (t_ntoken_per_chunk , pure_audio_len -
13281331 added_audio_len )).expand (3 , - 1 ) +
1329- audio_start_idx + 1 ).split (
1332+ audio_start_idx ).split (
13301333 1 , dim = 1 )
13311334 else :
13321335 audio_llm_pos_ids_list = []
@@ -1341,11 +1344,6 @@ def _omni_get_input_positions_tensor(
13411344 3 , - 1 ) + llm_pos_ids_list [- 1 ].max () + 1 ).split (
13421345 1 , dim = 1 )
13431346 llm_pos_ids_list .extend (audio_llm_pos_ids_list )
1344- llm_pos_ids_list .extend ([
1345- torch .tensor (
1346- [llm_pos_ids_list [- 1 ].max () + 1 ] * 3 ).unsqueeze (1 )
1347- ] * 1 )
1348- new_src_item .extend ([audio_end_token_id ])
13491347 audio_idx += 1
13501348 video_idx += 1
13511349 # move to the next token
0 commit comments