Skip to content

Commit

Permalink
feat: Add PDF Retrieval and Text Splitter (#20)
Browse files Browse the repository at this point in the history
  • Loading branch information
xmnlab authored Nov 14, 2024
1 parent ae1a7c8 commit 6e80756
Show file tree
Hide file tree
Showing 18 changed files with 1,396 additions and 368 deletions.
1,403 changes: 1,091 additions & 312 deletions poetry.lock

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ torchvision = [
langdetect = ">=1"
openai = "^1.52.2"
google-generativeai = "^0.8.3"
pypdf = ">=5"
langchain = ">=0.3.7"
langchain-community = ">=0.3.7"

[tool.poetry.extras]
cpu = ["torch", "torchvision"]
Expand Down
2 changes: 2 additions & 0 deletions src/rago/retrieval/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
from __future__ import annotations

from rago.retrieval.base import RetrievalBase, StringRet
from rago.retrieval.file import PDFPathRet

__all__ = [
'RetrievalBase',
'StringRet',
'PDFPathRet',
]
50 changes: 40 additions & 10 deletions src/rago/retrieval/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,60 @@
from __future__ import annotations

from abc import abstractmethod
from typing import Any, cast
from typing import Any, Iterable, cast

from typeguard import typechecked

from rago.retrieval.text_splitter import (
LangChainTextSplitter,
TextSplitterBase,
)


@typechecked
class RetrievalBase:
"""Base Retrieval class."""

content: Any

def __init__(self, sources: Any) -> None:
self.sources = sources
source: Any
splitter: TextSplitterBase

def __init__(
self,
source: Any,
splitter: TextSplitterBase = LangChainTextSplitter(
'RecursiveCharacterTextSplitter'
),
) -> None:
"""Initialize the Retrieval class."""
self.source = source
self.splitter = splitter
self._validate()
self._setup()

def _validate(self) -> None:
"""Validate if the source is valid, otherwise raises an exception."""
return None

def _setup(self) -> None:
"""Set up the object with the giving initial parameters."""
return None

@abstractmethod
def get(self, query: str = '') -> Any:
"""Get the data from the sources."""
...
def get(self, query: str = '') -> Iterable[str]:
"""Get the data from the source."""
return []


@typechecked
class StringRet(RetrievalBase):
"""String Retrieval class."""
"""
String Retrieval class.
This is a very generic class that assumes that the input (source) is
already a list of strings.
"""

def get(self, query: str = '') -> list[str]:
def get(self, query: str = '') -> Iterable[str]:
"""Get the data from the sources."""
return cast(list[str], self.sources)
return cast(list[str], self.source)
41 changes: 41 additions & 0 deletions src/rago/retrieval/file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""Base classes for retrieval."""

from __future__ import annotations

from pathlib import Path
from typing import Iterable

from typeguard import typechecked

from rago.retrieval.base import RetrievalBase
from rago.retrieval.tools.pdf import extract_text_from_pdf, is_pdf


@typechecked
class FilePathRet(RetrievalBase):
"""File Retrieval class."""

def _validate(self) -> None:
"""Validate if the source is valid, otherwise raises an exception."""
if not isinstance(self.source, (str, Path)):
raise Exception('Argument source should be an string or a Path.')

source_path = Path(self.source)
if not source_path.exists():
raise Exception("File doesn't exist.")


@typechecked
class PDFPathRet(FilePathRet):
"""PDFPathRet Retrieval class."""

def _validate(self) -> None:
"""Validate if the source is valid, otherwise raises an exception."""
super()._validate()
if not is_pdf(self.source):
raise Exception('Given file is not a PDF.')

def get(self, query: str = '') -> Iterable[str]:
"""Get the data from the source."""
text = extract_text_from_pdf(self.source)
return self.splitter.split(text)
9 changes: 9 additions & 0 deletions src/rago/retrieval/text_splitter/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""Package for classes about text splitter."""

from rago.retrieval.text_splitter.base import TextSplitterBase
from rago.retrieval.text_splitter.langchain import LangChainTextSplitter

__all__ = [
'TextSplitterBase',
'LangChainTextSplitter',
]
49 changes: 49 additions & 0 deletions src/rago/retrieval/text_splitter/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""The base classes for text splitter."""

from __future__ import annotations

from abc import abstractmethod
from typing import Any, Iterable


class TextSplitterBase:
"""The base text splitter class."""

chunk_size: int = 500
chunk_overlap: int = 100
splitter_name: str = ''
splitter: Any = None

# defaults

default_chunk_size: int = 500
default_chunk_overlap: int = 100
default_splitter_name: str = ''
default_splitter: Any = None

def __init__(
self,
splitter_name: str = '',
chunk_size: int = 500,
chunk_overlap: int = 100,
) -> None:
"""Initialize the text splitter class."""
self.chunk_size = chunk_size or self.default_chunk_size
self.chunk_overlap = chunk_overlap or self.default_chunk_overlap
self.splitter_name = splitter_name or self.default_splitter_name

self._validate()
self._setup()

def _validate(self) -> None:
"""Validate if the initial parameters are valid."""
return

def _setup(self) -> None:
"""Set up the object according to the given parameters."""
return

@abstractmethod
def split(self, text: str) -> Iterable[str]:
"""Split a text into chunks."""
return []
40 changes: 40 additions & 0 deletions src/rago/retrieval/text_splitter/langchain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""Support langchain text splitter."""

from __future__ import annotations

from typing import List, cast

from langchain.text_splitter import RecursiveCharacterTextSplitter

from rago.retrieval.text_splitter.base import TextSplitterBase


class LangChainTextSplitter(TextSplitterBase):
"""LangChain Text Splitter class."""

default_splitter_name: str = 'RecursiveCharacterTextSplitter'

def _validate(self) -> None:
"""Validate if the initial parameters are valid."""
valid_splitter_names = ['RecursiveCharacterTextSplitter']

if self.splitter_name not in valid_splitter_names:
raise Exception(
f'The given splitter_name {self.splitter_name} is not valid. '
f'Valid options are {valid_splitter_names}'
)

def _setup(self) -> None:
"""Set up the object according to the given parameters."""
if self.splitter_name == 'RecursiveCharacterTextSplitter':
self.splitter = RecursiveCharacterTextSplitter

def split(self, text: str) -> list[str]:
"""Split text into smaller chunks for processing."""
text_splitter = self.splitter(
chunk_size=self.chunk_size,
chunk_overlap=self.chunk_overlap,
length_function=len,
is_separator_regex=True,
)
return cast(List[str], text_splitter.split_text(text))
1 change: 1 addition & 0 deletions src/rago/retrieval/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Tools for support retrieval classes."""
44 changes: 44 additions & 0 deletions src/rago/retrieval/tools/pdf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""PDF tools."""

from __future__ import annotations

from pathlib import Path

from pypdf import PdfReader


def is_pdf(file_path: str | Path) -> bool:
"""
Check if a file is a PDF by reading its header.
Parameters
----------
file_path : str
Path to the file to be checked.
Returns
-------
bool
True if the file is a PDF, False otherwise.
"""
try:
with open(file_path, 'rb') as file:
header = file.read(4)
return header == b'%PDF'
except IOError:
return False


def extract_text_from_pdf(file_path: str) -> str:
"""
Extract text from a PDF file using pypdf.
The result is the same as the one returned by PyPDFLoader.
"""
reader = PdfReader(file_path)
pages = []
for page in reader.pages:
text = page.extract_text()
if text:
pages.append(text)
return ' '.join(pages)
9 changes: 9 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,12 @@ def env() -> dict[str, str]:
dotenv_file = Path(__file__).parent / '.env'
load_dotenv(dotenv_file)
return dotenv_values(dotenv_file)


@pytest.fixture
def animals_data() -> list[str]:
"""Fixture for loading the animals dataset."""
data_path = Path(__file__).parent / 'data' / 'animals.txt'
with open(data_path) as f:
data = [line.strip() for line in f.readlines() if line.strip()]
return data
Binary file added tests/data/pdf/1.pdf
Binary file not shown.
11 changes: 0 additions & 11 deletions tests/test_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

import os

from pathlib import Path

import pytest

from rago import Rago
Expand All @@ -12,15 +10,6 @@
from rago.retrieval import StringRet


@pytest.fixture
def animals_data() -> list[str]:
"""Fixture for loading the animals dataset."""
data_path = Path(__file__).parent / 'data' / 'animals.txt'
with open(data_path) as f:
data = [line.strip() for line in f.readlines() if line.strip()]
return data


@pytest.fixture
def api_key(env) -> str:
"""Fixture for Gemini API key from environment."""
Expand Down
13 changes: 0 additions & 13 deletions tests/test_hugging_face.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,11 @@
"""Tests for rago package."""

from pathlib import Path

import pytest

from rago import Rago
from rago.augmented import SentenceTransformerAug
from rago.generation import HuggingFaceGen
from rago.retrieval import StringRet


@pytest.fixture
def animals_data() -> list[str]:
"""Create a fixture with animals dataset."""
data_path = Path(__file__).parent / 'data' / 'animals.txt'
with open(data_path) as f:
data = [line.strip() for line in f.readlines() if line.strip()]
return data


def test_aug_sentence_transformer(animals_data: list[str]) -> None:
"""Test RAG with hugging face."""
query = 'Is there any animals larger than a dinosaur?'
Expand Down
11 changes: 0 additions & 11 deletions tests/test_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

import os

from pathlib import Path

import pytest

from rago import Rago
Expand All @@ -12,15 +10,6 @@
from rago.retrieval import StringRet


@pytest.fixture
def animals_data() -> list[str]:
"""Create a fixture with animals dataset."""
data_path = Path(__file__).parent / 'data' / 'animals.txt'
with open(data_path) as f:
data = [line.strip() for line in f.readlines() if line.strip()]
return data


@pytest.mark.skip_on_ci
def test_llama(env, animals_data: list[str], device: str = 'auto') -> None:
"""Test RAG with hugging face."""
Expand Down
11 changes: 0 additions & 11 deletions tests/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

import os

from pathlib import Path

import pytest

from rago import Rago
Expand All @@ -12,15 +10,6 @@
from rago.retrieval import StringRet


@pytest.fixture
def animals_data() -> list[str]:
"""Fixture for loading the animals dataset."""
data_path = Path(__file__).parent / 'data' / 'animals.txt'
with open(data_path) as f:
data = [line.strip() for line in f.readlines() if line.strip()]
return data


@pytest.fixture
def api_key(env) -> str:
"""Fixture for OpenAI API key from environment."""
Expand Down
Loading

0 comments on commit 6e80756

Please sign in to comment.