Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Initial Support for Chameleon #5770

Merged
merged 17 commits into from
Jul 22, 2024
Merged

Conversation

ywang96
Copy link
Member

@ywang96 ywang96 commented Jun 23, 2024

This PR kicks off the effort to add support for Chameleon - Mixed-Modal Early-Fusion Foundation Models from Meta AI. Currently its goal is to match the transformers capability to generate text only from text + images.

This PR itself adds ChameleonForConditionalGeneration for text-to-text inference. Fully functional vision language inference support with VQVAE will be added in the next PR.

Notable differences between ChameleonForConditionalGeneration and LlamaForCausalLM

  • qk_layernorm
  • swin norm

PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

@ywang96
Copy link
Member Author

ywang96 commented Jul 8, 2024

This PR is ready for review - it's based on the implementation in huggingface/transformers#31534.

Note: the transformers PR hasn't been merged yet - if there's any major update for the language model itself in that PR, I will address it in the next PR if this one gets merged before that.

@ywang96 ywang96 marked this pull request as ready for review July 8, 2024 16:17
@jacobkahn
Copy link

jacobkahn commented Jul 8, 2024

@ywang96 — note that the diff in huggingface/transformers#31534 (comment) will likely be applied before merge, but that no other changes are likely to be made.

@ywang96
Copy link
Member Author

ywang96 commented Jul 8, 2024

@ywang96 — note the the diff in huggingface/transformers#31534 (comment) will likely be applied before merge, but that no other changes are likely to be made.

@jacobkahn Thanks for the heads up!

@xwjiang2010
Copy link
Contributor

Can we add an example?
Can we add a correctness test to compare with hf result (once it's ready)? We can comment out the compare part for now and just make sure that the code has at least some CI run through.

@xwjiang2010
Copy link
Contributor

Should we add the model to the supported vlm model list?

@ywang96
Copy link
Member Author

ywang96 commented Jul 9, 2024

@xwjiang2010 Thanks for the quick review!

Can we add a correctness test to compare with hf result (once it's ready)?

The issue is that we will need to wait for the transformers release to be able to test the hf result, and so far Ive been testing it locally with that branch. I also plan to add a correctness test in the next PR where we add the full text + image support.

Should we add the model to the supported vlm model list?

Since this PR is only functional in text-to-text, I don't think we should add it to the supported vlm model list yet.

@ywang96
Copy link
Member Author

ywang96 commented Jul 10, 2024

Given there's still ongoing discussion on the original transformers PR, I'm going to mark this back to draft until that PR is merged to modify accordingly.

@ywang96 ywang96 marked this pull request as draft July 10, 2024 17:47
@AGIGOAT
Copy link

AGIGOAT commented Jul 17, 2024

@ywang96 the huggingface PR has been merged

huggingface/transformers#31534 (comment)

@ywang96
Copy link
Member Author

ywang96 commented Jul 19, 2024

This PR has passed model test locally (I haven't added the test file since transformers hasn't made a new released that officially supports ChameleonForConditionalGeneration) and is ready to be reviewed. cc @xwjiang2010.

@ywang96 ywang96 marked this pull request as ready for review July 19, 2024 23:05
@xwjiang2010
Copy link
Contributor

Thank you @ywang96.
I did some manual testing. One weird thing that shows up is that 30B in this code doesn't generate anything meaningful in my case (7B works).

Example code:

import os
import subprocess

from vllm import LLM

def run():
    # llm = LLM(model="/home/ray/.cache/huggingface/hub/models--facebook--chameleon-7b/snapshots/0f474b71631bcaa20d20d56909066e56cb8a4372/")
    llm = LLM(model="/home/ray/.cache/huggingface/hub/models--facebook--chameleon-30b/snapshots/1c95500025a2a51183c7bb8f410ecd3d84e4cb2c/")

    prompt = "Tell me about tangram."


    outputs = llm.generate({
        "prompt": prompt,
    })

    for o in outputs:
        generated_text = o.outputs[0].text
        print(generated_text)

run()

7B gives me: "A tangram is a flat, square puzzle containing seven pieces, each"
30B gives me: "oku Participationpers earliest noble inhal Wii des meritsassertilet Multi bulanRabrough"

I then switched to huggingface to try out both versions: Both gives me meaningful outputs.

@ywang96
Copy link
Member Author

ywang96 commented Jul 21, 2024

Thank you @ywang96. I did some manual testing. One weird thing that shows up is that 30B in this code doesn't generate anything meaningful in my case (7B works).

Example code:

import os
import subprocess

from vllm import LLM

def run():
    # llm = LLM(model="/home/ray/.cache/huggingface/hub/models--facebook--chameleon-7b/snapshots/0f474b71631bcaa20d20d56909066e56cb8a4372/")
    llm = LLM(model="/home/ray/.cache/huggingface/hub/models--facebook--chameleon-30b/snapshots/1c95500025a2a51183c7bb8f410ecd3d84e4cb2c/")

    prompt = "Tell me about tangram."


    outputs = llm.generate({
        "prompt": prompt,
    })

    for o in outputs:
        generated_text = o.outputs[0].text
        print(generated_text)

run()

7B gives me: "A tangram is a flat, square puzzle containing seven pieces, each" 30B gives me: "oku Participationpers earliest noble inhal Wii des meritsassertilet Multi bulanRabrough"

I then switched to huggingface to try out both versions: Both gives me meaningful outputs.

I think there's something wrong with the swin norm implementation in this PR since it's only enabled for the 30B model. Will look into it!

@ywang96
Copy link
Member Author

ywang96 commented Jul 21, 2024

@xwjiang2010 The swin norm issue is fixed now - I ran your example with chameleon-30b with the following code

from vllm import LLM, SamplingParams
import torch

model_path = "facebook/chameleon-30b"
llm = LLM(model=model_path, dtype=torch.bfloat16)

greedy_params = SamplingParams(temperature=0.0, max_tokens=100)
prompt = "Tell me about tangram."
output = llm.generate(prompt, greedy_params)

print(output[0].outputs[0].text)

This gives:

Tangram is a puzzle that consists of a square cut into five pieces: a large square, a medium-sized square, a small square, a medium-sized triangle, and a small triangle. The goal of the puzzle is to arrange the pieces into a square shape.

The tangram puzzle has been around for centuries and is believed to have originated in China. It is a popular puzzle that is enjoyed by people of all ages and is often used as a tool for developing

Copy link
Contributor

@xwjiang2010 xwjiang2010 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome! Thanks.

@ywang96 ywang96 added the ready ONLY add when PR is ready to merge/full CI is needed label Jul 21, 2024
Copy link
Collaborator

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, excited to give it a try!

@ywang96 ywang96 merged commit c9eef37 into vllm-project:main Jul 22, 2024
84 checks passed
xjpang pushed a commit to xjpang/vllm that referenced this pull request Jul 24, 2024
gnpinkert pushed a commit to gnpinkert/vllm that referenced this pull request Jul 26, 2024
cduk pushed a commit to cduk/vllm-pascal that referenced this pull request Aug 6, 2024
Alvant pushed a commit to compressa-ai/vllm that referenced this pull request Oct 26, 2024
Signed-off-by: Alvant <alvasian@yandex.ru>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants