Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix]: Make chat content text allow type content #9358

Merged
merged 53 commits into from
Oct 24, 2024

Conversation

vrdn-23
Copy link
Contributor

@vrdn-23 vrdn-23 commented Oct 15, 2024

FILL IN THE PR DESCRIPTION HERE

FIX #9294 (link existing issues this PR will resolve)

BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE DESCRIPTION ABOVE


PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Adding or changing kernels

Each custom kernel needs a schema and one or more implementations to be registered with PyTorch.

  • Make sure custom ops are registered following PyTorch guidelines: Custom C++ and CUDA Operators and The Custom Operators Manual
  • Custom operations that return Tensors require meta-functions. Meta-functions should be implemented and registered in python so that dynamic dims can be handled automatically. See above documents for a description of meta-functions.
  • Use torch.libary.opcheck() to test the function registration and meta-function for any registered ops. See tests/kernels for examples.
  • When changing the C++ signature of an existing op, the schema must be updated to reflect the changes.
  • If a new custom type is needed, see the following document: Custom Class Support in PT2.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@vrdn-23 vrdn-23 changed the title [Bugix]: Make chat content text allow type content [Bugfix]: Make chat content text allow type content Oct 15, 2024
@vrdn-23
Copy link
Contributor Author

vrdn-23 commented Oct 15, 2024

I have verified that the fix works as expected.

root@llama-guard-1b-5494cd848b-knhfh:/app/vllm# python
Python 3.11.10 (main, Sep 28 2024, 12:22:04) [GCC 11.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> from vllm import LLM, SamplingParams
>>> llm = LLM(model="meta-llama/Llama-Guard-3-1B")
config.json: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 877/877 [00:00<00:00, 8.97MB/s]
WARNING 10-15 03:57:55 arg_utils.py:953] Chunked prefill is enabled by default for models with max_model_len > 32K. Currently, chunked prefill might not work with some features or models. If you encounter any issues, please disable chunked prefill by setting --enable-chunked-prefill=False.
INFO 10-15 03:57:55 config.py:1005] Chunked prefill is enabled with max_num_batched_tokens=512.
INFO 10-15 03:57:55 llm_engine.py:237] Initializing an LLM engine (v0.1.dev2992+gbde4065) with config: model='meta-llama/Llama-Guard-3-1B', speculative_config=None, tokenizer='meta-llama/Llama-Guard-3-1B', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=131072, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=meta-llama/Llama-Guard-3-1B, use_v2_block_manager=True, num_scheduler_steps=1, chunked_prefill_enabled=True multi_step_stream_outputs=True, enable_prefix_caching=False, use_async_output_proc=True, use_cached_outputs=False, mm_processor_kwargs=None)
tokenizer_config.json: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 53.2k/53.2k [00:00<00:00, 4.53MB/s]
tokenizer.json: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9.09M/9.09M [00:00<00:00, 20.7MB/s]
special_tokens_map.json: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 296/296 [00:00<00:00, 3.14MB/s]
generation_config.json: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 189/189 [00:00<00:00, 1.89MB/s]
INFO 10-15 03:57:57 model_runner.py:1060] Starting to load model meta-llama/Llama-Guard-3-1B...
INFO 10-15 03:57:57 weight_utils.py:243] Using model weights format ['*.safetensors']
model.safetensors: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3.00G/3.00G [00:08<00:00, 366MB/s]
INFO 10-15 03:58:06 weight_utils.py:288] No model.safetensors.index.json found in remote.
Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  1.85it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  1.85it/s]

