diff --git a/bondai/models/openai/openai_embedding_model.py b/bondai/models/openai/openai_embedding_model.py index 8416de5..0b061e5 100644 --- a/bondai/models/openai/openai_embedding_model.py +++ b/bondai/models/openai/openai_embedding_model.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Dict from bondai.models import EmbeddingModel from .openai_models import ModelConfig, OpenAIModelType, OpenAIModelNames from .openai_wrapper import create_embedding, count_tokens, get_max_tokens @@ -7,9 +7,12 @@ class OpenAIEmbeddingModel(EmbeddingModel): def __init__( - self, model: OpenAIModelNames = OpenAIModelNames.TEXT_EMBEDDING_ADA_002 + self, + model: OpenAIModelNames = OpenAIModelNames.TEXT_EMBEDDING_ADA_002, + connection_params: Dict = None, ): self._model = model.value + self._connection_params = connection_params if ModelConfig[self._model]["model_type"] != OpenAIModelType.EMBEDDING: raise Exception(f"Model {model} is not an embedding model.") @@ -22,8 +25,9 @@ def max_tokens(self) -> int: return get_max_tokens(self._model) def create_embedding(self, prompt: str) -> List[float] | List[List[float]]: + connection_params = self._connection_params or EMBEDDINGS_CONNECTION_PARAMS return create_embedding( - prompt, self._model, connection_params=EMBEDDINGS_CONNECTION_PARAMS + prompt, self._model, connection_params=connection_params ) def count_tokens(self, prompt: str) -> int: