Skip to content

Commit

Permalink
minimal rag example script
Browse files Browse the repository at this point in the history
  • Loading branch information
francoishernandez committed Dec 9, 2024
1 parent 58a0265 commit 46ebf29
Show file tree
Hide file tree
Showing 3 changed files with 170 additions and 0 deletions.
Binary file added recipes/rag-example/OJ_L_202401689_EN_TXT.pdf
Binary file not shown.
46 changes: 46 additions & 0 deletions recipes/rag-example/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Minimal RAG example

## Context

This recipe is intended as a minimal example for Retrieval Augmented Generation using Eole models.
It relies on additional tools, such as Langchain's [loaders](https://python.langchain.com/docs/integrations/document_loaders/pymupdf/) and [splitters](https://python.langchain.com/v0.1/docs/modules/data_connection/document_transformers/recursive_text_splitter/), as well as [ChromaDB](https://docs.trychroma.com/getting-started) for vector search.

The example is using the rather hard to digest ["EU AI Act" full text](https://digital-strategy.ec.europa.eu/en/policies/regulatory-framework-ai#:~:text=The%20AI%20Act%20(Regulation%20(EU,regarding%20specific%20uses%20of%20AI.) for the sake of the exercise.

This is just a very quickly put together proof of concept, and is not expected to give perfect answers.

## Usage

### 1. Convert the model you want to use (Llama-3.1-8B by default)

**Set environment variables**

```bash
export EOLE_MODEL_DIR=<where_to_store_models>
export HF_TOKEN=<your_hf_token>
```

**Download and convert model**

```bash
eole convert HF --model_dir meta-llama/Meta-Llama-3.1-8B --output $EOLE_MODEL_DIR/llama3.1-8b --token $HF_TOKEN
```

### 2. Adapt and run the script

Modify, if needed, the model_path in `PredictConfig`:
```python
...
config = PredictConfig(
model_path=os.path.expandvars("${EOLE_MODEL_DIR}/llama3.1-8b"), # <------ change if needed
src="dummy",
max_length=500,
...
```

**Run the script**
```bash
python3 test_rag.py
```

Note: You can test various queries by changing the `QUERY` variable.
124 changes: 124 additions & 0 deletions recipes/rag-example/test_rag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# flake8: noqa

import os
from rich import print
from tqdm import tqdm
from langchain_community.document_loaders import PyMuPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
import chromadb

from eole.utils.logging import init_logger
from eole.config.run import PredictConfig
from eole.inference_engine import InferenceEnginePY

# Set up logging
logger = init_logger()

# 1. Load and Split the Document
logger.info("Loading and splitting the document...")
loader = PyMuPDFLoader("./OJ_L_202401689_EN_TXT.pdf")
docs = loader.load()

text_splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=150)
documents = text_splitter.split_documents(documents=docs)
print(f"[INFO] Total chunks: {len(documents)}")

# 2. Set Up ChromaDB Client and Collection
logger.info("Setting up ChromaDB client...")
chroma_client = chromadb.PersistentClient(path="chromadb_data")
collection = chroma_client.get_or_create_collection(name="test-eu")

# 3. Insert Documents into Collection
logger.info("Checking for existing data in the collection...")

# Retrieve all existing IDs from the collection
existing_ids = set(collection.get(ids=None)["ids"]) # Fetches all IDs in the collection
logger.info(f"Found {len(existing_ids)} existing documents in the collection.")
batch_size = 100
for i in tqdm(
range(0, len(documents), batch_size),
desc="Upserting batches in the vector database",
):
batch_ids = [str(k) for k in range(i, min(len(documents), i + batch_size))]
new_ids = [
id_ for id_ in batch_ids if id_ not in existing_ids
] # Filter out existing IDs

if new_ids: # Only upsert if there are new IDs
new_docs = [documents[int(id_)] for id_ in new_ids]
collection.upsert(
ids=new_ids,
documents=[doc.page_content for doc in new_docs],
metadatas=[doc.metadata for doc in new_docs],
)
logger.info(f"Upserted {len(new_ids)} new documents.")
else:
logger.info(f"Skipping batch {i // batch_size + 1}, all IDs already exist.")

# print(collection.peek(10))

# 4. Query the Collection
QUERY = "What is the general position around using biometrics and facial recognition in public places?"
# QUERY = "Are there any derogations for specific actors?"
# QUERY = "What are the main obligations of importers?"
# QUERY = "What are the main risks and penalties incurred?"
# QUERY = "What is the maximum fine for potential offenders?"
# QUERY = "What are the main prohibited practices coverd by the act?"
# QUERY = "What are the main accepted practices covered by the act?"

print(f"[INFO] Querying collection with: {QUERY}")
results = collection.query(query_texts=[QUERY], n_results=5)

best_id = int(results["ids"][0][0])
print(f"[INFO] Best result ID: {best_id}")
context_docs = collection.get(ids=[str(best_id - 1), str(best_id), str(best_id + 1)])

# 5. Prepare the Prompt for Inference
PROMPT = """You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know. Use three sentences maximum and keep the answer concise. The answer should be understandable outside of its context.
The context comes from this document: "Regulation (EU) 2024/1689 of the European Parliament and of the Council of 13 June 2024 laying down harmonised rules on artificial intelligence and amending Regulations (EC) No 300/2008, (EU) No 167/2013, (EU) No 168/2013, (EU) 2018/858, (EU) 2018/1139 and (EU) 2019/2144 and Directives 2014/90/EU, (EU) 2016/797 and (EU) 2020/1828 (Artificial Intelligence Act)Text with EEA relevance."
Question: {question}
Context:  {context}
Answer:"""

context = "\n".join(context_docs["documents"])
prompt = PROMPT.format(question=QUERY, context=context)
logger.info("Generated Prompt:")
print(prompt)


# 6. Perform Inference
logger.info("Running inference...")
config = PredictConfig(
model_path=os.path.expandvars("${EOLE_MODEL_DIR}/llama3.1-8b"),
src="dummy",
max_length=500,
gpu_ranks=[0],
# Uncomment to activate bnb quantization
# quant_type="bnb_NF4",
# quant_layers=[
# "gate_up_proj",
# "down_proj",
# "up_proj",
# "linear_values",
# "linear_query",
# "linear_keys",
# "final_linear",
# "w_in",
# "w_out",
# ],
top_p=0.3,
temperature=0.35,
beam_size=5,
seed=42,
batch_size=1,
batch_type="sents",
)

engine = InferenceEnginePY(config)

_, _, predictions = engine.infer_list([prompt])

# 7. Display the Prediction
answer = predictions[0][0].replace("⦅newline⦆", "\n")
logger.info("Final Answer:")
print(answer)

0 comments on commit 46ebf29

Please sign in to comment.