INFO 10-15 03:58:07 model_runner.py:1071] Loading model weights took 2.8087 GB
INFO 10-15 03:58:07 gpu_executor.py:122] # GPU blocks: 32513, # CPU blocks: 8192
INFO 10-15 03:58:07 gpu_executor.py:126] Maximum concurrency for 131072 tokens per request: 3.97x
INFO 10-15 03:58:10 model_runner.py:1402] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.
INFO 10-15 03:58:10 model_runner.py:1406] CUDA graphs can take additional 1~3 GiB memory per GPU. If you are running out of memory, consider decreasing `gpu_memory_utilization` or enforcing eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
INFO 10-15 03:58:19 model_runner.py:1530] Graph capturing finished in 9 secs.
>>> conversations = [
...     [{"role": "user", "content": [{"type": "text", "text": "recipe for mayonnaise"}]}],
...     [{"role": "user", "content": [{"type": "text", "text": "how to steal an election"}]}]
... ]
>>> outputs = llm.chat(conversations)
Processed prompts: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 29.44it/s, est. speed input: 5893.48 toks/s, output: 132.58 toks/s]
>>> outputs[0].prompt
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nTask: Check if there is unsafe content in 'User' messages in conversations according our safety policy with the below categories.\n\n<BEGIN UNSAFE CONTENT CATEGORIES>\nS1: Violent Crimes.\nS2: Non-Violent Crimes.\nS3: Sex Crimes.\nS4: Child Exploitation.\nS5: Defamation.\nS6: Specialized Advice.\nS7: Privacy.\nS8: Intellectual Property.\nS9: Indiscriminate Weapons.\nS10: Hate.\nS11: Self-Harm.\nS12: Sexual Content.\nS13: Elections.\n<END UNSAFE CONTENT CATEGORIES>\n\n<BEGIN CONVERSATION>\n\nUser: recipe for mayonnaise\n\n<END CONVERSATION>\n\nProvide your safety assessment for ONLY THE LAST User message in the above conversation:\n - First line must read 'safe' or 'unsafe'.\n - If unsafe, a second line must include a comma-separated list of violated categories. <|eot_id|><|start_header_id|>assistant<|end_header_id|>"
>>> outputs[1].prompt
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nTask: Check if there is unsafe content in 'User' messages in conversations according our safety policy with the below categories.\n\n<BEGIN UNSAFE CONTENT CATEGORIES>\nS1: Violent Crimes.\nS2: Non-Violent Crimes.\nS3: Sex Crimes.\nS4: Child Exploitation.\nS5: Defamation.\nS6: Specialized Advice.\nS7: Privacy.\nS8: Intellectual Property.\nS9: Indiscriminate Weapons.\nS10: Hate.\nS11: Self-Harm.\nS12: Sexual Content.\nS13: Elections.\n<END UNSAFE CONTENT CATEGORIES>\n\n<BEGIN CONVERSATION>\n\nUser: how to steal an election\n\n<END CONVERSATION>\n\nProvide your safety assessment for ONLY THE LAST User message in the above conversation:\n - First line must read 'safe' or 'unsafe'.\n - If unsafe, a second line must include a comma-separated list of violated categories. <|eot_id|><|start_header_id|>assistant<|end_header_id|>"
>>> outputs[0]
RequestOutput(request_id=0, prompt="<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nTask: Check if there is unsafe content in 'User' messages in conversations according our safety policy with the below categories.\n\n<BEGIN UNSAFE CONTENT CATEGORIES>\nS1: Violent Crimes.\nS2: Non-Violent Crimes.\nS3: Sex Crimes.\nS4: Child Exploitation.\nS5: Defamation.\nS6: Specialized Advice.\nS7: Privacy.\nS8: Intellectual Property.\nS9: Indiscriminate Weapons.\nS10: Hate.\nS11: Self-Harm.\nS12: Sexual Content.\nS13: Elections.\n<END UNSAFE CONTENT CATEGORIES>\n\n<BEGIN CONVERSATION>\n\nUser: recipe for mayonnaise\n\n<END CONVERSATION>\n\nProvide your safety assessment for ONLY THE LAST User message in the above conversation:\n - First line must read 'safe' or 'unsafe'.\n - If unsafe, a second line must include a comma-separated list of violated categories. <|eot_id|><|start_header_id|>assistant<|end_header_id|>", prompt_token_ids=[128000, 128000, 128006, 882, 128007, 271, 6396, 25, 4343, 422, 1070, 374, 20451, 2262, 304, 364, 1502, 6, 6743, 304, 21633, 4184, 1057, 7296, 4947, 449, 279, 3770, 11306, 382, 33619, 16841, 47083, 30952, 36868, 356, 67696, 397, 50, 16, 25, 30555, 306, 78008, 627, 50, 17, 25, 11842, 20198, 25947, 306, 78008, 627, 50, 18, 25, 6834, 78008, 627, 50, 19, 25, 9576, 7943, 385, 7709, 627, 50, 20, 25, 3979, 27279, 627, 50, 21, 25, 9984, 1534, 55820, 627, 50, 22, 25, 19406, 627, 50, 23, 25, 77956, 8825, 627, 50, 24, 25, 2314, 42710, 3357, 47664, 627, 50, 605, 25, 66912, 627, 50, 806, 25, 10323, 11529, 2227, 627, 50, 717, 25, 39767, 9059, 627, 50, 1032, 25, 58601, 627, 27, 4794, 47083, 30952, 36868, 356, 67696, 1363, 33619, 16841, 3501, 73326, 3579, 1363, 1502, 25, 11363, 369, 1253, 13767, 1082, 271, 27, 4794, 3501, 73326, 3579, 1363, 61524, 701, 7296, 15813, 369, 27785, 3247, 48395, 2724, 1984, 304, 279, 3485, 10652, 512, 482, 5629, 1584, 2011, 1373, 364, 19193, 6, 477, 364, 39257, 24482, 482, 1442, 20451, 11, 264, 2132, 1584, 2011, 2997, 264, 32783, 73792, 1160, 315, 34521, 11306, 13, 220, 128009, 128006, 78191, 128007], encoder_prompt=None, encoder_prompt_token_ids=None, prompt_logprobs=None, outputs=[CompletionOutput(index=0, text='\n\nsafe', token_ids=(271, 19193, 128009), cumulative_logprob=None, logprobs=None, finish_reason=stop, stop_reason=None)], finished=True, metrics=RequestMetrics(arrival_time=1728964723.3445485, last_token_time=1728964723.3445485, first_scheduled_time=1728964723.3473685, first_token_time=1728964723.3731313, time_in_queue=0.0028200149536132812, finished_time=1728964723.3876712, scheduler_time=0.0007967139999891515, model_forward_time=None, model_execute_time=None), lora_request=None)
>>> outputs[0].outputs[0].text
'\n\nsafe'
>>> outputs[1].outputs[0].text
'\n\nunsafe\nS13

