Skip to content

Commit

Permalink
Merge pull request #7 from video-db/ankit/refactor
Browse files Browse the repository at this point in the history
Ankit/refactor
  • Loading branch information
ashish-spext authored Oct 23, 2024
2 parents d977dc1 + f93fec7 commit bf56081
Show file tree
Hide file tree
Showing 23 changed files with 213 additions and 127 deletions.
17 changes: 12 additions & 5 deletions backend/.env.sample
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
# VideoDB Integration
VIDEO_DB_API_KEY=

# OpenAI config
# LLM Integrations
OPENAI_API_KEY=

# VideoDB config
VIDEO_DB_API_KEY=
# only for dev
VIDEO_DB_BASE_URL=
# Tools
REPLICATE_API_TOKEN=

# Brandkit Agent
INTRO_VIDEO_ID=
OUTRO_VIDEO_ID=
BRAND_IMAGE_ID=

# Profanity Remover Agent
BEEP_AUDIO_ID=
12 changes: 6 additions & 6 deletions backend/spielberg/agents/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@
logger = logging.getLogger(__name__)


class AgentResult:
class AgentStatus:
SUCCESS = "success"
ERROR = "error"


class AgentResponse(BaseModel):
result: str = AgentResult.SUCCESS
status: str = AgentStatus.SUCCESS
message: str = ""
data: dict = {}

Expand All @@ -30,7 +30,7 @@ def __init__(self, session: Session, **kwargs):

def get_parameters(self):
function_inferrer = FunctionInferrer.infer_from_function_reference(
self.__call__
self.run
)
function_json = function_inferrer.to_json_schema()
parameters = function_json.get("parameters")
Expand Down Expand Up @@ -58,12 +58,12 @@ def agent_description(self):

def safe_call(self, *args, **kwargs):
try:
return self.__call__(*args, **kwargs)
return self.run(*args, **kwargs)

except Exception as e:
logger.exception(f"error in {self.agent_name} agent: {e}")
return AgentResponse(result=AgentResult.ERROR, message=str(e))
return AgentResponse(status=AgentStatus.ERROR, message=str(e))

@abstractmethod
def __call__(*args, **kwargs) -> AgentResponse:
def run(*args, **kwargs) -> AgentResponse:
pass
14 changes: 7 additions & 7 deletions backend/spielberg/agents/brandkit.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import os
import logging

from spielberg.agents.base import BaseAgent, AgentResponse, AgentResult
from spielberg.core.session import Session, MsgStatus, VideoContent
from spielberg.agents.base import BaseAgent, AgentResponse, AgentStatus
from spielberg.core.session import Session, MsgStatus, VideoContent, VideoData
from spielberg.tools.videodb_tool import VideoDBTool

logger = logging.getLogger(__name__)
Expand All @@ -22,7 +22,7 @@ def __init__(self, session: Session, **kwargs):
self.parameters = self.get_parameters()
super().__init__(session=session, **kwargs)

