-
Notifications
You must be signed in to change notification settings - Fork 151
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Megaservice support for MMRAG - MultimodalRAGQnAWithVideos usecase (
#626) * updates Signed-off-by: Tiep Le <tiep.le@intel.com> * cosmetic Signed-off-by: siddhivelankar23 <siddhi.velankar@intel.com> * update redis schema Signed-off-by: siddhivelankar23 <siddhi.velankar@intel.com> * update multimodal config and docker compose retriever Signed-off-by: siddhivelankar23 <siddhi.velankar@intel.com> * update requirements Signed-off-by: siddhivelankar23 <siddhi.velankar@intel.com> * update retriever redis Signed-off-by: siddhivelankar23 <siddhi.velankar@intel.com> * multimodal retriever implementation Signed-off-by: siddhivelankar23 <siddhi.velankar@intel.com> * test for multimodal retriever Signed-off-by: siddhivelankar23 <siddhi.velankar@intel.com> * include prompt preparation for multimodal rag on videos application Signed-off-by: sjagtap1803 <siddhant.jagtap@intel.com> * fix template Signed-off-by: sjagtap1803 <siddhant.jagtap@intel.com> * add test for llava for mm_rag_on_videos Signed-off-by: sjagtap1803 <siddhant.jagtap@intel.com> * update test Signed-off-by: sjagtap1803 <siddhant.jagtap@intel.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * first update on gateaway Signed-off-by: sjagtap1803 <siddhant.jagtap@intel.com> * fix index not found Signed-off-by: sjagtap1803 <siddhant.jagtap@intel.com> * add LVMSearchedMultimodalDoc Signed-off-by: sjagtap1803 <siddhant.jagtap@intel.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * implement gateway for MultimodalRagQnAWithVideos Signed-off-by: siddhivelankar23 <siddhi.velankar@intel.com> * remove INDEX_SCHEMA Signed-off-by: siddhivelankar23 <siddhi.velankar@intel.com> * update MultimodalRAGQnAWithVideosGateway with 2 megaservices Signed-off-by: sjagtap1803 <siddhant.jagtap@intel.com> * revise folder structure to comps/retrievers/langchain/redis_multimodal Signed-off-by: siddhivelankar23 <siddhi.velankar@intel.com> * update test Signed-off-by: siddhivelankar23 <siddhi.velankar@intel.com> * add unittest for multimodalrag_qna_with_videos_gateway Signed-off-by: siddhivelankar23 <siddhi.velankar@intel.com> * update test mmrag qna with videos Signed-off-by: Tiep Le <tiep.le@intel.com> * change port of redis to resolve CI test Signed-off-by: siddhivelankar23 <siddhi.velankar@intel.com> * update test Signed-off-by: siddhivelankar23 <siddhi.velankar@intel.com> * update lvms test Signed-off-by: siddhivelankar23 <siddhi.velankar@intel.com> * update test Signed-off-by: siddhivelankar23 <siddhi.velankar@intel.com> * update test Signed-off-by: siddhivelankar23 <siddhi.velankar@intel.com> * update test for multimodal rag qna with videos gateway Signed-off-by: siddhivelankar23 <siddhi.velankar@intel.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add more test to increase coverage Signed-off-by: Tiep Le <tiep.le@intel.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * cosmetic Signed-off-by: Tiep Le <tiep.le@intel.com> * add more test Signed-off-by: Tiep Le <tiep.le@intel.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update name of gateway Signed-off-by: Tiep Le <tiep.le@intel.com> --------- Signed-off-by: Tiep Le <tiep.le@intel.com> Signed-off-by: siddhivelankar23 <siddhi.velankar@intel.com> Signed-off-by: sjagtap1803 <siddhant.jagtap@intel.com> Co-authored-by: siddhivelankar23 <siddhi.velankar@intel.com> Co-authored-by: sjagtap1803 <siddhant.jagtap@intel.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Sihan Chen <39623753+Spycsh@users.noreply.github.com>
- Loading branch information
1 parent
2705e93
commit 99be1bd
Showing
4 changed files
with
374 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
215 changes: 215 additions & 0 deletions
215
tests/cores/mega/test_multimodalrag_with_videos_gateway.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,215 @@ | ||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import json | ||
import unittest | ||
from typing import Union | ||
|
||
import requests | ||
from fastapi import Request | ||
|
||
from comps import ( | ||
EmbedDoc, | ||
EmbedMultimodalDoc, | ||
LVMDoc, | ||
LVMSearchedMultimodalDoc, | ||
MultimodalDoc, | ||
MultimodalRAGWithVideosGateway, | ||
SearchedMultimodalDoc, | ||
ServiceOrchestrator, | ||
TextDoc, | ||
opea_microservices, | ||
register_microservice, | ||
) | ||
|
||
|
||
@register_microservice(name="mm_embedding", host="0.0.0.0", port=8083, endpoint="/v1/mm_embedding") | ||
async def mm_embedding_add(request: MultimodalDoc) -> EmbedDoc: | ||
req = request.model_dump_json() | ||
req_dict = json.loads(req) | ||
text = req_dict["text"] | ||
res = {} | ||
res["text"] = text | ||
res["embedding"] = [0.12, 0.45] | ||
return res | ||
|
||
|
||
@register_microservice(name="mm_retriever", host="0.0.0.0", port=8084, endpoint="/v1/mm_retriever") | ||
async def mm_retriever_add(request: EmbedMultimodalDoc) -> SearchedMultimodalDoc: | ||
req = request.model_dump_json() | ||
req_dict = json.loads(req) | ||
text = req_dict["text"] | ||
res = {} | ||
res["retrieved_docs"] = [] | ||
res["initial_query"] = text | ||
res["top_n"] = 1 | ||
res["metadata"] = [ | ||
{ | ||
"b64_img_str": "iVBORw0KGgoAAAANSUhEUgAAAAoAAAAKCAYAAACNMs+9AAAAFUlEQVR42mP8/5+hnoEIwDiqkL4KAcT9GO0U4BxoAAAAAElFTkSuQmCC", | ||
"transcript_for_inference": "yellow image", | ||
} | ||
] | ||
res["chat_template"] = "The caption of the image is: '{context}'. {question}" | ||
return res | ||
|
||
|
||
@register_microservice(name="lvm", host="0.0.0.0", port=8085, endpoint="/v1/lvm") | ||
async def lvm_add(request: Union[LVMDoc, LVMSearchedMultimodalDoc]) -> TextDoc: | ||
req = request.model_dump_json() | ||
req_dict = json.loads(req) | ||
if isinstance(request, LVMSearchedMultimodalDoc): | ||
print("request is the output of multimodal retriever") | ||
text = req_dict["initial_query"] | ||
text += "opea project!" | ||
|
||
else: | ||
print("request is from user.") | ||
text = req_dict["prompt"] | ||
text = f"<image>\nUSER: {text}\nASSISTANT:" | ||
|
||
res = {} | ||
res["text"] = text | ||
return res | ||
|
||
|
||
class TestServiceOrchestrator(unittest.IsolatedAsyncioTestCase): | ||
@classmethod | ||
def setUpClass(cls): | ||
cls.mm_embedding = opea_microservices["mm_embedding"] | ||
cls.mm_retriever = opea_microservices["mm_retriever"] | ||
cls.lvm = opea_microservices["lvm"] | ||
cls.mm_embedding.start() | ||
cls.mm_retriever.start() | ||
cls.lvm.start() | ||
|
||
cls.service_builder = ServiceOrchestrator() | ||
|
||
cls.service_builder.add(opea_microservices["mm_embedding"]).add(opea_microservices["mm_retriever"]).add( | ||
opea_microservices["lvm"] | ||
) | ||
cls.service_builder.flow_to(cls.mm_embedding, cls.mm_retriever) | ||
cls.service_builder.flow_to(cls.mm_retriever, cls.lvm) | ||
|
||
cls.follow_up_query_service_builder = ServiceOrchestrator() | ||
cls.follow_up_query_service_builder.add(cls.lvm) | ||
|
||
cls.gateway = MultimodalRAGWithVideosGateway( | ||
cls.service_builder, cls.follow_up_query_service_builder, port=9898 | ||
) | ||
|
||
@classmethod | ||
def tearDownClass(cls): | ||
cls.mm_embedding.stop() | ||
cls.mm_retriever.stop() | ||
cls.lvm.stop() | ||
cls.gateway.stop() | ||
|
||
async def test_service_builder_schedule(self): | ||
result_dict, _ = await self.service_builder.schedule(initial_inputs={"text": "hello, "}) | ||
self.assertEqual(result_dict[self.lvm.name]["text"], "hello, opea project!") | ||
|
||
async def test_follow_up_query_service_builder_schedule(self): | ||
result_dict, _ = await self.follow_up_query_service_builder.schedule( | ||
initial_inputs={"prompt": "chao, ", "image": "some image"} | ||
) | ||
# print(result_dict) | ||
self.assertEqual(result_dict[self.lvm.name]["text"], "<image>\nUSER: chao, \nASSISTANT:") | ||
|
||
def test_multimodal_rag_with_videos_gateway(self): | ||
json_data = {"messages": "hello, "} | ||
response = requests.post("http://0.0.0.0:9898/v1/mmragvideoqna", json=json_data) | ||
response = response.json() | ||
self.assertEqual(response["choices"][-1]["message"]["content"], "hello, opea project!") | ||
|
||
def test_follow_up_mm_rag_with_videos_gateway(self): | ||
json_data = { | ||
"messages": [ | ||
{ | ||
"role": "user", | ||
"content": [ | ||
{"type": "text", "text": "hello, "}, | ||
{ | ||
"type": "image_url", | ||
"image_url": {"url": "https://www.ilankelman.org/stopsigns/australia.jpg"}, | ||
}, | ||
], | ||
}, | ||
{"role": "assistant", "content": "opea project! "}, | ||
{"role": "user", "content": "chao, "}, | ||
], | ||
"max_tokens": 300, | ||
} | ||
response = requests.post("http://0.0.0.0:9898/v1/mmragvideoqna", json=json_data) | ||
response = response.json() | ||
self.assertEqual( | ||
response["choices"][-1]["message"]["content"], | ||
"<image>\nUSER: hello, \nASSISTANT: opea project! \nUSER: chao, \n\nASSISTANT:", | ||
) | ||
|
||
def test_handle_message(self): | ||
messages = [ | ||
{ | ||
"role": "user", | ||
"content": [ | ||
{"type": "text", "text": "hello, "}, | ||
{ | ||
"type": "image_url", | ||
"image_url": {"url": "https://www.ilankelman.org/stopsigns/australia.jpg"}, | ||
}, | ||
], | ||
}, | ||
{"role": "assistant", "content": "opea project! "}, | ||
{"role": "user", "content": "chao, "}, | ||
] | ||
prompt, images = self.gateway._handle_message(messages) | ||
self.assertEqual(prompt, "hello, \nASSISTANT: opea project! \nUSER: chao, \n") | ||
|
||
def test_handle_message_with_system_prompt(self): | ||
messages = [ | ||
{"role": "system", "content": "System Prompt"}, | ||
{ | ||
"role": "user", | ||
"content": [ | ||
{"type": "text", "text": "hello, "}, | ||
{ | ||
"type": "image_url", | ||
"image_url": {"url": "https://www.ilankelman.org/stopsigns/australia.jpg"}, | ||
}, | ||
], | ||
}, | ||
{"role": "assistant", "content": "opea project! "}, | ||
{"role": "user", "content": "chao, "}, | ||
] | ||
prompt, images = self.gateway._handle_message(messages) | ||
self.assertEqual(prompt, "System Prompt\nhello, \nASSISTANT: opea project! \nUSER: chao, \n") | ||
|
||
async def test_handle_request(self): | ||
json_data = { | ||
"messages": [ | ||
{ | ||
"role": "user", | ||
"content": [ | ||
{"type": "text", "text": "hello, "}, | ||
{ | ||
"type": "image_url", | ||
"image_url": {"url": "https://www.ilankelman.org/stopsigns/australia.jpg"}, | ||
}, | ||
], | ||
}, | ||
{"role": "assistant", "content": "opea project! "}, | ||
{"role": "user", "content": "chao, "}, | ||
], | ||
"max_tokens": 300, | ||
} | ||
mock_request = Request(scope={"type": "http"}) | ||
mock_request._json = json_data | ||
res = await self.gateway.handle_request(mock_request) | ||
res = json.loads(res.json()) | ||
self.assertEqual( | ||
res["choices"][-1]["message"]["content"], | ||
"<image>\nUSER: hello, \nASSISTANT: opea project! \nUSER: chao, \n\nASSISTANT:", | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |