-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
12e6de6
commit 3787883
Showing
6 changed files
with
127 additions
and
57 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,70 +1,53 @@ | ||
""" | ||
Embeddings Adapter Module | ||
# embeddings/adapters.py | ||
This module provides a collection of adapter classes that implement the EmbeddingGeneratorInterface | ||
to generate text embeddings using various external services or models, with an emphasis on | ||
extensibility and caching capabilities to enhance performance. | ||
The adapters act as intermediaries between the abstract interface and concrete embedding generation | ||
mechanisms, allowing for flexible integration of different embedding sources without modifying | ||
client code. This modular approach simplifies the process of adding new embedding generators or | ||
switching between them. | ||
Example usage: | ||
from embeddings.adapters import OpenAIEmbeddingGeneratorAdapter | ||
# Initialize the adapter with a specific model | ||
embedding_adapter = OpenAIEmbeddingGeneratorAdapter(model="text-embedding-ada-002") | ||
# Generate embeddings for a list of texts | ||
embeddings = embedding_adapter.generate_embeddings(["Sample text for embedding."]) | ||
print(embeddings) | ||
# The adapter utilizes an internal cache to store and retrieve embeddings, reducing | ||
# the number of external requests and speeding up the process for repeated inputs. | ||
Classes: | ||
- EmbeddingGeneratorInterface: An abstract base class defining the contract for | ||
embedding generators. | ||
- OpenAIEmbeddingGeneratorAdapter: An adapter for the OpenAI Embedding Generator, | ||
with caching support. | ||
- AnotherEmbeddingGenerator: A template for additional embedding generator implementations. | ||
The module demonstrates the use of the Adapter Design Pattern to facilitate the interaction between | ||
high-level operations and external libraries or APIs. It ensures that changes in the embedding | ||
generation services have minimal impact on the application code, promoting maintainability | ||
and scalability. | ||
""" | ||
|
||
|
||
from abc import ABC, abstractmethod | ||
from typing import List | ||
|
||
from det.embeddings.cache import EmbeddingsCache | ||
from det.embeddings.generator import OpenAIEmbeddingGenerator | ||
from det.embeddings.generator import ( | ||
EmbeddingGeneratorInterface, | ||
OpenAIEmbeddingGenerator, | ||
) | ||
|
||
|
||
class EmbeddingGeneratorAdapterInterface(ABC): | ||
def __init__(self, model: str): | ||
self.model = model | ||
self.embedding_generator: EmbeddingGeneratorInterface = ( | ||
self._create_embedding_generator() | ||
) | ||
|
||
class EmbeddingGeneratorInterface(ABC): | ||
@abstractmethod | ||
def generate_embeddings(self, texts): | ||
def _create_embedding_generator(self) -> EmbeddingGeneratorInterface: | ||
pass | ||
|
||
@abstractmethod | ||
def generate_embeddings(self, texts: List[str]) -> List[List[float]]: | ||
pass | ||
|
||
class OpenAIEmbeddingGeneratorAdapter(EmbeddingGeneratorInterface): | ||
|
||
class OpenAIEmbeddingGeneratorAdapter(EmbeddingGeneratorAdapterInterface): | ||
def __init__(self, model="text-embedding-ada-002"): | ||
self.embedding_generator = OpenAIEmbeddingGenerator(model=model) | ||
super().__init__(model) | ||
self.embeddings_cache = EmbeddingsCache( | ||
embeddings_generator=self.embedding_generator | ||
) | ||
|
||
def generate_embeddings(self, texts): | ||
def _create_embedding_generator(self) -> EmbeddingGeneratorInterface: | ||
# Ensure the OpenAIEmbeddingGenerator class implements EmbeddingGeneratorInterface | ||
return OpenAIEmbeddingGenerator(model=self.model) | ||
|
||
def generate_embeddings(self, texts: List[str]) -> List[List[float]]: | ||
return self.embeddings_cache.generate_embeddings(texts) | ||
|
||
|
||
class AnotherEmbeddingGenerator(EmbeddingGeneratorInterface): | ||
def generate_embeddings(self, texts): | ||
class AnotherEmbeddingGeneratorAdapter(EmbeddingGeneratorAdapterInterface): | ||
def __init__(self, model): | ||
super().__init__(model) | ||
# Additional setup if necessary | ||
|
||
def _create_embedding_generator(self) -> EmbeddingGeneratorInterface: | ||
# Return an instance of another class that implements EmbeddingGeneratorInterface | ||
pass | ||
|
||
def generate_embeddings(self, texts: List[str]) -> List[List[float]]: | ||
# Implementation for another source of embeddings | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
# Generated by CodiumAI | ||
from det.helpers import ( | ||
_get_client_class, | ||
get_llm_client, | ||
get_embedding_generator_adapter, | ||
) | ||
from det.llm.llm_openai import OpenAIClient | ||
from det.embeddings.adapters import AnotherEmbeddingGeneratorAdapter | ||
|
||
import pytest | ||
|
||
|
||
class Test_GetClientClass: | ||
# Can import a valid client class from a valid module path and return it | ||
def test_valid_module_path(self): | ||
# Arrange | ||
module_path = "det.llm.llm_openai" | ||
class_name = "OpenAIClient" | ||
|
||
# Act | ||
result = _get_client_class(module_path, class_name) | ||
|
||
# Assert | ||
assert result == OpenAIClient | ||
|
||
# Raises ImportError when the module path is invalid | ||
def test_invalid_module_path(self): | ||
# Arrange | ||
module_path = "det.llm.invalid_module" | ||
class_name = "ProviderClient" | ||
|
||
# Act and Assert | ||
with pytest.raises(ImportError): | ||
_get_client_class(module_path, class_name) | ||
|
||
# Should be able to import the specified client class from the constructed module path and instantiate it with the given model parameter | ||
def test_import_and_instantiate_client(self): | ||
llm_provider = "OpenAI" | ||
llm_model = "gpt-4" | ||
|
||
# Call the function under test | ||
client = get_llm_client(llm_provider, llm_model) | ||
|
||
# Assert that the client is an instance of the imported client class | ||
assert isinstance(client, OpenAIClient) | ||
|
||
# Assert that the client was instantiated with the correct model parameter | ||
assert client.model == llm_model | ||
|
||
# llm_provider and llm_model parameters are None | ||
def test_none_parameters(self): | ||
llm_provider = None | ||
llm_model = None | ||
|
||
# Call the function under test | ||
with pytest.raises(ValueError): | ||
get_llm_client(llm_provider, llm_model) | ||
|
||
# Raises ImportError when the module path is invalid | ||
def test_invalid_parameters(self): | ||
llm_provider = "invalid" | ||
llm_model = "invlaid" | ||
|
||
# Call the function under test | ||
with pytest.raises(ImportError): | ||
get_llm_client(llm_provider, llm_model) | ||
|
||
# Returns an instance of the specified embedding generator adapter class with the specified model name. | ||
def test_returns_instance_of_embedding_generator_adapter_class(self): | ||
# Arrange | ||
embeddings_provider = "Another" | ||
embeddings_model = "TestModel" | ||
|
||
# Act | ||
result = get_embedding_generator_adapter(embeddings_provider, embeddings_model) | ||
|
||
# Assert | ||
assert isinstance(result, AnotherEmbeddingGeneratorAdapter) | ||
assert result.model == embeddings_model |