Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] LoRA Support for Ultravox model #11253

Merged
merged 21 commits into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
1c55938
WIP: early draft of lora support in Ultravox
thedebugger Nov 18, 2024
5a6b79f
format fixes
thedebugger Nov 19, 2024
3f5996c
Fix lora modules and formatting
thedebugger Dec 17, 2024
d1b65eb
Merge branch 'main' of https://github.com/vllm-project/vllm into svij…
jeejeelee Jan 6, 2025
7367bc2
Merge branch 'main' of https://github.com/vllm-project/vllm into svij…
jeejeelee Jan 6, 2025
2abf2ab
Done
jeejeelee Jan 6, 2025
be87788
Address code review feedback
thedebugger Jan 10, 2025
317fc38
Merge branch 'main' into svij-ultravox-lora-dec-16
thedebugger Jan 11, 2025
4a633d3
Fix formatting and test case
thedebugger Jan 11, 2025
224a65e
Merge branch 'main' of https://github.com/vllm-project/vllm into svij…
jeejeelee Jan 16, 2025
769f7bd
Done
jeejeelee Jan 16, 2025
907b3c7
Merge branch 'main' of https://github.com/vllm-project/vllm into svij…
jeejeelee Jan 16, 2025
208e662
Add doc
jeejeelee Jan 16, 2025
1248d5f
Merge branch 'main' of https://github.com/vllm-project/vllm into svij…
jeejeelee Jan 18, 2025
575b5dc
Merge branch 'main' of https://github.com/vllm-project/vllm into svij…
jeejeelee Jan 19, 2025
7cb7eba
Merge branch 'main' of https://github.com/vllm-project/vllm into svij…
jeejeelee Jan 20, 2025
f483d9a
Optmize unit test
jeejeelee Jan 20, 2025
1976ee0
Test setting cpu as a default device
thedebugger Jan 22, 2025
80fb1b8
Merge remote-tracking branch 'origin/main' into svij-ultravox-lora-de…
thedebugger Jan 27, 2025
c5cdde7
Merge remote-tracking branch 'origin/main' into svij-ultravox-lora-de…
thedebugger Jan 28, 2025
0b18650
Merge remote-tracking branch 'origin/main' into svij-ultravox-lora-de…
thedebugger Jan 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -759,7 +759,7 @@ See [this page](#generative-models) for more information on how to use generativ
- Ultravox
- T + A<sup>E+</sup>
- `fixie-ai/ultravox-v0_3`
-
- ✅︎
- ✅︎
- ✅︎
```
Expand Down
16 changes: 12 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,14 +734,16 @@ def generate(
images: Optional[PromptImageInput] = None,
videos: Optional[PromptVideoInput] = None,
audios: Optional[PromptAudioInput] = None,
**kwargs: Any,
) -> List[Tuple[List[List[int]], List[str]]]:
inputs = self.get_inputs(prompts,
images=images,
videos=videos,
audios=audios)

req_outputs = self.model.generate(inputs,
sampling_params=sampling_params)
sampling_params=sampling_params,
**kwargs)

outputs: List[Tuple[List[List[int]], List[str]]] = []
for req_output in req_outputs:
Expand Down Expand Up @@ -779,6 +781,7 @@ def generate_w_logprobs(
images: Optional[PromptImageInput] = None,
audios: Optional[PromptAudioInput] = None,
videos: Optional[PromptVideoInput] = None,
**kwargs: Any,
) -> Union[List[TokensTextLogprobs],
List[TokensTextLogprobsPromptLogprobs]]:
inputs = self.get_inputs(prompts,
Expand All @@ -787,7 +790,8 @@ def generate_w_logprobs(
audios=audios)

req_outputs = self.model.generate(inputs,
sampling_params=sampling_params)
sampling_params=sampling_params,
**kwargs)

toks_str_logsprobs_prompt_logprobs = (
self._final_steps_generate_w_logprobs(req_outputs))
Expand Down Expand Up @@ -823,13 +827,15 @@ def generate_greedy(
images: Optional[PromptImageInput] = None,
videos: Optional[PromptVideoInput] = None,
audios: Optional[PromptAudioInput] = None,
**kwargs: Any,
) -> List[Tuple[List[int], str]]:
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
outputs = self.generate(prompts,
greedy_params,
images=images,
videos=videos,
audios=audios)
audios=audios,
**kwargs)
return [(output_ids[0], output_str[0])
for output_ids, output_str in outputs]

Expand All @@ -844,6 +850,7 @@ def generate_greedy_logprobs(
videos: Optional[PromptVideoInput] = None,
stop_token_ids: Optional[List[int]] = None,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Union[List[TokensTextLogprobs],
List[TokensTextLogprobsPromptLogprobs]]:
greedy_logprobs_params = SamplingParams(
Expand All @@ -858,7 +865,8 @@ def generate_greedy_logprobs(
greedy_logprobs_params,
images=images,
audios=audios,
videos=videos)
videos=videos,
**kwargs)

def generate_encoder_decoder_greedy_logprobs(
self,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes to this file are not related to this PR, please revert.

Expand Down
133 changes: 133 additions & 0 deletions tests/lora/test_ultravox.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import shutil
from os import path
from tempfile import TemporaryDirectory
from typing import List, Tuple

from huggingface_hub import snapshot_download
from safetensors.torch import load_file, save_file
from transformers import AutoTokenizer

from vllm.lora.request import LoRARequest

from ..models.utils import check_outputs_equal

ULTRAVOX_MODEL_NAME = "fixie-ai/ultravox-v0_3"
LLMA_MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"

VLLM_PLACEHOLDER = "<|reserved_special_token_0|>"

PROMPT = "Tell me about a Fool's mate move in 20 words. Provide the moves!"


def llama3_1_8b_chess_lora_path():
return snapshot_download(
repo_id="mkopecki/chess-lora-adapter-llama-3.1-8b")


# can't use llama lora adapter without module name transformation
# because ultravox nest language model
def transform_module_names_for_ultravox(state_dict):
transformed_state_dict = {}
for key, value in state_dict.items():
new_key = key.replace("base_model.model",
"base_model.model.language_model")
transformed_state_dict[new_key] = value
return transformed_state_dict


def mk_llama3_1_8b_ultravox_chess_lora(source_repo, target_path):
tensor_file = "adapter_model.safetensors"
state_dict = load_file(path.join(source_repo, tensor_file))
transformed_state_dict = transform_module_names_for_ultravox(state_dict)

save_file(transformed_state_dict, path.join(target_path, tensor_file))

config_file = "adapter_config.json"
shutil.copyfile(path.join(source_repo, config_file),
path.join(target_path, config_file))
return target_path


def _get_prompt(audio_count, question, placeholder, model_name) -> str:
tokenizer = AutoTokenizer.from_pretrained(model_name)
placeholder = f"{placeholder}\n" * audio_count

return tokenizer.apply_chat_template([{
'role': 'user',
'content': f"{placeholder}{question}"
}],
tokenize=False,
add_generation_prompt=True)


def test_ultravox_lora(vllm_runner):
llama3_1_8b_chess_lora = llama3_1_8b_chess_lora_path()
with TemporaryDirectory() as temp_ultravox_lora_dir:
llama3_1_8b_ultravox_chess_lora = mk_llama3_1_8b_ultravox_chess_lora(
llama3_1_8b_chess_lora, temp_ultravox_lora_dir)
with vllm_runner(
ULTRAVOX_MODEL_NAME,
enforce_eager=True,
max_num_seqs=128,
enable_lora=True,
max_loras=4,
max_lora_rank=128,
dtype="bfloat16",
max_model_len=4096,
) as vllm_model:
ultravox_outputs: List[Tuple[
List[int], str]] = vllm_model.generate_greedy(
[
_get_prompt(0, PROMPT, VLLM_PLACEHOLDER,
ULTRAVOX_MODEL_NAME)
],
256,
lora_request=LoRARequest(str(1), 1,
llama3_1_8b_ultravox_chess_lora),
)

# run llama with and without lora to compare outputs with above
with vllm_runner(
LLMA_MODEL_NAME,
enforce_eager=True,
max_num_seqs=128,
enable_lora=True,
max_loras=4,
max_lora_rank=128,
dtype="bfloat16",
max_model_len=4096,
) as vllm_model:
llama_outputs_no_lora: List[Tuple[List[int],
str]] = vllm_model.generate_greedy(
[
_get_prompt(
0, PROMPT,
VLLM_PLACEHOLDER,
LLMA_MODEL_NAME)
],
256,
)
llama_outputs: List[Tuple[List[int],
str]] = vllm_model.generate_greedy(
[
_get_prompt(0, PROMPT,
VLLM_PLACEHOLDER,
LLMA_MODEL_NAME)
],
256,
lora_request=LoRARequest(
str(1), 1, llama3_1_8b_chess_lora),
)

check_outputs_equal(
outputs_0_lst=ultravox_outputs,
outputs_1_lst=llama_outputs,
name_0="ultravox",
name_1="llama",
)

_, llama_no_lora_str = llama_outputs_no_lora[0]
_, ultravox_str = ultravox_outputs[0]

# verify that text don't match with no lora
assert llama_no_lora_str != ultravox_str
28 changes: 26 additions & 2 deletions vllm/model_executor/models/ultravox.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.model_loader.loader import DefaultModelLoader
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
Expand All @@ -31,7 +32,7 @@
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.ultravox import UltravoxConfig

from .interfaces import SupportsMultiModal, SupportsPP
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings,
Expand Down Expand Up @@ -340,7 +341,20 @@ def forward(
info=UltravoxProcessingInfo,
dummy_inputs=UltravoxDummyInputsBuilder
)
class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP):
class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):

packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"]
}

# LoRA specific attributes
# TODO : Add LoRA to the audio tower and projector.
supported_lora_modules = [
"qkv_proj", "o_proj", "gate_up_proj", "down_proj"
]
embedding_modules = {}
embedding_padding_modules = []

hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={"audio_tower.model.encoder.": "audio_tower."})
Expand Down Expand Up @@ -388,6 +402,16 @@ def sampler(self):

return get_sampler()

def get_mm_mapping(self) -> MultiModelKeys:
"""
Get the module prefix in multimodal models
"""
return MultiModelKeys.from_string_field(
language_model="language_model.",
connector="multi_modal_projector.",
tower_model="audio_tower.",
)

def _audio_features_to_embeddings(
self, input_features: torch.Tensor) -> torch.Tensor:
audio_input = input_features.to(self.audio_tower.dtype)
Expand Down
Loading