Skip to content

Commit

Permalink
feat: Added Ollama engine via OpenAI api (#51)
Browse files Browse the repository at this point in the history
* feat: Added Ollama engine via OpenAI api

* fix: Added PR remarks and test
  • Loading branch information
AtakanTekparmak authored Jul 13, 2024
1 parent a15e7b6 commit c94a6a7
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 5 deletions.
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,20 @@ We are grateful for all the help we got from our contributors!
<br />
<sub><b>tboen1</b></sub>
</a>
</td>
<td align="center">
<a href="https://github.com/nihalnayak">
<img src="https://avatars.githubusercontent.com/u/5679782?v=4" width="100;" alt="nihalnayak"/>
<br />
<sub><b>Nihal Nayak</b></sub>
</a>
</td>
<td align="center">
<a href="https://github.com/AtakanTekparmak">
<img src="https://avatars.githubusercontent.com/u/59488384?v=4" width="100;" alt="AtakanTekparmak"/>
<br />
<sub><b>Atakan Tekparmak</b></sub>
</a>
</td>
</tr>
<tbody>
Expand Down
15 changes: 15 additions & 0 deletions tests/test_engines.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import pytest

from textgrad.engine import get_engine

def test_ollama_engine():
# Declare test constants
OLLAMA_BASE_URL = 'http://localhost:11434/v1'
MODEL_STRING = "test-model-string"

# Initialise the engine
engine = get_engine("ollama-" + MODEL_STRING)

assert engine
assert engine.model_string == MODEL_STRING
assert engine.base_url == OLLAMA_BASE_URL
8 changes: 8 additions & 0 deletions textgrad/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,14 @@ def get_engine(engine_name: str, **kwargs) -> EngineLM:
elif engine_name in ["command-r-plus", "command-r", "command", "command-light"]:
from .cohere import ChatCohere
return ChatCohere(model_string=engine_name, **kwargs)
elif engine_name.startswith("ollama"):
from .openai import ChatOpenAI, OLLAMA_BASE_URL
model_string = engine_name.replace("ollama-", "")
return ChatOpenAI(
model_string=model_string,
base_url=OLLAMA_BASE_URL,
**kwargs
)
elif "vllm" in engine_name:
from .vllm import ChatVLLM
engine_name = engine_name.replace("vllm-", "")
Expand Down
29 changes: 24 additions & 5 deletions textgrad/engine/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@
from .base import EngineLM, CachedEngine
from .engine_utils import get_image_type_from_bytes

# Default base URL for OLLAMA
OLLAMA_BASE_URL = 'http://localhost:11434/v1'

# Check if the user set the OLLAMA_BASE_URL environment variable
if os.getenv("OLLAMA_BASE_URL"):
OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL")

class ChatOpenAI(EngineLM, CachedEngine):
DEFAULT_SYSTEM_PROMPT = "You are a helpful, creative, and smart assistant."
Expand All @@ -26,23 +32,36 @@ def __init__(
model_string: str="gpt-3.5-turbo-0613",
system_prompt: str=DEFAULT_SYSTEM_PROMPT,
is_multimodal: bool=False,
base_url: str=None,
**kwargs):
"""
:param model_string:
:param system_prompt:
:param base_url: Used to support Ollama
"""
root = platformdirs.user_cache_dir("textgrad")
cache_path = os.path.join(root, f"cache_openai_{model_string}.db")

super().__init__(cache_path=cache_path)

self.system_prompt = system_prompt
if os.getenv("OPENAI_API_KEY") is None:
raise ValueError("Please set the OPENAI_API_KEY environment variable if you'd like to use OpenAI models.")
self.base_url = base_url

self.client = OpenAI(
api_key=os.getenv("OPENAI_API_KEY"),
)
if not base_url:
if os.getenv("OPENAI_API_KEY") is None:
raise ValueError("Please set the OPENAI_API_KEY environment variable if you'd like to use OpenAI models.")

self.client = OpenAI(
api_key=os.getenv("OPENAI_API_KEY")
)
elif base_url and base_url == OLLAMA_BASE_URL:
self.client = OpenAI(
base_url=base_url,
api_key="ollama"
)
else:
raise ValueError("Invalid base URL provided. Please use the default OLLAMA base URL or None.")

self.model_string = model_string
self.is_multimodal = is_multimodal

Expand Down

0 comments on commit c94a6a7

Please sign in to comment.