Skip to content

Commit

Permalink
add support to openai embedding model for passing connection params i…
Browse files Browse the repository at this point in the history
…n constructor
  • Loading branch information
krohling committed Dec 15, 2023
1 parent 3723f8d commit 77ea942
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions bondai/models/openai/openai_embedding_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import List, Dict
from bondai.models import EmbeddingModel
from .openai_models import ModelConfig, OpenAIModelType, OpenAIModelNames
from .openai_wrapper import create_embedding, count_tokens, get_max_tokens
Expand All @@ -7,9 +7,12 @@

class OpenAIEmbeddingModel(EmbeddingModel):
def __init__(
self, model: OpenAIModelNames = OpenAIModelNames.TEXT_EMBEDDING_ADA_002
self,
model: OpenAIModelNames = OpenAIModelNames.TEXT_EMBEDDING_ADA_002,
connection_params: Dict = None,
):
self._model = model.value
self._connection_params = connection_params
if ModelConfig[self._model]["model_type"] != OpenAIModelType.EMBEDDING:
raise Exception(f"Model {model} is not an embedding model.")

Expand All @@ -22,8 +25,9 @@ def max_tokens(self) -> int:
return get_max_tokens(self._model)

def create_embedding(self, prompt: str) -> List[float] | List[List[float]]:
connection_params = self._connection_params or EMBEDDINGS_CONNECTION_PARAMS
return create_embedding(
prompt, self._model, connection_params=EMBEDDINGS_CONNECTION_PARAMS
prompt, self._model, connection_params=connection_params
)

def count_tokens(self, prompt: str) -> int:
Expand Down

0 comments on commit 77ea942

Please sign in to comment.