From c94a6a72c6520271ab042349dc281c41fe59513a Mon Sep 17 00:00:00 2001
From: Atakan Tekparmak <59488384+AtakanTekparmak@users.noreply.github.com>
Date: Sat, 13 Jul 2024 03:44:55 +0200
Subject: [PATCH] feat: Added Ollama engine via OpenAI api (#51)
* feat: Added Ollama engine via OpenAI api
* fix: Added PR remarks and test
---
README.md | 14 ++++++++++++++
tests/test_engines.py | 15 +++++++++++++++
textgrad/engine/__init__.py | 8 ++++++++
textgrad/engine/openai.py | 29 ++++++++++++++++++++++++-----
4 files changed, 61 insertions(+), 5 deletions(-)
create mode 100644 tests/test_engines.py
diff --git a/README.md b/README.md
index 84b30b2..75d3122 100644
--- a/README.md
+++ b/README.md
@@ -384,6 +384,20 @@ We are grateful for all the help we got from our contributors!
tboen1
+
+
+
+
+
+ Nihal Nayak
+
+ |
+
+
+
+
+ Atakan Tekparmak
+
|
diff --git a/tests/test_engines.py b/tests/test_engines.py
new file mode 100644
index 0000000..ddd34d1
--- /dev/null
+++ b/tests/test_engines.py
@@ -0,0 +1,15 @@
+import pytest
+
+from textgrad.engine import get_engine
+
+def test_ollama_engine():
+ # Declare test constants
+ OLLAMA_BASE_URL = 'http://localhost:11434/v1'
+ MODEL_STRING = "test-model-string"
+
+ # Initialise the engine
+ engine = get_engine("ollama-" + MODEL_STRING)
+
+ assert engine
+ assert engine.model_string == MODEL_STRING
+ assert engine.base_url == OLLAMA_BASE_URL
\ No newline at end of file
diff --git a/textgrad/engine/__init__.py b/textgrad/engine/__init__.py
index bd3ab20..d710c50 100644
--- a/textgrad/engine/__init__.py
+++ b/textgrad/engine/__init__.py
@@ -55,6 +55,14 @@ def get_engine(engine_name: str, **kwargs) -> EngineLM:
elif engine_name in ["command-r-plus", "command-r", "command", "command-light"]:
from .cohere import ChatCohere
return ChatCohere(model_string=engine_name, **kwargs)
+ elif engine_name.startswith("ollama"):
+ from .openai import ChatOpenAI, OLLAMA_BASE_URL
+ model_string = engine_name.replace("ollama-", "")
+ return ChatOpenAI(
+ model_string=model_string,
+ base_url=OLLAMA_BASE_URL,
+ **kwargs
+ )
elif "vllm" in engine_name:
from .vllm import ChatVLLM
engine_name = engine_name.replace("vllm-", "")
diff --git a/textgrad/engine/openai.py b/textgrad/engine/openai.py
index 0a922b4..723f04a 100644
--- a/textgrad/engine/openai.py
+++ b/textgrad/engine/openai.py
@@ -17,6 +17,12 @@
from .base import EngineLM, CachedEngine
from .engine_utils import get_image_type_from_bytes
+# Default base URL for OLLAMA
+OLLAMA_BASE_URL = 'http://localhost:11434/v1'
+
+# Check if the user set the OLLAMA_BASE_URL environment variable
+if os.getenv("OLLAMA_BASE_URL"):
+ OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL")
class ChatOpenAI(EngineLM, CachedEngine):
DEFAULT_SYSTEM_PROMPT = "You are a helpful, creative, and smart assistant."
@@ -26,10 +32,12 @@ def __init__(
model_string: str="gpt-3.5-turbo-0613",
system_prompt: str=DEFAULT_SYSTEM_PROMPT,
is_multimodal: bool=False,
+ base_url: str=None,
**kwargs):
"""
:param model_string:
:param system_prompt:
+ :param base_url: Used to support Ollama
"""
root = platformdirs.user_cache_dir("textgrad")
cache_path = os.path.join(root, f"cache_openai_{model_string}.db")
@@ -37,12 +45,23 @@ def __init__(
super().__init__(cache_path=cache_path)
self.system_prompt = system_prompt
- if os.getenv("OPENAI_API_KEY") is None:
- raise ValueError("Please set the OPENAI_API_KEY environment variable if you'd like to use OpenAI models.")
+ self.base_url = base_url
- self.client = OpenAI(
- api_key=os.getenv("OPENAI_API_KEY"),
- )
+ if not base_url:
+ if os.getenv("OPENAI_API_KEY") is None:
+ raise ValueError("Please set the OPENAI_API_KEY environment variable if you'd like to use OpenAI models.")
+
+ self.client = OpenAI(
+ api_key=os.getenv("OPENAI_API_KEY")
+ )
+ elif base_url and base_url == OLLAMA_BASE_URL:
+ self.client = OpenAI(
+ base_url=base_url,
+ api_key="ollama"
+ )
+ else:
+ raise ValueError("Invalid base URL provided. Please use the default OLLAMA base URL or None.")
+
self.model_string = model_string
self.is_multimodal = is_multimodal