Skip to content

Commit

Permalink
Merge pull request #1 from video-db/ashish/agents
Browse files Browse the repository at this point in the history
Ashish/agents
  • Loading branch information
ankit-v2-3 authored Oct 16, 2024
2 parents 6eb2746 + b7a582c commit 6e6ba48
Show file tree
Hide file tree
Showing 17 changed files with 436 additions and 57 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,5 @@ venv
.ruff_cache
.pytest_cache
*.egg-info
package-lock.json
*.mjs
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Spielberg is an open source platform for creating agents which interact with you

### Prerequisites

- Python 3.6 or higher
- Python 3.9 or higher
- Node.js 22.8.0 or higher
- npm

Expand Down
1 change: 1 addition & 0 deletions backend/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ openai-function-calling==2.2.0
pydantic==2.8.2
pydantic-settings==2.4.0
python-dotenv==1.0.1
replicate==1.0.1
git+https://github.com/video-db/videodb-python@ankit/add-download
85 changes: 85 additions & 0 deletions backend/spielberg/agents/brandkit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import os
import logging

from spielberg.agents.base import BaseAgent, AgentResponse, AgentResult
from spielberg.core.session import Session, MsgStatus, VideoContent
from spielberg.tools.videodb_tool import VideoDBTool

logger = logging.getLogger(__name__)

INTRO_VIDEO_ID = os.getenv("INTRO_VIDEO_ID")
OUTRO_VIDEO_ID = os.getenv("OUTRO_VIDEO_ID")
BRAND_IMAGE_ID = os.getenv("BRAND_IMAGE_ID")


class BrandkitAgent(BaseAgent):
def __init__(self, session: Session, **kwargs):
self.agent_name = "brandkit"
self.description = (
"Agent to add brand kit elements (intro video, outro video and brand image) to the given video in VideoDB,"
"if user has not given those optional param of intro video, outro video and brand image always try with sending them as None so that defaults are picked from env"
)
self.parameters = self.get_parameters()
super().__init__(session=session, **kwargs)

def __call__(
self,
collection_id: str,
video_id: str,
intro_video_id: str = None,
outro_video_id: str = None,
brand_image_id: str = None,
*args,
**kwargs,
) -> AgentResponse:
"""
Generate stream of video after adding branding elements.
:param str collection_id: collection id in which videos are available.
:param str video_id: video id on which branding is required.
:param str intro_video_id: VideoDB video id of intro video, defaults to INTRO_VIDEO_ID
:param str outro_video_id: video id of outro video, defaults to OUTRO_VIDEO_ID
:param str brand_image_id: image id of brand image for overlay over video, defaults to BRAND_IMAGE_ID
:param args: Additional positional arguments.
:param kwargs: Additional keyword arguments.
:return: The response containing information about the generated brand stream.
:rtype: AgentResponse
"""
try:
self.output_message.actions.append("Processing brandkit request..")
intro_video_id = intro_video_id or INTRO_VIDEO_ID
outro_video_id = outro_video_id or OUTRO_VIDEO_ID
brand_image_id = brand_image_id or BRAND_IMAGE_ID
if not any([intro_video_id, outro_video_id, brand_image_id]):
message = (
"Branding elementes not provided, either you can provide provide IDs for intro video, outro video and branding image"
" or you can set INTRO_VIDEO_ID, OUTRO_VIDEO_ID and BRAND_IMAGE_ID in .env of backend directory."
)
return AgentResponse(result=AgentResult.ERROR, message=message)
video_content = VideoContent(
agent_name=self.agent_name,
status=MsgStatus.progress,
status_message="Generating video with branding..",
)
self.output_message.content.append(video_content)
self.output_message.push_update()
videodb_tool = VideoDBTool(collection_id=collection_id)
brandkit_stream = videodb_tool.add_brandkit(
video_id, intro_video_id, outro_video_id, brand_image_id
)
video_content.video = {"stream_url": brandkit_stream}
video_content.status = MsgStatus.success
video_content.status_message = "Here is your brandkit stream"
self.output_message.publish()
except Exception:
logger.exception(f"Error in {self.agent_name}")
video_content.status = MsgStatus.error
error_message = "Error in adding branding."
video_content.status_message = error_message
self.output_message.publish()
return AgentResponse(result=AgentResult.ERROR, message=error_message)
return AgentResponse(
result=AgentResult.SUCCESS,
message=f"Agent {self.name} completed successfully.",
data={"stream_url": brandkit_stream},
)
59 changes: 59 additions & 0 deletions backend/spielberg/agents/image_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import logging

from spielberg.agents.base import BaseAgent, AgentResponse, AgentResult
from spielberg.core.session import Session, MsgStatus, ImageContent
from spielberg.tools.replicate import flux_dev

