Skip to content

Commit

Permalink
fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
faychu committed Oct 18, 2024
1 parent 2eb71ac commit 5dff6ea
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 9 deletions.
5 changes: 3 additions & 2 deletions examples/offline_inference_audio_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,11 @@ def run_qwen2_audio(question, audio_count):
max_num_seqs=5,
limit_mm_per_prompt={"audio": audio_count})

audio_in_prompt = "".join([f"Audio {idx+1}: <|audio_bos|><|AUDIO|><|audio_eos|>\n" for idx in range(audio_count)])

prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
"<|im_start|>user\n"
"Audio 1: <|audio_bos|><|AUDIO|><|audio_eos|>\n"
f"{question}<|im_end|>\n"
f"{audio_in_prompt}{question}<|im_end|>\n"
"<|im_start|>assistant\n")
stop_token_ids = None
return llm, prompt, stop_token_ids
Expand Down
16 changes: 9 additions & 7 deletions vllm/model_executor/models/qwen2_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,13 @@ def dummy_data_for_qwen2_audio(ctx: InputContext, seq_len: int,

audio_token_index = ctx.model_config.hf_config.audio_token_index

dummy_seqdata = SequenceData.from_token_counts(
dummy_seqdata = SequenceData.from_prompt_token_counts(
(audio_token_index, max_llm_audio_tokens),
(0, seq_len - max_llm_audio_tokens),
)
dummy_audio = np.full((max_llm_audio_tokens * 2 * 2 * 160, ), 0.)
return dummy_seqdata, {
"audio": dummy_audio if num_audios == 1 else [dummy_audio] * num_audios
"audio": [(dummy_audio, 16000)] * num_audios
}


Expand Down Expand Up @@ -165,11 +165,12 @@ def input_processor_for_qwen2_audio(ctx: InputContext,
multi_modal_data = llm_inputs.get("multi_modal_data")
if multi_modal_data is None or "audio" not in multi_modal_data:
return llm_inputs
if len(multi_modal_data["audio"]) == 0:
return llm_inputs
assert isinstance(multi_modal_data['audio'], list) and isinstance(multi_modal_data['audio'][0], tuple)

audios = multi_modal_data['audio']
processor = cached_get_processor(ctx.model_config.model)
if len(audios) == 0:
return llm_inputs
audios = [_[0] for _ in multi_modal_data['audio']]

audio_inputs = processor.feature_extractor(audios,
sampling_rate=16000,
Expand Down Expand Up @@ -227,14 +228,15 @@ def input_mapper_for_qwen2_audio(
}
return batch_data
try:
batch_data = audio_feature_extractor(multi_modal_data,
audios = [_[0] for _ in multi_modal_data]
batch_data = audio_feature_extractor(audios,
sampling_rate=16000,
return_attention_mask=True,
padding="max_length",
return_tensors="pt").data
batch_data["feature_attention_mask"] = batch_data.pop("attention_mask")
except Exception:
logger.error("Failed to process audio (%s)", multi_modal_data)
logger.error("Failed to process audio (%s)", audios)
raise

return MultiModalInputs(batch_data)
Expand Down

0 comments on commit 5dff6ea

Please sign in to comment.