diff --git a/backend/.env.sample b/backend/.env.sample index dd9fba80..b5629614 100644 --- a/backend/.env.sample +++ b/backend/.env.sample @@ -4,6 +4,7 @@ VIDEO_DB_API_KEY= # LLM Integrations OPENAI_API_KEY= ANTHROPIC_API_KEY= +XAI_API_KEY= # Tools REPLICATE_API_TOKEN= diff --git a/backend/director/agents/base.py b/backend/director/agents/base.py index 3a37a42d..01b2e0f8 100644 --- a/backend/director/agents/base.py +++ b/backend/director/agents/base.py @@ -5,7 +5,7 @@ from openai_function_calling import FunctionInferrer -from director.core.session import Session, OutputMessage +from director.core.session import Session, OutputMessage, TextContent, MsgStatus logger = logging.getLogger(__name__) @@ -41,6 +41,20 @@ def get_parameters(self): ) return parameters + def _check_supported_llm(self): + """Check if supported_llm is present and validate LLM.""" + if hasattr(self, "supported_llm") and hasattr(self, "llm"): + if self.llm.llm_type not in self.supported_llm: + error = f"`@{self.agent_name}` Agent does not support the configured LLM `{self.llm.llm_type.upper()}`, To use this Agent, please configure one of the following LLMs: {[llm.upper() for llm in self.supported_llm]}." + self.output_message.content.append( + TextContent( + status_message="LLM not supported.", + text=error, + status=MsgStatus.error, + ) + ) + raise Exception(error) + def to_llm_format(self): """Convert the agent to LLM tool format.""" return { @@ -59,6 +73,7 @@ def agent_description(self): def safe_call(self, *args, **kwargs): try: + self._check_supported_llm() return self.run(*args, **kwargs) except Exception as e: diff --git a/backend/director/agents/meme_maker.py b/backend/director/agents/meme_maker.py index a7b8f3e3..80caf748 100644 --- a/backend/director/agents/meme_maker.py +++ b/backend/director/agents/meme_maker.py @@ -11,6 +11,7 @@ VideoContent, VideoData, ) +from director.constants import LLMType from director.tools.videodb_tool import VideoDBTool from director.llm import get_default_llm @@ -42,6 +43,7 @@ def __init__(self, session: Session, **kwargs): self.description = "Generates meme clips and images based on user prompts. This agent usages LLM to analyze the transcript and visual content of the video to generate memes." self.parameters = MEMEMAKER_PARAMETERS self.llm = get_default_llm() + self.supported_llm = [LLMType.OPENAI, LLMType.VIDEODB_PROXY] super().__init__(session=session, **kwargs) def _chunk_docs(self, docs, chunk_size): diff --git a/backend/director/agents/profanity_remover.py b/backend/director/agents/profanity_remover.py index 37b0922a..ee980cac 100644 --- a/backend/director/agents/profanity_remover.py +++ b/backend/director/agents/profanity_remover.py @@ -13,6 +13,7 @@ ContextMessage, RoleTypes, ) +from director.constants import LLMType from director.llm import get_default_llm from director.tools.videodb_tool import VideoDBTool @@ -34,6 +35,7 @@ def __init__(self, session: Session, **kwargs): ) self.parameters = self.get_parameters() self.llm = get_default_llm() + self.supported_llm = [LLMType.OPENAI, LLMType.VIDEODB_PROXY] super().__init__(session=session, **kwargs) def add_beep(self, videodb_tool, video_id, beep_audio_id, timestamps): diff --git a/backend/director/agents/prompt_clip.py b/backend/director/agents/prompt_clip.py index 9c9816b9..ee962ff5 100644 --- a/backend/director/agents/prompt_clip.py +++ b/backend/director/agents/prompt_clip.py @@ -11,6 +11,7 @@ VideoContent, VideoData, ) +from director.constants import LLMType from director.tools.videodb_tool import VideoDBTool from director.llm import get_default_llm @@ -47,6 +48,7 @@ def __init__(self, session: Session, **kwargs): self.description = "Generates video clips based on user prompts. This agent uses AI to analyze the text of a video transcript and identify sentences relevant to the user prompt for making clips. It then generates video clips based on the identified sentences. Use this tool to create clips based on specific themes or topics from a video." self.parameters = PROMPTCLIP_AGENT_PARAMETERS self.llm = get_default_llm() + self.supported_llm = [LLMType.OPENAI, LLMType.VIDEODB_PROXY] super().__init__(session=session, **kwargs) def _chunk_docs(self, docs, chunk_size): diff --git a/backend/director/agents/subtitle.py b/backend/director/agents/subtitle.py index 7025e81c..4d490692 100644 --- a/backend/director/agents/subtitle.py +++ b/backend/director/agents/subtitle.py @@ -11,6 +11,7 @@ VideoData, MsgStatus, ) +from director.constants import LLMType from director.tools.videodb_tool import VideoDBTool from director.llm import get_default_llm @@ -111,6 +112,7 @@ def __init__(self, session: Session, **kwargs): self.description = "An agent designed to add different languages subtitles to a specified video within VideoDB." self.llm = get_default_llm() self.parameters = SUBTITLE_AGENT_PARAMETERS + self.supported_llm = [LLMType.OPENAI, LLMType.VIDEODB_PROXY] super().__init__(session=session, **kwargs) def wrap_text(self, text, video_width, max_width_percent=0.60, avg_char_width=20): diff --git a/backend/director/constants.py b/backend/director/constants.py index f370aac9..6b3f2c39 100644 --- a/backend/director/constants.py +++ b/backend/director/constants.py @@ -19,6 +19,7 @@ class LLMType(str, Enum): OPENAI = "openai" ANTHROPIC = "anthropic" + XAI = "xai" VIDEODB_PROXY = "videodb_proxy" @@ -27,5 +28,7 @@ class EnvPrefix(str, Enum): OPENAI_ = "OPENAI_" ANTHROPIC_ = "ANTHROPIC_" + XAI_ = "XAI_" -DOWNLOADS_PATH="director/downloads" + +DOWNLOADS_PATH = "director/downloads" diff --git a/backend/director/handler.py b/backend/director/handler.py index 5a65de71..b3346284 100644 --- a/backend/director/handler.py +++ b/backend/director/handler.py @@ -20,7 +20,7 @@ from director.agents.composio import ComposioAgent -from director.core.session import Session, InputMessage, MsgStatus +from director.core.session import Session, InputMessage, MsgStatus, TextContent from director.core.reasoning import ReasoningEngine from director.db.base import BaseDB from director.db import load_db @@ -102,6 +102,9 @@ def chat(self, message): res_eng.run() except Exception as e: + session.output_message.content.append( + TextContent(text=f"{e}", status=MsgStatus.error) + ) session.output_message.update_status(MsgStatus.error) logger.exception(f"Error in chat handler: {e}") diff --git a/backend/director/llm/__init__.py b/backend/director/llm/__init__.py index b909bb3b..658abe49 100644 --- a/backend/director/llm/__init__.py +++ b/backend/director/llm/__init__.py @@ -4,6 +4,7 @@ from director.llm.openai import OpenAI from director.llm.anthropic import AnthropicAI +from director.llm.xai import XAI from director.llm.videodb_proxy import VideoDBProxy @@ -12,6 +13,7 @@ def get_default_llm(): openai = True if os.getenv("OPENAI_API_KEY") else False anthropic = True if os.getenv("ANTHROPIC_API_KEY") else False + xai = True if os.getenv("XAI_API_KEY") else False default_llm = os.getenv("DEFAULT_LLM") @@ -19,5 +21,7 @@ def get_default_llm(): return OpenAI() elif anthropic or default_llm == LLMType.ANTHROPIC: return AnthropicAI() + elif xai or default_llm == LLMType.XAI: + return XAI() else: return VideoDBProxy() diff --git a/backend/director/llm/openai.py b/backend/director/llm/openai.py index 3090ff2b..8d21475b 100644 --- a/backend/director/llm/openai.py +++ b/backend/director/llm/openai.py @@ -156,7 +156,7 @@ def chat_completions( params["tools"] = self._format_tools(tools) params["tool_choice"] = "auto" - if response_format: + if response_format and self.config.api_base == "https://api.openai.com/v1": params["response_format"] = response_format try: diff --git a/backend/director/llm/xai.py b/backend/director/llm/xai.py new file mode 100644 index 00000000..60ca96fc --- /dev/null +++ b/backend/director/llm/xai.py @@ -0,0 +1,181 @@ +import json +from enum import Enum + +from pydantic import Field, field_validator, FieldValidationInfo +from pydantic_settings import SettingsConfigDict + + +from director.llm.base import BaseLLM, BaseLLMConfig, LLMResponse, LLMResponseStatus +from director.constants import ( + LLMType, + EnvPrefix, +) + + +class XAIModel(str, Enum): + """Enum for XAI Chat models""" + + GROK_BETA = "grok-beta" + + +class XAIConfig(BaseLLMConfig): + """XAI Config""" + + model_config = SettingsConfigDict( + env_prefix=EnvPrefix.XAI_, + extra="ignore", + ) + + llm_type: str = LLMType.XAI + api_key: str = "" + api_base: str = "https://api.x.ai/v1" + chat_model: str = Field(default=XAIModel.GROK_BETA) + max_tokens: int = 4096 + + @field_validator("api_key") + @classmethod + def validate_non_empty(cls, v, info: FieldValidationInfo): + if not v: + raise ValueError( + f"{info.field_name} must not be empty. please set {EnvPrefix.XAI_.value}{info.field_name.upper()} environment variable." + ) + return v + + +class XAI(BaseLLM): + def __init__(self, config: XAIConfig = None): + """ + :param config: XAI Config + """ + if config is None: + config = XAIConfig() + super().__init__(config=config) + try: + import openai + except ImportError: + raise ImportError("Please install OpenAI python library.") + + self.client = openai.OpenAI(api_key=self.api_key, base_url=self.api_base) + + def init_langfuse(self): + from langfuse.decorators import observe + + self.chat_completions = observe(name=type(self).__name__)(self.chat_completions) + self.text_completions = observe(name=type(self).__name__)(self.text_completions) + + def _format_messages(self, messages: list): + """Format the messages to the format that OpenAI expects.""" + formatted_messages = [] + for message in messages: + if message["role"] == "assistant" and message.get("tool_calls"): + formatted_messages.append( + { + "role": message["role"], + "content": message["content"], + "tool_calls": [ + { + "id": tool_call["id"], + "function": { + "name": tool_call["tool"]["name"], + "arguments": json.dumps( + tool_call["tool"]["arguments"] + ), + }, + "type": tool_call["type"], + } + for tool_call in message["tool_calls"] + ], + } + ) + else: + formatted_messages.append(message) + return formatted_messages + + def _format_tools(self, tools: list): + """Format the tools to the format that OpenAI expects. + + **Example**:: + + [ + { + "type": "function", + "function": { + "name": "get_delivery_date", + "description": "Get the delivery date for a customer's order.", + "parameters": { + "type": "object", + "properties": { + "order_id": { + "type": "string", + "description": "The customer's order ID." + } + }, + "required": ["order_id"], + "additionalProperties": False + } + } + } + ] + """ + formatted_tools = [] + for tool in tools: + formatted_tools.append( + { + "type": "function", + "function": { + "name": tool["name"], + "description": tool["description"], + "parameters": tool["parameters"], + }, + "strict": True, + } + ) + return formatted_tools + + def chat_completions( + self, messages: list, tools: list = [], stop=None, response_format=None + ): + """Get completions for chat. + + docs: https://docs.x.ai/docs/guides/function-calling + """ + params = { + "model": self.chat_model, + "messages": self._format_messages(messages), + "temperature": self.temperature, + "max_tokens": self.max_tokens, + "top_p": self.top_p, + "stop": stop, + "timeout": self.timeout, + } + if tools: + params["tools"] = self._format_tools(tools) + params["tool_choice"] = "auto" + + try: + response = self.client.chat.completions.create(**params) + except Exception as e: + print(f"Error: {e}") + return LLMResponse(content=f"Error: {e}") + + return LLMResponse( + content=response.choices[0].message.content or "", + tool_calls=[ + { + "id": tool_call.id, + "tool": { + "name": tool_call.function.name, + "arguments": json.loads(tool_call.function.arguments), + }, + "type": tool_call.type, + } + for tool_call in response.choices[0].message.tool_calls + ] + if response.choices[0].message.tool_calls + else [], + finish_reason=response.choices[0].finish_reason, + send_tokens=response.usage.prompt_tokens, + recv_tokens=response.usage.completion_tokens, + total_tokens=response.usage.total_tokens, + status=LLMResponseStatus.SUCCESS, + ) diff --git a/docs/llm/xai.md b/docs/llm/xai.md new file mode 100644 index 00000000..29b7da3c --- /dev/null +++ b/docs/llm/xai.md @@ -0,0 +1,15 @@ +## XAI + +XAI extends the base LLM and implements the XAI API. + +### XAI Config + +XAI Config is the configuration object for XAI. It is used to configure XAI and is passed to XAI when it is created. + +::: director.llm.xai.XAIConfig + +### XAI Interface + +XAI is the LLM used by the agents and tools. It is used to generate responses to messages. + +::: director.llm.xai.XAI diff --git a/mkdocs.yml b/mkdocs.yml index 196ad5e9..e8d0ab77 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -74,6 +74,7 @@ nav: - Integrations: - 'OpenAI': 'llm/openai.md' - 'AnthropicAI': 'llm/anthropic.md' + - 'XAI': 'llm/xai.md' - 'Database': - 'Interface': 'database/interface.md' - Integrations: