Skip to content

Commit 14ffdb2

Browse files
committed
Connect new settings api
1 parent d6407f1 commit 14ffdb2

File tree

4 files changed

+65
-43
lines changed

4 files changed

+65
-43
lines changed

shelloracle/config/config.py

Lines changed: 41 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,58 @@
1+
from collections.abc import MutableMapping
12
from pathlib import Path
2-
from typing import TextIO
33

4-
from tomlkit import document, table, load
4+
import tomlkit
55

66
data_home = Path.home() / "Library/Application Support" / "shelloracle"
77

88

9-
def _default_config():
10-
doc = document()
11-
shelloracle = table()
12-
shelloracle.add("provider", "ollama")
13-
doc.add("shelloracle", shelloracle)
14-
return doc
9+
def _default_config() -> tomlkit.TOMLDocument:
10+
config = tomlkit.document()
11+
shor_table = tomlkit.table()
12+
shor_table.add("provider", "Ollama")
13+
config.add("shelloracle", shor_table)
14+
return config
1515

1616

17-
class SingletonMeta(type):
18-
_instances = {}
17+
class Configuration(MutableMapping):
18+
filepath = data_home / "config.toml"
1919

20-
def __call__(cls, *args, **kwargs):
21-
if cls not in cls._instances:
22-
cls._instances[cls] = super().__call__(*args, **kwargs)
23-
return cls._instances[cls]
20+
def __init__(self) -> None:
21+
self._ensure_config_exists()
2422

23+
def __getitem__(self, item):
24+
with self.filepath.open("r") as file:
25+
config = tomlkit.load(file)
26+
return config[item]
2527

26-
class Configuration(metaclass=SingletonMeta):
27-
filepath = data_home / "config.toml"
28+
def __setitem__(self, key, value):
29+
with self.filepath.open("r") as file:
30+
config = tomlkit.load(file)
31+
config[key] = value
32+
config.multiline = True
33+
with self.filepath.open("w") as file:
34+
tomlkit.dump(config, file)
2835

29-
def __init__(self):
30-
self.file: TextIO | None = None
31-
self._ensure_config_exists()
36+
def __delitem__(self, key):
37+
raise NotImplementedError()
3238

33-
def _ensure_config_exists(self):
34-
if not self.filepath.exists():
35-
data_home.mkdir(exist_ok=True)
36-
self.filepath.write_text(_default_config())
39+
def __iter__(self):
40+
raise NotImplementedError()
41+
42+
def __len__(self) -> int:
43+
raise NotImplementedError()
44+
45+
def _ensure_config_exists(self) -> None:
46+
if self.filepath.exists():
47+
return
48+
data_home.mkdir(exist_ok=True)
49+
config = _default_config()
50+
with self.filepath.open("w") as file:
51+
tomlkit.dump(config, file)
3752

3853
@property
3954
def provider(self) -> str | None:
40-
with self.filepath.open("r") as config_file:
41-
file = load(config_file)
42-
return file.get("shelloracle", {}).get("provider", None)
55+
return self["shelloracle"]["provider"]
56+
4357

58+
global_config = Configuration()

shelloracle/config/setting.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,27 @@
22

33
from typing import TypeVar, Generic, TYPE_CHECKING
44

5+
from . import config
6+
57
if TYPE_CHECKING:
68
from ..provider import Provider
79

810
T = TypeVar("T")
911

1012

1113
class Setting(Generic[T]):
12-
def __init__(self, *, default: T | None = None):
14+
def __init__(self, *, default: T | None = None) -> None:
1315
self.default = default
1416

15-
def __set_name__(self, owner: Provider, name: str) -> None:
17+
def __set_name__(self, owner: type[Provider], name: str) -> None:
1618
self.name = name
19+
# Set the default value in the config dictionary
20+
provider_table = config.global_config.get("provider", {})
21+
provider_table.setdefault(owner.name, {}).setdefault(name, self.default)
22+
config.global_config["provider"] = provider_table
1723

1824
def __get__(self, instance: Provider, owner: type[Provider]) -> T:
19-
return self.default
25+
return config.global_config.get("provider", {}).get(instance.name, {})[self.name]
2026

2127
def __set__(self, instance: Provider, value: T) -> None:
22-
self.default = value
28+
config.global_config.setdefault("provider", {}).setdefault(instance.name, {})[self.name] = value

shelloracle/provider/ollama.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -48,22 +48,22 @@ class GenerateRequest:
4848
history yourself. JSON mode"""
4949

5050

51-
system_prompt = (
52-
"Based on the following user description, generate a corresponding Bash command. Focus solely on interpreting "
53-
"the requirements and translating them into a single, executable Bash command. Ensure accuracy and relevance "
54-
"to the user's description. The output should be a valid Bash command that directly aligns with the user's "
55-
"intent, ready for execution in a command-line environment. Output nothing except for the command. No code "
56-
"block, no English explanation, no start/end tags."
57-
)
58-
59-
6051
class Ollama(Provider):
6152
name = "Ollama"
6253

6354
host = Setting(default="localhost")
6455
port = Setting(default=11434)
6556
model = Setting(default="codellama:13b")
66-
system_prompt = Setting(default=system_prompt)
57+
system_prompt = Setting(
58+
default=(
59+
"Based on the following user description, generate a corresponding Bash command. Focus solely "
60+
"on interpreting the requirements and translating them into a single, executable Bash command. "
61+
"Ensure accuracy and relevance to the user's description. The output should be a valid Bash "
62+
"command that directly aligns with the user's intent, ready for execution in a command-line "
63+
"environment. Output nothing except for the command. No code block, no English explanation, "
64+
"no start/end tags."
65+
)
66+
)
6767

6868
@property
6969
def endpoint(self):

shelloracle/shelloracle.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@
66
from prompt_toolkit.application import create_app_session_from_tty
77
from prompt_toolkit.history import FileHistory
88

9-
from .config import Configuration, data_home
9+
from .config import data_home
1010
from .provider import get_provider
1111

12+
from .config import config
13+
1214

1315
async def prompt_user(default_prompt: str | None = None) -> str:
1416
with create_app_session_from_tty():
@@ -47,8 +49,7 @@ async def shell_oracle() -> None:
4749
4850
:returns: None
4951
"""
50-
config = Configuration()
51-
provider = get_provider(config.provider)()
52+
provider = get_provider(config.global_config.provider)()
5253

5354
if not (prompt := get_query_from_pipe()):
5455
default_prompt = os.environ.get("SHOR_DEFAULT_PROMPT")

0 commit comments

Comments
 (0)