Skip to content

Conversation

@aarnphm
Copy link
Collaborator

@aarnphm aarnphm commented Mar 11, 2025

Previously, we obtained vocab_size for xgrammar from hf_text_config directly.

However, in the recent version of xgrammar, the detected vocab_size now include special_tokens, in which raises the issue found in #14534

By calculating the vocab size, it ensures supporting custom tokenizers with the like of Olmo, etc.

@aarnphm aarnphm requested review from mgoin and russellb as code owners March 11, 2025 17:34
@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could use len(tokenizer), which is cached: https://github.com/vllm-project/vllm/blob/main/vllm/transformers_utils/tokenizer.py#L101

Or it might be better to use tokenizer.max_token_id. I have seen cases where the vocab size is actually larger than the number of tokens in the tokenizer since some were removed.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will need @Ubospica confirmation here.

@aarnphm
Copy link
Collaborator Author

aarnphm commented Mar 11, 2025

Will add tests once #14625 is merged

@aarnphm aarnphm force-pushed the v1/molmo-aria-support-vocab branch 2 times, most recently from 8fa2aa6 to a53c058 Compare March 12, 2025 00:20
@aarnphm aarnphm changed the title [V1][Core] calculating vocab_size from given tokenizer [V1][Core] using cached vocab_size for Structured Outputs Mar 12, 2025
@mergify
Copy link

mergify bot commented Mar 12, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @aarnphm.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 12, 2025
aarnphm added 2 commits March 12, 2025 03:49
Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
@aarnphm aarnphm force-pushed the v1/molmo-aria-support-vocab branch from 689278a to 8872481 Compare March 12, 2025 07:50
@mergify mergify bot removed the needs-rebase label Mar 12, 2025
lora_config=self.vllm_config.lora_config) # type: ignore[arg-type]
tokenizer_group.ping()

tokenizer = tokenizer_group.get_lora_tokenizer(None)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not directly related to this PR but it looks like structured output currently isn't compatible with custom vocab LoRAs.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sg, will add to the list for compatibility

Fwiw i don't think it ever worked with LoRA

@njhill njhill added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 13, 2025
@aarnphm aarnphm added this to the v0.8.0 milestone Mar 13, 2025
@njhill njhill merged commit 8a4a2ef into vllm-project:main Mar 13, 2025
42 checks passed
@aarnphm aarnphm deleted the v1/molmo-aria-support-vocab branch March 13, 2025 22:35
russellb added a commit to russellb/vllm that referenced this pull request Mar 14, 2025
When testing with V1 structured output + Llama-3.1-8B-Instruct, the
changes made in vllm-project#14630 broke for me. I get the error:

```
ERROR 03-14 15:44:42 [core.py:337]   File "/home/rbryant/vllm/vllm/v1/structured_output/__init__.py", line 77, in _delayed_init
ERROR 03-14 15:44:42 [core.py:337]     tokenizer_info = xgr.TokenizerInfo.from_huggingface(
ERROR 03-14 15:44:42 [core.py:337]   File "/home/rbryant/vllm/venv/lib/python3.10/site-packages/xgrammar/tokenizer_info.py", line 184, in from_huggingface
ERROR 03-14 15:44:42 [core.py:337]     raise ValueError(msg)
ERROR 03-14 15:44:42 [core.py:337] ValueError: Input vocab_size less than minimum viable vocab size for tokenizer <class 'vllm.transformers_utils.tokenizer.get_cached_tokenizer.<locals>.CachedTokenizer'>.
ERROR 03-14 15:44:42 [core.py:337]
```

The vocab size was off by one. The max token ID is not == the vocab
size, since 0 is also a token ID. It's the max token ID + 1.

Signed-off-by: Russell Bryant <rbryant@redhat.com>
richardsliu pushed a commit to richardsliu/vllm that referenced this pull request Mar 14, 2025
…ct#14630)

Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
Signed-off-by: Richard Liu <ricliu@google.com>
lulmer pushed a commit to lulmer/vllm that referenced this pull request Apr 7, 2025
…ct#14630)

Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
Signed-off-by: Louis Ulmer <ulmerlouis@gmail.com>
shreyankg pushed a commit to shreyankg/vllm that referenced this pull request May 3, 2025
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
…ct#14630)

Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
Signed-off-by: Mu Huai <tianbowen.tbw@antgroup.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed structured-output v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants