Skip to content

Commit

Permalink
[Model] Initialize Florence-2 language backbone support (vllm-project…
Browse files Browse the repository at this point in the history
…#9555)

Signed-off-by: NickLucche <nlucches@redhat.com>
  • Loading branch information
Isotr0py authored and NickLucche committed Oct 31, 2024
1 parent 4da2860 commit ebc8ecc
Show file tree
Hide file tree
Showing 5 changed files with 428 additions and 8 deletions.
44 changes: 44 additions & 0 deletions examples/florence2_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
'''
Demonstrate prompting of text-to-text
encoder/decoder models, specifically Florence-2
'''
# TODO(Isotr0py):
# Move to offline_inference_vision_language.py after porting vision backbone
from vllm import LLM, SamplingParams

dtype = "float"

# Create a Florence-2 encoder/decoder model instance
llm = LLM(
model="microsoft/Florence-2-base",
tokenizer="facebook/bart-base",
dtype=dtype,
trust_remote_code=True,
)

prompts = [
"<CAPTION>", "<DETAILED_CAPTION>", "<MORE_DETAILED_CAPTION>",
"<CAPTION_TO_PHRASE_GROUNDING>", "<OD>", "<DENSE_REGION_CAPTION>",
"<REGION_PROPOSAL>", "<OCR>", "<OCR_WITH_REGION>"
]
# Create a sampling params object.
sampling_params = SamplingParams(
temperature=0,
top_p=1.0,
min_tokens=0,
max_tokens=20,
)

# Generate output tokens from the prompts. The output is a list of
# RequestOutput objects that contain the prompt, generated
# text, and other information.
outputs = llm.generate(prompts, sampling_params)

# Print the outputs.
for output in outputs:
prompt = output.prompt
encoder_prompt = output.encoder_prompt
generated_text = output.outputs[0].text
print(f"Encoder prompt: {encoder_prompt!r}, "
f"Decoder prompt: {prompt!r}, "
f"Generated text: {generated_text!r}")
28 changes: 20 additions & 8 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,9 @@ def __init__(
dtype: str = "half",
*,
model_kwargs: Optional[Dict[str, Any]] = None,
is_embedding_model: bool = False,
is_sentence_transformer: bool = False,
skip_tokenizer_init: bool = False,
auto_cls: Type[_BaseAutoModelClass] = AutoModelForCausalLM,
postprocess_inputs: Callable[[BatchEncoding],
BatchEncoding] = identity,
Expand Down Expand Up @@ -281,11 +283,12 @@ def __init__(
**model_kwargs,
))

self.tokenizer = AutoTokenizer.from_pretrained(
model_name,
torch_dtype=torch_dtype,
trust_remote_code=True,
)
if not skip_tokenizer_init:
self.tokenizer = AutoTokenizer.from_pretrained(
model_name,
torch_dtype=torch_dtype,
trust_remote_code=True,
)

# don't put this import at the top level
# it will call torch.cuda.device_count()
Expand All @@ -295,6 +298,8 @@ def __init__(
torch_dtype=torch_dtype,
trust_remote_code=True,
)
if skip_tokenizer_init:
self.tokenizer = self.processor.tokenizer

self.postprocess_inputs = postprocess_inputs

Expand Down Expand Up @@ -535,6 +540,7 @@ def generate_encoder_decoder_greedy_logprobs_limit(
encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
max_tokens: int,
num_logprobs: int,
images: Optional[PromptImageInput] = None,
**kwargs: Any,
) -> List[TokensTextLogprobs]:
'''
Expand All @@ -545,11 +551,17 @@ def generate_encoder_decoder_greedy_logprobs_limit(
all_output_ids: List[List[int]] = []
all_output_strs: List[str] = []

for (encoder_prompt,
decoder_prompt) in to_enc_dec_tuple_list(encoder_decoder_prompts):
for i, (encoder_prompt, decoder_prompt) in enumerate(
to_enc_dec_tuple_list(encoder_decoder_prompts)):
processor_kwargs: Dict[str, Any] = {
"text": encoder_prompt,
"return_tensors": "pt",
}
if images is not None and images[i] is not None:
processor_kwargs["images"] = images[i]

encoder_input_ids = self.wrap_device(
self.tokenizer(encoder_prompt, return_tensors="pt").input_ids,
self.processor(**processor_kwargs).input_ids,
device=self.model.device.type,
)

Expand Down
102 changes: 102 additions & 0 deletions tests/models/encoder_decoder/vision_language/test_florence2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from functools import partial
from typing import List, Optional, Tuple, Type

import pytest
from PIL import Image

from vllm.inputs.data import ExplicitEncoderDecoderPrompt
from vllm.sequence import SampleLogprobs

from ....conftest import HfRunner, VllmRunner
from ...utils import check_logprobs_close

Florence2Prompt = partial(ExplicitEncoderDecoderPrompt,
decoder_prompt=None,
mm_processor_kwargs=None)

MODELS = ["microsoft/Florence-2-base"]
# Florence-2 uses BartFastTokenizer which can't be loaded from AutoTokenizer
# Therefore, we borrow the BartTokenizer from the original Bart model
TOKENIZER = "facebook/bart-base"
PROMPTS = [
Florence2Prompt(encoder_prompt="<CAPTION>"),
Florence2Prompt(encoder_prompt="<DETAILED_CAPTION>"),
Florence2Prompt(encoder_prompt="<MORE_DETAILED_CAPTION>"),
Florence2Prompt(encoder_prompt="<CAPTION_TO_PHRASE_GROUNDING>"),
Florence2Prompt(encoder_prompt="<DENSE_REGION_CAPTION>"),
Florence2Prompt(encoder_prompt="<REGION_PROPOSAL>"),
Florence2Prompt(encoder_prompt="<OCR_WITH_REGION>"),
Florence2Prompt(encoder_prompt="<OCR>"),
Florence2Prompt(encoder_prompt="<OD>"),
]


def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
Optional[SampleLogprobs]], ):
"""Sanitize vllm output to be comparable with hf output."""
output_ids, output_str, out_logprobs = vllm_output

hf_output_str = "</s><s>" + output_str + "</s>"

return output_ids, hf_output_str, out_logprobs


def run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
prompts: List[ExplicitEncoderDecoderPrompt],
model: str,
*,
dtype: str,
max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
) -> None:
with vllm_runner(model,
tokenizer_name=TOKENIZER,
dtype=dtype,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True) as vllm_model:
vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs(
prompts, max_tokens, num_logprobs)

# Florence-2 processors require image inputs
dummy_image = Image.new(mode="RGB", size=(2, 2))
with hf_runner(model, dtype=dtype, skip_tokenizer_init=True) as hf_model:
hf_model.model.get_output_embeddings = lambda: \
hf_model.model.language_model.lm_head
hf_outputs = (hf_model.generate_encoder_decoder_greedy_logprobs_limit(
prompts,
max_tokens,
num_logprobs,
images=[dummy_image] * len(prompts),
))

check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=[
vllm_to_hf_output(vllm_output) for vllm_output in vllm_outputs
],
name_0="hf",
name_1="vllm",
)


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(hf_runner, vllm_runner, model, dtype, max_tokens,
num_logprobs) -> None:
run_test(
hf_runner,
vllm_runner,
PROMPTS,
model,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)
Loading

0 comments on commit ebc8ecc

Please sign in to comment.