Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve message trimming #787

Merged
merged 3 commits into from
Nov 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 32 additions & 4 deletions litellm/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import sys, os
import traceback
from dotenv import load_dotenv
import copy

load_dotenv()
import os
Expand Down Expand Up @@ -56,14 +56,42 @@ def test_multiple_messages_no_trimming():
# test_multiple_messages_no_trimming()


def test_large_trimming():
def test_large_trimming_multiple_messages():
messages = [{"role": "user", "content": "This is a singlelongwordthatexceedsthelimit."}, {"role": "user", "content": "This is a singlelongwordthatexceedsthelimit."},{"role": "user", "content": "This is a singlelongwordthatexceedsthelimit."},{"role": "user", "content": "This is a singlelongwordthatexceedsthelimit."},{"role": "user", "content": "This is a singlelongwordthatexceedsthelimit."}]
trimmed_messages = trim_messages(messages, max_tokens=20, model="random")
trimmed_messages = trim_messages(messages, max_tokens=20, model="gpt-4-0613")
print("trimmed messages")
print(trimmed_messages)
assert(get_token_count(messages=trimmed_messages, model="random")) <= 20
assert(get_token_count(messages=trimmed_messages, model="gpt-4-0613")) <= 20
# test_large_trimming()

def test_large_trimming_single_message():
messages = [{"role": "user", "content": "This is a singlelongwordthatexceedsthelimit."}]
trimmed_messages = trim_messages(messages, max_tokens=5, model="gpt-4-0613")
assert(get_token_count(messages=trimmed_messages, model="gpt-4-0613")) <= 5
assert(get_token_count(messages=trimmed_messages, model="gpt-4-0613")) > 0


def test_trimming_with_system_message_within_max_tokens():
# This message is 33 tokens long
messages = [{"role": "system", "content": "This is a short system message"}, {"role": "user", "content": "This is a medium normal message, let's say litellm is awesome."}]
trimmed_messages = trim_messages(messages, max_tokens=30, model="gpt-4-0613") # The system message should fit within the token limit
assert len(trimmed_messages) == 2
assert trimmed_messages[0]["content"] == "This is a short system message"


def test_trimming_with_system_message_exceeding_max_tokens():
# This message is 33 tokens long. The system message is 13 tokens long.
messages = [{"role": "system", "content": "This is a short system message"}, {"role": "user", "content": "This is a medium normal message, let's say litellm is awesome."}]
trimmed_messages = trim_messages(messages, max_tokens=12, model="gpt-4-0613")
assert len(trimmed_messages) == 1
assert '..' in trimmed_messages[0]["content"]

def test_trimming_should_not_change_original_messages():
messages = [{"role": "system", "content": "This is a short system message"}, {"role": "user", "content": "This is a medium normal message, let's say litellm is awesome."}]
messages_copy = copy.deepcopy(messages)
trimmed_messages = trim_messages(messages, max_tokens=12, model="gpt-4-0613")
assert(messages==messages_copy)

