Skip to content

Commit

Permalink
Refactoring CLI to use config file, connect to Llama Index data sourc…
Browse files Browse the repository at this point in the history
…es, and allow for multiple agents (#154)

* Migrate to `memgpt run` and `memgpt configure` 
* Add Llama index data sources via `memgpt load` 
* Save config files for defaults and agents
  • Loading branch information
sarahwooders authored Oct 30, 2023
1 parent a78ba2f commit b7f9560
Show file tree
Hide file tree
Showing 17 changed files with 968 additions and 94 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/poetry-publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ on:
release:
types: [published]
workflow_dispatch:

jobs:
build-and-publish:
name: Build and Publish to PyPI
Expand Down
75 changes: 71 additions & 4 deletions memgpt/agent.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
import asyncio
import inspect
import datetime
import glob
import pickle
import math
import os
import json
import threading

import openai

from memgpt.persistence_manager import LocalStateManager
from memgpt.config import AgentConfig
from .system import get_heartbeat, get_login_event, package_function_response, package_summarize_message, get_initial_boot_messages
from .memory import CoreMemory as Memory, summarize_messages, a_summarize_messages
from .openai_tools import acompletions_with_backoff as acreate, completions_with_backoff as create
from .utils import get_local_time, parse_json, united_diff, printd, count_tokens
from .constants import (
MEMGPT_DIR,
FIRST_MESSAGE_ATTEMPTS,
MAX_PAUSE_HEARTBEATS,
MESSAGE_CHATGPT_FUNCTION_MODEL,
Expand Down Expand Up @@ -167,6 +170,7 @@ async def call_function(function_to_call, **function_args):
class Agent(object):
def __init__(
self,
config,
model,
system,
functions,
Expand All @@ -178,6 +182,8 @@ def __init__(
persistence_manager_init=True,
first_message_verify_mono=True,
):
# agent config
self.config = config
# gpt-4, gpt-3.5-turbo
self.model = model
# Store the system instructions (used to rebuild memory)
Expand All @@ -194,7 +200,8 @@ def __init__(
)
# Keep track of the total number of messages throughout all time
self.messages_total = messages_total if messages_total is not None else (len(self._messages) - 1) # (-system)
self.messages_total_init = self.messages_total
# self.messages_total_init = self.messages_total
self.messages_total_init = len(self._messages) - 1
printd(f"AgentAsync initialized, self.messages_total={self.messages_total}")

# Interface must implement:
Expand Down Expand Up @@ -331,6 +338,61 @@ def save_to_json_file(self, filename):
with open(filename, "w") as file:
json.dump(self.to_dict(), file)

def save(self):
"""Save agent state locally"""

timestamp = get_local_time().replace(" ", "_").replace(":", "_")
agent_name = self.config.name # TODO: fix

# save agent state
filename = f"{timestamp}.json"
os.makedirs(self.config.save_state_dir(), exist_ok=True)
self.save_to_json_file(os.path.join(self.config.save_state_dir(), filename))

# save the persistence manager too
filename = f"{timestamp}.persistence.pickle"
os.makedirs(self.config.save_persistence_manager_dir(), exist_ok=True)
self.persistence_manager.save(os.path.join(self.config.save_persistence_manager_dir(), filename))

@classmethod
def load_agent(cls, interface, agent_config: AgentConfig):
"""Load saved agent state"""
# TODO: support loading from specific file
agent_name = agent_config.name

# load state
directory = agent_config.save_state_dir()
json_files = glob.glob(f"{directory}/*.json") # This will list all .json files in the current directory.
if not json_files:
print(f"/load error: no .json checkpoint files found")
raise ValueError(f"Cannot load {agent_name}")

# Sort files based on modified timestamp, with the latest file being the first.
filename = max(json_files, key=os.path.getmtime)
state = json.load(open(filename, "r"))

# load persistence manager
filename = os.path.basename(filename).replace(".json", ".persistence.pickle")
directory = agent_config.save_persistence_manager_dir()
persistence_manager = LocalStateManager.load(os.path.join(directory, filename), agent_config)

messages = state["messages"]
agent = cls(
config=agent_config,
model=state["model"],
system=state["system"],
functions=state["functions"],
interface=interface,
persistence_manager=persistence_manager,
persistence_manager_init=False,
persona_notes=state["memory"]["persona"],
human_notes=state["memory"]["human"],
messages_total=state["messages_total"] if "messages_total" in state else len(messages) - 1,
)
agent._messages = messages
agent.memory = initialize_memory(state["memory"]["persona"], state["memory"]["human"])
return agent

@classmethod
def load(cls, state, interface, persistence_manager):
model = state["model"]
Expand Down Expand Up @@ -875,6 +937,9 @@ async def step(self, user_message, first_message=False, first_message_retry_limi

if len(input_message_sequence) > 1 and input_message_sequence[-1]["role"] != "user":
printd(f"WARNING: attempting to run ChatCompletion without user as the last message in the queue")
from pprint import pprint

pprint(input_message_sequence[-1])

# Step 1: send the conversation and available functions to GPT
if not skip_verify and (first_message or self.messages_total == self.messages_total_init):
Expand All @@ -901,9 +966,9 @@ async def step(self, user_message, first_message=False, first_message_retry_limi

# Add the extra metadata to the assistant response
# (e.g. enough metadata to enable recreating the API call)
assert "api_response" not in all_response_messages[0]
assert "api_response" not in all_response_messages[0], f"api_response already in {all_response_messages[0]}"
all_response_messages[0]["api_response"] = response_message_copy
assert "api_args" not in all_response_messages[0]
assert "api_args" not in all_response_messages[0], f"api_args already in {all_response_messages[0]}"
all_response_messages[0]["api_args"] = {
"model": self.model,
"messages": input_message_sequence,
Expand Down Expand Up @@ -933,6 +998,7 @@ async def step(self, user_message, first_message=False, first_message_retry_limi

except Exception as e:
printd(f"step() failed\nuser_message = {user_message}\nerror = {e}")
print(f"step() failed\nuser_message = {user_message}\nerror = {e}")

# If we got a context alert, try trimming the messages length, then try again
if "maximum context length" in str(e):
Expand All @@ -943,6 +1009,7 @@ async def step(self, user_message, first_message=False, first_message_retry_limi
return await self.step(user_message, first_message=first_message)
else:
printd(f"step() failed with openai.InvalidRequestError, but didn't recognize the error message: '{str(e)}'")
print(e)
raise e

async def summarize_messages_inplace(self, cutoff=None):
Expand Down
8 changes: 4 additions & 4 deletions memgpt/autogen/memgpt_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def create_memgpt_autogen_agent_from_config(

autogen_memgpt_agent = create_autogen_memgpt_agent(
name,
preset=presets.DEFAULT,
preset=presets.DEFAULT_PRESET,
model=model,
persona_description=persona_desc,
user_description=user_desc,
Expand All @@ -50,7 +50,7 @@ def create_memgpt_autogen_agent_from_config(
if human_input_mode != "ALWAYS":
coop_agent1 = create_autogen_memgpt_agent(
name,
preset=presets.DEFAULT,
preset=presets.DEFAULT_PRESET,
model=model,
persona_description=persona_desc,
user_description=user_desc,
Expand All @@ -65,7 +65,7 @@ def create_memgpt_autogen_agent_from_config(
else:
coop_agent2 = create_autogen_memgpt_agent(
name,
preset=presets.DEFAULT,
preset=presets.DEFAULT_PRESET,
model=model,
persona_description=persona_desc,
user_description=user_desc,
Expand All @@ -86,7 +86,7 @@ def create_memgpt_autogen_agent_from_config(

def create_autogen_memgpt_agent(
autogen_name,
preset=presets.DEFAULT,
preset=presets.DEFAULT_PRESET,
model=constants.DEFAULT_MEMGPT_MODEL,
persona_description=personas.DEFAULT,
user_description=humans.DEFAULT,
Expand Down
148 changes: 148 additions & 0 deletions memgpt/cli/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import typer
import sys
import io
import logging
import asyncio
import os
from prettytable import PrettyTable
import questionary
import openai

from llama_index import set_global_service_context
from llama_index import VectorStoreIndex, SimpleDirectoryReader, ServiceContext

import memgpt.interface # for printing to terminal
from memgpt.cli.cli_config import configure
import memgpt.agent as agent
import memgpt.system as system
import memgpt.presets as presets
import memgpt.constants as constants
import memgpt.personas.personas as personas
import memgpt.humans.humans as humans
import memgpt.utils as utils
from memgpt.utils import printd
from memgpt.persistence_manager import LocalStateManager
from memgpt.config import MemGPTConfig, AgentConfig
from memgpt.constants import MEMGPT_DIR
from memgpt.agent import AgentAsync
from memgpt.embeddings import embedding_model


def run(
persona: str = typer.Option(None, help="Specify persona"),
agent: str = typer.Option(None, help="Specify agent save file"),
human: str = typer.Option(None, help="Specify human"),
model: str = typer.Option(None, help="Specify the LLM model"),
preset: str = typer.Option(None, help="Specify preset"),
data_source: str = typer.Option(None, help="Specify data source to attach to agent"),
first: bool = typer.Option(False, "--first", help="Use --first to send the first message in the sequence"),
debug: bool = typer.Option(False, "--debug", help="Use --debug to enable debugging output"),
no_verify: bool = typer.Option(False, "--no_verify", help="Bypass message verification"),
yes: bool = typer.Option(False, "-y", help="Skip confirmation prompt and use defaults"),
):
"""Start chatting with an MemGPT agent
Example usage: `memgpt run --agent myagent --data-source mydata --persona mypersona --human myhuman --model gpt-3.5-turbo`
:param persona: Specify persona
:param agent: Specify agent name (will load existing state if the agent exists, or create a new one with that name)
:param human: Specify human
:param model: Specify the LLM model
:param data_source: Specify data source to attach to agent (if new agent is being created)
"""

# setup logger
utils.DEBUG = debug
logging.getLogger().setLevel(logging.CRITICAL)
if debug:
logging.getLogger().setLevel(logging.DEBUG)

if not MemGPTConfig.exists(): # if no config, run configure
if yes:
# use defaults
config = MemGPTConfig()
else:
# use input
configure()
config = MemGPTConfig.load()
else: # load config
config = MemGPTConfig.load()

# override with command line arguments
if debug:
config.debug = debug
if no_verify:
config.no_verify = no_verify

# determine agent to use, if not provided
if not yes and not agent:
agent_files = utils.list_agent_config_files()
agents = [AgentConfig.load(f).name for f in agent_files]

if len(agents) > 0:
select_agent = questionary.confirm("Would you like to select an existing agent?").ask()
if select_agent:
agent = questionary.select("Select agent:", choices=agents).ask()

# configure llama index
config = MemGPTConfig.load()
original_stdout = sys.stdout # unfortunate hack required to suppress confusing print statements from llama index
sys.stdout = io.StringIO()
embed_model = embedding_model(config)
service_context = ServiceContext.from_defaults(llm=None, embed_model=embed_model, chunk_size=config.embedding_chunk_size)
set_global_service_context(service_context)
sys.stdout = original_stdout

# create agent config
if agent and AgentConfig.exists(agent): # use existing agent
typer.secho(f"Using existing agent {agent}", fg=typer.colors.GREEN)
agent_config = AgentConfig.load(agent)
printd("State path:", agent_config.save_state_dir())
printd("Persistent manager path:", agent_config.save_persistence_manager_dir())
printd("Index path:", agent_config.save_agent_index_dir())
# persistence_manager = LocalStateManager(agent_config).load() # TODO: implement load
# TODO: load prior agent state
assert not any(
[persona, human, model]
), f"Cannot override existing agent state with command line arguments: {persona}, {human}, {model}"

# load existing agent
memgpt_agent = AgentAsync.load_agent(memgpt.interface, agent_config)
else: # create new agent
# create new agent config: override defaults with args if provided
typer.secho("Creating new agent...", fg=typer.colors.GREEN)
agent_config = AgentConfig(
name=agent if agent else None,
persona=persona if persona else config.default_persona,
human=human if human else config.default_human,
model=model if model else config.model,
preset=preset if preset else config.preset,
)

# attach data source to agent
agent_config.attach_data_source(data_source)

# TODO: allow configrable state manager (only local is supported right now)
persistence_manager = LocalStateManager(agent_config) # TODO: insert dataset/pre-fill

# save new agent config
agent_config.save()
typer.secho(f"Created new agent {agent_config.name}.", fg=typer.colors.GREEN)

# create agent
memgpt_agent = presets.use_preset(
agent_config.preset,
agent_config,
agent_config.model,
agent_config.persona,
agent_config.human,
memgpt.interface,
persistence_manager,
)

# start event loop
from memgpt.main import run_agent_loop

loop = asyncio.get_event_loop()
loop.run_until_complete(run_agent_loop(memgpt_agent, first, no_verify, config)) # TODO: add back no_verify
Loading

0 comments on commit b7f9560

Please sign in to comment.