From 32382068aad46c3cd7d8beac0fea1cb37a698d20 Mon Sep 17 00:00:00 2001 From: Dirk Kulawiak Date: Mon, 6 May 2024 09:38:41 -0700 Subject: [PATCH] Add support for generative-octoai --- test/collection/test_config.py | 16 ++++++++++++++ weaviate/collections/classes/config.py | 29 +++++++++++++++++++++++--- 2 files changed, 42 insertions(+), 3 deletions(-) diff --git a/test/collection/test_config.py b/test/collection/test_config.py index 3c951283b..c9bd28ff5 100644 --- a/test/collection/test_config.py +++ b/test/collection/test_config.py @@ -592,6 +592,22 @@ def test_config_with_vectorizer_and_properties( Configure.Generative.mistral(temperature=0.5, max_tokens=100, model="model"), {"generative-mistral": {"temperature": 0.5, "maxTokens": 100, "model": "model"}}, ), + ( + Configure.Generative.octoai( + model="mistral-7b-instruct", + temperature=0.5, + base_url="https://text.octoai.run", + max_tokens=123, + ), + { + "generative-openai": { + "model": "mistral-7b-instruct", + "maxTokens": 123, + "temperature": 0.5, + "baseURL": "https://text.octoai.run", + } + }, + ), ( Configure.Generative.openai( model="gpt-4", diff --git a/weaviate/collections/classes/config.py b/weaviate/collections/classes/config.py index 995c05c0d..a98b5ca47 100644 --- a/weaviate/collections/classes/config.py +++ b/weaviate/collections/classes/config.py @@ -160,12 +160,13 @@ class GenerativeSearches(str, Enum): Weaviate module backed by AWS Bedrock generative models. """ - OPENAI = "generative-openai" - COHERE = "generative-cohere" - PALM = "generative-palm" AWS = "generative-aws" ANYSCALE = "generative-anyscale" + COHERE = "generative-cohere" MISTRAL = "generative-mistral" + OCTOAI = "generative-octoai" + OPENAI = "generative-openai" + PALM = "generative-palm" class Rerankers(str, Enum): @@ -368,6 +369,16 @@ class _GenerativeAnyscale(_GenerativeConfigCreate): model: Optional[str] +class _GenerativeOctoai(_GenerativeConfigCreate): + generative: GenerativeSearches = Field( + default=GenerativeSearches.OCTOAI, frozen=True, exclude=True + ) + baseURL: Optional[str] + temperature: Optional[float] + maxTokens: Optional[int] + model: Optional[str] + + class _GenerativeMistral(_GenerativeConfigCreate): generative: GenerativeSearches = Field( default=GenerativeSearches.MISTRAL, frozen=True, exclude=True @@ -490,6 +501,18 @@ def mistral( ) -> _GenerativeConfigCreate: return _GenerativeMistral(model=model, temperature=temperature, maxTokens=max_tokens) + @staticmethod + def octoai( + *, + base_url: Optional[str] = None, + max_tokens: Optional[int] = None, + model: Optional[str] = None, + temperature: Optional[float] = None, + ) -> _GenerativeConfigCreate: + return _GenerativeOctoai( + baseURL=base_url, maxTokens=max_tokens, model=model, temperature=temperature + ) + @staticmethod def openai( model: Optional[str] = None,