-
Notifications
You must be signed in to change notification settings - Fork 42
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1 from video-db/ashish/agents
Ashish/agents
- Loading branch information
Showing
17 changed files
with
436 additions
and
57 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,3 +12,5 @@ venv | |
.ruff_cache | ||
.pytest_cache | ||
*.egg-info | ||
package-lock.json | ||
*.mjs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
import os | ||
import logging | ||
|
||
from spielberg.agents.base import BaseAgent, AgentResponse, AgentResult | ||
from spielberg.core.session import Session, MsgStatus, VideoContent | ||
from spielberg.tools.videodb_tool import VideoDBTool | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
INTRO_VIDEO_ID = os.getenv("INTRO_VIDEO_ID") | ||
OUTRO_VIDEO_ID = os.getenv("OUTRO_VIDEO_ID") | ||
BRAND_IMAGE_ID = os.getenv("BRAND_IMAGE_ID") | ||
|
||
|
||
class BrandkitAgent(BaseAgent): | ||
def __init__(self, session: Session, **kwargs): | ||
self.agent_name = "brandkit" | ||
self.description = ( | ||
"Agent to add brand kit elements (intro video, outro video and brand image) to the given video in VideoDB," | ||
"if user has not given those optional param of intro video, outro video and brand image always try with sending them as None so that defaults are picked from env" | ||
) | ||
self.parameters = self.get_parameters() | ||
super().__init__(session=session, **kwargs) | ||
|
||
def __call__( | ||
self, | ||
collection_id: str, | ||
video_id: str, | ||
intro_video_id: str = None, | ||
outro_video_id: str = None, | ||
brand_image_id: str = None, | ||
*args, | ||
**kwargs, | ||
) -> AgentResponse: | ||
""" | ||
Generate stream of video after adding branding elements. | ||
:param str collection_id: collection id in which videos are available. | ||
:param str video_id: video id on which branding is required. | ||
:param str intro_video_id: VideoDB video id of intro video, defaults to INTRO_VIDEO_ID | ||
:param str outro_video_id: video id of outro video, defaults to OUTRO_VIDEO_ID | ||
:param str brand_image_id: image id of brand image for overlay over video, defaults to BRAND_IMAGE_ID | ||
:param args: Additional positional arguments. | ||
:param kwargs: Additional keyword arguments. | ||
:return: The response containing information about the generated brand stream. | ||
:rtype: AgentResponse | ||
""" | ||
try: | ||
self.output_message.actions.append("Processing brandkit request..") | ||
intro_video_id = intro_video_id or INTRO_VIDEO_ID | ||
outro_video_id = outro_video_id or OUTRO_VIDEO_ID | ||
brand_image_id = brand_image_id or BRAND_IMAGE_ID | ||
if not any([intro_video_id, outro_video_id, brand_image_id]): | ||
message = ( | ||
"Branding elementes not provided, either you can provide provide IDs for intro video, outro video and branding image" | ||
" or you can set INTRO_VIDEO_ID, OUTRO_VIDEO_ID and BRAND_IMAGE_ID in .env of backend directory." | ||
) | ||
return AgentResponse(result=AgentResult.ERROR, message=message) | ||
video_content = VideoContent( | ||
agent_name=self.agent_name, | ||
status=MsgStatus.progress, | ||
status_message="Generating video with branding..", | ||
) | ||
self.output_message.content.append(video_content) | ||
self.output_message.push_update() | ||
videodb_tool = VideoDBTool(collection_id=collection_id) | ||
brandkit_stream = videodb_tool.add_brandkit( | ||
video_id, intro_video_id, outro_video_id, brand_image_id | ||
) | ||
video_content.video = {"stream_url": brandkit_stream} | ||
video_content.status = MsgStatus.success | ||
video_content.status_message = "Here is your brandkit stream" | ||
self.output_message.publish() | ||
except Exception: | ||
logger.exception(f"Error in {self.agent_name}") | ||
video_content.status = MsgStatus.error | ||
error_message = "Error in adding branding." | ||
video_content.status_message = error_message | ||
self.output_message.publish() | ||
return AgentResponse(result=AgentResult.ERROR, message=error_message) | ||
return AgentResponse( | ||
result=AgentResult.SUCCESS, | ||
message=f"Agent {self.name} completed successfully.", | ||
data={"stream_url": brandkit_stream}, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import logging | ||
|
||
from spielberg.agents.base import BaseAgent, AgentResponse, AgentResult | ||
from spielberg.core.session import Session, MsgStatus, ImageContent | ||
from spielberg.tools.replicate import flux_dev | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class ImageGenerationAgent(BaseAgent): | ||
def __init__(self, session: Session, **kwargs): | ||
self.agent_name = "image_generation" | ||
self.description = "Agent for image generation using Gen AI models on given prompt and configurations." | ||
self.parameters = self.get_parameters() | ||
super().__init__(session=session, **kwargs) | ||
|
||
def __call__(self, prompt: str, *args, **kwargs) -> AgentResponse: | ||
""" | ||
Process the prompt to generate the image. | ||
:param str prompt: prompt for image generation. | ||
:param args: Additional positional arguments. | ||
:param kwargs: Additional keyword arguments. | ||
:return: The response containing information about generated image. | ||
:rtype: AgentResponse | ||
""" | ||
try: | ||
# TODO: Integrate other models and drive parameters from input as well | ||
self.output_message.actions.append("Processing prompt..") | ||
image_content = ImageContent( | ||
agent_name=self.agent_name, status=MsgStatus.progress | ||
) | ||
image_content.status_message = "Generating image.." | ||
self.output_message.content.append(image_content) | ||
self.output_message.push_update() | ||
flux_output = flux_dev(prompt) | ||
if not flux_output: | ||
image_content.status = MsgStatus.error | ||
image_content.status_message = "Error in generating image." | ||
self.output_message.publish() | ||
error_message = "Agent failed with error in replicate." | ||
return AgentResponse(result=AgentResult.ERROR, message=error_message) | ||
image_url = flux_output[0].url | ||
image_content.image = {"url": image_url} | ||
image_content.status = MsgStatus.success | ||
image_content.status_message = "Here is your generated image" | ||
self.output_message.publish() | ||
except Exception as e: | ||
logger.exception(f"Error in {self.agent_name}") | ||
image_content.status = MsgStatus.error | ||
image_content.status_message = "Error in generating image." | ||
self.output_message.publish() | ||
error_message = f"Agent failed with error {e}" | ||
return AgentResponse(result=AgentResult.ERROR, message=error_message) | ||
return AgentResponse( | ||
result=AgentResult.SUCCESS, | ||
message=f"Agent {self.name} completed successfully.", | ||
data={"image_url": image_url}, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
import json | ||
import logging | ||
import os | ||
|
||
from videodb.asset import VideoAsset, AudioAsset | ||
|
||
from spielberg.agents.base import BaseAgent, AgentResponse, AgentResult | ||
from spielberg.core.session import ( | ||
Session, | ||
MsgStatus, | ||
VideoContent, | ||
ContextMessage, | ||
RoleTypes, | ||
) | ||
from spielberg.llm.openai import OpenAI | ||
from spielberg.tools.videodb_tool import VideoDBTool | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
BEEP_AUDIO_ID = os.getenv("BEEP_AUDIO_ID") | ||
PROFANITY_FINDER_PROMPT = """ | ||
Given the following transcript give the list of timestamps where profanity is there for censoring. | ||
Expected output format is json like {"timestamps": [(start, end), (start, end)]} where start and end are integer in seconds | ||
""" | ||
|
||
|
||
class ProfanityRemoverAgent(BaseAgent): | ||
def __init__(self, session: Session, **kwargs): | ||
self.agent_name = "profanity_remover" | ||
self.description = ( | ||
"Agent to beep the profanities in the given video and return the clean stream." | ||
"if user has not given those optional param of beep_audio_id always try with sending it as None so that defaults are picked from env" | ||
) | ||
self.parameters = self.get_parameters() | ||
self.llm = OpenAI() | ||
super().__init__(session=session, **kwargs) | ||
|
||
def add_beep(self, videodb_tool, video_id, beep_audio_id, timestamps): | ||
timeline = videodb_tool.get_and_set_timeline() | ||
video_asset = VideoAsset(asset_id=video_id) | ||
timeline.add_inline(video_asset) | ||
for start, _ in timestamps: | ||
beep = AudioAsset(asset_id=beep_audio_id) | ||
timeline.add_overlay(start=start, asset=beep) | ||
stream_url = timeline.generate_stream() | ||
return stream_url | ||
|
||
def __call__( | ||
self, | ||
collection_id: str, | ||
video_id: str, | ||
beep_audio_id: str = None, | ||
*args, | ||
**kwargs, | ||
) -> AgentResponse: | ||
""" | ||
Process the video to remove the profanities by overlaying beep. | ||
:param str collection_id: collection id in which the source video is present. | ||
:param str video_id: video_id on which profanity remover needs to run. | ||
:param str beep_audio_id: audio id of beep asset in videodb, defaults to BEEP_AUDIO_ID | ||
:param args: Additional positional arguments. | ||
:param kwargs: Additional keyword arguments. | ||
:return: The response containing information about the sample processing operation. | ||
:rtype: AgentResponse | ||
""" | ||
try: | ||
beep_audio_id = beep_audio_id or BEEP_AUDIO_ID | ||
if not beep_audio_id: | ||
return AgentResponse( | ||
result=AgentResult.failed, | ||
message="Please provide the beep_audio_id or setup BEEP_AUDIO_ID in .env of backend directory.", | ||
) | ||
self.output_message.actions.append("Started process to remove profanity..") | ||
video_content = VideoContent( | ||
agent_name=self.agent_name, status=MsgStatus.progress | ||
) | ||
video_content.status_message = "Generating clean stream.." | ||
self.output_message.push_update() | ||
videodb_tool = VideoDBTool(collection_id=collection_id) | ||
try: | ||
transcript = videodb_tool.get_transcript(video_id, text=False) | ||
except Exception: | ||
logger.error("Failed to get transcript, indexing") | ||
self.output_message.actions.append("Indexing the video..") | ||
self.output_message.push_update() | ||
videodb_tool.index_spoken_words(video_id) | ||
transcript = videodb_tool.get_transcript(video_id, text=False) | ||
profanity_prompt = f"{PROFANITY_FINDER_PROMPT}\n\ntranscript: {transcript}" | ||
profanity_llm_message = ContextMessage( | ||
content=profanity_prompt, | ||
role=RoleTypes.user, | ||
) | ||
llm_response = self.llm.chat_completions( | ||
[profanity_llm_message.to_llm_msg()], | ||
response_format={"type": "json_object"}, | ||
) | ||
profanity_timeline_response = json.loads(llm_response.content) | ||
profanity_timeline = profanity_timeline_response.get("timestamps") | ||
clean_stream = self.add_beep( | ||
videodb_tool, video_id, beep_audio_id, profanity_timeline | ||
) | ||
video_content.video = {"stream_url": clean_stream} | ||
video_content.status = MsgStatus.success | ||
video_content.status_message = "Here is the clean stream" | ||
self.output_message.content.append(video_content) | ||
self.output_message.publish() | ||
except Exception as e: | ||
logger.exception(f"Error in {self.agent_name}") | ||
video_content.status = MsgStatus.error | ||
video_content.status_message = "Failed to generate clean stream" | ||
self.output_message.publish() | ||
error_message = f"Error in generating the clean stream due to {e}." | ||
return AgentResponse(result=AgentResult.ERROR, message=error_message) | ||
return AgentResponse( | ||
result=AgentResult.SUCCESS, | ||
message=f"Agent {self.name} completed successfully.", | ||
data={"stream_url": clean_stream}, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.