Skip to content

Commit

Permalink
I added commands to shape the conversation:
Browse files Browse the repository at this point in the history
`/rethink <text>` will change the internal dialog of the last assistant message.
`/rewrite <text>` will change the last answer of the assistant.

Both commands can be used to change how the conversation continues in
some pretty drastic and powerfull ways.
  • Loading branch information
oderwat committed Oct 31, 2023
1 parent 9b7153a commit 6805927
Showing 1 changed file with 96 additions and 19 deletions.
115 changes: 96 additions & 19 deletions memgpt/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import sys
import pickle
import json

import questionary
import typer
Expand Down Expand Up @@ -84,8 +85,12 @@ def load(memgpt_agent, filename):
print(f"Loading {filename} failed with: {e}")
else:
# Load the latest file
print(f"/load warning: no checkpoint specified, loading most recent checkpoint instead")
json_files = glob.glob("saved_state/*.json") # This will list all .json files in the current directory.
print(
f"/load warning: no checkpoint specified, loading most recent checkpoint instead"
)
json_files = glob.glob(
"saved_state/*.json"
) # This will list all .json files in the current directory.

# Check if there are any json files.
if not json_files:
Expand All @@ -107,17 +112,27 @@ def load(memgpt_agent, filename):
) # TODO(fixme):for different types of persistence managers that require different load/save methods
print(f"Loaded persistence manager from {filename}")
except Exception as e:
print(f"/load warning: loading persistence manager from {filename} failed with: {e}")
print(
f"/load warning: loading persistence manager from {filename} failed with: {e}"
)


