Skip to content
This repository has been archived by the owner on Jan 7, 2025. It is now read-only.

Commit

Permalink
Move model check to splash message (#532)
Browse files Browse the repository at this point in the history
  • Loading branch information
jakethekoenig authored Feb 23, 2024
1 parent 6219d49 commit 9178325
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 15 deletions.
14 changes: 0 additions & 14 deletions mentat/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,20 +45,6 @@ async def display_token_count(self):
config = session_context.config
code_context = session_context.code_context

if "gpt-4" not in config.model:
stream.send(
"Warning: Mentat has only been tested on GPT-4. You may experience"
" issues with quality. This model may not be able to respond in"
" mentat's edit format.",
style="warning",
)
if "gpt-3.5" not in config.model:
stream.send(
"Warning: Mentat does not know how to calculate costs or context"
" size for this model.",
style="warning",
)

messages = self.get_messages()
code_message = await code_context.get_code_message(
prompt_tokens(
Expand Down
4 changes: 3 additions & 1 deletion mentat/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from mentat.session_context import SESSION_CONTEXT, SessionContext
from mentat.session_input import collect_input_with_commands
from mentat.session_stream import SessionStream
from mentat.splash_messages import check_version
from mentat.splash_messages import check_model, check_version
from mentat.utils import mentat_dir_path
from mentat.vision.vision_manager import VisionManager

Expand Down Expand Up @@ -144,6 +144,8 @@ async def _main(self):
ensure_ctags_installed()

session_context.llm_api_handler.initialize_client()

check_model()
await conversation.display_token_count()

stream.send("Type 'q' or use Ctrl-C to quit at any time.")
Expand Down
18 changes: 18 additions & 0 deletions mentat/splash_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,21 @@ def check_version():
f.write(__version__)
except Exception as err:
ctx.stream.send(f"Error checking for most recent version: {err}", style="error")


def check_model():
ctx = SESSION_CONTEXT.get()
model = ctx.config.model
if "gpt-4" not in model:
ctx.stream.send(
"Warning: Mentat has only been tested on GPT-4. You may experience"
" issues with quality. This model may not be able to respond in"
" mentat's edit format.",
style="warning",
)
if "gpt-3.5" not in model:
ctx.stream.send(
"Warning: Mentat does not know how to calculate costs or context"
" size for this model.",
style="warning",
)
44 changes: 44 additions & 0 deletions tests/splash_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest

from mentat import Mentat
from mentat.config import Config


@pytest.mark.asyncio
Expand Down Expand Up @@ -65,3 +66,46 @@ async def test_not_display_new_version_splash_message(mock_get):
not in mentat._accumulated_message
)
await mentat.shutdown()


@pytest.mark.asyncio
async def test_check_model():
mentat = Mentat(config=Config(model="test"))

await mentat.startup()
await asyncio.sleep(0.01)
assert (
"Warning: Mentat has only been tested on GPT-4" in mentat._accumulated_message
)
assert (
"Warning: Mentat does not know how to calculate costs or context"
in mentat._accumulated_message
)
await mentat.shutdown()

mentat = Mentat(config=Config(model="gpt-3.5"))

await mentat.startup()
await asyncio.sleep(0.01)
assert (
"Warning: Mentat has only been tested on GPT-4" in mentat._accumulated_message
)
assert (
"Warning: Mentat does not know how to calculate costs or context"
not in mentat._accumulated_message
)
await mentat.shutdown()

mentat = Mentat(config=Config(model="gpt-4"))

await mentat.startup()
await asyncio.sleep(0.01)
assert (
"Warning: Mentat has only been tested on GPT-4"
not in mentat._accumulated_message
)
assert (
"Warning: Mentat does not know how to calculate costs or context"
not in mentat._accumulated_message
)
await mentat.shutdown()

0 comments on commit 9178325

Please sign in to comment.