Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(aug): Add support for SpaCy augmented class #23

Merged
merged 3 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2,631 changes: 1,812 additions & 819 deletions poetry.lock

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ google-generativeai = "^0.8.3"
pypdf = ">=5"
langchain = ">=0.3.7"
langchain-community = ">=0.3.7"
spacy = ">=3"

[tool.poetry.extras]
cpu = ["torch", "torchvision"]
Expand Down Expand Up @@ -72,6 +73,8 @@ makim = "1.19.0"
# 'PosixPath' object has no attribute 'endswith'
virtualenv = "<=20.25.1"
python-dotenv = ">=1.0"
# note: Version 3.7.1 requries spaCy >=3.7.2,<3.8.0
en-core-web-md = {url = "https://github.com/explosion/spacy-models/releases/download/en_core_web_md-3.7.1/en_core_web_md-3.7.1-py3-none-any.whl"}


[[tool.poetry.source]]
Expand Down
2 changes: 2 additions & 0 deletions src/rago/augmented/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
from rago.augmented.base import AugmentedBase
from rago.augmented.openai import OpenAIAug
from rago.augmented.sentence_transformer import SentenceTransformerAug
from rago.augmented.spacy import SpacyAug

__all__ = [
'AugmentedBase',
'OpenAIAug',
'SentenceTransformerAug',
'SpacyAug',
]
7 changes: 3 additions & 4 deletions src/rago/augmented/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,13 @@ class OpenAIAug(AugmentedBase):
"""Class for augmentation with OpenAI embeddings."""

default_model_name = 'text-embedding-3-small'
default_top_k = 2
default_top_k = 3

def _setup(self) -> None:
"""Set up the object with initial parameters."""
if not self.api_key:
raise ValueError('API key for OpenAI is required.')
openai.api_key = self.api_key
self.model_name = self.model_name or self.default_model_name
self.model = openai.OpenAI(api_key=self.api_key)

def get_embedding(
Expand Down Expand Up @@ -58,13 +57,13 @@ def search(
self.db.embed(document_encoded)
scores, indices = self.db.search(query_encoded, top_k=top_k)

retrieved_docs = [documents[i] for i in indices]

self.logs['indices'] = indices
self.logs['scores'] = scores
self.logs['search_params'] = {
'query_encoded': query_encoded,
'top_k': top_k,
}

retrieved_docs = [documents[i] for i in indices if i >= 0]

return retrieved_docs
2 changes: 1 addition & 1 deletion src/rago/augmented/sentence_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class SentenceTransformerAug(AugmentedBase):
"""Class for augmentation with Hugging Face."""

default_model_name = 'paraphrase-MiniLM-L12-v2'
default_top_k = 2
default_top_k = 3

def _setup(self) -> None:
"""Set up the object with the initial parameters."""
Expand Down
66 changes: 66 additions & 0 deletions src/rago/augmented/spacy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""Classes for augmentation with SpaCy embeddings."""

from __future__ import annotations

from typing import TYPE_CHECKING, List, cast

import numpy as np
import spacy

from typeguard import typechecked

from rago.augmented.base import AugmentedBase

if TYPE_CHECKING:
import numpy.typing as npt

from torch import Tensor


@typechecked
class SpacyAug(AugmentedBase):
"""Class for augmentation with SpaCy embeddings."""

default_model_name = 'en_core_web_md'
default_top_k = 3

def _setup(self) -> None:
"""Set up the object with initial parameters."""
self.model = spacy.load(self.model_name)

def get_embedding(
self, content: List[str]
) -> npt.NDArray[np.float64] | Tensor:
"""Retrieve the embedding for a given text using SpaCy."""
model = cast(spacy.language.Language, self.model)
embeddings = []
for text in content:
doc = model(text)
embeddings.append(doc.vector)
return np.array(embeddings)

def search(
self, query: str, documents: list[str], top_k: int = 0
) -> list[str]:
"""Search an encoded query into vector database."""
if not hasattr(self, 'db') or not self.db:
raise Exception('Vector database (db) is not initialized.')

# Encode the documents and query
document_encoded = self.get_embedding(documents)
query_encoded = self.get_embedding([query])
top_k = top_k or self.top_k or self.default_top_k or 1

self.db.embed(document_encoded)
scores, indices = self.db.search(query_encoded, top_k=top_k)

self.logs['indices'] = indices
self.logs['scores'] = scores
self.logs['search_params'] = {
'query_encoded': query_encoded,
'top_k': top_k,
}

retrieved_docs = [documents[i] for i in indices if i >= 0]

return retrieved_docs
Binary file added tests/data/pdf/2407.13797.pdf
Binary file not shown.
Binary file added tests/data/pdf/2407.20116.pdf
Binary file not shown.
40 changes: 39 additions & 1 deletion tests/test_retrieval_pdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,52 @@

from pathlib import Path

import pytest

from rago.augmented import SpacyAug
from rago.retrieval import PDFPathRet

PDF_DATA_PATH = Path(__file__).parent / 'data' / 'pdf'


def test_retrieval_pdf_extraction() -> None:
def test_retrieval_pdf_extraction_basic() -> None:
"""Test the text extraction from a pdf."""
pdf_ret = PDFPathRet(PDF_DATA_PATH / '1.pdf')
chunks = pdf_ret.get()

assert len(chunks) >= 100


@pytest.mark.parametrize(
'pdf_path,expected',
[
('2407.13797.pdf', ''),
('2407.20116.pdf', ''),
],
)
def test_retrieval_pdfs_extraction_aug_spacy(
pdf_path: str, expected: str
) -> None:
"""Test the text extraction from a pdf."""
pdf_ret = PDFPathRet(PDF_DATA_PATH / pdf_path)
chunks = pdf_ret.get()

min_total_chunks = 100 # arbitrary number
max_chunk_size = pdf_ret.splitter.chunk_size

assert len(chunks) >= min_total_chunks
for chunk in chunks:
assert len(chunk) < max_chunk_size

query = 'What are the key barriers to implementing vitamin D?'

aug_top_k = 3

aug_openai = SpacyAug(top_k=aug_top_k)
aug_result = aug_openai.search(query, documents=chunks)

assert aug_result
assert len(aug_result) == aug_top_k
assert all(['vitamin' in result.lower() for result in aug_result])
assert all(['vitamin d' in result.lower() for result in aug_result])
assert len(set(aug_result)) == 3
30 changes: 30 additions & 0 deletions tests/test_spacy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""Tests for Rago package using Spacy."""

from rago.augmented import SpacyAug
from rago.retrieval import StringRet


def test_aug_spacy(animals_data: list[str]) -> None:
"""Test RAG pipeline with Spacy."""
logs = {
'augmented': {},
}

query = 'Is there any animal larger than a dinosaur?'
top_k = 3

ret_string = StringRet(animals_data)
aug_openai = SpacyAug(
top_k=top_k,
logs=logs['augmented'],
)

ret_result = ret_string.get()
aug_result = aug_openai.search(query, ret_result)

assert aug_openai.top_k == top_k
assert len(aug_result) == top_k
assert any(['blue whale' in result.lower() for result in aug_result])

# check if logs have been used
assert logs['augmented']
Loading