Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring CLI to use config file, connect to Llama Index data sources, and allow for multiple agents #154

Merged
merged 31 commits into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
7b94237
add pytest
sarahwooders Oct 27, 2023
544f2da
partially refactored config and updated run function
sarahwooders Oct 27, 2023
76d297a
fix configs
sarahwooders Oct 27, 2023
796330d
merge
sarahwooders Oct 27, 2023
74f8324
cleanup loading
sarahwooders Oct 27, 2023
8060203
implement adding humans/persons via CLI
sarahwooders Oct 27, 2023
d869bbe
update agent checkpointing
sarahwooders Oct 27, 2023
dfe1b7e
fix agent state loading
sarahwooders Oct 27, 2023
5660a6b
bugging agent reloading
sarahwooders Oct 28, 2023
c93faaa
implement memgpt configure
sarahwooders Oct 29, 2023
028205d
autorun configure with first run
sarahwooders Oct 29, 2023
be71ab8
graceful exit
sarahwooders Oct 29, 2023
ac695be
use config params for embeddings
sarahwooders Oct 29, 2023
520ddd6
refactor so cli code contained in one folder
sarahwooders Oct 29, 2023
5ca0326
add back a few commands
sarahwooders Oct 29, 2023
4b50611
fixed reloading issue thanks to charles
sarahwooders Oct 29, 2023
d1f4560
fix bug with listing humans/personas
sarahwooders Oct 29, 2023
75fcbac
Merge branch 'main' into refactor-cli
sarahwooders Oct 30, 2023
8453c62
add agent selection and -y
sarahwooders Oct 30, 2023
d75ebe4
Merge branch 'refactor-cli' of github.com:sarahwooders/MemGPT into re…
sarahwooders Oct 30, 2023
b5220cb
load bugfixes
sarahwooders Oct 30, 2023
d1d3a8d
Merge branch 'main' into refactor-cli
sarahwooders Oct 30, 2023
5566ca7
modify warning for legacy
sarahwooders Oct 30, 2023
e689f49
fix some bugs for initial configure
sarahwooders Oct 30, 2023
9e2982e
Merge branch 'refactor-cli' of github.com:sarahwooders/MemGPT into re…
sarahwooders Oct 30, 2023
76ead3f
black pass on stray files
cpacker Oct 30, 2023
a3b8a5f
another stray black file
cpacker Oct 30, 2023
d251543
add embeddings
sarahwooders Oct 30, 2023
529a75c
fix typo
sarahwooders Oct 30, 2023
18b3109
add preset to config
sarahwooders Oct 30, 2023
6fef44b
add temporary requirements-local.txt file for locally running embeddi…
sarahwooders Oct 30, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 71 additions & 4 deletions memgpt/agent.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
import asyncio
import inspect
import datetime
import glob
import pickle
import math
import os
import json
import threading

import openai

