Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
djcopley committed Apr 7, 2024
1 parent 42af3c2 commit b88b142
Show file tree
Hide file tree
Showing 5 changed files with 199 additions and 24 deletions.
18 changes: 9 additions & 9 deletions src/shelloracle/config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import logging
import os
import sys
from collections.abc import Mapping, Iterator
from pathlib import Path
Expand All @@ -21,14 +20,14 @@
class Configuration(Mapping):
"""ShellOracle application configuration"""

if "SHELLORACLE_CONFIG" in os.environ:
filepath = Path(os.environ["SHELLORACLE_CONFIG"]).absolute()
else:
filepath = shelloracle_home / "config.toml"
def __init__(self, config) -> None:
self._config = config

def __init__(self) -> None:
with self.filepath.open("rb") as config_file:
self._config = tomllib.load(config_file)
@classmethod
def from_file(cls, filepath: Path) -> Configuration:
with filepath.open("rb") as config_file:
config = tomllib.load(config_file)
return cls(config)

def __getitem__(self, key) -> Any:
return self._config[key]
Expand Down Expand Up @@ -72,5 +71,6 @@ def get_config() -> Configuration:
"""
global _config
if _config is None:
_config = Configuration()
filepath = shelloracle_home / "config.toml"
_config = Configuration.from_file(filepath)
return _config
59 changes: 59 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from unittest.mock import MagicMock

import pytest

from shelloracle.config import Configuration


@pytest.fixture(autouse=True)
def tmp_shelloracle_home(monkeypatch, tmp_path):
monkeypatch.setattr("shelloracle.config.shelloracle_home", tmp_path)
return tmp_path


@pytest.fixture
def mock_yaspin(monkeypatch):
mock = MagicMock()
monkeypatch.setattr("shelloracle.shelloracle.yaspin", mock)
return mock


@pytest.fixture
def mock_config(monkeypatch):
config = MagicMock()
monkeypatch.setattr("shelloracle.config._config", config)
return config


@pytest.fixture
def default_config_dict():
config = {
"shelloracle": {
"provider": "Ollama",
"spinner_style": "earth"
},
"provider": {
"Ollama": {
"host": "localhost",
"port": 11434,
"model": "dolphin-mistral"
},
"OpenAI": {
"api_key": "sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
"model": "gpt-3.5-turbo"
},
"LocalAI": {
"host": "localhost",
"port": 8080,
"model": "mostral-openorca"
}
}
}
return config


@pytest.fixture
def default_config(monkeypatch, default_config_dict):
configuration = Configuration(default_config_dict)
monkeypatch.setattr("shelloracle.config._config", configuration)
return configuration
59 changes: 59 additions & 0 deletions tests/providers/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import pytest

from shelloracle.config import Configuration


@pytest.fixture
def ollama_config(monkeypatch):
config_dict = {
"shelloracle": {
"provider": "Ollama",
},
"provider": {
"Ollama": {
"host": "localhost",
"port": 11434,
"model": "dolphin-mistral"
}
}
}
config = Configuration(config_dict)
monkeypatch.setattr("shelloracle.config.get_config", lambda: config)
return config


@pytest.fixture
def openai_config(monkeypatch):
config_dict = {
"shelloracle": {
"provider": "OpenAI",
},
"provider": {
"OpenAI": {
"api_key": "sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
"model": "gpt-3.5-turbo"
}
}
}
config = Configuration(config_dict)
monkeypatch.setattr("shelloracle.config.get_config", lambda: config)
return config


@pytest.fixture
def localai_config(monkeypatch):
config_dict = {
"shelloracle": {
"provider": "LocalAI",
},
"provider": {
"LocalAI": {
"host": "localhost",
"port": 8080,
"model": "mostral-openorca"
}
}
}
config = Configuration(config_dict)
monkeypatch.setattr("shelloracle.config.get_config", lambda: config)
return config
71 changes: 71 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from __future__ import annotations

import logging
from unittest.mock import MagicMock

import pytest
import tomlkit

from shelloracle.config import get_config, Configuration


class TestGetConfig:
@pytest.fixture
def from_file_mock(self, monkeypatch):
mock = MagicMock()
monkeypatch.setattr("shelloracle.config.Configuration.from_file", mock)
return mock

def test_config_exists(self, monkeypatch, from_file_mock):
empty_config = Configuration({})
monkeypatch.setattr("shelloracle.config._config", empty_config)
config = get_config()
from_file_mock.assert_not_called()
assert config == empty_config

def test_no_config(self, monkeypatch, from_file_mock, tmp_path):
monkeypatch.setattr("shelloracle.config._config", None)
monkeypatch.setattr("shelloracle.config.shelloracle_home", tmp_path)
get_config()
from_file_mock.assert_called_once_with(tmp_path / "config.toml")


class TestConfig:
def test_from_file(self, tmp_path, default_config_dict):
config_path = tmp_path / "config.toml"
with config_path.open("w") as file:
tomlkit.dump(default_config_dict, file)
configuration = Configuration.from_file(config_path)
assert configuration.provider == default_config_dict["shelloracle"]["provider"]

def test_getitem(self, default_config_dict, default_config):
for key in default_config_dict:
assert default_config_dict[key] == default_config[key]

def test_len(self, default_config_dict, default_config):
assert len(default_config_dict) == len(default_config)

def test_iter(self, default_config_dict, default_config):
assert list(iter(default_config_dict)) == list(iter(default_config))

def test_str(self, default_config_dict, default_config):
assert str(default_config) == f"Configuration({default_config_dict})"

def test_repr(self, default_config):
assert repr(default_config) == str(default_config)

def test_provider(self, default_config):
assert default_config.provider == "Ollama"

def test_spinner_style(self, default_config):
assert default_config.spinner_style == "earth"

def test_no_spinner_style(self, caplog, default_config_dict):
del default_config_dict["shelloracle"]["spinner_style"]
config = Configuration(default_config_dict)
assert config.spinner_style is None

def test_spinner_invalid_style(self, caplog, default_config_dict):
default_config_dict["shelloracle"]["spinner_style"] = "invalid"
config = Configuration(default_config_dict)
assert config.spinner_style is None
16 changes: 1 addition & 15 deletions tests/test_shelloracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,14 @@

import os
import sys
from unittest.mock import MagicMock, call
from unittest.mock import call

import pytest
from yaspin.spinners import Spinners

from shelloracle.shelloracle import get_query_from_pipe, spinner


@pytest.fixture
def mock_yaspin(monkeypatch):
mock = MagicMock()
monkeypatch.setattr("shelloracle.shelloracle.yaspin", mock)
return mock


@pytest.fixture
def mock_config(monkeypatch):
config = MagicMock()
monkeypatch.setattr("shelloracle.config._config", config)
return config


@pytest.mark.parametrize("spinner_style,expected", [(None, call()), ("earth", call(Spinners.earth))])
def test_spinner(spinner_style, expected, mock_config, mock_yaspin):
mock_config.spinner_style = spinner_style
Expand Down

0 comments on commit b88b142

Please sign in to comment.