Skip to content

Commit

Permalink
Merge pull request #57 from arjbingly/tests
Browse files Browse the repository at this point in the history
Basic RAG test and DeepLake issue resolved
  • Loading branch information
arjbingly authored Mar 27, 2024
2 parents 34e7d73 + f433944 commit 3025d59
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 8 deletions.
1 change: 1 addition & 0 deletions projects/Basic-RAG/BasicRAG_stuff.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

client = DeepLakeClient(collection_name="test")
retriever = Retriever(vectordb=client)

rag = BasicRAG(doc_chain="stuff", retriever=retriever)

if __name__ == "__main__":
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,6 @@ docstring-code-format = true

[tool.ruff.lint.pydocstyle]
convention = "google"

[tool.mypy]
ignore_missing_imports = true
2 changes: 1 addition & 1 deletion src/config.ini
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[llm]
model_name : Llama-2-7b-chat
model_name : Llama-2-13b-chat
# meta-llama/Llama-2-70b-chat-hf Mixtral-8x7B-Instruct-v0.1
quantization : Q5_K_M
pipeline : llama_cpp
Expand Down
2 changes: 1 addition & 1 deletion src/grag/components/multivec_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(
store_path: str = multivec_retriever_conf["store_path"],
id_key: str = multivec_retriever_conf["id_key"],
namespace: str = multivec_retriever_conf["namespace"],
top_k=1,
top_k=int(multivec_retriever_conf["top_k"]),
client_kwargs: Optional[Dict[str, Any]] = None,
):
"""Initialize the Retriever.
Expand Down
16 changes: 10 additions & 6 deletions src/tests/rag/basic_rag_test.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,25 @@
from typing import Text, List
from typing import List, Text

from grag.components.multivec_retriever import Retriever
from grag.components.vectordb.deeplake_client import DeepLakeClient
from grag.rag.basic_rag import BasicRAG

client = DeepLakeClient(collection_name="test")
retriever = Retriever(vectordb=client)


def test_rag_stuff():
rag = BasicRAG(doc_chain="stuff")
response, sources = rag("What is simulated annealing?")
rag = BasicRAG(doc_chain="stuff", retriever=retriever)
response, sources = rag("What is Flash Attention?")
assert isinstance(response, Text)
assert isinstance(sources, List)
assert all(isinstance(s, str) for s in sources)
del rag.llm


def test_rag_refine():
rag = BasicRAG(doc_chain="refine")
response, sources = rag("What is simulated annealing?")
# assert isinstance(response, Text)
rag = BasicRAG(doc_chain="refine", retriever=retriever)
response, sources = rag("What is Flash Attention?")
assert isinstance(response, List)
assert all(isinstance(s, str) for s in response)
assert isinstance(sources, List)
Expand Down

0 comments on commit 3025d59

Please sign in to comment.