From 0559d19c8ef0ff3e53625c8d7d25a0eb5c64f7e9 Mon Sep 17 00:00:00 2001 From: Sanchit Vijay Date: Tue, 19 Mar 2024 19:50:02 -0400 Subject: [PATCH 1/4] Update ruff_commit.yml --- .github/workflows/ruff_commit.yml | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ruff_commit.yml b/.github/workflows/ruff_commit.yml index fe64b9d..2ffff92 100644 --- a/.github/workflows/ruff_commit.yml +++ b/.github/workflows/ruff_commit.yml @@ -8,9 +8,8 @@ jobs: - uses: actions/checkout@v2 - uses: actions/setup-python@v2 - run: pip install ruff - - run: | - ruff check src/ - ruff fix src/ + - run: ruff check src/ + - run: ruff format src/ - uses: stefanzweifel/git-auto-commit-action@v4 with: commit_message: 'style fixes by ruff' From 80bcc6891d259424d63224060056b06ceb483f39 Mon Sep 17 00:00:00 2001 From: Sanchit Vijay Date: Tue, 19 Mar 2024 20:45:49 -0400 Subject: [PATCH 2/4] Update ruff_commit.yml --- .github/workflows/ruff_commit.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ruff_commit.yml b/.github/workflows/ruff_commit.yml index 2ffff92..3a5e8cc 100644 --- a/.github/workflows/ruff_commit.yml +++ b/.github/workflows/ruff_commit.yml @@ -3,7 +3,7 @@ on: push jobs: lint: - runs-on: ubuntu_latest + runs-on: self-hosted steps: - uses: actions/checkout@v2 - uses: actions/setup-python@v2 From d749618f171572d90ad0c73d3397306288e2aac0 Mon Sep 17 00:00:00 2001 From: Sanchit Vijay Date: Tue, 19 Mar 2024 20:48:39 -0400 Subject: [PATCH 3/4] Update ruff_commit.yml --- .github/workflows/ruff_commit.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ruff_commit.yml b/.github/workflows/ruff_commit.yml index 3a5e8cc..2294c6d 100644 --- a/.github/workflows/ruff_commit.yml +++ b/.github/workflows/ruff_commit.yml @@ -8,7 +8,7 @@ jobs: - uses: actions/checkout@v2 - uses: actions/setup-python@v2 - run: pip install ruff - - run: ruff check src/ + # - run: ruff check src/ - run: ruff format src/ - uses: stefanzweifel/git-auto-commit-action@v4 with: From b378f56cd23fe065a535f51ef7bf99c20620343d Mon Sep 17 00:00:00 2001 From: sanchitvj Date: Wed, 20 Mar 2024 00:50:42 +0000 Subject: [PATCH 4/4] style fixes by ruff --- src/grag/components/chroma_client.py | 47 +++++--- src/grag/components/embedding.py | 16 ++- src/grag/components/llm.py | 104 +++++++++-------- src/grag/components/multivec_retriever.py | 31 +++-- src/grag/components/parse_pdf.py | 109 +++++++++--------- src/grag/components/prompt.py | 61 +++++----- src/grag/components/text_splitter.py | 12 +- src/grag/components/utils.py | 28 ++--- src/grag/prompts/__init__.py | 7 +- src/grag/rag/basic_rag.py | 96 ++++++++------- src/tests/components/embedding_test.py | 34 ++++-- src/tests/components/llm_test.py | 28 +++-- .../components/multivec_retriever_test.py | 20 ++-- src/tests/components/parse_pdf_test.py | 22 ++-- src/tests/components/prompt_test.py | 19 ++- src/tests/rag/basic_rag_test.py | 8 +- 16 files changed, 354 insertions(+), 288 deletions(-) diff --git a/src/grag/components/chroma_client.py b/src/grag/components/chroma_client.py index 72b0db7..1d816bd 100644 --- a/src/grag/components/chroma_client.py +++ b/src/grag/components/chroma_client.py @@ -10,7 +10,7 @@ from grag.components.embedding import Embedding from grag.components.utils import get_config -chroma_conf = get_config()['chroma'] +chroma_conf = get_config()["chroma"] class ChromaClient: @@ -37,12 +37,14 @@ class ChromaClient: LangChain wrapper for Chroma collection """ - def __init__(self, - host=chroma_conf['host'], - port=chroma_conf['port'], - collection_name=chroma_conf['collection_name'], - embedding_type=chroma_conf['embedding_type'], - embedding_model=chroma_conf['embedding_model']): + def __init__( + self, + host=chroma_conf["host"], + port=chroma_conf["port"], + collection_name=chroma_conf["collection_name"], + embedding_type=chroma_conf["embedding_type"], + embedding_model=chroma_conf["embedding_model"], + ): """Args: host: IP Address of hosted Chroma Vectorstore, defaults to argument from config file port: port address of hosted Chroma Vectorstore, defaults to argument from config file @@ -56,14 +58,19 @@ def __init__(self, self.embedding_type: str = embedding_type self.embedding_model: str = embedding_model - self.embedding_function = Embedding(embedding_model=self.embedding_model, - embedding_type=self.embedding_type).embedding_function + self.embedding_function = Embedding( + embedding_model=self.embedding_model, embedding_type=self.embedding_type + ).embedding_function self.chroma_client = chromadb.HttpClient(host=self.host, port=self.port) - self.collection = self.chroma_client.get_or_create_collection(name=self.collection_name) - self.langchain_chroma = Chroma(client=self.chroma_client, - collection_name=self.collection_name, - embedding_function=self.embedding_function, ) + self.collection = self.chroma_client.get_or_create_collection( + name=self.collection_name + ) + self.langchain_chroma = Chroma( + client=self.chroma_client, + collection_name=self.collection_name, + embedding_function=self.embedding_function, + ) self.allowed_metadata_types = (str, int, float, bool) def test_connection(self, verbose=True): @@ -78,9 +85,9 @@ def test_connection(self, verbose=True): response = self.chroma_client.heartbeat() if verbose: if response: - print(f'Connection to {self.host}/{self.port} is alive..') + print(f"Connection to {self.host}/{self.port} is alive..") else: - print(f'Connection to {self.host}/{self.port} is not alive !!') + print(f"Connection to {self.host}/{self.port} is not alive !!") return response async def aadd_docs(self, docs: List[Document], verbose=True): @@ -100,7 +107,11 @@ async def aadd_docs(self, docs: List[Document], verbose=True): # else: # await asyncio.gather(*tasks) if verbose: - for doc in atqdm(docs, desc=f'Adding documents to {self.collection_name}', total=len(docs)): + for doc in atqdm( + docs, + desc=f"Adding documents to {self.collection_name}", + total=len(docs), + ): await self.langchain_chroma.aadd_documents([doc]) else: for doc in docs: @@ -117,7 +128,9 @@ def add_docs(self, docs: List[Document], verbose=True): None """ docs = self._filter_metadata(docs) - for doc in (tqdm(docs, desc=f'Adding to {self.collection_name}:') if verbose else docs): + for doc in ( + tqdm(docs, desc=f"Adding to {self.collection_name}:") if verbose else docs + ): _id = self.langchain_chroma.add_documents([doc]) def _filter_metadata(self, docs: List[Document]): diff --git a/src/grag/components/embedding.py b/src/grag/components/embedding.py index ed483c7..7a9d249 100644 --- a/src/grag/components/embedding.py +++ b/src/grag/components/embedding.py @@ -20,11 +20,15 @@ def __init__(self, embedding_type: str, embedding_model: str): self.embedding_type = embedding_type self.embedding_model = embedding_model match self.embedding_type: - case 'sentence-transformers': - self.embedding_function = SentenceTransformerEmbeddings(model_name=self.embedding_model) - case 'instructor-embedding': - self.embedding_instruction = 'Represent the document for retrival' - self.embedding_function = HuggingFaceInstructEmbeddings(model_name=self.embedding_model) + case "sentence-transformers": + self.embedding_function = SentenceTransformerEmbeddings( + model_name=self.embedding_model + ) + case "instructor-embedding": + self.embedding_instruction = "Represent the document for retrival" + self.embedding_function = HuggingFaceInstructEmbeddings( + model_name=self.embedding_model + ) self.embedding_function.embed_instruction = self.embedding_instruction case _: - raise Exception('embedding_type is invalid') + raise Exception("embedding_type is invalid") diff --git a/src/grag/components/llm.py b/src/grag/components/llm.py index 5192881..20db968 100644 --- a/src/grag/components/llm.py +++ b/src/grag/components/llm.py @@ -16,7 +16,7 @@ from .utils import get_config -llm_conf = get_config()['llm'] +llm_conf = get_config()["llm"] print("CUDA: ", torch.cuda.is_available()) @@ -35,20 +35,21 @@ class LLM: n_gpu_layers (int): Number of GPU layers for CPP. """ - def __init__(self, - model_name=llm_conf["model_name"], - device_map=llm_conf["device_map"], - task=llm_conf["task"], - max_new_tokens=llm_conf["max_new_tokens"], - temperature=llm_conf["temperature"], - n_batch=llm_conf["n_batch_gpu_cpp"], - n_ctx=llm_conf["n_ctx_cpp"], - n_gpu_layers=llm_conf["n_gpu_layers_cpp"], - std_out=llm_conf["std_out"], - base_dir=llm_conf["base_dir"], - quantization=llm_conf["quantization"], - pipeline=llm_conf["pipeline"], - ): + def __init__( + self, + model_name=llm_conf["model_name"], + device_map=llm_conf["device_map"], + task=llm_conf["task"], + max_new_tokens=llm_conf["max_new_tokens"], + temperature=llm_conf["temperature"], + n_batch=llm_conf["n_batch_gpu_cpp"], + n_ctx=llm_conf["n_ctx_cpp"], + n_gpu_layers=llm_conf["n_gpu_layers_cpp"], + std_out=llm_conf["std_out"], + base_dir=llm_conf["base_dir"], + quantization=llm_conf["quantization"], + pipeline=llm_conf["pipeline"], + ): self.base_dir = Path(base_dir) self._model_name = model_name self.quantization = quantization @@ -74,7 +75,8 @@ def model_name(self): def model_path(self): """Sets the model name.""" return str( - self.base_dir / self.model_name / f'ggml-model-{self.quantization}.gguf') + self.base_dir / self.model_name / f"ggml-model-{self.quantization}.gguf" + ) @model_name.setter def model_name(self, value): @@ -92,21 +94,24 @@ def hf_pipeline(self, is_local=False): else: hf_model = self.model_name match self.quantization: - case 'Q8': + case "Q8": quantization_config = BitsAndBytesConfig(load_in_8bit=True) - case 'Q4': + case "Q4": quantization_config = BitsAndBytesConfig(load_in_4bit=True) case _: raise ValueError( - f'{self.quantization} is not a valid quantization. Non-local hf_pipeline takes only Q4 and Q8.') + f"{self.quantization} is not a valid quantization. Non-local hf_pipeline takes only Q4 and Q8." + ) try: # Try to load the model without passing the token tokenizer = AutoTokenizer.from_pretrained(hf_model) - model = AutoModelForCausalLM.from_pretrained(hf_model, - quantization_config=quantization_config, - device_map=self.device_map, - torch_dtype=torch.float16, ) + model = AutoModelForCausalLM.from_pretrained( + hf_model, + quantization_config=quantization_config, + device_map=self.device_map, + torch_dtype=torch.float16, + ) except OSError: # LocalTokenNotFoundError: # If loading fails due to an auth token error, then load the token and retry load_dotenv() @@ -114,24 +119,29 @@ def hf_pipeline(self, is_local=False): if not auth_token: raise ValueError("Authentication token not provided.") tokenizer = AutoTokenizer.from_pretrained(hf_model, token=True) - model = AutoModelForCausalLM.from_pretrained(hf_model, - quantization_config=quantization_config, - device_map=self.device_map, - torch_dtype=torch.float16, - token=True) - - pipe = pipeline(self.task, - model=model, - tokenizer=tokenizer, - torch_dtype=torch.bfloat16, - device_map=self.device_map, - max_new_tokens=self.max_new_tokens, - do_sample=True, - top_k=10, - num_return_sequences=1, - eos_token_id=tokenizer.eos_token_id - ) - llm = HuggingFacePipeline(pipeline=pipe, model_kwargs={'temperature': self.temperature}) + model = AutoModelForCausalLM.from_pretrained( + hf_model, + quantization_config=quantization_config, + device_map=self.device_map, + torch_dtype=torch.float16, + token=True, + ) + + pipe = pipeline( + self.task, + model=model, + tokenizer=tokenizer, + torch_dtype=torch.bfloat16, + device_map=self.device_map, + max_new_tokens=self.max_new_tokens, + do_sample=True, + top_k=10, + num_return_sequences=1, + eos_token_id=tokenizer.eos_token_id, + ) + llm = HuggingFacePipeline( + pipeline=pipe, model_kwargs={"temperature": self.temperature} + ) return llm def llama_cpp(self): @@ -149,11 +159,9 @@ def llama_cpp(self): ) return llm - def load_model(self, - model_name=None, - pipeline=None, - quantization=None, - is_local=None): + def load_model( + self, model_name=None, pipeline=None, quantization=None, is_local=None + ): """Loads the model based on the specified pipeline and model name. Args: @@ -172,7 +180,7 @@ def load_model(self, is_local = False match self.pipeline: - case 'llama_cpp': + case "llama_cpp": return self.llama_cpp() - case 'hf': + case "hf": return self.hf_pipeline(is_local=is_local) diff --git a/src/grag/components/multivec_retriever.py b/src/grag/components/multivec_retriever.py index a839917..9946a3a 100644 --- a/src/grag/components/multivec_retriever.py +++ b/src/grag/components/multivec_retriever.py @@ -10,7 +10,7 @@ from grag.components.text_splitter import TextSplitter from grag.components.utils import get_config -multivec_retriever_conf = get_config()['multivec_retriever'] +multivec_retriever_conf = get_config()["multivec_retriever"] class Retriever: @@ -30,11 +30,13 @@ class Retriever: """ - def __init__(self, - store_path: str = multivec_retriever_conf['store_path'], - id_key: str = multivec_retriever_conf['id_key'], - namespace: str = multivec_retriever_conf['namespace'], - top_k=1): + def __init__( + self, + store_path: str = multivec_retriever_conf["store_path"], + id_key: str = multivec_retriever_conf["id_key"], + namespace: str = multivec_retriever_conf["namespace"], + top_k=1, + ): """Args: store_path: Path to the local file store, defaults to argument from config file id_key: A key prefix for identifying documents, defaults to argument from config file @@ -53,7 +55,7 @@ def __init__(self, ) self.splitter = TextSplitter() self.top_k: int = top_k - self.retriever.search_kwargs = {'k': self.top_k} + self.retriever.search_kwargs = {"k": self.top_k} def id_gen(self, doc: Document) -> str: """Takes a document and returns a unique id (uuid5) using the namespace and document source. @@ -65,7 +67,7 @@ def id_gen(self, doc: Document) -> str: Returns: string of hexadecimal uuid """ - return uuid.uuid5(self.namespace, doc.metadata['source']).hex + return uuid.uuid5(self.namespace, doc.metadata["source"]).hex def gen_doc_ids(self, docs: List[Document]) -> List[str]: """Takes a list of documents and produces a list of unique id, refer id_gen method for more details. @@ -144,15 +146,12 @@ def get_chunk(self, query: str, with_score=False, top_k=None): """ if with_score: - return self.client.langchain_chroma.similarity_search_with_relevance_scores( - query=query, - **{'k': top_k} if top_k else self.retriever.search_kwargs + query=query, **{"k": top_k} if top_k else self.retriever.search_kwargs ) else: return self.client.langchain_chroma.similarity_search( - query=query, - **{'k': top_k} if top_k else self.retriever.search_kwargs + query=query, **{"k": top_k} if top_k else self.retriever.search_kwargs ) async def aget_chunk(self, query: str, with_score=False, top_k=None): @@ -169,13 +168,11 @@ async def aget_chunk(self, query: str, with_score=False, top_k=None): """ if with_score: return await self.client.langchain_chroma.asimilarity_search_with_relevance_scores( - query=query, - **{'k': top_k} if top_k else self.retriever.search_kwargs + query=query, **{"k": top_k} if top_k else self.retriever.search_kwargs ) else: return await self.client.langchain_chroma.asimilarity_search( - query=query, - **{'k': top_k} if top_k else self.retriever.search_kwargs + query=query, **{"k": top_k} if top_k else self.retriever.search_kwargs ) def get_doc(self, query: str): diff --git a/src/grag/components/parse_pdf.py b/src/grag/components/parse_pdf.py index 00d0fa0..d918c93 100644 --- a/src/grag/components/parse_pdf.py +++ b/src/grag/components/parse_pdf.py @@ -3,39 +3,42 @@ from .utils import get_config -parser_conf = get_config()['parser'] +parser_conf = get_config()["parser"] class ParsePDF: """Parsing and partitioning PDF documents into Text, Table or Image elements. - + Attributes: single_text_out (bool): Whether to combine all text elements into a single output document. strategy (str): The strategy for PDF partitioning; default is "hi_res" for better accuracy. extract_image_block_types (list): Elements to be extracted as image blocks. infer_table_structure (bool): Whether to extract tables during partitioning. - extract_images (bool): Whether to extract images. + extract_images (bool): Whether to extract images. image_output_dir (str): Directory to save extracted images, if any. add_captions_to_text (bool): Whether to include figure captions in text output. Default is True. add_captions_to_blocks (bool): Whether to add captions to table and image blocks. Default is True. add_caption_first (bool): Whether to place captions before their corresponding image or table in the output. Default is True. """ - def __init__(self, - single_text_out=parser_conf['single_text_out'], - strategy=parser_conf['strategy'], - infer_table_structure=parser_conf['infer_table_structure'], - extract_images=parser_conf['extract_images'], - image_output_dir=parser_conf['image_output_dir'], - add_captions_to_text=parser_conf['add_captions_to_text'], - add_captions_to_blocks=parser_conf['add_captions_to_blocks'], - table_as_html=parser_conf['table_as_html'] - - ): + def __init__( + self, + single_text_out=parser_conf["single_text_out"], + strategy=parser_conf["strategy"], + infer_table_structure=parser_conf["infer_table_structure"], + extract_images=parser_conf["extract_images"], + image_output_dir=parser_conf["image_output_dir"], + add_captions_to_text=parser_conf["add_captions_to_text"], + add_captions_to_blocks=parser_conf["add_captions_to_blocks"], + table_as_html=parser_conf["table_as_html"], + ): # Instantialize instance variables with parameters self.strategy = strategy if extract_images: # by default always extract Table - self.extract_image_block_types = ["Image", "Table"] # extracting Image and Table as image blocks + self.extract_image_block_types = [ + "Image", + "Table", + ] # extracting Image and Table as image blocks else: self.extract_image_block_types = ["Table"] self.infer_table_structure = infer_table_structure @@ -63,7 +66,7 @@ def partition(self, path: str): extract_image_block_types=self.extract_image_block_types, infer_table_structure=self.infer_table_structure, extract_image_block_to_payload=False, - extract_image_block_output_dir=self.image_output_dir + extract_image_block_output_dir=self.image_output_dir, ) return partitions @@ -78,43 +81,42 @@ def classify(self, partitions): dict: A dictionary with keys 'Text', 'Tables', and 'Images', each containing a list of corresponding elements. """ # Initialize lists for each type of element - classified_elements = { - 'Text': [], - 'Tables': [], - 'Images': [] - } + classified_elements = {"Text": [], "Tables": [], "Images": []} for i, element in enumerate(partitions): # enumerate, classify and add element + caption (when available) to respective list if element.category == "Table": if self.add_captions_to_blocks and i + 1 < len(partitions): - if partitions[i + 1].category == "FigureCaption": # check for caption + if ( + partitions[i + 1].category == "FigureCaption" + ): # check for caption caption_element = partitions[i + 1] else: caption_element = None - classified_elements['Tables'].append((element, caption_element)) + classified_elements["Tables"].append((element, caption_element)) else: - classified_elements['Tables'].append((element, None)) + classified_elements["Tables"].append((element, None)) elif element.category == "Image": if self.add_captions_to_blocks and i + 1 < len(partitions): - if partitions[i + 1].category == "FigureCaption": # check for caption + if ( + partitions[i + 1].category == "FigureCaption" + ): # check for caption caption_element = partitions[i + 1] else: caption_element = None - classified_elements['Images'].append((element, caption_element)) + classified_elements["Images"].append((element, caption_element)) else: - classified_elements['Images'].append((element, None)) + classified_elements["Images"].append((element, None)) else: if not self.add_captions_to_text: - if element.category != 'FigureCaption': - classified_elements['Text'].append(element) + if element.category != "FigureCaption": + classified_elements["Text"].append(element) else: - classified_elements['Text'].append(element) + classified_elements["Text"].append(element) return classified_elements def text_concat(self, elements) -> str: - for current_element, next_element in zip(elements, elements[1:]): curr_type = current_element.category next_type = next_element.category @@ -122,22 +124,22 @@ def text_concat(self, elements) -> str: # if curr_type in ["FigureCaption", "NarrativeText", "Title", "Address", 'Table', "UncategorizedText", "Formula"]: # full_text += str(current_element) + "\n\n" - if curr_type == "Title" and next_type == 'NarrativeText': - full_text += str(current_element) + '\n' - elif curr_type == 'NarrativeText' and next_type == 'NarrativeText': - full_text += str(current_element) + '\n' + if curr_type == "Title" and next_type == "NarrativeText": + full_text += str(current_element) + "\n" + elif curr_type == "NarrativeText" and next_type == "NarrativeText": + full_text += str(current_element) + "\n" elif curr_type == "ListItem": full_text += "- " + str(current_element) + "\n" - if next_element == 'Title': - full_text += '\n' - elif next_element == 'Title': - full_text = str(current_element) + '\n\n' + if next_element == "Title": + full_text += "\n" + elif next_element == "Title": + full_text = str(current_element) + "\n\n" elif curr_type in ["Header", "Footer", "PageBreak"]: full_text += str(current_element) + "\n\n\n" else: - full_text += '\n' + full_text += "\n" return full_text @@ -151,14 +153,13 @@ def process_text(self, elements): docs (list): A list of Document instances containing the extracted Text content and their metadata. """ if self.single_text_out: - metadata = {'source': self.file_path} # Check for more metadata + metadata = {"source": self.file_path} # Check for more metadata text = "\n\n".join([str(el) for el in elements]) docs = [Document(page_content=text, metadata=metadata)] else: docs = [] for element in elements: - metadata = {'source': self.file_path, - 'category': element.category} + metadata = {"source": self.file_path, "category": element.category} metadata.update(element.metadata.to_dict()) docs.append(Document(page_content=str(element), metadata=metadata)) return docs @@ -170,13 +171,12 @@ def process_tables(self, elements): elements (list): The list of table elements (and optional captions) to be processed. Returns: - docs (list): A list of Document instances containing Tables, their captions and metadata. + docs (list): A list of Document instances containing Tables, their captions and metadata. """ docs = [] for block_element, caption_element in elements: - metadata = {'source': self.file_path, - 'category': block_element.category} + metadata = {"source": self.file_path, "category": block_element.category} metadata.update(block_element.metadata.to_dict()) if self.table_as_html: table_data = block_element.metadata.text_as_html @@ -184,7 +184,9 @@ def process_tables(self, elements): table_data = str(block_element) if caption_element: - if self.add_caption_first: # if there is a caption, add that before the element + if ( + self.add_caption_first + ): # if there is a caption, add that before the element content = "\n\n".join([str(caption_element), table_data]) else: content = "\n\n".join([table_data, str(caption_element)]) @@ -204,8 +206,7 @@ def process_images(self, elements): """ docs = [] for block_element, caption_element in elements: - metadata = {'source': self.file_path, - 'category': block_element.category} + metadata = {"source": self.file_path, "category": block_element.category} metadata.update(block_element.metadata.to_dict()) if caption_element: # if there is a caption, add that before the element if self.add_caption_first: @@ -228,9 +229,7 @@ def load_file(self, path): """ partitions = self.partition(str(path)) classified_elements = self.classify(partitions) - text_docs = self.process_text(classified_elements['Text']) - table_docs = self.process_tables(classified_elements['Tables']) - image_docs = self.process_images(classified_elements['Images']) - return {'Text': text_docs, - 'Tables': table_docs, - 'Images': image_docs} + text_docs = self.process_text(classified_elements["Text"]) + table_docs = self.process_tables(classified_elements["Tables"]) + image_docs = self.process_images(classified_elements["Images"]) + return {"Text": text_docs, "Tables": table_docs, "Images": image_docs} diff --git a/src/grag/components/prompt.py b/src/grag/components/prompt.py index 3c20be9..ecefa71 100644 --- a/src/grag/components/prompt.py +++ b/src/grag/components/prompt.py @@ -9,16 +9,16 @@ Example = Dict[str, Any] SUPPORTED_TASKS = ["QA"] -SUPPORTED_DOC_CHAINS = ["stuff", 'refine'] +SUPPORTED_DOC_CHAINS = ["stuff", "refine"] class Prompt(BaseModel): - name: str = Field(default='custom_prompt') - llm_type: str = Field(default='None') - task: str = Field(default='QA') - source: str = Field(default='NoSource') - doc_chain: str = Field(default='stuff') - language: str = 'en' + name: str = Field(default="custom_prompt") + llm_type: str = Field(default="None") + task: str = Field(default="QA") + source: str = Field(default="NoSource") + doc_chain: str = Field(default="stuff") + language: str = "en" filepath: Optional[str] = Field(default=None, exclude=True) input_keys: List[str] template: str @@ -28,7 +28,7 @@ class Prompt(BaseModel): @classmethod def validate_input_keys(cls, v) -> List[str]: if v is None or v == []: - raise ValueError('input_keys cannot be empty') + raise ValueError("input_keys cannot be empty") return v @field_validator("doc_chain") @@ -36,14 +36,17 @@ def validate_input_keys(cls, v) -> List[str]: def validate_doc_chain(cls, v: str) -> str: if v not in SUPPORTED_DOC_CHAINS: raise ValueError( - f'The provided doc_chain, {v} is not supported, supported doc_chains are {SUPPORTED_DOC_CHAINS}') + f"The provided doc_chain, {v} is not supported, supported doc_chains are {SUPPORTED_DOC_CHAINS}" + ) return v @field_validator("task") @classmethod def validate_task(cls, v: str) -> str: if v not in SUPPORTED_TASKS: - raise ValueError(f'The provided task, {v} is not supported, supported tasks are {SUPPORTED_TASKS}') + raise ValueError( + f"The provided task, {v} is not supported, supported tasks are {SUPPORTED_TASKS}" + ) return v # @model_validator(mode='after') @@ -51,21 +54,21 @@ def validate_task(cls, v: str) -> str: # self.prompt = ChatPromptTemplate.from_template(self.template) def __init__(self, **kwargs): super().__init__(**kwargs) - self.prompt = PromptTemplate(input_variables=self.input_keys, template=self.template) - - def save(self, filepath: Union[Path, str, None], overwrite=False) -> Union[None, ValueError]: - dump = self.model_dump_json( - indent=2, - exclude_defaults=True, - exclude_none=True + self.prompt = PromptTemplate( + input_variables=self.input_keys, template=self.template ) + + def save( + self, filepath: Union[Path, str, None], overwrite=False + ) -> Union[None, ValueError]: + dump = self.model_dump_json(indent=2, exclude_defaults=True, exclude_none=True) if filepath is None: - filepath = f'{self.name}.json' + filepath = f"{self.name}.json" if overwrite: if self.filepath is None: - return ValueError('filepath does not exist in instance') + return ValueError("filepath does not exist in instance") filepath = self.filepath - with open(filepath, 'w') as f: + with open(filepath, "w") as f: f.write(dump) return None @@ -87,12 +90,16 @@ class FewShotPrompt(Prompt): prefix: str suffix: str example_template: str - prompt: Optional[FewShotPromptTemplate] = Field(exclude=True, repr=False, default=None) + prompt: Optional[FewShotPromptTemplate] = Field( + exclude=True, repr=False, default=None + ) def __init__(self, **kwargs): super().__init__(**kwargs) - eg_formatter = PromptTemplate(input_vars=self.input_keys + self.output_keys, - template=self.example_template) + eg_formatter = PromptTemplate( + input_vars=self.input_keys + self.output_keys, + template=self.example_template, + ) self.prompt = FewShotPromptTemplate( examples=self.examples, example_prompt=eg_formatter, @@ -105,14 +112,14 @@ def __init__(self, **kwargs): @classmethod def validate_output_keys(cls, v) -> List[str]: if v is None or v == []: - raise ValueError('output_keys cannot be empty') + raise ValueError("output_keys cannot be empty") return v - @field_validator('examples') + @field_validator("examples") @classmethod def validate_examples(cls, v) -> List[Dict[str, Any]]: if v is None or v == []: - raise ValueError('examples cannot be empty') + raise ValueError("examples cannot be empty") for eg in v: if not all(key in eg for key in cls.input_keys): raise ValueError(f"input key(s) not in example {eg}") @@ -121,5 +128,5 @@ def validate_examples(cls, v) -> List[Dict[str, Any]]: return v -if __name__ == '__main__': +if __name__ == "__main__": p = Prompt.load("../prompts/Llama-2_QA_1.json") diff --git a/src/grag/components/text_splitter.py b/src/grag/components/text_splitter.py index 4b5b334..cff3c7c 100644 --- a/src/grag/components/text_splitter.py +++ b/src/grag/components/text_splitter.py @@ -2,13 +2,15 @@ from .utils import get_config -text_splitter_conf = get_config()['text_splitter'] +text_splitter_conf = get_config()["text_splitter"] # %% class TextSplitter: def __init__(self): - self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=int(text_splitter_conf['chunk_size']), - chunk_overlap=int(text_splitter_conf['chunk_overlap']), - length_function=len, - is_separator_regex=False, ) + self.text_splitter = RecursiveCharacterTextSplitter( + chunk_size=int(text_splitter_conf["chunk_size"]), + chunk_overlap=int(text_splitter_conf["chunk_overlap"]), + length_function=len, + is_separator_regex=False, + ) diff --git a/src/grag/components/utils.py b/src/grag/components/utils.py index ea3c041..cb64258 100644 --- a/src/grag/components/utils.py +++ b/src/grag/components/utils.py @@ -16,7 +16,7 @@ def stuff_docs(docs: List[Document]) -> str: Returns: string of document page content joined by '\n\n' """ - return '\n\n'.join([doc.page_content for doc in docs]) + return "\n\n".join([doc.page_content for doc in docs]) def reformat_text_with_line_breaks(input_text, max_width=110): @@ -31,13 +31,15 @@ def reformat_text_with_line_breaks(input_text, max_width=110): str: The reformatted text with preserved line breaks and adjusted line width. """ # Divide the text into separate lines - original_lines = input_text.split('\n') + original_lines = input_text.split("\n") # Apply wrapping to each individual line - reformatted_lines = [textwrap.fill(line, width=max_width) for line in original_lines] + reformatted_lines = [ + textwrap.fill(line, width=max_width) for line in original_lines + ] # Combine the lines back into a single text block - reformatted_text = '\n'.join(reformatted_lines) + reformatted_text = "\n".join(reformatted_lines) return reformatted_text @@ -49,14 +51,14 @@ def display_llm_output_and_sources(response_from_llm): response_from_llm (dict): The response object from an LLM which includes the result and source documents. """ # Display the main result from the LLM response - print(response_from_llm['result']) + print(response_from_llm["result"]) # Separator for clarity - print('\nSources:') + print("\nSources:") # Loop through each source document and print its source for source in response_from_llm["source_documents"]: - print(source.metadata['source']) + print(source.metadata["source"]) def load_prompt(json_file: str | os.PathLike, return_input_vars=False): @@ -72,9 +74,9 @@ def load_prompt(json_file: str | os.PathLike, return_input_vars=False): """ with open(f"{json_file}", "r") as f: prompt_json = json.load(f) - prompt_template = ChatPromptTemplate.from_template(prompt_json['template']) + prompt_template = ChatPromptTemplate.from_template(prompt_json["template"]) - input_vars = prompt_json['input_variables'] + input_vars = prompt_json["input_variables"] return (prompt_template, input_vars) if return_input_vars else prompt_template @@ -94,7 +96,7 @@ def find_config_path(current_path: Path) -> Path: Raises: FileNotFoundError: If 'config.ini' cannot be found in any of the parent directories. """ - config_path = Path('src/config.ini') + config_path = Path("src/config.ini") while not (current_path / config_path).exists(): current_path = current_path.parent if current_path == current_path.parent: @@ -113,11 +115,11 @@ def get_config() -> ConfigParser: """ # Assuming this script is somewhere inside your project directory script_location = Path(__file__).resolve() - if os.environ.get('CONFIG_PATH'): - config_path = os.environ.get('CONFIG_PATH') + if os.environ.get("CONFIG_PATH"): + config_path = os.environ.get("CONFIG_PATH") else: config_path = find_config_path(script_location) - os.environ['CONFIG_PATH'] = str(config_path) + os.environ["CONFIG_PATH"] = str(config_path) print(f"Loaded config from {config_path}.") # Initialize parser and read config config = ConfigParser(interpolation=ExtendedInterpolation()) diff --git a/src/grag/prompts/__init__.py b/src/grag/prompts/__init__.py index 4c869db..cf6b3bc 100644 --- a/src/grag/prompts/__init__.py +++ b/src/grag/prompts/__init__.py @@ -1,10 +1,7 @@ { - "input_keys": [ - "context", - "question" - ], + "input_keys": ["context", "question"], "template": "[INST] <>\nYou are a helpful, respectful and honest assistant.\n\nAlways answer based only on the provided context. If the question can not be answered from the provided context, just say that you don't know, don't try to make up an answer.\n<>\n\nUse the following pieces of context to answer the question at the end:\n\n\n{context}\n\n\nQuestion: {question}\n\nHelpful Answer: [/INST]", "llm_type": "Llama-2", "source": "https://github.com/hwchase17/langchain-hub/blob/master/prompts/vector_db_qa/prompt.json", - "task": "QA" + "task": "QA", } diff --git a/src/grag/rag/basic_rag.py b/src/grag/rag/basic_rag.py index 1e33ff6..9589920 100644 --- a/src/grag/rag/basic_rag.py +++ b/src/grag/rag/basic_rag.py @@ -13,15 +13,15 @@ class BasicRAG: - def __init__(self, - model_name=None, - doc_chain='stuff', - task='QA', - llm_kwargs=None, - retriever_kwargs=None, - custom_prompt: Union[Prompt, FewShotPrompt, None] = None - ): - + def __init__( + self, + model_name=None, + doc_chain="stuff", + task="QA", + llm_kwargs=None, + retriever_kwargs=None, + custom_prompt: Union[Prompt, FewShotPrompt, None] = None, + ): if retriever_kwargs is None: self.retriever = Retriever() else: @@ -35,16 +35,20 @@ def __init__(self, self.prompt_path = files(prompts) self.custom_prompt = custom_prompt - self._task = 'QA' + self._task = "QA" self.model_name = model_name self.doc_chain = doc_chain self.task = task if self.custom_prompt is None: - self.main_prompt = Prompt.load(self.prompt_path.joinpath(self.main_prompt_name)) - - if self.doc_chain == 'refine': - self.refine_prompt = Prompt.load(self.prompt_path.joinpath(self.refine_prompt_name)) + self.main_prompt = Prompt.load( + self.prompt_path.joinpath(self.main_prompt_name) + ) + + if self.doc_chain == "refine": + self.refine_prompt = Prompt.load( + self.prompt_path.joinpath(self.refine_prompt_name) + ) else: self.main_prompt = self.custom_prompt @@ -56,7 +60,7 @@ def model_name(self): def model_name(self, value): if value is None: self.llm = self.llm_.load_model() - self._model_name = conf['llm']['model_name'] + self._model_name = conf["llm"]["model_name"] else: self._model_name = value self.llm = self.llm_.load_model(model_name=self.model_name) @@ -67,14 +71,17 @@ def doc_chain(self): @doc_chain.setter def doc_chain(self, value): - _allowed_doc_chains = ['refine', 'stuff'] + _allowed_doc_chains = ["refine", "stuff"] if value not in _allowed_doc_chains: - raise ValueError(f'Doc chain {value} is not allowed. Available choices: {_allowed_doc_chains}') + raise ValueError( + f"Doc chain {value} is not allowed. Available choices: {_allowed_doc_chains}" + ) self._doc_chain = value - if value == 'refine': + if value == "refine": if self.custom_prompt is not None: assert len(self.custom_prompt) == 2, ValueError( - f"Refine chain needs 2 custom prompts. {len(self.custom_prompt)} custom prompts were given.") + f"Refine chain needs 2 custom prompts. {len(self.custom_prompt)} custom prompts were given." + ) self.prompt_matcher() @property @@ -83,27 +90,35 @@ def task(self): @task.setter def task(self, value): - _allowed_tasks = ['QA'] + _allowed_tasks = ["QA"] if value not in _allowed_tasks: - raise ValueError(f'Task {value} is not allowed. Available tasks: {_allowed_tasks}') + raise ValueError( + f"Task {value} is not allowed. Available tasks: {_allowed_tasks}" + ) self._task = value self.prompt_matcher() def prompt_matcher(self): - matcher_path = self.prompt_path.joinpath('matcher.json') + matcher_path = self.prompt_path.joinpath("matcher.json") with open(f"{matcher_path}", "r") as f: matcher_dict = json.load(f) try: self.model_type = matcher_dict[self.model_name] except KeyError: - raise ValueError(f'Prompt for {self.model_name} not found in {matcher_path}') + raise ValueError( + f"Prompt for {self.model_name} not found in {matcher_path}" + ) - self.main_prompt_name = f'{self.model_type}_{self.task}_1.json' - self.refine_prompt_name = f'{self.model_type}_{self.task}-refine_1.json' + self.main_prompt_name = f"{self.model_type}_{self.task}_1.json" + self.refine_prompt_name = f"{self.model_type}_{self.task}-refine_1.json" if self.custom_prompt is None: - self.main_prompt = Prompt.load(self.prompt_path.joinpath(self.main_prompt_name)) - if self.doc_chain == 'refine': - self.refine_prompt = Prompt.load(self.prompt_path.joinpath(self.refine_prompt_name)) + self.main_prompt = Prompt.load( + self.prompt_path.joinpath(self.main_prompt_name) + ) + if self.doc_chain == "refine": + self.refine_prompt = Prompt.load( + self.prompt_path.joinpath(self.refine_prompt_name) + ) @staticmethod def stuff_docs(docs: List[Document]) -> str: @@ -113,18 +128,18 @@ def stuff_docs(docs: List[Document]) -> str: Returns: string of document page content joined by '\n\n' """ - return '\n\n'.join([doc.page_content for doc in docs]) + return "\n\n".join([doc.page_content for doc in docs]) @staticmethod def output_parser(call_func): def output_parser_wrapper(*args, **kwargs): response, sources = call_func(*args, **kwargs) - if conf['llm']['std_out'] == 'False': + if conf["llm"]["std_out"] == "False": # if self.llm_.callback_manager is None: print(response) - print('Sources: ') + print("Sources: ") for index, source in enumerate(sources): - print(f'\t{index}: {source}') + print(f"\t{index}: {source}") return response, sources return output_parser_wrapper @@ -145,20 +160,23 @@ def refine_call(self, query: str): responses = [] for index, doc in enumerate(retrieved_docs): if index == 0: - prompt = self.main_prompt.format(context=doc.page_content, - question=query) + prompt = self.main_prompt.format( + context=doc.page_content, question=query + ) response = self.llm.invoke(prompt) responses.append(response) else: - prompt = self.refine_prompt.format(context=doc.page_content, - question=query, - existing_answer=responses[-1]) + prompt = self.refine_prompt.format( + context=doc.page_content, + question=query, + existing_answer=responses[-1], + ) response = self.llm.invoke(prompt) responses.append(response) return responses, sources def __call__(self, query: str): - if self.doc_chain == 'stuff': + if self.doc_chain == "stuff": return self.stuff_call(query) - elif self.doc_chain == 'refine': + elif self.doc_chain == "refine": return self.refine_call(query) diff --git a/src/tests/components/embedding_test.py b/src/tests/components/embedding_test.py index 6d3e45a..1eda90f 100644 --- a/src/tests/components/embedding_test.py +++ b/src/tests/components/embedding_test.py @@ -13,21 +13,37 @@ def cosine_similarity(a, b): # %% embedding_configs = [ - {'embedding_type': 'sentence-transformers', - 'embedding_model': "all-mpnet-base-v2", }, - {'embedding_type': 'instructor-embedding', - 'embedding_model': 'hkunlp/instructor-xl', } + { + "embedding_type": "sentence-transformers", + "embedding_model": "all-mpnet-base-v2", + }, + { + "embedding_type": "instructor-embedding", + "embedding_model": "hkunlp/instructor-xl", + }, ] -@pytest.mark.parametrize('embedding_config', embedding_configs) +@pytest.mark.parametrize("embedding_config", embedding_configs) def test_embeddings(embedding_config): # docs tuple format: (doc, similar to doc, asimilar to doc) - doc_sets = [('The new movie is awesome.', 'The new movie is so great.', 'The video is awful'), - ('The cat sits outside.', 'The dog plays in the garden.', 'The car is parked inside')] + doc_sets = [ + ( + "The new movie is awesome.", + "The new movie is so great.", + "The video is awful", + ), + ( + "The cat sits outside.", + "The dog plays in the garden.", + "The car is parked inside", + ), + ] embedding = Embedding(**embedding_config) for docs in doc_sets: doc_vecs = [embedding.embedding_function.embed_query(doc) for doc in docs] - similarity_scores = [cosine_similarity(doc_vecs[0], doc_vecs[1]), - cosine_similarity(doc_vecs[0], doc_vecs[2])] + similarity_scores = [ + cosine_similarity(doc_vecs[0], doc_vecs[1]), + cosine_similarity(doc_vecs[0], doc_vecs[2]), + ] assert similarity_scores[0] > similarity_scores[1] diff --git a/src/tests/components/llm_test.py b/src/tests/components/llm_test.py index 44de0dd..df0f4d9 100644 --- a/src/tests/components/llm_test.py +++ b/src/tests/components/llm_test.py @@ -3,22 +3,26 @@ import pytest from grag.components.llm import LLM -llama_models = ['Llama-2-7b-chat', - 'Llama-2-13b-chat', - 'Mixtral-8x7B-Instruct-v0.1', - 'gemma-7b-it'] -hf_models = ['meta-llama/Llama-2-7b-chat-hf', - 'meta-llama/Llama-2-13b-chat-hf', - # 'mistralai/Mixtral-8x7B-Instruct-v0.1', - 'google/gemma-7b-it'] -cpp_quantization = ['Q5_K_M', 'Q5_K_M', 'Q4_K_M', 'f16'] -hf_quantization = ['Q8', 'Q4', 'Q4'] # , 'Q4'] +llama_models = [ + "Llama-2-7b-chat", + "Llama-2-13b-chat", + "Mixtral-8x7B-Instruct-v0.1", + "gemma-7b-it", +] +hf_models = [ + "meta-llama/Llama-2-7b-chat-hf", + "meta-llama/Llama-2-13b-chat-hf", + # 'mistralai/Mixtral-8x7B-Instruct-v0.1', + "google/gemma-7b-it", +] +cpp_quantization = ["Q5_K_M", "Q5_K_M", "Q4_K_M", "f16"] +hf_quantization = ["Q8", "Q4", "Q4"] # , 'Q4'] params = [(model, quant) for model, quant in zip(hf_models, hf_quantization)] @pytest.mark.parametrize("hf_models, quantization", params) def test_hf_web_pipe(hf_models, quantization): - llm_ = LLM(quantization=quantization, model_name=hf_models, pipeline='hf') + llm_ = LLM(quantization=quantization, model_name=hf_models, pipeline="hf") model = llm_.load_model(is_local=False) response = model.invoke("Who are you?") assert isinstance(response, Text) @@ -30,7 +34,7 @@ def test_hf_web_pipe(hf_models, quantization): @pytest.mark.parametrize("model_name, quantization", params) def test_llamacpp_pipe(model_name, quantization): - llm_ = LLM(quantization=quantization, model_name=model_name, pipeline='llama_cpp') + llm_ = LLM(quantization=quantization, model_name=model_name, pipeline="llama_cpp") model = llm_.load_model() response = model.invoke("Who are you?") assert isinstance(response, Text) diff --git a/src/tests/components/multivec_retriever_test.py b/src/tests/components/multivec_retriever_test.py index 6c94882..3ccb3fb 100644 --- a/src/tests/components/multivec_retriever_test.py +++ b/src/tests/components/multivec_retriever_test.py @@ -1,18 +1,18 @@ # # add code folder to sys path # import os # from pathlib import Path -# +# # from grag.components.multivec_retriever import Retriever # from langchain_community.document_loaders import DirectoryLoader, TextLoader -# +# # # %%% # # data_path = "data/pdf/9809" # data_path = Path(os.getcwd()).parents[1] / 'data' / 'Gutenberg' / 'txt' # "data/Gutenberg/txt" # # %% # retriever = Retriever(top_k=3) -# +# # new_docs = True -# +# # if new_docs: # # loading text files from data_path # loader = DirectoryLoader(data_path, @@ -25,20 +25,20 @@ # # %% # # limit docs for testing # docs = docs[:100] -# +# # # %% # # adding chunks and parent doc # retriever.add_docs(docs) # # %% # # testing retrival # query = 'Thomas H. Huxley' -# +# # # Retrieving the 3 most relevant small chunk # chunk_result = retriever.get_chunk(query) -# +# # # Retrieving the most relevant document # doc_result = retriever.get_doc(query) -# +# # # Ensuring that the length of chunk is smaller than length of doc # chunk_len = [len(chunk.page_content) for chunk in chunk_result] # print(f'Length of chunks : {chunk_len}') @@ -46,13 +46,13 @@ # print(f'Length of doc : {doc_len}') # len_test = [c_len < d_len for c_len, d_len in zip(chunk_len, doc_len)] # print(f'Is len of chunk less than len of doc?: {len_test} ') -# +# # # Ensuring both the chunk and document match the source # chunk_sources = [chunk.metadata['source'] for chunk in chunk_result] # doc_sources = [doc.metadata['source'] for doc in doc_result] # source_test = [source[0] == source[1] for source in zip(chunk_sources, doc_sources)] # print(f'Does source of chunk and doc match? : {source_test}') -# +# # # # Ensuring both the chunk and document match the source # # source_test = chunk_result.metadata['source'] == doc_result.metadata['source'] # # print(f'Does source of chunk and doc match? : {source_test}') diff --git a/src/tests/components/parse_pdf_test.py b/src/tests/components/parse_pdf_test.py index 668f8f1..7110e21 100644 --- a/src/tests/components/parse_pdf_test.py +++ b/src/tests/components/parse_pdf_test.py @@ -1,39 +1,39 @@ # # add code folder to sys path # import time # from pathlib import Path -# +# # from grag.components.parse_pdf import ParsePDF # from grag.components.utils import get_config -# +# # data_path = Path(get_config()['data']['data_path']) # # %% # data_path = data_path / 'test' / 'pdf' # "data/test/pdf" -# -# +# +# # def main(filename): # file_path = data_path / filename # pdf_parser = ParsePDF() # start_time = time.time() # docs_dict = pdf_parser.load_file(file_path) # time_taken = time.time() - start_time -# +# # print(f'Parsed pdf in {time_taken:2f} secs') -# +# # print('******** TEXT ********') # for doc in docs_dict['Text']: # print(doc) -# +# # print('******** TABLES ********') # for text_doc in docs_dict['Tables']: # print(text_doc) -# +# # print('******** IMAGES ********') # for doc in docs_dict['Images']: # print(doc) -# +# # return docs_dict -# -# +# +# # if __name__ == "__main__": # filename = 'he_pdsw12.pdf' # print(f'Parsing: {filename}') diff --git a/src/tests/components/prompt_test.py b/src/tests/components/prompt_test.py index fee7edc..3414f59 100644 --- a/src/tests/components/prompt_test.py +++ b/src/tests/components/prompt_test.py @@ -2,16 +2,18 @@ from importlib_resources import files question = "What is the capital of France" -context = "Paris is the capital and most populous city of France. With an official estimated population of 2,102,650 \ +context = ( + "Paris is the capital and most populous city of France. With an official estimated population of 2,102,650 \ residents as of 1 January 2023 in an area of more than 105 km2 (41 sq mi), Paris is the fourth-most populated \ city in the European Union and the 30th most densely populated city in the world in 2022. Since the 17th century, \ Paris has been one of the world's major centres of finance, diplomacy, commerce, culture, fashion, and gastronomy. \ For its leading role in the arts and sciences, as well as its early and extensive system of street lighting, in the \ 19th century, it became known as the City of Light." +) def test_prompt_files(): - prompt_files = list(files('grag.prompts').glob('*.json')) + prompt_files = list(files("grag.prompts").glob("*.json")) for file in prompt_files: if file.name.startswith("matcher"): prompt_files.remove(file) @@ -21,18 +23,15 @@ def test_prompt_files(): def test_custom_prompt(): - template = '''Answer the following question based on the given context. + template = """Answer the following question based on the given context. question: {question} context: {context} answer: - ''' - correct_prompt = f'''Answer the following question based on the given context. + """ + correct_prompt = f"""Answer the following question based on the given context. question: {question} context: {context} answer: - ''' - custom_prompt = Prompt( - input_keys={"context", "question"}, - template=template - ) + """ + custom_prompt = Prompt(input_keys={"context", "question"}, template=template) assert custom_prompt.format(question=question, context=context) == correct_prompt diff --git a/src/tests/rag/basic_rag_test.py b/src/tests/rag/basic_rag_test.py index 72991d1..2249028 100644 --- a/src/tests/rag/basic_rag_test.py +++ b/src/tests/rag/basic_rag_test.py @@ -4,8 +4,8 @@ def test_rag_stuff(): - rag = BasicRAG(doc_chain='stuff') - response, sources = rag('What is simulated annealing?') + rag = BasicRAG(doc_chain="stuff") + response, sources = rag("What is simulated annealing?") assert isinstance(response, Text) assert isinstance(sources, List) assert all(isinstance(s, str) for s in sources) @@ -13,8 +13,8 @@ def test_rag_stuff(): def test_rag_refine(): - rag = BasicRAG(doc_chain='refine') - response, sources = rag('What is simulated annealing?') + rag = BasicRAG(doc_chain="refine") + response, sources = rag("What is simulated annealing?") # assert isinstance(response, Text) assert isinstance(response, List) assert all(isinstance(s, str) for s in response)