Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Support SigLIP encoder and alternative decoders for LLaVA models #7153

Merged
merged 19 commits into from
Aug 6, 2024

Conversation

DarkLight1337
Copy link
Member

@DarkLight1337 DarkLight1337 commented Aug 5, 2024

Based on #7067, this PR partially addresses #7143 by enabling SigLIP to be used in LLaVA models. The same code also enables them to load other LLM decoders.

This PR also fixes two hidden bugs in the calculation for the number of placeholder tokens:

  • For CLIP encoder, the feature size should be increased by 1 because the CLS token is prepended to the output.
  • For LLaVA models, the feature size should be reduced by 1 when vision_feature_select_strategy="default".

These two miscalculations cancelled out each other prior to this PR. They were discovered when trying to support SigLIP in #7149 because SigLIP encoder, unlike CLIP encoder, does not output the CLS token, resulting in an off-by-one error when filling in placeholders.

FIX #7149

@BrenchCC
Copy link

BrenchCC commented Aug 5, 2024

I have tried this method 3 hour ago. Bue the error as follos:
[rank0]: File "/data/vayu/train/workspace/vllm/vllm/model_executor/models/utils.py", line 29, in merge_vision_embeddings
[rank0]: raise ValueError(
[rank0]: ValueError: Attempted to assign 11 x 728 = 8008 image tokens to 8019 placeholders
It's seems that the model need more image tokens.

@DarkLight1337
Copy link
Member Author

DarkLight1337 commented Aug 5, 2024

I have tried this method 3 hour ago. Bue the error as follos: [rank0]: File "/data/vayu/train/workspace/vllm/vllm/model_executor/models/utils.py", line 29, in merge_vision_embeddings [rank0]: raise ValueError( [rank0]: ValueError: Attempted to assign 11 x 728 = 8008 image tokens to 8019 placeholders It's seems that the model need more image tokens.

Can you show the input to your model? Also note that we don't support multiple image input per prompt for now, so you can only test with single image input.

@DarkLight1337
Copy link
Member Author

DarkLight1337 commented Aug 5, 2024

OK I have found an issue with the image placeholder calculation, it is related to vision_feature_select_strategy. Fixing...

@DarkLight1337
Copy link
Member Author

Fixed, can you try again?

@BrenchCC
Copy link

BrenchCC commented Aug 5, 2024 via email

@DarkLight1337 DarkLight1337 changed the title [Model] Support alternative LLM backbone and SigLIP encoder for LLaVA model [Model] Support alternative LLM backbone and SigLIP encoder for LLaVA models Aug 5, 2024
@DarkLight1337 DarkLight1337 changed the title [Model] Support alternative LLM backbone and SigLIP encoder for LLaVA models [Model] Support alternative LLM and SigLIP encoder for LLaVA models Aug 5, 2024
@BrenchCC
Copy link

BrenchCC commented Aug 5, 2024

from io import BytesIO

import requests
from PIL import Image

from vllm import LLM, SamplingParams


def run_llava_next():
    llm = LLM(model="TIGER-Lab/Mantis-8B-siglip-llama3", max_model_len=4096)

    prompt = " <image>\nWhat is shown in this image? "
    url = "https://h2o-release.s3.amazonaws.com/h2ogpt/bigben.jpg"
    image = Image.open(BytesIO(requests.get(url).content))
    sampling_params = SamplingParams(temperature=0.8,
                                     top_p=0.95,
                                     max_tokens=100)

    outputs = llm.generate(
        {
            "prompt": prompt,
            "multi_modal_data": {
                "image": image
            }
        },
        sampling_params=sampling_params)

    generated_text = ""
    for o in outputs:
        generated_text += o.outputs[0].text

    print(f"LLM output:{generated_text}")


if __name__ == "__main__":
    run_llava_next()

can you try this? because my conda env has a problem:
[rank0]: AttributeError: '_OpNamespace' '_C' object has no attribute 'rms_norm'

@DarkLight1337
Copy link
Member Author

DarkLight1337 commented Aug 5, 2024

from io import BytesIO

import requests
from PIL import Image

from vllm import LLM, SamplingParams

def run_llava_next():
llm = LLM(model="TIGER-Lab/Mantis-8B-siglip-llama3", max_model_len=4096)

prompt = " <image>\nWhat is shown in this image? "
url = "https://h2o-release.s3.amazonaws.com/h2ogpt/bigben.jpg"
image = Image.open(BytesIO(requests.get(url).content))
sampling_params = SamplingParams(temperature=0.8,
                                 top_p=0.95,
                                 max_tokens=100)

outputs = llm.generate(
    {
        "prompt": prompt,
        "multi_modal_data": {
            "image": image
        }
    },
    sampling_params=sampling_params)

generated_text = ""
for o in outputs:
    generated_text += o.outputs[0].text

print(f"LLM output:{generated_text}")

if name == "main":
run_llava_next()

You should delete any compiled binaries from your local vLLM repo, and reinstall vLLM. Alternatively, delete the repo and clone a fresh one.

@DarkLight1337
Copy link
Member Author

DarkLight1337 commented Aug 5, 2024

I have tested your example and the model can run in vLLM, but the output is gibberish (probably something wrong with the input processor). Since this PR isn't aimed at fully implementing Mantis (but rather to enable it), we can create another PR after this one to resolve those issues.

LLM output:

1

Answer: A

A light trail is lit up with a light bulb that represents a different point in the image. 2

Answer: A

A light trail is lit up with a different point in the image. 3<|eot_id|><|start_header_id|>assistant<|end_header_id|>

Answer: B<|eot_id|><|start_header_id|>assistant<|end_header_id|>

B<|eot_id|><|start_header_id|>assistant<|end_header_id|>

Answer: A<|eot_id|><|start_header_id|>assistant<|end_header_id|>

A light trail is lit up with a different point in the image. 4<|eot_id|><|start_header_id|>assistant<|end_header_id|>

@BrenchCC
Copy link

BrenchCC commented Aug 5, 2024

I know the mistake. I need to build on CUDA. Your branch is base on AMD. LOL

@BrenchCC
Copy link

BrenchCC commented Aug 5, 2024

Thanks for your help!!!

@DarkLight1337 DarkLight1337 changed the title [Model] Support alternative LLM and SigLIP encoder for LLaVA models [Model] Support alternative LLMs and SigLIP encoder for LLaVA models Aug 5, 2024
@DarkLight1337 DarkLight1337 changed the title [Model] Support alternative LLMs and SigLIP encoder for LLaVA models [Model] Support SigLIP encoder and alternative decoders for LLaVA models Aug 5, 2024
@ywang96 ywang96 self-assigned this Aug 5, 2024
@BrenchCC
Copy link

BrenchCC commented Aug 6, 2024

Yep,because there are some modifications of the 'LlavaForConditionalGeneration' in Mantis about the tokens. So Mantis depends on its own class to use.

Copy link
Member

@ywang96 ywang96 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM but I left a few comments - PTAL!

requirements-test.txt Outdated Show resolved Hide resolved
tests/models/test_llava.py Show resolved Hide resolved
tests/models/test_llava.py Show resolved Hide resolved
vllm/model_executor/models/utils.py Outdated Show resolved Hide resolved
@DarkLight1337 DarkLight1337 added the ready ONLY add when PR is ready to merge/full CI is needed label Aug 6, 2024
@DarkLight1337 DarkLight1337 merged commit 1f26efb into vllm-project:main Aug 6, 2024
37 checks passed
@DarkLight1337 DarkLight1337 deleted the llava-siglip branch August 6, 2024 08:55
@BrenchCC
Copy link

BrenchCC commented Aug 8, 2024

may be should use some function to remove the unexpected tag.
reference:https://github.com/TIGER-AI-Lab/Mantis/blob/main/mantis/models/mllava/utils.py

import PIL
import torch
from .modeling_llava import LlavaForConditionalGeneration
from .processing_llava import MLlavaProcessor
# from ..conversation import conv_mllava_v1_mmtag as default_conv
from ..conversation import conv_mllava_v1 as default_conv, conv_templates

from typing import List, Tuple, Union, Tuple

def chat_mllava(
    text:str, 
    images: List[Union[PIL.Image.Image, str]], 
    model:LlavaForConditionalGeneration, 
    processor:MLlavaProcessor, 
    max_input_length:int=None, 
    history:List[dict]=None, 
    **kwargs) -> Tuple[str, List[dict]]:
    """
    Chat with the Mllava model
    Args:
        text: str, the text to be sent to the model, where <image> will be the placeholder for the image
        images: List[PIL.Image.Image], the images to be sent to the model, or None  
        model: LlavaForConditionalGeneration, the model to be used
        processor: MLlavaProcessor, the processor to be used
        max_input_length: int, the maximum input length
        history: List[dict], list of messages in the conversation as history. Each message is a dictionary {"role": "ASSISTANT/USER", "text": "the message"}. If None, the conversation will start from scratch
        kwargs: dict, the generation kwargs
    Returns:
        Tuple[str, List[dict]], the generated text and the history of the conversation
        

    """
    if "llama-3" in model.language_model.name_or_path.lower():
        conv = conv_templates['llama_3']
        terminators = [
            processor.tokenizer.eos_token_id,
            processor.tokenizer.convert_tokens_to_ids("<|eot_id|>")
        ]
    else:
        conv = default_conv
        terminators = None
    kwargs["eos_token_id"] = terminators
    conv = conv.copy()
    conv.messages = []
    if history is not None:
        for message in history:
            assert message["role"] in conv.roles
            conv.append_message(message["role"], message["text"])
        if text:
            assert conv.messages[-1][0] == conv.roles[1], "The last message in the history should be the assistant, if the given text is not empty"
            conv.append_message(conv.roles[0], text)
            conv.append_message(conv.roles[1], "")
            history.append({"role": conv.roles[0], "text": text})
            history.append({"role": conv.roles[1], "text": ""})
        else:
            if conv.messages[-1][0] == conv.roles[1]:
                assert conv.messages[-1][1] == "", "No user message should be provided"
            else:
                assert conv.messages[-1][0] == conv.roles[0], "The last message in the history should be the user, if the given text is empty"
                conv.append_message(conv.roles[0], "")
                history.append({"role": conv.roles[0], "text": ""})
    else:
        history = []
        history.append({"role": conv.roles[0], "text": text})
        history.append({"role": conv.roles[1], "text": ""})
        conv.append_message(conv.roles[0], text)
        conv.append_message(conv.roles[1], "")
    assert conv.messages[-1][0] == conv.roles[1] and conv.messages[-1][1] == "", "Format check"
    assert history[-1]["role"] == conv.roles[1] and history[-1]["text"] == "", "Format check"
    
    prompt = conv.get_prompt()
    if images:
        for i in range(len(images)):
            if isinstance(images[i], str):
                images[i] = PIL.Image.open(images[i]).convert("RGB")
    
    inputs = processor(images=images, text=prompt, return_tensors="pt", truncation=True, max_length=max_input_length)
    for k, v in inputs.items():
        if v is not None:
            if isinstance(v, torch.Tensor):
                inputs[k] = v.to(model.device)
            elif isinstance(v, list):
                inputs[k] = [x.to(model.device) for x in v]
            else:
                raise ValueError(f"Invalid input type: {type(v)}")
    

    output_ids = model.generate(**inputs, **kwargs)
    output_ids = output_ids[0]
    
    # remove the input tokens
    generated_ids = output_ids[inputs["input_ids"].shape[-1]:]
    generated_text = processor.decode(generated_ids, skip_special_tokens=True)

    history[-1]["text"] = generated_text
    
    return generated_text, history

reference: https://github.com/open-compass/VLMEvalKit/blob/main/vlmeval/vlm/mantis.py

import torch
from PIL import Image
from abc import abstractproperty
from .base import BaseModel
from ..smp import *
from ..dataset import DATASET_TYPE
import warnings


class Mantis(BaseModel):
    """
    Mantis Model
    This implementation is adpated from the Llava model from llava.py and the Idefics model from idefics.py
    """
    INSTALL_REQ = True
    INTERLEAVE = True

    DEFAULT_IMAGE_TOKEN = '<image>'
    IMAGE_TOKEN_INDEX = -200

    def __init__(self, model_path='TIGER-Lab/Mantis-8B-siglip-llama3', **kwargs):
        assert model_path is not None
        try:
            from mantis.models.mllava import LlavaForConditionalGeneration, MLlavaProcessor
            from mantis.models.mfuyu import MFuyuForCausalLM, MFuyuProcessor
            from mantis.models.conversation import conv_mllava_v1 as default_conv, conv_templates
        except:
            warnings.warn(
                "Mantis is not installed. Please install Mantis to use this model.Please use 'pip install "
                "git+https://github.com/TIGER-AI-Lab/Mantis.git' to install"
            )

        try:
            from transformers import AutoModelForVision2Seq, AutoProcessor
        except Exception as e:
            warnings.warn("Upgrade transformers to use Mantis's idefics model.\nError: %s" % e)
        except:
            warnings.warn('Please `pip install git+https://github.com/LLaVA-VL/LLaVA-NeXT.git')

        # inference implementation for attention, can be "sdpa", "eager", "flash_attention_2".
        # Seems FA2 is not effective during inference:
        # https://discuss.huggingface.co/t/flash-attention-has-no-effect-on-inference/73453/5
        # if is_flash_attn_2_available:
        #     best_fit_attn_implementation = "flash_attention_2"
        # flash_attn has a bug that says: ERROR Error query and key must have the same dtype in generating

        try:
            import flash_attn
            best_fit_attn_implementation = 'flash_attention_2'
        except ImportError:
            best_fit_attn_implementation = 'eager'
        self.model_path = model_path
        attn_implementation = best_fit_attn_implementation
        self._is_idefics = 'idefics' in model_path.lower()
        # Here load the "non-idefics" Mantis model.
        if not self._is_idefics:
            if 'fuyu' in model_path.lower():
                self.processor = MFuyuProcessor.from_pretrained(self.model_path)
                model = MFuyuForCausalLM.from_pretrained(
                    self.model_path,
                    device_map='cuda',
                    attn_implementation=attn_implementation,
                    torch_dtype=torch.float16
                )
            else:
                self.processor = MLlavaProcessor.from_pretrained(self.model_path)
                model = LlavaForConditionalGeneration.from_pretrained(
                    self.model_path,
                    device_map='cuda',
                    attn_implementation=attn_implementation,
                    torch_dtype=torch.float16
                )
        else:
            self.processor = AutoProcessor.from_pretrained(self.model_path)
            model = AutoModelForVision2Seq.from_pretrained(
                self.model_path,
                device_map='cuda',
                torch_dtype=torch.float16
            )

        model = model.eval()
        self.model = model.cuda()
        kwargs_default = dict(do_sample=False, temperature=0, max_new_tokens=1024, top_p=None, num_beams=1)
        kwargs_default.update(kwargs)
        self.kwargs = kwargs_default
        warnings.warn(f'Following kwargs received: {self.kwargs}, will use as generation config. ')

        self.tokenizer = self.processor.tokenizer
        self.default_conv = default_conv
        self.conv_templates = conv_templates

    def use_custom_prompt(self, dataset):
        assert dataset is not None
        if DATASET_TYPE(dataset) == 'MCQ':
            return True
        return False

    def build_prompt(self, line, dataset=None):
        assert self.use_custom_prompt(dataset)
        assert dataset is None or isinstance(dataset, str)
        tgt_path = self.dump_image(line, dataset)

        question = line['question']
        hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
        if hint is not None:
            question = hint + '\n' + question
        options = {
            cand: line[cand]
            for cand in string.ascii_uppercase
            if cand in line and not pd.isna(line[cand])
        }
        for key, item in options.items():
            question += f'\n{key}. {item}'
        prompt = question

        if len(options):
            prompt += (
                '\n请直接回答选项字母。' if cn_string(prompt) else
                "\nAnswer with the option's letter from the given choices directly."
            )
        else:
            prompt += '\n请直接回答问题。' if cn_string(prompt) else '\nAnswer the question directly.'
        message = [dict(type='image', value=s) for s in tgt_path]
        message.append(dict(type='text', value=prompt))
        return message

    def output_process(self, answer):
        if '<s>' in answer:
            answer = answer.replace('<s>', '').strip()
        if '[/INST]' in answer:
            answer = answer.split('[/INST]')[1].strip()
        elif 'ASSISTANT:' in answer:
            answer = answer.split('ASSISTANT:')[1].strip()
        elif 'assistant\n' in answer:
            answer = answer.split('assistant\n')[1].strip()
        elif '<|end_header_id|>\n\n' in answer:
            answer = answer.split('<|end_header_id|>\n\n')[2].strip()

        if '</s>' in answer:
            answer = answer.split('</s>')[0].strip()
        elif '<|im_end|>' in answer:
            answer = answer.split('<|im_end|>')[0].strip()
        elif '<|eot_id|>' in answer:
            answer = answer.split('<|eot_id|>')[0].strip()
        elif '<end_of_utterance>':
            answer = answer.split('<end_of_utterance>')[0].strip()
        return answer

    def generate_inner(self, message, dataset=None):
        content, images = '', []
        ide_content, question = [], ''
        for msg in message:
            if msg['type'] == 'text':
                content += msg['value']
                question += msg['value']
            else:
                images.append(Image.open(msg['value']).convert('RGB'))
                content += (self.DEFAULT_IMAGE_TOKEN + '\n')
                ide_content.append({'type': 'image'})
        if self._is_idefics:
            # Follow the idefics implementation:
            ide_content.append({'type': 'text', 'text': question})
            prompt = [{'role': 'user', 'content': ide_content}]
            prompt = self.processor.apply_chat_template(prompt, add_generation_prompt=True)
        else:
            # Follow the Mantis code base to make sure they are consistent:
            # https://github.com/TIGER-AI-Lab/Mantis/blob/main/mantis/models/mllava/utils.py#L33
            # Users don't need to define chat template as it is done here
            if 'llama-3' in self.model.language_model.name_or_path.lower():
                conv = self.conv_templates['llama_3']
                terminators = [
                    self.processor.tokenizer.eos_token_id,
                    self.processor.tokenizer.convert_tokens_to_ids('<|eot_id|>')
                ]
            else:
                conv = self.default_conv
                terminators = None

            # Using EOT because end of *text* is more accurate for what we're doing than end of *sentence*
            if 'eos_token_id' not in self.kwargs:
                self.kwargs['eos_token_id'] = terminators

            conv = conv.copy()
            conv.append_message(conv.roles[0], content)
            conv.append_message(conv.roles[1], '')
            assert conv.messages[-1][0] == conv.roles[1] and conv.messages[-1][1] == '', 'Format check'
            prompt = conv.get_prompt()

        inputs = self.processor(prompt, images, return_tensors='pt', truncation=True)
        # FIXME: Fuyu model would return a list instead of a pytorch tensor. This weird behavior needs fixing.
        if 'image_patches' in inputs.keys():
            inputs['image_patches'] = inputs['image_patches'][0]
        inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
        output = self.model.generate(**inputs, **self.kwargs)
        output = output[0]
        generated_ids = output[inputs['input_ids'].shape[-1]:]
        answer = self.processor.decode(generated_ids, skip_special_token=True)
        answer = self.output_process(answer)
        return answer

sfc-gh-mkeralapura pushed a commit to sfc-gh-mkeralapura/vllm that referenced this pull request Aug 12, 2024
…els (vllm-project#7153)

Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
kylesayrs pushed a commit to neuralmagic/vllm that referenced this pull request Aug 17, 2024
…els (vllm-project#7153)

Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
fialhocoelho pushed a commit to opendatahub-io/vllm that referenced this pull request Aug 22, 2024
…els (vllm-project#7153)

Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
Alvant pushed a commit to compressa-ai/vllm that referenced this pull request Oct 26, 2024
…els (vllm-project#7153)

Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
Signed-off-by: Alvant <alvasian@yandex.ru>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
3 participants