Skip to content

Commit

Permalink
fix: patch error on memory overflow (#1669)
Browse files Browse the repository at this point in the history
  • Loading branch information
sarahwooders authored Aug 20, 2024
2 parents fbed1e5 + 5e8c2fc commit d3d0fe5
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 9 deletions.
6 changes: 3 additions & 3 deletions memgpt/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
# import pydantic response objects from memgpt.server.rest_api
from memgpt.server.rest_api.tools.index import CreateToolRequest, ListToolsResponse
from memgpt.server.server import SyncServer
from memgpt.utils import get_human_text
from memgpt.utils import get_human_text, get_persona_text


def create_client(base_url: Optional[str] = None, token: Optional[str] = None):
Expand Down Expand Up @@ -259,7 +259,7 @@ def create_agent(
embedding_config: Optional[EmbeddingConfig] = None,
llm_config: Optional[LLMConfig] = None,
# memory
memory: BaseMemory = ChatMemory(human=get_human_text(DEFAULT_HUMAN), persona=get_human_text(DEFAULT_PERSONA)),
memory: BaseMemory = ChatMemory(human=get_human_text(DEFAULT_HUMAN), persona=get_persona_text(DEFAULT_PERSONA)),
# system prompt (can be templated)
system_prompt: Optional[str] = None,
# tools
Expand Down Expand Up @@ -729,7 +729,7 @@ def create_agent(
embedding_config: Optional[EmbeddingConfig] = None,
llm_config: Optional[LLMConfig] = None,
# memory
memory: BaseMemory = ChatMemory(human=get_human_text(DEFAULT_HUMAN), persona=get_human_text(DEFAULT_PERSONA)),
memory: BaseMemory = ChatMemory(human=get_human_text(DEFAULT_HUMAN), persona=get_persona_text(DEFAULT_PERSONA)),
# system prompt (can be templated)
system_prompt: Optional[str] = None,
# tools
Expand Down
22 changes: 16 additions & 6 deletions memgpt/memory.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import datetime
import uuid
import warnings
from abc import ABC, abstractmethod
from typing import List, Optional, Tuple, Union

Expand All @@ -19,6 +20,7 @@
)


# always run validation
class MemoryModule(BaseModel):
"""Base class for memory modules"""

Expand All @@ -31,10 +33,11 @@ def __setattr__(self, name, value):
super().__setattr__(name, value)
if name == "value":
# run validation
self.__class__.validate(self.dict(exclude_unset=True))
self.__class__.validate(self.dict(exclude_unset=True)) # TODO: not sure what this does

@validator("value", always=True)
@validator("value", always=True, check_fields=False)
def check_value_length(cls, v, values):
# TODO: this doesn't run all the time, should fix
if v is not None:
# Fetching the limit from the values dictionary
limit = values.get("limit", 2000) # Default to 2000 if limit is not yet set
Expand All @@ -48,10 +51,9 @@ def check_value_length(cls, v, values):
raise ValueError("Value must be either a string or a list of strings.")

if length > limit:
error_msg = f"Edit failed: Exceeds {limit} character limit (requested {length})."
# TODO: add archival memory error?
raise ValueError(error_msg)
return v
raise ValueError(f"Value exceeds {limit} character limit (requested {length}).")

return v

def __len__(self):
return len(str(self))
Expand Down Expand Up @@ -93,6 +95,14 @@ def to_dict(self):
class ChatMemory(BaseMemory):

def __init__(self, persona: str, human: str, limit: int = 2000):
# TODO: clip if needed
if persona and len(persona) > limit:
warnings.warn(f"Persona exceeds {limit} character limit (requested {len(persona)}).")
persona = persona[:limit]

if human and len(human) > limit:
warnings.warn(f"Human exceeds {limit} character limit (requested {len(human)}).")
human = human[:limit]
self.memory = {
"persona": MemoryModule(name="persona", value=persona, limit=limit),
"human": MemoryModule(name="human", value=human, limit=limit),
Expand Down

0 comments on commit d3d0fe5

Please sign in to comment.