-
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Add PDF Retrieval and Text Splitter (#20)
- Loading branch information
Showing
18 changed files
with
1,396 additions
and
368 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
"""Base classes for retrieval.""" | ||
|
||
from __future__ import annotations | ||
|
||
from pathlib import Path | ||
from typing import Iterable | ||
|
||
from typeguard import typechecked | ||
|
||
from rago.retrieval.base import RetrievalBase | ||
from rago.retrieval.tools.pdf import extract_text_from_pdf, is_pdf | ||
|
||
|
||
@typechecked | ||
class FilePathRet(RetrievalBase): | ||
"""File Retrieval class.""" | ||
|
||
def _validate(self) -> None: | ||
"""Validate if the source is valid, otherwise raises an exception.""" | ||
if not isinstance(self.source, (str, Path)): | ||
raise Exception('Argument source should be an string or a Path.') | ||
|
||
source_path = Path(self.source) | ||
if not source_path.exists(): | ||
raise Exception("File doesn't exist.") | ||
|
||
|
||
@typechecked | ||
class PDFPathRet(FilePathRet): | ||
"""PDFPathRet Retrieval class.""" | ||
|
||
def _validate(self) -> None: | ||
"""Validate if the source is valid, otherwise raises an exception.""" | ||
super()._validate() | ||
if not is_pdf(self.source): | ||
raise Exception('Given file is not a PDF.') | ||
|
||
def get(self, query: str = '') -> Iterable[str]: | ||
"""Get the data from the source.""" | ||
text = extract_text_from_pdf(self.source) | ||
return self.splitter.split(text) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
"""Package for classes about text splitter.""" | ||
|
||
from rago.retrieval.text_splitter.base import TextSplitterBase | ||
from rago.retrieval.text_splitter.langchain import LangChainTextSplitter | ||
|
||
__all__ = [ | ||
'TextSplitterBase', | ||
'LangChainTextSplitter', | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
"""The base classes for text splitter.""" | ||
|
||
from __future__ import annotations | ||
|
||
from abc import abstractmethod | ||
from typing import Any, Iterable | ||
|
||
|
||
class TextSplitterBase: | ||
"""The base text splitter class.""" | ||
|
||
chunk_size: int = 500 | ||
chunk_overlap: int = 100 | ||
splitter_name: str = '' | ||
splitter: Any = None | ||
|
||
# defaults | ||
|
||
default_chunk_size: int = 500 | ||
default_chunk_overlap: int = 100 | ||
default_splitter_name: str = '' | ||
default_splitter: Any = None | ||
|
||
def __init__( | ||
self, | ||
splitter_name: str = '', | ||
chunk_size: int = 500, | ||
chunk_overlap: int = 100, | ||
) -> None: | ||
"""Initialize the text splitter class.""" | ||
self.chunk_size = chunk_size or self.default_chunk_size | ||
self.chunk_overlap = chunk_overlap or self.default_chunk_overlap | ||
self.splitter_name = splitter_name or self.default_splitter_name | ||
|
||
self._validate() | ||
self._setup() | ||
|
||
def _validate(self) -> None: | ||
"""Validate if the initial parameters are valid.""" | ||
return | ||
|
||
def _setup(self) -> None: | ||
"""Set up the object according to the given parameters.""" | ||
return | ||
|
||
@abstractmethod | ||
def split(self, text: str) -> Iterable[str]: | ||
"""Split a text into chunks.""" | ||
return [] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
"""Support langchain text splitter.""" | ||
|
||
from __future__ import annotations | ||
|
||
from typing import List, cast | ||
|
||
from langchain.text_splitter import RecursiveCharacterTextSplitter | ||
|
||
from rago.retrieval.text_splitter.base import TextSplitterBase | ||
|
||
|
||
class LangChainTextSplitter(TextSplitterBase): | ||
"""LangChain Text Splitter class.""" | ||
|
||
default_splitter_name: str = 'RecursiveCharacterTextSplitter' | ||
|
||
def _validate(self) -> None: | ||
"""Validate if the initial parameters are valid.""" | ||
valid_splitter_names = ['RecursiveCharacterTextSplitter'] | ||
|
||
if self.splitter_name not in valid_splitter_names: | ||
raise Exception( | ||
f'The given splitter_name {self.splitter_name} is not valid. ' | ||
f'Valid options are {valid_splitter_names}' | ||
) | ||
|
||
def _setup(self) -> None: | ||
"""Set up the object according to the given parameters.""" | ||
if self.splitter_name == 'RecursiveCharacterTextSplitter': | ||
self.splitter = RecursiveCharacterTextSplitter | ||
|
||
def split(self, text: str) -> list[str]: | ||
"""Split text into smaller chunks for processing.""" | ||
text_splitter = self.splitter( | ||
chunk_size=self.chunk_size, | ||
chunk_overlap=self.chunk_overlap, | ||
length_function=len, | ||
is_separator_regex=True, | ||
) | ||
return cast(List[str], text_splitter.split_text(text)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""Tools for support retrieval classes.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
"""PDF tools.""" | ||
|
||
from __future__ import annotations | ||
|
||
from pathlib import Path | ||
|
||
from pypdf import PdfReader | ||
|
||
|
||
def is_pdf(file_path: str | Path) -> bool: | ||
""" | ||
Check if a file is a PDF by reading its header. | ||
Parameters | ||
---------- | ||
file_path : str | ||
Path to the file to be checked. | ||
Returns | ||
------- | ||
bool | ||
True if the file is a PDF, False otherwise. | ||
""" | ||
try: | ||
with open(file_path, 'rb') as file: | ||
header = file.read(4) | ||
return header == b'%PDF' | ||
except IOError: | ||
return False | ||
|
||
|
||
def extract_text_from_pdf(file_path: str) -> str: | ||
""" | ||
Extract text from a PDF file using pypdf. | ||
The result is the same as the one returned by PyPDFLoader. | ||
""" | ||
reader = PdfReader(file_path) | ||
pages = [] | ||
for page in reader.pages: | ||
text = page.extract_text() | ||
if text: | ||
pages.append(text) | ||
return ' '.join(pages) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.