Skip to content

Commit

Permalink
refactor: tests for the helplers
Browse files Browse the repository at this point in the history
  • Loading branch information
thompsonson committed Feb 29, 2024
1 parent 12e6de6 commit 3787883
Show file tree
Hide file tree
Showing 6 changed files with 127 additions and 57 deletions.
5 changes: 4 additions & 1 deletion .github/workflows/setup-poetry.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: Setup Poetry Environment and check formatting and style
name: Setup Poetry, check formatting and style, and run the tests

on: [push, pull_request]

Expand Down Expand Up @@ -33,3 +33,6 @@ jobs:

- name: Run Style Checks
run: make check_style

- name: Run the tests
run: make test
85 changes: 34 additions & 51 deletions det/embeddings/adapters.py
Original file line number Diff line number Diff line change
@@ -1,70 +1,53 @@
"""
Embeddings Adapter Module
# embeddings/adapters.py
This module provides a collection of adapter classes that implement the EmbeddingGeneratorInterface
to generate text embeddings using various external services or models, with an emphasis on
extensibility and caching capabilities to enhance performance.
The adapters act as intermediaries between the abstract interface and concrete embedding generation
mechanisms, allowing for flexible integration of different embedding sources without modifying
client code. This modular approach simplifies the process of adding new embedding generators or
switching between them.
Example usage:
from embeddings.adapters import OpenAIEmbeddingGeneratorAdapter
# Initialize the adapter with a specific model
embedding_adapter = OpenAIEmbeddingGeneratorAdapter(model="text-embedding-ada-002")
# Generate embeddings for a list of texts
embeddings = embedding_adapter.generate_embeddings(["Sample text for embedding."])
print(embeddings)
# The adapter utilizes an internal cache to store and retrieve embeddings, reducing
# the number of external requests and speeding up the process for repeated inputs.
Classes:
- EmbeddingGeneratorInterface: An abstract base class defining the contract for
embedding generators.
- OpenAIEmbeddingGeneratorAdapter: An adapter for the OpenAI Embedding Generator,
with caching support.
- AnotherEmbeddingGenerator: A template for additional embedding generator implementations.
The module demonstrates the use of the Adapter Design Pattern to facilitate the interaction between
high-level operations and external libraries or APIs. It ensures that changes in the embedding
generation services have minimal impact on the application code, promoting maintainability
and scalability.
"""


from abc import ABC, abstractmethod
from typing import List

from det.embeddings.cache import EmbeddingsCache
from det.embeddings.generator import OpenAIEmbeddingGenerator
from det.embeddings.generator import (
EmbeddingGeneratorInterface,
OpenAIEmbeddingGenerator,
)


class EmbeddingGeneratorAdapterInterface(ABC):
def __init__(self, model: str):
self.model = model
self.embedding_generator: EmbeddingGeneratorInterface = (
self._create_embedding_generator()
)

class EmbeddingGeneratorInterface(ABC):
@abstractmethod
def generate_embeddings(self, texts):
def _create_embedding_generator(self) -> EmbeddingGeneratorInterface:
pass

@abstractmethod
def generate_embeddings(self, texts: List[str]) -> List[List[float]]:
pass

class OpenAIEmbeddingGeneratorAdapter(EmbeddingGeneratorInterface):

class OpenAIEmbeddingGeneratorAdapter(EmbeddingGeneratorAdapterInterface):
def __init__(self, model="text-embedding-ada-002"):
self.embedding_generator = OpenAIEmbeddingGenerator(model=model)
super().__init__(model)
self.embeddings_cache = EmbeddingsCache(
embeddings_generator=self.embedding_generator
)

def generate_embeddings(self, texts):
def _create_embedding_generator(self) -> EmbeddingGeneratorInterface:
# Ensure the OpenAIEmbeddingGenerator class implements EmbeddingGeneratorInterface
return OpenAIEmbeddingGenerator(model=self.model)

def generate_embeddings(self, texts: List[str]) -> List[List[float]]:
return self.embeddings_cache.generate_embeddings(texts)


class AnotherEmbeddingGenerator(EmbeddingGeneratorInterface):
def generate_embeddings(self, texts):
class AnotherEmbeddingGeneratorAdapter(EmbeddingGeneratorAdapterInterface):
def __init__(self, model):
super().__init__(model)
# Additional setup if necessary

def _create_embedding_generator(self) -> EmbeddingGeneratorInterface:
# Return an instance of another class that implements EmbeddingGeneratorInterface
pass

