diff --git a/setup.cfg b/setup.cfg index 0eba51f86f..a9a62a798f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -278,13 +278,8 @@ audiolm = soundfile~=0.12 librosa~=0.10 - # For OpenFlamingo - einops~=0.7.0 - einops-exts~=0.0.4 - open-clip-torch~=2.24 - - # For IDEFICS - torch~=2.1 + # For LLaMA-Omni + openai-whisper==20240930 # For Qwen2-Audio transformers~=4.45.1 diff --git a/src/helm/benchmark/scenarios/audio_language/audio_mnist_scenario.py b/src/helm/benchmark/scenarios/audio_language/audio_mnist_scenario.py index 2005417f43..d535fbe19a 100644 --- a/src/helm/benchmark/scenarios/audio_language/audio_mnist_scenario.py +++ b/src/helm/benchmark/scenarios/audio_language/audio_mnist_scenario.py @@ -1,6 +1,7 @@ """Scenarios for audio models""" from typing import List +import os from helm.benchmark.scenarios.scenario import ( Scenario, @@ -11,7 +12,11 @@ Input, Output, ) +from tqdm import tqdm +from datasets import load_dataset from helm.common.media_object import MediaObject, MultimediaObject +from helm.common.general import ensure_directory_exists +from helm.common.audio_utils import ensure_audio_file_exists class AudioMNISTScenario(Scenario): @@ -37,28 +42,21 @@ class AudioMNISTScenario(Scenario): } """ # noqa: E501 - NUM_SPEAKERS = 60 - NUM_TRIALS = 50 - WAV_URL_TEMPLATE = r"https://github.com/soerenab/AudioMNIST/raw/544b0f4bc65227e54332e665d5e02c24be6732c2/data/{speaker_id}/{digit}_{speaker_id}_{trial_index}.wav" # noqa: E501 - name = "audio_mnist" description = "Classify an audio sample of a spoken digit ([Becker et al, 2023](https://arxiv.org/abs/1807.03418))." tags = ["audio", "classification"] def get_instances(self, output_path: str) -> List[Instance]: instances: List[Instance] = [] - for digit in range(10): - for speaker_index in range(AudioMNISTScenario.NUM_SPEAKERS): - speaker_id = str(speaker_index).zfill(2) - for trial_index in range(AudioMNISTScenario.NUM_TRIALS): - wav_url = AudioMNISTScenario.WAV_URL_TEMPLATE.format( - digit=digit, speaker_id=speaker_id, trial_index=trial_index - ) - input = Input( - multimedia_content=MultimediaObject([MediaObject(content_type="audio/wav", location=wav_url)]) - ) - references = [Reference(Output(text=str(digit)), tags=[CORRECT_TAG])] - # Don't need train split because we're using zero-shot - instance = Instance(input=input, references=references, split=TEST_SPLIT) - instances.append(instance) + wav_save_dir: str = os.path.join(output_path, "wav_files") + ensure_directory_exists(wav_save_dir) + for row in tqdm(load_dataset("flexthink/audiomnist", cache_dir=output_path, split=TEST_SPLIT)): + local_audio_path = os.path.join(wav_save_dir, row["audio"]["path"]) + audio_array = row["audio"]["array"] + ensure_audio_file_exists(local_audio_path, audio_array, row["audio"]["sampling_rate"]) + input = Input( + multimedia_content=MultimediaObject([MediaObject(content_type="audio/mpeg", location=local_audio_path)]) + ) + references = [Reference(Output(text=str(row["digit"])), tags=[CORRECT_TAG])] + instances.append(Instance(input=input, references=references, split=TEST_SPLIT)) return instances diff --git a/src/helm/clients/audio_language/llama_omni/arguments.py b/src/helm/clients/audio_language/llama_omni/arguments.py new file mode 100644 index 0000000000..0a0aee1650 --- /dev/null +++ b/src/helm/clients/audio_language/llama_omni/arguments.py @@ -0,0 +1,61 @@ +import transformers + +from dataclasses import dataclass, field +from typing import Optional + + +@dataclass +class ModelArguments: + model_name_or_path: Optional[str] = field(default="facebook/opt-125m") + version: Optional[str] = field(default="v0") + freeze_backbone: bool = field(default=False) + tune_speech_projector: bool = field(default=False) + tune_speech_encoder: bool = field(default=False) + tune_speech_generator_only: bool = field(default=False) + speech_encoder_type: Optional[str] = field(default=None) + speech_encoder: Optional[str] = field(default=None) + pretrain_speech_projector: Optional[str] = field(default=None) + speech_projector_type: Optional[str] = field(default="linear") + speech_generator_type: Optional[str] = field(default="ctc") + ctc_decoder_config: str = "(2,4096,32,11008)" + ctc_upsample_factor: int = 1 + ctc_loss_weight: float = 1.0 + unit_vocab_size: int = 1000 + speech_encoder_ds_rate: int = 5 + speech_encoder_hidden_size: int = 1280 + + +@dataclass +class DataArguments: + data_path: str = field(default="", metadata={"help": "Path to the training data."}) + is_multimodal: bool = False + input_type: str = field(default="mel") + speech_normalize: bool = False + mel_size: int = 128 + has_tgt_units: bool = False + + +@dataclass +class TrainingArguments(transformers.TrainingArguments): + cache_dir: Optional[str] = field(default=None) + optim: str = field(default="adamw_torch") + freeze_speech_projector: bool = field(default=False) + model_max_length: int = field( + default=512, + metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."}, + ) + double_quant: bool = field( + default=True, metadata={"help": "Compress the quantization statistics through double quantization."} + ) + quant_type: str = field( + default="nf4", metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."} + ) + bits: int = field(default=16, metadata={"help": "How many bits to use."}) + lora_enable: bool = False + lora_r: int = 64 + lora_alpha: int = 16 + lora_dropout: float = 0.05 + lora_weight_path: str = "" + lora_bias: str = "none" + speech_projector_lr: Optional[float] = None + group_by_modality_length: bool = field(default=False) diff --git a/src/helm/clients/audio_language/llama_omni/constants.py b/src/helm/clients/audio_language/llama_omni/constants.py new file mode 100644 index 0000000000..7ef790617f --- /dev/null +++ b/src/helm/clients/audio_language/llama_omni/constants.py @@ -0,0 +1,9 @@ +CONTROLLER_HEART_BEAT_EXPIRATION = 30 +WORKER_HEART_BEAT_INTERVAL = 15 + +LOGDIR = "." + +# Model Constants +IGNORE_INDEX = -100 +SPEECH_TOKEN_INDEX = -200 +DEFAULT_SPEECH_TOKEN = "" diff --git a/src/helm/clients/audio_language/llama_omni/conversation.py b/src/helm/clients/audio_language/llama_omni/conversation.py new file mode 100644 index 0000000000..72596e6448 --- /dev/null +++ b/src/helm/clients/audio_language/llama_omni/conversation.py @@ -0,0 +1,213 @@ +# Adopted from https://github.com/haotian-liu/LLaVA. Below is the original copyright: +# Copyright 2023 Haotian Liu +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dataclasses +from enum import auto, Enum +from typing import List, Any, Union, Optional + + +class SeparatorStyle(Enum): + """Different separator style.""" + + TWO = auto() + PLAIN = auto() + LLAMA_2 = auto() + LLAMA_3 = auto() + + +@dataclasses.dataclass +class Conversation: + """A class that keeps all conversation history.""" + + system: str + roles: List[str] + messages: List[List[str]] + offset: int + sep_style: SeparatorStyle = SeparatorStyle.PLAIN + sep: str = "###" + sep2: str = "" + version: str = "Unknown" + + tokenizer_id: str = "" + tokenizer: Any = None + # Stop criteria (the default one is EOS token) + stop_str: Optional[Union[str, List[str]]] = None + # Stops generation if meeting any token in this list + stop_token_ids: Optional[List[int]] = None + + skip_next: bool = False + + def get_prompt(self): + messages = self.messages + + if self.sep_style == SeparatorStyle.TWO: + seps = [self.sep, self.sep2] + ret = self.system + seps[0] + for i, (role, message) in enumerate(messages): + if message: + if type(message) is tuple: + message = message[0] + ret += role + ": " + message + seps[i % 2] + else: + ret += role + ":" + elif self.sep_style == SeparatorStyle.LLAMA_3: + wrap_sys = lambda msg: ( + f"<|start_header_id|>system<|end_header_id|>\n\n{msg}<|eot_id|>" if len(msg) > 0 else msg + ) + ret = "<|begin_of_text|>" + wrap_sys(self.system) + for i, (role, message) in enumerate(messages): + if message: + if type(message) is tuple: + message = message[0] + ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n" + ret += message.strip() + self.sep2 + else: + ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n" + return ret + elif self.sep_style == SeparatorStyle.LLAMA_2: + wrap_sys = lambda msg: f"<>\n{msg}\n<>\n\n" if len(msg) > 0 else msg + wrap_inst = lambda msg: f"[INST] {msg} [/INST]" + ret = "" + + for i, (role, message) in enumerate(messages): + if i == 0: + assert message, "first message should not be none" + assert role == self.roles[0], "first message should come from user" + if message: + if type(message) is tuple: + message = message[0] + if i == 0: + message = wrap_sys(self.system) + message + if i % 2 == 0: + message = wrap_inst(message) + ret += self.sep + message + else: + ret += " " + message + " " + self.sep2 + else: + ret += "" + ret = ret.lstrip(self.sep) + elif self.sep_style == SeparatorStyle.PLAIN: + seps = [self.sep, self.sep2] + ret = self.system + for i, (role, message) in enumerate(messages): + if message: + if type(message) is tuple: + message = message[0] + ret += message + seps[i % 2] + else: + ret += "" + else: + raise ValueError(f"Invalid style: {self.sep_style}") + + return ret + + def append_message(self, role, message): + self.messages.append([role, message]) + + def to_gradio_chatbot(self): + ret = [] + for i, (role, msg) in enumerate(self.messages[self.offset :]): + if i % 2 == 0: + if type(msg) is tuple: + msg = msg[0] + ret.append([msg, None]) + else: + ret.append([msg, None]) + else: + ret[-1][-1] = msg + return ret + + def copy(self): + return Conversation( + system=self.system, + roles=self.roles, + messages=[[x, y] for x, y in self.messages], + offset=self.offset, + sep_style=self.sep_style, + sep=self.sep, + sep2=self.sep2, + version=self.version, + ) + + def dict(self): + return { + "system": self.system, + "roles": self.roles, + "messages": self.messages, + "offset": self.offset, + "sep": self.sep, + "sep2": self.sep2, + } + + +conv_vicuna_v1 = Conversation( + system="A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions.", + roles=["USER", "ASSISTANT"], + version="v1", + messages=[], + offset=0, + sep_style=SeparatorStyle.TWO, + sep=" ", + sep2="", +) + +conv_llama_2 = Conversation( + system="You are a helpful language and speech assistant. " + "You are able to understand the speech content that the user provides, " + "and assist the user with a variety of tasks using natural language.", + roles=["USER", "ASSISTANT"], + version="llama_v2", + messages=[], + offset=0, + sep_style=SeparatorStyle.LLAMA_2, + sep="", + sep2="", +) + +conv_llama_3 = Conversation( + system="You are a helpful language and speech assistant. " + "You are able to understand the speech content that the user provides, " + "and assist the user with a variety of tasks using natural language.", + roles=["user", "assistant"], + version="llama_v3", + messages=[], + offset=0, + sep_style=SeparatorStyle.LLAMA_3, + sep="", + sep2="<|eot_id|>", +) + +conv_plain = Conversation( + system="", + roles=["", ""], + messages=[], + offset=0, + sep_style=SeparatorStyle.PLAIN, + sep="", +) + + +default_conversation = conv_llama_3 +conv_templates = { + "v1": conv_vicuna_v1, + "plain": conv_plain, + "llama_2": conv_llama_2, + "llama_3": conv_llama_3, +} + + +if __name__ == "__main__": + print(default_conversation.get_prompt()) diff --git a/src/helm/clients/audio_language/llama_omni/model/__init__.py b/src/helm/clients/audio_language/llama_omni/model/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/helm/clients/audio_language/llama_omni/model/builder.py b/src/helm/clients/audio_language/llama_omni/model/builder.py new file mode 100644 index 0000000000..ddff036006 --- /dev/null +++ b/src/helm/clients/audio_language/llama_omni/model/builder.py @@ -0,0 +1,88 @@ +import os + +from transformers import AutoTokenizer, AutoConfig, BitsAndBytesConfig +import torch +from helm.clients.audio_language.llama_omni.model.language_model.omni_speech_llama import OmniSpeechLlamaForCausalLM +from helm.clients.audio_language.llama_omni.model.language_model.omni_speech2s_llama import OmniSpeech2SLlamaForCausalLM +from helm.clients.audio_language.llama_omni.model.speech_encoder.builder import build_speech_encoder + + +def load_pretrained_model( + model_path, + model_base, + is_lora=False, + s2s=False, + load_8bit=False, + load_4bit=False, + device="cuda", + use_flash_attn=False, + **kwargs +): + if load_8bit: + kwargs["load_in_8bit"] = True + elif load_4bit: + kwargs["load_in_4bit"] = True + kwargs["quantization_config"] = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + ) + else: + kwargs["torch_dtype"] = torch.float16 + + if use_flash_attn: + kwargs["attn_implementation"] = "flash_attention_2" + + model_cls = OmniSpeech2SLlamaForCausalLM if s2s else OmniSpeechLlamaForCausalLM + + # Load OmniSpeech model + if is_lora: + assert model_base is not None, "model_base is required for LoRA models." + from language_model.omni_speech_llama import OmniSpeechConfig + + lora_cfg_pretrained = OmniSpeechConfig.from_pretrained(model_path) + tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) + print("Loading OmniSpeech from base model...") + model = model_cls.from_pretrained(model_base, low_cpu_mem_usage=False, config=lora_cfg_pretrained, **kwargs) + print("Loading additional OmniSpeech weights...") + if os.path.exists(os.path.join(model_path, "non_lora_trainables.bin")): + non_lora_trainables = torch.load(os.path.join(model_path, "non_lora_trainables.bin"), map_location="cpu") + non_lora_trainables = { + (k[11:] if k.startswith("base_model.") else k): v for k, v in non_lora_trainables.items() + } + if any(k.startswith("model.model.") for k in non_lora_trainables): + non_lora_trainables = {(k[6:] if k.startswith("model.") else k): v for k, v in non_lora_trainables.items()} + model.load_state_dict(non_lora_trainables, strict=False) + + from peft import PeftModel + + print("Loading LoRA weights...") + model = PeftModel.from_pretrained(model, model_path) + print("Merging LoRA weights...") + model = model.merge_and_unload() + print("Model is loaded...") + elif model_base is not None: + print("Loading OmniSpeech from base model...") + tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) + cfg_pretrained = AutoConfig.from_pretrained(model_path) + model = model_cls.from_pretrained(model_base, low_cpu_mem_usage=False, config=cfg_pretrained, **kwargs) + + speech_projector_weights = torch.load(os.path.join(model_path, "speech_projector.bin"), map_location="cpu") + speech_projector_weights = {k: v.to(torch.float16) for k, v in speech_projector_weights.items()} + model.load_state_dict(speech_projector_weights, strict=False) + model = model.to(device=device) + else: + tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) + model = model_cls.from_pretrained(model_path, low_cpu_mem_usage=False, **kwargs) + model = model.to(device=device) + + model.get_model().speech_encoder = build_speech_encoder(model.config) + model.get_model().speech_encoder.to(device=device, dtype=torch.float16) + + if hasattr(model.config, "max_sequence_length"): + context_len = model.config.max_sequence_length + else: + context_len = 2048 + + return tokenizer, model, context_len diff --git a/src/helm/clients/audio_language/llama_omni/model/language_model/omni_speech2s_llama.py b/src/helm/clients/audio_language/llama_omni/model/language_model/omni_speech2s_llama.py new file mode 100644 index 0000000000..c1e5471da2 --- /dev/null +++ b/src/helm/clients/audio_language/llama_omni/model/language_model/omni_speech2s_llama.py @@ -0,0 +1,190 @@ +from typing import List, Optional, Tuple, Union, Callable + +import torch + +from transformers import AutoConfig, AutoModelForCausalLM, LlamaConfig + +from transformers import PreTrainedModel +from transformers.generation.streamers import BaseStreamer +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.generation.utils import ( + GenerationConfig, + LogitsProcessorList, + StoppingCriteriaList, +) + +from helm.clients.audio_language.llama_omni.model.language_model.omni_speech_llama import OmniSpeechLlamaForCausalLM +from helm.clients.audio_language.llama_omni.model.speech_generator.builder import build_speech_generator +from helm.clients.audio_language.llama_omni.model.speech_generator.generation import GenerationWithCTC + + +class OmniSpeech2SConfig(LlamaConfig): + model_type = "omni_speech2s_llama" + + +class OmniSpeech2SLlamaForCausalLM(OmniSpeechLlamaForCausalLM, GenerationWithCTC): + config_class = OmniSpeech2SConfig + + def __init__(self, config): + super().__init__(config) + + # Initialize weights and apply final processing + self.post_init() + if hasattr(config, "speech_generator_type"): + self.speech_generator = build_speech_generator(config) + + def initialize_speech_generator(self, model_args): + self.config.speech_generator_type = getattr(model_args, "speech_generator_type", "ctc") + self.config.ctc_decoder_config = getattr(model_args, "ctc_decoder_config", "(4,4096,32,11008)") + self.config.ctc_upsample_factor = getattr(model_args, "ctc_upsample_factor", 1) + self.config.ctc_loss_weight = getattr(model_args, "ctc_loss_weight", 1.0) + self.config.unit_vocab_size = getattr(model_args, "unit_vocab_size", 1000) + self.tune_speech_generator_only = getattr(model_args, "tune_speech_generator_only", False) + if getattr(self, "speech_generator", None) is None: + self.speech_generator = build_speech_generator(self.config) + + def forward( + self, + input_ids: torch.LongTensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + speech: Optional[torch.FloatTensor] = None, + speech_lengths: Optional[torch.LongTensor] = None, + tgt_units: Optional[torch.LongTensor] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + + if inputs_embeds is None: + (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = ( + self.prepare_inputs_labels_for_speech_and_text( + input_ids, position_ids, attention_mask, past_key_values, labels, speech, speech_lengths + ) + ) + + if self.training: + if self.tune_speech_generator_only: + with torch.no_grad(): + llama_output = super(OmniSpeechLlamaForCausalLM, self).forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=True, + return_dict=return_dict, + ) + loss = self.speech_generator(llama_output["hidden_states"][-1], labels, tgt_units) + else: + llama_output = super(OmniSpeechLlamaForCausalLM, self).forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=True, + return_dict=return_dict, + ) + lm_loss = llama_output.loss + ctc_loss = self.speech_generator(llama_output["hidden_states"][-1], labels, tgt_units) + loss = lm_loss + ctc_loss * self.config.ctc_loss_weight + else: + llama_output = super(OmniSpeechLlamaForCausalLM, self).forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=True, + return_dict=return_dict, + ) + loss = llama_output.loss + + return CausalLMOutputWithPast( + loss=loss, + logits=llama_output.logits, + past_key_values=llama_output.past_key_values, + hidden_states=llama_output.hidden_states, + attentions=llama_output.attentions, + ) + + @torch.no_grad() + def generate( + self, + inputs: Optional[torch.Tensor] = None, + speech: Optional[torch.Tensor] = None, + speech_lengths: Optional[torch.Tensor] = None, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, + synced_gpus: Optional[bool] = None, + assistant_model: Optional["PreTrainedModel"] = None, + streamer: Optional["BaseStreamer"] = None, + streamer_unit: Optional["BaseStreamer"] = None, + streaming_unit_gen=False, + negative_prompt_ids: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + **kwargs, + ): + position_ids = kwargs.pop("position_ids", None) + attention_mask = kwargs.pop("attention_mask", None) + if "inputs_embeds" in kwargs: + raise NotImplementedError("`inputs_embeds` is not supported") + + if speech is not None: + (inputs, position_ids, attention_mask, _, inputs_embeds, _) = ( + self.prepare_inputs_labels_for_speech_and_text( + inputs, position_ids, attention_mask, None, None, speech, speech_lengths + ) + ) + else: + inputs_embeds = self.get_model().embed_tokens(inputs) + outputs = GenerationWithCTC.generate( + self, + position_ids=position_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + output_hidden_states=True, + return_dict_in_generate=True, + streaming_unit_gen=streaming_unit_gen, + **kwargs, + ) + + hidden_states = outputs["hidden_states"] + hidden_states = torch.cat( + [hidden_states[0][-1][:, -1:, :]] + [hidden_states[i][-1] for i in range(1, len(hidden_states))], dim=1 + ) + ctc_pred = self.speech_generator.predict(hidden_states.squeeze(0)) + + return outputs.sequences, ctc_pred + + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): + speech = kwargs.pop("speech", None) + speech_lengths = kwargs.pop("speech_lengths", None) + inputs = super().prepare_inputs_for_generation( + input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs + ) + if speech is not None: + inputs["speech"] = speech + inputs["speech_lengths"] = speech_lengths + return inputs + + +AutoConfig.register("omni_speech2s_llama", OmniSpeech2SConfig) +AutoModelForCausalLM.register(OmniSpeech2SConfig, OmniSpeech2SLlamaForCausalLM) diff --git a/src/helm/clients/audio_language/llama_omni/model/language_model/omni_speech_llama.py b/src/helm/clients/audio_language/llama_omni/model/language_model/omni_speech_llama.py new file mode 100644 index 0000000000..2bd5e60601 --- /dev/null +++ b/src/helm/clients/audio_language/llama_omni/model/language_model/omni_speech_llama.py @@ -0,0 +1,118 @@ +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from transformers import AutoConfig, AutoModelForCausalLM, LlamaConfig, LlamaModel, LlamaForCausalLM + +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.generation.utils import GenerateOutput + +from helm.clients.audio_language.llama_omni.model.omni_speech_arch import OmniSpeechMetaModel, OmniSpeechMetaForCausalLM + + +class OmniSpeechConfig(LlamaConfig): + model_type = "omni_speech_llama" + + +class OmniSpeechLlamaModel(OmniSpeechMetaModel, LlamaModel): + config_class = OmniSpeechConfig + + def __init__(self, config: LlamaConfig): + super(OmniSpeechLlamaModel, self).__init__(config) + + +class OmniSpeechLlamaForCausalLM(LlamaForCausalLM, OmniSpeechMetaForCausalLM): + config_class = OmniSpeechConfig + + def __init__(self, config): + super(LlamaForCausalLM, self).__init__(config) + self.model = OmniSpeechLlamaModel(config) + self.pretraining_tp = config.pretraining_tp + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_model(self): + return self.model + + def forward( + self, + input_ids: torch.LongTensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + speech: Optional[torch.FloatTensor] = None, + speech_lengths: Optional[torch.LongTensor] = None, + tgt_units: Optional[torch.LongTensor] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + + if inputs_embeds is None: + (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = ( + self.prepare_inputs_labels_for_speech_and_text( + input_ids, position_ids, attention_mask, past_key_values, labels, speech, speech_lengths + ) + ) + + return super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + @torch.no_grad() + def generate( + self, + inputs: Optional[torch.Tensor] = None, + speech: Optional[torch.Tensor] = None, + speech_lengths: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[GenerateOutput, torch.LongTensor]: + position_ids = kwargs.pop("position_ids", None) + attention_mask = kwargs.pop("attention_mask", None) + if "inputs_embeds" in kwargs: + raise NotImplementedError("`inputs_embeds` is not supported") + + if speech is not None: + (inputs, position_ids, attention_mask, _, inputs_embeds, _) = ( + self.prepare_inputs_labels_for_speech_and_text( + inputs, position_ids, attention_mask, None, None, speech, speech_lengths + ) + ) + else: + inputs_embeds = self.get_model().embed_tokens(inputs) + + return super().generate( + position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs + ) + + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): + speech = kwargs.pop("speech", None) + speech_lengths = kwargs.pop("speech_lengths", None) + inputs = super().prepare_inputs_for_generation( + input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs + ) + if speech is not None: + inputs["speech"] = speech + inputs["speech_lengths"] = speech_lengths + return inputs + + +AutoConfig.register("omni_speech_llama", OmniSpeechConfig) +AutoModelForCausalLM.register(OmniSpeechConfig, OmniSpeechLlamaForCausalLM) diff --git a/src/helm/clients/audio_language/llama_omni/model/omni_speech_arch.py b/src/helm/clients/audio_language/llama_omni/model/omni_speech_arch.py new file mode 100644 index 0000000000..d0260f84be --- /dev/null +++ b/src/helm/clients/audio_language/llama_omni/model/omni_speech_arch.py @@ -0,0 +1,249 @@ +from abc import ABC, abstractmethod + +import torch +from torch import nn + +from helm.clients.audio_language.llama_omni.model.speech_encoder.builder import build_speech_encoder +from helm.clients.audio_language.llama_omni.model.speech_projector.builder import build_speech_projector +from helm.clients.audio_language.llama_omni.constants import IGNORE_INDEX, SPEECH_TOKEN_INDEX + + +class OmniSpeechMetaModel(nn.Module): + + def __init__(self, config): + super(OmniSpeechMetaModel, self).__init__(config) + self.config = config + + if hasattr(config, "speech_encoder"): + self.speech_encoder = build_speech_encoder(config) + self.speech_projector = build_speech_projector(config) + + def get_speech_encoder(self): + speech_encoder = getattr(self, "speech_encoder", None) + if type(speech_encoder) is list: + speech_encoder = speech_encoder[0] + return speech_encoder + + def initialize_speech_modules(self, model_args, fsdp=None): + self.config.speech_encoder = getattr(model_args, "speech_encoder", None) + self.config.speech_encoder_type = getattr(model_args, "speech_encoder_type", None) + self.config.speech_projector_type = getattr(model_args, "speech_projector_type", "linear") + self.config.speech_encoder_ds_rate = getattr(model_args, "speech_encoder_ds_rate", 5) + self.config.speech_encoder_hidden_size = getattr(model_args, "speech_encoder_hidden_size", 1280) + + if self.get_speech_encoder() is None: + speech_encoder = build_speech_encoder(self.config) + if fsdp is not None and len(fsdp) > 0: + self.speech_encoder = [speech_encoder] + else: + self.speech_encoder = speech_encoder + + if getattr(self, "speech_projector", None) is None: + self.speech_projector = build_speech_projector(self.config) + else: + # In case it is frozen by LoRA + for p in self.speech_projector.parameters(): + p.requires_grad = True + + if model_args.pretrain_speech_projector is not None: + pretrain_speech_projector_weights = torch.load(model_args.pretrain_speech_projector, map_location="cpu") + + def get_w(weights, keyword): + return {k.split(keyword + ".")[1]: v for k, v in weights.items() if keyword in k} + + self.speech_projector.load_state_dict(get_w(pretrain_speech_projector_weights, "speech_projector")) + + +class OmniSpeechMetaForCausalLM(ABC): + def __init__(self, config): + self.config = config + + @abstractmethod + def get_model(self): + pass + + def get_speech_encoder(self): + return self.get_model().get_speech_encoder() + + def get_speech_projector(self): + return self.get_model().speech_projector + + def encode_speech(self, speech, speech_lengths): + speech_encoder_type = self.config.speech_encoder_type + speech_encoder = self.get_speech_encoder() + if "whisper" in speech_encoder_type.lower(): + encoder_outs = speech_encoder(speech.permute(0, 2, 1)) + speech_lengths = (speech_lengths + 1) // 2 + else: + raise ValueError(f"Unknown speech encoder: {speech_encoder}") + speech_projector_type = self.config.speech_projector_type + speech_projector = self.get_speech_projector() + if speech_projector_type == "linear": + encoder_outs = speech_projector(encoder_outs) + speech_lengths = speech_lengths // speech_projector.k + else: + raise ValueError(f"Unknown speech projector: {speech_projector_type}") + speech_features = [encoder_outs[i, : speech_lengths[i]] for i in range(len(encoder_outs))] + return speech_features + + def prepare_inputs_labels_for_speech_and_text( + self, input_ids, position_ids, attention_mask, past_key_values, labels, speech, speech_lengths + ): + # input_ids = input_ids.unsqueeze(0) + speech_encoder = self.get_speech_encoder() + if speech_encoder is None or speech is None or input_ids.shape[1] == 1: + return input_ids, position_ids, attention_mask, past_key_values, None, labels + + speech_features = self.encode_speech(speech, speech_lengths) + # Let's just add dummy tensors if they do not exist, + # it is a headache to deal with None all the time. + # But it is not ideal, and if you have a better idea, + # please open an issue / submit a PR, thanks. + _labels = labels + _position_ids = position_ids + _attention_mask = attention_mask + if attention_mask is None: + attention_mask = torch.ones_like(input_ids, dtype=torch.bool) + else: + attention_mask = attention_mask.bool() + if position_ids is None: + position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device) + if labels is None: + labels = torch.full_like(input_ids, IGNORE_INDEX) + + # remove the padding using attention_mask -- FIXME + # _input_ids = input_ids + input_ids = [ + cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask) + ] + labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)] + + new_input_embeds = [] + new_labels = [] + cur_speech_idx = 0 + for batch_idx, cur_input_ids in enumerate(input_ids): + num_speech = (cur_input_ids == SPEECH_TOKEN_INDEX).sum() + if num_speech == 0: + cur_speech_features = speech_features[cur_speech_idx] + cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids) + cur_input_embeds = torch.cat([cur_input_embeds_1, cur_speech_features[0:0]], dim=0) + new_input_embeds.append(cur_input_embeds) + new_labels.append(labels[batch_idx]) + cur_speech_idx += 1 + continue + + speech_token_indices = ( + [-1] + torch.where(cur_input_ids == SPEECH_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]] + ) + cur_input_ids_nospeech = [] + cur_labels = labels[batch_idx] + cur_labels_nospeech = [] + for i in range(len(speech_token_indices) - 1): + cur_input_ids_nospeech.append(cur_input_ids[speech_token_indices[i] + 1 : speech_token_indices[i + 1]]) + cur_labels_nospeech.append(cur_labels[speech_token_indices[i] + 1 : speech_token_indices[i + 1]]) + split_sizes = [x.shape[0] for x in cur_labels_nospeech] + cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_nospeech)) + cur_input_embeds_no_speech = torch.split(cur_input_embeds, split_sizes, dim=0) + cur_new_input_embeds = [] + cur_new_labels = [] + + for i in range(num_speech + 1): + cur_new_input_embeds.append(cur_input_embeds_no_speech[i]) + cur_new_labels.append(cur_labels_nospeech[i]) + if i < num_speech: + cur_speech_features = speech_features[cur_speech_idx] + cur_speech_idx += 1 + cur_new_input_embeds.append(cur_speech_features) + cur_new_labels.append( + torch.full( + (cur_speech_features.shape[0],), + IGNORE_INDEX, + device=cur_labels.device, + dtype=cur_labels.dtype, + ) + ) + + cur_new_input_embeds_stack = [x.to(input_ids[0].device) for x in cur_new_input_embeds] + + cur_new_input_embeds_tensor = torch.cat(cur_new_input_embeds_stack) + cur_new_labels_tensor = torch.cat(cur_new_labels) + + new_input_embeds.append(cur_new_input_embeds_tensor) + new_labels.append(cur_new_labels_tensor) + + # Truncate sequences to max length as speech features can make the sequence longer + tokenizer_model_max_length = getattr(self.config, "tokenizer_model_max_length", None) + if tokenizer_model_max_length is not None: + new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds] + new_labels = [x[:tokenizer_model_max_length] for x in new_labels] + + # Combine them + max_len = max(x.shape[0] for x in new_input_embeds) + batch_size = len(new_input_embeds) + + new_input_embeds_padded = [] + new_labels_padded = torch.full( + (batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device + ) + attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device) + position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device) + + for i, (cur_new_embed, cur_new_labels_loop) in enumerate(zip(new_input_embeds, new_labels)): + cur_len = cur_new_embed.shape[0] + if getattr(self.config, "tokenizer_padding_side", "right") == "left": + new_input_embeds_padded.append( + torch.cat( + ( + torch.zeros( + (max_len - cur_len, cur_new_embed.shape[1]), + dtype=cur_new_embed.dtype, + device=cur_new_embed.device, + ), + cur_new_embed, + ), + dim=0, + ) + ) + if cur_len > 0: + new_labels_padded[i, -cur_len:] = cur_new_labels_loop + attention_mask[i, -cur_len:] = True + position_ids[i, -cur_len:] = torch.arange( + 0, cur_len, dtype=position_ids.dtype, device=position_ids.device + ) + else: + new_input_embeds_padded.append( + torch.cat( + ( + cur_new_embed, + torch.zeros( + (max_len - cur_len, cur_new_embed.shape[1]), + dtype=cur_new_embed.dtype, + device=cur_new_embed.device, + ), + ), + dim=0, + ) + ) + if cur_len > 0: + new_labels_padded[i, :cur_len] = cur_new_labels_loop + attention_mask[i, :cur_len] = True + position_ids[i, :cur_len] = torch.arange( + 0, cur_len, dtype=position_ids.dtype, device=position_ids.device + ) + + new_input_embeds_tensor = torch.stack(new_input_embeds_padded, dim=0) + + if _labels is None: + new_labels_new = None + else: + new_labels_new = new_labels_padded + + if _attention_mask is None: + attention_mask_new = None + else: + attention_mask_new = attention_mask.to(dtype=_attention_mask.dtype) + + if _position_ids is None: + position_ids = None + + return None, position_ids, attention_mask_new, past_key_values, new_input_embeds_tensor, new_labels_new diff --git a/src/helm/clients/audio_language/llama_omni/model/speech_encoder/builder.py b/src/helm/clients/audio_language/llama_omni/model/speech_encoder/builder.py new file mode 100644 index 0000000000..9c9aea37dc --- /dev/null +++ b/src/helm/clients/audio_language/llama_omni/model/speech_encoder/builder.py @@ -0,0 +1,9 @@ +from helm.clients.audio_language.llama_omni.model.speech_encoder.speech_encoder import WhisperWrappedEncoder + + +def build_speech_encoder(config): + speech_encoder_type = getattr(config, "speech_encoder_type", "none") + if "whisper" in speech_encoder_type.lower(): + return WhisperWrappedEncoder.load(config) + + raise ValueError(f"Unknown speech encoder: {speech_encoder_type}") diff --git a/src/helm/clients/audio_language/llama_omni/model/speech_encoder/speech_encoder.py b/src/helm/clients/audio_language/llama_omni/model/speech_encoder/speech_encoder.py new file mode 100644 index 0000000000..d436e32e64 --- /dev/null +++ b/src/helm/clients/audio_language/llama_omni/model/speech_encoder/speech_encoder.py @@ -0,0 +1,27 @@ +# Adopted from https://github.com/ddlBoJack/SLAM-LLM/blob/main/src/slam_llm/models/encoder.py +import torch.nn as nn +import whisper + + +class WhisperWrappedEncoder: + + @classmethod + def load(cls, model_config): + + def replace_layer_norm(module): + from whisper.model import LayerNorm + + for name, child in module.named_children(): + if isinstance(child, LayerNorm): + old_params = child.state_dict() + new_layer_norm = nn.LayerNorm( + child.normalized_shape, eps=child.eps, elementwise_affine=child.elementwise_affine + ) + new_layer_norm.load_state_dict(old_params) + setattr(module, name, new_layer_norm) + else: + replace_layer_norm(child) + + encoder = whisper.load_model(name="large-v3", device="cpu").encoder + replace_layer_norm(encoder) + return encoder diff --git a/src/helm/clients/audio_language/llama_omni/model/speech_generator/builder.py b/src/helm/clients/audio_language/llama_omni/model/speech_generator/builder.py new file mode 100644 index 0000000000..85e407b849 --- /dev/null +++ b/src/helm/clients/audio_language/llama_omni/model/speech_generator/builder.py @@ -0,0 +1,9 @@ +from helm.clients.audio_language.llama_omni.model.speech_generator.speech_generator import SpeechGeneratorCTC + + +def build_speech_generator(config): + generator_type = getattr(config, "speech_generator_type", "ctc") + if generator_type == "ctc": + return SpeechGeneratorCTC(config) + + raise ValueError(f"Unknown generator type: {generator_type}") diff --git a/src/helm/clients/audio_language/llama_omni/model/speech_generator/generation.py b/src/helm/clients/audio_language/llama_omni/model/speech_generator/generation.py new file mode 100644 index 0000000000..d02b14a836 --- /dev/null +++ b/src/helm/clients/audio_language/llama_omni/model/speech_generator/generation.py @@ -0,0 +1,622 @@ +import torch +import inspect +import warnings +import torch.nn as nn +from typing import Optional, Union, List, Callable +import torch.distributed as dist + +from transformers import PreTrainedModel +from transformers.generation.streamers import BaseStreamer +from transformers.generation.utils import ( + GenerationConfig, + GenerationMode, + LogitsProcessorList, + StoppingCriteriaList, + GenerationMixin, + GenerateEncoderDecoderOutput, + GenerateDecoderOnlyOutput, + GenerateNonBeamOutput, + is_deepspeed_zero3_enabled, + is_torchdynamo_compiling, + NEED_SETUP_CACHE_CLASSES_MAPPING, + QUANT_BACKEND_CLASSES_MAPPING, + is_hqq_available, + QuantizedCacheConfig, + is_quanto_available, + DynamicCache, + EncoderDecoderCache, + logging, +) + +logger = logging.get_logger(__name__) + + +class GenerationWithCTC(GenerationMixin): + + @torch.no_grad() + def generate( + self, + inputs: Optional[torch.Tensor] = None, + speech: Optional[torch.Tensor] = None, + speech_lengths: Optional[torch.Tensor] = None, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, + synced_gpus: Optional[bool] = None, + assistant_model: Optional["PreTrainedModel"] = None, + streamer: Optional["BaseStreamer"] = None, + streamer_unit: Optional["BaseStreamer"] = None, + streaming_unit_gen=False, + negative_prompt_ids: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + **kwargs, + ): + + # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call + self._validate_model_class() + tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria + generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs) + self._validate_model_kwargs(model_kwargs.copy()) + self._validate_assistant(assistant_model) + + # 2. Set generation parameters if not already defined + if synced_gpus is None: + if is_deepspeed_zero3_enabled() and dist.get_world_size() > 1: + synced_gpus = True + else: + synced_gpus = False + + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + + accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys()) + requires_attention_mask = "encoder_outputs" not in model_kwargs + kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None + + # 3. Define model inputs + inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs( + inputs, generation_config.bos_token_id, model_kwargs + ) + batch_size = inputs_tensor.shape[0] + + device = inputs_tensor.device + self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device) + + # decoder-only models must use left-padding for batched generation. + if not self.config.is_encoder_decoder and not is_torchdynamo_compiling(): + # If `input_ids` was given, check if the last id in any sequence is `pad_token_id` + # Note: If using, `inputs_embeds` this check does not work, because we want to be more hands-off. + if ( + generation_config._pad_token_tensor is not None + and batch_size > 1 + and len(inputs_tensor.shape) == 2 + and torch.sum(inputs_tensor[:, -1] == generation_config._pad_token_tensor) > 0 + ): + logger.warning( + "A decoder-only architecture is being used, but right-padding was detected! For correct " + "generation results, please set `padding_side='left'` when initializing the tokenizer." + ) + + # 4. Define other model kwargs + # decoder-only models with inputs_embeds forwarding must use caching (otherwise we can't detect whether we are + # generating the first new token or not, and we only want to use the embeddings for the first new token) + if not self.config.is_encoder_decoder and model_input_name == "inputs_embeds": + model_kwargs["use_cache"] = True + else: + model_kwargs["use_cache"] = generation_config.use_cache + + if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask: + model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( + inputs_tensor, generation_config._pad_token_tensor, generation_config._eos_token_tensor + ) + + if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs: + # if model is encoder decoder encoder_outputs are created and added to `model_kwargs` + model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation( + inputs_tensor, model_kwargs, model_input_name, generation_config + ) + + # 5. Prepare `input_ids` which will be used for auto-regressive generation + if self.config.is_encoder_decoder: + input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation( + batch_size=batch_size, + model_input_name=model_input_name, + model_kwargs=model_kwargs, + decoder_start_token_id=generation_config._decoder_start_token_tensor, + device=inputs_tensor.device, + ) + else: + input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids") + + if generation_config.token_healing: + input_ids = self.heal_tokens(input_ids, tokenizer) + + if streamer is not None: + streamer.put(input_ids.cpu()) + + # 6. Prepare `max_length` depending on other stopping criteria. + input_ids_length = input_ids.shape[-1] + has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None + has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None + generation_config = self._prepare_generated_length( + generation_config=generation_config, + has_default_max_length=has_default_max_length, + has_default_min_length=has_default_min_length, + model_input_name=model_input_name, + inputs_tensor=inputs_tensor, + input_ids_length=input_ids_length, + ) + + # use_dynamic_cache_by_default = False + if "mamba" in self.__class__.__name__.lower(): + cache_name = "cache_params" + else: + cache_name = "past_key_values" + if generation_config.cache_implementation is not None and (model_kwargs.get(cache_name) is not None): + raise ValueError( + f"Passing both `cache_implementation` (used to initialize certain caches) and `{cache_name}` (a " + "Cache object) is unsupported. Please use only one of the two." + ) + elif generation_config.cache_implementation is not None: + if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING: + if generation_config.cache_implementation == "static" and not self._supports_static_cache: + raise ValueError( + "This model does not support `cache_implementation='static'`. Please check the following " + "issue: https://github.com/huggingface/transformers/issues/28981" + ) + model_kwargs[cache_name] = self._get_cache( + generation_config.cache_implementation, + getattr(generation_config, "num_beams", 1) * batch_size, + generation_config.max_length, + model_kwargs, + ) + elif generation_config.cache_implementation == "quantized": + if not self._supports_quantized_cache: + raise ValueError( + "This model does not support the quantized cache. If you want your model to support quantized " + "cache, please open an issue." + ) + + cache_config = ( + generation_config.cache_config + if generation_config.cache_config is not None + else QuantizedCacheConfig() + ) + cache_class = QUANT_BACKEND_CLASSES_MAPPING[cache_config.backend] + + if cache_config.backend == "quanto" and not is_quanto_available(): + raise ImportError( + "You need to install `quanto` in order to use KV cache quantization with quanto backend. " + "Please install it via with `pip install quanto`" + ) + elif cache_config.backend == "HQQ" and not is_hqq_available(): + raise ImportError( + "You need to install `HQQ` in order to use KV cache quantization with HQQ backend. " + "Please install it via with `pip install hqq`" + ) + + model_kwargs[cache_name] = cache_class(cache_config) + # Use DynamicCache() instance by default. This will avoid back and forth from legacy format that + # keeps copying the cache thus using much more memory + elif generation_config.cache_implementation is None and self._supports_default_dynamic_cache(): + past = model_kwargs.get(cache_name, None) + requires_cross_attention_cache = ( + self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None + ) + if past is None: + model_kwargs[cache_name] = ( + DynamicCache() + if not requires_cross_attention_cache + else EncoderDecoderCache(DynamicCache(), DynamicCache()) + ) + # use_dynamic_cache_by_default = True + elif isinstance(past, tuple): + model_kwargs[cache_name] = ( + DynamicCache.from_legacy_cache(past) + if not requires_cross_attention_cache + else EncoderDecoderCache.from_legacy_cache(past) + ) + # use_dynamic_cache_by_default = True + self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) + + # 7. determine generation mode + generation_mode = generation_config.get_generation_mode(assistant_model) + + if (streamer is not None or streamer_unit is not None) and (generation_config.num_beams > 1): + raise ValueError( + "`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1." + ) + + if self.device.type != input_ids.device.type: + warnings.warn( + "You are calling .generate() with the `input_ids` being on a device type different" + f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model" + f" is on {self.device.type}. You may experience unexpected behaviors or slower generation." + " Please make sure that you have put `input_ids` to the" + f" correct device by calling for example input_ids = input_ids.to('{self.device.type}') before" + " running `.generate()`.", + UserWarning, + ) + + # 8. prepare distribution pre_processing samplers + prepared_logits_processor = self._get_logits_processor( + generation_config=generation_config, + input_ids_seq_length=input_ids_length, + encoder_input_ids=inputs_tensor, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + logits_processor=logits_processor, + device=inputs_tensor.device, + model_kwargs=model_kwargs, + negative_prompt_ids=negative_prompt_ids, + negative_prompt_attention_mask=negative_prompt_attention_mask, + ) + + # 9. prepare stopping criteria + prepared_stopping_criteria = self._get_stopping_criteria( + generation_config=generation_config, stopping_criteria=stopping_criteria, tokenizer=tokenizer, **kwargs + ) + # 10. go into different generation modes + + if generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH): + # 11. prepare logits warper + prepared_logits_warper = ( + self._get_logits_warper(generation_config, device=input_ids.device) + if generation_config.do_sample + else None + ) + + # 12. expand input_ids with `num_return_sequences` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids=input_ids, + expand_size=generation_config.num_return_sequences, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, + ) + + # 13. run sample (it degenerates to greedy search when `generation_config.do_sample=False`) + if streaming_unit_gen: + return self._sample_streaming_unit( + input_ids, + logits_processor=prepared_logits_processor, + logits_warper=prepared_logits_warper, + stopping_criteria=prepared_stopping_criteria, + generation_config=generation_config, + synced_gpus=synced_gpus, + streamer=streamer, + streamer_unit=streamer_unit, + **model_kwargs, + ) + else: + return self._sample( + input_ids, + logits_processor=prepared_logits_processor, + logits_warper=prepared_logits_warper, + stopping_criteria=prepared_stopping_criteria, + generation_config=generation_config, + synced_gpus=synced_gpus, + streamer=streamer, + **model_kwargs, + ) + else: + raise NotImplementedError + + def _sample( + self, + input_ids: torch.Tensor, + logits_processor: LogitsProcessorList, + stopping_criteria: StoppingCriteriaList, + generation_config: GenerationConfig, + synced_gpus: bool, + streamer: Optional["BaseStreamer"], + logits_warper: Optional[LogitsProcessorList], + **model_kwargs, + ) -> Union[GenerateNonBeamOutput, torch.LongTensor]: + # init values + pad_token_id = generation_config._pad_token_tensor + output_attentions = generation_config.output_attentions + output_hidden_states = generation_config.output_hidden_states + output_scores = generation_config.output_scores + output_logits = generation_config.output_logits + return_dict_in_generate = generation_config.return_dict_in_generate + has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria) + do_sample = generation_config.do_sample + if do_sample is True and not isinstance(logits_warper, LogitsProcessorList): + raise ValueError( + "`do_sample` is set to `True`, `logits_warper` must be a `LogitsProcessorList` instance (it is " + f"{logits_warper})." + ) + + # init attention / hidden states / scores tuples + # scores = () if (return_dict_in_generate and output_scores) else None + # raw_logits = () if (return_dict_in_generate and output_logits) else None + # decoder_attentions = () if (return_dict_in_generate and output_attentions) else None + # cross_attentions = () if (return_dict_in_generate and output_attentions) else None + # decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None + + scores: tuple = () + raw_logits: tuple = () + decoder_attentions: tuple = () + cross_attentions: tuple = () + decoder_hidden_states: tuple = () + + # if model is an encoder-decoder, retrieve encoder attention weights and hidden states + if return_dict_in_generate and self.config.is_encoder_decoder: + encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None + encoder_hidden_states = ( + model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None + ) + + # keep track of which sequences are already finished + batch_size = input_ids.shape[0] + this_peer_finished = False + unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) + model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) + + while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): + # prepare model inputs + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + + # prepare variable output controls (note: some models won't accept all output controls) + model_inputs.update({"output_attentions": output_attentions} if output_attentions else {}) + model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {}) + + # forward pass to get next token + outputs = self(**model_inputs, return_dict=True) + + if synced_gpus and this_peer_finished: + continue # don't waste resources running the code we don't need + + # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be + # very large for first iteration (the clone itself is always small) + next_token_logits = outputs.logits[:, -1, :].clone() + + # pre-process distribution + next_token_scores = logits_processor(input_ids, next_token_logits) + if do_sample and logits_warper is not None: + next_token_scores = logits_warper(input_ids, next_token_scores) + + # Store scores, attentions and hidden_states when required + if return_dict_in_generate: + if output_scores: + scores += (next_token_scores,) + if output_logits: + raw_logits += (next_token_logits,) + if output_attentions: + decoder_attentions += ( + (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) + ) + if self.config.is_encoder_decoder: + cross_attentions += (outputs.cross_attentions,) + + if output_hidden_states: + decoder_hidden_states += ( + (outputs.decoder_hidden_states,) if self.config.is_encoder_decoder else (outputs.hidden_states,) + ) + + # token selection + if do_sample: + probs = nn.functional.softmax(next_token_scores, dim=-1) + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + else: + next_tokens = torch.argmax(next_token_scores, dim=-1) + + # finished sentences should have their next token be a padding token + if has_eos_stopping_criteria: + next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) + + # update generated ids, model inputs, and length for next step + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + if streamer is not None: + streamer.put(next_tokens.cpu()) + model_kwargs = self._update_model_kwargs_for_generation( + outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder, + ) + + unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) + this_peer_finished = bool(int(unfinished_sequences.max()) == 0) + + # This is needed to properly delete outputs.logits which may be very large for first iteration + # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration + del outputs + + if streamer is not None: + streamer.end() + + if return_dict_in_generate: + if self.config.is_encoder_decoder: + return GenerateEncoderDecoderOutput( + sequences=input_ids, + scores=scores, + logits=raw_logits, + encoder_attentions=encoder_attentions, + encoder_hidden_states=encoder_hidden_states, + decoder_attentions=decoder_attentions, + cross_attentions=cross_attentions, + decoder_hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values"), + ) + else: + return GenerateDecoderOnlyOutput( + sequences=input_ids, + scores=scores, + logits=raw_logits, + attentions=decoder_attentions, + hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values"), + ) + else: + return input_ids + + def _sample_streaming_unit( + self, + input_ids: torch.Tensor, + logits_processor: LogitsProcessorList, + stopping_criteria: StoppingCriteriaList, + generation_config: GenerationConfig, + synced_gpus: bool, + streamer: Optional["BaseStreamer"], + streamer_unit: Optional["BaseStreamer"], + logits_warper: Optional[LogitsProcessorList], + **model_kwargs, + ) -> Union[GenerateNonBeamOutput, torch.LongTensor]: + # init values + pad_token_id = generation_config._pad_token_tensor + output_attentions = generation_config.output_attentions + output_hidden_states = generation_config.output_hidden_states + output_scores = generation_config.output_scores + output_logits = generation_config.output_logits + return_dict_in_generate = generation_config.return_dict_in_generate + has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria) + do_sample = generation_config.do_sample + if do_sample is True and not isinstance(logits_warper, LogitsProcessorList): + raise ValueError( + "`do_sample` is set to `True`, `logits_warper` must be a `LogitsProcessorList` instance (it is " + f"{logits_warper})." + ) + + # init attention / hidden states / scores tuples + # scores = () if (return_dict_in_generate and output_scores) else None + # raw_logits = () if (return_dict_in_generate and output_logits) else None + # decoder_attentions = () if (return_dict_in_generate and output_attentions) else None + # cross_attentions = () if (return_dict_in_generate and output_attentions) else None + # decoder_hidden_states: tuple = () if (return_dict_in_generate and output_hidden_states) else None + + scores: tuple = () + raw_logits: tuple = () + decoder_attentions: tuple = () + cross_attentions: tuple = () + decoder_hidden_states: tuple = () + + # if model is an encoder-decoder, retrieve encoder attention weights and hidden states + if return_dict_in_generate and self.config.is_encoder_decoder: + encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None + encoder_hidden_states = ( + model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None + ) + + # keep track of which sequences are already finished + batch_size = input_ids.shape[0] + this_peer_finished = False + unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) + model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) + + generated_units = torch.tensor([]) + while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): + # prepare model inputs + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + + # prepare variable output controls (note: some models won't accept all output controls) + model_inputs.update({"output_attentions": output_attentions} if output_attentions else {}) + model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {}) + + # forward pass to get next token + outputs = self(**model_inputs, return_dict=True) + + if synced_gpus and this_peer_finished: + continue # don't waste resources running the code we don't need + + # Clone is needed to avoid keeping a hanging ref to outputs.logits + # which may be very large for first iteration (the clone itself is always small) + next_token_logits = outputs.logits[:, -1, :].clone() + + # pre-process distribution + next_token_scores = logits_processor(input_ids, next_token_logits) + if do_sample and logits_warper is not None: + next_token_scores = logits_warper(input_ids, next_token_scores) + + # Store scores, attentions and hidden_states when required + if return_dict_in_generate: + if output_scores and scores is not None and next_token_scores is not None: + scores += (next_token_scores,) + if output_logits and raw_logits is not None and next_token_logits is not None: + raw_logits += (next_token_logits,) + if output_attentions and decoder_attentions is not None: + decoder_attentions += ( + (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) + ) + if self.config.is_encoder_decoder and cross_attentions is not None: + cross_attentions += (outputs.cross_attentions,) + + if output_hidden_states and decoder_hidden_states is not None: + decoder_hidden_states += ( + (outputs.decoder_hidden_states,) if self.config.is_encoder_decoder else (outputs.hidden_states,) + ) + + # token selection + if do_sample: + probs = nn.functional.softmax(next_token_scores, dim=-1) + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + else: + next_tokens = torch.argmax(next_token_scores, dim=-1) + + # speechgen + hidden_states = torch.cat( + [decoder_hidden_states[0][-1][:, -1:, :]] + + [decoder_hidden_states[i][-1] for i in range(1, len(decoder_hidden_states))], + dim=1, + ) + ctc_pred = self.speech_generator.predict(hidden_states.squeeze(0)) + cur_units = ctc_postprocess(ctc_pred, blank=self.model.config.unit_vocab_size) + + # finished sentences should have their next token be a padding token + if has_eos_stopping_criteria: + next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) + + # update generated ids, model inputs, and length for next step + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + if streamer is not None: + streamer.put(next_tokens.cpu()) + if streamer_unit is not None: + for i in range(len(generated_units), len(cur_units)): + streamer_unit.put(cur_units[i].unsqueeze(0)) + generated_units = cur_units + model_kwargs = self._update_model_kwargs_for_generation( + outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder, + ) + + unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) + this_peer_finished = bool(int(unfinished_sequences.max()) == 0) + + # This is needed to properly delete outputs.logits which may be very large for first iteration + # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration + del outputs + + if streamer is not None: + streamer.end() + + if return_dict_in_generate: + if self.config.is_encoder_decoder: + return GenerateEncoderDecoderOutput( + sequences=input_ids, + scores=scores, + logits=raw_logits, + encoder_attentions=encoder_attentions, + encoder_hidden_states=encoder_hidden_states, + decoder_attentions=decoder_attentions, + cross_attentions=cross_attentions, + decoder_hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values"), + ) + else: + return GenerateDecoderOnlyOutput( + sequences=input_ids, + scores=scores, + logits=raw_logits, + attentions=decoder_attentions, + hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values"), + ) + else: + return input_ids + + +def ctc_postprocess(tokens, blank): + _toks = tokens.squeeze(0).tolist() + deduplicated_toks = [v for i, v in enumerate(_toks) if i == 0 or v != _toks[i - 1]] + hyp = torch.tensor([v for v in deduplicated_toks if v != blank]) + return hyp diff --git a/src/helm/clients/audio_language/llama_omni/model/speech_generator/speech_generator.py b/src/helm/clients/audio_language/llama_omni/model/speech_generator/speech_generator.py new file mode 100644 index 0000000000..63061e7b53 --- /dev/null +++ b/src/helm/clients/audio_language/llama_omni/model/speech_generator/speech_generator.py @@ -0,0 +1,104 @@ +import copy +import torch +import torch.nn as nn +import torch.nn.functional as F + +from transformers.models.llama.modeling_llama import LlamaDecoderLayer +from helm.clients.audio_language.llama_omni.constants import IGNORE_INDEX + + +def lengths_to_padding_mask(lens): + bsz, max_lens = lens.size(0), torch.max(lens).item() + mask = torch.arange(max_lens).to(lens.device).view([1, int(max_lens)]) + mask = mask.expand(bsz, -1) >= lens.view(bsz, 1).expand(-1, max_lens) + return mask + + +def _uniform_assignment(src_lens, tgt_lens): + tgt_max_len = torch.max(tgt_lens).item() + tgt_indices = torch.arange(tgt_max_len).expand(len(tgt_lens), -1).to(tgt_lens.device) + ratio = tgt_lens / src_lens + index_t = (tgt_indices / ratio.view(-1, 1)).long() + return index_t + + +class SpeechGeneratorCTC(nn.Module): + def __init__(self, config): + super().__init__() + n_layers, n_dims, n_heads, n_inter_dims = list(map(int, config.ctc_decoder_config[1:-1].split(","))) + _config = copy.deepcopy(config) + _config.hidden_size = n_dims + _config.num_hidden_layers = n_layers + _config.num_attention_heads = n_heads + _config.num_key_value_heads = n_heads + _config.intermediate_size = n_inter_dims + _config._attn_implementation = "flash_attention_2" + self.upsample_factor = config.ctc_upsample_factor + self.input_proj = nn.Linear(config.hidden_size, n_dims) + self.layers = nn.ModuleList([LlamaDecoderLayer(_config, layer_idx) for layer_idx in range(n_layers)]) + self.unit_vocab_size = config.unit_vocab_size + self.output_proj = nn.Linear(n_dims, config.unit_vocab_size + 1) + + def upsample(self, reps, tgt_units=None): + src_lens = torch.LongTensor([len(rep) for rep in reps]).to(reps[0].device) + up_lens = src_lens * self.upsample_factor + if tgt_units is not None: + tgt_lens = tgt_units.ne(IGNORE_INDEX).long().sum(dim=-1) + up_lens = torch.max(up_lens, tgt_lens) + reps = torch.nn.utils.rnn.pad_sequence(reps, batch_first=True) + padding_mask = lengths_to_padding_mask(up_lens) + mapped_inputs = _uniform_assignment(src_lens, up_lens).masked_fill(padding_mask, 0) + copied_reps = torch.gather( + reps, + 1, + mapped_inputs.unsqueeze(-1).expand(*mapped_inputs.size(), reps.size(-1)), + ) + copied_reps = copied_reps.masked_fill(padding_mask.unsqueeze(-1), 0) + position_ids = torch.arange(0, max(up_lens)).unsqueeze(0).expand(len(reps), -1).to(device=copied_reps.device) + return copied_reps, ~padding_mask, position_ids + + def forward(self, tgt_reps, labels, tgt_units): + tgt_label_reps = [] + for tgt_rep, label in zip(tgt_reps, labels): + tgt_label_reps.append(tgt_rep[label != IGNORE_INDEX]) + hidden_states, attention_mask, position_ids = self.upsample(tgt_label_reps, tgt_units) + hidden_states = self.input_proj(hidden_states) + for layer in self.layers: + layer_outputs = layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + ) + hidden_states = layer_outputs[0] + ctc_logits = self.output_proj(hidden_states) + ctc_lprobs = F.log_softmax(ctc_logits.float(), dim=-1, dtype=torch.float32) + ctc_lens = attention_mask.long().sum(dim=-1) + ctc_tgt_lens = tgt_units.ne(IGNORE_INDEX).long().sum(dim=-1) + ctc_tgt_mask = ~lengths_to_padding_mask(ctc_tgt_lens) + ctc_tgt_flat = tgt_units.masked_select(ctc_tgt_mask) + ctc_loss = F.ctc_loss( + ctc_lprobs.transpose(0, 1), + ctc_tgt_flat, + ctc_lens, + ctc_tgt_lens, + reduction="sum", + zero_infinity=True, + blank=self.unit_vocab_size, + ) + ctc_loss /= ctc_tgt_lens.sum().item() + return ctc_loss + + def predict(self, tgt_reps): + hidden_states, attention_mask, position_ids = self.upsample([tgt_reps]) + hidden_states = self.input_proj(hidden_states) + for layer in self.layers: + layer_outputs = layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + ) + hidden_states = layer_outputs[0] + ctc_logits = self.output_proj(hidden_states) + ctc_lprobs = F.log_softmax(ctc_logits.float(), dim=-1, dtype=torch.float32) + ctc_pred = ctc_lprobs.argmax(dim=-1).masked_fill_(~attention_mask, self.unit_vocab_size) + return ctc_pred diff --git a/src/helm/clients/audio_language/llama_omni/model/speech_projector/builder.py b/src/helm/clients/audio_language/llama_omni/model/speech_projector/builder.py new file mode 100644 index 0000000000..3b44e4bd81 --- /dev/null +++ b/src/helm/clients/audio_language/llama_omni/model/speech_projector/builder.py @@ -0,0 +1,9 @@ +from helm.clients.audio_language.llama_omni.model.speech_projector.speech_projector import EncoderProjectorConcat + + +def build_speech_projector(config): + projector_type = getattr(config, "speech_projector_type", "linear") + if projector_type == "linear": + return EncoderProjectorConcat(config) + + raise ValueError(f"Unknown projector type: {projector_type}") diff --git a/src/helm/clients/audio_language/llama_omni/model/speech_projector/speech_projector.py b/src/helm/clients/audio_language/llama_omni/model/speech_projector/speech_projector.py new file mode 100644 index 0000000000..be747eed03 --- /dev/null +++ b/src/helm/clients/audio_language/llama_omni/model/speech_projector/speech_projector.py @@ -0,0 +1,27 @@ +# Adopted from https://github.com/ddlBoJack/SLAM-LLM/blob/main/src/slam_llm/models/projector.py +import torch.nn as nn + + +class EncoderProjectorConcat(nn.Module): + def __init__(self, config): + super().__init__() + self.k = config.speech_encoder_ds_rate + self.encoder_dim = config.speech_encoder_hidden_size + self.llm_dim = config.hidden_size + self.linear1 = nn.Linear(self.encoder_dim * self.k, 2048) + self.relu = nn.ReLU() + self.linear2 = nn.Linear(2048, config.hidden_size) + + def forward(self, x): + batch_size, seq_len, dim = x.size() + num_frames_to_discard = seq_len % self.k + if num_frames_to_discard > 0: + x = x[:, :-num_frames_to_discard, :] + seq_len = x.size(1) + + x = x.contiguous() + x = x.view(batch_size, seq_len // self.k, dim * self.k) + x = self.linear1(x) + x = self.relu(x) + x = self.linear2(x) + return x diff --git a/src/helm/clients/audio_language/llama_omni/preprocess.py b/src/helm/clients/audio_language/llama_omni/preprocess.py new file mode 100644 index 0000000000..ff46d53d77 --- /dev/null +++ b/src/helm/clients/audio_language/llama_omni/preprocess.py @@ -0,0 +1,295 @@ +# Adopted from https://github.com/haotian-liu/LLaVA. Below is the original copyright: +# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright: +# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: +# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import transformers + +from typing import Dict, Sequence + +from helm.clients.audio_language.llama_omni.constants import IGNORE_INDEX, SPEECH_TOKEN_INDEX +import helm.clients.audio_language.llama_omni.conversation as conversation_lib + + +def tokenizer_speech_token(prompt, tokenizer, speech_token_index=SPEECH_TOKEN_INDEX, return_tensors=None): + prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("")] + + def insert_separator(X, sep): + return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1] + + input_ids = [] + offset = 0 + if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: + offset = 1 + input_ids.append(prompt_chunks[0][0]) + + for x in insert_separator(prompt_chunks, [speech_token_index] * (offset + 1)): + input_ids.extend(x[offset:]) + + if return_tensors is not None: + if return_tensors == "pt": + return torch.tensor(input_ids, dtype=torch.long) + raise ValueError(f"Unsupported tensor type: {return_tensors}") + return input_ids + + +def preprocess_llama_2(sources, tokenizer: transformers.PreTrainedTokenizer, has_speech: bool = False) -> Dict: + conv = conversation_lib.default_conversation.copy() + roles = {"human": conv.roles[0], "gpt": conv.roles[1]} + + # Apply prompt templates + conversations = [] + for i, source in enumerate(sources): + if roles[source[0]["from"]] != conv.roles[0]: + # Skip the first one if it is not from human + source = source[1:] + + conv.messages = [] + for j, sentence in enumerate(source): + role = roles[sentence["from"]] + assert role == conv.roles[j % 2], f"{i}" + conv.append_message(role, sentence["value"]) + conversations.append(conv.get_prompt()) + + # Tokenize conversations + + if has_speech: + input_ids = torch.stack( + [tokenizer_speech_token(prompt, tokenizer, return_tensors="pt") for prompt in conversations], dim=0 + ) + else: + input_ids = tokenizer( + conversations, + return_tensors="pt", + padding="longest", + max_length=tokenizer.model_max_length, + truncation=True, + ).input_ids + + targets = input_ids.clone() + + assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_2 + + # Mask targets + sep = "[/INST] " + for conversation, target in zip(conversations, targets): + total_len = int(target.ne(tokenizer.pad_token_id).sum()) + + rounds = conversation.split(conv.sep2) + cur_len = 1 + target[:cur_len] = IGNORE_INDEX + for i, rou in enumerate(rounds): + if rou == "": + break + + parts = rou.split(sep) + if len(parts) != 2: + break + parts[0] += sep + + if has_speech: + round_len = len(tokenizer_speech_token(rou, tokenizer)) + instruction_len = len(tokenizer_speech_token(parts[0], tokenizer)) - 2 + else: + round_len = len(tokenizer(rou).input_ids) + instruction_len = len(tokenizer(parts[0]).input_ids) - 2 + + target[cur_len : cur_len + instruction_len] = IGNORE_INDEX + + cur_len += round_len + target[cur_len:] = IGNORE_INDEX + + if cur_len < tokenizer.model_max_length: + if cur_len != total_len: + target[:] = IGNORE_INDEX + print(f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)") + + return dict( + input_ids=input_ids, + labels=targets, + ) + + +def preprocess_llama_3(sources, tokenizer: transformers.PreTrainedTokenizer, has_speech: bool = False) -> Dict: + conv = conversation_lib.default_conversation.copy() + roles = {"human": conv.roles[0], "gpt": conv.roles[1]} + + # Apply prompt templates + conversations = [] + for i, source in enumerate(sources): + if roles[source[0]["from"]] != conv.roles[0]: + # Skip the first one if it is not from human + source = source[1:] + + assert len(source) == 2, "now only support single-turn conversation" + + conv.messages = [] + for j, sentence in enumerate(source): + role = roles[sentence["from"]] + assert role == conv.roles[j % 2], f"{i}" + conv.append_message(role, sentence["value"]) + conversations.append(conv.get_prompt()) + + # Tokenize conversations + + if has_speech: + input_ids = torch.stack( + [tokenizer_speech_token(prompt, tokenizer, return_tensors="pt") for prompt in conversations], dim=0 + ) + else: + input_ids = tokenizer( + conversations, + return_tensors="pt", + padding="longest", + max_length=tokenizer.model_max_length, + truncation=True, + ).input_ids + + targets = input_ids.clone() + + assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_3 + + # Mask targets + sep = "<|start_header_id|>" + conv.roles[1] + "<|end_header_id|>\n\n" + for conversation, target in zip(conversations, targets): + + cur_len = 1 + target[:cur_len] = IGNORE_INDEX + parts = conversation.split(sep) + parts[0] += sep + + if has_speech: + conversation_len = len(tokenizer_speech_token(conversation, tokenizer)) + instruction_len = len(tokenizer_speech_token(parts[0], tokenizer)) - 1 + else: + conversation_len = len(tokenizer(conversation).input_ids) + instruction_len = len(tokenizer(parts[0]).input_ids) - 1 + + target[cur_len : cur_len + instruction_len] = IGNORE_INDEX + cur_len += conversation_len + target[cur_len:] = IGNORE_INDEX + + # if cur_len < tokenizer.model_max_length: + # if cur_len != total_len: + # target[:] = IGNORE_INDEX + # print( + # f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." + # f" (ignored)" + # ) + + return dict( + input_ids=input_ids, + labels=targets, + ) + + +def preprocess_v1(sources, tokenizer: transformers.PreTrainedTokenizer, has_speech: bool = False) -> Dict: + conv = conversation_lib.default_conversation.copy() + roles = {"human": conv.roles[0], "gpt": conv.roles[1]} + + # Apply prompt templates + conversations = [] + for i, source in enumerate(sources): + if roles[source[0]["from"]] != conv.roles[0]: + # Skip the first one if it is not from human + source = source[1:] + + conv.messages = [] + for j, sentence in enumerate(source): + role = roles[sentence["from"]] + assert role == conv.roles[j % 2], f"{i}" + conv.append_message(role, sentence["value"]) + conversations.append(conv.get_prompt()) + + # Tokenize conversations + + if has_speech: + input_ids = torch.stack( + [tokenizer_speech_token(prompt, tokenizer, return_tensors="pt") for prompt in conversations], dim=0 + ) + else: + input_ids = tokenizer( + conversations, + return_tensors="pt", + padding="longest", + max_length=tokenizer.model_max_length, + truncation=True, + ).input_ids + + targets = input_ids.clone() + + assert conv.sep_style == conversation_lib.SeparatorStyle.TWO + + # Mask targets + sep = conv.sep + conv.roles[1] + ": " + for conversation, target in zip(conversations, targets): + total_len = int(target.ne(tokenizer.pad_token_id).sum()) + + rounds = conversation.split(conv.sep2) + cur_len = 1 + target[:cur_len] = IGNORE_INDEX + for i, rou in enumerate(rounds): + if rou == "": + break + + parts = rou.split(sep) + if len(parts) != 2: + break + parts[0] += sep + + if has_speech: + round_len = len(tokenizer_speech_token(rou, tokenizer)) + instruction_len = len(tokenizer_speech_token(parts[0], tokenizer)) - 2 + else: + round_len = len(tokenizer(rou).input_ids) + instruction_len = len(tokenizer(parts[0]).input_ids) - 2 + + # FIXME: tokenizer bug + if i != 0 and not tokenizer.legacy: + round_len -= 1 + instruction_len -= 1 + + target[cur_len : cur_len + instruction_len] = IGNORE_INDEX + + cur_len += round_len + target[cur_len:] = IGNORE_INDEX + + if cur_len < tokenizer.model_max_length: + if cur_len != total_len: + target[:] = IGNORE_INDEX + print(f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)") + + return dict( + input_ids=input_ids, + labels=targets, + ) + + +def preprocess(sources: Sequence[str], tokenizer: transformers.PreTrainedTokenizer, has_speech: bool = False) -> Dict: + """ + Given a list of sources, each is a conversation list. This transform: + 1. Add signal '### ' at the beginning each sentence, with end signal '\n'; + 2. Concatenate conversations together; + 3. Tokenize the concatenated conversation; + 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX. + """ + if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.LLAMA_2: + return preprocess_llama_2(sources, tokenizer, has_speech=has_speech) + if conversation_lib.default_conversation.version.startswith("v1"): + return preprocess_v1(sources, tokenizer, has_speech=has_speech) + if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.LLAMA_3: + return preprocess_llama_3(sources, tokenizer, has_speech=has_speech) + raise NotImplementedError diff --git a/src/helm/clients/audio_language/llama_omni/utils.py b/src/helm/clients/audio_language/llama_omni/utils.py new file mode 100644 index 0000000000..5b680db3d5 --- /dev/null +++ b/src/helm/clients/audio_language/llama_omni/utils.py @@ -0,0 +1,202 @@ +# Adopted from https://github.com/haotian-liu/LLaVA. Below is the original copyright: +# Copyright 2023 Haotian Liu +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import torch +import logging +import logging.handlers +import transformers + +server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" +moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN." + +handler = None + + +class StreamToLogger(object): + """ + Fake file-like stream object that redirects writes to a logger instance. + """ + + def __init__(self, logger, log_level=logging.INFO): + self.terminal = sys.stdout + self.logger = logger + self.log_level = log_level + self.linebuf = "" + + def __getattr__(self, attr): + return getattr(self.terminal, attr) + + def write(self, buf): + temp_linebuf = self.linebuf + buf + self.linebuf = "" + for line in temp_linebuf.splitlines(True): + # From the io.TextIOWrapper docs: + # On output, if newline is None, any '\n' characters written + # are translated to the system default line separator. + # By default sys.stdout.write() expects '\n' newlines and then + # translates them so this is still cross platform. + if line[-1] == "\n": + self.logger.log(self.log_level, line.rstrip()) + else: + self.linebuf += line + + def flush(self): + if self.linebuf != "": + self.logger.log(self.log_level, self.linebuf.rstrip()) + self.linebuf = "" + + +def maybe_zero_3(param, ignore_status=False, name=None): + from deepspeed import zero + from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus + + if hasattr(param, "ds_id"): + if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: + if not ignore_status: + logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}") + with zero.GatheredParameters([param]): + param = param.data.detach().cpu().clone() + else: + param = param.detach().cpu().clone() + return param + + +# Borrowed from peft.utils.get_peft_model_state_dict +def get_peft_state_maybe_zero_3(named_params, bias): + if bias == "none": + to_return = {k: t for k, t in named_params if "lora_" in k} + elif bias == "all": + to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k} + elif bias == "lora_only": + to_return = {} + maybe_lora_bias = {} + lora_bias_names = set() + for k, t in named_params: + if "lora_" in k: + to_return[k] = t + bias_name = k.split("lora_")[0] + "bias" + lora_bias_names.add(bias_name) + elif "bias" in k: + maybe_lora_bias[k] = t + for k, t in maybe_lora_bias: + if bias_name in lora_bias_names: + to_return[bias_name] = t + else: + raise NotImplementedError + to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()} + return to_return + + +def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True): + to_return = {k: t for k, t in named_params if "lora_" not in k} + if require_grad_only: + to_return = {k: t for k, t in to_return.items() if t.requires_grad} + to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()} + return to_return + + +def get_speech_projector_state_maybe_zero_3(named_params, keys_to_match): + to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)} + to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()} + return to_return + + +def find_all_linear_names(model): + cls = torch.nn.Linear + lora_module_names = set() + speech_keywords = ["speech_projector", "speech_encoder"] + for name, module in model.named_modules(): + if any(speech_keyword in name for speech_keyword in speech_keywords): + continue + if isinstance(module, cls): + names = name.split(".") + lora_module_names.add(names[0] if len(names) == 1 else names[-1]) + + if "lm_head" in lora_module_names: # needed for 16-bit + lora_module_names.remove("lm_head") + return list(lora_module_names) + + +def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str): + """Collects the state dict and dump to disk.""" + + if getattr(trainer.args, "tune_speech_projector", False): + # Only save projector + keys_to_match = ["speech_projector"] + if getattr(trainer.args, "use_im_start_end", False): + keys_to_match.extend(["embed_tokens", "embed_in"]) + + weight_to_save = get_speech_projector_state_maybe_zero_3(trainer.model.named_parameters(), keys_to_match) + trainer.model.config.save_pretrained(output_dir) + + current_folder = output_dir.split("/")[-1] + parent_folder = os.path.dirname(output_dir) + if trainer.args.local_rank == 0 or trainer.args.local_rank == -1: + if current_folder.startswith("checkpoint-"): + speech_projector_folder = os.path.join(parent_folder, "speech_projector") + os.makedirs(speech_projector_folder, exist_ok=True) + torch.save(weight_to_save, os.path.join(speech_projector_folder, f"{current_folder}.bin")) + else: + torch.save(weight_to_save, os.path.join(output_dir, "speech_projector.bin")) + return + + if trainer.deepspeed: + torch.cuda.synchronize() + trainer.save_model(output_dir) + return + + state_dict = trainer.model.state_dict() + if trainer.args.should_save: + cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()} + del state_dict + trainer._save(output_dir, state_dict=cpu_state_dict) # noqa + + +def lengths_to_padding_mask(lens): + bsz, max_lens = lens.size(0), torch.max(lens).item() + mask = torch.arange(max_lens).to(lens.device).view([1, int(max_lens)]) + mask = mask.expand(bsz, -1) >= lens.view(bsz, 1).expand(-1, max_lens) + return mask + + +def lengths_to_mask(lens): + return ~lengths_to_padding_mask(lens) + + +def disable_torch_init(): + """ + Disable the redundant torch default initialization to accelerate model creation. + """ + import torch + + setattr(torch.nn.Linear, "reset_parameters", lambda self: None) + setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) + + +def get_model_name_from_path(model_path): + model_path = model_path.strip("/") + model_paths = model_path.split("/") + if model_paths[-1].startswith("checkpoint-"): + return model_paths[-2] + "_" + model_paths[-1] + else: + return model_paths[-1] + + +def pretty_print_semaphore(semaphore): + if semaphore is None: + return "None" + return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" diff --git a/src/helm/clients/audio_language/llama_omni_client.py b/src/helm/clients/audio_language/llama_omni_client.py new file mode 100644 index 0000000000..04fe8621a7 --- /dev/null +++ b/src/helm/clients/audio_language/llama_omni_client.py @@ -0,0 +1,198 @@ +from threading import Lock +import torch +from typing import Any, Dict, List, Optional + +from dataclasses import dataclass +from transformers import AutoTokenizer +import whisper +from helm.clients.audio_language.llama_omni.model.builder import load_pretrained_model as load_llama_omni +from helm.clients.audio_language.llama_omni.model.language_model.omni_speech2s_llama import OmniSpeech2SLlamaForCausalLM +from helm.clients.audio_language.llama_omni.conversation import conv_templates, Conversation +from helm.clients.audio_language.llama_omni.preprocess import tokenizer_speech_token + +from helm.common.cache import CacheConfig +from helm.common.gpu_utils import get_torch_device_name +from helm.common.hierarchical_logger import hlog, htrack_block +from helm.common.media_object import TEXT_TYPE +from helm.common.request import Request, RequestResult, GeneratedOutput, Token +from helm.common.request import wrap_request_time +from helm.clients.client import CachingClient, generate_uid_for_multimodal_prompt + + +@dataclass(frozen=True) +class LoadedLlamaOmniModelProcessor: + """Loaded model and processor for Qwen.""" + + model: OmniSpeech2SLlamaForCausalLM + tokenizer: AutoTokenizer + + +_models_lock: Lock = Lock() +_models: Dict[str, Optional[LoadedLlamaOmniModelProcessor]] = { + "ICTNLP/Llama-3.1-8B-Omni": None, +} + + +class LlamaOmniAudioLMClient(CachingClient): + """ + From https://github.com/ictnlp/LLaMA-Omni, + LLaMA-Omni is the audio multimodal version based on the LLaMA-3.1-8B large language model, + developed by ICTNLP group. LLaMA-Omni accepts audio, text as inputs, and outputs text. + + Paper: https://arxiv.org/abs/2409.06666 + """ + + END_OF_TEXT_TOKEN: str = "<|im_end|>" + CONV_MODE: str = "llama_3" + PAD_ID: int = 128004 + MEL_NUM: int = 128 + + def __init__(self, cache_config: CacheConfig): + super().__init__(cache_config=cache_config) + self._device: str = get_torch_device_name() + + def _get_model(self, helm_model_name: str) -> LoadedLlamaOmniModelProcessor: + global _models_lock + global _models + + model_name: str + if helm_model_name == "llama-3.1-8b-omni": + model_name = "ICTNLP/Llama-3.1-8B-Omni" + else: + raise ValueError(f"Unhandled model name: {helm_model_name}") + + # Ensure that only one thread is loading the model at a time + with _models_lock: + loaded_model_processor = _models[model_name] + if loaded_model_processor is None: + hlog(f"Loading model {model_name} and caching in memory...") + # Follow the official LLaMA-Omni model loading pattern: + # https://github.com/ictnlp/LLaMA-Omni/blob/main/omni_speech/infer/run.sh + tokenizer, model, _ = load_llama_omni(model_name, None, s2s=True) + _models[model_name] = LoadedLlamaOmniModelProcessor(model, tokenizer) + loaded_model_processor = _models[model_name] + + assert loaded_model_processor is not None + return loaded_model_processor + + def _load_local_audio(self, media_object) -> torch.Tensor: + assert media_object.is_local_file, "LLaMA-Omni only supports local audio file input" + audio_media = whisper.load_audio(media_object.location) + audio_media = whisper.pad_or_trim(audio_media) + audio_media = whisper.log_mel_spectrogram(audio_media, n_mels=self.MEL_NUM).permute(1, 0) + return audio_media + + def make_request(self, request: Request) -> RequestResult: + assert request.multimodal_prompt is not None, "Multimodal prompt is required" + + loaded_model_processor: LoadedLlamaOmniModelProcessor = self._get_model(request.model_engine) + model = loaded_model_processor.model + tokenizer = loaded_model_processor.tokenizer + + # The generation configs are taken from the official LLaMA-Omni repository + # https://github.com/ictnlp/LLaMA-Omni/blob/main/omni_speech/infer/infer.py#L116 + generation_args = { + "max_new_tokens": 25, + "do_sample": False, + "use_cache": False, + "pad_token_id": self.PAD_ID, + "streaming_unit_gen": False, + "top_p": None, + } + + input_text_query: Dict[str, str] + input_audio_query: Dict[str, Any] + prompt_text: str = "" + + for media_object in request.multimodal_prompt.media_objects: + if media_object.is_type("audio") and media_object.location: + input_audio_query = {"audio": self._load_local_audio(media_object)} + elif media_object.is_type(TEXT_TYPE): + if media_object.text is None: + raise ValueError("MediaObject of text type has missing text field value") + input_text_query = {"text": "\n" + media_object.text} + prompt_text += media_object.text + else: + raise ValueError(f"Unrecognized MediaObject type {media_object.type}") + + completions: List[GeneratedOutput] = [] + request_time: float = 0 + request_datetime: Optional[int] = None + all_cached: bool = True + + with htrack_block(f"Generating for prompt: {prompt_text}"): + for completion_index in range(request.num_completions): + try: + + def do_it() -> Dict[str, Any]: + conv: Conversation = conv_templates[self.CONV_MODE].copy() + conv.append_message(conv.roles[0], input_text_query["text"]) + conv.append_message(conv.roles[1], None) + query: str = conv.get_prompt() + # LLama-Omni requires a batch input + text_inputs = ( + tokenizer_speech_token(query, tokenizer, return_tensors="pt").unsqueeze(0).to(self._device) + ) + audio_inputs = ( + input_audio_query["audio"].to(dtype=torch.float16, device=self._device).unsqueeze(0) + ) + speech_length = torch.LongTensor([audio_inputs.shape[1]]) + pred, _ = model.generate( + text_inputs, + audio_inputs, + speech_length, + None, + None, + None, + None, + None, + None, + None, + None, + False, + None, + None, + **generation_args, + ) + completion = tokenizer.decode(pred.cpu()[0], skip_special_tokens=True) + tokens: List[str] = tokenizer.tokenize(completion) + return {"output": (completion, tokens)} + + # Include the prompt and model name in the cache key + cache_key = CachingClient.make_cache_key( + raw_request={ + "completion_index": completion_index, + "model": request.model, + "prompt": generate_uid_for_multimodal_prompt(request.multimodal_prompt), + **generation_args, + }, + request=request, + ) + result, cached = self.cache.get(cache_key, wrap_request_time(do_it)) + except RuntimeError as model_error: + return RequestResult( + success=False, cached=False, error=str(model_error), completions=[], embedding=[] + ) + + text, tokens = result["output"] + + # Tokenize truncated text to get the list of tokens + completions.append( + GeneratedOutput( + text=text, logprob=0, tokens=[Token(text=str(token), logprob=0) for token in tokens] + ) + ) + + request_time += result["request_time"] + # Use the datetime from the first completion because that's when the request was fired + request_datetime = request_datetime or result.get("request_datetime") + all_cached = all_cached and cached + + return RequestResult( + success=True, + cached=all_cached, + request_time=request_time, + request_datetime=request_datetime, + completions=completions, + embedding=[], + ) diff --git a/src/helm/clients/audio_language/qwen2_audiolm_client.py b/src/helm/clients/audio_language/qwen2_audiolm_client.py index de8a8c432a..4a3b4134a4 100644 --- a/src/helm/clients/audio_language/qwen2_audiolm_client.py +++ b/src/helm/clients/audio_language/qwen2_audiolm_client.py @@ -1,5 +1,4 @@ from threading import Lock -from io import BytesIO import librosa from typing import Any, Dict, List, Optional @@ -7,7 +6,6 @@ from transformers import Qwen2AudioForConditionalGeneration, AutoProcessor from helm.common.cache import CacheConfig -from helm.common.multimodal_request_utils import get_contents_as_bytes from helm.common.gpu_utils import get_torch_device_name from helm.common.hierarchical_logger import hlog, htrack_block from helm.common.media_object import TEXT_TYPE @@ -99,7 +97,9 @@ def make_request(self, request: Request) -> RequestResult: prompt_text += "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" for media_num, media_object in enumerate(request.multimodal_prompt.media_objects): if media_object.is_type("audio") and media_object.location: - query.append({"type": "audio", "audio_url": media_object.location}) + assert media_object.is_local_file, "Only local audio files are supported" + query.append({"type": "audio", "audio_loc": media_object.location}) + prompt_text += f"<|im_start|>user\nAudio {media_num+1}: <|audio_bos|><|AUDIO|><|audio_eos|>\n" elif media_object.is_type(TEXT_TYPE): if media_object.text is None: @@ -131,7 +131,7 @@ def do_it() -> Dict[str, Any]: if element["type"] == "audio": audios.append( librosa.load( - BytesIO(get_contents_as_bytes(element["audio_url"])), + element["audio_loc"], sr=tokenizer.feature_extractor.sampling_rate, )[0] ) diff --git a/src/helm/common/audio_utils.py b/src/helm/common/audio_utils.py index 701e93173a..7b9332135a 100644 --- a/src/helm/common/audio_utils.py +++ b/src/helm/common/audio_utils.py @@ -1,4 +1,7 @@ import base64 +import os +from scipy.io.wavfile import write +import numpy as np def encode_base64(audio_path: str) -> str: @@ -6,3 +9,9 @@ def encode_base64(audio_path: str) -> str: with open(audio_path, "rb") as audio_file: audio_data = audio_file.read() return base64.b64encode(audio_data).decode("utf-8") + + +def ensure_audio_file_exists(audio_path: str, audio_array: np.ndarray, audio_sampling_rate: int) -> None: + """Ensures that the audio file exists locally.""" + if not os.path.exists(audio_path): + write(audio_path, audio_sampling_rate, audio_array) diff --git a/src/helm/config/model_deployments.yaml b/src/helm/config/model_deployments.yaml index 4e23731da6..c2d929398a 100644 --- a/src/helm/config/model_deployments.yaml +++ b/src/helm/config/model_deployments.yaml @@ -2694,3 +2694,11 @@ model_deployments: max_sequence_length: 64000 client_spec: class_name: "helm.clients.reka_client.RekaClient" + +# LLaMA-Omni + - name: ictnlp/llama-3.1-8b-omni + model_name: ictnlp/llama-3.1-8b-omni + tokenizer_name: ictnlp/llama-3.1-8b-omni + max_sequence_length: 8192 + client_spec: + class_name: "helm.clients.audio_language.llama_omni_client.LlamaOmniAudioLMClient" \ No newline at end of file diff --git a/src/helm/config/model_metadata.yaml b/src/helm/config/model_metadata.yaml index fe0ceeef7d..3aed3e4646 100644 --- a/src/helm/config/model_metadata.yaml +++ b/src/helm/config/model_metadata.yaml @@ -3308,3 +3308,12 @@ models: release_date: 2024-04-18 tags: [VISION_LANGUAGE_MODEL_TAG, LIMITED_FUNCTIONALITY_TEXT_MODEL_TAG] +# LLaMA-Omni + - name: ictnlp/llama-3.1-8b-omni + display_name: LLaMA-Omni (8B) + description: The audio-visual multimodal version of the LLaMA 3.1 model ([paper](https://arxiv.org/abs/2409.06666)). + creator_organization_name: ICTNLP + access: open + num_parameters: 8000000000 + release_date: 2024-09-10 + tags: [AUDIO_LANGUAGE_MODEL_TAG] diff --git a/src/helm/config/tokenizer_configs.yaml b/src/helm/config/tokenizer_configs.yaml index 0460613a6a..056a05a223 100644 --- a/src/helm/config/tokenizer_configs.yaml +++ b/src/helm/config/tokenizer_configs.yaml @@ -621,3 +621,13 @@ tokenizer_configs: class_name: "helm.tokenizers.yalm_tokenizer.YaLMTokenizer" end_of_text_token: "" prefix_token: "" + + # LLaMA-Omni + - name: ictnlp/llama-3.1-8b-omni + tokenizer_spec: + class_name: "helm.tokenizers.huggingface_tokenizer.HuggingFaceTokenizer" + args: + pretrained_model_name_or_path: ICTNLP/Llama-3.1-8B-Omni + trust_remote_code: false + end_of_text_token: "<|eot_id|>" + prefix_token: "<|begin_of_text|>" \ No newline at end of file