diff --git a/pyproject.toml b/pyproject.toml index 897ab02..f7c2d4f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -101,9 +101,13 @@ exclude_lines = [ [tool.ruff] line-length = 88 indent-width = 4 +extend-exclude = ["tests", "others"] [tool.ruff.lint] select = ["E4", "E7", "E9", "F", "I", "D"] +ignore = ["D104"] +exclude = ["__about__.py"] + [tool.ruff.format] quote-style = "double" diff --git a/src/config.ini b/src/config.ini index 452ac04..54990bf 100644 --- a/src/config.ini +++ b/src/config.ini @@ -25,6 +25,14 @@ embedding_model : hkunlp/instructor-xl store_path : ${data:data_path}/vectordb allow_reset : True +[deeplake] +collection_name : arxiv +# embedding_type : sentence-transformers +# embedding_model : "all-mpnet-base-v2" +embedding_type : instructor-embedding +embedding_model : hkunlp/instructor-xl +store_path : ${data:data_path}/vectordb + [text_splitter] chunk_size : 5000 chunk_overlap : 400 @@ -51,4 +59,4 @@ table_as_html : True data_path : ${root:root_path}/data [root] -root_path : /home/ubuntu/volume_2k/Capstone_5 \ No newline at end of file +root_path : /home/ubuntu/CapStone/Capstone_5 diff --git a/src/grag/components/embedding.py b/src/grag/components/embedding.py index 7a9d249..eab107f 100644 --- a/src/grag/components/embedding.py +++ b/src/grag/components/embedding.py @@ -1,3 +1,9 @@ +"""Class for embedding. + +This module provies: +- Embedding +""" + from langchain_community.embeddings import HuggingFaceInstructEmbeddings from langchain_community.embeddings.sentence_transformer import ( SentenceTransformerEmbeddings, @@ -6,6 +12,7 @@ class Embedding: """A class for vector embeddings. + Supports: huggingface sentence transformers -> model_type = 'sentence-transformers' huggingface instructor embeddings -> model_type = 'instructor-embedding' @@ -17,6 +24,7 @@ class Embedding: """ def __init__(self, embedding_type: str, embedding_model: str): + """Initialize the embedding with embedding_type and embedding_model.""" self.embedding_type = embedding_type self.embedding_model = embedding_model match self.embedding_type: diff --git a/src/grag/components/llm.py b/src/grag/components/llm.py index 20db968..6e7296c 100644 --- a/src/grag/components/llm.py +++ b/src/grag/components/llm.py @@ -1,3 +1,5 @@ +"""Class for LLM.""" + import os from pathlib import Path @@ -50,6 +52,7 @@ def __init__( quantization=llm_conf["quantization"], pipeline=llm_conf["pipeline"], ): + """Initialize the LLM class using the given parameters.""" self.base_dir = Path(base_dir) self._model_name = model_name self.quantization = quantization diff --git a/src/grag/components/multivec_retriever.py b/src/grag/components/multivec_retriever.py index 18ed752..97684dd 100644 --- a/src/grag/components/multivec_retriever.py +++ b/src/grag/components/multivec_retriever.py @@ -1,3 +1,9 @@ +"""Class for retriever. + +This module provides: +- Retriever +""" + import asyncio import uuid from typing import List @@ -13,9 +19,11 @@ class Retriever: - """A class for multi vector retriever, it connects to a vector database and a local file store. - It is used to return most similar chunks from a vector store but has the additional funcationality - to return a linked document, chunk, etc. + """A class for multi vector retriever. + + It connects to a vector database and a local file store. + It is used to return most similar chunks from a vector store but has the additional functionality to return a + linked document, chunk, etc. Attributes: store_path: Path to the local file store @@ -36,7 +44,9 @@ def __init__( namespace: str = multivec_retriever_conf["namespace"], top_k=1, ): - """Args: + """Initialize the Retriever. + + Args: store_path: Path to the local file store, defaults to argument from config file id_key: A key prefix for identifying documents, defaults to argument from config file namespace: A namespace for producing unique id, defaults to argument from congig file @@ -58,6 +68,7 @@ def __init__( def id_gen(self, doc: Document) -> str: """Takes a document and returns a unique id (uuid5) using the namespace and document source. + This ensures that a single document always gets the same unique id. Args: @@ -81,7 +92,9 @@ def gen_doc_ids(self, docs: List[Document]) -> List[str]: return [self.id_gen(doc) for doc in docs] def split_docs(self, docs: List[Document]) -> List[Document]: - """Takes a list of documents and splits them into smaller chunks using TextSplitter from compoenents.text_splitter + """Takes a list of documents and splits them into smaller chunks. + + Using TextSplitter from components.text_splitter Also adds the unique parent document id into metadata Args: @@ -101,8 +114,7 @@ def split_docs(self, docs: List[Document]) -> List[Document]: return chunks def add_docs(self, docs: List[Document]): - """Takes a list of documents, splits them using the split_docs method and then adds them into the vector database - and adds the parent document into the file store. + """Adds given documents into the vector database also adds the parent document into the file store. Args: docs: List of langchain_core.documents.Document @@ -117,8 +129,7 @@ def add_docs(self, docs: List[Document]): self.retriever.docstore.mset(list(zip(doc_ids, docs))) async def aadd_docs(self, docs: List[Document]): - """Takes a list of documents, splits them using the split_docs method and then adds them into the vector database - and adds the parent document into the file store. + """Adds given documents into the vector database also adds the parent document into the file store. Args: docs: List of langchain_core.documents.Document diff --git a/src/grag/components/parse_pdf.py b/src/grag/components/parse_pdf.py index d918c93..dc30f8a 100644 --- a/src/grag/components/parse_pdf.py +++ b/src/grag/components/parse_pdf.py @@ -1,3 +1,9 @@ +"""Classes for parsing files. + +This module provides: +- ParsePDF +""" + from langchain_core.documents import Document from unstructured.partition.pdf import partition_pdf @@ -32,7 +38,7 @@ def __init__( add_captions_to_blocks=parser_conf["add_captions_to_blocks"], table_as_html=parser_conf["table_as_html"], ): - # Instantialize instance variables with parameters + """Initialize instance variables with parameters.""" self.strategy = strategy if extract_images: # by default always extract Table self.extract_image_block_types = [ @@ -72,7 +78,8 @@ def partition(self, path: str): def classify(self, partitions): """Classifies the partitioned elements into Text, Tables, and Images list in a dictionary. - Add captions for each element (if available). + + Also adds captions for each element (if available). Parameters: partitions (list): The list of partitioned elements from the PDF document. @@ -117,6 +124,8 @@ def classify(self, partitions): return classified_elements def text_concat(self, elements) -> str: + """Context aware concatenates all elements into a single string.""" + full_text = "" for current_element, next_element in zip(elements, elements[1:]): curr_type = current_element.category next_type = next_element.category diff --git a/src/grag/components/prompt.py b/src/grag/components/prompt.py index ecefa71..4364c06 100644 --- a/src/grag/components/prompt.py +++ b/src/grag/components/prompt.py @@ -1,3 +1,10 @@ +"""Classes for prompts. + +This module provides: +- Prompt - for generic prompts +- FewShotPrompt - for few-shot prompts +""" + import json from pathlib import Path from typing import Any, Dict, List, Optional, Union @@ -13,6 +20,20 @@ class Prompt(BaseModel): + """A class for generic prompts. + + Attributes: + name (str): The prompt name (Optional, defaults to "custom_prompt") + llm_type (str): The type of llm, llama2, etc (Optional, defaults to "None") + task (str): The task (Optional, defaults to QA) + source (str): The source of the prompt (Optional, defaults to "NoSource") + doc_chain (str): The doc chain for the prompt ("stuff", "refine") (Optional, defaults to "stuff") + language (str): The language of the prompt (Optional, defaults to "en") + filepath (str): The filepath of the prompt (Optional) + input_keys (List[str]): The input keys for the prompt + template (str): The template for the prompt + """ + name: str = Field(default="custom_prompt") llm_type: str = Field(default="None") task: str = Field(default="QA") @@ -27,6 +48,7 @@ class Prompt(BaseModel): @field_validator("input_keys") @classmethod def validate_input_keys(cls, v) -> List[str]: + """Validate the input_keys field.""" if v is None or v == []: raise ValueError("input_keys cannot be empty") return v @@ -34,6 +56,7 @@ def validate_input_keys(cls, v) -> List[str]: @field_validator("doc_chain") @classmethod def validate_doc_chain(cls, v: str) -> str: + """Validate the doc_chain field.""" if v not in SUPPORTED_DOC_CHAINS: raise ValueError( f"The provided doc_chain, {v} is not supported, supported doc_chains are {SUPPORTED_DOC_CHAINS}" @@ -43,6 +66,7 @@ def validate_doc_chain(cls, v: str) -> str: @field_validator("task") @classmethod def validate_task(cls, v: str) -> str: + """Validate the task field.""" if v not in SUPPORTED_TASKS: raise ValueError( f"The provided task, {v} is not supported, supported tasks are {SUPPORTED_TASKS}" @@ -53,6 +77,7 @@ def validate_task(cls, v: str) -> str: # def load_template(self): # self.prompt = ChatPromptTemplate.from_template(self.template) def __init__(self, **kwargs): + """Initialize the prompt.""" super().__init__(**kwargs) self.prompt = PromptTemplate( input_variables=self.input_keys, template=self.template @@ -61,6 +86,7 @@ def __init__(self, **kwargs): def save( self, filepath: Union[Path, str, None], overwrite=False ) -> Union[None, ValueError]: + """Saves the prompt class into a json file.""" dump = self.model_dump_json(indent=2, exclude_defaults=True, exclude_none=True) if filepath is None: filepath = f"{self.name}.json" @@ -74,17 +100,37 @@ def save( @classmethod def load(cls, filepath: Union[Path, str]): + """Loads a json file and returns a Prompt class.""" with open(f"{filepath}", "r") as f: prompt_json = json.load(f) _prompt = cls(**prompt_json) _prompt.filepath = str(filepath) return _prompt - def format(self, **kwargs): + def format(self, **kwargs) -> str: + """Formats the prompt with provided keys and returns a string.""" return self.prompt.format(**kwargs) class FewShotPrompt(Prompt): + """A class for generic prompts. + + Attributes: + name (str): The prompt name (Optional, defaults to "custom_prompt") (Parent Class) + llm_type (str): The type of llm, llama2, etc (Optional, defaults to "None") (Parent Class) + task (str): The task (Optional, defaults to QA) (Parent Class) + source (str): The source of the prompt (Optional, defaults to "NoSource") (Parent Class) + doc_chain (str): The doc chain for the prompt ("stuff", "refine") (Optional, defaults to "stuff") (Parent Class) + language (str): The language of the prompt (Optional, defaults to "en") (Parent Class) + filepath (str): The filepath of the prompt (Optional) (Parent Class) + input_keys (List[str]): The input keys for the prompt (Parent Class) + input_keys (List[str]): The output keys for the prompt + prefix (str): The template prefix for the prompt + suffix (str): The template suffix for the prompt + example_template (str): The template for formatting the examples + examples (List[Dict[str, Any]]): The list of examples, each example is a dictionary with respective keys + """ + output_keys: List[str] examples: List[Dict[str, Any]] prefix: str @@ -95,6 +141,7 @@ class FewShotPrompt(Prompt): ) def __init__(self, **kwargs): + """Initialize the prompt.""" super().__init__(**kwargs) eg_formatter = PromptTemplate( input_vars=self.input_keys + self.output_keys, @@ -111,6 +158,7 @@ def __init__(self, **kwargs): @field_validator("output_keys") @classmethod def validate_output_keys(cls, v) -> List[str]: + """Validate the output_keys field.""" if v is None or v == []: raise ValueError("output_keys cannot be empty") return v @@ -118,6 +166,7 @@ def validate_output_keys(cls, v) -> List[str]: @field_validator("examples") @classmethod def validate_examples(cls, v) -> List[Dict[str, Any]]: + """Validate the examples field.""" if v is None or v == []: raise ValueError("examples cannot be empty") for eg in v: diff --git a/src/grag/components/text_splitter.py b/src/grag/components/text_splitter.py index cff3c7c..d04c9a5 100644 --- a/src/grag/components/text_splitter.py +++ b/src/grag/components/text_splitter.py @@ -1,3 +1,9 @@ +"""Class for splitting/chunking text. + +This module provides: +- TextSplitter +""" + from langchain.text_splitter import RecursiveCharacterTextSplitter from .utils import get_config @@ -7,10 +13,23 @@ # %% class TextSplitter: - def __init__(self): + """Class for recursively chunking text, it prioritizes '/n/n then '/n' and so on. + + Attributes: + chunk_size: maximum size of chunk + chunk_overlap: chunk overlap size + """ + + def __init__( + self, + chunk_size: int = text_splitter_conf["chunk_size"], + chunk_overlap: int = text_splitter_conf["chunk_overlap"], + ): + """Initialize TextSplitter.""" self.text_splitter = RecursiveCharacterTextSplitter( - chunk_size=int(text_splitter_conf["chunk_size"]), - chunk_overlap=int(text_splitter_conf["chunk_overlap"]), + chunk_size=int(chunk_size), + chunk_overlap=int(chunk_overlap), length_function=len, is_separator_regex=False, ) + """Initialize TextSplitter using chunk_size and chunk_overlap""" diff --git a/src/grag/components/utils.py b/src/grag/components/utils.py index cb64258..2e34dc9 100644 --- a/src/grag/components/utils.py +++ b/src/grag/components/utils.py @@ -1,3 +1,12 @@ +"""Utils functions. + +This module provides: +- stuff_docs: concats langchain documents into string +- load_prompt: loads json prompt to langchain prompt +- find_config_path: finds the path of the 'config.ini' file by traversing up the directory tree from the current path. +- get_config: retrieves and parses the configuration settings from the 'config.ini' file. +""" + import json import os import textwrap @@ -10,7 +19,9 @@ def stuff_docs(docs: List[Document]) -> str: - """Args: + r"""Concatenates langchain documents into a string using '\n\n' seperator. + + Args: docs: List of langchain_core.documents.Document Returns: @@ -20,8 +31,7 @@ def stuff_docs(docs: List[Document]) -> str: def reformat_text_with_line_breaks(input_text, max_width=110): - """Reformat the given text to ensure each line does not exceed a specific width, - preserving existing line breaks. + """Reformat the given text to ensure each line does not exceed a specific width, preserving existing line breaks. Args: input_text (str): The text to be reformatted. @@ -62,7 +72,7 @@ def display_llm_output_and_sources(response_from_llm): def load_prompt(json_file: str | os.PathLike, return_input_vars=False): - """Loads a prompt template from json file and returns a langchain ChatPromptTemplate + """Loads a prompt template from json file and returns a langchain ChatPromptTemplate. Args: json_file: path to the prompt template json file. diff --git a/src/grag/rag/basic_rag.py b/src/grag/rag/basic_rag.py index a99ecdd..be055ba 100644 --- a/src/grag/rag/basic_rag.py +++ b/src/grag/rag/basic_rag.py @@ -1,3 +1,9 @@ +"""Class for Basic RAG. + +This module provides: +- BasicRAG +""" + import json from typing import List, Union @@ -13,6 +19,17 @@ class BasicRAG: + """Class for Basis RAG. + + Attributes: + model_name (str): Name of the llm model + doc_chain (str): Name of the document chain, ("stuff", "refine"), defaults to "stuff" + task (str): Name of task, defaults to "QA" + llm_kwargs (dict): Keyword arguments for LLM class + retriever_kwargs (dict): Keyword arguments for Retriever class + custom_prompt (Prompt): Prompt, defaults to None + """ + def __init__( self, model_name=None, @@ -20,8 +37,11 @@ def __init__( task="QA", llm_kwargs=None, retriever_kwargs=None, - custom_prompt: Union[Prompt, FewShotPrompt, None] = None, + custom_prompt: Union[ + Prompt, FewShotPrompt, List[Prompt, FewShotPrompt], None + ] = None, ): + """Initialize BasicRAG.""" if retriever_kwargs is None: self.retriever = Retriever() else: @@ -54,6 +74,7 @@ def __init__( @property def model_name(self): + """Return the name of the model.""" return self._model_name @model_name.setter @@ -67,6 +88,7 @@ def model_name(self, value): @property def doc_chain(self): + """Returns the doc_chain.""" return self._doc_chain @doc_chain.setter @@ -86,6 +108,7 @@ def doc_chain(self, value): @property def task(self): + """Returns the task.""" return self._task @task.setter @@ -99,6 +122,7 @@ def task(self, value): self.prompt_matcher() def prompt_matcher(self): + """Matches relvant prompt using model, task and doc_chain.""" matcher_path = self.prompt_path.joinpath("matcher.json") with open(f"{matcher_path}", "r") as f: matcher_dict = json.load(f) @@ -122,7 +146,9 @@ def prompt_matcher(self): @staticmethod def stuff_docs(docs: List[Document]) -> str: - """Args: + r"""Concatenates docs into a string seperated by '\n\n'. + + Args: docs: List of langchain_core.documents.Document Returns: @@ -132,6 +158,8 @@ def stuff_docs(docs: List[Document]) -> str: @staticmethod def output_parser(call_func): + """Decorator to format llm output.""" + def output_parser_wrapper(*args, **kwargs): response, sources = call_func(*args, **kwargs) if conf["llm"]["std_out"] == "False": @@ -146,6 +174,7 @@ def output_parser_wrapper(*args, **kwargs): @output_parser def stuff_call(self, query: str): + """Call function for stuff chain.""" retrieved_docs = self.retriever.get_chunk(query) context = self.stuff_docs(retrieved_docs) prompt = self.main_prompt.format(context=context, question=query) @@ -155,6 +184,7 @@ def stuff_call(self, query: str): @output_parser def refine_call(self, query: str): + """Call function for refine chain.""" retrieved_docs = self.retriever.get_chunk(query) sources = [doc.metadata["source"] for doc in retrieved_docs] responses = [] @@ -176,6 +206,7 @@ def refine_call(self, query: str): return responses, sources def __call__(self, query: str): + """Call function for the class.""" if self.doc_chain == "stuff": return self.stuff_call(query) elif self.doc_chain == "refine":