def generate_embeddings(self, texts: List[str]) -> List[List[float]]:
# Implementation for another source of embeddings
pass
4 changes: 2 additions & 2 deletions det/embeddings/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,13 @@
import os
import pickle

from det.embeddings.generator import EmbeddingGenerator
from det.embeddings.generator import EmbeddingGeneratorInterface


class EmbeddingsCache:
def __init__(
self,
embeddings_generator: EmbeddingGenerator,
embeddings_generator: EmbeddingGeneratorInterface,
cache_file_path="embeddings_cache.pkl",
):
self.cache_file_path = cache_file_path
Expand Down
6 changes: 3 additions & 3 deletions det/embeddings/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from openai import OpenAI


class EmbeddingGenerator(ABC):
class EmbeddingGeneratorInterface(ABC):
"""
Abstract base class for embedding generators.
"""
Expand All @@ -61,7 +61,7 @@ def generate_embeddings(self, texts: List[str]) -> List[List[float]]:
raise NotImplementedError("This method should be implemented by subclasses.")


class OpenAIEmbeddingGenerator(EmbeddingGenerator):
class OpenAIEmbeddingGenerator(EmbeddingGeneratorInterface):
"""
Embedding generator using OpenAI's API.
"""
Expand All @@ -86,7 +86,7 @@ def generate_embeddings(self, texts: List[str]) -> List[List[float]]:
return [embedding.embedding for embedding in response.data]


class AnotherEmbeddingGenerator(EmbeddingGenerator):
class AnotherEmbeddingGenerator(EmbeddingGeneratorInterface):
"""
Placeholder class for another embedding generator.
"""
Expand Down
5 changes: 5 additions & 0 deletions det/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ def _get_client_class(module_path: str, class_name: str):


def get_llm_client(llm_provider: str, llm_model: str):
if not llm_provider:
raise ValueError(f"Could not import class for {llm_provider}")
if not llm_model:
raise ValueError(f"Model is not given: {llm_model}")

module_path = f"det.llm.llm_{llm_provider.lower()}"
class_name = f"{llm_provider}Client"

Expand Down
79 changes: 79 additions & 0 deletions tests/unit/test_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Generated by CodiumAI
from det.helpers import (
_get_client_class,
get_llm_client,
get_embedding_generator_adapter,
)
from det.llm.llm_openai import OpenAIClient
from det.embeddings.adapters import AnotherEmbeddingGeneratorAdapter

import pytest


class Test_GetClientClass:
# Can import a valid client class from a valid module path and return it
def test_valid_module_path(self):
# Arrange
module_path = "det.llm.llm_openai"
class_name = "OpenAIClient"

# Act
result = _get_client_class(module_path, class_name)

# Assert
assert result == OpenAIClient

# Raises ImportError when the module path is invalid
def test_invalid_module_path(self):
# Arrange
module_path = "det.llm.invalid_module"
class_name = "ProviderClient"

# Act and Assert
with pytest.raises(ImportError):
_get_client_class(module_path, class_name)

# Should be able to import the specified client class from the constructed module path and instantiate it with the given model parameter
def test_import_and_instantiate_client(self):
llm_provider = "OpenAI"
llm_model = "gpt-4"

# Call the function under test
client = get_llm_client(llm_provider, llm_model)

# Assert that the client is an instance of the imported client class
assert isinstance(client, OpenAIClient)

# Assert that the client was instantiated with the correct model parameter
assert client.model == llm_model

# llm_provider and llm_model parameters are None
def test_none_parameters(self):
llm_provider = None
llm_model = None

# Call the function under test
with pytest.raises(ValueError):
get_llm_client(llm_provider, llm_model)

# Raises ImportError when the module path is invalid
def test_invalid_parameters(self):
llm_provider = "invalid"
llm_model = "invlaid"

# Call the function under test
with pytest.raises(ImportError):
get_llm_client(llm_provider, llm_model)

# Returns an instance of the specified embedding generator adapter class with the specified model name.
def test_returns_instance_of_embedding_generator_adapter_class(self):
# Arrange
embeddings_provider = "Another"
embeddings_model = "TestModel"

# Act
result = get_embedding_generator_adapter(embeddings_provider, embeddings_model)

# Assert
assert isinstance(result, AnotherEmbeddingGeneratorAdapter)
assert result.model == embeddings_model

0 comments on commit 3787883

Please sign in to comment.