Skip to content

Commit

Permalink
Add schema to Redis initialization & Improve LVM-TGI For Multimodal R…
Browse files Browse the repository at this point in the history
…etriever Microservice (#638)

* add schema to Redis initialization

Signed-off-by: Tiep Le <tiep.le@intel.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update lvm_tgi

Signed-off-by: Tiep Le <tiep.le@intel.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Tiep Le <tiep.le@intel.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
tileintel and pre-commit-ci[bot] authored Sep 9, 2024
1 parent fb4b8d2 commit 23cc3ea
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 11 deletions.
52 changes: 43 additions & 9 deletions comps/lvms/lvm_tgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,17 @@

import os
import time
from typing import Union

from fastapi.responses import StreamingResponse
from huggingface_hub import AsyncInferenceClient
from langchain_core.prompts import PromptTemplate
from template import ChatTemplate

from comps import (
CustomLogger,
LVMDoc,
LVMSearchedMultimodalDoc,
ServiceType,
TextDoc,
opea_microservices,
Expand All @@ -32,19 +36,49 @@
output_datatype=TextDoc,
)
@register_statistics(names=["opea_service@lvm_tgi"])
async def lvm(request: LVMDoc):
async def lvm(request: Union[LVMDoc, LVMSearchedMultimodalDoc]) -> TextDoc:
if logflag:
logger.info(request)
start = time.time()
stream_gen_time = []
img_b64_str = request.image
prompt = request.prompt
max_new_tokens = request.max_new_tokens
streaming = request.streaming
repetition_penalty = request.repetition_penalty
temperature = request.temperature
top_k = request.top_k
top_p = request.top_p

if isinstance(request, LVMSearchedMultimodalDoc):
if logflag:
logger.info("[LVMSearchedMultimodalDoc ] input from retriever microservice")
retrieved_metadatas = request.metadata
img_b64_str = retrieved_metadatas[0]["b64_img_str"]
initial_query = request.initial_query
context = retrieved_metadatas[0]["transcript_for_inference"]
prompt = initial_query
if request.chat_template is None:
prompt = ChatTemplate.generate_multimodal_rag_on_videos_prompt(initial_query, context)
else:
prompt_template = PromptTemplate.from_template(request.chat_template)
input_variables = prompt_template.input_variables
if sorted(input_variables) == ["context", "question"]:
prompt = prompt_template.format(question=initial_query, context=context)
else:
logger.info(
f"[ LVMSearchedMultimodalDoc ] {prompt_template} not used, we only support 2 input variables ['question', 'context']"
)
max_new_tokens = request.max_new_tokens
streaming = request.streaming
repetition_penalty = request.repetition_penalty
temperature = request.temperature
top_k = request.top_k
top_p = request.top_p
if logflag:
logger.info(f"prompt generated for [LVMSearchedMultimodalDoc ] input from retriever microservice: {prompt}")

else:
img_b64_str = request.image
prompt = request.prompt
max_new_tokens = request.max_new_tokens
streaming = request.streaming
repetition_penalty = request.repetition_penalty
temperature = request.temperature
top_k = request.top_k
top_p = request.top_p

image = f"data:image/png;base64,{img_b64_str}"
image_prompt = f"![]({image})\n{prompt}\nASSISTANT:"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,9 @@ def format_redis_conn_from_env():

# Vector Index Configuration
INDEX_NAME = os.getenv("INDEX_NAME", "test-index")

current_file_path = os.path.abspath(__file__)
parent_dir = os.path.dirname(current_file_path)
REDIS_SCHEMA = os.getenv("REDIS_SCHEMA", "schema.yml")
schema_path = os.path.join(parent_dir, REDIS_SCHEMA)
INDEX_SCHEMA = schema_path
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Union

from langchain_community.vectorstores import Redis
from multimodal_config import INDEX_NAME, REDIS_URL
from multimodal_config import INDEX_NAME, INDEX_SCHEMA, REDIS_URL

from comps import (
EmbedMultimodalDoc,
Expand Down Expand Up @@ -89,5 +89,5 @@ def retrieve(
if __name__ == "__main__":

embeddings = BridgeTowerEmbedding()
vector_db = Redis(embedding=embeddings, index_name=INDEX_NAME, redis_url=REDIS_URL)
vector_db = Redis(embedding=embeddings, index_name=INDEX_NAME, index_schema=INDEX_SCHEMA, redis_url=REDIS_URL)
opea_microservices["opea_service@multimodal_retriever_redis"].start()
19 changes: 19 additions & 0 deletions comps/retrievers/langchain/redis_multimodal/schema.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

text:
- name: content
- name: b64_img_str
- name: video_id
- name: source_video
- name: embedding_type
- name: title
- name: transcript_for_inference
numeric:
- name: time_of_frame_ms
vector:
- name: content_vector
algorithm: HNSW
datatype: FLOAT32
dims: 512
distance_metric: COSINE

0 comments on commit 23cc3ea

Please sign in to comment.