Skip to content

Commit d6407f1

Browse files
committed
Get config basics down
1 parent c58e040 commit d6407f1

File tree

8 files changed

+106
-42
lines changed

8 files changed

+106
-42
lines changed

pyproject.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ name = "shelloracle"
77
version = "0.0.1"
88
dependencies = [
99
"httpx",
10-
"prompt-toolkit"
10+
"prompt-toolkit",
11+
"tomlkit"
1112
]
1213
authors = [
1314
{ name = "Daniel Copley", email = "djcopley@users.noreply.github.com" },
@@ -24,7 +25,7 @@ classifiers = [
2425
]
2526

2627
[tool.setuptools]
27-
packages = ["shelloracle", "shelloracle.provider"]
28+
packages = ["shelloracle", "shelloracle.config", "shelloracle.provider"]
2829

2930
[project.scripts]
3031
shor = "shelloracle.shelloracle:cli"

shelloracle/config.py

Lines changed: 0 additions & 24 deletions
This file was deleted.

shelloracle/config/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .config import Configuration, data_home
2+
from .setting import Setting

shelloracle/config/config.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from pathlib import Path
2+
from typing import TextIO
3+
4+
from tomlkit import document, table, load
5+
6+
data_home = Path.home() / "Library/Application Support" / "shelloracle"
7+
8+
9+
def _default_config():
10+
doc = document()
11+
shelloracle = table()
12+
shelloracle.add("provider", "ollama")
13+
doc.add("shelloracle", shelloracle)
14+
return doc
15+
16+
17+
class SingletonMeta(type):
18+
_instances = {}
19+
20+
def __call__(cls, *args, **kwargs):
21+
if cls not in cls._instances:
22+
cls._instances[cls] = super().__call__(*args, **kwargs)
23+
return cls._instances[cls]
24+
25+
26+
class Configuration(metaclass=SingletonMeta):
27+
filepath = data_home / "config.toml"
28+
29+
def __init__(self):
30+
self.file: TextIO | None = None
31+
self._ensure_config_exists()
32+
33+
def _ensure_config_exists(self):
34+
if not self.filepath.exists():
35+
data_home.mkdir(exist_ok=True)
36+
self.filepath.write_text(_default_config())
37+
38+
@property
39+
def provider(self) -> str | None:
40+
with self.filepath.open("r") as config_file:
41+
file = load(config_file)
42+
return file.get("shelloracle", {}).get("provider", None)
43+

shelloracle/config/setting.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from __future__ import annotations
2+
3+
from typing import TypeVar, Generic, TYPE_CHECKING
4+
5+
if TYPE_CHECKING:
6+
from ..provider import Provider
7+
8+
T = TypeVar("T")
9+
10+
11+
class Setting(Generic[T]):
12+
def __init__(self, *, default: T | None = None):
13+
self.default = default
14+
15+
def __set_name__(self, owner: Provider, name: str) -> None:
16+
self.name = name
17+
18+
def __get__(self, instance: Provider, owner: type[Provider]) -> T:
19+
return self.default
20+
21+
def __set__(self, instance: Provider, value: T) -> None:
22+
self.default = value

shelloracle/provider/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ class Provider(Protocol):
1414
1515
All LLM backends must implement this interface.
1616
"""
17-
name = ""
17+
name: str
1818

1919
@abstractmethod
2020
def generate(self, prompt: str) -> AsyncGenerator[str, None, None]:
@@ -36,6 +36,6 @@ def get_provider(name: str) -> type[Provider]:
3636
"""
3737
from .ollama import Ollama
3838
providers = {
39-
"ollama": Ollama
39+
"Ollama": Ollama
4040
}
4141
return providers[name]

shelloracle/provider/ollama.py

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,24 @@
11
import json
22
from collections.abc import AsyncGenerator
33
from dataclasses import dataclass, asdict
4+
from typing import Any
45

56
import httpx
67

78
from . import Provider, ProviderError
9+
from ..config import Setting
10+
11+
12+
def dataclass_to_json(obj: Any) -> dict[str, Any]:
13+
"""Convert dataclass to a json dict
14+
15+
This function filters out 'None' values.
16+
17+
:param obj: the dataclass to serialize
18+
:return: serialized dataclass
19+
:raises TypeError: if obj is not a dataclass
20+
"""
21+
return {k: v for k, v in asdict(obj).items() if v is not None}
822

923

1024
@dataclass
@@ -33,29 +47,35 @@ class GenerateRequest:
3347
the raw parameter if you are specifying a full templated prompt in your request to the API, and are managing
3448
history yourself. JSON mode"""
3549

36-
def to_json(self):
37-
return json.dumps(asdict(self))
50+
51+
system_prompt = (
52+
"Based on the following user description, generate a corresponding Bash command. Focus solely on interpreting "
53+
"the requirements and translating them into a single, executable Bash command. Ensure accuracy and relevance "
54+
"to the user's description. The output should be a valid Bash command that directly aligns with the user's "
55+
"intent, ready for execution in a command-line environment. Output nothing except for the command. No code "
56+
"block, no English explanation, no start/end tags."
57+
)
3858

3959

4060
class Ollama(Provider):
4161
name = "Ollama"
4262

43-
host = "localhost"
44-
port = 11434
45-
endpoint = f"http://{host}:{port}/api/generate"
63+
host = Setting(default="localhost")
64+
port = Setting(default=11434)
65+
model = Setting(default="codellama:13b")
66+
system_prompt = Setting(default=system_prompt)
4667

47-
model = "codellama:13b"
48-
system_prompt = """Based on the following user description, generate a corresponding Bash command. Focus solely on
49-
interpreting the requirements and translating them into a single, executable Bash command. Ensure accuracy and
50-
relevance to the user's description. The output should be a valid Bash command that directly aligns with the user's
51-
intent, ready for execution in a command-line environment. Output nothing except for the command. No code block, no
52-
English explanation, no start/end tags."""
68+
@property
69+
def endpoint(self):
70+
# computed property because python descriptors need to be bound to an instance before access
71+
return f"http://{self.host}:{self.port}/api/generate"
5372

5473
async def generate(self, prompt: str) -> AsyncGenerator[str, None, None]:
74+
request = GenerateRequest(self.model, prompt, system=system_prompt, stream=True)
75+
data = dataclass_to_json(request)
5576
try:
56-
request = GenerateRequest(self.model, prompt, system=self.system_prompt, stream=True).to_json()
5777
async with httpx.AsyncClient() as client:
58-
async with client.stream("POST", self.endpoint, content=request, timeout=20.0) as stream:
78+
async with client.stream("POST", self.endpoint, json=data, timeout=20.0) as stream:
5979
async for line in stream.aiter_lines():
6080
yield json.loads(line)["response"]
6181
except (httpx.HTTPError, httpx.StreamError) as e:

shelloracle/shelloracle.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
from prompt_toolkit.application import create_app_session_from_tty
77
from prompt_toolkit.history import FileHistory
88

9-
from .provider import get_provider
109
from .config import Configuration, data_home
10+
from .provider import get_provider
1111

1212

1313
async def prompt_user(default_prompt: str | None = None) -> str:

0 commit comments

Comments
 (0)