Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify LLM implementation by consolidating LLM and BaseLLM classes #2371

Closed
wants to merge 8 commits into from
Prev Previous commit
Next Next commit
Enhance custom LLM implementation with better error handling, documen…
…tation, and test coverage

Co-Authored-By: Joe Moura <joao@crewai.com>
devin-ai-integration[bot] and Joe Moura committed Mar 4, 2025
commit 963ed23b63cc23b927b59ff983c31ad8815c403c
587 changes: 502 additions & 85 deletions docs/custom_llm.md
Original file line number Diff line number Diff line change
@@ -22,6 +22,10 @@ from typing import Any, Dict, List, Optional, Union
class CustomLLM(BaseLLM):
def __init__(self, api_key: str, endpoint: str):
super().__init__() # Initialize the base class to set default attributes
if not api_key or not isinstance(api_key, str):
raise ValueError("Invalid API key: must be a non-empty string")
if not endpoint or not isinstance(endpoint, str):
raise ValueError("Invalid endpoint URL: must be a non-empty string")
self.api_key = api_key
self.endpoint = endpoint
self.stop = [] # You can customize stop words if needed
@@ -33,40 +37,194 @@ class CustomLLM(BaseLLM):
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
"""Call the LLM with the given messages.
Args:
messages: Input messages for the LLM.
tools: Optional list of tool schemas for function calling.
callbacks: Optional list of callback functions.
available_functions: Optional dict mapping function names to callables.
Returns:
Either a text response from the LLM or the result of a tool function call.
Raises:
TimeoutError: If the LLM request times out.
RuntimeError: If the LLM request fails for other reasons.
ValueError: If the response format is invalid.
"""
# Implement your own logic to call the LLM
# For example, using requests:
import requests

headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}

# Convert string message to proper format if needed
if isinstance(messages, str):
messages = [{"role": "user", "content": messages}]

data = {
"messages": messages,
"tools": tools
}

response = requests.post(self.endpoint, headers=headers, json=data)
return response.json()["choices"][0]["message"]["content"]
try:
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}

# Convert string message to proper format if needed
if isinstance(messages, str):
messages = [{"role": "user", "content": messages}]

data = {
"messages": messages,
"tools": tools
}

response = requests.post(
self.endpoint,
headers=headers,
json=data,
timeout=30 # Set a reasonable timeout
)
response.raise_for_status() # Raise an exception for HTTP errors
return response.json()["choices"][0]["message"]["content"]
except requests.Timeout:
raise TimeoutError("LLM request timed out")
except requests.RequestException as e:
raise RuntimeError(f"LLM request failed: {str(e)}")
except (KeyError, IndexError, ValueError) as e:
raise ValueError(f"Invalid response format: {str(e)}")

def supports_function_calling(self) -> bool:
"""Check if the LLM supports function calling.
Returns:
True if the LLM supports function calling, False otherwise.
"""
# Return True if your LLM supports function calling
return True

def supports_stop_words(self) -> bool:
"""Check if the LLM supports stop words.
Returns:
True if the LLM supports stop words, False otherwise.
"""
# Return True if your LLM supports stop words
return True