logger = logging.getLogger(__name__)


class ImageGenerationAgent(BaseAgent):
def __init__(self, session: Session, **kwargs):
self.agent_name = "image_generation"
self.description = "Agent for image generation using Gen AI models on given prompt and configurations."
self.parameters = self.get_parameters()
super().__init__(session=session, **kwargs)

def __call__(self, prompt: str, *args, **kwargs) -> AgentResponse:
"""
Process the prompt to generate the image.
:param str prompt: prompt for image generation.
:param args: Additional positional arguments.
:param kwargs: Additional keyword arguments.
:return: The response containing information about generated image.
:rtype: AgentResponse
"""
try:
# TODO: Integrate other models and drive parameters from input as well
self.output_message.actions.append("Processing prompt..")
image_content = ImageContent(
agent_name=self.agent_name, status=MsgStatus.progress
)
image_content.status_message = "Generating image.."
self.output_message.content.append(image_content)
self.output_message.push_update()
flux_output = flux_dev(prompt)
if not flux_output:
image_content.status = MsgStatus.error
image_content.status_message = "Error in generating image."
self.output_message.publish()
error_message = "Agent failed with error in replicate."
return AgentResponse(result=AgentResult.ERROR, message=error_message)
image_url = flux_output[0].url
image_content.image = {"url": image_url}
image_content.status = MsgStatus.success
image_content.status_message = "Here is your generated image"
self.output_message.publish()
except Exception as e:
logger.exception(f"Error in {self.agent_name}")
image_content.status = MsgStatus.error
image_content.status_message = "Error in generating image."
self.output_message.publish()
error_message = f"Agent failed with error {e}"
return AgentResponse(result=AgentResult.ERROR, message=error_message)
return AgentResponse(
result=AgentResult.SUCCESS,
message=f"Agent {self.name} completed successfully.",
data={"image_url": image_url},
)
13 changes: 7 additions & 6 deletions backend/spielberg/agents/pricing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
RoleTypes,
TextContent,
)
from spielberg.llm.openai import OpenaiConfig, OpenAI
from spielberg.llm.openai import OpenAI

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -108,16 +108,16 @@ def __call__(self, query: str, *args, **kwargs) -> AgentResponse:
:rtype: AgentResponse
"""
try:
self.output_message.actions.append("Calculating pricing")
text_content = TextContent(agent_name=self.agent_name)
text_content.status_message = "Calculating pricing.."
self.output_message.content.append(text_content)
self.output_message.push_update()

pricing_llm_message = f"{PRICING_AGENT_PROMPT} user query: {query}"
pricing_llm_context = ContextMessage(
content=pricing_llm_message, role=RoleTypes.user
pricing_llm_prompt = f"{PRICING_AGENT_PROMPT} user query: {query}"
pricing_llm_message = ContextMessage(
content=pricing_llm_prompt, role=RoleTypes.user
)
llm_response = self.llm.chat_completions([pricing_llm_context.to_llm_msg()])
llm_response = self.llm.chat_completions([pricing_llm_message.to_llm_msg()])

if not llm_response.status:
logger.error(f"LLM failed with {llm_response}")
Expand All @@ -130,6 +130,7 @@ def __call__(self, query: str, *args, **kwargs) -> AgentResponse:
)
text_content.text = llm_response.content
text_content.status = MsgStatus.success
text_content.status_message = "Pricing estimation is ready."
self.output_message.publish()
except Exception:
logger.exception(f"Error in {self.agent_name}")
Expand Down
119 changes: 119 additions & 0 deletions backend/spielberg/agents/profanity_remover.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import json
import logging
import os

from videodb.asset import VideoAsset, AudioAsset

from spielberg.agents.base import BaseAgent, AgentResponse, AgentResult
from spielberg.core.session import (
Session,
MsgStatus,
VideoContent,
ContextMessage,
RoleTypes,
)
from spielberg.llm.openai import OpenAI
from spielberg.tools.videodb_tool import VideoDBTool

logger = logging.getLogger(__name__)

BEEP_AUDIO_ID = os.getenv("BEEP_AUDIO_ID")
PROFANITY_FINDER_PROMPT = """
Given the following transcript give the list of timestamps where profanity is there for censoring.
Expected output format is json like {"timestamps": [(start, end), (start, end)]} where start and end are integer in seconds
"""


class ProfanityRemoverAgent(BaseAgent):
def __init__(self, session: Session, **kwargs):
self.agent_name = "profanity_remover"
self.description = (
"Agent to beep the profanities in the given video and return the clean stream."
"if user has not given those optional param of beep_audio_id always try with sending it as None so that defaults are picked from env"
)
self.parameters = self.get_parameters()
self.llm = OpenAI()
super().__init__(session=session, **kwargs)

def add_beep(self, videodb_tool, video_id, beep_audio_id, timestamps):
timeline = videodb_tool.get_and_set_timeline()
video_asset = VideoAsset(asset_id=video_id)
timeline.add_inline(video_asset)
for start, _ in timestamps:
beep = AudioAsset(asset_id=beep_audio_id)
timeline.add_overlay(start=start, asset=beep)
stream_url = timeline.generate_stream()
return stream_url

def __call__(
self,
collection_id: str,
video_id: str,
beep_audio_id: str = None,
*args,
**kwargs,
) -> AgentResponse:
"""
Process the video to remove the profanities by overlaying beep.
:param str collection_id: collection id in which the source video is present.
:param str video_id: video_id on which profanity remover needs to run.
:param str beep_audio_id: audio id of beep asset in videodb, defaults to BEEP_AUDIO_ID
:param args: Additional positional arguments.
:param kwargs: Additional keyword arguments.
:return: The response containing information about the sample processing operation.
:rtype: AgentResponse
"""
try:
beep_audio_id = beep_audio_id or BEEP_AUDIO_ID
if not beep_audio_id:
return AgentResponse(
result=AgentResult.failed,
message="Please provide the beep_audio_id or setup BEEP_AUDIO_ID in .env of backend directory.",
)
self.output_message.actions.append("Started process to remove profanity..")
video_content = VideoContent(
agent_name=self.agent_name, status=MsgStatus.progress
)
video_content.status_message = "Generating clean stream.."
self.output_message.push_update()
videodb_tool = VideoDBTool(collection_id=collection_id)
try:
transcript = videodb_tool.get_transcript(video_id, text=False)
except Exception:
logger.error("Failed to get transcript, indexing")
self.output_message.actions.append("Indexing the video..")
self.output_message.push_update()
videodb_tool.index_spoken_words(video_id)
transcript = videodb_tool.get_transcript(video_id, text=False)
profanity_prompt = f"{PROFANITY_FINDER_PROMPT}\n\ntranscript: {transcript}"
profanity_llm_message = ContextMessage(
content=profanity_prompt,
role=RoleTypes.user,
)
llm_response = self.llm.chat_completions(
[profanity_llm_message.to_llm_msg()],
response_format={"type": "json_object"},
)
profanity_timeline_response = json.loads(llm_response.content)
profanity_timeline = profanity_timeline_response.get("timestamps")
clean_stream = self.add_beep(
videodb_tool, video_id, beep_audio_id, profanity_timeline
)
video_content.video = {"stream_url": clean_stream}
video_content.status = MsgStatus.success
video_content.status_message = "Here is the clean stream"
self.output_message.content.append(video_content)
self.output_message.publish()
except Exception as e:
logger.exception(f"Error in {self.agent_name}")
video_content.status = MsgStatus.error
video_content.status_message = "Failed to generate clean stream"
self.output_message.publish()
error_message = f"Error in generating the clean stream due to {e}."
return AgentResponse(result=AgentResult.ERROR, message=error_message)
return AgentResponse(
result=AgentResult.SUCCESS,
message=f"Agent {self.name} completed successfully.",
data={"stream_url": clean_stream},
)
13 changes: 6 additions & 7 deletions backend/spielberg/agents/sample.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import logging

from spielberg.agents.base import BaseAgent, AgentResponse, AgentResult

from spielberg.core.session import Session, MsgStatus, TextContent

logger = logging.getLogger(__name__)
Expand All @@ -18,8 +17,7 @@ def __call__(self, sample_id: str, *args, **kwargs) -> AgentResponse:
"""
Process the sample based on the given sample ID.
:param sample_id: The ID of the sample to process.
:type sample_id: str
:param str sample_id: The ID of the sample to process.
:param args: Additional positional arguments.
:param kwargs: Additional keyword arguments.
:return: The response containing information about the sample processing operation.
Expand All @@ -34,13 +32,14 @@ def __call__(self, sample_id: str, *args, **kwargs) -> AgentResponse:
self.output_message.push_update()
text_content.text = "This is the text result of Agent."
text_content.status = MsgStatus.success
text_content.status_message = "Here is your response"
self.output_message.publish()
except Exception:
logger.exception(f"error in {self.agent_name}")
except Exception as e:
logger.exception(f"Error in {self.agent_name}")
text_content.status = MsgStatus.error
error_message = "Error in calculating pricing."
text_content.status_message = error_message
text_content.status_message = "Error in calculating pricing."
self.output_message.publish()
error_message = f"Agent failed with error {e}"
return AgentResponse(result=AgentResult.ERROR, message=error_message)
return AgentResponse(
result=AgentResult.SUCCESS,
Expand Down
Loading

0 comments on commit 6e6ba48

Please sign in to comment.