Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use ~/.memgpt/config to set questionary defaults in memgpt configure #389

Merged
merged 50 commits into from
Nov 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
89cf976
mark depricated API section
sarahwooders Oct 30, 2023
be6212c
add readme
sarahwooders Oct 31, 2023
b011380
add readme
sarahwooders Oct 31, 2023
59f7b71
add readme
sarahwooders Oct 31, 2023
176538b
add readme
sarahwooders Oct 31, 2023
9905266
add readme
sarahwooders Oct 31, 2023
3606959
add readme
sarahwooders Oct 31, 2023
c48803c
add readme
sarahwooders Oct 31, 2023
40cdb23
add readme
sarahwooders Oct 31, 2023
ff43c98
add readme
sarahwooders Oct 31, 2023
01db319
CLI bug fixes for azure
sarahwooders Oct 31, 2023
a11cef9
check azure before running
sarahwooders Oct 31, 2023
a47d49e
Merge branch 'cpacker:main' into main
sarahwooders Oct 31, 2023
fbe2482
Update README.md
sarahwooders Oct 31, 2023
446a1a1
Update README.md
sarahwooders Oct 31, 2023
1541482
bug fix with persona loading
sarahwooders Oct 31, 2023
5776e30
Merge branch 'main' of github.com:sarahwooders/MemGPT
sarahwooders Oct 31, 2023
d48cf23
Merge branch 'cpacker:main' into main
sarahwooders Oct 31, 2023
7a8eb80
remove print
sarahwooders Oct 31, 2023
9a5ece0
Merge branch 'main' of github.com:sarahwooders/MemGPT
sarahwooders Oct 31, 2023
d3370b3
merge
sarahwooders Nov 3, 2023
c19c2ce
Merge branch 'cpacker:main' into main
sarahwooders Nov 3, 2023
aa6ee71
Merge branch 'cpacker:main' into main
sarahwooders Nov 3, 2023
36bb04d
make errors for cli flags more clear
sarahwooders Nov 3, 2023
6f50db1
format
sarahwooders Nov 3, 2023
4c91a41
Merge branch 'cpacker:main' into main
sarahwooders Nov 3, 2023
dbaf4a0
Merge branch 'cpacker:main' into main
sarahwooders Nov 5, 2023
c86e1c9
fix imports
sarahwooders Nov 5, 2023
e54e762
Merge branch 'cpacker:main' into main
sarahwooders Nov 5, 2023
524a974
Merge branch 'main' of github.com:sarahwooders/MemGPT
sarahwooders Nov 5, 2023
7baf3e7
fix imports
sarahwooders Nov 5, 2023
2fd8795
Merge branch 'main' of github.com:sarahwooders/MemGPT
sarahwooders Nov 5, 2023
4ab4f2d
add prints
sarahwooders Nov 5, 2023
cc94b4e
Merge branch 'main' of github.com:sarahwooders/MemGPT
sarahwooders Nov 6, 2023
9d1707d
Merge branch 'main' of github.com:sarahwooders/MemGPT
sarahwooders Nov 7, 2023
1782bb9
Merge branch 'main' of github.com:sarahwooders/MemGPT
sarahwooders Nov 7, 2023
caaf476
Merge branch 'main' of github.com:sarahwooders/MemGPT
sarahwooders Nov 7, 2023
6692bca
update lock
sarahwooders Nov 7, 2023
7cc1b9f
Merge branch 'cpacker:main' into main
sarahwooders Nov 7, 2023
728531e
Merge branch 'cpacker:main' into main
sarahwooders Nov 8, 2023
06e971c
Merge branch 'cpacker:main' into main
sarahwooders Nov 8, 2023
2dce7b2
use config file for defaults
sarahwooders Nov 9, 2023
49872b3
set questionary defaults with config
sarahwooders Nov 9, 2023
2b4ace1
Merge branch 'cpacker:main' into main
sarahwooders Nov 9, 2023
140f8d1
add config
sarahwooders Nov 9, 2023
e867e9d
fix bugs
sarahwooders Nov 9, 2023
b731990
Merge branch 'cpacker:main' into main
sarahwooders Nov 9, 2023
6cb15f2
Merge branch 'main' into default-configure
sarahwooders Nov 9, 2023
1cd3ac6
fix tests
sarahwooders Nov 9, 2023
10c3c06
fix tests
sarahwooders Nov 9, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 31 additions & 16 deletions memgpt/cli/cli_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,11 @@ def configure():

