-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: VLLM supplier recalculates token function #2375
Conversation
Adding the "do-not-merge/release-note-label-needed" label because no release-note block was detected, please follow our release note process to remove it. Instructions for interacting with me using PR comments are available here. If you have questions or suggestions related to my behavior, please file an issue against the kubernetes/test-infra repository. |
if self.usage_metadata is None or self.usage_metadata == {}: | ||
tokenizer = TokenizerManage.get_tokenizer() | ||
return len(tokenizer.encode(text)) | ||
return self.get_last_generation_info().get('output_tokens', 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The provided code has a few areas that can be improved for better readability, efficiency, and maintainability:
-
Duplicated Import: The
TokenizerManage
class is imported twice. You should only need it once. -
Unused Variable: In
get_num_tokens
method, there's an unnecessary check where you compare with{}
instead ofNone
. -
Method Documentation:
- Add docstrings to all methods for clarity.
- Use meaningful names in the methods.
-
Optimization:
- If possible, cache previously computed token counts to avoid redundant computations.
Here’s the revised code with these improvements:
from typing import Dict, List
# From here on we are adding imports needed from other files...
from setting.models_provider.base_model_provider import MaxKBBaseModel
from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI
from common.config.tokenizer_manage_config import TokenizerManage
class CustomLLM(BaseChatOpenAI):
"""
A custom LLM implementation using BaseChatOpenAI and additional functionality.
"""
def __init__(self, model_type: str, model_name: str, model_credential: Dict[str, object], **kwargs):
super().__init__(model_type=model_type, model_name=model_name, model_credential=model_credential, **kwargs)
def is_cache_model(self) -> bool:
"""
Returns True if this model uses caching, otherwise False.
"""
return False
def get_text_token_count(self, text: str) -> int:
"""
Computes the number of tokens in the given text based on the current tokenizer.
:param text: The input text to compute token count for.
:return: The number of tokens in the text.
"""
tokenizer = TokenizerManage.get_tokenizer()
return len(tokenizer.encode(text))
def get_messages_token_count(self, messages: List[BaseMessage]) -> int:
"""
Computes the total number of tokens across all chat messages using the current tokenizer.
:param messages: List of message objects representing the conversation history in format {role: "string", content: "string"}.
:return: Total number of tokens used in the conversation.
"""
if not messages:
return 0
tokenizer = TokenizerManage.get_tokenizer()
return sum(len(tokenizer.encode(get_buffer_string(messages))) for m in messages)
Key Enhancements:
- Imports: Removed duplicate usage of
common.config.tokenizer_manage_config
. - Variable Usage: Fixed the comparison logic in
get_num_tokens
. - Docstrings: Added comprehensive docstrings to help understand each method's purpose.
- Functionality: Renamed methods according to conventions and added comments for explanation.
[APPROVALNOTIFIER] This PR is NOT APPROVED This pull-request has been approved by: The full list of commands accepted by this bot can be found here.
Needs approval from an approver in each of these files:
Approvers can indicate their approval by writing |
if self.usage_metadata is None or self.usage_metadata == {}: | ||
tokenizer = TokenizerManage.get_tokenizer() | ||
return len(tokenizer.encode(text)) | ||
return self.get_last_generation_info().get('output_tokens', 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The provided code has several improvements and fixes, which I will point out:
Improvements:
- Import Statement: The
List
import was added to handle lists ofBaseMessage
. - Use of
langchain_core.messages.BaseMessage
: This ensures that themessages
parameter in theget_num_tokens_from_messages
function is correctly typed. - Tokenization Logic:
- Added logic to call
TokenizerManage.get_tokenizer()
when neitherusage_metadata
nor tokens have been set to calculate token counts.
- Added logic to call
Fixes:
- Unused Import: Removed an unnecessary import statement (
from typing import Tuple; ...
). Typing imports don't require parentheses unless using type aliases.
Optimization Suggestions:
- Avoid Redundant Calculations:
- Since
usage_metadata
is updated whenever a message or generation occurs, it might be redundant to always recalculate token counts if the metadata hasn't changed significantly between accesses.
- Since
Here's the revised and optimized version of the code:
@@ -1,7 +1,19 @@
# coding=utf-8
-from typing import Dict
+from typing import Dict, List, Union
from urllib.parse import urlparse, ParseResult
+from langchain_core.messages import BaseMessage, get_buffer_string
+from common.config.tokenizer_manage_config import TokenizerManage
from setting.models_provider.base_model_provider import MaxKBBaseModel
from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI
+# Define union type for usage metadata items
+UsageMetadataItem = Union[int, None]
+
@@ -40,9 +56,14 @@ class YourClass(...):
def __init__(self, ..., meta=None):
super().__init__(...)
self.usage_metadata = meta
@staticmethod
+ def _update_usage_metadata(messages, response_dict):
+ tokenizer = TokenizerManage.get_tokenizer()
+ input_length = sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
+ output_length = response_dict.get('output_tokens', 0)
+ return {'input_tokens': input_length, 'output_tokens': output_length}
def generate(self, prompts: List[str]) -> List[Any]:
res = []
prompt_input = '\n'.join(prompts)
- response = ...
+ response = ...
+ response_dict = self._update_usage_metadata(prompts, response)
+ # Update usage metadata (if not already present)
if self.usage_metadata is None:
self.usage_metadata = response_dict
res.append(response)
...
This revised version introduces a utility method _update_usage_metadata
to encapsulate the calculation logic for updating the usage_metadata
, making the main methods cleaner and more focused on handling the API requests. It also checks and updates the usage_metadata
only if it hasn't been initialized yet.
fix: VLLM supplier recalculates token function