diff --git a/src/shelloracle/providers/__init__.py b/src/shelloracle/providers/__init__.py index c9d2d50..da09a33 100644 --- a/src/shelloracle/providers/__init__.py +++ b/src/shelloracle/providers/__init__.py @@ -77,8 +77,9 @@ def _providers() -> dict[str, type[Provider]]: from shelloracle.providers.localai import LocalAI from shelloracle.providers.ollama import Ollama from shelloracle.providers.openai import OpenAI + from shelloracle.providers.xai import XAI - return {Ollama.name: Ollama, OpenAI.name: OpenAI, LocalAI.name: LocalAI} + return {Ollama.name: Ollama, OpenAI.name: OpenAI, LocalAI.name: LocalAI, XAI.name: XAI} def get_provider(name: str) -> type[Provider]: diff --git a/src/shelloracle/providers/xai.py b/src/shelloracle/providers/xai.py new file mode 100644 index 0000000..65d7a71 --- /dev/null +++ b/src/shelloracle/providers/xai.py @@ -0,0 +1,38 @@ +from collections.abc import AsyncIterator + +from openai import APIError, AsyncOpenAI + +from shelloracle.providers import Provider, ProviderError, Setting, system_prompt + + +class XAI(Provider): + name = "XAI" + + api_key = Setting(default="") + model = Setting(default="grok-beta") + + def __init__(self): + if not self.api_key: + msg = "No API key provided" + raise ProviderError(msg) + self.client = AsyncOpenAI( + api_key=self.api_key, + base_url="https://api.x.ai/v1", + ) + + async def generate(self, prompt: str) -> AsyncIterator[str]: + try: + stream = await self.client.chat.completions.create( + model=self.model, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": prompt}, + ], + stream=True, + ) + async for chunk in stream: + if chunk.choices[0].delta.content is not None: + yield chunk.choices[0].delta.content + except APIError as e: + msg = f"Something went wrong while querying XAI: {e}" + raise ProviderError(msg) from e diff --git a/tests/providers/test_xai.py b/tests/providers/test_xai.py new file mode 100644 index 0000000..2101988 --- /dev/null +++ b/tests/providers/test_xai.py @@ -0,0 +1,41 @@ +import pytest + +from shelloracle.providers.xai import XAI + + +class TestOpenAI: + @pytest.fixture + def xai_config(self, set_config): + config = { + "shelloracle": {"provider": "XAI"}, + "provider": { + "XAI": { + "api_key": "xai-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", + "model": "grok-beta", + } + }, + } + set_config(config) + + @pytest.fixture + def xai_instance(self, xai_config): + return XAI() + + def test_name(self): + assert XAI.name == "XAI" + + def test_api_key(self, xai_instance): + assert ( + xai_instance.api_key + == "xai-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" + ) + + def test_model(self, xai_instance): + assert xai_instance.model == "grok-beta" + + @pytest.mark.asyncio + async def test_generate(self, mock_asyncopenai, xai_instance): + result = "" + async for response in xai_instance.generate(""): + result += response + assert result == "head -c 100 /dev/urandom | hexdump -C"