def get_context_window_size(self) -> int:
"""Get the context window size of the LLM.
Returns:
The context window size as an integer.
"""
# Return the context window size of your LLM
return 8192
```

## Error Handling Best Practices

When implementing custom LLMs, it's important to handle errors properly to ensure robustness and reliability. Here are some best practices:

### 1. Implement Try-Except Blocks for API Calls

Always wrap API calls in try-except blocks to handle different types of errors:

```python
def call(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
try:
# API call implementation
response = requests.post(
self.endpoint,
headers=self.headers,
json=self.prepare_payload(messages),
timeout=30 # Set a reasonable timeout
)
response.raise_for_status() # Raise an exception for HTTP errors
return response.json()["choices"][0]["message"]["content"]
except requests.Timeout:
raise TimeoutError("LLM request timed out")
except requests.RequestException as e:
raise RuntimeError(f"LLM request failed: {str(e)}")
except (KeyError, IndexError, ValueError) as e:
raise ValueError(f"Invalid response format: {str(e)}")
```

### 2. Implement Retry Logic for Transient Failures

For transient failures like network issues or rate limiting, implement retry logic with exponential backoff:

```python
def call(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
import time

max_retries = 3
retry_delay = 1 # seconds

for attempt in range(max_retries):
try:
response = requests.post(
self.endpoint,
headers=self.headers,
json=self.prepare_payload(messages),
timeout=30
)
response.raise_for_status()
return response.json()["choices"][0]["message"]["content"]
except (requests.Timeout, requests.ConnectionError) as e:
if attempt < max_retries - 1:
time.sleep(retry_delay * (2 ** attempt)) # Exponential backoff
continue
raise TimeoutError(f"LLM request failed after {max_retries} attempts: {str(e)}")
except requests.RequestException as e:
raise RuntimeError(f"LLM request failed: {str(e)}")
```

### 3. Validate Input Parameters

Always validate input parameters to prevent runtime errors:

```python
def __init__(self, api_key: str, endpoint: str):
super().__init__()
if not api_key or not isinstance(api_key, str):
raise ValueError("Invalid API key: must be a non-empty string")
if not endpoint or not isinstance(endpoint, str):
raise ValueError("Invalid endpoint URL: must be a non-empty string")
self.api_key = api_key
self.endpoint = endpoint
```

### 4. Handle Authentication Errors Gracefully

Provide clear error messages for authentication failures:

```python
def call(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
try:
response = requests.post(self.endpoint, headers=self.headers, json=data)
if response.status_code == 401:
raise ValueError("Authentication failed: Invalid API key or token")
elif response.status_code == 403:
raise ValueError("Authorization failed: Insufficient permissions")
response.raise_for_status()
# Process response
except Exception as e:
# Handle error
raise
```

## Example: JWT-based Authentication

For services that use JWT-based authentication instead of API keys, you can implement a custom LLM like this:
@@ -78,6 +236,10 @@ from typing import Any, Dict, List, Optional, Union
class JWTAuthLLM(BaseLLM):
def __init__(self, jwt_token: str, endpoint: str):
super().__init__() # Initialize the base class to set default attributes
if not jwt_token or not isinstance(jwt_token, str):
raise ValueError("Invalid JWT token: must be a non-empty string")
if not endpoint or not isinstance(endpoint, str):
raise ValueError("Invalid endpoint URL: must be a non-empty string")
self.jwt_token = jwt_token
self.endpoint = endpoint
self.stop = [] # You can customize stop words if needed
@@ -89,9 +251,282 @@ class JWTAuthLLM(BaseLLM):
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
"""Call the LLM with JWT authentication.
Args:
messages: Input messages for the LLM.
tools: Optional list of tool schemas for function calling.
callbacks: Optional list of callback functions.
available_functions: Optional dict mapping function names to callables.
Returns:
Either a text response from the LLM or the result of a tool function call.
Raises:
TimeoutError: If the LLM request times out.
RuntimeError: If the LLM request fails for other reasons.
ValueError: If the response format is invalid.
"""
# Implement your own logic to call the LLM with JWT authentication
import requests

try:
headers = {
"Authorization": f"Bearer {self.jwt_token}",
"Content-Type": "application/json"
}

# Convert string message to proper format if needed
if isinstance(messages, str):
messages = [{"role": "user", "content": messages}]

data = {
"messages": messages,
"tools": tools
}

response = requests.post(
self.endpoint,
headers=headers,
json=data,
timeout=30 # Set a reasonable timeout
)

if response.status_code == 401:
raise ValueError("Authentication failed: Invalid JWT token")
elif response.status_code == 403:
raise ValueError("Authorization failed: Insufficient permissions")

response.raise_for_status() # Raise an exception for HTTP errors
return response.json()["choices"][0]["message"]["content"]
except requests.Timeout:
raise TimeoutError("LLM request timed out")
except requests.RequestException as e:
raise RuntimeError(f"LLM request failed: {str(e)}")
except (KeyError, IndexError, ValueError) as e:
raise ValueError(f"Invalid response format: {str(e)}")

def supports_function_calling(self) -> bool:
"""Check if the LLM supports function calling.
Returns:
True if the LLM supports function calling, False otherwise.
"""
return True

def supports_stop_words(self) -> bool:
"""Check if the LLM supports stop words.
Returns:
True if the LLM supports stop words, False otherwise.
"""
return True

def get_context_window_size(self) -> int:
"""Get the context window size of the LLM.
Returns:
The context window size as an integer.
"""
return 8192
```

## Troubleshooting

Here are some common issues you might encounter when implementing custom LLMs and how to resolve them:

### 1. Authentication Failures

**Symptoms**: 401 Unauthorized or 403 Forbidden errors

**Solutions**:
- Verify that your API key or JWT token is valid and not expired
- Check that you're using the correct authentication header format
- Ensure that your token has the necessary permissions

### 2. Timeout Issues

**Symptoms**: Requests taking too long or timing out

**Solutions**:
- Implement timeout handling as shown in the examples
- Use retry logic with exponential backoff
- Consider using a more reliable network connection

### 3. Response Parsing Errors

**Symptoms**: KeyError, IndexError, or ValueError when processing responses

**Solutions**:
- Validate the response format before accessing nested fields
- Implement proper error handling for malformed responses
- Check the API documentation for the expected response format

### 4. Rate Limiting

**Symptoms**: 429 Too Many Requests errors

**Solutions**:
- Implement rate limiting in your custom LLM
- Add exponential backoff for retries
- Consider using a token bucket algorithm for more precise rate control

## Advanced Features

### Logging

Adding logging to your custom LLM can help with debugging and monitoring:

```python
import logging
from typing import Any, Dict, List, Optional, Union

class LoggingLLM(BaseLLM):
def __init__(self, api_key: str, endpoint: str):
super().__init__()
self.api_key = api_key
self.endpoint = endpoint
self.logger = logging.getLogger("crewai.llm.custom")

def call(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
self.logger.info(f"Calling LLM with {len(messages) if isinstance(messages, list) else 1} messages")
try:
# API call implementation
response = self._make_api_call(messages, tools)
self.logger.debug(f"LLM response received: {response[:100]}...")
return response
except Exception as e:
self.logger.error(f"LLM call failed: {str(e)}")
raise
```

### Rate Limiting

Implementing rate limiting can help avoid overwhelming the LLM API:

```python
import time
from typing import Any, Dict, List, Optional, Union

class RateLimitedLLM(BaseLLM):
def __init__(
self,
api_key: str,
endpoint: str,
requests_per_minute: int = 60
):
super().__init__()
self.api_key = api_key
self.endpoint = endpoint
self.requests_per_minute = requests_per_minute
self.request_times: List[float] = []

def call(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
self._enforce_rate_limit()
# Record this request time
self.request_times.append(time.time())
# Make the actual API call
return self._make_api_call(messages, tools)

def _enforce_rate_limit(self) -> None:
"""Enforce the rate limit by waiting if necessary."""
now = time.time()
# Remove request times older than 1 minute
self.request_times = [t for t in self.request_times if now - t < 60]

if len(self.request_times) >= self.requests_per_minute:
# Calculate how long to wait
oldest_request = min(self.request_times)
wait_time = 60 - (now - oldest_request)
if wait_time > 0:
time.sleep(wait_time)
```

### Metrics Collection

Collecting metrics can help you monitor your LLM usage:

```python
import time
from typing import Any, Dict, List, Optional, Union

class MetricsCollectingLLM(BaseLLM):
def __init__(self, api_key: str, endpoint: str):
super().__init__()
self.api_key = api_key
self.endpoint = endpoint
self.metrics: Dict[str, Any] = {
"total_calls": 0,
"total_tokens": 0,
"errors": 0,
"latency": []
}

def call(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
start_time = time.time()
self.metrics["total_calls"] += 1

try:
response = self._make_api_call(messages, tools)
# Estimate tokens (simplified)
if isinstance(messages, str):
token_estimate = len(messages) // 4
else:
token_estimate = sum(len(m.get("content", "")) // 4 for m in messages)
self.metrics["total_tokens"] += token_estimate
return response
except Exception as e:
self.metrics["errors"] += 1
raise
finally:
latency = time.time() - start_time
self.metrics["latency"].append(latency)

def get_metrics(self) -> Dict[str, Any]:
"""Return the collected metrics."""
avg_latency = sum(self.metrics["latency"]) / len(self.metrics["latency"]) if self.metrics["latency"] else 0
return {
**self.metrics,
"avg_latency": avg_latency
}
```

## Advanced Usage: Function Calling

If your LLM supports function calling, you can implement the function calling logic in your custom LLM:

```python
import json
from typing import Any, Dict, List, Optional, Union

def call(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
import requests

try:
headers = {
"Authorization": f"Bearer {self.jwt_token}",
"Content-Type": "application/json"
@@ -106,27 +541,57 @@ class JWTAuthLLM(BaseLLM):
"tools": tools
}

response = requests.post(self.endpoint, headers=headers, json=data)
return response.json()["choices"][0]["message"]["content"]

def supports_function_calling(self) -> bool:
# Return True if your LLM supports function calling
return True
response = requests.post(
self.endpoint,
headers=headers,
json=data,
timeout=30
)
response.raise_for_status()
response_data = response.json()

def supports_stop_words(self) -> bool:
# Return True if your LLM supports stop words
return True
# Check if the LLM wants to call a function
if response_data["choices"][0]["message"].get("tool_calls"):
tool_calls = response_data["choices"][0]["message"]["tool_calls"]

# Process each tool call
for tool_call in tool_calls:
function_name = tool_call["function"]["name"]
function_args = json.loads(tool_call["function"]["arguments"])

if available_functions and function_name in available_functions:
function_to_call = available_functions[function_name]
function_response = function_to_call(**function_args)

# Add the function response to the messages
messages.append({
"role": "tool",
"tool_call_id": tool_call["id"],
"name": function_name,
"content": str(function_response)
})

# Call the LLM again with the updated messages
return self.call(messages, tools, callbacks, available_functions)

def get_context_window_size(self) -> int:
# Return the context window size of your LLM
return 8192
# Return the text response if no function call
return response_data["choices"][0]["message"]["content"]
except requests.Timeout:
raise TimeoutError("LLM request timed out")
except requests.RequestException as e:
raise RuntimeError(f"LLM request failed: {str(e)}")
except (KeyError, IndexError, ValueError) as e:
raise ValueError(f"Invalid response format: {str(e)}")
```

## Using Your Custom LLM with CrewAI

Once you've implemented your custom LLM, you can use it with CrewAI agents and crews:

```python
from crewai import Agent, Task, Crew
from typing import Dict, Any

# Create your custom LLM instance
jwt_llm = JWTAuthLLM(
jwt_token="your.jwt.token",
@@ -151,65 +616,17 @@ task = Task(
# Execute the task
result = agent.execute_task(task)
print(result)
```

## Advanced Usage: Function Calling

If your LLM supports function calling, you can implement the function calling logic in your custom LLM:
# Or use it with a crew
crew = Crew(
agents=[agent],
tasks=[task],
manager_llm=jwt_llm, # Use your custom LLM for the manager
)

```python
def call(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
import requests

headers = {
"Authorization": f"Bearer {self.jwt_token}",
"Content-Type": "application/json"
}

# Convert string message to proper format if needed
if isinstance(messages, str):
messages = [{"role": "user", "content": messages}]

data = {
"messages": messages,
"tools": tools
}

response = requests.post(self.endpoint, headers=headers, json=data)
response_data = response.json()

# Check if the LLM wants to call a function
if response_data["choices"][0]["message"].get("tool_calls"):
tool_calls = response_data["choices"][0]["message"]["tool_calls"]

# Process each tool call
for tool_call in tool_calls:
function_name = tool_call["function"]["name"]
function_args = json.loads(tool_call["function"]["arguments"])

if available_functions and function_name in available_functions:
function_to_call = available_functions[function_name]
function_response = function_to_call(**function_args)

# Add the function response to the messages
messages.append({
"role": "tool",
"tool_call_id": tool_call["id"],
"name": function_name,
"content": str(function_response)
})

# Call the LLM again with the updated messages
return self.call(messages, tools, callbacks, available_functions)

# Return the text response if no function call
return response_data["choices"][0]["message"]["content"]
# Run the crew
result = crew.kickoff()
print(result)
```

## Implementing Your Own Authentication Mechanism
2 changes: 1 addition & 1 deletion src/crewai/__init__.py
Original file line number Diff line number Diff line change
@@ -4,7 +4,7 @@
from crewai.crew import Crew
from crewai.flow.flow import Flow
from crewai.knowledge.knowledge import Knowledge
from crewai.llm import BaseLLM, LLM
from crewai.llm import LLM, BaseLLM
from crewai.process import Process
from crewai.task import Task

5 changes: 2 additions & 3 deletions src/crewai/crew.py
Original file line number Diff line number Diff line change
@@ -6,10 +6,9 @@
from concurrent.futures import Future
from copy import copy as shallow_copy
from hashlib import md5
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union, cast, TypeVar
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TypeVar, Union, cast

from langchain_core.tools import BaseTool as LangchainBaseTool
from crewai.tools.base_tool import BaseTool, Tool
from pydantic import (
UUID4,
BaseModel,
@@ -38,7 +37,7 @@
from crewai.tasks.conditional_task import ConditionalTask
from crewai.tasks.task_output import TaskOutput
from crewai.tools.agent_tools.agent_tools import AgentTools
from crewai.tools.base_tool import Tool
from crewai.tools.base_tool import BaseTool, Tool
from crewai.types.usage_metrics import UsageMetrics
from crewai.utilities import I18N, FileHandler, Logger, RPMController
from crewai.utilities.constants import TRAINING_DATA_FILE
186 changes: 181 additions & 5 deletions tests/custom_llm_test.py
Original file line number Diff line number Diff line change
@@ -7,9 +7,19 @@


class CustomLLM(BaseLLM):
"""Custom LLM implementation for testing."""
"""Custom LLM implementation for testing.
This is a simple implementation of the BaseLLM abstract base class
that returns a predefined response for testing purposes.
"""

def __init__(self, response: str = "Custom LLM response"):
"""Initialize the CustomLLM with a predefined response.
Args:
response: The predefined response to return from call().
"""
super().__init__()
self.response = response
self.calls = []
self.stop = []
@@ -21,7 +31,17 @@ def call(
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
"""Record the call and return the predefined response."""
"""Record the call and return the predefined response.
Args:
messages: Input messages for the LLM.
tools: Optional list of tool schemas for function calling.
callbacks: Optional list of callback functions.
available_functions: Optional dict mapping function names to callables.
Returns:
The predefined response string.
"""
self.calls.append({
"messages": messages,
"tools": tools,
@@ -31,15 +51,27 @@ def call(
return self.response

def supports_function_calling(self) -> bool:
"""Return True to indicate that function calling is supported."""
"""Return True to indicate that function calling is supported.
Returns:
True, indicating that this LLM supports function calling.
"""
return True

def supports_stop_words(self) -> bool:
"""Return True to indicate that stop words are supported."""
"""Return True to indicate that stop words are supported.
Returns:
True, indicating that this LLM supports stop words.
"""
return True

def get_context_window_size(self) -> int:
"""Return a default context window size."""
"""Return a default context window size.
Returns:
8192, a typical context window size for modern LLMs.
"""
return 8192


@@ -119,3 +151,147 @@ def test_custom_llm_with_jwt_auth():
assert len(jwt_llm.calls) > 0
# Verify that the response from the JWT-authenticated LLM was used
assert response == "Response from JWT-authenticated LLM"


def test_jwt_auth_llm_validation():
"""Test that JWT token validation works correctly."""
# Test with invalid JWT token (empty string)
with pytest.raises(ValueError, match="Invalid JWT token"):
JWTAuthLLM(jwt_token="")

# Test with invalid JWT token (non-string)
with pytest.raises(ValueError, match="Invalid JWT token"):
JWTAuthLLM(jwt_token=None)


class TimeoutHandlingLLM(BaseLLM):
"""Custom LLM implementation with timeout handling and retry logic."""

def __init__(self, max_retries: int = 3, timeout: int = 30):
"""Initialize the TimeoutHandlingLLM with retry and timeout settings.
Args:
max_retries: Maximum number of retry attempts.
timeout: Timeout in seconds for each API call.
"""
super().__init__()
self.max_retries = max_retries
self.timeout = timeout
self.calls = []
self.stop = []
self.fail_count = 0 # Number of times to simulate failure

def call(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
"""Simulate API calls with timeout handling and retry logic.
Args:
messages: Input messages for the LLM.
tools: Optional list of tool schemas for function calling.
callbacks: Optional list of callback functions.
available_functions: Optional dict mapping function names to callables.
Returns:
A response string based on whether this is the first attempt or a retry.
Raises:
TimeoutError: If all retry attempts fail.
"""
# Record the initial call
self.calls.append({
"messages": messages,
"tools": tools,
"callbacks": callbacks,
"available_functions": available_functions,
"attempt": 0
})

# Simulate retry logic
for attempt in range(self.max_retries):
# Skip the first attempt recording since we already did that above
if attempt == 0:
# Simulate a failure if fail_count > 0
if self.fail_count > 0:
self.fail_count -= 1
# If we've used all retries, raise an error
if attempt == self.max_retries - 1:
raise TimeoutError(f"LLM request failed after {self.max_retries} attempts")
# Otherwise, continue to the next attempt (simulating backoff)
continue
else:
# Success on first attempt
return "First attempt response"
else:
# This is a retry attempt (attempt > 0)
# Always record retry attempts
self.calls.append({
"retry_attempt": attempt,
"messages": messages,
"tools": tools,
"callbacks": callbacks,
"available_functions": available_functions
})

# Simulate a failure if fail_count > 0
if self.fail_count > 0:
self.fail_count -= 1
# If we've used all retries, raise an error
if attempt == self.max_retries - 1:
raise TimeoutError(f"LLM request failed after {self.max_retries} attempts")
# Otherwise, continue to the next attempt (simulating backoff)
continue
else:
# Success on retry
return "Response after retry"

def supports_function_calling(self) -> bool:
"""Return True to indicate that function calling is supported.
Returns:
True, indicating that this LLM supports function calling.
"""
return True

def supports_stop_words(self) -> bool:
"""Return True to indicate that stop words are supported.
Returns:
True, indicating that this LLM supports stop words.
"""
return True

def get_context_window_size(self) -> int:
"""Return a default context window size.
Returns:
8192, a typical context window size for modern LLMs.
"""
return 8192


def test_timeout_handling_llm():
"""Test a custom LLM implementation with timeout handling and retry logic."""
# Test successful first attempt
llm = TimeoutHandlingLLM()
response = llm.call("Test message")
assert response == "First attempt response"
assert len(llm.calls) == 1

# Test successful retry
llm = TimeoutHandlingLLM()
llm.fail_count = 1 # Fail once, then succeed
response = llm.call("Test message")
assert response == "Response after retry"
assert len(llm.calls) == 2 # Initial call + successful retry call

# Test failure after all retries
llm = TimeoutHandlingLLM(max_retries=2)
llm.fail_count = 2 # Fail twice, which is all retries
with pytest.raises(TimeoutError, match="LLM request failed after 2 attempts"):
llm.call("Test message")
assert len(llm.calls) == 2 # Initial call + failed retry attempt