Skip to content

Commit

Permalink
[Model] LoRA Support for Ultravox model (#11253)
Browse files Browse the repository at this point in the history
  • Loading branch information
thedebugger authored Feb 6, 2025
1 parent 9cdea30 commit d88506d
Show file tree
Hide file tree
Showing 4 changed files with 160 additions and 7 deletions.
2 changes: 1 addition & 1 deletion docs/source/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -857,7 +857,7 @@ See [this page](#generative-models) for more information on how to use generativ
* Ultravox
* T + A<sup>E+</sup>
* `fixie-ai/ultravox-v0_3`
*
* ✅︎
* ✅︎
* ✅︎
:::
Expand Down
16 changes: 12 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,14 +737,16 @@ def generate(
images: Optional[PromptImageInput] = None,
videos: Optional[PromptVideoInput] = None,
audios: Optional[PromptAudioInput] = None,
**kwargs: Any,
) -> List[Tuple[List[List[int]], List[str]]]:
inputs = self.get_inputs(prompts,
images=images,
videos=videos,
audios=audios)

req_outputs = self.model.generate(inputs,
sampling_params=sampling_params)
sampling_params=sampling_params,
**kwargs)

outputs: List[Tuple[List[List[int]], List[str]]] = []
for req_output in req_outputs:
Expand Down Expand Up @@ -782,6 +784,7 @@ def generate_w_logprobs(
images: Optional[PromptImageInput] = None,
audios: Optional[PromptAudioInput] = None,
videos: Optional[PromptVideoInput] = None,
**kwargs: Any,
) -> Union[List[TokensTextLogprobs],
List[TokensTextLogprobsPromptLogprobs]]:
inputs = self.get_inputs(prompts,
Expand All @@ -790,7 +793,8 @@ def generate_w_logprobs(
audios=audios)

req_outputs = self.model.generate(inputs,
sampling_params=sampling_params)
sampling_params=sampling_params,
**kwargs)

toks_str_logsprobs_prompt_logprobs = (
self._final_steps_generate_w_logprobs(req_outputs))
Expand Down Expand Up @@ -826,13 +830,15 @@ def generate_greedy(
images: Optional[PromptImageInput] = None,
videos: Optional[PromptVideoInput] = None,
audios: Optional[PromptAudioInput] = None,
**kwargs: Any,
) -> List[Tuple[List[int], str]]:
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
outputs = self.generate(prompts,
greedy_params,
images=images,
videos=videos,
audios=audios)
audios=audios,
**kwargs)
return [(output_ids[0], output_str[0])
for output_ids, output_str in outputs]

