Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Patch summarize when running with local llms #213

Merged
merged 14 commits into from
Nov 3, 2023
Merged
104 changes: 83 additions & 21 deletions memgpt/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,12 @@
MESSAGE_CHATGPT_FUNCTION_MODEL,
MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE,
MESSAGE_SUMMARY_WARNING_TOKENS,
MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC,
MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST,
CORE_MEMORY_HUMAN_CHAR_LIMIT,
CORE_MEMORY_PERSONA_CHAR_LIMIT,
)
from .errors import LLMError


def initialize_memory(ai_notes, human_notes):
Expand Down Expand Up @@ -680,16 +683,43 @@ def step(self, user_message, first_message=False, first_message_retry_limit=FIRS
printd(f"step() failed with openai.InvalidRequestError, but didn't recognize the error message: '{str(e)}'")
raise e

def summarize_messages_inplace(self, cutoff=None):
if cutoff is None:
tokens_so_far = 0 # Smart cutoff -- just below the max.
cutoff = len(self.messages) - 1
for m in reversed(self.messages):
tokens_so_far += count_tokens(str(m), self.model)
if tokens_so_far >= MESSAGE_SUMMARY_WARNING_TOKENS * 0.2:
break
cutoff -= 1
cutoff = min(len(self.messages) - 3, cutoff) # Always keep the last two messages too
def summarize_messages_inplace(self, cutoff=None, preserve_last_N_messages=True):
assert self.messages[0]["role"] == "system", f"self.messages[0] should be system (instead got {self.messages[0]})"

# Start at index 1 (past the system message),
# and collect messages for summarization until we reach the desired truncation token fraction (eg 50%)
# Do not allow truncation of the last N messages, since these are needed for in-context examples of function calling
token_counts = [count_tokens(str(msg)) for msg in self.messages]
message_buffer_token_count = sum(token_counts[1:]) # no system message
desired_token_count_to_summarize = int(message_buffer_token_count * MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC)
candidate_messages_to_summarize = self.messages[1:]
token_counts = token_counts[1:]
if preserve_last_N_messages:
candidate_messages_to_summarize = candidate_messages_to_summarize[:-MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST]
token_counts = token_counts[:-MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST]
printd(f"MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC={MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC}")
printd(f"MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST={MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST}")
printd(f"token_counts={token_counts}")
printd(f"message_buffer_token_count={message_buffer_token_count}")
printd(f"desired_token_count_to_summarize={desired_token_count_to_summarize}")
printd(f"len(candidate_messages_to_summarize)={len(candidate_messages_to_summarize)}")

# If at this point there's nothing to summarize, throw an error
if len(candidate_messages_to_summarize) == 0:
raise LLMError(
f"Summarize error: tried to run summarize, but couldn't find enough messages to compress [len={len(self.messages)}, preserve_N={MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST}]"
)

# Walk down the message buffer (front-to-back) until we hit the target token count
tokens_so_far = 0
cutoff = 0
for i, msg in enumerate(candidate_messages_to_summarize):
cutoff = i
tokens_so_far += token_counts[i]
if tokens_so_far > desired_token_count_to_summarize:
break
# Account for system message
cutoff += 1

# Try to make an assistant message come after the cutoff
try:
Expand Down Expand Up @@ -1083,16 +1113,43 @@ async def step(self, user_message, first_message=False, first_message_retry_limi
print(e)
raise e

async def summarize_messages_inplace(self, cutoff=None):
if cutoff is None:
tokens_so_far = 0 # Smart cutoff -- just below the max.
cutoff = len(self.messages) - 1
for m in reversed(self.messages):
tokens_so_far += count_tokens(str(m), self.model)
if tokens_so_far >= MESSAGE_SUMMARY_WARNING_TOKENS * 0.2:
break
cutoff -= 1
cutoff = min(len(self.messages) - 3, cutoff) # Always keep the last two messages too
async def summarize_messages_inplace(self, cutoff=None, preserve_last_N_messages=True):
assert self.messages[0]["role"] == "system", f"self.messages[0] should be system (instead got {self.messages[0]})"

# Start at index 1 (past the system message),
# and collect messages for summarization until we reach the desired truncation token fraction (eg 50%)
# Do not allow truncation of the last N messages, since these are needed for in-context examples of function calling
token_counts = [count_tokens(str(msg)) for msg in self.messages]
message_buffer_token_count = sum(token_counts[1:]) # no system message
token_counts = token_counts[1:]
desired_token_count_to_summarize = int(message_buffer_token_count * MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC)
candidate_messages_to_summarize = self.messages[1:]
if preserve_last_N_messages:
candidate_messages_to_summarize = candidate_messages_to_summarize[:-MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST]
token_counts = token_counts[:-MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST]
printd(f"MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC={MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC}")
printd(f"MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST={MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST}")
printd(f"token_counts={token_counts}")
printd(f"message_buffer_token_count={message_buffer_token_count}")
printd(f"desired_token_count_to_summarize={desired_token_count_to_summarize}")
printd(f"len(candidate_messages_to_summarize)={len(candidate_messages_to_summarize)}")

# If at this point there's nothing to summarize, throw an error
if len(candidate_messages_to_summarize) == 0:
raise LLMError(
f"Summarize error: tried to run summarize, but couldn't find enough messages to compress [len={len(self.messages)}, preserve_N={MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST}]"
)

# Walk down the message buffer (front-to-back) until we hit the target token count
tokens_so_far = 0
cutoff = 0
for i, msg in enumerate(candidate_messages_to_summarize):
cutoff = i
tokens_so_far += token_counts[i]
if tokens_so_far > desired_token_count_to_summarize:
break
# Account for system message
cutoff += 1

# Try to make an assistant message come after the cutoff
try:
Expand All @@ -1106,8 +1163,13 @@ async def summarize_messages_inplace(self, cutoff=None):
pass

message_sequence_to_summarize = self.messages[1:cutoff] # do NOT get rid of the system message
printd(f"Attempting to summarize {len(message_sequence_to_summarize)} messages [1:{cutoff}] of {len(self.messages)}")
if len(message_sequence_to_summarize) == 0:
printd(f"message_sequence_to_summarize is len 0, skipping summarize")
raise LLMError(
f"Summarize error: tried to run summarize, but couldn't find enough messages to compress [len={len(self.messages)}, cutoff={cutoff}]"
)

printd(f"Attempting to summarize {len(message_sequence_to_summarize)} messages [1:{cutoff}] of {len(self.messages)}")
summary = await a_summarize_messages(self.model, message_sequence_to_summarize)
printd(f"Got summary: {summary}")

Expand Down
12 changes: 11 additions & 1 deletion memgpt/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,18 @@
INITIAL_BOOT_MESSAGE_SEND_MESSAGE_FIRST_MSG = STARTUP_QUOTES[2]

# Constants to do with summarization / conversation length window
MESSAGE_SUMMARY_WARNING_TOKENS = 7000 # the number of tokens consumed in a call before a system warning goes to the agent
# The max amount of tokens supported by the underlying model (eg 8k for gpt-4 and Mistral 7B)
LLM_MAX_TOKENS = 8000 # change this depending on your model
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vivi you can test that this works by intentionally lowering LLM_MAX_TOKENS + lowering it on the web UI backend

# The amount of tokens before a sytem warning about upcoming truncation is sent to MemGPT
MESSAGE_SUMMARY_WARNING_TOKENS = int(0.75 * LLM_MAX_TOKENS)
# The error message that MemGPT will receive
MESSAGE_SUMMARY_WARNING_STR = f"Warning: the conversation history will soon reach its maximum length and be trimmed. Make sure to save any important information from the conversation to your memory before it is removed."
# The fraction of tokens we truncate down to
MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC = 0.75

# Even when summarizing, we want to keep a handful of recent messages
# These serve as in-context examples of how to use functions / what user messages look like
MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST = 3

# Default memory limits
CORE_MEMORY_PERSONA_CHAR_LIMIT = 2000
Expand Down
28 changes: 28 additions & 0 deletions memgpt/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
class LLMError(Exception):
"""Base class for all LLM-related errors."""

pass


class LLMJSONParsingError(LLMError):
"""Exception raised for errors in the JSON parsing process."""

def __init__(self, message="Error parsing JSON generated by LLM"):
self.message = message
super().__init__(self.message)


class LocalLLMError(LLMError):
"""Generic catch-all error for local LLM problems"""

def __init__(self, message="Encountered an error while running local LLM"):
self.message = message
super().__init__(self.message)


class LocalLLMConnectionError(LLMError):
"""Error for when local LLM cannot be reached with provided IP/port"""

def __init__(self, message="Could not connect to local LLM"):
self.message = message
super().__init__(self.message)
45 changes: 33 additions & 12 deletions memgpt/local_llm/chat_completion_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,26 @@

from .webui.api import get_webui_completion
from .lmstudio.api import get_lmstudio_completion
from .llm_chat_completion_wrappers import airoboros, dolphin, zephyr
from .llm_chat_completion_wrappers import airoboros, dolphin, zephyr, simple_summary_wrapper
from .utils import DotDict
from ..prompts.gpt_summarize import SYSTEM as SUMMARIZE_SYSTEM_MESSAGE
from ..errors import LocalLLMConnectionError, LocalLLMError

HOST = os.getenv("OPENAI_API_BASE")
HOST_TYPE = os.getenv("BACKEND_TYPE") # default None == ChatCompletion
DEBUG = False
DEFAULT_WRAPPER = airoboros.Airoboros21InnerMonologueWrapper()
has_shown_warning = False


def get_chat_completion(
model, # no model, since the model is fixed to whatever you set in your own backend
messages,
functions,
functions=None,
function_call="auto",
):
global has_shown_warning

if HOST is None:
raise ValueError(f"The OPENAI_API_BASE environment variable is not defined. Please set it in your environment.")
if HOST_TYPE is None:
Expand All @@ -29,21 +34,34 @@ def get_chat_completion(
if function_call != "auto":
raise ValueError(f"function_call == {function_call} not supported (auto only)")

if model == "airoboros-l2-70b-2.1":
if messages[0]["role"] == "system" and messages[0]["content"].strip() == SUMMARIZE_SYSTEM_MESSAGE.strip():
# Special case for if the call we're making is coming from the summarizer
llm_wrapper = simple_summary_wrapper.SimpleSummaryWrapper()
elif model == "airoboros-l2-70b-2.1":
llm_wrapper = airoboros.Airoboros21InnerMonologueWrapper()
elif model == "dolphin-2.1-mistral-7b":
llm_wrapper = dolphin.Dolphin21MistralWrapper()
elif model == "zephyr-7B-alpha" or model == "zephyr-7B-beta":
llm_wrapper = zephyr.ZephyrMistralInnerMonologueWrapper()
else:
# Warn the user that we're using the fallback
print(f"Warning: no wrapper specified for local LLM, using the default wrapper")
if not has_shown_warning:
print(
f"Warning: no wrapper specified for local LLM, using the default wrapper (you can remove this warning by specifying the wrapper with --model)"
)
has_shown_warning = True
llm_wrapper = DEFAULT_WRAPPER

# First step: turn the message sequence into a prompt that the model expects
prompt = llm_wrapper.chat_completion_to_prompt(messages, functions)
if DEBUG:
print(prompt)

try:
prompt = llm_wrapper.chat_completion_to_prompt(messages, functions)
if DEBUG:
print(prompt)
except Exception as e:
raise LocalLLMError(
f"Failed to convert ChatCompletion messages into prompt string with wrapper {str(llm_wrapper)} - error: {str(e)}"
)

try:
if HOST_TYPE == "webui":
Expand All @@ -54,14 +72,17 @@ def get_chat_completion(
print(f"Warning: BACKEND_TYPE was not set, defaulting to webui")
result = get_webui_completion(prompt)
except requests.exceptions.ConnectionError as e:
raise ValueError(f"Was unable to connect to host {HOST}")
raise LocalLLMConnectionError(f"Unable to connect to host {HOST}")

if result is None or result == "":
raise Exception(f"Got back an empty response string from {HOST}")
raise LocalLLMError(f"Got back an empty response string from {HOST}")

chat_completion_result = llm_wrapper.output_to_chat_completion_response(result)
if DEBUG:
print(json.dumps(chat_completion_result, indent=2))
try:
chat_completion_result = llm_wrapper.output_to_chat_completion_response(result)
if DEBUG:
print(json.dumps(chat_completion_result, indent=2))
except Exception as e:
raise LocalLLMError(f"Failed to parse JSON from local LLM response - error: {str(e)}")

# unpack with response.choices[0].message.content
response = DotDict(
Expand Down
15 changes: 11 additions & 4 deletions memgpt/local_llm/llm_chat_completion_wrappers/airoboros.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json

from .wrapper_base import LLMChatCompletionWrapper
from ...errors import LLMJSONParsingError


class Airoboros21Wrapper(LLMChatCompletionWrapper):
Expand Down Expand Up @@ -186,8 +187,11 @@ def output_to_chat_completion_response(self, raw_llm_output):
function_json_output = json.loads(raw_llm_output)
except Exception as e:
raise Exception(f"Failed to decode JSON from LLM output:\n{raw_llm_output}")
function_name = function_json_output["function"]
function_parameters = function_json_output["params"]
try:
function_name = function_json_output["function"]
function_parameters = function_json_output["params"]
except KeyError as e:
raise LLMJSONParsingError(f"Received valid JSON from LLM, but JSON was missing fields: {str(e)}")

if self.clean_func_args:
function_name, function_parameters = self.clean_function_args(function_name, function_parameters)
Expand Down Expand Up @@ -395,8 +399,11 @@ def output_to_chat_completion_response(self, raw_llm_output):
function_json_output = json.loads(raw_llm_output + "\n}")
except:
raise Exception(f"Failed to decode JSON from LLM output:\n{raw_llm_output}")
function_name = function_json_output["function"]
function_parameters = function_json_output["params"]
try:
function_name = function_json_output["function"]
function_parameters = function_json_output["params"]
except KeyError as e:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These changes should be put in the other wrappers too

raise LLMJSONParsingError(f"Received valid JSON from LLM, but JSON was missing fields: {str(e)}")

if self.clean_func_args:
(
Expand Down
8 changes: 6 additions & 2 deletions memgpt/local_llm/llm_chat_completion_wrappers/dolphin.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json

from .wrapper_base import LLMChatCompletionWrapper
from ...errors import LLMJSONParsingError


class Dolphin21MistralWrapper(LLMChatCompletionWrapper):
Expand Down Expand Up @@ -221,8 +222,11 @@ def output_to_chat_completion_response(self, raw_llm_output):
function_json_output = json.loads(raw_llm_output)
except Exception as e:
raise Exception(f"Failed to decode JSON from LLM output:\n{raw_llm_output}")
function_name = function_json_output["function"]
function_parameters = function_json_output["params"]
try:
function_name = function_json_output["function"]
function_parameters = function_json_output["params"]
except KeyError as e:
raise LLMJSONParsingError(f"Received valid JSON from LLM, but JSON was missing fields: {str(e)}")

if self.clean_func_args:
function_name, function_parameters = self.clean_function_args(function_name, function_parameters)
Expand Down
Loading