diff --git a/src/shelloracle/config.py b/src/shelloracle/config.py index d20905a..1d3be68 100644 --- a/src/shelloracle/config.py +++ b/src/shelloracle/config.py @@ -1,7 +1,6 @@ from __future__ import annotations import logging -import os import sys from collections.abc import Mapping, Iterator from pathlib import Path @@ -21,14 +20,14 @@ class Configuration(Mapping): """ShellOracle application configuration""" - if "SHELLORACLE_CONFIG" in os.environ: - filepath = Path(os.environ["SHELLORACLE_CONFIG"]).absolute() - else: - filepath = shelloracle_home / "config.toml" + def __init__(self, config) -> None: + self._config = config - def __init__(self) -> None: - with self.filepath.open("rb") as config_file: - self._config = tomllib.load(config_file) + @classmethod + def from_file(cls, filepath: Path) -> Configuration: + with filepath.open("rb") as config_file: + config = tomllib.load(config_file) + return cls(config) def __getitem__(self, key) -> Any: return self._config[key] @@ -72,5 +71,6 @@ def get_config() -> Configuration: """ global _config if _config is None: - _config = Configuration() + filepath = shelloracle_home / "config.toml" + _config = Configuration.from_file(filepath) return _config diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..2c4bf2f --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,59 @@ +from unittest.mock import MagicMock + +import pytest + +from shelloracle.config import Configuration + + +@pytest.fixture(autouse=True) +def tmp_shelloracle_home(monkeypatch, tmp_path): + monkeypatch.setattr("shelloracle.config.shelloracle_home", tmp_path) + return tmp_path + + +@pytest.fixture +def mock_yaspin(monkeypatch): + mock = MagicMock() + monkeypatch.setattr("shelloracle.shelloracle.yaspin", mock) + return mock + + +@pytest.fixture +def mock_config(monkeypatch): + config = MagicMock() + monkeypatch.setattr("shelloracle.config._config", config) + return config + + +@pytest.fixture +def default_config_dict(): + config = { + "shelloracle": { + "provider": "Ollama", + "spinner_style": "earth" + }, + "provider": { + "Ollama": { + "host": "localhost", + "port": 11434, + "model": "dolphin-mistral" + }, + "OpenAI": { + "api_key": "sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", + "model": "gpt-3.5-turbo" + }, + "LocalAI": { + "host": "localhost", + "port": 8080, + "model": "mostral-openorca" + } + } + } + return config + + +@pytest.fixture +def default_config(monkeypatch, default_config_dict): + configuration = Configuration(default_config_dict) + monkeypatch.setattr("shelloracle.config._config", configuration) + return configuration diff --git a/tests/providers/conftest.py b/tests/providers/conftest.py new file mode 100644 index 0000000..9e9120b --- /dev/null +++ b/tests/providers/conftest.py @@ -0,0 +1,59 @@ +import pytest + +from shelloracle.config import Configuration + + +@pytest.fixture +def ollama_config(monkeypatch): + config_dict = { + "shelloracle": { + "provider": "Ollama", + }, + "provider": { + "Ollama": { + "host": "localhost", + "port": 11434, + "model": "dolphin-mistral" + } + } + } + config = Configuration(config_dict) + monkeypatch.setattr("shelloracle.config.get_config", lambda: config) + return config + + +@pytest.fixture +def openai_config(monkeypatch): + config_dict = { + "shelloracle": { + "provider": "OpenAI", + }, + "provider": { + "OpenAI": { + "api_key": "sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", + "model": "gpt-3.5-turbo" + } + } + } + config = Configuration(config_dict) + monkeypatch.setattr("shelloracle.config.get_config", lambda: config) + return config + + +@pytest.fixture +def localai_config(monkeypatch): + config_dict = { + "shelloracle": { + "provider": "LocalAI", + }, + "provider": { + "LocalAI": { + "host": "localhost", + "port": 8080, + "model": "mostral-openorca" + } + } + } + config = Configuration(config_dict) + monkeypatch.setattr("shelloracle.config.get_config", lambda: config) + return config diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 0000000..2c93520 --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +import logging +from unittest.mock import MagicMock + +import pytest +import tomlkit + +from shelloracle.config import get_config, Configuration + + +class TestGetConfig: + @pytest.fixture + def from_file_mock(self, monkeypatch): + mock = MagicMock() + monkeypatch.setattr("shelloracle.config.Configuration.from_file", mock) + return mock + + def test_config_exists(self, monkeypatch, from_file_mock): + empty_config = Configuration({}) + monkeypatch.setattr("shelloracle.config._config", empty_config) + config = get_config() + from_file_mock.assert_not_called() + assert config == empty_config + + def test_no_config(self, monkeypatch, from_file_mock, tmp_path): + monkeypatch.setattr("shelloracle.config._config", None) + monkeypatch.setattr("shelloracle.config.shelloracle_home", tmp_path) + get_config() + from_file_mock.assert_called_once_with(tmp_path / "config.toml") + + +class TestConfig: + def test_from_file(self, tmp_path, default_config_dict): + config_path = tmp_path / "config.toml" + with config_path.open("w") as file: + tomlkit.dump(default_config_dict, file) + configuration = Configuration.from_file(config_path) + assert configuration.provider == default_config_dict["shelloracle"]["provider"] + + def test_getitem(self, default_config_dict, default_config): + for key in default_config_dict: + assert default_config_dict[key] == default_config[key] + + def test_len(self, default_config_dict, default_config): + assert len(default_config_dict) == len(default_config) + + def test_iter(self, default_config_dict, default_config): + assert list(iter(default_config_dict)) == list(iter(default_config)) + + def test_str(self, default_config_dict, default_config): + assert str(default_config) == f"Configuration({default_config_dict})" + + def test_repr(self, default_config): + assert repr(default_config) == str(default_config) + + def test_provider(self, default_config): + assert default_config.provider == "Ollama" + + def test_spinner_style(self, default_config): + assert default_config.spinner_style == "earth" + + def test_no_spinner_style(self, caplog, default_config_dict): + del default_config_dict["shelloracle"]["spinner_style"] + config = Configuration(default_config_dict) + assert config.spinner_style is None + + def test_spinner_invalid_style(self, caplog, default_config_dict): + default_config_dict["shelloracle"]["spinner_style"] = "invalid" + config = Configuration(default_config_dict) + assert config.spinner_style is None diff --git a/tests/test_shelloracle.py b/tests/test_shelloracle.py index ca98c47..c7fc29e 100644 --- a/tests/test_shelloracle.py +++ b/tests/test_shelloracle.py @@ -2,7 +2,7 @@ import os import sys -from unittest.mock import MagicMock, call +from unittest.mock import call import pytest from yaspin.spinners import Spinners @@ -10,20 +10,6 @@ from shelloracle.shelloracle import get_query_from_pipe, spinner -@pytest.fixture -def mock_yaspin(monkeypatch): - mock = MagicMock() - monkeypatch.setattr("shelloracle.shelloracle.yaspin", mock) - return mock - - -@pytest.fixture -def mock_config(monkeypatch): - config = MagicMock() - monkeypatch.setattr("shelloracle.config._config", config) - return config - - @pytest.mark.parametrize("spinner_style,expected", [(None, call()), ("earth", call(Spinners.earth))]) def test_spinner(spinner_style, expected, mock_config, mock_yaspin): mock_config.spinner_style = spinner_style