Skip to content

Commit

Permalink
Add support for generative-octoai
Browse files Browse the repository at this point in the history
  • Loading branch information
dirkkul committed May 6, 2024
1 parent bb2625a commit 3238206
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 3 deletions.
16 changes: 16 additions & 0 deletions test/collection/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,22 @@ def test_config_with_vectorizer_and_properties(
Configure.Generative.mistral(temperature=0.5, max_tokens=100, model="model"),
{"generative-mistral": {"temperature": 0.5, "maxTokens": 100, "model": "model"}},
),
(
Configure.Generative.octoai(
model="mistral-7b-instruct",
temperature=0.5,
base_url="https://text.octoai.run",
max_tokens=123,
),
{
"generative-openai": {
"model": "mistral-7b-instruct",
"maxTokens": 123,
"temperature": 0.5,
"baseURL": "https://text.octoai.run",
}
},
),
(
Configure.Generative.openai(
model="gpt-4",
Expand Down
29 changes: 26 additions & 3 deletions weaviate/collections/classes/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,12 +160,13 @@ class GenerativeSearches(str, Enum):
Weaviate module backed by AWS Bedrock generative models.
"""

OPENAI = "generative-openai"
COHERE = "generative-cohere"
PALM = "generative-palm"
AWS = "generative-aws"
ANYSCALE = "generative-anyscale"
COHERE = "generative-cohere"
MISTRAL = "generative-mistral"
OCTOAI = "generative-octoai"
OPENAI = "generative-openai"
PALM = "generative-palm"


class Rerankers(str, Enum):
Expand Down Expand Up @@ -368,6 +369,16 @@ class _GenerativeAnyscale(_GenerativeConfigCreate):
model: Optional[str]


class _GenerativeOctoai(_GenerativeConfigCreate):
generative: GenerativeSearches = Field(
default=GenerativeSearches.OCTOAI, frozen=True, exclude=True
)
baseURL: Optional[str]
temperature: Optional[float]
maxTokens: Optional[int]
model: Optional[str]


class _GenerativeMistral(_GenerativeConfigCreate):
generative: GenerativeSearches = Field(
default=GenerativeSearches.MISTRAL, frozen=True, exclude=True
Expand Down Expand Up @@ -490,6 +501,18 @@ def mistral(
) -> _GenerativeConfigCreate:
return _GenerativeMistral(model=model, temperature=temperature, maxTokens=max_tokens)

@staticmethod
def octoai(
*,
base_url: Optional[str] = None,
max_tokens: Optional[int] = None,
model: Optional[str] = None,
temperature: Optional[float] = None,
) -> _GenerativeConfigCreate:
return _GenerativeOctoai(
baseURL=base_url, maxTokens=max_tokens, model=model, temperature=temperature
)

@staticmethod
def openai(
model: Optional[str] = None,
Expand Down

0 comments on commit 3238206

Please sign in to comment.