Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incrementally decode output tokens #121

Merged
merged 9 commits into from
May 24, 2023
Merged

Incrementally decode output tokens #121

merged 9 commits into from
May 24, 2023

Conversation

WoosukKwon
Copy link
Collaborator

@WoosukKwon WoosukKwon commented May 23, 2023

Fixes #119

This PR reduces the overhead of the tokenizer by incrementally decoding the output tokens at every step. Specifically, the detokenization process is split into two stages:

  1. Convert token ids to token strings (i.e., List[int] -> List[str])
  2. Concatenate the token strings to the output text (i.e., List[str] -> str)

In the new method introduced by this PR, the first stage is performed incrementally: the new token id is decoded into a string token and is appended to the list of previous string output tokens. Meanwhile, this PR does not change the second stage: the tokenizer will transform the whole output tokens into the output text at every step. IIUC, this stage cannot be performed incrementally.

# opt-13b inference latency (bs 8, input 32, output 128)
# Main
Avg latency: 3.57 seconds
Tokenizer (fast): 0.14 seconds
# This PR
Avg latency: 3.47 seconds
Tokenizer (fast): 0.03 seconds

# llama-13b inference latency (bs 8, input 32, output 128)
# Main
Avg latency: 5.28 seconds
Tokenizer (slow): 1.97 seconds
# This PR
Avg latency: 3.77 seconds
Tokenizer: 0.48 seconds

Note that there are still some overheads from the tokenizer in case of LLaMA.

@WoosukKwon WoosukKwon requested a review from zhuohan123 May 23, 2023 02:02
@WoosukKwon
Copy link
Collaborator Author

@zhuohan123 This PR is ready for review.

Comment on lines 57 to 58
if not (hasattr(tokenizer, "added_tokens_encoder") and
tokenizer.added_tokens_encoder):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit:

Suggested change
if not (hasattr(tokenizer, "added_tokens_encoder") and
tokenizer.added_tokens_encoder):
if not getattr(tokenizer, "added_tokens_encoder", None):

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In addition, this optimization seems non-necessary? If the tokenizer has an empty added_tokens_encoder, the code below also only calls convert_tokens_to_string once on the final current_sub_text.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the simplification. Changed the code.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In addition, this optimization seems non-necessary? If the tokenizer has an empty added_tokens_encoder, the code below also only calls convert_tokens_to_string once on the final current_sub_text.

@zhuohan123 Surprisingly, the optimization does reduce the latency from 4.4 s to 3.8 s. The reason is the python loop over output_tokens takes some time even if the loop body is very simple. I've added more comments on this.

cacheflow/server/tokenizer_utils.py Outdated Show resolved Hide resolved
@WoosukKwon WoosukKwon requested a review from zhuohan123 May 24, 2023 01:19
Copy link
Member

@zhuohan123 zhuohan123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@WoosukKwon WoosukKwon merged commit e867178 into main May 24, 2023
@WoosukKwon WoosukKwon deleted the fix-tokenizer branch May 24, 2023 03:46
hongxiayang pushed a commit to hongxiayang/vllm that referenced this pull request Feb 13, 2024
yukavio pushed a commit to yukavio/vllm that referenced this pull request Jul 3, 2024
SUMMARY:
Turns back on the marlin tests. Issue was that vllm was not properly
tearing itself down. Calling the gc explicitly seems to have resolved
this in the short term.

In general, we should get to the bottom of why vllm does not shut down
cleanly.

