Skip to content

Commit

Permalink
refactor: group all provider/model related logic into llm directory (
Browse files Browse the repository at this point in the history
…#254)

* Move provider files in provider directory

* Move all specific code into abstraction

* Fix first round of feedback

* Second round of feedbacks

* fix(llm): improve type hints and handle missing files

- Add 'unknown' as valid provider type in ModelMeta
- Handle missing 'files' key in Anthropic message dict
- Organize imports and improve type hints

* fix: remove files field from anthropic messages + refactor

- Fix: Remove 'files' field from message dict in anthropic provider
- Refactor: Extract file processing logic in both providers
  - Split handle_files into smaller functions
  - Make internal functions private with underscore prefix
  - Improve code organization and readability

---------

Co-authored-by: Erik Bjäreholt <erik@bjareho.lt>
  • Loading branch information
jrmi and ErikBjare authored Nov 22, 2024
1 parent 0fbafd0 commit ac3c173
Show file tree
Hide file tree
Showing 12 changed files with 263 additions and 153 deletions.
10 changes: 7 additions & 3 deletions gptme/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from .llm import reply
from .logmanager import Log, LogManager, prepare_messages
from .message import Message
from .models import get_model
from .llm.models import get_model
from .prompts import get_workspace_prompt
from .readline import add_history
from .tools import ToolUse, execute_msg, has_tool
Expand Down Expand Up @@ -56,8 +56,12 @@ def chat(
# init
init(model, interactive, tool_allowlist)

if model and model.startswith("openai/o1") and stream:
logger.info("Disabled streaming for OpenAI's O1 (not supported)")
if not get_model().supports_streaming and stream:
logger.info(
"Disabled streaming for '%s/%s' model (not supported)",
get_model().provider,
get_model().model,
)
stream = False

console.log(f"Using logdir {path_with_tilde(logdir)}")
Expand Down
2 changes: 1 addition & 1 deletion gptme/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
print_msg,
toml_to_msgs,
)
from .models import get_model
from .llm.models import get_model
from .tools import ToolUse, execute_msg, loaded_tools
from .tools.base import ConfirmFunc
from .useredit import edit_text_with_editor
Expand Down
29 changes: 9 additions & 20 deletions gptme/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from rich.logging import RichHandler

from .config import config_path, get_config, set_config_value
from .llm import init_llm
from .models import (
from .llm import get_model_from_api_key, guess_model_from_config, init_llm
from .llm.models import (
PROVIDERS,
Provider,
get_recommended_model,
Expand Down Expand Up @@ -39,18 +39,11 @@ def init(model: str | None, interactive: bool, tool_allowlist: list[str] | None)

if not model: # pragma: no cover
# auto-detect depending on if OPENAI_API_KEY or ANTHROPIC_API_KEY is set
if config.get_env("OPENAI_API_KEY"):
console.log("Found OpenAI API key, using OpenAI provider")
model = "openai"
elif config.get_env("ANTHROPIC_API_KEY"):
console.log("Found Anthropic API key, using Anthropic provider")
model = "anthropic"
elif config.get_env("OPENROUTER_API_KEY"):
console.log("Found OpenRouter API key, using OpenRouter provider")
model = "openrouter"
# ask user for API key
elif interactive:
model, _ = ask_for_api_key()
model = guess_model_from_config()

# ask user for API key
if not model and interactive:
model, _ = ask_for_api_key()

# fail
if not model:
Expand Down Expand Up @@ -90,12 +83,8 @@ def init_logging(verbose):

def _prompt_api_key() -> tuple[str, str, str]: # pragma: no cover
api_key = input("Your OpenAI, Anthropic, or OpenRouter API key: ").strip()
if api_key.startswith("sk-ant-"):
return api_key, "anthropic", "ANTHROPIC_API_KEY"
elif api_key.startswith("sk-or-"):
return api_key, "openrouter", "OPENROUTER_API_KEY"
elif api_key.startswith("sk-"):
return api_key, "openai", "OPENAI_API_KEY"
if (found_model_tuple := get_model_from_api_key(api_key)) is not None:
return found_model_tuple
else:
console.print("Invalid API key format. Please try again.")
return _prompt_api_key()
Expand Down
45 changes: 40 additions & 5 deletions gptme/llm.py → gptme/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

from rich import print

from .config import get_config
from .constants import PROMPT_ASSISTANT
from ..config import get_config
from ..constants import PROMPT_ASSISTANT
from .llm_anthropic import chat as chat_anthropic
from .llm_anthropic import get_client as get_anthropic_client
from .llm_anthropic import init as init_anthropic
Expand All @@ -17,15 +17,15 @@
from .llm_openai import get_client as get_openai_client
from .llm_openai import init as init_openai
from .llm_openai import stream as stream_openai
from .message import Message, format_msgs, len_tokens
from ..message import Message, format_msgs, len_tokens
from .models import (
MODELS,
PROVIDERS_OPENAI,
Provider,
get_summary_model,
)
from .tools import ToolUse
from .util import console
from ..tools import ToolUse
from ..util import console

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -234,3 +234,38 @@ def _summarize_helper(s: str, tok_max_start=400, tok_max_end=400) -> str:
else:
summary = _summarize_str(s)
return summary


def guess_model_from_config() -> Provider | None:
"""
Guess the model to use from the configuration.
"""

config = get_config()

if config.get_env("OPENAI_API_KEY"):
console.log("Found OpenAI API key, using OpenAI provider")
return "openai"
elif config.get_env("ANTHROPIC_API_KEY"):
console.log("Found Anthropic API key, using Anthropic provider")
return "anthropic"
elif config.get_env("OPENROUTER_API_KEY"):
console.log("Found OpenRouter API key, using OpenRouter provider")
return "openrouter"

return None


def get_model_from_api_key(api_key: str) -> tuple[str, Provider, str] | None:
"""
Guess the model from the API key prefix.
"""

if api_key.startswith("sk-ant-"):
return api_key, "anthropic", "ANTHROPIC_API_KEY"
elif api_key.startswith("sk-or-"):
return api_key, "openrouter", "OPENROUTER_API_KEY"
elif api_key.startswith("sk-"):
return api_key, "openai", "OPENAI_API_KEY"

return None
93 changes: 86 additions & 7 deletions gptme/llm_anthropic.py → gptme/llm/llm_anthropic.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,29 @@
import base64
import logging
from collections.abc import Generator
from typing import TYPE_CHECKING, Literal, TypedDict
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
Literal,
TypedDict,
)

from typing_extensions import Required

from .constants import TEMPERATURE, TOP_P
from .message import Message, len_tokens, msgs2dicts
from ..constants import TEMPERATURE, TOP_P
from ..message import Message, len_tokens, msgs2dicts

if TYPE_CHECKING:
from anthropic import Anthropic
logger = logging.getLogger(__name__)


if TYPE_CHECKING:
from anthropic import Anthropic # fmt: skip

anthropic: "Anthropic | None" = None

ALLOWED_FILE_EXTS = ["jpg", "jpeg", "png", "gif"]


def init(config):
global anthropic
Expand All @@ -38,7 +50,9 @@ class MessagePart(TypedDict, total=False):
def chat(messages: list[Message], model: str) -> str:
assert anthropic, "LLM not initialized"
messages, system_messages = _transform_system_messages(messages)
messages_dicts = msgs2dicts(messages, provider="anthropic")

messages_dicts = _handle_files(msgs2dicts(messages))

response = anthropic.beta.prompt_caching.messages.create(
model=model,
messages=messages_dicts, # type: ignore
Expand All @@ -56,7 +70,9 @@ def chat(messages: list[Message], model: str) -> str:
def stream(messages: list[Message], model: str) -> Generator[str, None, None]:
assert anthropic, "LLM not initialized"
messages, system_messages = _transform_system_messages(messages)
messages_dicts = msgs2dicts(messages, provider="anthropic")

messages_dicts = _handle_files(msgs2dicts(messages))

with anthropic.beta.prompt_caching.messages.stream(
model=model,
messages=messages_dicts, # type: ignore
Expand All @@ -68,6 +84,69 @@ def stream(messages: list[Message], model: str) -> Generator[str, None, None]:
yield from stream.text_stream


def _handle_files(message_dicts: list[dict]) -> list[dict]:
return [_process_file(message_dict) for message_dict in message_dicts]


def _process_file(message_dict: dict) -> dict:
message_content = message_dict["content"]

# combines a content message with a list of files
content: list[dict[str, Any]] = (
message_content
if isinstance(message_content, list)
else [{"type": "text", "text": message_content}]
)

for f in message_dict.pop("files", []):
f = Path(f)
ext = f.suffix[1:]
if ext not in ALLOWED_FILE_EXTS:
logger.warning("Unsupported file type: %s", ext)
continue
if ext == "jpg":
ext = "jpeg"
media_type = f"image/{ext}"

content.append(
{
"type": "text",
"text": f"![{f.name}]({f.name}):",
}
)

# read file
data_bytes = f.read_bytes()
data = base64.b64encode(data_bytes).decode("utf-8")

# check that the file is not too large
# anthropic limit is 5MB, seems to measure the base64-encoded size instead of raw bytes
# TODO: use compression to reduce file size
# print(f"{len(data)=}")
if len(data) > 5 * 1_024 * 1_024:
content.append(
{
"type": "text",
"text": "Image size exceeds 5MB. Please upload a smaller image.",
}
)
continue

content.append(
{
"type": "image",
"source": {
"type": "base64",
"media_type": media_type,
"data": data,
},
}
)

message_dict["content"] = content
return message_dict


def _transform_system_messages(
messages: list[Message],
) -> tuple[list[Message], list[MessagePart]]:
Expand Down
Loading

0 comments on commit ac3c173

Please sign in to comment.