Skip to content

Commit ac3c173

Browse files
jrmiErikBjare
andauthored
refactor: group all provider/model related logic into llm directory (#254)
* Move provider files in provider directory * Move all specific code into abstraction * Fix first round of feedback * Second round of feedbacks * fix(llm): improve type hints and handle missing files - Add 'unknown' as valid provider type in ModelMeta - Handle missing 'files' key in Anthropic message dict - Organize imports and improve type hints * fix: remove files field from anthropic messages + refactor - Fix: Remove 'files' field from message dict in anthropic provider - Refactor: Extract file processing logic in both providers - Split handle_files into smaller functions - Make internal functions private with underscore prefix - Improve code organization and readability --------- Co-authored-by: Erik Bjäreholt <erik@bjareho.lt>
1 parent 0fbafd0 commit ac3c173

File tree

12 files changed

+263
-153
lines changed

12 files changed

+263
-153
lines changed

gptme/chat.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from .llm import reply
1616
from .logmanager import Log, LogManager, prepare_messages
1717
from .message import Message
18-
from .models import get_model
18+
from .llm.models import get_model
1919
from .prompts import get_workspace_prompt
2020
from .readline import add_history
2121
from .tools import ToolUse, execute_msg, has_tool
@@ -56,8 +56,12 @@ def chat(
5656
# init
5757
init(model, interactive, tool_allowlist)
5858

59-
if model and model.startswith("openai/o1") and stream:
60-
logger.info("Disabled streaming for OpenAI's O1 (not supported)")
59+
if not get_model().supports_streaming and stream:
60+
logger.info(
61+
"Disabled streaming for '%s/%s' model (not supported)",
62+
get_model().provider,
63+
get_model().model,
64+
)
6165
stream = False
6266

6367
console.log(f"Using logdir {path_with_tilde(logdir)}")

gptme/commands.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
print_msg,
1717
toml_to_msgs,
1818
)
19-
from .models import get_model
19+
from .llm.models import get_model
2020
from .tools import ToolUse, execute_msg, loaded_tools
2121
from .tools.base import ConfirmFunc
2222
from .useredit import edit_text_with_editor

gptme/init.py

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
from rich.logging import RichHandler
66

77
from .config import config_path, get_config, set_config_value
8-
from .llm import init_llm
9-
from .models import (
8+
from .llm import get_model_from_api_key, guess_model_from_config, init_llm
9+
from .llm.models import (
1010
PROVIDERS,
1111
Provider,
1212
get_recommended_model,
@@ -39,18 +39,11 @@ def init(model: str | None, interactive: bool, tool_allowlist: list[str] | None)
3939

4040
if not model: # pragma: no cover
4141
# auto-detect depending on if OPENAI_API_KEY or ANTHROPIC_API_KEY is set
42-
if config.get_env("OPENAI_API_KEY"):
43-
console.log("Found OpenAI API key, using OpenAI provider")
44-
model = "openai"
45-
elif config.get_env("ANTHROPIC_API_KEY"):
46-
console.log("Found Anthropic API key, using Anthropic provider")
47-
model = "anthropic"
48-
elif config.get_env("OPENROUTER_API_KEY"):
49-
console.log("Found OpenRouter API key, using OpenRouter provider")
50-
model = "openrouter"
51-
# ask user for API key
52-
elif interactive:
53-
model, _ = ask_for_api_key()
42+
model = guess_model_from_config()
43+
44+
# ask user for API key
45+
if not model and interactive:
46+
model, _ = ask_for_api_key()
5447

5548
# fail
5649
if not model:
@@ -90,12 +83,8 @@ def init_logging(verbose):
9083

9184
def _prompt_api_key() -> tuple[str, str, str]: # pragma: no cover
9285
api_key = input("Your OpenAI, Anthropic, or OpenRouter API key: ").strip()
93-
if api_key.startswith("sk-ant-"):
94-
return api_key, "anthropic", "ANTHROPIC_API_KEY"
95-
elif api_key.startswith("sk-or-"):
96-
return api_key, "openrouter", "OPENROUTER_API_KEY"
97-
elif api_key.startswith("sk-"):
98-
return api_key, "openai", "OPENAI_API_KEY"
86+
if (found_model_tuple := get_model_from_api_key(api_key)) is not None:
87+
return found_model_tuple
9988
else:
10089
console.print("Invalid API key format. Please try again.")
10190
return _prompt_api_key()

gptme/llm.py renamed to gptme/llm/__init__.py

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77

88
from rich import print
99

10-
from .config import get_config
11-
from .constants import PROMPT_ASSISTANT
10+
from ..config import get_config
11+
from ..constants import PROMPT_ASSISTANT
1212
from .llm_anthropic import chat as chat_anthropic
1313
from .llm_anthropic import get_client as get_anthropic_client
1414
from .llm_anthropic import init as init_anthropic
@@ -17,15 +17,15 @@
1717
from .llm_openai import get_client as get_openai_client
1818
from .llm_openai import init as init_openai
1919
from .llm_openai import stream as stream_openai
20-
from .message import Message, format_msgs, len_tokens
20+
from ..message import Message, format_msgs, len_tokens
2121
from .models import (
2222
MODELS,
2323
PROVIDERS_OPENAI,
2424
Provider,
2525
get_summary_model,
2626
)
27-
from .tools import ToolUse
28-
from .util import console
27+
from ..tools import ToolUse
28+
from ..util import console
2929

3030
logger = logging.getLogger(__name__)
3131

@@ -234,3 +234,38 @@ def _summarize_helper(s: str, tok_max_start=400, tok_max_end=400) -> str:
234234
else:
235235
summary = _summarize_str(s)
236236
return summary
237+
238+
239+
def guess_model_from_config() -> Provider | None:
240+
"""
241+
Guess the model to use from the configuration.
242+
"""
243+
244+
config = get_config()
245+
246+
if config.get_env("OPENAI_API_KEY"):
247+
console.log("Found OpenAI API key, using OpenAI provider")
248+
return "openai"
249+
elif config.get_env("ANTHROPIC_API_KEY"):
250+
console.log("Found Anthropic API key, using Anthropic provider")
251+
return "anthropic"
252+
elif config.get_env("OPENROUTER_API_KEY"):
253+
console.log("Found OpenRouter API key, using OpenRouter provider")
254+
return "openrouter"
255+
256+
return None
257+
258+
259+
def get_model_from_api_key(api_key: str) -> tuple[str, Provider, str] | None:
260+
"""
261+
Guess the model from the API key prefix.
262+
"""
263+
264+
if api_key.startswith("sk-ant-"):
265+
return api_key, "anthropic", "ANTHROPIC_API_KEY"
266+
elif api_key.startswith("sk-or-"):
267+
return api_key, "openrouter", "OPENROUTER_API_KEY"
268+
elif api_key.startswith("sk-"):
269+
return api_key, "openai", "OPENAI_API_KEY"
270+
271+
return None

gptme/llm_anthropic.py renamed to gptme/llm/llm_anthropic.py

Lines changed: 86 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,29 @@
1+
import base64
2+
import logging
13
from collections.abc import Generator
2-
from typing import TYPE_CHECKING, Literal, TypedDict
4+
from pathlib import Path
5+
from typing import (
6+
TYPE_CHECKING,
7+
Any,
8+
Literal,
9+
TypedDict,
10+
)
311

412
from typing_extensions import Required
513

6-
from .constants import TEMPERATURE, TOP_P
7-
from .message import Message, len_tokens, msgs2dicts
14+
from ..constants import TEMPERATURE, TOP_P
15+
from ..message import Message, len_tokens, msgs2dicts
816

9-
if TYPE_CHECKING:
10-
from anthropic import Anthropic
17+
logger = logging.getLogger(__name__)
1118

1219

20+
if TYPE_CHECKING:
21+
from anthropic import Anthropic # fmt: skip
22+
1323
anthropic: "Anthropic | None" = None
1424

25+
ALLOWED_FILE_EXTS = ["jpg", "jpeg", "png", "gif"]
26+
1527

1628
def init(config):
1729
global anthropic
@@ -38,7 +50,9 @@ class MessagePart(TypedDict, total=False):
3850
def chat(messages: list[Message], model: str) -> str:
3951
assert anthropic, "LLM not initialized"
4052
messages, system_messages = _transform_system_messages(messages)
41-
messages_dicts = msgs2dicts(messages, provider="anthropic")
53+
54+
messages_dicts = _handle_files(msgs2dicts(messages))
55+
4256
response = anthropic.beta.prompt_caching.messages.create(
4357
model=model,
4458
messages=messages_dicts, # type: ignore
@@ -56,7 +70,9 @@ def chat(messages: list[Message], model: str) -> str:
5670
def stream(messages: list[Message], model: str) -> Generator[str, None, None]:
5771
assert anthropic, "LLM not initialized"
5872
messages, system_messages = _transform_system_messages(messages)
59-
messages_dicts = msgs2dicts(messages, provider="anthropic")
73+
74+
messages_dicts = _handle_files(msgs2dicts(messages))
75+
6076
with anthropic.beta.prompt_caching.messages.stream(
6177
model=model,
6278
messages=messages_dicts, # type: ignore
@@ -68,6 +84,69 @@ def stream(messages: list[Message], model: str) -> Generator[str, None, None]:
6884
yield from stream.text_stream
6985

7086

87+
def _handle_files(message_dicts: list[dict]) -> list[dict]:
88+
return [_process_file(message_dict) for message_dict in message_dicts]
89+
90+
91+
def _process_file(message_dict: dict) -> dict:
92+
message_content = message_dict["content"]
93+
94+
# combines a content message with a list of files
95+
content: list[dict[str, Any]] = (
96+
message_content
97+
if isinstance(message_content, list)
98+
else [{"type": "text", "text": message_content}]
99+
)
100+
101+
for f in message_dict.pop("files", []):
102+
f = Path(f)
103+
ext = f.suffix[1:]
104+
if ext not in ALLOWED_FILE_EXTS:
105+
logger.warning("Unsupported file type: %s", ext)
106+
continue
107+
if ext == "jpg":
108+
ext = "jpeg"
109+
media_type = f"image/{ext}"
110+
111+
content.append(
112+
{
113+
"type": "text",
114+
"text": f"![{f.name}]({f.name}):",
115+
}
116+
)
117+
118+
# read file
119+
data_bytes = f.read_bytes()
120+
data = base64.b64encode(data_bytes).decode("utf-8")
121+
122+
# check that the file is not too large
123+
# anthropic limit is 5MB, seems to measure the base64-encoded size instead of raw bytes
124+
# TODO: use compression to reduce file size
125+
# print(f"{len(data)=}")
126+
if len(data) > 5 * 1_024 * 1_024:
127+
content.append(
128+
{
129+
"type": "text",
130+
"text": "Image size exceeds 5MB. Please upload a smaller image.",
131+
}
132+
)
133+
continue
134+
135+
content.append(
136+
{
137+
"type": "image",
138+
"source": {
139+
"type": "base64",
140+
"media_type": media_type,
141+
"data": data,
142+
},
143+
}
144+
)
145+
146+
message_dict["content"] = content
147+
return message_dict
148+
149+
71150
def _transform_system_messages(
72151
messages: list[Message],
73152
) -> tuple[list[Message], list[MessagePart]]:

0 commit comments

Comments
 (0)