Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
sarahwooders committed Aug 19, 2024
1 parent b4dfba9 commit 5e8c2fc
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 16 deletions.
6 changes: 3 additions & 3 deletions memgpt/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
# import pydantic response objects from memgpt.server.rest_api
from memgpt.server.rest_api.tools.index import CreateToolRequest, ListToolsResponse
from memgpt.server.server import SyncServer
from memgpt.utils import get_human_text
from memgpt.utils import get_human_text, get_persona_text


def create_client(base_url: Optional[str] = None, token: Optional[str] = None):
Expand Down Expand Up @@ -259,7 +259,7 @@ def create_agent(
embedding_config: Optional[EmbeddingConfig] = None,
llm_config: Optional[LLMConfig] = None,
# memory
memory: BaseMemory = ChatMemory(human=get_human_text(DEFAULT_HUMAN), persona=get_human_text(DEFAULT_PERSONA)),
memory: BaseMemory = ChatMemory(human=get_human_text(DEFAULT_HUMAN), persona=get_persona_text(DEFAULT_PERSONA)),
# system prompt (can be templated)
system_prompt: Optional[str] = None,
# tools
Expand Down Expand Up @@ -729,7 +729,7 @@ def create_agent(
embedding_config: Optional[EmbeddingConfig] = None,
llm_config: Optional[LLMConfig] = None,
# memory
memory: BaseMemory = ChatMemory(human=get_human_text(DEFAULT_HUMAN), persona=get_human_text(DEFAULT_PERSONA)),
memory: BaseMemory = ChatMemory(human=get_human_text(DEFAULT_HUMAN), persona=get_persona_text(DEFAULT_PERSONA)),
# system prompt (can be templated)
system_prompt: Optional[str] = None,
# tools
Expand Down
27 changes: 14 additions & 13 deletions memgpt/memory.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import datetime
import uuid
import warnings
from abc import ABC, abstractmethod
from typing import List, Optional, Tuple, Union

Expand Down Expand Up @@ -29,17 +30,12 @@ class MemoryModule(BaseModel):

def __setattr__(self, name, value):
"""Run validation if self.value is updated"""
super().__setattr__(name, value)
if name == "value":
# run validation
self.__class__.validate(self.dict(exclude_unset=True)) # TODO: not sure what this does
if len(value) > self.limit:
# TODO: come up with smarter eviction algorithm
print(f"Edit failed: Exceeds {self.limit} character limit (requested {len(value)}). Clipping to {self.limit} characters.")
value = value[: self.limit]

super().__setattr__(name, value)

@validator("value", always=True)
@validator("value", always=True, check_fields=False)
def check_value_length(cls, v, values):
# TODO: this doesn't run all the time, should fix
if v is not None:
Expand All @@ -55,12 +51,9 @@ def check_value_length(cls, v, values):
raise ValueError("Value must be either a string or a list of strings.")

if length > limit:
error_msg = f"Edit failed: Exceeds {limit} character limit (requested {length}). Clipping to {limit} characters."
print(error_msg)
values["value"] = v[:limit]
return v[:limit]
else:
return v
raise ValueError(f"Value exceeds {limit} character limit (requested {length}).")

return v

def __len__(self):
return len(str(self))
Expand Down Expand Up @@ -102,6 +95,14 @@ def to_dict(self):
class ChatMemory(BaseMemory):

def __init__(self, persona: str, human: str, limit: int = 2000):
# TODO: clip if needed
if persona and len(persona) > limit:
warnings.warn(f"Persona exceeds {limit} character limit (requested {len(persona)}).")
persona = persona[:limit]

if human and len(human) > limit:
warnings.warn(f"Human exceeds {limit} character limit (requested {len(human)}).")
human = human[:limit]
self.memory = {
"persona": MemoryModule(name="persona", value=persona, limit=limit),
"human": MemoryModule(name="human", value=human, limit=limit),
Expand Down

0 comments on commit 5e8c2fc

Please sign in to comment.