diff --git a/docs/custom_llm.md b/docs/custom_llm.md new file mode 100644 index 0000000000..0e44cca5d4 --- /dev/null +++ b/docs/custom_llm.md @@ -0,0 +1,681 @@ +# Custom LLM Implementations + +CrewAI supports custom LLM implementations through the `LLM` base class. This allows you to create your own LLM implementations that don't rely on litellm's authentication mechanism. + +## Using Custom LLM Implementations + +To create a custom LLM implementation, you need to: + +1. Inherit from the `LLM` base class +2. Implement the required methods: + - `call()`: The main method to call the LLM with messages + - `supports_function_calling()`: Whether the LLM supports function calling + - `supports_stop_words()`: Whether the LLM supports stop words + - `get_context_window_size()`: The context window size of the LLM + +## Using the Default LLM Implementation + +If you don't need a custom LLM implementation, you can use the default implementation provided by CrewAI: + +```python +from crewai import LLM + +# Create a default LLM instance +llm = LLM.create(model="gpt-4") + +# Or with more parameters +llm = LLM.create( + model="gpt-4", + temperature=0.7, + max_tokens=1000, + api_key="your-api-key" +) +``` + +## Example: Basic Custom LLM + +```python +from crewai import LLM +from typing import Any, Dict, List, Optional, Union + +class CustomLLM(LLM): + def __init__(self, api_key: str, endpoint: str): + super().__init__() # Initialize the base class to set default attributes + if not api_key or not isinstance(api_key, str): + raise ValueError("Invalid API key: must be a non-empty string") + if not endpoint or not isinstance(endpoint, str): + raise ValueError("Invalid endpoint URL: must be a non-empty string") + self.api_key = api_key + self.endpoint = endpoint + self.stop = [] # You can customize stop words if needed + + def call( + self, + messages: Union[str, List[Dict[str, str]]], + tools: Optional[List[dict]] = None, + callbacks: Optional[List[Any]] = None, + available_functions: Optional[Dict[str, Any]] = None, + ) -> Union[str, Any]: + """Call the LLM with the given messages. + + Args: + messages: Input messages for the LLM. + tools: Optional list of tool schemas for function calling. + callbacks: Optional list of callback functions. + available_functions: Optional dict mapping function names to callables. + + Returns: + Either a text response from the LLM or the result of a tool function call. + + Raises: + TimeoutError: If the LLM request times out. + RuntimeError: If the LLM request fails for other reasons. + ValueError: If the response format is invalid. + """ + # Implement your own logic to call the LLM + # For example, using requests: + import requests + + try: + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json" + } + + # Convert string message to proper format if needed + if isinstance(messages, str): + messages = [{"role": "user", "content": messages}] + + data = { + "messages": messages, + "tools": tools + } + + response = requests.post( + self.endpoint, + headers=headers, + json=data, + timeout=30 # Set a reasonable timeout + ) + response.raise_for_status() # Raise an exception for HTTP errors + return response.json()["choices"][0]["message"]["content"] + except requests.Timeout: + raise TimeoutError("LLM request timed out") + except requests.RequestException as e: + raise RuntimeError(f"LLM request failed: {str(e)}") + except (KeyError, IndexError, ValueError) as e: + raise ValueError(f"Invalid response format: {str(e)}") + + def supports_function_calling(self) -> bool: + """Check if the LLM supports function calling. + + Returns: + True if the LLM supports function calling, False otherwise. + """ + # Return True if your LLM supports function calling + return True + + def supports_stop_words(self) -> bool: + """Check if the LLM supports stop words. + + Returns: + True if the LLM supports stop words, False otherwise. + """ + # Return True if your LLM supports stop words + return True + + def get_context_window_size(self) -> int: + """Get the context window size of the LLM. + + Returns: + The context window size as an integer. + """ + # Return the context window size of your LLM + return 8192 +``` + +## Error Handling Best Practices + +When implementing custom LLMs, it's important to handle errors properly to ensure robustness and reliability. Here are some best practices: + +### 1. Implement Try-Except Blocks for API Calls + +Always wrap API calls in try-except blocks to handle different types of errors: + +```python +def call( + self, + messages: Union[str, List[Dict[str, str]]], + tools: Optional[List[dict]] = None, + callbacks: Optional[List[Any]] = None, + available_functions: Optional[Dict[str, Any]] = None, +) -> Union[str, Any]: + try: + # API call implementation + response = requests.post( + self.endpoint, + headers=self.headers, + json=self.prepare_payload(messages), + timeout=30 # Set a reasonable timeout + ) + response.raise_for_status() # Raise an exception for HTTP errors + return response.json()["choices"][0]["message"]["content"] + except requests.Timeout: + raise TimeoutError("LLM request timed out") + except requests.RequestException as e: + raise RuntimeError(f"LLM request failed: {str(e)}") + except (KeyError, IndexError, ValueError) as e: + raise ValueError(f"Invalid response format: {str(e)}") +``` + +### 2. Implement Retry Logic for Transient Failures + +For transient failures like network issues or rate limiting, implement retry logic with exponential backoff: + +```python +def call( + self, + messages: Union[str, List[Dict[str, str]]], + tools: Optional[List[dict]] = None, + callbacks: Optional[List[Any]] = None, + available_functions: Optional[Dict[str, Any]] = None, +) -> Union[str, Any]: + import time + + max_retries = 3 + retry_delay = 1 # seconds + + for attempt in range(max_retries): + try: + response = requests.post( + self.endpoint, + headers=self.headers, + json=self.prepare_payload(messages), + timeout=30 + ) + response.raise_for_status() + return response.json()["choices"][0]["message"]["content"] + except (requests.Timeout, requests.ConnectionError) as e: + if attempt < max_retries - 1: + time.sleep(retry_delay * (2 ** attempt)) # Exponential backoff + continue + raise TimeoutError(f"LLM request failed after {max_retries} attempts: {str(e)}") + except requests.RequestException as e: + raise RuntimeError(f"LLM request failed: {str(e)}") +``` + +### 3. Validate Input Parameters + +Always validate input parameters to prevent runtime errors: + +```python +def __init__(self, api_key: str, endpoint: str): + super().__init__() + if not api_key or not isinstance(api_key, str): + raise ValueError("Invalid API key: must be a non-empty string") + if not endpoint or not isinstance(endpoint, str): + raise ValueError("Invalid endpoint URL: must be a non-empty string") + self.api_key = api_key + self.endpoint = endpoint +``` + +### 4. Handle Authentication Errors Gracefully + +Provide clear error messages for authentication failures: + +```python +def call( + self, + messages: Union[str, List[Dict[str, str]]], + tools: Optional[List[dict]] = None, + callbacks: Optional[List[Any]] = None, + available_functions: Optional[Dict[str, Any]] = None, +) -> Union[str, Any]: + try: + response = requests.post(self.endpoint, headers=self.headers, json=data) + if response.status_code == 401: + raise ValueError("Authentication failed: Invalid API key or token") + elif response.status_code == 403: + raise ValueError("Authorization failed: Insufficient permissions") + response.raise_for_status() + # Process response + except Exception as e: + # Handle error + raise +``` + +## Example: JWT-based Authentication + +For services that use JWT-based authentication instead of API keys, you can implement a custom LLM like this: + +```python +from crewai import LLM, Agent, Task +from typing import Any, Dict, List, Optional, Union + +class JWTAuthLLM(LLM): + def __init__(self, jwt_token: str, endpoint: str): + super().__init__() # Initialize the base class to set default attributes + if not jwt_token or not isinstance(jwt_token, str): + raise ValueError("Invalid JWT token: must be a non-empty string") + if not endpoint or not isinstance(endpoint, str): + raise ValueError("Invalid endpoint URL: must be a non-empty string") + self.jwt_token = jwt_token + self.endpoint = endpoint + self.stop = [] # You can customize stop words if needed + + def call( + self, + messages: Union[str, List[Dict[str, str]]], + tools: Optional[List[dict]] = None, + callbacks: Optional[List[Any]] = None, + available_functions: Optional[Dict[str, Any]] = None, + ) -> Union[str, Any]: + """Call the LLM with JWT authentication. + + Args: + messages: Input messages for the LLM. + tools: Optional list of tool schemas for function calling. + callbacks: Optional list of callback functions. + available_functions: Optional dict mapping function names to callables. + + Returns: + Either a text response from the LLM or the result of a tool function call. + + Raises: + TimeoutError: If the LLM request times out. + RuntimeError: If the LLM request fails for other reasons. + ValueError: If the response format is invalid. + """ + # Implement your own logic to call the LLM with JWT authentication + import requests + + try: + headers = { + "Authorization": f"Bearer {self.jwt_token}", + "Content-Type": "application/json" + } + + # Convert string message to proper format if needed + if isinstance(messages, str): + messages = [{"role": "user", "content": messages}] + + data = { + "messages": messages, + "tools": tools + } + + response = requests.post( + self.endpoint, + headers=headers, + json=data, + timeout=30 # Set a reasonable timeout + ) + + if response.status_code == 401: + raise ValueError("Authentication failed: Invalid JWT token") + elif response.status_code == 403: + raise ValueError("Authorization failed: Insufficient permissions") + + response.raise_for_status() # Raise an exception for HTTP errors + return response.json()["choices"][0]["message"]["content"] + except requests.Timeout: + raise TimeoutError("LLM request timed out") + except requests.RequestException as e: + raise RuntimeError(f"LLM request failed: {str(e)}") + except (KeyError, IndexError, ValueError) as e: + raise ValueError(f"Invalid response format: {str(e)}") + + def supports_function_calling(self) -> bool: + """Check if the LLM supports function calling. + + Returns: + True if the LLM supports function calling, False otherwise. + """ + return True + + def supports_stop_words(self) -> bool: + """Check if the LLM supports stop words. + + Returns: + True if the LLM supports stop words, False otherwise. + """ + return True + + def get_context_window_size(self) -> int: + """Get the context window size of the LLM. + + Returns: + The context window size as an integer. + """ + return 8192 +``` + +## Troubleshooting + +Here are some common issues you might encounter when implementing custom LLMs and how to resolve them: + +### 1. Authentication Failures + +**Symptoms**: 401 Unauthorized or 403 Forbidden errors + +**Solutions**: +- Verify that your API key or JWT token is valid and not expired +- Check that you're using the correct authentication header format +- Ensure that your token has the necessary permissions + +### 2. Timeout Issues + +**Symptoms**: Requests taking too long or timing out + +**Solutions**: +- Implement timeout handling as shown in the examples +- Use retry logic with exponential backoff +- Consider using a more reliable network connection + +### 3. Response Parsing Errors + +**Symptoms**: KeyError, IndexError, or ValueError when processing responses + +**Solutions**: +- Validate the response format before accessing nested fields +- Implement proper error handling for malformed responses +- Check the API documentation for the expected response format + +### 4. Rate Limiting + +**Symptoms**: 429 Too Many Requests errors + +**Solutions**: +- Implement rate limiting in your custom LLM +- Add exponential backoff for retries +- Consider using a token bucket algorithm for more precise rate control + +## Advanced Features + +### Logging + +Adding logging to your custom LLM can help with debugging and monitoring: + +```python +import logging +from typing import Any, Dict, List, Optional, Union + +class LoggingLLM(BaseLLM): + def __init__(self, api_key: str, endpoint: str): + super().__init__() + self.api_key = api_key + self.endpoint = endpoint + self.logger = logging.getLogger("crewai.llm.custom") + + def call( + self, + messages: Union[str, List[Dict[str, str]]], + tools: Optional[List[dict]] = None, + callbacks: Optional[List[Any]] = None, + available_functions: Optional[Dict[str, Any]] = None, + ) -> Union[str, Any]: + self.logger.info(f"Calling LLM with {len(messages) if isinstance(messages, list) else 1} messages") + try: + # API call implementation + response = self._make_api_call(messages, tools) + self.logger.debug(f"LLM response received: {response[:100]}...") + return response + except Exception as e: + self.logger.error(f"LLM call failed: {str(e)}") + raise +``` + +### Rate Limiting + +Implementing rate limiting can help avoid overwhelming the LLM API: + +```python +import time +from typing import Any, Dict, List, Optional, Union + +class RateLimitedLLM(BaseLLM): + def __init__( + self, + api_key: str, + endpoint: str, + requests_per_minute: int = 60 + ): + super().__init__() + self.api_key = api_key + self.endpoint = endpoint + self.requests_per_minute = requests_per_minute + self.request_times: List[float] = [] + + def call( + self, + messages: Union[str, List[Dict[str, str]]], + tools: Optional[List[dict]] = None, + callbacks: Optional[List[Any]] = None, + available_functions: Optional[Dict[str, Any]] = None, + ) -> Union[str, Any]: + self._enforce_rate_limit() + # Record this request time + self.request_times.append(time.time()) + # Make the actual API call + return self._make_api_call(messages, tools) + + def _enforce_rate_limit(self) -> None: + """Enforce the rate limit by waiting if necessary.""" + now = time.time() + # Remove request times older than 1 minute + self.request_times = [t for t in self.request_times if now - t < 60] + + if len(self.request_times) >= self.requests_per_minute: + # Calculate how long to wait + oldest_request = min(self.request_times) + wait_time = 60 - (now - oldest_request) + if wait_time > 0: + time.sleep(wait_time) +``` + +### Metrics Collection + +Collecting metrics can help you monitor your LLM usage: + +```python +import time +from typing import Any, Dict, List, Optional, Union + +class MetricsCollectingLLM(BaseLLM): + def __init__(self, api_key: str, endpoint: str): + super().__init__() + self.api_key = api_key + self.endpoint = endpoint + self.metrics: Dict[str, Any] = { + "total_calls": 0, + "total_tokens": 0, + "errors": 0, + "latency": [] + } + + def call( + self, + messages: Union[str, List[Dict[str, str]]], + tools: Optional[List[dict]] = None, + callbacks: Optional[List[Any]] = None, + available_functions: Optional[Dict[str, Any]] = None, + ) -> Union[str, Any]: + start_time = time.time() + self.metrics["total_calls"] += 1 + + try: + response = self._make_api_call(messages, tools) + # Estimate tokens (simplified) + if isinstance(messages, str): + token_estimate = len(messages) // 4 + else: + token_estimate = sum(len(m.get("content", "")) // 4 for m in messages) + self.metrics["total_tokens"] += token_estimate + return response + except Exception as e: + self.metrics["errors"] += 1 + raise + finally: + latency = time.time() - start_time + self.metrics["latency"].append(latency) + + def get_metrics(self) -> Dict[str, Any]: + """Return the collected metrics.""" + avg_latency = sum(self.metrics["latency"]) / len(self.metrics["latency"]) if self.metrics["latency"] else 0 + return { + **self.metrics, + "avg_latency": avg_latency + } +``` + +## Advanced Usage: Function Calling + +If your LLM supports function calling, you can implement the function calling logic in your custom LLM: + +```python +import json +from typing import Any, Dict, List, Optional, Union + +def call( + self, + messages: Union[str, List[Dict[str, str]]], + tools: Optional[List[dict]] = None, + callbacks: Optional[List[Any]] = None, + available_functions: Optional[Dict[str, Any]] = None, +) -> Union[str, Any]: + import requests + + try: + headers = { + "Authorization": f"Bearer {self.jwt_token}", + "Content-Type": "application/json" + } + + # Convert string message to proper format if needed + if isinstance(messages, str): + messages = [{"role": "user", "content": messages}] + + data = { + "messages": messages, + "tools": tools + } + + response = requests.post( + self.endpoint, + headers=headers, + json=data, + timeout=30 + ) + response.raise_for_status() + response_data = response.json() + + # Check if the LLM wants to call a function + if response_data["choices"][0]["message"].get("tool_calls"): + tool_calls = response_data["choices"][0]["message"]["tool_calls"] + + # Process each tool call + for tool_call in tool_calls: + function_name = tool_call["function"]["name"] + function_args = json.loads(tool_call["function"]["arguments"]) + + if available_functions and function_name in available_functions: + function_to_call = available_functions[function_name] + function_response = function_to_call(**function_args) + + # Add the function response to the messages + messages.append({ + "role": "tool", + "tool_call_id": tool_call["id"], + "name": function_name, + "content": str(function_response) + }) + + # Call the LLM again with the updated messages + return self.call(messages, tools, callbacks, available_functions) + + # Return the text response if no function call + return response_data["choices"][0]["message"]["content"] + except requests.Timeout: + raise TimeoutError("LLM request timed out") + except requests.RequestException as e: + raise RuntimeError(f"LLM request failed: {str(e)}") + except (KeyError, IndexError, ValueError) as e: + raise ValueError(f"Invalid response format: {str(e)}") +``` + +## Using Your Custom LLM with CrewAI + +Once you've implemented your custom LLM, you can use it with CrewAI agents and crews: + +```python +from crewai import Agent, Task, Crew +from typing import Dict, Any + +# Create your custom LLM instance +jwt_llm = JWTAuthLLM( + jwt_token="your.jwt.token", + endpoint="https://your-llm-endpoint.com/v1/chat/completions" +) + +# Use it with an agent +agent = Agent( + role="Research Assistant", + goal="Find information on a topic", + backstory="You are a research assistant tasked with finding information.", + llm=jwt_llm, +) + +# Create a task for the agent +task = Task( + description="Research the benefits of exercise", + agent=agent, + expected_output="A summary of the benefits of exercise", +) + +# Execute the task +result = agent.execute_task(task) +print(result) + +# Or use it with a crew +crew = Crew( + agents=[agent], + tasks=[task], + manager_llm=jwt_llm, # Use your custom LLM for the manager +) + +# Run the crew +result = crew.kickoff() +print(result) +``` + +## Implementing Your Own Authentication Mechanism + +The `LLM` class allows you to implement any authentication mechanism you need, not just JWT or API keys. You can use: + +- OAuth tokens +- Client certificates +- Custom headers +- Session-based authentication +- Any other authentication method required by your LLM provider + +Simply implement the appropriate authentication logic in your custom LLM class. + +## Migrating from BaseLLM to LLM + +If you were previously using `BaseLLM`, you can simply replace it with `LLM`: + +```python +# Old code +from crewai import BaseLLM + +class CustomLLM(BaseLLM): + # ... + +# New code +from crewai import LLM + +class CustomLLM(LLM): + # ... +``` + +The `BaseLLM` class is still available for backward compatibility but will be removed in a future release. It now inherits from `LLM` and emits a deprecation warning when instantiated. diff --git a/src/crewai/__init__.py b/src/crewai/__init__.py index 662af25635..98ad92ca3f 100644 --- a/src/crewai/__init__.py +++ b/src/crewai/__init__.py @@ -4,7 +4,7 @@ from crewai.crew import Crew from crewai.flow.flow import Flow from crewai.knowledge.knowledge import Knowledge -from crewai.llm import LLM +from crewai.llm import LLM, BaseLLM, DefaultLLM from crewai.process import Process from crewai.task import Task @@ -21,6 +21,8 @@ "Process", "Task", "LLM", + "BaseLLM", + "DefaultLLM", "Flow", "Knowledge", ] diff --git a/src/crewai/agent.py b/src/crewai/agent.py index cfebc18e5f..41d514ad64 100644 --- a/src/crewai/agent.py +++ b/src/crewai/agent.py @@ -11,7 +11,7 @@ from crewai.knowledge.knowledge import Knowledge from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context -from crewai.llm import LLM +from crewai.llm import LLM, BaseLLM from crewai.memory.contextual.contextual_memory import ContextualMemory from crewai.task import Task from crewai.tools import BaseTool @@ -70,10 +70,10 @@ class Agent(BaseAgent): default=True, description="Use system prompt for the agent.", ) - llm: Union[str, InstanceOf[LLM], Any] = Field( + llm: Union[str, InstanceOf[BaseLLM], Any] = Field( description="Language model that will run the agent.", default=None ) - function_calling_llm: Optional[Union[str, InstanceOf[LLM], Any]] = Field( + function_calling_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field( description="Language model that will run the agent.", default=None ) system_template: Optional[str] = Field( @@ -116,9 +116,16 @@ class Agent(BaseAgent): def post_init_setup(self): self.agent_ops_agent_name = self.role - self.llm = create_llm(self.llm) - if self.function_calling_llm and not isinstance(self.function_calling_llm, LLM): - self.function_calling_llm = create_llm(self.function_calling_llm) + try: + self.llm = create_llm(self.llm) + except Exception as e: + raise RuntimeError(f"Failed to initialize LLM for agent '{self.role}': {str(e)}") + + if self.function_calling_llm and not isinstance(self.function_calling_llm, BaseLLM): + try: + self.function_calling_llm = create_llm(self.function_calling_llm) + except Exception as e: + raise RuntimeError(f"Failed to initialize function calling LLM for agent '{self.role}': {str(e)}") if not self.agent_executor: self._setup_agent_executor() diff --git a/src/crewai/cli/crew_chat.py b/src/crewai/cli/crew_chat.py index cd0da2bb8d..e730935f36 100644 --- a/src/crewai/cli/crew_chat.py +++ b/src/crewai/cli/crew_chat.py @@ -14,7 +14,7 @@ from crewai.cli.utils import read_toml from crewai.cli.version import get_crewai_version from crewai.crew import Crew -from crewai.llm import LLM +from crewai.llm import LLM, BaseLLM from crewai.types.crew_chat import ChatInputField, ChatInputs from crewai.utilities.llm_utils import create_llm @@ -116,7 +116,7 @@ def show_loading(event: threading.Event): print() -def initialize_chat_llm(crew: Crew) -> Optional[LLM]: +def initialize_chat_llm(crew: Crew) -> Optional[BaseLLM]: """Initializes the chat LLM and handles exceptions.""" try: return create_llm(crew.chat_llm) @@ -220,7 +220,7 @@ def get_user_input() -> str: def handle_user_input( user_input: str, - chat_llm: LLM, + chat_llm: BaseLLM, messages: List[Dict[str, str]], crew_tool_schema: Dict[str, Any], available_functions: Dict[str, Any], diff --git a/src/crewai/crew.py b/src/crewai/crew.py index 9cecfed3a2..e23f8d3ce0 100644 --- a/src/crewai/crew.py +++ b/src/crewai/crew.py @@ -6,8 +6,9 @@ from concurrent.futures import Future from copy import copy as shallow_copy from hashlib import md5 -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TypeVar, Union, cast +from langchain_core.tools import BaseTool as LangchainBaseTool from pydantic import ( UUID4, BaseModel, @@ -26,7 +27,7 @@ from crewai.crews.crew_output import CrewOutput from crewai.knowledge.knowledge import Knowledge from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource -from crewai.llm import LLM +from crewai.llm import LLM, BaseLLM from crewai.memory.entity.entity_memory import EntityMemory from crewai.memory.long_term.long_term_memory import LongTermMemory from crewai.memory.short_term.short_term_memory import ShortTermMemory @@ -36,7 +37,7 @@ from crewai.tasks.conditional_task import ConditionalTask from crewai.tasks.task_output import TaskOutput from crewai.tools.agent_tools.agent_tools import AgentTools -from crewai.tools.base_tool import Tool +from crewai.tools.base_tool import BaseTool, Tool from crewai.types.usage_metrics import UsageMetrics from crewai.utilities import I18N, FileHandler, Logger, RPMController from crewai.utilities.constants import TRAINING_DATA_FILE @@ -150,14 +151,14 @@ class Crew(BaseModel): default=None, description="Metrics for the LLM usage during all tasks execution.", ) - manager_llm: Optional[Any] = Field( + manager_llm: Optional[Union[str, InstanceOf[LLM], Any]] = Field( description="Language model that will run the agent.", default=None ) manager_agent: Optional[BaseAgent] = Field( description="Custom agent that will be used as manager.", default=None ) function_calling_llm: Optional[Union[str, InstanceOf[LLM], Any]] = Field( - description="Language model that will run the agent.", default=None + description="Language model that will be used for function calling.", default=None ) config: Optional[Union[Json, Dict[str, Any]]] = Field(default=None) id: UUID4 = Field(default_factory=uuid.uuid4, frozen=True) @@ -184,7 +185,7 @@ class Crew(BaseModel): default=None, description="Maximum number of requests per minute for the crew execution to be respected.", ) - prompt_file: str = Field( + prompt_file: Optional[str] = Field( default=None, description="Path to the prompt json file to be used for the crew.", ) @@ -196,7 +197,7 @@ class Crew(BaseModel): default=False, description="Plan the crew execution and add the plan to the crew.", ) - planning_llm: Optional[Any] = Field( + planning_llm: Optional[Union[str, InstanceOf[LLM], Any]] = Field( default=None, description="Language model that will run the AgentPlanner if planning is True.", ) @@ -212,7 +213,7 @@ class Crew(BaseModel): default=None, description="Knowledge sources for the crew. Add knowledge sources to the knowledge object.", ) - chat_llm: Optional[Any] = Field( + chat_llm: Optional[Union[str, InstanceOf[LLM], Any]] = Field( default=None, description="LLM used to handle chatting with the crew.", ) @@ -798,7 +799,8 @@ def _execute_tasks( # Determine which tools to use - task tools take precedence over agent tools tools_for_task = task.tools or agent_to_use.tools or [] - tools_for_task = self._prepare_tools(agent_to_use, task, tools_for_task) + # Prepare tools and ensure they're compatible with task execution + tools_for_task = self._prepare_tools(agent_to_use, task, cast(Union[List[Tool], List[BaseTool]], tools_for_task)) self._log_task_start(task, agent_to_use.role) @@ -817,7 +819,7 @@ def _execute_tasks( future = task.execute_async( agent=agent_to_use, context=context, - tools=tools_for_task, + tools=cast(List[BaseTool], tools_for_task), ) futures.append((task, future, task_index)) else: @@ -829,7 +831,7 @@ def _execute_tasks( task_output = task.execute_sync( agent=agent_to_use, context=context, - tools=tools_for_task, + tools=cast(List[BaseTool], tools_for_task), ) task_outputs.append(task_output) self._process_task_result(task, task_output) @@ -867,10 +869,10 @@ def _handle_conditional_task( return None def _prepare_tools( - self, agent: BaseAgent, task: Task, tools: List[Tool] - ) -> List[Tool]: + self, agent: BaseAgent, task: Task, tools: Union[List[Tool], List[BaseTool]] + ) -> List[BaseTool]: # Add delegation tools if agent allows delegation - if agent.allow_delegation: + if hasattr(agent, "allow_delegation") and getattr(agent, "allow_delegation", False): if self.process == Process.hierarchical: if self.manager_agent: tools = self._update_manager_tools(task, tools) @@ -879,17 +881,18 @@ def _prepare_tools( "Manager agent is required for hierarchical process." ) - elif agent and agent.allow_delegation: + elif agent: tools = self._add_delegation_tools(task, tools) # Add code execution tools if agent allows code execution - if agent.allow_code_execution: + if hasattr(agent, "allow_code_execution") and getattr(agent, "allow_code_execution", False): tools = self._add_code_execution_tools(agent, tools) - if agent and agent.multimodal: + if agent and hasattr(agent, "multimodal") and getattr(agent, "multimodal", False): tools = self._add_multimodal_tools(agent, tools) - return tools + # Return a List[BaseTool] which is compatible with both Task.execute_sync and Task.execute_async + return cast(List[BaseTool], tools) def _get_agent_to_use(self, task: Task) -> Optional[BaseAgent]: if self.process == Process.hierarchical: @@ -897,11 +900,11 @@ def _get_agent_to_use(self, task: Task) -> Optional[BaseAgent]: return task.agent def _merge_tools( - self, existing_tools: List[Tool], new_tools: List[Tool] - ) -> List[Tool]: + self, existing_tools: Union[List[Tool], List[BaseTool]], new_tools: Union[List[Tool], List[BaseTool]] + ) -> List[BaseTool]: """Merge new tools into existing tools list, avoiding duplicates by tool name.""" if not new_tools: - return existing_tools + return cast(List[BaseTool], existing_tools) # Create mapping of tool names to new tools new_tool_map = {tool.name: tool for tool in new_tools} @@ -912,23 +915,32 @@ def _merge_tools( # Add all new tools tools.extend(new_tools) - return tools + return cast(List[BaseTool], tools) def _inject_delegation_tools( - self, tools: List[Tool], task_agent: BaseAgent, agents: List[BaseAgent] - ): - delegation_tools = task_agent.get_delegation_tools(agents) - return self._merge_tools(tools, delegation_tools) - - def _add_multimodal_tools(self, agent: BaseAgent, tools: List[Tool]): - multimodal_tools = agent.get_multimodal_tools() - return self._merge_tools(tools, multimodal_tools) - - def _add_code_execution_tools(self, agent: BaseAgent, tools: List[Tool]): - code_tools = agent.get_code_execution_tools() - return self._merge_tools(tools, code_tools) - - def _add_delegation_tools(self, task: Task, tools: List[Tool]): + self, tools: Union[List[Tool], List[BaseTool]], task_agent: BaseAgent, agents: List[BaseAgent] + ) -> List[BaseTool]: + if hasattr(task_agent, "get_delegation_tools"): + delegation_tools = task_agent.get_delegation_tools(agents) + # Cast delegation_tools to the expected type for _merge_tools + return self._merge_tools(tools, cast(List[BaseTool], delegation_tools)) + return cast(List[BaseTool], tools) + + def _add_multimodal_tools(self, agent: BaseAgent, tools: Union[List[Tool], List[BaseTool]]) -> List[BaseTool]: + if hasattr(agent, "get_multimodal_tools"): + multimodal_tools = agent.get_multimodal_tools() + # Cast multimodal_tools to the expected type for _merge_tools + return self._merge_tools(tools, cast(List[BaseTool], multimodal_tools)) + return cast(List[BaseTool], tools) + + def _add_code_execution_tools(self, agent: BaseAgent, tools: Union[List[Tool], List[BaseTool]]) -> List[BaseTool]: + if hasattr(agent, "get_code_execution_tools"): + code_tools = agent.get_code_execution_tools() + # Cast code_tools to the expected type for _merge_tools + return self._merge_tools(tools, cast(List[BaseTool], code_tools)) + return cast(List[BaseTool], tools) + + def _add_delegation_tools(self, task: Task, tools: Union[List[Tool], List[BaseTool]]) -> List[BaseTool]: agents_for_delegation = [agent for agent in self.agents if agent != task.agent] if len(self.agents) > 1 and len(agents_for_delegation) > 0 and task.agent: if not tools: @@ -936,7 +948,7 @@ def _add_delegation_tools(self, task: Task, tools: List[Tool]): tools = self._inject_delegation_tools( tools, task.agent, agents_for_delegation ) - return tools + return cast(List[BaseTool], tools) def _log_task_start(self, task: Task, role: str = "None"): if self.output_log_file: @@ -944,7 +956,7 @@ def _log_task_start(self, task: Task, role: str = "None"): task_name=task.name, task=task.description, agent=role, status="started" ) - def _update_manager_tools(self, task: Task, tools: List[Tool]): + def _update_manager_tools(self, task: Task, tools: Union[List[Tool], List[BaseTool]]) -> List[BaseTool]: if self.manager_agent: if task.agent: tools = self._inject_delegation_tools(tools, task.agent, [task.agent]) @@ -952,7 +964,7 @@ def _update_manager_tools(self, task: Task, tools: List[Tool]): tools = self._inject_delegation_tools( tools, self.manager_agent, self.agents ) - return tools + return cast(List[BaseTool], tools) def _get_context(self, task: Task, task_outputs: List[TaskOutput]): context = ( @@ -1198,21 +1210,27 @@ def test( ) -> None: """Test and evaluate the Crew with the given inputs for n iterations concurrently using concurrent.futures.""" try: - eval_llm = create_llm(eval_llm) - if not eval_llm: + # Create LLM instance and ensure it's of type LLM for CrewEvaluator + llm_instance = create_llm(eval_llm) + if not llm_instance: raise ValueError("Failed to create LLM instance.") + + # Ensure we have an LLM instance (not just BaseLLM) for CrewEvaluator + from crewai.llm import LLM + if not isinstance(llm_instance, LLM): + raise TypeError("CrewEvaluator requires an LLM instance, not a BaseLLM instance.") crewai_event_bus.emit( self, CrewTestStartedEvent( crew_name=self.name or "crew", n_iterations=n_iterations, - eval_llm=eval_llm, + eval_llm=llm_instance, inputs=inputs, ), ) test_crew = self.copy() - evaluator = CrewEvaluator(test_crew, eval_llm) # type: ignore[arg-type] + evaluator = CrewEvaluator(test_crew, llm_instance) for i in range(1, n_iterations + 1): evaluator.set_iteration(i) diff --git a/src/crewai/llm.py b/src/crewai/llm.py index 0c8a462140..d6a9977acd 100644 --- a/src/crewai/llm.py +++ b/src/crewai/llm.py @@ -4,6 +4,7 @@ import sys import threading import warnings +from abc import ABC, abstractmethod from contextlib import contextmanager from typing import Any, Dict, List, Literal, Optional, Type, Union, cast @@ -34,6 +35,223 @@ load_dotenv() +class LLM(ABC): + """Base class for LLM implementations. + + This class defines the interface that all LLM implementations must follow. + Users can extend this class to create custom LLM implementations that don't + rely on litellm's authentication mechanism. + + Custom LLM implementations should handle error cases gracefully, including + timeouts, authentication failures, and malformed responses. They should also + implement proper validation for input parameters and provide clear error + messages when things go wrong. + + Attributes: + stop (list): A list of stop sequences that the LLM should use to stop generation. + This is used by the CrewAgentExecutor and other components. + """ + + def __new__(cls, *args, **kwargs): + """Create a new LLM instance. + + This method handles backward compatibility by creating a DefaultLLM instance + when the LLM class is instantiated directly with parameters. + + Args: + *args: Positional arguments. + **kwargs: Keyword arguments. + + Returns: + Either a new LLM instance or a DefaultLLM instance for backward compatibility. + """ + if cls is LLM and (args or kwargs.get('model') is not None): + # Import locally to avoid circular imports + # This is safe because DefaultLLM is defined later in this file + DefaultLLM = globals().get('DefaultLLM') + if DefaultLLM is None: + # If DefaultLLM is not yet defined, return a placeholder + # that will be replaced with a real DefaultLLM instance later + return object.__new__(cls) + return DefaultLLM(*args, **kwargs) + return super().__new__(cls) + + def __init__(self): + """Initialize the LLM with default attributes. + + This constructor sets default values for attributes that are expected + by the CrewAgentExecutor and other components. + + All custom LLM implementations should call super().__init__() to ensure + that these default attributes are properly initialized. + """ + self.stop = [] + + @classmethod + def create( + cls, + model: str, + timeout: Optional[Union[float, int]] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + n: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + max_completion_tokens: Optional[int] = None, + max_tokens: Optional[int] = None, + presence_penalty: Optional[float] = None, + frequency_penalty: Optional[float] = None, + logit_bias: Optional[Dict[int, float]] = None, + response_format: Optional[Type[BaseModel]] = None, + seed: Optional[int] = None, + logprobs: Optional[int] = None, + top_logprobs: Optional[int] = None, + base_url: Optional[str] = None, + api_base: Optional[str] = None, + api_version: Optional[str] = None, + api_key: Optional[str] = None, + callbacks: List[Any] = [], + reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None, + **kwargs, + ) -> 'DefaultLLM': + """Create a default LLM instance using litellm. + + This factory method creates a default LLM instance using litellm as the backend. + It's the recommended way to create LLM instances for most users. + + Args: + model: The model name (e.g., "gpt-4"). + timeout: Optional timeout for the LLM call. + temperature: Optional temperature for the LLM call. + top_p: Optional top_p for the LLM call. + n: Optional n for the LLM call. + stop: Optional stop sequences for the LLM call. + max_completion_tokens: Optional max_completion_tokens for the LLM call. + max_tokens: Optional max_tokens for the LLM call. + presence_penalty: Optional presence_penalty for the LLM call. + frequency_penalty: Optional frequency_penalty for the LLM call. + logit_bias: Optional logit_bias for the LLM call. + response_format: Optional response_format for the LLM call. + seed: Optional seed for the LLM call. + logprobs: Optional logprobs for the LLM call. + top_logprobs: Optional top_logprobs for the LLM call. + base_url: Optional base_url for the LLM call. + api_base: Optional api_base for the LLM call. + api_version: Optional api_version for the LLM call. + api_key: Optional api_key for the LLM call. + callbacks: Optional callbacks for the LLM call. + reasoning_effort: Optional reasoning_effort for the LLM call. + **kwargs: Additional keyword arguments for the LLM call. + + Returns: + A DefaultLLM instance configured with the provided parameters. + """ + from crewai.llm import DefaultLLM + + return DefaultLLM( + model=model, + timeout=timeout, + temperature=temperature, + top_p=top_p, + n=n, + stop=stop, + max_completion_tokens=max_completion_tokens, + max_tokens=max_tokens, + presence_penalty=presence_penalty, + frequency_penalty=frequency_penalty, + logit_bias=logit_bias, + response_format=response_format, + seed=seed, + logprobs=logprobs, + top_logprobs=top_logprobs, + base_url=base_url, + api_base=api_base, + api_version=api_version, + api_key=api_key, + callbacks=callbacks, + reasoning_effort=reasoning_effort, + **kwargs, + ) + + def call( + self, + messages: Union[str, List[Dict[str, str]]], + tools: Optional[List[dict]] = None, + callbacks: Optional[List[Any]] = None, + available_functions: Optional[Dict[str, Any]] = None, + ) -> Union[str, Any]: + """Call the LLM with the given messages. + + Args: + messages: Input messages for the LLM. + Can be a string or list of message dictionaries. + If string, it will be converted to a single user message. + If list, each dict must have 'role' and 'content' keys. + tools: Optional list of tool schemas for function calling. + Each tool should define its name, description, and parameters. + callbacks: Optional list of callback functions to be executed + during and after the LLM call. + available_functions: Optional dict mapping function names to callables + that can be invoked by the LLM. + + Returns: + Either a text response from the LLM (str) or + the result of a tool function call (Any). + + Raises: + ValueError: If the messages format is invalid. + TimeoutError: If the LLM request times out. + RuntimeError: If the LLM request fails for other reasons. + NotImplementedError: If this method is not implemented by a subclass. + """ + raise NotImplementedError("Subclasses must implement call()") + + def supports_function_calling(self) -> bool: + """Check if the LLM supports function calling. + + This method should return True if the LLM implementation supports + function calling (tools), and False otherwise. If this method returns + True, the LLM should be able to handle the 'tools' parameter in the + call() method. + + Returns: + True if the LLM supports function calling, False otherwise. + + Raises: + NotImplementedError: If this method is not implemented by a subclass. + """ + raise NotImplementedError("Subclasses must implement supports_function_calling()") + + def supports_stop_words(self) -> bool: + """Check if the LLM supports stop words. + + This method should return True if the LLM implementation supports + stop words, and False otherwise. If this method returns True, the + LLM should respect the 'stop' attribute when generating responses. + + Returns: + True if the LLM supports stop words, False otherwise. + + Raises: + NotImplementedError: If this method is not implemented by a subclass. + """ + raise NotImplementedError("Subclasses must implement supports_stop_words()") + + def get_context_window_size(self) -> int: + """Get the context window size of the LLM. + + This method should return the maximum number of tokens that the LLM + can process in a single request. This is used by CrewAI to ensure + that messages don't exceed the LLM's context window. + + Returns: + The context window size as an integer. + + Raises: + NotImplementedError: If this method is not implemented by a subclass. + """ + raise NotImplementedError("Subclasses must implement get_context_window_size()") + + class FilteredStream: def __init__(self, original_stream): self._original_stream = original_stream @@ -126,7 +344,14 @@ def suppress_warnings(): sys.stderr = old_stderr -class LLM: +class DefaultLLM(LLM): + """Default LLM implementation using litellm. + + This class provides a concrete implementation of the LLM interface + using litellm as the backend. It's the default implementation used + by CrewAI when no custom LLM is provided. + """ + def __init__( self, model: str, @@ -152,6 +377,8 @@ def __init__( reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None, **kwargs, ): + super().__init__() # Initialize the base class + self.model = model self.timeout = timeout self.temperature = temperature @@ -180,7 +407,7 @@ def __init__( # Normalize self.stop to always be a List[str] if stop is None: - self.stop: List[str] = [] + self.stop = [] # Already initialized in base class elif isinstance(stop, str): self.stop = [stop] else: @@ -564,3 +791,27 @@ def set_env_callbacks(self): litellm.success_callback = success_callbacks litellm.failure_callback = failure_callbacks + + +class BaseLLM(LLM): + """Deprecated: Use LLM instead. + + This class is kept for backward compatibility and will be removed in a future release. + It inherits from LLM and provides the same interface, but emits a deprecation warning + when instantiated. + """ + + def __init__(self): + """Initialize the BaseLLM with a deprecation warning. + + This constructor emits a deprecation warning and then calls the parent class's + constructor to initialize the LLM. + """ + import warnings + warnings.warn( + "BaseLLM is deprecated and will be removed in a future release. " + "Use LLM instead for custom implementations.", + DeprecationWarning, + stacklevel=2 + ) + super().__init__() diff --git a/src/crewai/utilities/llm_utils.py b/src/crewai/utilities/llm_utils.py index 4d34d789ca..63c5a94414 100644 --- a/src/crewai/utilities/llm_utils.py +++ b/src/crewai/utilities/llm_utils.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List, Optional, Union from crewai.cli.constants import DEFAULT_LLM_MODEL, ENV_VARS, LITELLM_PARAMS -from crewai.llm import LLM +from crewai.llm import LLM, BaseLLM def create_llm( @@ -19,17 +19,17 @@ def create_llm( - None: Use environment-based or fallback default model. Returns: - An LLM instance if successful, or None if something fails. + A LLM instance if successful, or None if something fails. """ - # 1) If llm_value is already an LLM object, return it directly + # 1) If llm_value is already a LLM object, return it directly if isinstance(llm_value, LLM): return llm_value # 2) If llm_value is a string (model name) if isinstance(llm_value, str): try: - created_llm = LLM(model=llm_value) + created_llm = LLM.create(model=llm_value) return created_llm except Exception as e: print(f"Failed to instantiate LLM with model='{llm_value}': {e}") @@ -56,7 +56,7 @@ def create_llm( base_url: Optional[str] = getattr(llm_value, "base_url", None) api_base: Optional[str] = getattr(llm_value, "api_base", None) - created_llm = LLM( + created_llm = LLM.create( model=model, temperature=temperature, max_tokens=max_tokens, @@ -175,7 +175,7 @@ def _llm_via_environment_or_fallback() -> Optional[LLM]: # Try creating the LLM try: - new_llm = LLM(**llm_params) + new_llm = LLM.create(**llm_params) return new_llm except Exception as e: print( diff --git a/tests/cassettes/test_litellm_auth_error_handling.yaml b/tests/cassettes/test_litellm_auth_error_handling.yaml new file mode 100644 index 0000000000..bfba1bde36 --- /dev/null +++ b/tests/cassettes/test_litellm_auth_error_handling.yaml @@ -0,0 +1,89 @@ +interactions: +- request: + body: '{"messages": [{"role": "system", "content": "You are test role. test backstory\nYour + personal goal is: test goal\nTo give my best complete final answer to the task + respond using the exact following format:\n\nThought: I now can give a great + answer\nFinal Answer: Your final answer must be the great and the most complete + as possible, it must be outcome described.\n\nI MUST use these formats, my job + depends on it!"}, {"role": "user", "content": "\nCurrent Task: Test task\n\nThis + is the expected criteria for your final answer: Test output\nyou MUST return + the actual complete content as the final answer, not a summary.\n\nBegin! This + is VERY important to you, use the tools available and give your best Final Answer, + your job depends on it!\n\nThought:"}], "model": "gpt-4", "stop": ["\nObservation:"]}' + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '805' + content-type: + - application/json + cookie: + - _cfuvid=xecEkmr_qTiKn7EKC7aeGN5bpsbPM9ofyIsipL4VCYM-1734033219265-0.0.1.1-604800000 + host: + - api.openai.com + user-agent: + - OpenAI/Python 1.61.0 + x-stainless-arch: + - x64 + x-stainless-async: + - 'false' + x-stainless-lang: + - python + x-stainless-os: + - Linux + x-stainless-package-version: + - 1.61.0 + x-stainless-raw-response: + - 'true' + x-stainless-retry-count: + - '0' + x-stainless-runtime: + - CPython + x-stainless-runtime-version: + - 3.12.7 + method: POST + uri: https://api.openai.com/v1/chat/completions + response: + content: "{\n \"error\": {\n \"message\": \"Incorrect API key provided: + sk-proj-********************************************************************************************************************************************************sLcA. + You can find your API key at https://platform.openai.com/account/api-keys.\",\n + \ \"type\": \"invalid_request_error\",\n \"param\": null,\n \"code\": + \"invalid_api_key\"\n }\n}\n" + headers: + CF-RAY: + - 9201beec18a0762e-SEA + Connection: + - keep-alive + Content-Length: + - '414' + Content-Type: + - application/json; charset=utf-8 + Date: + - Fri, 14 Mar 2025 06:34:31 GMT + Server: + - cloudflare + Set-Cookie: + - __cf_bm=wF6OyTyATDK7A9tGqAdaSB3QZfmd34JWPicYlDC1hug-1741934071-1.0.1.1-nZThPWX_7A9FsU7Z14PyrVhl6mCD99iuk9ujCFkNCCdepMHEwK9EXoDrP4IBBCXxkXmKjrVTSaQ63zpcociXuMHR8JKhth2fRUV2H4hMldY; + path=/; expires=Fri, 14-Mar-25 07:04:31 GMT; domain=.api.openai.com; HttpOnly; + Secure; SameSite=None + - _cfuvid=rn5IWZdYMRmbyCa2_84MkWO46MIaP6soWc8npaLc9iQ-1741934071787-0.0.1.1-604800000; + path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None + X-Content-Type-Options: + - nosniff + alt-svc: + - h3=":443"; ma=86400 + cf-cache-status: + - DYNAMIC + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + vary: + - Origin + x-request-id: + - req_f55471c8eb5755daaef3d63eab5a95de + http_version: HTTP/1.1 + status_code: 401 +version: 1 diff --git a/tests/custom_llm_test.py b/tests/custom_llm_test.py new file mode 100644 index 0000000000..49e11970b8 --- /dev/null +++ b/tests/custom_llm_test.py @@ -0,0 +1,570 @@ +from collections import deque +from typing import Any, Dict, List, Optional, Union +import time + +import jwt +import pytest + +from crewai.llm import LLM +from crewai.utilities.llm_utils import create_llm + + +class CustomLLM(LLM): + """Custom LLM implementation for testing. + + This is a simple implementation of the LLM abstract base class + that returns a predefined response for testing purposes. + """ + + def __init__(self, response: str = "Custom LLM response"): + """Initialize the CustomLLM with a predefined response. + + Args: + response: The predefined response to return from call(). + """ + super().__init__() + self.response = response + self.calls = [] + self.stop = [] + + def call( + self, + messages: Union[str, List[Dict[str, str]]], + tools: Optional[List[dict]] = None, + callbacks: Optional[List[Any]] = None, + available_functions: Optional[Dict[str, Any]] = None, + ) -> Union[str, Any]: + """Record the call and return the predefined response. + + Args: + messages: Input messages for the LLM. + tools: Optional list of tool schemas for function calling. + callbacks: Optional list of callback functions. + available_functions: Optional dict mapping function names to callables. + + Returns: + The predefined response string. + """ + self.calls.append({ + "messages": messages, + "tools": tools, + "callbacks": callbacks, + "available_functions": available_functions + }) + return self.response + + def supports_function_calling(self) -> bool: + """Return True to indicate that function calling is supported. + + Returns: + True, indicating that this LLM supports function calling. + """ + return True + + def supports_stop_words(self) -> bool: + """Return True to indicate that stop words are supported. + + Returns: + True, indicating that this LLM supports stop words. + """ + return True + + def get_context_window_size(self) -> int: + """Return a default context window size. + + Returns: + 8192, a typical context window size for modern LLMs. + """ + return 8192 + + +def test_custom_llm_implementation(): + """Test that a custom LLM implementation works with create_llm.""" + custom_llm = CustomLLM(response="The answer is 42") + + # Test that create_llm returns the custom LLM instance directly + result_llm = create_llm(custom_llm) + + assert result_llm is custom_llm + + # Test calling the custom LLM + response = result_llm.call("What is the answer to life, the universe, and everything?") + + # Verify that the custom LLM was called + assert len(custom_llm.calls) > 0 + # Verify that the response from the custom LLM was used + assert response == "The answer is 42" + + +class JWTAuthLLM(LLM): + """Custom LLM implementation with JWT authentication. + + This class demonstrates how to implement a custom LLM that uses JWT + authentication instead of API key-based authentication. It validates + the JWT token before each call and checks for token expiration. + """ + + def __init__(self, jwt_token: str, expiration_buffer: int = 60): + """Initialize the JWTAuthLLM with a JWT token. + + Args: + jwt_token: The JWT token to use for authentication. + expiration_buffer: Buffer time in seconds to warn about token expiration. + Default is 60 seconds. + + Raises: + ValueError: If the JWT token is invalid or missing. + """ + super().__init__() + if not jwt_token or not isinstance(jwt_token, str): + raise ValueError("Invalid JWT token") + + self.jwt_token = jwt_token + self.expiration_buffer = expiration_buffer + self.calls = [] + self.stop = [] + + # Validate the token immediately + self._validate_token() + + def _validate_token(self) -> None: + """Validate the JWT token. + + Checks if the token is valid and not expired. Also warns if the token + is about to expire within the expiration_buffer time. + + Raises: + ValueError: If the token is invalid, expired, or malformed. + """ + try: + # Decode without verification to check expiration + # In a real implementation, you would verify the signature + decoded = jwt.decode(self.jwt_token, options={"verify_signature": False}) + + # Check if token is expired or about to expire + if 'exp' in decoded: + expiration_time = decoded['exp'] + current_time = time.time() + + if expiration_time < current_time: + raise ValueError("JWT token has expired") + + if expiration_time < current_time + self.expiration_buffer: + # Token will expire soon, log a warning + import logging + logging.warning(f"JWT token will expire in {expiration_time - current_time} seconds") + except jwt.PyJWTError as e: + raise ValueError(f"Invalid JWT token format: {str(e)}") + + def call( + self, + messages: Union[str, List[Dict[str, str]]], + tools: Optional[List[dict]] = None, + callbacks: Optional[List[Any]] = None, + available_functions: Optional[Dict[str, Any]] = None, + ) -> Union[str, Any]: + """Call the LLM with JWT authentication. + + Validates the JWT token before making the call to ensure it's still valid. + + Args: + messages: Input messages for the LLM. + tools: Optional list of tool schemas for function calling. + callbacks: Optional list of callback functions. + available_functions: Optional dict mapping function names to callables. + + Returns: + The LLM response. + + Raises: + ValueError: If the JWT token is invalid or expired. + TimeoutError: If the request times out. + ConnectionError: If there's a connection issue. + """ + # Validate token before making the call + self._validate_token() + + self.calls.append({ + "messages": messages, + "tools": tools, + "callbacks": callbacks, + "available_functions": available_functions + }) + + # In a real implementation, this would use the JWT token to authenticate + # with an external service + return "Response from JWT-authenticated LLM" + + def supports_function_calling(self) -> bool: + """Return True to indicate that function calling is supported.""" + return True + + def supports_stop_words(self) -> bool: + """Return True to indicate that stop words are supported.""" + return True + + def get_context_window_size(self) -> int: + """Return a default context window size.""" + return 8192 + + +def test_custom_llm_with_jwt_auth(): + """Test a custom LLM implementation with JWT authentication.""" + # Create a valid JWT token that expires 1 hour from now + valid_token = jwt.encode( + {"exp": int(time.time()) + 3600}, + "secret", + algorithm="HS256" + ) + + jwt_llm = JWTAuthLLM(jwt_token=valid_token) + + # Test that create_llm returns the JWT-authenticated LLM instance directly + result_llm = create_llm(jwt_llm) + + assert result_llm is jwt_llm + + # Test calling the JWT-authenticated LLM + response = result_llm.call("Test message") + + # Verify that the JWT-authenticated LLM was called + assert len(jwt_llm.calls) > 0 + # Verify that the response from the JWT-authenticated LLM was used + assert response == "Response from JWT-authenticated LLM" + + +def test_jwt_auth_llm_validation(): + """Test that JWT token validation works correctly.""" + # Test with invalid JWT token (empty string) + with pytest.raises(ValueError, match="Invalid JWT token"): + JWTAuthLLM(jwt_token="") + + # Test with invalid JWT token (non-string) + with pytest.raises(ValueError, match="Invalid JWT token"): + JWTAuthLLM(jwt_token=None) + + # Test with expired token + # Create a token that expired 1 hour ago + expired_token = jwt.encode( + {"exp": int(time.time()) - 3600}, + "secret", + algorithm="HS256" + ) + with pytest.raises(ValueError, match="JWT token has expired"): + JWTAuthLLM(jwt_token=expired_token) + + # Test with malformed token + with pytest.raises(ValueError, match="Invalid JWT token format"): + JWTAuthLLM(jwt_token="not.a.valid.jwt.token") + + # Test with valid token + # Create a token that expires 1 hour from now + valid_token = jwt.encode( + {"exp": int(time.time()) + 3600}, + "secret", + algorithm="HS256" + ) + # This should not raise an exception + jwt_llm = JWTAuthLLM(jwt_token=valid_token) + assert jwt_llm.jwt_token == valid_token + + +class TimeoutHandlingLLM(LLM): + """Custom LLM implementation with timeout handling and retry logic.""" + + def __init__(self, max_retries: int = 3, timeout: int = 30): + """Initialize the TimeoutHandlingLLM with retry and timeout settings. + + Args: + max_retries: Maximum number of retry attempts. + timeout: Timeout in seconds for each API call. + """ + super().__init__() + self.max_retries = max_retries + self.timeout = timeout + self.calls = [] + self.stop = [] + self.fail_count = 0 # Number of times to simulate failure + + def call( + self, + messages: Union[str, List[Dict[str, str]]], + tools: Optional[List[dict]] = None, + callbacks: Optional[List[Any]] = None, + available_functions: Optional[Dict[str, Any]] = None, + ) -> Union[str, Any]: + """Simulate API calls with timeout handling and retry logic. + + Args: + messages: Input messages for the LLM. + tools: Optional list of tool schemas for function calling. + callbacks: Optional list of callback functions. + available_functions: Optional dict mapping function names to callables. + + Returns: + A response string based on whether this is the first attempt or a retry. + + Raises: + TimeoutError: If all retry attempts fail. + """ + # Record the initial call + self.calls.append({ + "messages": messages, + "tools": tools, + "callbacks": callbacks, + "available_functions": available_functions, + "attempt": 0 + }) + + # Simulate retry logic + for attempt in range(self.max_retries): + # Skip the first attempt recording since we already did that above + if attempt == 0: + # Simulate a failure if fail_count > 0 + if self.fail_count > 0: + self.fail_count -= 1 + # If we've used all retries, raise an error + if attempt == self.max_retries - 1: + raise TimeoutError(f"LLM request failed after {self.max_retries} attempts") + # Otherwise, continue to the next attempt (simulating backoff) + continue + else: + # Success on first attempt + return "First attempt response" + else: + # This is a retry attempt (attempt > 0) + # Always record retry attempts + self.calls.append({ + "retry_attempt": attempt, + "messages": messages, + "tools": tools, + "callbacks": callbacks, + "available_functions": available_functions + }) + + # Simulate a failure if fail_count > 0 + if self.fail_count > 0: + self.fail_count -= 1 + # If we've used all retries, raise an error + if attempt == self.max_retries - 1: + raise TimeoutError(f"LLM request failed after {self.max_retries} attempts") + # Otherwise, continue to the next attempt (simulating backoff) + continue + else: + # Success on retry + return "Response after retry" + + def supports_function_calling(self) -> bool: + """Return True to indicate that function calling is supported. + + Returns: + True, indicating that this LLM supports function calling. + """ + return True + + def supports_stop_words(self) -> bool: + """Return True to indicate that stop words are supported. + + Returns: + True, indicating that this LLM supports stop words. + """ + return True + + def get_context_window_size(self) -> int: + """Return a default context window size. + + Returns: + 8192, a typical context window size for modern LLMs. + """ + return 8192 + + +def test_timeout_handling_llm(): + """Test a custom LLM implementation with timeout handling and retry logic.""" + # Test successful first attempt + llm = TimeoutHandlingLLM() + response = llm.call("Test message") + assert response == "First attempt response" + assert len(llm.calls) == 1 + + # Test successful retry + llm = TimeoutHandlingLLM() + llm.fail_count = 1 # Fail once, then succeed + response = llm.call("Test message") + assert response == "Response after retry" + assert len(llm.calls) == 2 # Initial call + successful retry call + + # Test failure after all retries + llm = TimeoutHandlingLLM(max_retries=2) + llm.fail_count = 2 # Fail twice, which is all retries + with pytest.raises(TimeoutError, match="LLM request failed after 2 attempts"): + llm.call("Test message") + assert len(llm.calls) == 2 # Initial call + failed retry attempt + + +def test_rate_limited_llm(): + """Test that rate limiting works correctly.""" + # Create a rate limited LLM with a very low limit (2 requests per minute) + llm = RateLimitedLLM(requests_per_minute=2) + + # First request should succeed + response1 = llm.call("Test message 1") + assert response1 == "Rate limited response" + assert len(llm.calls) == 1 + + # Second request should succeed + response2 = llm.call("Test message 2") + assert response2 == "Rate limited response" + assert len(llm.calls) == 2 + + # Third request should fail due to rate limiting + with pytest.raises(ValueError, match="Rate limit exceeded"): + llm.call("Test message 3") + + # Test with invalid requests_per_minute + with pytest.raises(ValueError, match="requests_per_minute must be a positive integer"): + RateLimitedLLM(requests_per_minute=0) + + with pytest.raises(ValueError, match="requests_per_minute must be a positive integer"): + RateLimitedLLM(requests_per_minute=-1) + + +def test_rate_limit_reset(): + """Test that rate limits reset after the time window passes.""" + # Create a rate limited LLM with a very low limit (1 request per minute) + # and a short time window for testing (1 second instead of 60 seconds) + time_window = 1 # 1 second instead of 60 seconds + llm = RateLimitedLLM(requests_per_minute=1, time_window=time_window) + + # First request should succeed + response1 = llm.call("Test message 1") + assert response1 == "Rate limited response" + + # Second request should fail due to rate limiting + with pytest.raises(ValueError, match="Rate limit exceeded"): + llm.call("Test message 2") + + # Wait for the rate limit to reset + import time + time.sleep(time_window + 0.1) # Add a small buffer + + # After waiting, we should be able to make another request + response3 = llm.call("Test message 3") + assert response3 == "Rate limited response" + assert len(llm.calls) == 2 # First and third requests + + +class RateLimitedLLM(LLM): + """Custom LLM implementation with rate limiting. + + This class demonstrates how to implement a custom LLM with rate limiting + capabilities. It uses a sliding window algorithm to ensure that no more + than a specified number of requests are made within a given time period. + """ + + def __init__(self, requests_per_minute: int = 60, base_response: str = "Rate limited response", time_window: int = 60): + """Initialize the RateLimitedLLM with rate limiting parameters. + + Args: + requests_per_minute: Maximum number of requests allowed per minute. + base_response: Default response to return. + time_window: Time window in seconds for rate limiting (default: 60). + This is configurable for testing purposes. + + Raises: + ValueError: If requests_per_minute is not a positive integer. + """ + super().__init__() + if not isinstance(requests_per_minute, int) or requests_per_minute <= 0: + raise ValueError("requests_per_minute must be a positive integer") + + self.requests_per_minute = requests_per_minute + self.base_response = base_response + self.time_window = time_window + self.request_times = deque() + self.calls = [] + self.stop = [] + + def _check_rate_limit(self) -> None: + """Check if the current request exceeds the rate limit. + + This method implements a sliding window rate limiting algorithm. + It keeps track of request timestamps and ensures that no more than + `requests_per_minute` requests are made within the configured time window. + + Raises: + ValueError: If the rate limit is exceeded. + """ + current_time = time.time() + + # Remove requests older than the time window + while self.request_times and current_time - self.request_times[0] > self.time_window: + self.request_times.popleft() + + # Check if we've exceeded the rate limit + if len(self.request_times) >= self.requests_per_minute: + wait_time = self.time_window - (current_time - self.request_times[0]) + raise ValueError( + f"Rate limit exceeded. Maximum {self.requests_per_minute} " + f"requests per {self.time_window} seconds. Try again in {wait_time:.2f} seconds." + ) + + # Record this request + self.request_times.append(current_time) + + def call( + self, + messages: Union[str, List[Dict[str, str]]], + tools: Optional[List[dict]] = None, + callbacks: Optional[List[Any]] = None, + available_functions: Optional[Dict[str, Any]] = None, + ) -> Union[str, Any]: + """Call the LLM with rate limiting. + + Args: + messages: Input messages for the LLM. + tools: Optional list of tool schemas for function calling. + callbacks: Optional list of callback functions. + available_functions: Optional dict mapping function names to callables. + + Returns: + The LLM response. + + Raises: + ValueError: If the rate limit is exceeded. + """ + # Check rate limit before making the call + self._check_rate_limit() + + self.calls.append({ + "messages": messages, + "tools": tools, + "callbacks": callbacks, + "available_functions": available_functions + }) + + return self.base_response + + def supports_function_calling(self) -> bool: + """Return True to indicate that function calling is supported. + + Returns: + True, indicating that this LLM supports function calling. + """ + return True + + def supports_stop_words(self) -> bool: + """Return True to indicate that stop words are supported. + + Returns: + True, indicating that this LLM supports stop words. + """ + return True + + def get_context_window_size(self) -> int: + """Return a default context window size. + + Returns: + 8192, a typical context window size for modern LLMs. + """ + return 8192