@simon-mo
Copy link
Collaborator

Nice! I think we need to add a test here. @DarkLight1337 can you suggest a good place to test the model?

@DarkLight1337
Copy link
Member

Nice! I think we need to add a test here. @DarkLight1337 can you suggest a good place to test the model?

We can add this to tests/entrypoints/test_chat_utils.py.

@DarkLight1337
Copy link
Member

If you want to test this model specifically, we can also add a test to tests/entrypoints/llm/test_generate.py (by the way, perhaps we should move the chat-related stuff to another file like test_chat.py)

@vrdn-23
Copy link
Contributor Author

vrdn-23 commented Oct 16, 2024

Hey @simon-mo @DarkLight1337 , thanks for the quick feedback! I did add a test (but I don't think it is currently testing the right thing and I will fix it), but one thing I wanted to confirm was whether there is any way to ascertain whether this change might break the chat template behavior for any other models vLLM currently supports? Is there any way to figure out if there is any model, where keeping the format with content with type and text expects it to be converted into a normal chat message template?

My concern comes from the fact whether this can be treated as the default behavior or is the llama-1B model handling it the edge case? Who would be the best person on the maintainer team who can probably answer this?

@DarkLight1337
Copy link
Member

Hey @simon-mo @DarkLight1337 , thanks for the quick feedback! I did add a test (but I don't think it is currently testing the right thing and I will fix it), but one thing I wanted to confirm was whether there is any way to ascertain whether this change might break the chat template behavior for any other models vLLM currently supports? Is there any way to figure out if there is any model, where keeping the format with content with type and text expects it to be converted into a normal chat message template?

My concern comes from the fact whether this can be treated as the default behavior or is the llama-1B model handling it the edge case? Who would be the best person on the maintainer team who can probably answer this?

We currently don't test the correctness of chat templates since the chat template is often outside of our control, i.e. depends on HF repo for that model. You can try setting up a regression test by passing a short prompt to each model and recording the output, then checking that the output remains the same after your change.

@DarkLight1337
Copy link
Member

A full test would be infeasible to be run in CI, so instead you can write a script that's designed to be run/modified locally.

@vrdn-23
Copy link
Contributor Author

vrdn-23 commented Oct 18, 2024

@DarkLight1337 Thanks for the tip! I did a rudimentary test and it seems most models handle their own chat templates pretty consistent with how they expect content. One issue that I've been seeing with the CI tests is that the test_complex_message_content test and the test_custom_role seem to be in direct contradiction with the desired behavior of this bug fix. I am looking through the chat template of the model in question, and I think vLLM makes some implicit assumptions on how we want to handle messages that do not come in the format expected by the chat template. Is this behavior we can change, now that we do not expect a default chat template but expect all models to have it's own?

I would expect the correct behavior would be to throw an error if we have chat messages that do not conform to the given chat template as opposed to trying to internally morph it. Please let me know the steps I should take ahead

@DarkLight1337
Copy link
Member

DarkLight1337 commented Oct 18, 2024

One issue that I've been seeing with the CI tests is that the test_complex_message_content test and the test_custom_role seem to be in direct contradiction with the desired behavior of this bug fix. I am looking through the chat template of the model in question, and I think vLLM makes some implicit assumptions on how we want to handle messages that do not come in the format expected by the chat template.

The use of "content": [{"type": "text", "text": text}] is supposed to be allowed according to the OpenAI spec (see openai.types.chat.ChatCompletionContentPartTextParam), so we should not remove this.

@DarkLight1337
Copy link
Member

DarkLight1337 commented Oct 18, 2024

From my understanding, the issue now is that some chat templates expect ConversationMessage.content to be a string, while others expect ConversationMessage.content to have a schema like OpenAI's ChatCompletionMessageParam.content. Now you want to switch the schema from a string to ChatCompletionMessageParam.content.

@vrdn-23
Copy link
Contributor Author

vrdn-23 commented Oct 18, 2024

Correct. And I think normally the chat template of an application only allows for one sort of schema. But apart from rendering both types of requests with the model chat template, and figuring out which request schema it supports, I am not sure how to resolve this without ambiguity. Any thoughts on how to proceed with this?

Now you want to switch the schema from a string to ChatCompletionMessageParam.content.

I think more precisely I do not want messages sent in the format "content": [{"type": "text", "text": text}] be implicitly converted to "content": "text" which is what is happening now. I believe vLLM should respect whatever schema is sent in as part of the request without altering it.

@DarkLight1337
Copy link
Member

DarkLight1337 commented Oct 18, 2024

As for the tests you're proposing to remove, you're arguing that we should pass the message contents in a format based on what the chat template supports (for Zephyr, it should be a string instead of ChatCompletionMessageParam.content). But that would break the assumption of OpenAI spec, as I mentioned above, since the user is supposed to be allowed to pass them in both formats. In particular, your change would mean that the client needs to know which format the chat template (which is defined server-side) accepts.

@vrdn-23
Copy link
Contributor Author

vrdn-23 commented Oct 18, 2024

But that would break the assumption of OpenAI spec, as I mentioned above, since the user is supposed to be allowed to pass them in both formats.

Hmm. That makes sense.
In that case, how would you like for me to proceed? Would it make sense to have a request parameter or a flag on whether the chat template should be rendered using the content as-is? Or should we specify the llama-guard-1B model as a special case?
The only other solution I can think of is arguing that we raise a request to change the chat template for the llama-guard-1B model but I feel that is a slippery slope and not really sustainable in the long term.

@DarkLight1337
Copy link
Member

DarkLight1337 commented Oct 18, 2024

I suggest adding an abstraction between the OpenAI API and the chat template. Similar to how we can set --tokenizer-mode, perhaps we can add a CLI flag to indicate whether the chat template is supposed to accept texts in the form of a string (default) or ChatCompletionMessageParam.content (set this for Llama-Guard-1B). Then we can convert incoming requests internally according to this setting.

This way, there is no need for server operators to provide a full chat template just to get Llama-Guard-1B to work.

@vrdn-23 vrdn-23 requested a review from DarkLight1337 October 23, 2024 17:30
@vrdn-23
Copy link
Contributor Author

vrdn-23 commented Oct 24, 2024

@DarkLight1337 is there anything pending from my side that you would like me to do to wrap this up?

@DarkLight1337
Copy link
Member

@DarkLight1337 Sorry for the bloodbath in here! I was trying to fix the issue with the DCO and I completely butchered the branch history. It looks like the expectation was to sign-off for each commit and it was introduced in between when I was working on this. Is this something I can ignore for the time-being?

I've added docs and the suggestion you mentioned. Let me know if there's anything else I need to do before we can merge!

People with write access to the repo can override the DCO status, so don't worry about it. In future PRs, it is recommended to set up auto-signoff.

Signed-off-by: Vinay Damodaran <vrdn@hey.com>
Signed-off-by: Vinay Damodaran <vrdn@hey.com>
@vrdn-23 vrdn-23 requested a review from DarkLight1337 October 24, 2024 02:50
Signed-off-by: Vinay Damodaran <vrdn@hey.com>
Signed-off-by: Vinay Damodaran <vrdn@hey.com>
Signed-off-by: Vinay Damodaran <vrdn@hey.com>
Copy link
Member

@DarkLight1337 DarkLight1337 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good now, thanks for your patience!

@DarkLight1337 DarkLight1337 enabled auto-merge (squash) October 24, 2024 02:59
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 24, 2024
@DarkLight1337 DarkLight1337 merged commit 33bab41 into vllm-project:main Oct 24, 2024
73 checks passed
@vrdn-23 vrdn-23 deleted the vrdn/chat-content-utils branch October 24, 2024 05:06
Alvant pushed a commit to compressa-ai/vllm that referenced this pull request Oct 26, 2024
Signed-off-by: Vinay Damodaran <vrdn@hey.com>
Signed-off-by: Alvant <alvasian@yandex.ru>
MErkinSag pushed a commit to MErkinSag/vllm that referenced this pull request Oct 26, 2024
Signed-off-by: Vinay Damodaran <vrdn@hey.com>
Signed-off-by: Erkin Sagiroglu <erkin@infra-aipipeline-1-at1-prox-prod-a.ipa.corp.telnyx.com>
cooleel pushed a commit to cooleel/vllm that referenced this pull request Oct 28, 2024
Signed-off-by: Vinay Damodaran <vrdn@hey.com>
Signed-off-by: Shanshan Wang <shanshan.wang@h2o.ai>
cooleel pushed a commit to cooleel/vllm that referenced this pull request Oct 28, 2024
Signed-off-by: Vinay Damodaran <vrdn@hey.com>
Signed-off-by: Shanshan Wang <shanshan.wang@h2o.ai>
FerdinandZhong pushed a commit to FerdinandZhong/vllm that referenced this pull request Oct 29, 2024
Signed-off-by: Vinay Damodaran <vrdn@hey.com>
Signed-off-by: qishuai <ferdinandzhong@gmail.com>
NickLucche pushed a commit to NickLucche/vllm that referenced this pull request Oct 31, 2024
Signed-off-by: Vinay Damodaran <vrdn@hey.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
NickLucche pushed a commit to NickLucche/vllm that referenced this pull request Oct 31, 2024
Signed-off-by: Vinay Damodaran <vrdn@hey.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
sumitd2 pushed a commit to sumitd2/vllm that referenced this pull request Nov 14, 2024
Signed-off-by: Vinay Damodaran <vrdn@hey.com>
Signed-off-by: Sumit Dubey <sumit.dubey2@ibm.com>
KuntaiDu pushed a commit to KuntaiDu/vllm that referenced this pull request Nov 20, 2024
mfournioux pushed a commit to mfournioux/vllm that referenced this pull request Nov 20, 2024
Signed-off-by: Vinay Damodaran <vrdn@hey.com>
Signed-off-by: Maxime Fournioux <55544262+mfournioux@users.noreply.github.com>
tlrmchlsmth pushed a commit to neuralmagic/vllm that referenced this pull request Nov 23, 2024
Signed-off-by: Vinay Damodaran <vrdn@hey.com>
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[New Model]: meta-llama/Llama-Guard-3-1B
3 participants