@app.callback(invoke_without_command=True) # make default command
def run(
persona: str = typer.Option(None, help="Specify persona"),
human: str = typer.Option(None, help="Specify human"),
model: str = typer.Option(constants.DEFAULT_MEMGPT_MODEL, help="Specify the LLM model"),
first: bool = typer.Option(False, "--first", help="Use --first to send the first message in the sequence"),
debug: bool = typer.Option(False, "--debug", help="Use --debug to enable debugging output"),
no_verify: bool = typer.Option(False, "--no_verify", help="Bypass message verification"),
model: str = typer.Option(
constants.DEFAULT_MEMGPT_MODEL, help="Specify the LLM model"
),
first: bool = typer.Option(
False, "--first", help="Use --first to send the first message in the sequence"
),
debug: bool = typer.Option(
False, "--debug", help="Use --debug to enable debugging output"
),
no_verify: bool = typer.Option(
False, "--no_verify", help="Bypass message verification"
),
archival_storage_faiss_path: str = typer.Option(
"",
"--archival_storage_faiss_path",
Expand Down Expand Up @@ -187,7 +202,9 @@ async def main(
else:
azure_vars = get_set_azure_env_vars()
if len(azure_vars) > 0:
print(f"Error: Environment variables {', '.join([x[0] for x in azure_vars])} should not be set if --use_azure_openai is False")
print(
f"Error: Environment variables {', '.join([x[0] for x in azure_vars])} should not be set if --use_azure_openai is False"
)
return

if any(
Expand Down Expand Up @@ -280,17 +297,23 @@ async def main(
else:
cfg = await Config.config_init()

memgpt.interface.important_message("Running... [exit by typing '/exit', list available commands with '/help']")
memgpt.interface.important_message(
"Running... [exit by typing '/exit', list available commands with '/help']"
)
if cfg.model != constants.DEFAULT_MEMGPT_MODEL:
memgpt.interface.warning_message(
f"⛔️ Warning - you are running MemGPT with {cfg.model}, which is not officially supported (yet). Expect bugs!"
)

if cfg.index:
persistence_manager = InMemoryStateManagerWithFaiss(cfg.index, cfg.archival_database)
persistence_manager = InMemoryStateManagerWithFaiss(
cfg.index, cfg.archival_database
)
elif cfg.archival_storage_files:
print(f"Preloaded {len(cfg.archival_database)} chunks into archival memory.")
persistence_manager = InMemoryStateManagerWithPreloadedArchivalMemory(cfg.archival_database)
persistence_manager = InMemoryStateManagerWithPreloadedArchivalMemory(
cfg.archival_database
)
else:
persistence_manager = InMemoryStateManager()

Expand Down Expand Up @@ -334,7 +357,9 @@ async def main(
print(f"Database loaded into archival memory.")

if cfg.agent_save_file:
load_save_file = await questionary.confirm(f"Load in saved agent '{cfg.agent_save_file}'?").ask_async()
load_save_file = await questionary.confirm(
f"Load in saved agent '{cfg.agent_save_file}'?"
).ask_async()
if load_save_file:
load(memgpt_agent, cfg.agent_save_file)

Expand All @@ -343,7 +368,9 @@ async def main(
return

if not USER_GOES_FIRST:
console.input("[bold cyan]Hit enter to begin (will request first MemGPT message)[/bold cyan]")
console.input(
"[bold cyan]Hit enter to begin (will request first MemGPT message)[/bold cyan]"
)
clear_line()
print()

Expand Down Expand Up @@ -379,7 +406,9 @@ async def main(
break

elif user_input.lower() == "/savechat":
filename = utils.get_local_time().replace(" ", "_").replace(":", "_")
filename = (
utils.get_local_time().replace(" ", "_").replace(":", "_")
)
filename = f"{filename}.pkl"
directory = os.path.join(MEMGPT_DIR, "saved_chats")
try:
Expand All @@ -396,7 +425,9 @@ async def main(
save(memgpt_agent=memgpt_agent, cfg=cfg)
continue

elif user_input.lower() == "/load" or user_input.lower().startswith("/load "):
elif user_input.lower() == "/load" or user_input.lower().startswith(
"/load "
):
command = user_input.strip().split()
filename = command[1] if len(command) > 1 else None
load(memgpt_agent=memgpt_agent, filename=filename)
Expand Down Expand Up @@ -429,15 +460,55 @@ async def main(
print(f"Updated model to:\n{str(memgpt_agent.model)}")
continue

elif user_input.lower() == "/pop" or user_input.lower().startswith("/pop "):
elif user_input.lower() == "/pop" or user_input.lower().startswith(
"/pop "
):
# Check if there's an additional argument that's an integer
command = user_input.strip().split()
amount = int(command[1]) if len(command) > 1 and command[1].isdigit() else 2
amount = (
int(command[1])
if len(command) > 1 and command[1].isdigit()
else 2
)
print(f"Popping last {amount} messages from stack")
for _ in range(min(amount, len(memgpt_agent.messages))):
memgpt_agent.messages.pop()
continue

elif user_input.lower() == "/rethink" or user_input.lower().startswith(
"/rethink "
):
if len(user_input) < 9:
print("Missing text after the command")
continue
for x in range(len(memgpt_agent.messages) - 1, 0, -1):
if memgpt_agent.messages[x].get("role") == "assistant":
text = user_input[9:].strip()
memgpt_agent.messages[x].update({"content": text})
break
continue

elif user_input.lower() == "/rewrite" or user_input.lower().startswith(
"/rewrite "
):
if len(user_input) < 9:
print("Missing text after the command")
continue
for x in range(len(memgpt_agent.messages) - 1, 0, -1):
if memgpt_agent.messages[x].get("role") == "assistant":
text = user_input[9:].strip()
args = json.loads(
memgpt_agent.messages[x]
.get("function_call")
.get("arguments")
)
args["message"] = text
memgpt_agent.messages[x].get("function_call").update(
{"arguments": json.dumps(args)}
)
break
continue

# No skip options
elif user_input.lower() == "/wipe":
memgpt_agent = agent.AgentAsync(memgpt.interface)
Expand Down Expand Up @@ -477,14 +548,18 @@ async def main(
heartbeat_request,
function_failed,
token_warning,
) = await memgpt_agent.step(user_message, first_message=False, skip_verify=no_verify)
) = await memgpt_agent.step(
user_message, first_message=False, skip_verify=no_verify
)

# Skip user inputs if there's a memory warning, function execution failed, or the agent asked for control
if token_warning:
user_message = system.get_token_limit_warning()
skip_next_user_input = True
elif function_failed:
user_message = system.get_heartbeat(constants.FUNC_FAILED_HEARTBEAT_MESSAGE)
user_message = system.get_heartbeat(
constants.FUNC_FAILED_HEARTBEAT_MESSAGE
)
skip_next_user_input = True
elif heartbeat_request:
user_message = system.get_heartbeat(constants.REQ_HEARTBEAT_MESSAGE)
Expand All @@ -504,6 +579,8 @@ async def main(
("/memory", "print the current contents of agent memory"),
("/pop", "undo the last message in the conversation"),
("/heartbeat", "send a heartbeat system message to the agent"),
("/rewrite <text>", "changes the reply of the last agent message"),
("/rethink <text>", "changes the inner thoughts of the last agent message"),
("/memorywarning", "send a memory warning system message to the agent"),
]
# if __name__ == "__main__":
Expand Down

0 comments on commit 6805927

Please sign in to comment.