MemGPTConfig.create_config_dir()

# Will pre-populate with defaults, or what the user previously set
config = MemGPTConfig.load()

# openai credentials
use_openai = questionary.confirm("Do you want to enable MemGPT with Open AI?").ask()
use_openai = questionary.confirm("Do you want to enable MemGPT with Open AI?", default=True).ask()
if use_openai:
# search for key in enviornment
openai_key = os.getenv("OPENAI_API_KEY")
Expand All @@ -37,7 +40,7 @@ def configure():
# openai_key = questionary.text("Open AI API keys not found in enviornment - please enter:").ask()

# azure credentials
use_azure = questionary.confirm("Do you want to enable MemGPT with Azure?", default=False).ask()
use_azure = questionary.confirm("Do you want to enable MemGPT with Azure?", default=(config.azure_key is not None)).ask()
use_azure_deployment_ids = False
if use_azure:
# search for key in enviornment
Expand Down Expand Up @@ -69,49 +72,58 @@ def configure():
model_endpoint_options = []
if os.getenv("OPENAI_API_BASE") is not None:
model_endpoint_options.append(os.getenv("OPENAI_API_BASE"))
if use_azure:
model_endpoint_options += ["azure"]
if use_openai:
model_endpoint_options += ["openai"]

if use_azure:
model_endpoint_options += ["azure"]
assert len(model_endpoint_options) > 0, "No endpoints found. Please enable OpenAI, Azure, or set OPENAI_API_BASE."
default_endpoint = questionary.select("Select default inference endpoint:", model_endpoint_options).ask()
valid_default_model = config.model_endpoint in model_endpoint_options
default_endpoint = questionary.select(
"Select default inference endpoint:",
model_endpoint_options,
default=config.model_endpoint if valid_default_model else model_endpoint_options[0],
).ask()

# configure embedding provider
embedding_endpoint_options = ["local"] # cannot configure custom endpoint (too confusing)
if use_azure:
model_endpoint_options += ["azure"]
embedding_endpoint_options += ["azure"]
if use_openai:
model_endpoint_options += ["openai"]
default_embedding_endpoint = questionary.select("Select default embedding endpoint:", embedding_endpoint_options).ask()
embedding_endpoint_options += ["openai"]
valid_default_embedding = config.embedding_model in embedding_endpoint_options
default_embedding_endpoint = questionary.select(
"Select default embedding endpoint:",
embedding_endpoint_options,
default=config.embedding_model if valid_default_embedding else embedding_endpoint_options[-1],
).ask()

# configure embedding dimentions
default_embedding_dim = 1536
default_embedding_dim = config.embedding_dim
if default_embedding_endpoint == "local":
# HF model uses lower dimentionality
default_embedding_dim = 384

# configure preset
default_preset = questionary.select("Select default preset:", preset_options, default=DEFAULT_PRESET).ask()
default_preset = questionary.select("Select default preset:", preset_options, default=config.preset).ask()

# default model
if use_openai or use_azure:
model_options = []
if use_openai:
model_options += ["gpt-4", "gpt-4-1106-preview", "gpt-3.5-turbo-16k"]
default_model = questionary.select(
"Select default model (recommended: gpt-4):", choices=["gpt-4", "gpt-4-1106-preview", "gpt-3.5-turbo-16k"], default="gpt-4"
"Select default model (recommended: gpt-4):", choices=["gpt-4", "gpt-4-1106-preview", "gpt-3.5-turbo-16k"], default=config.model
).ask()
else:
default_model = "local" # TODO: figure out if this is ok? this is for local endpoint