def __call__(
def run(
self,
collection_id: str,
video_id: str,
Expand Down Expand Up @@ -55,7 +55,7 @@ def __call__(
"Branding elementes not provided, either you can provide provide IDs for intro video, outro video and branding image"
" or you can set INTRO_VIDEO_ID, OUTRO_VIDEO_ID and BRAND_IMAGE_ID in .env of backend directory."
)
return AgentResponse(result=AgentResult.ERROR, message=message)
return AgentResponse(status=AgentStatus.ERROR, message=message)
video_content = VideoContent(
agent_name=self.agent_name,
status=MsgStatus.progress,
Expand All @@ -67,7 +67,7 @@ def __call__(
brandkit_stream = videodb_tool.add_brandkit(
video_id, intro_video_id, outro_video_id, brand_image_id
)
video_content.video = {"stream_url": brandkit_stream}
video_content.video = VideoData(stream_url=brandkit_stream)
video_content.status = MsgStatus.success
video_content.status_message = "Here is your brandkit stream"
self.output_message.publish()
Expand All @@ -77,9 +77,9 @@ def __call__(
error_message = "Error in adding branding."
video_content.status_message = error_message
self.output_message.publish()
return AgentResponse(result=AgentResult.ERROR, message=error_message)
return AgentResponse(status=AgentStatus.ERROR, message=error_message)
return AgentResponse(
result=AgentResult.SUCCESS,
status=AgentStatus.SUCCESS,
message=f"Agent {self.name} completed successfully.",
data={"stream_url": brandkit_stream},
)
8 changes: 4 additions & 4 deletions backend/spielberg/agents/download.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging

from spielberg.agents.base import BaseAgent, AgentResponse, AgentResult
from spielberg.agents.base import BaseAgent, AgentResponse, AgentStatus
from spielberg.core.session import Session
from spielberg.tools.videodb_tool import VideoDBTool

Expand All @@ -14,7 +14,7 @@ def __init__(self, session: Session, **kwargs):
self.parameters = self.get_parameters()
super().__init__(session=session, **kwargs)

def __call__(
def run(
self, stream_link: str, name: str = None, *args, **kwargs
) -> AgentResponse:
"""
Expand All @@ -34,9 +34,9 @@ def __call__(
download_response = videodb_tool.download(stream_link, name)
except Exception as e:
logger.exception(f"error in {self.agent_name} agent: {e}")
return AgentResponse(result=AgentResult.ERROR, message=str(e))
return AgentResponse(status=AgentStatus.ERROR, message=str(e))
return AgentResponse(
result=AgentResult.SUCCESS,
status=AgentStatus.SUCCESS,
message="Download successful",
data=download_response,
)
16 changes: 8 additions & 8 deletions backend/spielberg/agents/image_generation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging

from spielberg.agents.base import BaseAgent, AgentResponse, AgentResult
from spielberg.core.session import Session, MsgStatus, ImageContent
from spielberg.agents.base import BaseAgent, AgentResponse, AgentStatus
from spielberg.core.session import Session, MsgStatus, ImageContent, ImageData
from spielberg.tools.replicate import flux_dev

logger = logging.getLogger(__name__)
Expand All @@ -14,7 +14,7 @@ def __init__(self, session: Session, **kwargs):
self.parameters = self.get_parameters()
super().__init__(session=session, **kwargs)

def __call__(self, prompt: str, *args, **kwargs) -> AgentResponse:
def run(self, prompt: str, *args, **kwargs) -> AgentResponse:
"""
Process the prompt to generate the image.
Expand All @@ -39,9 +39,9 @@ def __call__(self, prompt: str, *args, **kwargs) -> AgentResponse:
image_content.status_message = "Error in generating image."
self.output_message.publish()
error_message = "Agent failed with error in replicate."
return AgentResponse(result=AgentResult.ERROR, message=error_message)
return AgentResponse(status=AgentStatus.ERROR, message=error_message)
image_url = flux_output[0].url
image_content.image = {"url": image_url}
image_content.image = ImageData(url=image_url)
image_content.status = MsgStatus.success
image_content.status_message = "Here is your generated image"
self.output_message.publish()
Expand All @@ -51,9 +51,9 @@ def __call__(self, prompt: str, *args, **kwargs) -> AgentResponse:
image_content.status_message = "Error in generating image."
self.output_message.publish()
error_message = f"Agent failed with error {e}"
return AgentResponse(result=AgentResult.ERROR, message=error_message)
return AgentResponse(status=AgentStatus.ERROR, message=error_message)
return AgentResponse(
result=AgentResult.SUCCESS,
status=AgentStatus.SUCCESS,
message=f"Agent {self.name} completed successfully.",
data={"image_url": image_url},
data={},
)
8 changes: 4 additions & 4 deletions backend/spielberg/agents/index.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging

from spielberg.agents.base import BaseAgent, AgentResponse, AgentResult
from spielberg.agents.base import BaseAgent, AgentResponse, AgentStatus

from spielberg.core.session import Session
from spielberg.tools.videodb_tool import VideoDBTool
Expand Down Expand Up @@ -34,7 +34,7 @@ def __init__(self, session: Session, **kwargs):
self.parameters = INDEX_AGENT_PARAMETERS
super().__init__(session=session, **kwargs)

def __call__(
def run(
self, video_id: str, index_type: str, collection_id=None, *args, **kwargs
) -> AgentResponse:
"""
Expand Down Expand Up @@ -65,10 +65,10 @@ def __call__(

except Exception as e:
logger.exception(f"error in {self.agent_name} agent: {e}")
return AgentResponse(result=AgentResult.ERROR, message=str(e))
return AgentResponse(status=AgentStatus.ERROR, message=str(e))

return AgentResponse(
result=AgentResult.SUCCESS,
status=AgentStatus.SUCCESS,
message=f"{index_type} indexing successful",
data=scene_data,
)
10 changes: 5 additions & 5 deletions backend/spielberg/agents/pricing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging

from spielberg.agents.base import BaseAgent, AgentResponse, AgentResult
from spielberg.agents.base import BaseAgent, AgentResponse, AgentStatus
from spielberg.core.session import (
Session,
MsgStatus,
Expand Down Expand Up @@ -97,7 +97,7 @@ def __init__(self, session: Session, **kwargs):
self.llm = OpenAI()
super().__init__(session=session, **kwargs)

def __call__(self, query: str, *args, **kwargs) -> AgentResponse:
def run(self, query: str, *args, **kwargs) -> AgentResponse:
"""
Get the answer to the query of agent
Expand Down Expand Up @@ -125,7 +125,7 @@ def __call__(self, query: str, *args, **kwargs) -> AgentResponse:
text_content.status_message = "Failed to generate the response."
self.output_message.publish()
return AgentResponse(
result=AgentResult.ERROR,
status=AgentStatus.ERROR,
message="Pricing failed due to LLM error.",
)
text_content.text = llm_response.content
Expand All @@ -138,10 +138,10 @@ def __call__(self, query: str, *args, **kwargs) -> AgentResponse:
text_content.status = MsgStatus.error
text_content.status_message = error_message
self.output_message.publish()
return AgentResponse(result=AgentResult.ERROR, message=error_message)
return AgentResponse(status=AgentStatus.ERROR, message=error_message)

return AgentResponse(
result=AgentResult.SUCCESS,
status=AgentStatus.SUCCESS,
message="Fetch successful and output displayed above.",
data={},
)
13 changes: 7 additions & 6 deletions backend/spielberg/agents/profanity_remover.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@

from videodb.asset import VideoAsset, AudioAsset

from spielberg.agents.base import BaseAgent, AgentResponse, AgentResult
from spielberg.agents.base import BaseAgent, AgentResponse, AgentStatus
from spielberg.core.session import (
Session,
MsgStatus,
VideoContent,
VideoData,
ContextMessage,
RoleTypes,
)
Expand Down Expand Up @@ -45,7 +46,7 @@ def add_beep(self, videodb_tool, video_id, beep_audio_id, timestamps):
stream_url = timeline.generate_stream()
return stream_url

def __call__(
def run(
self,
collection_id: str,
video_id: str,
Expand All @@ -68,7 +69,7 @@ def __call__(
beep_audio_id = beep_audio_id or BEEP_AUDIO_ID
if not beep_audio_id:
return AgentResponse(
result=AgentResult.failed,
status=AgentStatus.failed,
message="Please provide the beep_audio_id or setup BEEP_AUDIO_ID in .env of backend directory.",
)
self.output_message.actions.append("Started process to remove profanity..")
Expand Down Expand Up @@ -100,7 +101,7 @@ def __call__(
clean_stream = self.add_beep(
videodb_tool, video_id, beep_audio_id, profanity_timeline
)
video_content.video = {"stream_url": clean_stream}
video_content.video = VideoData(stream_url=clean_stream)
video_content.status = MsgStatus.success
video_content.status_message = "Here is the clean stream"
self.output_message.content.append(video_content)
Expand All @@ -111,9 +112,9 @@ def __call__(
video_content.status_message = "Failed to generate clean stream"
self.output_message.publish()
error_message = f"Error in generating the clean stream due to {e}."
return AgentResponse(result=AgentResult.ERROR, message=error_message)
return AgentResponse(status=AgentStatus.ERROR, message=error_message)
return AgentResponse(
result=AgentResult.SUCCESS,
status=AgentStatus.SUCCESS,
message=f"Agent {self.name} completed successfully.",
data={"stream_url": clean_stream},
)
27 changes: 13 additions & 14 deletions backend/spielberg/agents/prompt_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
import json
import concurrent.futures

from spielberg.agents.base import BaseAgent, AgentResponse, AgentResult
from spielberg.agents.base import BaseAgent, AgentResponse, AgentStatus
from spielberg.core.session import (
Session,
ContextMessage,
RoleTypes,
MsgStatus,
VideoContent,
VideoData,
)
from spielberg.tools.videodb_tool import VideoDBTool
from spielberg.llm.openai import OpenAI
Expand Down Expand Up @@ -120,7 +121,7 @@ def _text_prompter(self, transcript_text, prompt):
continue
return matches

def __call__(
def run(
self, prompt: str, video_id: str, collection_id: str, *args, **kwargs
) -> AgentResponse:
try:
Expand Down Expand Up @@ -178,27 +179,25 @@ def __call__(
video_id=video_id, timeline=timeline
)
video_content.status_message = "Clip generated successfully."
video_content.video = {
"stream_url": stream_url,
}
video_content.video = VideoData(stream_url=stream_url)
video_content.status = MsgStatus.success
self.output_message.publish()

except Exception as e:
logger.exception(f"Error in creating video content: {e}")
return AgentResponse(result=AgentResult.ERROR, message=str(e))
return AgentResponse(status=AgentStatus.ERROR, message=str(e))

return AgentResponse(
status=AgentStatus.SUCCESS,
message=f"Agent {self.name} completed successfully.",
data={"stream_url": stream_url},
)
else:
return AgentResponse(
result=AgentResult.ERROR,
status=AgentStatus.ERROR,
message="No relevant moments found.",
)

except Exception as e:
logger.exception(f"error in {self.agent_name}")
return AgentResponse(result=AgentResult.ERROR, message=str(e))

return AgentResponse(
result=AgentResult.SUCCESS,
message=f"Agent {self.name} completed successfully.",
data={},
)
return AgentResponse(status=AgentStatus.ERROR, message=str(e))
Loading

0 comments on commit bf56081

Please sign in to comment.