from memgpt.persistence_manager import LocalStateManager
from memgpt.config import AgentConfig
from .system import get_heartbeat, get_login_event, package_function_response, package_summarize_message, get_initial_boot_messages
from .memory import CoreMemory as Memory, summarize_messages, a_summarize_messages
from .openai_tools import acompletions_with_backoff as acreate, completions_with_backoff as create
from .utils import get_local_time, parse_json, united_diff, printd, count_tokens
from .constants import (
MEMGPT_DIR,
FIRST_MESSAGE_ATTEMPTS,
MAX_PAUSE_HEARTBEATS,
MESSAGE_CHATGPT_FUNCTION_MODEL,
Expand Down Expand Up @@ -167,6 +170,7 @@ async def call_function(function_to_call, **function_args):
class Agent(object):
def __init__(
self,
config,
model,
system,
functions,
Expand All @@ -178,6 +182,8 @@ def __init__(
persistence_manager_init=True,
first_message_verify_mono=True,
):
# agent config
self.config = config
# gpt-4, gpt-3.5-turbo
self.model = model
# Store the system instructions (used to rebuild memory)
Expand All @@ -194,7 +200,8 @@ def __init__(
)
# Keep track of the total number of messages throughout all time
self.messages_total = messages_total if messages_total is not None else (len(self._messages) - 1) # (-system)
self.messages_total_init = self.messages_total
# self.messages_total_init = self.messages_total
self.messages_total_init = len(self._messages) - 1
printd(f"AgentAsync initialized, self.messages_total={self.messages_total}")

# Interface must implement:
Expand Down Expand Up @@ -331,6 +338,61 @@ def save_to_json_file(self, filename):
with open(filename, "w") as file:
json.dump(self.to_dict(), file)

def save(self):
"""Save agent state locally"""

timestamp = get_local_time().replace(" ", "_").replace(":", "_")
agent_name = self.config.name # TODO: fix

# save agent state
filename = f"{timestamp}.json"
os.makedirs(self.config.save_state_dir(), exist_ok=True)
self.save_to_json_file(os.path.join(self.config.save_state_dir(), filename))

# save the persistence manager too
filename = f"{timestamp}.persistence.pickle"
os.makedirs(self.config.save_persistence_manager_dir(), exist_ok=True)
self.persistence_manager.save(os.path.join(self.config.save_persistence_manager_dir(), filename))

@classmethod
def load_agent(cls, interface, agent_config: AgentConfig):
"""Load saved agent state"""
# TODO: support loading from specific file
agent_name = agent_config.name

# load state
directory = agent_config.save_state_dir()
json_files = glob.glob(f"{directory}/*.json") # This will list all .json files in the current directory.
if not json_files:
print(f"/load error: no .json checkpoint files found")
raise ValueError(f"Cannot load {agent_name}")

# Sort files based on modified timestamp, with the latest file being the first.
filename = max(json_files, key=os.path.getmtime)
state = json.load(open(filename, "r"))

# load persistence manager
filename = os.path.basename(filename).replace(".json", ".persistence.pickle")
directory = agent_config.save_persistence_manager_dir()
persistence_manager = LocalStateManager.load(os.path.join(directory, filename), agent_config)

messages = state["messages"]
agent = cls(
config=agent_config,
model=state["model"],
system=state["system"],
functions=state["functions"],
interface=interface,
persistence_manager=persistence_manager,
persistence_manager_init=False,
persona_notes=state["memory"]["persona"],
human_notes=state["memory"]["human"],
messages_total=state["messages_total"] if "messages_total" in state else len(messages) - 1,
)
agent._messages = messages
agent.memory = initialize_memory(state["memory"]["persona"], state["memory"]["human"])
return agent

@classmethod
def load(cls, state, interface, persistence_manager):
model = state["model"]
Expand Down Expand Up @@ -875,6 +937,9 @@ async def step(self, user_message, first_message=False, first_message_retry_limi

if len(input_message_sequence) > 1 and input_message_sequence[-1]["role"] != "user":
printd(f"WARNING: attempting to run ChatCompletion without user as the last message in the queue")
from pprint import pprint

pprint(input_message_sequence[-1])

# Step 1: send the conversation and available functions to GPT
if not skip_verify and (first_message or self.messages_total == self.messages_total_init):
Expand All @@ -901,9 +966,9 @@ async def step(self, user_message, first_message=False, first_message_retry_limi

# Add the extra metadata to the assistant response
# (e.g. enough metadata to enable recreating the API call)
assert "api_response" not in all_response_messages[0]
assert "api_response" not in all_response_messages[0], f"api_response already in {all_response_messages[0]}"
all_response_messages[0]["api_response"] = response_message_copy
assert "api_args" not in all_response_messages[0]
assert "api_args" not in all_response_messages[0], f"api_args already in {all_response_messages[0]}"
all_response_messages[0]["api_args"] = {
"model": self.model,
"messages": input_message_sequence,
Expand Down Expand Up @@ -933,6 +998,7 @@ async def step(self, user_message, first_message=False, first_message_retry_limi

except Exception as e:
printd(f"step() failed\nuser_message = {user_message}\nerror = {e}")
print(f"step() failed\nuser_message = {user_message}\nerror = {e}")

# If we got a context alert, try trimming the messages length, then try again
if "maximum context length" in str(e):
Expand All @@ -943,6 +1009,7 @@ async def step(self, user_message, first_message=False, first_message_retry_limi
return await self.step(user_message, first_message=first_message)
else:
printd(f"step() failed with openai.InvalidRequestError, but didn't recognize the error message: '{str(e)}'")
print(e)
raise e

async def summarize_messages_inplace(self, cutoff=None):
Expand Down
132 changes: 132 additions & 0 deletions memgpt/cli/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import typer
import logging
import asyncio
import os
from prettytable import PrettyTable
import questionary
import openai


import memgpt.interface # for printing to terminal
from memgpt.cli.cli_config import configure
import memgpt.agent as agent
import memgpt.system as system
import memgpt.presets as presets
import memgpt.constants as constants
import memgpt.personas.personas as personas
import memgpt.humans.humans as humans
import memgpt.utils as utils
from memgpt.utils import printd
from memgpt.persistence_manager import LocalStateManager
from memgpt.config import MemGPTConfig, AgentConfig
from memgpt.constants import MEMGPT_DIR
from memgpt.agent import AgentAsync


def run(
persona: str = typer.Option(None, help="Specify persona"),
agent: str = typer.Option(None, help="Specify agent save file"),
human: str = typer.Option(None, help="Specify human"),
model: str = typer.Option(None, help="Specify the LLM model"),
data_source: str = typer.Option(None, help="Specify data source to attach to agent"),
first: bool = typer.Option(False, "--first", help="Use --first to send the first message in the sequence"),
debug: bool = typer.Option(False, "--debug", help="Use --debug to enable debugging output"),
no_verify: bool = typer.Option(False, "--no_verify", help="Bypass message verification"),
yes: bool = typer.Option(False, "-y", help="Skip confirmation prompt and use defaults"),
):
"""Start chatting with an MemGPT agent

Example usage: `memgpt run --agent myagent --data-source mydata --persona mypersona --human myhuman --model gpt-3.5-turbo`

:param persona: Specify persona
:param agent: Specify agent name (will load existing state if the agent exists, or create a new one with that name)
:param human: Specify human
:param model: Specify the LLM model
:param data_source: Specify data source to attach to agent (if new agent is being created)

"""

# setup logger
utils.DEBUG = debug
logging.getLogger().setLevel(logging.CRITICAL)
if debug:
logging.getLogger().setLevel(logging.DEBUG)

if not MemGPTConfig.exists(): # if no config, run configure
if yes:
# use defaults
config = MemGPTConfig()
else:
# use input
configure()
config = MemGPTConfig.load()
else: # load config
config = MemGPTConfig.load()

# override with command line arguments
if debug:
config.debug = debug
if no_verify:
config.no_verify = no_verify

# determine agent to use, if not provided
if not yes and not agent:
agent_files = utils.list_agent_config_files()
agents = [AgentConfig.load(f).name for f in agent_files]

if len(agents) > 0:
select_agent = questionary.confirm("Would you like to select an existing agent?").ask()
if select_agent:
agent = questionary.select("Select agent:", choices=agents).ask()

# create agent config
if agent and AgentConfig.exists(agent): # use existing agent
typer.secho(f"Using existing agent {agent}", fg=typer.colors.GREEN)
agent_config = AgentConfig.load(agent)
printd("State path:", agent_config.save_state_dir())
printd("Persistent manager path:", agent_config.save_persistence_manager_dir())
printd("Index path:", agent_config.save_agent_index_dir())
# persistence_manager = LocalStateManager(agent_config).load() # TODO: implement load
# TODO: load prior agent state
assert not any(
[persona, human, model]
), f"Cannot override existing agent state with command line arguments: {persona}, {human}, {model}"

# load existing agent
memgpt_agent = AgentAsync.load_agent(memgpt.interface, agent_config)
else: # create new agent
# create new agent config: override defaults with args if provided
typer.secho("Creating new agent...", fg=typer.colors.GREEN)
agent_config = AgentConfig(
name=agent if agent else None,
persona=persona if persona else config.default_persona,
human=human if human else config.default_human,
model=model if model else config.model,
)

# attach data source to agent
agent_config.attach_data_source(data_source)

# TODO: allow configrable state manager (only local is supported right now)
persistence_manager = LocalStateManager(agent_config) # TODO: insert dataset/pre-fill

# save new agent config
agent_config.save()
typer.secho(f"Created new agent {agent_config.name}.", fg=typer.colors.GREEN)

# create agent
memgpt_agent = presets.use_preset(
presets.DEFAULT,
agent_config,
agent_config.model,
agent_config.persona,
agent_config.human,
memgpt.interface,
persistence_manager,
)

# start event loop
from memgpt.main import run_agent_loop

loop = asyncio.get_event_loop()
loop.run_until_complete(run_agent_loop(memgpt_agent, first, no_verify, config)) # TODO: add back no_verify
Loading