forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Model] Initialize Florence-2 language backbone support (vllm-project…
…#9555) Signed-off-by: NickLucche <nlucches@redhat.com>
- Loading branch information
1 parent
4da2860
commit ebc8ecc
Showing
5 changed files
with
428 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
''' | ||
Demonstrate prompting of text-to-text | ||
encoder/decoder models, specifically Florence-2 | ||
''' | ||
# TODO(Isotr0py): | ||
# Move to offline_inference_vision_language.py after porting vision backbone | ||
from vllm import LLM, SamplingParams | ||
|
||
dtype = "float" | ||
|
||
# Create a Florence-2 encoder/decoder model instance | ||
llm = LLM( | ||
model="microsoft/Florence-2-base", | ||
tokenizer="facebook/bart-base", | ||
dtype=dtype, | ||
trust_remote_code=True, | ||
) | ||
|
||
prompts = [ | ||
"<CAPTION>", "<DETAILED_CAPTION>", "<MORE_DETAILED_CAPTION>", | ||
"<CAPTION_TO_PHRASE_GROUNDING>", "<OD>", "<DENSE_REGION_CAPTION>", | ||
"<REGION_PROPOSAL>", "<OCR>", "<OCR_WITH_REGION>" | ||
] | ||
# Create a sampling params object. | ||
sampling_params = SamplingParams( | ||
temperature=0, | ||
top_p=1.0, | ||
min_tokens=0, | ||
max_tokens=20, | ||
) | ||
|
||
# Generate output tokens from the prompts. The output is a list of | ||
# RequestOutput objects that contain the prompt, generated | ||
# text, and other information. | ||
outputs = llm.generate(prompts, sampling_params) | ||
|
||
# Print the outputs. | ||
for output in outputs: | ||
prompt = output.prompt | ||
encoder_prompt = output.encoder_prompt | ||
generated_text = output.outputs[0].text | ||
print(f"Encoder prompt: {encoder_prompt!r}, " | ||
f"Decoder prompt: {prompt!r}, " | ||
f"Generated text: {generated_text!r}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
102 changes: 102 additions & 0 deletions
102
tests/models/encoder_decoder/vision_language/test_florence2.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
from functools import partial | ||
from typing import List, Optional, Tuple, Type | ||
|
||
import pytest | ||
from PIL import Image | ||
|
||
from vllm.inputs.data import ExplicitEncoderDecoderPrompt | ||
from vllm.sequence import SampleLogprobs | ||
|
||
from ....conftest import HfRunner, VllmRunner | ||
from ...utils import check_logprobs_close | ||
|
||
Florence2Prompt = partial(ExplicitEncoderDecoderPrompt, | ||
decoder_prompt=None, | ||
mm_processor_kwargs=None) | ||
|
||
MODELS = ["microsoft/Florence-2-base"] | ||
# Florence-2 uses BartFastTokenizer which can't be loaded from AutoTokenizer | ||
# Therefore, we borrow the BartTokenizer from the original Bart model | ||
TOKENIZER = "facebook/bart-base" | ||
PROMPTS = [ | ||
Florence2Prompt(encoder_prompt="<CAPTION>"), | ||
Florence2Prompt(encoder_prompt="<DETAILED_CAPTION>"), | ||
Florence2Prompt(encoder_prompt="<MORE_DETAILED_CAPTION>"), | ||
Florence2Prompt(encoder_prompt="<CAPTION_TO_PHRASE_GROUNDING>"), | ||
Florence2Prompt(encoder_prompt="<DENSE_REGION_CAPTION>"), | ||
Florence2Prompt(encoder_prompt="<REGION_PROPOSAL>"), | ||
Florence2Prompt(encoder_prompt="<OCR_WITH_REGION>"), | ||
Florence2Prompt(encoder_prompt="<OCR>"), | ||
Florence2Prompt(encoder_prompt="<OD>"), | ||
] | ||
|
||
|
||
def vllm_to_hf_output(vllm_output: Tuple[List[int], str, | ||
Optional[SampleLogprobs]], ): | ||
"""Sanitize vllm output to be comparable with hf output.""" | ||
output_ids, output_str, out_logprobs = vllm_output | ||
|
||
hf_output_str = "</s><s>" + output_str + "</s>" | ||
|
||
return output_ids, hf_output_str, out_logprobs | ||
|
||
|
||
def run_test( | ||
hf_runner: Type[HfRunner], | ||
vllm_runner: Type[VllmRunner], | ||
prompts: List[ExplicitEncoderDecoderPrompt], | ||
model: str, | ||
*, | ||
dtype: str, | ||
max_tokens: int, | ||
num_logprobs: int, | ||
tensor_parallel_size: int, | ||
distributed_executor_backend: Optional[str] = None, | ||
) -> None: | ||
with vllm_runner(model, | ||
tokenizer_name=TOKENIZER, | ||
dtype=dtype, | ||
tensor_parallel_size=tensor_parallel_size, | ||
distributed_executor_backend=distributed_executor_backend, | ||
enforce_eager=True) as vllm_model: | ||
vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs( | ||
prompts, max_tokens, num_logprobs) | ||
|
||
# Florence-2 processors require image inputs | ||
dummy_image = Image.new(mode="RGB", size=(2, 2)) | ||
with hf_runner(model, dtype=dtype, skip_tokenizer_init=True) as hf_model: | ||
hf_model.model.get_output_embeddings = lambda: \ | ||
hf_model.model.language_model.lm_head | ||
hf_outputs = (hf_model.generate_encoder_decoder_greedy_logprobs_limit( | ||
prompts, | ||
max_tokens, | ||
num_logprobs, | ||
images=[dummy_image] * len(prompts), | ||
)) | ||
|
||
check_logprobs_close( | ||
outputs_0_lst=hf_outputs, | ||
outputs_1_lst=[ | ||
vllm_to_hf_output(vllm_output) for vllm_output in vllm_outputs | ||
], | ||
name_0="hf", | ||
name_1="vllm", | ||
) | ||
|
||
|
||
@pytest.mark.parametrize("model", MODELS) | ||
@pytest.mark.parametrize("dtype", ["float"]) | ||
@pytest.mark.parametrize("max_tokens", [64]) | ||
@pytest.mark.parametrize("num_logprobs", [5]) | ||
def test_models(hf_runner, vllm_runner, model, dtype, max_tokens, | ||
num_logprobs) -> None: | ||
run_test( | ||
hf_runner, | ||
vllm_runner, | ||
PROMPTS, | ||
model, | ||
dtype=dtype, | ||
max_tokens=max_tokens, | ||
num_logprobs=num_logprobs, | ||
tensor_parallel_size=1, | ||
) |
Oops, something went wrong.