-
Notifications
You must be signed in to change notification settings - Fork 85
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Ankit/add xai #83
base: main
Are you sure you want to change the base?
Ankit/add xai #83
Changes from all commits
3a193d3
1de457a
82cac1d
86983b2
d0da3fa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,6 +13,7 @@ | |
ContextMessage, | ||
RoleTypes, | ||
) | ||
from director.constants import LLMType | ||
from director.llm import get_default_llm | ||
from director.tools.videodb_tool import VideoDBTool | ||
|
||
|
@@ -34,6 +35,7 @@ def __init__(self, session: Session, **kwargs): | |
) | ||
self.parameters = self.get_parameters() | ||
self.llm = get_default_llm() | ||
self.supported_llm = [LLMType.OPENAI, LLMType.VIDEODB_PROXY] | ||
super().__init__(session=session, **kwargs) | ||
Comment on lines
37
to
39
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Add robust LLM validation and error handling The current implementation could be enhanced with:
Consider adding this validation in the constructor: def __init__(self, session: Session, **kwargs):
self.agent_name = "profanity_remover"
self.description = (
"Agent to beep the profanities in the given video and return the clean stream."
"if user has not given those optional param of beep_audio_id always try with sending it as None so that defaults are picked from env"
)
self.parameters = self.get_parameters()
+ self.supported_llm = [LLMType.OPENAI, LLMType.VIDEODB_PROXY]
self.llm = get_default_llm()
- self.supported_llm = [LLMType.OPENAI, LLMType.VIDEODB_PROXY]
+ if not self._check_supported_llm():
+ raise ValueError(f"Default LLM is not supported. Supported LLMs: {self.supported_llm}")
super().__init__(session=session, **kwargs) And add explicit error handling for JSON responses: llm_response = self.llm.chat_completions(
[profanity_llm_message.to_llm_msg()],
response_format={"type": "json_object"},
)
- profanity_timeline_response = json.loads(llm_response.content)
+ try:
+ profanity_timeline_response = json.loads(llm_response.content)
+ if not isinstance(profanity_timeline_response, dict) or "timestamps" not in profanity_timeline_response:
+ raise ValueError("Invalid response format from LLM")
+ except json.JSONDecodeError as e:
+ logger.error(f"Failed to parse LLM response: {e}")
+ raise ValueError(f"LLM response is not valid JSON: {llm_response.content[:100]}...")
|
||
|
||
def add_beep(self, videodb_tool, video_id, beep_audio_id, timestamps): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,6 +11,7 @@ | |
VideoContent, | ||
VideoData, | ||
) | ||
from director.constants import LLMType | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Consider adding error handling for LLM-specific JSON parsing The code assumes all supported LLMs will return properly formatted JSON. While Add LLM-specific error handling: def _prompt_runner(self, prompts):
matches = []
with concurrent.futures.ThreadPoolExecutor() as executor:
future_to_index = {
executor.submit(
self.llm.chat_completions,
[ContextMessage(content=prompt, role=RoleTypes.user).to_llm_msg()],
response_format={"type": "json_object"},
): i
for i, prompt in enumerate(prompts)
}
for future in concurrent.futures.as_completed(future_to_index):
try:
llm_response = future.result()
if not llm_response.status:
logger.error(f"LLM failed with {llm_response.content}")
continue
+ try:
output = json.loads(llm_response.content)
+ except json.JSONDecodeError as e:
+ logger.error(f"Failed to parse JSON from {self.llm.__class__.__name__}: {e}")
+ logger.debug(f"Raw response: {llm_response.content}")
+ continue
matches.extend(output["sentences"])
except Exception as e:
logger.exception(f"Error in getting matches: {e}")
continue Also applies to: 96-117 |
||
from director.tools.videodb_tool import VideoDBTool | ||
from director.llm import get_default_llm | ||
|
||
|
@@ -47,6 +48,7 @@ def __init__(self, session: Session, **kwargs): | |
self.description = "Generates video clips based on user prompts. This agent uses AI to analyze the text of a video transcript and identify sentences relevant to the user prompt for making clips. It then generates video clips based on the identified sentences. Use this tool to create clips based on specific themes or topics from a video." | ||
self.parameters = PROMPTCLIP_AGENT_PARAMETERS | ||
self.llm = get_default_llm() | ||
self.supported_llm = [LLMType.OPENAI, LLMType.VIDEODB_PROXY] | ||
super().__init__(session=session, **kwargs) | ||
|
||
def _chunk_docs(self, docs, chunk_size): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,7 +20,7 @@ | |
from director.agents.composio import ComposioAgent | ||
|
||
|
||
from director.core.session import Session, InputMessage, MsgStatus | ||
from director.core.session import Session, InputMessage, MsgStatus, TextContent | ||
from director.core.reasoning import ReasoningEngine | ||
from director.db.base import BaseDB | ||
from director.db import load_db | ||
|
@@ -102,6 +102,9 @@ def chat(self, message): | |
res_eng.run() | ||
|
||
except Exception as e: | ||
session.output_message.content.append( | ||
TextContent(text=f"{e}", status=MsgStatus.error) | ||
) | ||
Comment on lines
+105
to
+107
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Enhance error handling for LLM-specific scenarios While the structured error reporting is good, consider these improvements:
Consider this implementation: - session.output_message.content.append(
- TextContent(text=f"{e}", status=MsgStatus.error)
- )
+ error_msg = str(e)
+ if isinstance(e, LLMError): # Add this class for LLM-specific errors
+ error_msg = f"LLM Error: {error_msg}. Please verify LLM configuration and API keys."
+ session.output_message.content.append(
+ TextContent(
+ text=sanitize_error_message(error_msg), # Add this helper function
+ status=MsgStatus.error,
+ error_type=getattr(e, '__class__.__name__', 'Unknown')
+ )
+ )
|
||
session.output_message.update_status(MsgStatus.error) | ||
logger.exception(f"Error in chat handler: {e}") | ||
|
||
|
Original file line number | Diff line number | Diff line change | ||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -4,6 +4,7 @@ | |||||||||||
|
||||||||||||
from director.llm.openai import OpenAI | ||||||||||||
from director.llm.anthropic import AnthropicAI | ||||||||||||
from director.llm.xai import XAI | ||||||||||||
from director.llm.videodb_proxy import VideoDBProxy | ||||||||||||
|
||||||||||||
|
||||||||||||
|
@@ -12,12 +13,15 @@ def get_default_llm(): | |||||||||||
|
||||||||||||
openai = True if os.getenv("OPENAI_API_KEY") else False | ||||||||||||
anthropic = True if os.getenv("ANTHROPIC_API_KEY") else False | ||||||||||||
xai = True if os.getenv("XAI_API_KEY") else False | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Consider adding API key validation The current implementation doesn't validate the XAI_API_KEY format or value. Consider adding validation to fail fast if the API key is invalid. - xai = True if os.getenv("XAI_API_KEY") else False
+ xai_key = os.getenv("XAI_API_KEY")
+ if xai_key and not xai_key.strip():
+ raise ValueError("XAI_API_KEY is empty or contains only whitespace")
+ xai = bool(xai_key) 📝 Committable suggestion
Suggested change
🧰 Tools🪛 Ruff (0.8.0)16-16: Use Replace with `bool(...) (SIM210) |
||||||||||||
|
||||||||||||
default_llm = os.getenv("DEFAULT_LLM") | ||||||||||||
|
||||||||||||
if openai or default_llm == LLMType.OPENAI: | ||||||||||||
return OpenAI() | ||||||||||||
elif anthropic or default_llm == LLMType.ANTHROPIC: | ||||||||||||
return AnthropicAI() | ||||||||||||
elif xai or default_llm == LLMType.XAI: | ||||||||||||
return XAI() | ||||||||||||
Comment on lines
+24
to
+25
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 💡 Codebase verification Consider implementing JSON response format support in XAI class The verification revealed that several agents (subtitle.py, profanity_remover.py, prompt_clip.py, meme_maker.py) require JSON response format, but the XAI implementation doesn't handle the
🔗 Analysis chainConsider implementing fallback mechanism for non-optimized agents As mentioned in the PR comments, some agents might not be optimized for xAI Grok, particularly in JSON mode. Consider implementing a fallback mechanism to use OpenAI or VideoDB Proxy in such cases. Let's verify if there are any JSON mode agents that might need this fallback: 🏁 Scripts executedThe following scripts were executed for the analysis: Script: #!/bin/bash
# Search for agent configurations that specify JSON mode
rg -l "json_mode|output_format.*json" --type python
# Search for existing fallback implementations we can reference
rg -l "fallback.*llm|backup.*llm" --type python
Length of output: 165 Script: #!/bin/bash
# Search for agent configurations that specify JSON mode without file type filter
rg -l "json_mode|output_format.*json"
# Search for existing fallback implementations we can reference
rg -l "fallback.*llm|backup.*llm"
# Look for agent configurations and LLM settings
fd -e py -e yaml -e yml . | xargs rg -l "output_format|json_mode|response_format"
Length of output: 484 Script: #!/bin/bash
# Check the content of agent files to understand their JSON requirements
rg -A 5 "output_format|json_mode|response_format" backend/director/agents/
# Check XAI implementation for JSON handling
cat backend/director/llm/xai.py
Length of output: 8229 |
||||||||||||
else: | ||||||||||||
return VideoDBProxy() |
Original file line number | Diff line number | Diff line change | ||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,181 @@ | ||||||||||||||||||
import json | ||||||||||||||||||
from enum import Enum | ||||||||||||||||||
|
||||||||||||||||||
from pydantic import Field, field_validator, FieldValidationInfo | ||||||||||||||||||
from pydantic_settings import SettingsConfigDict | ||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
from director.llm.base import BaseLLM, BaseLLMConfig, LLMResponse, LLMResponseStatus | ||||||||||||||||||
from director.constants import ( | ||||||||||||||||||
LLMType, | ||||||||||||||||||
EnvPrefix, | ||||||||||||||||||
) | ||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
class XAIModel(str, Enum): | ||||||||||||||||||
"""Enum for XAI Chat models""" | ||||||||||||||||||
|
||||||||||||||||||
GROK_BETA = "grok-beta" | ||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
class XAIConfig(BaseLLMConfig): | ||||||||||||||||||
"""XAI Config""" | ||||||||||||||||||
|
||||||||||||||||||
model_config = SettingsConfigDict( | ||||||||||||||||||
env_prefix=EnvPrefix.XAI_, | ||||||||||||||||||
extra="ignore", | ||||||||||||||||||
) | ||||||||||||||||||
|
||||||||||||||||||
llm_type: str = LLMType.XAI | ||||||||||||||||||
api_key: str = "" | ||||||||||||||||||
api_base: str = "https://api.x.ai/v1" | ||||||||||||||||||
chat_model: str = Field(default=XAIModel.GROK_BETA) | ||||||||||||||||||
max_tokens: int = 4096 | ||||||||||||||||||
|
||||||||||||||||||
@field_validator("api_key") | ||||||||||||||||||
@classmethod | ||||||||||||||||||
def validate_non_empty(cls, v, info: FieldValidationInfo): | ||||||||||||||||||
if not v: | ||||||||||||||||||
raise ValueError( | ||||||||||||||||||
f"{info.field_name} must not be empty. please set {EnvPrefix.XAI_.value}{info.field_name.upper()} environment variable." | ||||||||||||||||||
) | ||||||||||||||||||
return v | ||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
class XAI(BaseLLM): | ||||||||||||||||||
def __init__(self, config: XAIConfig = None): | ||||||||||||||||||
""" | ||||||||||||||||||
:param config: XAI Config | ||||||||||||||||||
""" | ||||||||||||||||||
if config is None: | ||||||||||||||||||
config = XAIConfig() | ||||||||||||||||||
super().__init__(config=config) | ||||||||||||||||||
try: | ||||||||||||||||||
import openai | ||||||||||||||||||
except ImportError: | ||||||||||||||||||
raise ImportError("Please install OpenAI python library.") | ||||||||||||||||||
Comment on lines
+53
to
+56
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use 'raise ... from e' when re-raising exceptions in except blocks When raising an exception within an Apply this diff to fix the issue: try:
import openai
-except ImportError:
+except ImportError as e:
- raise ImportError("Please install OpenAI python library.")
+ raise ImportError("Please install OpenAI python library.") from e 📝 Committable suggestion
Suggested change
🧰 Tools🪛 Ruff (0.8.0)56-56: Within an (B904) |
||||||||||||||||||
|
||||||||||||||||||
self.client = openai.OpenAI(api_key=self.api_key, base_url=self.api_base) | ||||||||||||||||||
Comment on lines
+54
to
+58
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Correct the usage of the OpenAI API client The OpenAI Python library does not have an Apply this diff to fix the issue: try:
import openai
except ImportError as e:
raise ImportError("Please install OpenAI python library.") from e
-self.client = openai.OpenAI(api_key=self.api_key, base_url=self.api_base)
+openai.api_key = self.api_key
+openai.api_base = self.api_base Update the try:
- response = self.client.chat.completions.create(**params)
+ response = openai.ChatCompletion.create(**params)
except Exception as e:
logging.error(f"Error in chat_completions: {e}")
return LLMResponse(content=f"Error: {e}") Also applies to: 156-157 🧰 Tools🪛 Ruff (0.8.0)56-56: Within an (B904) |
||||||||||||||||||
|
||||||||||||||||||
def init_langfuse(self): | ||||||||||||||||||
from langfuse.decorators import observe | ||||||||||||||||||
|
||||||||||||||||||
self.chat_completions = observe(name=type(self).__name__)(self.chat_completions) | ||||||||||||||||||
self.text_completions = observe(name=type(self).__name__)(self.text_completions) | ||||||||||||||||||
|
||||||||||||||||||
def _format_messages(self, messages: list): | ||||||||||||||||||
"""Format the messages to the format that OpenAI expects.""" | ||||||||||||||||||
formatted_messages = [] | ||||||||||||||||||
for message in messages: | ||||||||||||||||||
if message["role"] == "assistant" and message.get("tool_calls"): | ||||||||||||||||||
formatted_messages.append( | ||||||||||||||||||
{ | ||||||||||||||||||
"role": message["role"], | ||||||||||||||||||
"content": message["content"], | ||||||||||||||||||
"tool_calls": [ | ||||||||||||||||||
{ | ||||||||||||||||||
"id": tool_call["id"], | ||||||||||||||||||
"function": { | ||||||||||||||||||
"name": tool_call["tool"]["name"], | ||||||||||||||||||
"arguments": json.dumps( | ||||||||||||||||||
tool_call["tool"]["arguments"] | ||||||||||||||||||
), | ||||||||||||||||||
}, | ||||||||||||||||||
"type": tool_call["type"], | ||||||||||||||||||
} | ||||||||||||||||||
for tool_call in message["tool_calls"] | ||||||||||||||||||
], | ||||||||||||||||||
} | ||||||||||||||||||
) | ||||||||||||||||||
else: | ||||||||||||||||||
formatted_messages.append(message) | ||||||||||||||||||
return formatted_messages | ||||||||||||||||||
|
||||||||||||||||||
def _format_tools(self, tools: list): | ||||||||||||||||||
"""Format the tools to the format that OpenAI expects. | ||||||||||||||||||
|
||||||||||||||||||
**Example**:: | ||||||||||||||||||
|
||||||||||||||||||
[ | ||||||||||||||||||
{ | ||||||||||||||||||
"type": "function", | ||||||||||||||||||
"function": { | ||||||||||||||||||
"name": "get_delivery_date", | ||||||||||||||||||
"description": "Get the delivery date for a customer's order.", | ||||||||||||||||||
"parameters": { | ||||||||||||||||||
"type": "object", | ||||||||||||||||||
"properties": { | ||||||||||||||||||
"order_id": { | ||||||||||||||||||
"type": "string", | ||||||||||||||||||
"description": "The customer's order ID." | ||||||||||||||||||
} | ||||||||||||||||||
}, | ||||||||||||||||||
"required": ["order_id"], | ||||||||||||||||||
"additionalProperties": False | ||||||||||||||||||
} | ||||||||||||||||||
} | ||||||||||||||||||
} | ||||||||||||||||||
] | ||||||||||||||||||
""" | ||||||||||||||||||
formatted_tools = [] | ||||||||||||||||||
for tool in tools: | ||||||||||||||||||
formatted_tools.append( | ||||||||||||||||||
{ | ||||||||||||||||||
"type": "function", | ||||||||||||||||||
"function": { | ||||||||||||||||||
"name": tool["name"], | ||||||||||||||||||
"description": tool["description"], | ||||||||||||||||||
"parameters": tool["parameters"], | ||||||||||||||||||
}, | ||||||||||||||||||
"strict": True, | ||||||||||||||||||
} | ||||||||||||||||||
) | ||||||||||||||||||
return formatted_tools | ||||||||||||||||||
|
||||||||||||||||||
def chat_completions( | ||||||||||||||||||
self, messages: list, tools: list = [], stop=None, response_format=None | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Avoid mutable default arguments Using a mutable default argument like Apply this diff to fix the issue: def chat_completions(
- self, messages: list, tools: list = [], stop=None, response_format=None
+ self, messages: list, tools: list = None, stop=None, response_format=None
):
+ if tools is None:
+ tools = []
🧰 Tools🪛 Ruff (0.8.0)136-136: Do not use mutable data structures for argument defaults Replace with (B006) |
||||||||||||||||||
): | ||||||||||||||||||
"""Get completions for chat. | ||||||||||||||||||
|
||||||||||||||||||
docs: https://docs.x.ai/docs/guides/function-calling | ||||||||||||||||||
""" | ||||||||||||||||||
params = { | ||||||||||||||||||
"model": self.chat_model, | ||||||||||||||||||
"messages": self._format_messages(messages), | ||||||||||||||||||
"temperature": self.temperature, | ||||||||||||||||||
"max_tokens": self.max_tokens, | ||||||||||||||||||
"top_p": self.top_p, | ||||||||||||||||||
"stop": stop, | ||||||||||||||||||
"timeout": self.timeout, | ||||||||||||||||||
} | ||||||||||||||||||
if tools: | ||||||||||||||||||
params["tools"] = self._format_tools(tools) | ||||||||||||||||||
params["tool_choice"] = "auto" | ||||||||||||||||||
|
||||||||||||||||||
try: | ||||||||||||||||||
response = self.client.chat.completions.create(**params) | ||||||||||||||||||
except Exception as e: | ||||||||||||||||||
print(f"Error: {e}") | ||||||||||||||||||
return LLMResponse(content=f"Error: {e}") | ||||||||||||||||||
Comment on lines
+158
to
+159
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Use logging instead of print statements for error handling Using the Apply this diff to fix the issue: +import logging
# ... rest of the code ...
try:
response = openai.ChatCompletion.create(**params)
except Exception as e:
- print(f"Error: {e}")
+ logging.error(f"Error in chat_completions: {e}")
return LLMResponse(content=f"Error: {e}")
|
||||||||||||||||||
|
||||||||||||||||||
return LLMResponse( | ||||||||||||||||||
content=response.choices[0].message.content or "", | ||||||||||||||||||
tool_calls=[ | ||||||||||||||||||
{ | ||||||||||||||||||
"id": tool_call.id, | ||||||||||||||||||
"tool": { | ||||||||||||||||||
"name": tool_call.function.name, | ||||||||||||||||||
"arguments": json.loads(tool_call.function.arguments), | ||||||||||||||||||
}, | ||||||||||||||||||
"type": tool_call.type, | ||||||||||||||||||
} | ||||||||||||||||||
for tool_call in response.choices[0].message.tool_calls | ||||||||||||||||||
] | ||||||||||||||||||
if response.choices[0].message.tool_calls | ||||||||||||||||||
else [], | ||||||||||||||||||
finish_reason=response.choices[0].finish_reason, | ||||||||||||||||||
send_tokens=response.usage.prompt_tokens, | ||||||||||||||||||
recv_tokens=response.usage.completion_tokens, | ||||||||||||||||||
total_tokens=response.usage.total_tokens, | ||||||||||||||||||
status=LLMResponseStatus.SUCCESS, | ||||||||||||||||||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
## XAI | ||
|
||
XAI extends the base LLM and implements the XAI API. | ||
|
||
### XAI Config | ||
|
||
XAI Config is the configuration object for XAI. It is used to configure XAI and is passed to XAI when it is created. | ||
|
||
::: director.llm.xai.XAIConfig | ||
|
||
### XAI Interface | ||
|
||
XAI is the LLM used by the agents and tools. It is used to generate responses to messages. | ||
|
||
::: director.llm.xai.XAI |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Simplify conditional checks and enhance error handling
The method implementation can be improved for better readability and maintainability.
Apply this diff to simplify the nested conditions and improve error message construction:
📝 Committable suggestion
🧰 Tools
🪛 Ruff (0.8.0)
46-47: Use a single
if
statement instead of nestedif
statements(SIM102)