Expand All @@ -847,6 +853,7 @@ def generate_greedy_logprobs(
videos: Optional[PromptVideoInput] = None,
stop_token_ids: Optional[List[int]] = None,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Union[List[TokensTextLogprobs],
List[TokensTextLogprobsPromptLogprobs]]:
greedy_logprobs_params = SamplingParams(
Expand All @@ -861,7 +868,8 @@ def generate_greedy_logprobs(
greedy_logprobs_params,
images=images,
audios=audios,
videos=videos)
videos=videos,
**kwargs)

def generate_encoder_decoder_greedy_logprobs(
self,
Expand Down
121 changes: 121 additions & 0 deletions tests/lora/test_ultravox.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import shutil
from os import path
from tempfile import TemporaryDirectory
from typing import List, Tuple

import torch
from huggingface_hub import snapshot_download
from safetensors.torch import load_file, save_file
from transformers import AutoTokenizer

from vllm.lora.request import LoRARequest

from ..models.utils import check_outputs_equal

ULTRAVOX_MODEL_NAME = "fixie-ai/ultravox-v0_3"
LLMA_MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"

VLLM_PLACEHOLDER = "<|reserved_special_token_0|>"

PROMPT = "Tell me about a Fool's mate move in 20 words. Provide the moves!"


def llama3_1_8b_chess_lora_path():
return snapshot_download(
repo_id="mkopecki/chess-lora-adapter-llama-3.1-8b")


# can't use llama lora adapter without module name transformation
# because ultravox nest language model
def transform_module_names_for_ultravox(state_dict):
transformed_state_dict = {}
for key, value in state_dict.items():
new_key = key.replace("base_model.model",
"base_model.model.language_model")
transformed_state_dict[new_key] = value
return transformed_state_dict


def mk_llama3_1_8b_ultravox_chess_lora(source_repo, target_path):
tensor_file = "adapter_model.safetensors"
state_dict = load_file(path.join(source_repo, tensor_file))
transformed_state_dict = transform_module_names_for_ultravox(state_dict)

save_file(transformed_state_dict, path.join(target_path, tensor_file))

config_file = "adapter_config.json"
shutil.copyfile(path.join(source_repo, config_file),
path.join(target_path, config_file))
return target_path


def _get_prompt(audio_count, question, placeholder, model_name) -> str:
tokenizer = AutoTokenizer.from_pretrained(model_name)
placeholder = f"{placeholder}\n" * audio_count

return tokenizer.apply_chat_template([{
'role': 'user',
'content': f"{placeholder}{question}"
}],
tokenize=False,
add_generation_prompt=True)


def test_ultravox_lora(vllm_runner):
"""
TODO: Train an Ultravox LoRA instead of using a Llama LoRA.
"""
# Workaround to prevent device mismatch in Whisper.
# Can be removed when it is fixed upstream in transformer
# https://github.com/huggingface/transformers/pull/35866
torch.set_default_device("cpu")

llama3_1_8b_chess_lora = llama3_1_8b_chess_lora_path()
with TemporaryDirectory() as temp_ultravox_lora_dir:
llama3_1_8b_ultravox_chess_lora = mk_llama3_1_8b_ultravox_chess_lora(
llama3_1_8b_chess_lora, temp_ultravox_lora_dir)
with vllm_runner(
ULTRAVOX_MODEL_NAME,
enforce_eager=True,
max_num_seqs=2,
enable_lora=True,
max_loras=1,
max_lora_rank=128,
dtype="bfloat16",
max_model_len=1024,
) as vllm_model:
ultravox_outputs: List[Tuple[
List[int], str]] = vllm_model.generate_greedy(
[
_get_prompt(0, PROMPT, VLLM_PLACEHOLDER,
ULTRAVOX_MODEL_NAME)
],
256,
lora_request=LoRARequest(str(1), 1,
llama3_1_8b_ultravox_chess_lora),
)

# run llama with and without lora to compare outputs with above
with vllm_runner(
LLMA_MODEL_NAME,
enforce_eager=True,
max_num_seqs=2,
enable_lora=True,
max_loras=1,
max_lora_rank=128,
dtype="bfloat16",
max_model_len=1024,
) as vllm_model:
llama_outputs: List[Tuple[List[int], str]] = (
vllm_model.generate_greedy(
[_get_prompt(0, PROMPT, VLLM_PLACEHOLDER, LLMA_MODEL_NAME)],
256,
lora_request=LoRARequest(str(1), 1, llama3_1_8b_chess_lora),
))

check_outputs_equal(
outputs_0_lst=ultravox_outputs,
outputs_1_lst=llama_outputs,
name_0="ultravox",
name_1="llama",
)
28 changes: 26 additions & 2 deletions vllm/model_executor/models/ultravox.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.model_loader.loader import DefaultModelLoader
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
Expand All @@ -33,7 +34,7 @@
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.ultravox import UltravoxConfig

from .interfaces import SupportsMultiModal, SupportsPP
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings,
Expand Down Expand Up @@ -343,7 +344,20 @@ def forward(
UltravoxMultiModalProcessor,
info=UltravoxProcessingInfo,
dummy_inputs=UltravoxDummyInputsBuilder)
class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP):
class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):

packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"]
}

# LoRA specific attributes
# TODO : Add LoRA to the audio tower and projector.
supported_lora_modules = [
"qkv_proj", "o_proj", "gate_up_proj", "down_proj"
]
embedding_modules = {}
embedding_padding_modules = []

hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={"audio_tower.model.encoder.": "audio_tower."})
Expand Down Expand Up @@ -391,6 +405,16 @@ def sampler(self):

return get_sampler()

def get_mm_mapping(self) -> MultiModelKeys:
"""
Get the module prefix in multimodal models
"""
return MultiModelKeys.from_string_field(
language_model="language_model.",
connector="multi_modal_projector.",
tower_model="audio_tower.",
)

def _audio_features_to_embeddings(
self, input_features: torch.Tensor) -> torch.Tensor:
audio_input = input_features.to(self.audio_tower.dtype)
Expand Down

0 comments on commit d88506d

Please sign in to comment.