Skip to content

Commit

Permalink
refactor: Refactor bases classes for generation and retrieval (#15)
Browse files Browse the repository at this point in the history
  • Loading branch information
xmnlab authored Oct 31, 2024
1 parent b418c9b commit 2a73121
Show file tree
Hide file tree
Showing 21 changed files with 745 additions and 817 deletions.
3 changes: 0 additions & 3 deletions .env.tpl

This file was deleted.

1,105 changes: 498 additions & 607 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ torchvision = [
{version = ">=0.20.0", markers="extra=='gpu' and extra!='cpu'"},
]
langdetect = ">=1"
openai = "0.28"
openai = "^1.52.2"
google-generativeai = "^0.8.3"

[tool.poetry.extras]
Expand Down
4 changes: 2 additions & 2 deletions src/rago/augmented/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from __future__ import annotations

from rago.augmented.base import AugmentedBase
from rago.augmented.gemini_aug import GeminiAug
from rago.augmented.gemini import GeminiAug
from rago.augmented.hugging_face import HuggingFaceAug
from rago.augmented.openai_aug import OpenAIAug
from rago.augmented.openai import OpenAIAug

__all__ = [
'AugmentedBase',
Expand Down
58 changes: 51 additions & 7 deletions src/rago/augmented/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,29 +14,73 @@
class AugmentedBase:
"""Define the base structure for Augmented classes."""

api_key: str = ''
model: Optional[Any]
model_name: str = ''
db: Any
k: int = -1
documents: list[str]
k: int = 0
temperature: float = 0.5
prompt_template: str = ''
result_separator = '\n'
output_max_length: int = 500

# default values to be overwritten by the derived classes
default_model_name: str = ''
default_k: int = 0
default_temperature: float = 0.5
default_prompt_template: str = (
'Retrieve {k} entries from the context that better answer the '
'following query:\n```\n{query}\n```\n\ncontext:\n```\n{context}\n```'
)
default_result_separator = '\n'
default_output_max_length: int = 500

@abstractmethod
def __init__(
self,
documents: list[str] = [],
model_name: str = '',
api_key: str = '',
db: DBBase = FaissDB(),
k: int = -1,
k: int = 0,
temperature: float = 0.5,
prompt_template: str = '',
result_separator: str = '\n',
output_max_length: int = 500,
) -> None:
"""Initialize AugmentedBase."""
self.k = k
self.documents = documents
self.db = db
self.api_key = api_key

self.k = k or self.default_k
self.model_name = model_name or self.default_model_name
self.temperature = temperature or self.default_temperature
self.result_separator = (
result_separator or self.default_result_separator
)
self.prompt_template = prompt_template or self.default_prompt_template
self.output_max_length = (
output_max_length or self.default_output_max_length
)

self.model = None

self._validate()
self._setup()

def _validate(self) -> None:
"""Raise an error if the initial parameters are not valid."""
return

def _setup(self) -> None:
"""Set up the object with the initial parameters."""
return

@abstractmethod
def search(
self,
query: str,
documents: Any,
k: int = -1,
k: int = 0,
) -> list[str]:
"""Search an encoded query into vector database."""
...
27 changes: 13 additions & 14 deletions src/rago/augmented/gemini_aug.py → src/rago/augmented/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,30 +13,29 @@
class GeminiAug(AugmentedBase):
"""GeminiAug class for query augmentation using Gemini API."""

def __init__(
self,
model_name: str = 'gemini-1.5-flash',
k: int = 1,
api_key: str = '',
) -> None:
"""Initialize the GeminiAug class."""
self.model_name = model_name
self.k = k
genai.configure(api_key=api_key)
default_model_name: str = 'gemini-1.5-flash'
default_k: int = 1

def _setup(self) -> None:
"""Set up the object with the initial parameters."""
genai.configure(api_key=self.api_key)

def search(
self, query: str, documents: list[str], k: int = 1
self, query: str, documents: list[str], k: int = 0
) -> list[str]:
"""Augment the query by expanding or rephrasing it using Gemini."""
prompt = f"Retrieval: '{query}'\nContext: {' '.join(documents)}"
k = k or self.k
prompt = self.prompt_template.format(
query=query, context=' '.join(documents), k=k
)

response = genai.GenerativeModel(self.model_name).generate_content(
prompt
)

augmented_query = (
augmented_query = str(
response.text.strip()
if hasattr(response, 'text')
else response[0].text.strip()
)
return [augmented_query] * self.k
return augmented_query.split(self.result_separator)[:k]
34 changes: 11 additions & 23 deletions src/rago/augmented/hugging_face.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,36 +8,24 @@
from typeguard import typechecked

from rago.augmented.base import AugmentedBase
from rago.db import DBBase, FaissDB


@typechecked
class HuggingFaceAug(AugmentedBase):
"""Class for augmentation with Hugging Face."""

model: Any
k: int = -1
db: DBBase

def __init__(
self,
name: str = 'paraphrase',
db: DBBase = FaissDB(),
k: int = -1,
) -> None:
"""Initialize HuggingFaceAug."""
if name == 'paraphrase':
self.model = SentenceTransformer('paraphrase-MiniLM-L6-v2')
else:
raise Exception(
'The Augmented class name {name} is not supported.'
)

self.db = db
self.k = k

def search(self, query: str, documents: Any, k: int = -1) -> list[str]:
default_model_name = 'paraphrase-MiniLM-L6-v2'
default_k = 2

def _setup(self) -> None:
"""Set up the object with the initial parameters."""
self.model = SentenceTransformer(self.model_name)

def search(self, query: str, documents: Any, k: int = 0) -> list[str]:
"""Search an encoded query into vector database."""
if not self.model:
raise Exception('The model was not created.')

document_encoded = self.model.encode(documents)
query_encoded = self.model.encode([query])
k = k if k > 0 else self.k
Expand Down
48 changes: 48 additions & 0 deletions src/rago/augmented/openai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""OpenAIAug class for query augmentation using OpenAI API."""

from __future__ import annotations

from typing import cast

import openai

from typeguard import typechecked

from rago.augmented.base import AugmentedBase


@typechecked
class OpenAIAug(AugmentedBase):
"""OpenAIAug class for query augmentation using OpenAI API."""

default_model_name = 'gpt-3.5-turbo'
default_k = 2
default_result_separator = '\n'

def _setup(self) -> None:
"""Set up the object with the initial parameters."""
self.model = openai.OpenAI(api_key=self.api_key)

def search(
self, query: str, documents: list[str], k: int = 0
) -> list[str]:
"""Augment the query by expanding or rephrasing it using OpenAI."""
k = k or self.k
prompt = self.prompt_template.format(
query=query, context=' '.join(documents), k=k
)

if not self.model:
raise Exception('The model was not created.')

response = self.model.chat.completions.create(
model=self.model_name,
messages=[{'role': 'user', 'content': prompt}],
max_tokens=self.output_max_length,
temperature=self.temperature,
)

augmented_query = cast(
str, response.choices[0].message.content.strip()
)
return augmented_query.split(self.result_separator)[:k]
35 changes: 0 additions & 35 deletions src/rago/augmented/openai_aug.py

This file was deleted.

8 changes: 4 additions & 4 deletions src/rago/generation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
from __future__ import annotations

from rago.generation.base import GenerationBase
from rago.generation.gemini_ai import GeminiAIGen
from rago.generation.gemini import GeminiGen
from rago.generation.hugging_face import HuggingFaceGen
from rago.generation.llama import LlamaGen
from rago.generation.openai_gpt import OpenAIGPTGen
from rago.generation.openai import OpenAIGen

__all__ = [
'GenerationBase',
'HuggingFaceGen',
'LlamaGen',
'OpenAIGPTGen',
'GeminiAIGen',
'OpenAIGen',
'GeminiGen',
]
38 changes: 34 additions & 4 deletions src/rago/generation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,35 +18,54 @@ class GenerationBase:
device_name: str = 'cpu'
device: torch.device
model: Any
model_name: str = ''
tokenizer: Any
temperature: float = 0.5
output_max_length: int = 500
prompt_template: str = (
'question: \n```\n{query}\n```\ncontext: ```\n{context}\n```'
)

# default parameters that can be overwritten by the derived class
default_device_name: str = 'cpu'
default_model_name: str = ''
default_temperature: float = 0.5
default_output_max_length: int = 500
default_prompt_template: str = (
'question: \n```\n{query}\n```\ncontext: ```\n{context}\n```'
)

@abstractmethod
def __init__(
self,
model_name: str = '',
api_key: str = '',
temperature: float = 0.5,
prompt_template: str = '',
output_max_length: int = 500,
device: str = 'auto',
) -> None:
"""Initialize GenerationBase.
"""Initialize Generation class.
Parameters
----------
model_name : str
The name of the model to use.
api_key : str
temperature : float
prompt_template: str
output_max_length : int
Maximum length of the generated output.
device: str (default=auto)
"""
self.api_key = api_key
self.model_name = model_name
self.output_max_length = output_max_length
self.temperature = temperature
self.model_name = model_name or self.default_model_name
self.output_max_length = (
output_max_length or self.default_output_max_length
)
self.temperature = temperature or self.default_temperature

self.prompt_template = prompt_template or self.default_prompt_template

if self.device_name not in ['cpu', 'cuda', 'auto']:
raise Exception(
Expand All @@ -60,6 +79,17 @@ def __init__(
)
self.device = torch.device(self.device_name)

self._validate()
self._setup()

def _validate(self) -> None:
"""Raise an error if the initial parameters are not valid."""
return

def _setup(self) -> None:
"""Set up the object with the initial parameters."""
return

@abstractmethod
def generate(
self,
Expand Down
Loading

0 comments on commit 2a73121

Please sign in to comment.