Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Llama-Omni-8B #3119

Merged
merged 5 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 2 additions & 7 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -278,13 +278,8 @@ audiolm =
soundfile~=0.12
librosa~=0.10

# For OpenFlamingo
einops~=0.7.0
einops-exts~=0.0.4
open-clip-torch~=2.24

# For IDEFICS
torch~=2.1
# For LLaMA-Omni
openai-whisper==20240930

# For Qwen2-Audio
transformers~=4.45.1
Expand Down
34 changes: 16 additions & 18 deletions src/helm/benchmark/scenarios/audio_language/audio_mnist_scenario.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Scenarios for audio models"""

from typing import List
import os

from helm.benchmark.scenarios.scenario import (
Scenario,
Expand All @@ -11,7 +12,11 @@
Input,
Output,
)
from tqdm import tqdm
from datasets import load_dataset
from helm.common.media_object import MediaObject, MultimediaObject
from helm.common.general import ensure_directory_exists
from helm.common.audio_utils import ensure_audio_file_exists


class AudioMNISTScenario(Scenario):
Expand All @@ -37,28 +42,21 @@ class AudioMNISTScenario(Scenario):
}
""" # noqa: E501

NUM_SPEAKERS = 60
NUM_TRIALS = 50
WAV_URL_TEMPLATE = r"https://github.com/soerenab/AudioMNIST/raw/544b0f4bc65227e54332e665d5e02c24be6732c2/data/{speaker_id}/{digit}_{speaker_id}_{trial_index}.wav" # noqa: E501

name = "audio_mnist"
description = "Classify an audio sample of a spoken digit ([Becker et al, 2023](https://arxiv.org/abs/1807.03418))."
tags = ["audio", "classification"]

def get_instances(self, output_path: str) -> List[Instance]:
instances: List[Instance] = []
for digit in range(10):
for speaker_index in range(AudioMNISTScenario.NUM_SPEAKERS):
speaker_id = str(speaker_index).zfill(2)
for trial_index in range(AudioMNISTScenario.NUM_TRIALS):
wav_url = AudioMNISTScenario.WAV_URL_TEMPLATE.format(
digit=digit, speaker_id=speaker_id, trial_index=trial_index
)
input = Input(
multimedia_content=MultimediaObject([MediaObject(content_type="audio/wav", location=wav_url)])
)
references = [Reference(Output(text=str(digit)), tags=[CORRECT_TAG])]
# Don't need train split because we're using zero-shot
instance = Instance(input=input, references=references, split=TEST_SPLIT)
instances.append(instance)
wav_save_dir: str = os.path.join(output_path, "wav_files")
ensure_directory_exists(wav_save_dir)
for row in tqdm(load_dataset("flexthink/audiomnist", cache_dir=output_path, split=TEST_SPLIT)):
local_audio_path = os.path.join(wav_save_dir, row["audio"]["path"])
audio_array = row["audio"]["array"]
ensure_audio_file_exists(local_audio_path, audio_array, row["audio"]["sampling_rate"])
input = Input(
multimedia_content=MultimediaObject([MediaObject(content_type="audio/mpeg", location=local_audio_path)])
)
references = [Reference(Output(text=str(row["digit"])), tags=[CORRECT_TAG])]
instances.append(Instance(input=input, references=references, split=TEST_SPLIT))
return instances
61 changes: 61 additions & 0 deletions src/helm/clients/audio_language/llama_omni/arguments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import transformers

from dataclasses import dataclass, field
from typing import Optional


@dataclass
class ModelArguments:
model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
version: Optional[str] = field(default="v0")
freeze_backbone: bool = field(default=False)
tune_speech_projector: bool = field(default=False)
tune_speech_encoder: bool = field(default=False)
tune_speech_generator_only: bool = field(default=False)
speech_encoder_type: Optional[str] = field(default=None)
speech_encoder: Optional[str] = field(default=None)
pretrain_speech_projector: Optional[str] = field(default=None)
speech_projector_type: Optional[str] = field(default="linear")
speech_generator_type: Optional[str] = field(default="ctc")
ctc_decoder_config: str = "(2,4096,32,11008)"
ctc_upsample_factor: int = 1
ctc_loss_weight: float = 1.0
unit_vocab_size: int = 1000
speech_encoder_ds_rate: int = 5
speech_encoder_hidden_size: int = 1280


@dataclass
class DataArguments:
data_path: str = field(default="", metadata={"help": "Path to the training data."})
is_multimodal: bool = False
input_type: str = field(default="mel")
speech_normalize: bool = False
mel_size: int = 128
has_tgt_units: bool = False


@dataclass
class TrainingArguments(transformers.TrainingArguments):
cache_dir: Optional[str] = field(default=None)
optim: str = field(default="adamw_torch")
freeze_speech_projector: bool = field(default=False)
model_max_length: int = field(
default=512,
metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
)
double_quant: bool = field(
default=True, metadata={"help": "Compress the quantization statistics through double quantization."}
)
quant_type: str = field(
default="nf4", metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."}
)
bits: int = field(default=16, metadata={"help": "How many bits to use."})
lora_enable: bool = False
lora_r: int = 64
lora_alpha: int = 16
lora_dropout: float = 0.05
lora_weight_path: str = ""
lora_bias: str = "none"
speech_projector_lr: Optional[float] = None
group_by_modality_length: bool = field(default=False)
9 changes: 9 additions & 0 deletions src/helm/clients/audio_language/llama_omni/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
CONTROLLER_HEART_BEAT_EXPIRATION = 30
WORKER_HEART_BEAT_INTERVAL = 15

