From 07e8cf1d9a76f56d2a8550222867e9db3046f9ad Mon Sep 17 00:00:00 2001 From: Duc Pham Date: Fri, 10 Nov 2023 01:26:13 +0700 Subject: [PATCH 1/3] Improved trimming logic and OpenAI token counter --- litellm/tests/test_utils.py | 40 ++++++++++++-- litellm/utils.py | 106 ++++++++++++++++++++++++++++++------ 2 files changed, 122 insertions(+), 24 deletions(-) diff --git a/litellm/tests/test_utils.py b/litellm/tests/test_utils.py index 0e325b7381ce..ac8ab11343c1 100644 --- a/litellm/tests/test_utils.py +++ b/litellm/tests/test_utils.py @@ -1,6 +1,6 @@ import sys, os -import traceback from dotenv import load_dotenv +import copy load_dotenv() import os @@ -38,7 +38,7 @@ def test_multiple_messages_trimming(): {"role": "user", "content": "This is a long message that will exceed the token limit."}, {"role": "user", "content": "This is another long message that will also exceed the limit."} ] - trimmed_messages = trim_messages(messages=messages, model="gpt-3.5-turbo", max_tokens=20) + trimmed_messages = trim_messages(messages_copy=messages, model="gpt-3.5-turbo", max_tokens=20) # print(get_token_count(messages=trimmed_messages, model="gpt-3.5-turbo")) assert(get_token_count(messages=trimmed_messages, model="gpt-3.5-turbo")) <= 20 # test_multiple_messages_trimming() @@ -48,7 +48,7 @@ def test_multiple_messages_no_trimming(): {"role": "user", "content": "This is a long message that will exceed the token limit."}, {"role": "user", "content": "This is another long message that will also exceed the limit."} ] - trimmed_messages = trim_messages(messages=messages, model="gpt-3.5-turbo", max_tokens=100) + trimmed_messages = trim_messages(messages_copy=messages, model="gpt-3.5-turbo", max_tokens=100) print("Trimmed messages") print(trimmed_messages) assert(messages==trimmed_messages) @@ -56,14 +56,42 @@ def test_multiple_messages_no_trimming(): # test_multiple_messages_no_trimming() -def test_large_trimming(): +def test_large_trimming_multiple_messages(): messages = [{"role": "user", "content": "This is a singlelongwordthatexceedsthelimit."}, {"role": "user", "content": "This is a singlelongwordthatexceedsthelimit."},{"role": "user", "content": "This is a singlelongwordthatexceedsthelimit."},{"role": "user", "content": "This is a singlelongwordthatexceedsthelimit."},{"role": "user", "content": "This is a singlelongwordthatexceedsthelimit."}] - trimmed_messages = trim_messages(messages, max_tokens=20, model="random") + trimmed_messages = trim_messages(messages, max_tokens=20, model="gpt-4-0613") print("trimmed messages") print(trimmed_messages) - assert(get_token_count(messages=trimmed_messages, model="random")) <= 20 + assert(get_token_count(messages=trimmed_messages, model="gpt-4-0613")) <= 20 # test_large_trimming() +def test_large_trimming_single_message(): + messages = [{"role": "user", "content": "This is a singlelongwordthatexceedsthelimit."}] + trimmed_messages = trim_messages(messages, max_tokens=5, model="gpt-4-0613") + assert(get_token_count(messages=trimmed_messages, model="gpt-4-0613")) <= 5 + assert(get_token_count(messages=trimmed_messages, model="gpt-4-0613")) > 0 + + +def test_trimming_with_system_message_within_max_tokens(): + # This message is 33 tokens long + messages = [{"role": "system", "content": "This is a short system message"}, {"role": "user", "content": "This is a medium normal message, let's say litellm is awesome."}] + trimmed_messages = trim_messages(messages, max_tokens=30, model="gpt-4-0613") # The system message should fit within the token limit + assert len(trimmed_messages) == 2 + assert trimmed_messages[0]["content"] == "This is a short system message" + + +def test_trimming_with_system_message_exceeding_max_tokens(): + # This message is 33 tokens long. The system message is 13 tokens long. + messages = [{"role": "system", "content": "This is a short system message"}, {"role": "user", "content": "This is a medium normal message, let's say litellm is awesome."}] + trimmed_messages = trim_messages(messages, max_tokens=12, model="gpt-4-0613") + assert len(trimmed_messages) == 1 + assert '..' in trimmed_messages[0]["content"] + +def test_trimming_should_not_change_original_messages(): + messages = [{"role": "system", "content": "This is a short system message"}, {"role": "user", "content": "This is a medium normal message, let's say litellm is awesome."}] + messages_copy = copy.deepcopy(messages) + trimmed_messages = trim_messages(messages, max_tokens=12, model="gpt-4-0613") + assert(messages==messages_copy) + def test_get_valid_models(): old_environ = os.environ os.environ = {'OPENAI_API_KEY': 'temp'} # mock set only openai key in environ diff --git a/litellm/utils.py b/litellm/utils.py index 6b83138fe50f..11ab03472c38 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -19,6 +19,7 @@ import aiohttp import logging import asyncio +import copy from tokenizers import Tokenizer from dataclasses import ( dataclass, @@ -1101,6 +1102,50 @@ def decode(model: str, tokens: List[int]): dec = tokenizer_json["tokenizer"].decode(tokens) return dec +def openai_token_counter(messages, model="gpt-3.5-turbo-0613"): + """ + Return the number of tokens used by a list of messages. + + Borrowed from https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb. + """ + try: + encoding = tiktoken.encoding_for_model(model) + except KeyError: + print("Warning: model not found. Using cl100k_base encoding.") + encoding = tiktoken.get_encoding("cl100k_base") + if model in { + "gpt-3.5-turbo-0613", + "gpt-3.5-turbo-16k-0613", + "gpt-4-0314", + "gpt-4-32k-0314", + "gpt-4-0613", + "gpt-4-32k-0613", + }: + tokens_per_message = 3 + tokens_per_name = 1 + elif model == "gpt-3.5-turbo-0301": + tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n + tokens_per_name = -1 # if there's a name, the role is omitted + elif "gpt-3.5-turbo" in model: + print("Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.") + return openai_token_counter(messages, model="gpt-3.5-turbo-0613") + elif "gpt-4" in model: + print("Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.") + return openai_token_counter(messages, model="gpt-4-0613") + else: + raise NotImplementedError( + f"""num_tokens_from_messages() is not implemented for model {model}. See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens.""" + ) + num_tokens = 0 + for message in messages: + num_tokens += tokens_per_message + for key, value in message.items(): + num_tokens += len(encoding.encode(value)) + if key == "name": + num_tokens += tokens_per_name + num_tokens += 3 # every reply is primed with <|start|>assistant<|message|> + return num_tokens + def token_counter(model="", text=None, messages: Optional[List] = None): """ Count the number of tokens in a given text using a specified model. @@ -1121,14 +1166,17 @@ def token_counter(model="", text=None, messages: Optional[List] = None): raise ValueError("text and messages cannot both be None") num_tokens = 0 - if model is not None: + if model is not None: tokenizer_json = _select_tokenizer(model=model) if tokenizer_json["type"] == "huggingface_tokenizer": enc = tokenizer_json["tokenizer"].encode(text) num_tokens = len(enc.ids) elif tokenizer_json["type"] == "openai_tokenizer": - enc = tokenizer_json["tokenizer"].encode(text) - num_tokens = len(enc) + if messages is not None: + num_tokens = openai_token_counter(messages, model=model) + else: + enc = tokenizer_json["tokenizer"].encode(text) + num_tokens = len(enc) else: num_tokens = len(encoding.encode(text)) return num_tokens @@ -4429,7 +4477,7 @@ def completion_with_config(config: Union[dict, str], **kwargs): except: continue if prompt_larger_than_model: - messages = trim_messages(messages=messages, model=max_model) + messages = trim_messages(messages_copy=messages, model=max_model) kwargs["messages"] = messages kwargs["model"] = model @@ -4528,13 +4576,13 @@ def completion_with_fallbacks(**kwargs): def process_system_message(system_message, max_tokens, model): system_message_event = {"role": "system", "content": system_message} - system_message_tokens = get_token_count(system_message_event, model) + system_message_tokens = get_token_count([system_message_event], model) if system_message_tokens > max_tokens: print_verbose("`tokentrimmer`: Warning, system message exceeds token limit. Trimming...") # shorten system message to fit within max_tokens new_system_message = shorten_message_to_fit_limit(system_message_event, max_tokens, model) - system_message_tokens = get_token_count(new_system_message, model) + system_message_tokens = get_token_count([new_system_message], model) return system_message_event, max_tokens - system_message_tokens @@ -4544,11 +4592,15 @@ def process_messages(messages, max_tokens, model): final_messages = [] for message in messages: - final_messages = attempt_message_addition(final_messages, message, max_tokens, model) + used_tokens = get_token_count(final_messages, model) + available_tokens = max_tokens - used_tokens + if available_tokens <= 3: + break + final_messages = attempt_message_addition(final_messages=final_messages, message=message, available_tokens=available_tokens, max_tokens=max_tokens, model=model) return final_messages -def attempt_message_addition(final_messages, message, max_tokens, model): +def attempt_message_addition(final_messages, message, available_tokens, max_tokens, model): temp_messages = [message] + final_messages temp_message_tokens = get_token_count(messages=temp_messages, model=model) @@ -4558,7 +4610,7 @@ def attempt_message_addition(final_messages, message, max_tokens, model): # if temp_message_tokens > max_tokens, try shortening temp_messages elif "function_call" not in message: # fit updated_message to be within temp_message_tokens - max_tokens (aka the amount temp_message_tokens is greate than max_tokens) - updated_message = shorten_message_to_fit_limit(message, temp_message_tokens - max_tokens, model) + updated_message = shorten_message_to_fit_limit(message, available_tokens, model) if can_add_message(updated_message, final_messages, max_tokens, model): return [updated_message] + final_messages @@ -4580,6 +4632,13 @@ def shorten_message_to_fit_limit( """ Shorten a message to fit within a token limit by removing characters from the middle. """ + + # For OpenAI models, even blank messages cost 7 token, + # and if the buffer is less than 3, the while loop will never end, + # hence the value 10. + if 'gpt' in model and tokens_needed <= 10: + return message + content = message["content"] while True: @@ -4607,7 +4666,7 @@ def shorten_message_to_fit_limit( # this code is borrowed from https://github.com/KillianLucas/tokentrim/blob/main/tokentrim/tokentrim.py # Credits for this code go to Killian Lucas def trim_messages( - messages, + messages_copy, model: Optional[str] = None, trim_ratio: float = 0.75, return_response_tokens: bool = False, @@ -4628,6 +4687,7 @@ def trim_messages( """ # Initialize max_tokens # if users pass in max tokens, trim to this amount + messages_copy = copy.deepcopy(messages_copy) try: print_verbose(f"trimming messages") if max_tokens == None: @@ -4642,33 +4702,43 @@ def trim_messages( return system_message = "" - for message in messages: + for message in messages_copy: if message["role"] == "system": + system_message += '\n' if system_message else '' system_message += message["content"] - current_tokens = token_counter(model=model, messages=messages) + current_tokens = token_counter(model=model, messages=messages_copy) print_verbose(f"Current tokens: {current_tokens}, max tokens: {max_tokens}") # Do nothing if current tokens under messages if current_tokens < max_tokens: - return messages + return messages_copy #### Trimming messages if current_tokens > max_tokens - print_verbose(f"Need to trim input messages: {messages}, current_tokens{current_tokens}, max_tokens: {max_tokens}") + print_verbose(f"Need to trim input messages: {messages_copy}, current_tokens{current_tokens}, max_tokens: {max_tokens}") if system_message: system_message_event, max_tokens = process_system_message(system_message=system_message, max_tokens=max_tokens, model=model) - messages = messages + [system_message_event] - final_messages = process_messages(messages=messages, max_tokens=max_tokens, model=model) + if max_tokens == 0: # the system messages are too long + return [system_message_event] + + # Since all system messages are combined and trimmed to fit the max_tokens, + # we remove all system messages from the messages list + messages_copy = [message for message in messages_copy if message["role"] != "system"] + + final_messages = process_messages(messages=messages_copy, max_tokens=max_tokens, model=model) + + # Add system message to the beginning of the final messages + if system_message: + final_messages = [system_message_event] + final_messages if return_response_tokens: # if user wants token count with new trimmed messages response_tokens = max_tokens - get_token_count(final_messages, model) return final_messages, response_tokens - return final_messages except Exception as e: # [NON-Blocking, if error occurs just return final_messages print_verbose(f"Got exception while token trimming{e}") - return messages + return messages_copy def get_valid_models(): """ From eeac3954d55a96bc00f9538b44d446704e547c5d Mon Sep 17 00:00:00 2001 From: Duc Pham Date: Fri, 10 Nov 2023 01:35:41 +0700 Subject: [PATCH 2/3] Reverted error while refactoring --- litellm/tests/test_utils.py | 4 ++-- litellm/utils.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/litellm/tests/test_utils.py b/litellm/tests/test_utils.py index ac8ab11343c1..22a4af93c3fa 100644 --- a/litellm/tests/test_utils.py +++ b/litellm/tests/test_utils.py @@ -38,7 +38,7 @@ def test_multiple_messages_trimming(): {"role": "user", "content": "This is a long message that will exceed the token limit."}, {"role": "user", "content": "This is another long message that will also exceed the limit."} ] - trimmed_messages = trim_messages(messages_copy=messages, model="gpt-3.5-turbo", max_tokens=20) + trimmed_messages = trim_messages(messages=messages, model="gpt-3.5-turbo", max_tokens=20) # print(get_token_count(messages=trimmed_messages, model="gpt-3.5-turbo")) assert(get_token_count(messages=trimmed_messages, model="gpt-3.5-turbo")) <= 20 # test_multiple_messages_trimming() @@ -48,7 +48,7 @@ def test_multiple_messages_no_trimming(): {"role": "user", "content": "This is a long message that will exceed the token limit."}, {"role": "user", "content": "This is another long message that will also exceed the limit."} ] - trimmed_messages = trim_messages(messages_copy=messages, model="gpt-3.5-turbo", max_tokens=100) + trimmed_messages = trim_messages(messages=messages, model="gpt-3.5-turbo", max_tokens=100) print("Trimmed messages") print(trimmed_messages) assert(messages==trimmed_messages) diff --git a/litellm/utils.py b/litellm/utils.py index 11ab03472c38..2b797ec44c38 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -4477,7 +4477,7 @@ def completion_with_config(config: Union[dict, str], **kwargs): except: continue if prompt_larger_than_model: - messages = trim_messages(messages_copy=messages, model=max_model) + messages = trim_messages(messages=messages, model=max_model) kwargs["messages"] = messages kwargs["model"] = model @@ -4666,7 +4666,7 @@ def shorten_message_to_fit_limit( # this code is borrowed from https://github.com/KillianLucas/tokentrim/blob/main/tokentrim/tokentrim.py # Credits for this code go to Killian Lucas def trim_messages( - messages_copy, + messages, model: Optional[str] = None, trim_ratio: float = 0.75, return_response_tokens: bool = False, @@ -4687,7 +4687,7 @@ def trim_messages( """ # Initialize max_tokens # if users pass in max tokens, trim to this amount - messages_copy = copy.deepcopy(messages_copy) + messages_copy = copy.deepcopy(messages) try: print_verbose(f"trimming messages") if max_tokens == None: From 8e13da198cdd5fee88903f854a611ca9ba9a2085 Mon Sep 17 00:00:00 2001 From: Duc Pham Date: Fri, 10 Nov 2023 01:47:06 +0700 Subject: [PATCH 3/3] Another small refactoring --- litellm/utils.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/litellm/utils.py b/litellm/utils.py index 2b797ec44c38..821e26435e28 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -4687,7 +4687,7 @@ def trim_messages( """ # Initialize max_tokens # if users pass in max tokens, trim to this amount - messages_copy = copy.deepcopy(messages) + messages = copy.deepcopy(messages) try: print_verbose(f"trimming messages") if max_tokens == None: @@ -4702,20 +4702,20 @@ def trim_messages( return system_message = "" - for message in messages_copy: + for message in messages: if message["role"] == "system": system_message += '\n' if system_message else '' system_message += message["content"] - current_tokens = token_counter(model=model, messages=messages_copy) + current_tokens = token_counter(model=model, messages=messages) print_verbose(f"Current tokens: {current_tokens}, max tokens: {max_tokens}") # Do nothing if current tokens under messages if current_tokens < max_tokens: - return messages_copy + return messages #### Trimming messages if current_tokens > max_tokens - print_verbose(f"Need to trim input messages: {messages_copy}, current_tokens{current_tokens}, max_tokens: {max_tokens}") + print_verbose(f"Need to trim input messages: {messages}, current_tokens{current_tokens}, max_tokens: {max_tokens}") if system_message: system_message_event, max_tokens = process_system_message(system_message=system_message, max_tokens=max_tokens, model=model) @@ -4724,9 +4724,9 @@ def trim_messages( # Since all system messages are combined and trimmed to fit the max_tokens, # we remove all system messages from the messages list - messages_copy = [message for message in messages_copy if message["role"] != "system"] + messages = [message for message in messages if message["role"] != "system"] - final_messages = process_messages(messages=messages_copy, max_tokens=max_tokens, model=model) + final_messages = process_messages(messages=messages, max_tokens=max_tokens, model=model) # Add system message to the beginning of the final messages if system_message: @@ -4738,7 +4738,7 @@ def trim_messages( return final_messages except Exception as e: # [NON-Blocking, if error occurs just return final_messages print_verbose(f"Got exception while token trimming{e}") - return messages_copy + return messages def get_valid_models(): """