Skip to content

Commit

Permalink
update experimental SSA code
Browse files Browse the repository at this point in the history
  • Loading branch information
TheVinhLuong102 committed Dec 4, 2023
1 parent 5bc471d commit 681a06b
Show file tree
Hide file tree
Showing 12 changed files with 793 additions and 139 deletions.
132 changes: 132 additions & 0 deletions openssa/contrib/custom.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
from llama_index import Document, Response, SimpleDirectoryReader
from llama_index.evaluation import DatasetGenerator
from llama_index.llms.base import LLM as RAGLLM
from llama_index.node_parser import SimpleNodeParser

from openssa.core.backend.abstract_backend import AbstractBackend
from openssa.core.slm.base_slm import PassthroughSLM
from openssa.core.ssm.rag_ssm import RAGSSM
from openssa.integrations.llama_index.backend import Backend as LlamaIndexBackend
from openssa.utils.llm_config import LLMConfig

FILE_NAME = "file_name"


class CustomBackend(LlamaIndexBackend): # type: ignore
def __init__(self, rag_llm: RAGLLM = None) -> None: # type: ignore
super().__init__(rag_llm=rag_llm)
self.llm = rag_llm

def _do_read_directory(self, storage_dir: str) -> None:
def filename_fn(filename: str) -> dict:
return {FILE_NAME: filename}

documents = SimpleDirectoryReader(
self._get_source_dir(storage_dir),
filename_as_id=True,
file_metadata=filename_fn,
).load_data()
self.documents = documents
self._create_index(documents, storage_dir)

def get_citation_type(self, filename: str) -> str:
extension = filename.split(".")[-1]
return extension.strip().lower() if extension else "unknown"

def get_citations(self, response: Response, source_path: str = "") -> list[dict]:
citations: list = []
print("metadata", response.metadata)
if not response.metadata:
return citations
for data in response.metadata.values():
filename = (
data.get(FILE_NAME, "").strip() or data.get("filename", "").strip()
)

if not filename:
continue
filename = filename.split("/")[-1]
citation_type = self.get_citation_type(filename)
pages = [data.get("page_label")] if data.get("page_label") else []
if source_path and not source_path.endswith("/"):
source_path = source_path + "/"
source = source_path + filename if source_path else filename
citation = {"type": citation_type, "pages": pages, "source": source}
citations.append(citation)
return citations

def add_feedback(self, doc: Document) -> None:
nodes = SimpleNodeParser.from_defaults().get_nodes_from_documents([doc])
self._index.insert_nodes(nodes)
self.query_engine = self._index.as_query_engine()

def persist(self, persist_path: str) -> None:
print("persist_path", persist_path)
self._index.storage_context.persist(persist_path)

def query(self, query: str, source_path: str = "") -> dict:
"""Returns a response dict with keys role, content, and citations."""
if self.query_engine is None:
return {"content": "No index to query. Please load something first."}
response: Response = self.query_engine.query(query)
citations = self.get_citations(response, source_path)
print("citations", citations)
return {"content": response.response, "citations": citations}

async def get_evaluation_data(self) -> list:
if self.documents:
data_generator = DatasetGenerator.from_documents(self.documents)
nodes = self.sort_longest_nodes(self.documents)
service_context = LLMConfig.get_service_context_openai_35_turbo()
data_generator = DatasetGenerator(
nodes=nodes[:5],
service_context=service_context,
num_questions_per_chunk=3,
show_progress=True,
)
eval_questions = await data_generator.agenerate_questions_from_nodes(5)
return eval_questions
return []

def sort_longest_nodes(self, documents: list) -> list:
return sorted(documents, key=lambda doc: len(doc.text), reverse=True)


class CustomSSM(RAGSSM): # type: ignore
def __init__(
self,
custom_rag_backend: AbstractBackend = None,
s3_source_path: str = "",
llm: RAGLLM = LLMConfig.get_llm_openai_35_turbo(), # type: ignore
) -> None:
if custom_rag_backend is None:
custom_rag_backend = CustomBackend(rag_llm=llm)

slm = PassthroughSLM()
self._rag_backend = custom_rag_backend
self.s3_source_path = s3_source_path
super().__init__(slm=slm, rag_backend=self._rag_backend)

def discuss(self, query: str, conversation_id: str = "") -> dict:
"""Return response with keys role, content, and citations."""
return self._rag_backend.query(query, source_path=self.s3_source_path)

def add_feedback(self, doc: Document, storage_path: str = "") -> None:
self._rag_backend.add_feedback(doc)
self._rag_backend.persist(storage_path)

async def get_evaluation_data(self) -> dict:
return await self._rag_backend.get_evaluation_data()


if __name__ == "__main__":
import time

t1 = time.time()
ssm = CustomSSM()
ssm.read_directory("tests/doc", re_index=False)
t2 = time.time()
print("time to load", time.time() - t1)
res = ssm.discuss("what is MRI?")
print(res)
print("time to query", time.time() - t2)
63 changes: 63 additions & 0 deletions openssa/contrib/solver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from dotenv import load_dotenv

load_dotenv() # it must be called before importing the project modules
from openssa.core.ooda_rag.ooda_rag import Solver
from openssa.core.ooda_rag.heuristic import (
DefaultOODAHeuristic,
TaskDecompositionHeuristic,
)
from openssa.core.ooda_rag.tools import ReasearchAgentTool
from openssa.utils.aitomatic_llm_config import AitomaticLLMConfig
from openssa.contrib.custom import CustomSSM


class OodaSSA:
def __init__(
self,
task_heuristics,
llm=AitomaticLLMConfig.get_llama2_70b(),
model: str = "llama2",
):
self.llm = llm
self.solver = Solver(
task_heuristics=task_heuristics,
ooda_heuristics=DefaultOODAHeuristic(),
llm=llm,
model=model,
)

def load(self, folder_path: str) -> None:
# agent = CustomSSM(llm=self.llm) # TODO fix this to run
agent = CustomSSM()
agent.read_directory(folder_path)
response = agent.discuss("what is mri hahaha")
print('debug: ', response)
self.research_documents_tool = ReasearchAgentTool(agent=agent)


def solve(self, message: str) -> str:
return self.solver.run(
message, {"research_documents": self.research_documents_tool}
)


if __name__ == "__main__":
heuristic_rules_example = {
"uncrated picc": [
"find out the weight of the uncrated PICC",
],
"crated picc": [
"find out the weight of the crated PICC",
],
"picc": [
"find out the weight of PICC",
],
}
task_heuristics = TaskDecompositionHeuristic(heuristic_rules_example)
ooda_ssa = OodaSSA(task_heuristics)
print("start reading doc")
ooda_ssa.load("tests/doc")
print("finish reading doc")
print(ooda_ssa.solve("find out the weight of the uncrated PICC"))
print(ooda_ssa.solve("find out the weight of the crated PICC"))
print(ooda_ssa.solve("find out the weight of PICC"))
Loading

0 comments on commit 681a06b

Please sign in to comment.