TEST PLAN:
Automation
jikunshang pushed a commit to jikunshang/vllm that referenced this pull request Aug 15, 2024
mht-sharma added a commit to mht-sharma/vllm that referenced this pull request Aug 15, 2024
* Fixed single GPU issue without setting up mp. Added toggles for server request batching parameters (vllm-project#114)

* Fixed single GPU issue without setting up mp. Added toggles for server request batching parameters

* Adding HTTP headers

* Add distributed executor backend to benchmark scripts (vllm-project#118)

* Add weight padding for moe (vllm-project#119)

* add weight padding for moe

* enable padding by default

* fix linter

* fix linter

* fix linter

* using envs.py

* fix linter

* [BugFix] Fix navi build after many custom for MI kernels added (vllm-project#116)

* fix navi build

* Created dummy kernels of unsupported on Navi to avoid function not found crashes at runtime

* replacing ifdefs on host code with those on kernels

* refactoring code to avoid unsupported call on Navi

* syntactic change

* import statements fix

* moving env variables to envs.py

* style fixes

* cosmetic changes for isort

* remved extra include

* moving use_skinny to be member

---------

Co-authored-by: lcskrishna <lollachaitanya@gmail.com>
Co-authored-by: maleksan85 <maleksan@amd.com>
Co-authored-by: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com>

* add emtpy_cache() after each padding (vllm-project#120)

* [FIX] Gradlib OOM on Navi and sometimes on MI (vllm-project#124)

* add memory clean up after every shape and parameter to reduce cache invalidation buffers

* small typo

* syntax change

---------

Co-authored-by: maleksan85 <maleksan@amd.com>

* save shape when fp8 solution not found (vllm-project#123)

Co-authored-by: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com>

* Fix unit test for moe by adding padding (vllm-project#128)

* fix test_moe

* fix linter

* Llama3.1 (vllm-project#129)

* Add support for a rope extension method (vllm-project#6553)

* [BugFix] Fix RoPE error in Llama 3.1 (vllm-project#6693)

---------

Co-authored-by: Simon Mo <simon.mo@hey.com>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>

* chat/completions endpoint (vllm-project#121)

* Initial implementation of chat/completions endpoint and its streaming variant

* Reusing datatypes from the openai entrypoints

* Response role from arg

* Added models endpoint and model validation from the request

* Optimize custom all reduce (vllm-project#130)

* First version

* Revert error.

While there, add missing finalize.

* Use the correct defaults for ROCm.

Increase sampling area to capture crossover.

* Scope end_sync as well.

* Guard only volatile keyword for ifndef USE_ROCM

* Document crossover

* Add BF16 support to custom PA (vllm-project#133)

* tightened atol for custom PA; enable supported head size, block sizes in testing

* update num_blocks and num_iters in benchmark PA to realistic settings

* move to generic b16 type

* bf16 first port

* enabled all bf16 tests, set atol for bf16

* enable custom PA for bf16 as well as block size 32 and head size 64

* fix cast to zero in custom PA reduce

* py linter fixes

* clang format fixes

* div round up clang-format

---------

Co-authored-by: Charlie Fu <Charlie.Fu@amd.com>
Co-authored-by: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com>

* Making check for output match in original types. It saves some memory. (vllm-project#135)

Co-authored-by: maleksan85 <maleksan@amd.com>

* Make CAR ROCm 6.1 compatible. (vllm-project#137)

* remove scoping
* while there fix a typo
* while there remove unused variable

* Car revert (vllm-project#140)

* Per @iotamudelta suggestion until the deadlocks issue is better understood
Revert "Make CAR ROCm 6.1 compatible. (vllm-project#137)"

This reverts commit 4d2dda6.

* Per @iotamudelta suggestion until the deadlocks issue is better understood
Revert "Optimize custom all reduce (vllm-project#130)"

This reverts commit 636ff01.

---------

Co-authored-by: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com>
Co-authored-by: Matt Wong <156021403+mawong-amd@users.noreply.github.com>
Co-authored-by: Charlie Fu <Charlie.Fu@amd.com>
Co-authored-by: Aleksandr Malyshev <164964928+maleksan85@users.noreply.github.com>
Co-authored-by: lcskrishna <lollachaitanya@gmail.com>
Co-authored-by: maleksan85 <maleksan@amd.com>
Co-authored-by: Simon Mo <simon.mo@hey.com>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: iotamudelta <dieterich@ogolem.org>
Co-authored-by: sanyalington <shomy.sanyal@amd.com>
mht-sharma added a commit to mht-sharma/vllm that referenced this pull request Aug 21, 2024
* Fixed single GPU issue without setting up mp. Added toggles for server request batching parameters (vllm-project#114)

* Fixed single GPU issue without setting up mp. Added toggles for server request batching parameters

* Adding HTTP headers

* Add distributed executor backend to benchmark scripts (vllm-project#118)

* Add weight padding for moe (vllm-project#119)

* add weight padding for moe

* enable padding by default

* fix linter

* fix linter

* fix linter

* using envs.py

* fix linter

* [BugFix] Fix navi build after many custom for MI kernels added (vllm-project#116)

* fix navi build

* Created dummy kernels of unsupported on Navi to avoid function not found crashes at runtime

* replacing ifdefs on host code with those on kernels

* refactoring code to avoid unsupported call on Navi

* syntactic change

* import statements fix

* moving env variables to envs.py

* style fixes

* cosmetic changes for isort

* remved extra include

* moving use_skinny to be member

---------

Co-authored-by: lcskrishna <lollachaitanya@gmail.com>
Co-authored-by: maleksan85 <maleksan@amd.com>
Co-authored-by: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com>

* add emtpy_cache() after each padding (vllm-project#120)

* [FIX] Gradlib OOM on Navi and sometimes on MI (vllm-project#124)

* add memory clean up after every shape and parameter to reduce cache invalidation buffers

* small typo

* syntax change

---------

Co-authored-by: maleksan85 <maleksan@amd.com>

* save shape when fp8 solution not found (vllm-project#123)

Co-authored-by: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com>

* Fix unit test for moe by adding padding (vllm-project#128)

* fix test_moe

* fix linter

* Llama3.1 (vllm-project#129)

* Add support for a rope extension method (vllm-project#6553)

* [BugFix] Fix RoPE error in Llama 3.1 (vllm-project#6693)

---------

Co-authored-by: Simon Mo <simon.mo@hey.com>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>

* chat/completions endpoint (vllm-project#121)

* Initial implementation of chat/completions endpoint and its streaming variant

* Reusing datatypes from the openai entrypoints

* Response role from arg

* Added models endpoint and model validation from the request

* Optimize custom all reduce (vllm-project#130)

* First version

* Revert error.

While there, add missing finalize.

* Use the correct defaults for ROCm.

Increase sampling area to capture crossover.

* Scope end_sync as well.

* Guard only volatile keyword for ifndef USE_ROCM

* Document crossover

* Add BF16 support to custom PA (vllm-project#133)

* tightened atol for custom PA; enable supported head size, block sizes in testing

* update num_blocks and num_iters in benchmark PA to realistic settings

* move to generic b16 type

* bf16 first port

* enabled all bf16 tests, set atol for bf16

* enable custom PA for bf16 as well as block size 32 and head size 64

* fix cast to zero in custom PA reduce

* py linter fixes

* clang format fixes

* div round up clang-format

---------

Co-authored-by: Charlie Fu <Charlie.Fu@amd.com>
Co-authored-by: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com>

* Making check for output match in original types. It saves some memory. (vllm-project#135)

Co-authored-by: maleksan85 <maleksan@amd.com>

* Make CAR ROCm 6.1 compatible. (vllm-project#137)

* remove scoping
* while there fix a typo
* while there remove unused variable

* Car revert (vllm-project#140)

* Per @iotamudelta suggestion until the deadlocks issue is better understood
Revert "Make CAR ROCm 6.1 compatible. (vllm-project#137)"

This reverts commit 4d2dda6.

* Per @iotamudelta suggestion until the deadlocks issue is better understood
Revert "Optimize custom all reduce (vllm-project#130)"

This reverts commit 636ff01.

* Using the correct datatypes for streaming non-chat completions (vllm-project#134)

* Adding UNREACHABLE_CODE macro for non MI300 and MI250 cards (vllm-project#138)

* Adding UNREACHABLE_CODE macro

* clang format fixes

* clang formatting fix

* minor updates in syntax

* clang format update

* clang format fix one more try

* clang format one more try

* clang format fix one more try

---------

Co-authored-by: Aleksandr Malyshev <maleksan@amd.com>

* gfx90a typo fix (vllm-project#142)

Co-authored-by: maleksan85 <maleksan@amd.com>

* wvsplitk templatized and better tuned for MI300 (vllm-project#132)

* improvements to wvSpltK

* wvsplt gemm; better handle MI300 and large A[] sizes

* lint fix

* Adjustments to better handle small weights in TP8.

* early-out bug fix

* better wave load balancing in wvSplt

* add missing skip for wvsplt_big

* Bug fix for wvSplt_big in load balancing at M4, lint fix.

* [Bugfix] Dockerfile.rocm (vllm-project#141)

* Dockerfile.rocm bug fix

* naming preference

---------

Co-authored-by: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com>

* Update test-template.j2 (vllm-project#145)

* Adding Triton implementations awq_dequantize and awq_gemm to ROCm (vllm-project#136)

* basic support for AWQ added
* awq_dequantize implementation in Triton
* awq_gemm implementation in Triton
* unit tests in tests/kernels/test_awq_triton.py

---------

Co-authored-by: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com>
Co-authored-by: Matt Wong <156021403+mawong-amd@users.noreply.github.com>
Co-authored-by: Charlie Fu <Charlie.Fu@amd.com>
Co-authored-by: Aleksandr Malyshev <164964928+maleksan85@users.noreply.github.com>
Co-authored-by: lcskrishna <lollachaitanya@gmail.com>
Co-authored-by: maleksan85 <maleksan@amd.com>
Co-authored-by: Simon Mo <simon.mo@hey.com>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: iotamudelta <dieterich@ogolem.org>
Co-authored-by: sanyalington <shomy.sanyal@amd.com>
Co-authored-by: Hashem Hashemi <159079214+amd-hhashemi@users.noreply.github.com>
Co-authored-by: Zachary Streeter <90640993+zstreet87@users.noreply.github.com>
Co-authored-by: omkar kakarparthi <75638701+okakarpa@users.noreply.github.com>
Co-authored-by: rasmith <Randall.Smith@amd.com>
mht-sharma pushed a commit to mht-sharma/vllm that referenced this pull request Oct 30, 2024
* Initial implementation of chat/completions endpoint and its streaming variant

* Reusing datatypes from the openai entrypoints

* Response role from arg

* Added models endpoint and model validation from the request
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Tokenizer overhead is significant when use_fast=False
2 participants