LOGDIR = "."

# Model Constants
IGNORE_INDEX = -100
SPEECH_TOKEN_INDEX = -200
DEFAULT_SPEECH_TOKEN = "<speech>"
213 changes: 213 additions & 0 deletions src/helm/clients/audio_language/llama_omni/conversation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
# Adopted from https://github.com/haotian-liu/LLaVA. Below is the original copyright:
# Copyright 2023 Haotian Liu
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import dataclasses
from enum import auto, Enum
from typing import List, Any, Union, Optional


class SeparatorStyle(Enum):
"""Different separator style."""

TWO = auto()
PLAIN = auto()
LLAMA_2 = auto()
LLAMA_3 = auto()


@dataclasses.dataclass
class Conversation:
"""A class that keeps all conversation history."""

system: str
roles: List[str]
messages: List[List[str]]
offset: int
sep_style: SeparatorStyle = SeparatorStyle.PLAIN
sep: str = "###"
sep2: str = ""
version: str = "Unknown"

tokenizer_id: str = ""
tokenizer: Any = None
# Stop criteria (the default one is EOS token)
stop_str: Optional[Union[str, List[str]]] = None
# Stops generation if meeting any token in this list
stop_token_ids: Optional[List[int]] = None

skip_next: bool = False

def get_prompt(self):
messages = self.messages

if self.sep_style == SeparatorStyle.TWO:
seps = [self.sep, self.sep2]
ret = self.system + seps[0]
for i, (role, message) in enumerate(messages):
if message:
if type(message) is tuple:
message = message[0]
ret += role + ": " + message + seps[i % 2]
else:
ret += role + ":"
elif self.sep_style == SeparatorStyle.LLAMA_3:
wrap_sys = lambda msg: (
f"<|start_header_id|>system<|end_header_id|>\n\n{msg}<|eot_id|>" if len(msg) > 0 else msg
)
ret = "<|begin_of_text|>" + wrap_sys(self.system)
for i, (role, message) in enumerate(messages):
if message:
if type(message) is tuple:
message = message[0]
ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n"
ret += message.strip() + self.sep2
else:
ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n"
return ret
elif self.sep_style == SeparatorStyle.LLAMA_2:
wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n" if len(msg) > 0 else msg
wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
ret = ""

for i, (role, message) in enumerate(messages):
if i == 0:
assert message, "first message should not be none"
assert role == self.roles[0], "first message should come from user"
if message:
if type(message) is tuple:
message = message[0]
if i == 0:
message = wrap_sys(self.system) + message
if i % 2 == 0:
message = wrap_inst(message)
ret += self.sep + message
else:
ret += " " + message + " " + self.sep2
else:
ret += ""
ret = ret.lstrip(self.sep)
elif self.sep_style == SeparatorStyle.PLAIN:
seps = [self.sep, self.sep2]
ret = self.system
for i, (role, message) in enumerate(messages):
if message:
if type(message) is tuple:
message = message[0]
ret += message + seps[i % 2]
else:
ret += ""
else:
raise ValueError(f"Invalid style: {self.sep_style}")

return ret

def append_message(self, role, message):
self.messages.append([role, message])

def to_gradio_chatbot(self):
ret = []
for i, (role, msg) in enumerate(self.messages[self.offset :]):
if i % 2 == 0:
if type(msg) is tuple:
msg = msg[0]
ret.append([msg, None])
else:
ret.append([msg, None])
else:
ret[-1][-1] = msg
return ret

def copy(self):
return Conversation(
system=self.system,
roles=self.roles,
messages=[[x, y] for x, y in self.messages],
offset=self.offset,
sep_style=self.sep_style,
sep=self.sep,
sep2=self.sep2,
version=self.version,
)

def dict(self):
return {
"system": self.system,
"roles": self.roles,
"messages": self.messages,
"offset": self.offset,
"sep": self.sep,
"sep2": self.sep2,
}


conv_vicuna_v1 = Conversation(
system="A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
roles=["USER", "ASSISTANT"],
version="v1",
messages=[],
offset=0,
sep_style=SeparatorStyle.TWO,
sep=" ",
sep2="</s>",
)

conv_llama_2 = Conversation(
system="You are a helpful language and speech assistant. "
"You are able to understand the speech content that the user provides, "
"and assist the user with a variety of tasks using natural language.",
roles=["USER", "ASSISTANT"],
version="llama_v2",
messages=[],
offset=0,
sep_style=SeparatorStyle.LLAMA_2,
sep="<s>",
sep2="</s>",
)

conv_llama_3 = Conversation(
system="You are a helpful language and speech assistant. "
"You are able to understand the speech content that the user provides, "
"and assist the user with a variety of tasks using natural language.",
roles=["user", "assistant"],
version="llama_v3",
messages=[],
offset=0,
sep_style=SeparatorStyle.LLAMA_3,
sep="",
sep2="<|eot_id|>",
)

conv_plain = Conversation(
system="",
roles=["", ""],
messages=[],
offset=0,
sep_style=SeparatorStyle.PLAIN,
sep="</s>",
)


default_conversation = conv_llama_3
conv_templates = {
"v1": conv_vicuna_v1,
"plain": conv_plain,
"llama_2": conv_llama_2,
"llama_3": conv_llama_3,
}


if __name__ == "__main__":
print(default_conversation.get_prompt())
Empty file.
Loading