Skip to content

Commit

Permalink
Fix RAG performance issues (#132)
Browse files Browse the repository at this point in the history
* Fix RAG performance issues

Signed-off-by: lvliang-intel <liang1.lv@intel.com>
  • Loading branch information
lvliang-intel authored Jun 8, 2024
1 parent 8a670ee commit 70c23d1
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 22 deletions.
2 changes: 1 addition & 1 deletion comps/embeddings/langchain/local_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
)
@opea_telemetry
def embedding(input: TextDoc) -> EmbedDoc1024:
embeddings = HuggingFaceBgeEmbeddings(model_name="BAAI/bge-large-en-v1.5")
embed_vector = embeddings.embed_query(input.text)
res = EmbedDoc1024(text=input.text, embedding=embed_vector)
return res


if __name__ == "__main__":
embeddings = HuggingFaceBgeEmbeddings(model_name="BAAI/bge-large-en-v1.5")
opea_microservices["opea_service@local_embedding"].start()
4 changes: 2 additions & 2 deletions comps/guardrails/langchain/guardrails_tgi_gaudi.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,6 @@ def get_unsafe_dict(model_id="meta-llama/LlamaGuard-7b"):
)
@traceable(run_type="llm")
def safety_guard(input: TextDoc) -> TextDoc:
# chat engine for server-side prompt templating
llm_engine_hf = ChatHuggingFace(llm=llm_guard)
response_input_guard = llm_engine_hf.invoke([{"role": "user", "content": input.text}]).content
if "unsafe" in response_input_guard:
unsafe_dict = get_unsafe_dict(llm_engine_hf.model_id)
Expand All @@ -75,5 +73,7 @@ def safety_guard(input: TextDoc) -> TextDoc:
temperature=0.01,
repetition_penalty=1.03,
)
# chat engine for server-side prompt templating
llm_engine_hf = ChatHuggingFace(llm=llm_guard)
print("guardrails - router] LLM initialized.")
opea_microservices["opea_service@guardrails_tgi_gaudi"].start()
18 changes: 9 additions & 9 deletions comps/retrievers/langchain/retriever_redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,15 @@
)
@traceable(run_type="retriever")
def retrieve(input: EmbedDoc768) -> SearchedDoc:
search_res = vector_db.similarity_search_by_vector(embedding=input.embedding)
searched_docs = []
for r in search_res:
searched_docs.append(TextDoc(text=r.page_content))
result = SearchedDoc(retrieved_docs=searched_docs, initial_query=input.text)
return result


if __name__ == "__main__":
# Create vectorstore
if tei_embedding_endpoint:
# create embeddings using TEI endpoint service
Expand All @@ -36,13 +45,4 @@ def retrieve(input: EmbedDoc768) -> SearchedDoc:
redis_url=REDIS_URL,
schema=INDEX_SCHEMA,
)
search_res = vector_db.similarity_search_by_vector(embedding=input.embedding)
searched_docs = []
for r in search_res:
searched_docs.append(TextDoc(text=r.page_content))
result = SearchedDoc(retrieved_docs=searched_docs, initial_query=input.text)
return result


if __name__ == "__main__":
opea_microservices["opea_service@retriever_redis"].start()
8 changes: 5 additions & 3 deletions tests/test_llms_text-generation_tgi.sh
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ function start_service() {

# check whether tgi is fully ready
n=0
until [[ "$n" -ge 100 ]]; do
docker logs test-comps-llm-tgi-endpoint > test-comps-llm-tgi-endpoint.log
until [[ "$n" -ge 100 ]] || [[ $ready == true ]]; do
docker logs test-comps-llm-tgi-endpoint > ${WORKPATH}/tests/test-comps-llm-tgi-endpoint.log
n=$((n+1))
if grep -q Connected test-comps-llm-tgi-endpoint.log; then
if grep -q Connected ${WORKPATH}/tests/test-comps-llm-tgi-endpoint.log; then
break
fi
sleep 5s
Expand All @@ -43,6 +43,8 @@ function validate_microservice() {
-X POST \
-d '{"query":"What is Deep Learning?"}' \
-H 'Content-Type: application/json'
docker logs test-comps-llm-tgi-endpoint
docker logs test-comps-llm-tgi-server
}

function stop_docker() {
Expand Down
2 changes: 2 additions & 0 deletions tests/test_reranks_langchain.sh
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ function validate_microservice() {
-X POST \
-d '{"initial_query":"What is Deep Learning?", "retrieved_docs": [{"text":"Deep Learning is not..."}, {"text":"Deep learning is..."}]}' \
-H 'Content-Type: application/json'
docker logs test-comps-reranking-tei-server
docker logs test-comps-reranking-tei-endpoint
}

function stop_docker() {
Expand Down
27 changes: 20 additions & 7 deletions tests/test_retrievers_langchain.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,23 @@ WORKPATH=$(dirname "$PWD")
ip_address=$(hostname -I | awk '{print $1}')
function build_docker_images() {
cd $WORKPATH
docker build --no-cache -t opea/retriever-redis:comps -f comps/retrievers/langchain/docker/Dockerfile .
docker build --no-cache -t opea/retriever-redis:comps --build-arg https_proxy=$https_proxy --build-arg http_proxy=$http_proxy -f comps/retrievers/langchain/docker/Dockerfile .
}

function start_service() {
# redis
docker run -d --name test-redis-vector-db -p 5010:6379 -p 5011:8001 -e HTTPS_PROXY=$https_proxy -e HTTP_PROXY=$https_proxy redis/redis-stack:7.2.0-v9
sleep 10s

# tei endpoint
tei_endpoint=5008
model="BAAI/bge-large-en-v1.5"
revision="refs/pr/5"
docker run -d --name="test-comps-retriever-tei-endpoint" -p $tei_endpoint:80 -v ./data:/data --pull always ghcr.io/huggingface/text-embeddings-inference:cpu-1.2 --model-id $model --revision $revision
model="BAAI/bge-base-en-v1.5"
docker run -d --name="test-comps-retriever-tei-endpoint" -p $tei_endpoint:80 -v ./data:/data --pull always ghcr.io/huggingface/text-embeddings-inference:cpu-1.2 --model-id $model
sleep 30s
export TEI_EMBEDDING_ENDPOINT="http://${ip_address}:${tei_endpoint}"

# redis retriever
export REDIS_URL="redis://${ip_address}:6379"
export REDIS_URL="redis://${ip_address}:5010"
export INDEX_NAME="rag-redis"
retriever_port=5009
unset http_proxy
Expand All @@ -38,11 +42,20 @@ function validate_microservice() {
-X POST \
-d "{\"text\":\"test\",\"embedding\":${test_embedding}}" \
-H 'Content-Type: application/json'
docker logs test-comps-retriever-redis-server
docker logs test-comps-retriever-tei-endpoint
}

function stop_docker() {
cid=$(docker ps -aq --filter "name=test-comps-retrievers*")
if [[ ! -z "$cid" ]]; then docker stop $cid && docker rm $cid && sleep 1s; fi
cid_retrievers=$(docker ps -aq --filter "name=test-comps-retrievers*")
if [[ ! -z "$cid_retrievers" ]]; then
docker stop $cid_retrievers && docker rm $cid_retrievers && sleep 1s
fi

cid_redis=$(docker ps -aq --filter "name=test-redis-vector-db")
if [[ ! -z "$cid_redis" ]]; then
docker stop $cid_redis && docker rm $cid_redis && sleep 1s
fi
}

function main() {
Expand Down

0 comments on commit 70c23d1

Please sign in to comment.