# defaults
personas = [os.path.basename(f).replace(".txt", "") for f in utils.list_persona_files()]
print(personas)
default_persona = questionary.select("Select default persona:", personas, default="sam_pov").ask()
default_persona = questionary.select("Select default persona:", personas, default=config.default_persona).ask()
humans = [os.path.basename(f).replace(".txt", "") for f in utils.list_human_files()]
print(humans)
default_human = questionary.select("Select default human:", humans, default="cs_phd").ask()
default_human = questionary.select("Select default human:", humans, default=config.default_human).ask()

# TODO: figure out if we should set a default agent or not
default_agent = None
Expand All @@ -126,11 +138,14 @@ def configure():

# Configure archival storage backend
archival_storage_options = ["local", "postgres"]
archival_storage_type = questionary.select("Select storage backend for archival data:", archival_storage_options, default="local").ask()
archival_storage_type = questionary.select(
"Select storage backend for archival data:", archival_storage_options, default=config.archival_storage_type
).ask()
archival_storage_uri = None
if archival_storage_type == "postgres":
archival_storage_uri = questionary.text(
"Enter postgres connection string (e.g. postgresql+pg8000://{user}:{password}@{ip}:5432/{database}):"
"Enter postgres connection string (e.g. postgresql+pg8000://{user}:{password}@{ip}:5432/{database}):",
default=config.archival_storage_uri if config.archival_storage_uri else "",
).ask()

# TODO: allow configuring embedding model
Expand Down
2 changes: 1 addition & 1 deletion memgpt/connectors/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(self, name: Optional[str] = None, agent_config: Optional[AgentConfi
# create table
self.uri = config.archival_storage_uri
if config.archival_storage_uri is None:
raise ValueError(f"Must specifiy archival_storage_uri in config")
raise ValueError(f"Must specifiy archival_storage_uri in config {config.config_path}")
self.db_model = get_db_model(self.table_name)
self.engine = create_engine(self.uri)
Base.metadata.create_all(self.engine) # Create the table if it doesn't exist
Expand Down
19 changes: 12 additions & 7 deletions tests/test_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ def test_postgres_openai():
if os.getenv("OPENAI_API_KEY") is None:
return # soft pass

os.environ["MEMGPT_CONFIG_FILE"] = "./config"
config = MemGPTConfig()
config.archival_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL") # the URI for a postgres DB w/ the pgvector extension
# os.environ["MEMGPT_CONFIG_PATH"] = "./config"
config = MemGPTConfig(archival_storage_type="postgres", archival_storage_uri=os.getenv("PGVECTOR_TEST_DB_URL"))
print(config.config_path)
assert config.archival_storage_uri is not None
config.archival_storage_uri = config.archival_storage_uri.replace(
"postgres://", "postgresql://"
Expand Down Expand Up @@ -56,10 +56,15 @@ def test_postgres_openai():

def test_postgres_local():
assert os.getenv("PGVECTOR_TEST_DB_URL") is not None
os.environ["MEMGPT_CONFIG_FILE"] = "./config"

config = MemGPTConfig(embedding_model="local", embedding_dim=384) # use HF model
config.archival_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL") # the URI for a postgres DB w/ the pgvector extension
# os.environ["MEMGPT_CONFIG_PATH"] = "./config"

config = MemGPTConfig(
archival_storage_type="postgres",
archival_storage_uri=os.getenv("PGVECTOR_TEST_DB_URL"),
embedding_model="local",
embedding_dim=384, # use HF model
)
print(config.config_path)
assert config.archival_storage_uri is not None
config.archival_storage_uri = config.archival_storage_uri.replace(
"postgres://", "postgresql://"
Expand Down