def test_get_valid_models():
old_environ = os.environ
os.environ = {'OPENAI_API_KEY': 'temp'} # mock set only openai key in environ
Expand Down
90 changes: 80 additions & 10 deletions litellm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import aiohttp
import logging
import asyncio
import copy
from tokenizers import Tokenizer
from dataclasses import (
dataclass,
Expand Down Expand Up @@ -1101,6 +1102,50 @@ def decode(model: str, tokens: List[int]):
dec = tokenizer_json["tokenizer"].decode(tokens)
return dec

def openai_token_counter(messages, model="gpt-3.5-turbo-0613"):
"""
Return the number of tokens used by a list of messages.

Borrowed from https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb.
"""
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
print("Warning: model not found. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")
if model in {
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k-0613",
"gpt-4-0314",
"gpt-4-32k-0314",
"gpt-4-0613",
"gpt-4-32k-0613",
}:
tokens_per_message = 3
tokens_per_name = 1
elif model == "gpt-3.5-turbo-0301":
tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
tokens_per_name = -1 # if there's a name, the role is omitted
elif "gpt-3.5-turbo" in model:
print("Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.")
return openai_token_counter(messages, model="gpt-3.5-turbo-0613")
elif "gpt-4" in model:
print("Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.")
return openai_token_counter(messages, model="gpt-4-0613")
else:
raise NotImplementedError(
f"""num_tokens_from_messages() is not implemented for model {model}. See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens."""
)
num_tokens = 0
for message in messages:
num_tokens += tokens_per_message
for key, value in message.items():
num_tokens += len(encoding.encode(value))
if key == "name":
num_tokens += tokens_per_name
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
return num_tokens
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we just have one token counter ? Can you add these changes to token_counter ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just requesting this one change ^

unless there's a strong reason to have two token counters ? @duc-phamh

Copy link
Author

@thiel-ph thiel-ph Nov 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The openai_token_counter recursively calls itself. Of course we can refactor it into one token counter, but leaving the OpenAI code separate like this would make it easier to apply future changes in case OpenAI update this function. Therefore, I believe it's better to leave it this way.

Please let me know if you think otherwise.


def token_counter(model="", text=None, messages: Optional[List] = None):
"""
Count the number of tokens in a given text using a specified model.
Expand All @@ -1121,14 +1166,17 @@ def token_counter(model="", text=None, messages: Optional[List] = None):
raise ValueError("text and messages cannot both be None")
num_tokens = 0

if model is not None:
if model is not None:
tokenizer_json = _select_tokenizer(model=model)
if tokenizer_json["type"] == "huggingface_tokenizer":
enc = tokenizer_json["tokenizer"].encode(text)
num_tokens = len(enc.ids)
elif tokenizer_json["type"] == "openai_tokenizer":
enc = tokenizer_json["tokenizer"].encode(text)
num_tokens = len(enc)
if messages is not None:
num_tokens = openai_token_counter(messages, model=model)
else:
enc = tokenizer_json["tokenizer"].encode(text)
num_tokens = len(enc)
else:
num_tokens = len(encoding.encode(text))
return num_tokens
Expand Down Expand Up @@ -4528,13 +4576,13 @@ def completion_with_fallbacks(**kwargs):

def process_system_message(system_message, max_tokens, model):
system_message_event = {"role": "system", "content": system_message}
system_message_tokens = get_token_count(system_message_event, model)
system_message_tokens = get_token_count([system_message_event], model)

if system_message_tokens > max_tokens:
print_verbose("`tokentrimmer`: Warning, system message exceeds token limit. Trimming...")
# shorten system message to fit within max_tokens
new_system_message = shorten_message_to_fit_limit(system_message_event, max_tokens, model)
system_message_tokens = get_token_count(new_system_message, model)
system_message_tokens = get_token_count([new_system_message], model)

return system_message_event, max_tokens - system_message_tokens

Expand All @@ -4544,11 +4592,15 @@ def process_messages(messages, max_tokens, model):
final_messages = []

for message in messages:
final_messages = attempt_message_addition(final_messages, message, max_tokens, model)
used_tokens = get_token_count(final_messages, model)
available_tokens = max_tokens - used_tokens
if available_tokens <= 3:
break
final_messages = attempt_message_addition(final_messages=final_messages, message=message, available_tokens=available_tokens, max_tokens=max_tokens, model=model)

return final_messages

def attempt_message_addition(final_messages, message, max_tokens, model):
def attempt_message_addition(final_messages, message, available_tokens, max_tokens, model):
temp_messages = [message] + final_messages
temp_message_tokens = get_token_count(messages=temp_messages, model=model)

Expand All @@ -4558,7 +4610,7 @@ def attempt_message_addition(final_messages, message, max_tokens, model):
# if temp_message_tokens > max_tokens, try shortening temp_messages
elif "function_call" not in message:
# fit updated_message to be within temp_message_tokens - max_tokens (aka the amount temp_message_tokens is greate than max_tokens)
updated_message = shorten_message_to_fit_limit(message, temp_message_tokens - max_tokens, model)
updated_message = shorten_message_to_fit_limit(message, available_tokens, model)
if can_add_message(updated_message, final_messages, max_tokens, model):
return [updated_message] + final_messages

Expand All @@ -4580,6 +4632,13 @@ def shorten_message_to_fit_limit(
"""
Shorten a message to fit within a token limit by removing characters from the middle.
"""

# For OpenAI models, even blank messages cost 7 token,
# and if the buffer is less than 3, the while loop will never end,
# hence the value 10.
if 'gpt' in model and tokens_needed <= 10:
return message

content = message["content"]

while True:
Expand Down Expand Up @@ -4628,6 +4687,7 @@ def trim_messages(
"""
# Initialize max_tokens
# if users pass in max tokens, trim to this amount
messages = copy.deepcopy(messages)
try:
print_verbose(f"trimming messages")
if max_tokens == None:
Expand All @@ -4644,6 +4704,7 @@ def trim_messages(
system_message = ""
for message in messages:
if message["role"] == "system":
system_message += '\n' if system_message else ''
system_message += message["content"]

current_tokens = token_counter(model=model, messages=messages)
Expand All @@ -4657,14 +4718,23 @@ def trim_messages(
print_verbose(f"Need to trim input messages: {messages}, current_tokens{current_tokens}, max_tokens: {max_tokens}")
if system_message:
system_message_event, max_tokens = process_system_message(system_message=system_message, max_tokens=max_tokens, model=model)
messages = messages + [system_message_event]

if max_tokens == 0: # the system messages are too long
return [system_message_event]

# Since all system messages are combined and trimmed to fit the max_tokens,
# we remove all system messages from the messages list
messages = [message for message in messages if message["role"] != "system"]

final_messages = process_messages(messages=messages, max_tokens=max_tokens, model=model)

# Add system message to the beginning of the final messages
if system_message:
final_messages = [system_message_event] + final_messages

if return_response_tokens: # if user wants token count with new trimmed messages
response_tokens = max_tokens - get_token_count(final_messages, model)
return final_messages, response_tokens

return final_messages
except Exception as e: # [NON-Blocking, if error occurs just return final_messages
print_verbose(f"Got exception while token trimming{e}")
Expand Down