From c94a6a72c6520271ab042349dc281c41fe59513a Mon Sep 17 00:00:00 2001 From: Atakan Tekparmak <59488384+AtakanTekparmak@users.noreply.github.com> Date: Sat, 13 Jul 2024 03:44:55 +0200 Subject: [PATCH] feat: Added Ollama engine via OpenAI api (#51) * feat: Added Ollama engine via OpenAI api * fix: Added PR remarks and test --- README.md | 14 ++++++++++++++ tests/test_engines.py | 15 +++++++++++++++ textgrad/engine/__init__.py | 8 ++++++++ textgrad/engine/openai.py | 29 ++++++++++++++++++++++++----- 4 files changed, 61 insertions(+), 5 deletions(-) create mode 100644 tests/test_engines.py diff --git a/README.md b/README.md index 84b30b2..75d3122 100644 --- a/README.md +++ b/README.md @@ -384,6 +384,20 @@ We are grateful for all the help we got from our contributors!
tboen1 + + + + nihalnayak +
+ Nihal Nayak +
+ + + + AtakanTekparmak +
+ Atakan Tekparmak +
diff --git a/tests/test_engines.py b/tests/test_engines.py new file mode 100644 index 0000000..ddd34d1 --- /dev/null +++ b/tests/test_engines.py @@ -0,0 +1,15 @@ +import pytest + +from textgrad.engine import get_engine + +def test_ollama_engine(): + # Declare test constants + OLLAMA_BASE_URL = 'http://localhost:11434/v1' + MODEL_STRING = "test-model-string" + + # Initialise the engine + engine = get_engine("ollama-" + MODEL_STRING) + + assert engine + assert engine.model_string == MODEL_STRING + assert engine.base_url == OLLAMA_BASE_URL \ No newline at end of file diff --git a/textgrad/engine/__init__.py b/textgrad/engine/__init__.py index bd3ab20..d710c50 100644 --- a/textgrad/engine/__init__.py +++ b/textgrad/engine/__init__.py @@ -55,6 +55,14 @@ def get_engine(engine_name: str, **kwargs) -> EngineLM: elif engine_name in ["command-r-plus", "command-r", "command", "command-light"]: from .cohere import ChatCohere return ChatCohere(model_string=engine_name, **kwargs) + elif engine_name.startswith("ollama"): + from .openai import ChatOpenAI, OLLAMA_BASE_URL + model_string = engine_name.replace("ollama-", "") + return ChatOpenAI( + model_string=model_string, + base_url=OLLAMA_BASE_URL, + **kwargs + ) elif "vllm" in engine_name: from .vllm import ChatVLLM engine_name = engine_name.replace("vllm-", "") diff --git a/textgrad/engine/openai.py b/textgrad/engine/openai.py index 0a922b4..723f04a 100644 --- a/textgrad/engine/openai.py +++ b/textgrad/engine/openai.py @@ -17,6 +17,12 @@ from .base import EngineLM, CachedEngine from .engine_utils import get_image_type_from_bytes +# Default base URL for OLLAMA +OLLAMA_BASE_URL = 'http://localhost:11434/v1' + +# Check if the user set the OLLAMA_BASE_URL environment variable +if os.getenv("OLLAMA_BASE_URL"): + OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL") class ChatOpenAI(EngineLM, CachedEngine): DEFAULT_SYSTEM_PROMPT = "You are a helpful, creative, and smart assistant." @@ -26,10 +32,12 @@ def __init__( model_string: str="gpt-3.5-turbo-0613", system_prompt: str=DEFAULT_SYSTEM_PROMPT, is_multimodal: bool=False, + base_url: str=None, **kwargs): """ :param model_string: :param system_prompt: + :param base_url: Used to support Ollama """ root = platformdirs.user_cache_dir("textgrad") cache_path = os.path.join(root, f"cache_openai_{model_string}.db") @@ -37,12 +45,23 @@ def __init__( super().__init__(cache_path=cache_path) self.system_prompt = system_prompt - if os.getenv("OPENAI_API_KEY") is None: - raise ValueError("Please set the OPENAI_API_KEY environment variable if you'd like to use OpenAI models.") + self.base_url = base_url - self.client = OpenAI( - api_key=os.getenv("OPENAI_API_KEY"), - ) + if not base_url: + if os.getenv("OPENAI_API_KEY") is None: + raise ValueError("Please set the OPENAI_API_KEY environment variable if you'd like to use OpenAI models.") + + self.client = OpenAI( + api_key=os.getenv("OPENAI_API_KEY") + ) + elif base_url and base_url == OLLAMA_BASE_URL: + self.client = OpenAI( + base_url=base_url, + api_key="ollama" + ) + else: + raise ValueError("Invalid base URL provided. Please use the default OLLAMA base URL or None.") + self.model_string = model_string self.is_multimodal = is_multimodal