Skip to content

Latest commit

 

History

History
681 lines (551 loc) · 22.3 KB

custom_llm.md

File metadata and controls

681 lines (551 loc) · 22.3 KB

Custom LLM Implementations

CrewAI supports custom LLM implementations through the LLM base class. This allows you to create your own LLM implementations that don't rely on litellm's authentication mechanism.

Using Custom LLM Implementations

To create a custom LLM implementation, you need to:

  1. Inherit from the LLM base class
  2. Implement the required methods:
    • call(): The main method to call the LLM with messages
    • supports_function_calling(): Whether the LLM supports function calling
    • supports_stop_words(): Whether the LLM supports stop words
    • get_context_window_size(): The context window size of the LLM

Using the Default LLM Implementation

If you don't need a custom LLM implementation, you can use the default implementation provided by CrewAI:

from crewai import LLM

# Create a default LLM instance
llm = LLM.create(model="gpt-4")

# Or with more parameters
llm = LLM.create(
    model="gpt-4",
    temperature=0.7,
    max_tokens=1000,
    api_key="your-api-key"
)

Example: Basic Custom LLM

from crewai import LLM
from typing import Any, Dict, List, Optional, Union

class CustomLLM(LLM):
    def __init__(self, api_key: str, endpoint: str):
        super().__init__()  # Initialize the base class to set default attributes
        if not api_key or not isinstance(api_key, str):
            raise ValueError("Invalid API key: must be a non-empty string")
        if not endpoint or not isinstance(endpoint, str):
            raise ValueError("Invalid endpoint URL: must be a non-empty string")
        self.api_key = api_key
        self.endpoint = endpoint
        self.stop = []  # You can customize stop words if needed
        
    def call(
        self,
        messages: Union[str, List[Dict[str, str]]],
        tools: Optional[List[dict]] = None,
        callbacks: Optional[List[Any]] = None,
        available_functions: Optional[Dict[str, Any]] = None,
    ) -> Union[str, Any]:
        """Call the LLM with the given messages.
        
        Args:
            messages: Input messages for the LLM.
            tools: Optional list of tool schemas for function calling.
            callbacks: Optional list of callback functions.
            available_functions: Optional dict mapping function names to callables.
            
        Returns:
            Either a text response from the LLM or the result of a tool function call.
            
        Raises:
            TimeoutError: If the LLM request times out.
            RuntimeError: If the LLM request fails for other reasons.
            ValueError: If the response format is invalid.
        """
        # Implement your own logic to call the LLM
        # For example, using requests:
        import requests
        
        try:
            headers = {
                "Authorization": f"Bearer {self.api_key}",
                "Content-Type": "application/json"
            }
            
            # Convert string message to proper format if needed
            if isinstance(messages, str):
                messages = [{"role": "user", "content": messages}]
            
            data = {
                "messages": messages,
                "tools": tools
            }
            
            response = requests.post(
                self.endpoint, 
                headers=headers, 
                json=data,
                timeout=30  # Set a reasonable timeout
            )
            response.raise_for_status()  # Raise an exception for HTTP errors
            return response.json()["choices"][0]["message"]["content"]
        except requests.Timeout:
            raise TimeoutError("LLM request timed out")
        except requests.RequestException as e:
            raise RuntimeError(f"LLM request failed: {str(e)}")
        except (KeyError, IndexError, ValueError) as e:
            raise ValueError(f"Invalid response format: {str(e)}")
        
    def supports_function_calling(self) -> bool:
        """Check if the LLM supports function calling.
        
        Returns:
            True if the LLM supports function calling, False otherwise.
        """
        # Return True if your LLM supports function calling
        return True
        
    def supports_stop_words(self) -> bool:
        """Check if the LLM supports stop words.
        
        Returns:
            True if the LLM supports stop words, False otherwise.
        """
        # Return True if your LLM supports stop words
        return True
        
    def get_context_window_size(self) -> int:
        """Get the context window size of the LLM.
        
        Returns:
            The context window size as an integer.
        """
        # Return the context window size of your LLM
        return 8192

Error Handling Best Practices

When implementing custom LLMs, it's important to handle errors properly to ensure robustness and reliability. Here are some best practices:

1. Implement Try-Except Blocks for API Calls

Always wrap API calls in try-except blocks to handle different types of errors:

def call(
    self,
    messages: Union[str, List[Dict[str, str]]],
    tools: Optional[List[dict]] = None,
    callbacks: Optional[List[Any]] = None,
    available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
    try:
        # API call implementation
        response = requests.post(
            self.endpoint,
            headers=self.headers,
            json=self.prepare_payload(messages),
            timeout=30  # Set a reasonable timeout
        )
        response.raise_for_status()  # Raise an exception for HTTP errors
        return response.json()["choices"][0]["message"]["content"]
    except requests.Timeout:
        raise TimeoutError("LLM request timed out")
    except requests.RequestException as e:
        raise RuntimeError(f"LLM request failed: {str(e)}")
    except (KeyError, IndexError, ValueError) as e:
        raise ValueError(f"Invalid response format: {str(e)}")

2. Implement Retry Logic for Transient Failures

For transient failures like network issues or rate limiting, implement retry logic with exponential backoff:

def call(
    self,
    messages: Union[str, List[Dict[str, str]]],
    tools: Optional[List[dict]] = None,
    callbacks: Optional[List[Any]] = None,
    available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
    import time
    
    max_retries = 3
    retry_delay = 1  # seconds
    
    for attempt in range(max_retries):
        try:
            response = requests.post(
                self.endpoint,
                headers=self.headers,
                json=self.prepare_payload(messages),
                timeout=30
            )
            response.raise_for_status()
            return response.json()["choices"][0]["message"]["content"]
        except (requests.Timeout, requests.ConnectionError) as e:
            if attempt < max_retries - 1:
                time.sleep(retry_delay * (2 ** attempt))  # Exponential backoff
                continue
            raise TimeoutError(f"LLM request failed after {max_retries} attempts: {str(e)}")
        except requests.RequestException as e:
            raise RuntimeError(f"LLM request failed: {str(e)}")

3. Validate Input Parameters

Always validate input parameters to prevent runtime errors:

def __init__(self, api_key: str, endpoint: str):
    super().__init__()
    if not api_key or not isinstance(api_key, str):
        raise ValueError("Invalid API key: must be a non-empty string")
    if not endpoint or not isinstance(endpoint, str):
        raise ValueError("Invalid endpoint URL: must be a non-empty string")
    self.api_key = api_key
    self.endpoint = endpoint

4. Handle Authentication Errors Gracefully

Provide clear error messages for authentication failures:

def call(
    self,
    messages: Union[str, List[Dict[str, str]]],
    tools: Optional[List[dict]] = None,
    callbacks: Optional[List[Any]] = None,
    available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
    try:
        response = requests.post(self.endpoint, headers=self.headers, json=data)
        if response.status_code == 401:
            raise ValueError("Authentication failed: Invalid API key or token")
        elif response.status_code == 403:
            raise ValueError("Authorization failed: Insufficient permissions")
        response.raise_for_status()
        # Process response
    except Exception as e:
        # Handle error
        raise

Example: JWT-based Authentication

For services that use JWT-based authentication instead of API keys, you can implement a custom LLM like this:

from crewai import LLM, Agent, Task
from typing import Any, Dict, List, Optional, Union

class JWTAuthLLM(LLM):
    def __init__(self, jwt_token: str, endpoint: str):
        super().__init__()  # Initialize the base class to set default attributes
        if not jwt_token or not isinstance(jwt_token, str):
            raise ValueError("Invalid JWT token: must be a non-empty string")
        if not endpoint or not isinstance(endpoint, str):
            raise ValueError("Invalid endpoint URL: must be a non-empty string")
        self.jwt_token = jwt_token
        self.endpoint = endpoint
        self.stop = []  # You can customize stop words if needed
        
    def call(
        self,
        messages: Union[str, List[Dict[str, str]]],
        tools: Optional[List[dict]] = None,
        callbacks: Optional[List[Any]] = None,
        available_functions: Optional[Dict[str, Any]] = None,
    ) -> Union[str, Any]:
        """Call the LLM with JWT authentication.
        
        Args:
            messages: Input messages for the LLM.
            tools: Optional list of tool schemas for function calling.
            callbacks: Optional list of callback functions.
            available_functions: Optional dict mapping function names to callables.
            
        Returns:
            Either a text response from the LLM or the result of a tool function call.
            
        Raises:
            TimeoutError: If the LLM request times out.
            RuntimeError: If the LLM request fails for other reasons.
            ValueError: If the response format is invalid.
        """
        # Implement your own logic to call the LLM with JWT authentication
        import requests
        
        try:
            headers = {
                "Authorization": f"Bearer {self.jwt_token}",
                "Content-Type": "application/json"
            }
            
            # Convert string message to proper format if needed
            if isinstance(messages, str):
                messages = [{"role": "user", "content": messages}]
            
            data = {
                "messages": messages,
                "tools": tools
            }
            
            response = requests.post(
                self.endpoint,
                headers=headers,
                json=data,
                timeout=30  # Set a reasonable timeout
            )
            
            if response.status_code == 401:
                raise ValueError("Authentication failed: Invalid JWT token")
            elif response.status_code == 403:
                raise ValueError("Authorization failed: Insufficient permissions")
                
            response.raise_for_status()  # Raise an exception for HTTP errors
            return response.json()["choices"][0]["message"]["content"]
        except requests.Timeout:
            raise TimeoutError("LLM request timed out")
        except requests.RequestException as e:
            raise RuntimeError(f"LLM request failed: {str(e)}")
        except (KeyError, IndexError, ValueError) as e:
            raise ValueError(f"Invalid response format: {str(e)}")
        
    def supports_function_calling(self) -> bool:
        """Check if the LLM supports function calling.
        
        Returns:
            True if the LLM supports function calling, False otherwise.
        """
        return True
        
    def supports_stop_words(self) -> bool:
        """Check if the LLM supports stop words.
        
        Returns:
            True if the LLM supports stop words, False otherwise.
        """
        return True
        
    def get_context_window_size(self) -> int:
        """Get the context window size of the LLM.
        
        Returns:
            The context window size as an integer.
        """
        return 8192

Troubleshooting

Here are some common issues you might encounter when implementing custom LLMs and how to resolve them:

1. Authentication Failures

Symptoms: 401 Unauthorized or 403 Forbidden errors

Solutions:

  • Verify that your API key or JWT token is valid and not expired
  • Check that you're using the correct authentication header format
  • Ensure that your token has the necessary permissions

2. Timeout Issues

Symptoms: Requests taking too long or timing out

Solutions:

  • Implement timeout handling as shown in the examples
  • Use retry logic with exponential backoff
  • Consider using a more reliable network connection

3. Response Parsing Errors

Symptoms: KeyError, IndexError, or ValueError when processing responses

Solutions:

  • Validate the response format before accessing nested fields
  • Implement proper error handling for malformed responses
  • Check the API documentation for the expected response format

4. Rate Limiting

Symptoms: 429 Too Many Requests errors

Solutions:

  • Implement rate limiting in your custom LLM
  • Add exponential backoff for retries
  • Consider using a token bucket algorithm for more precise rate control

Advanced Features

Logging

Adding logging to your custom LLM can help with debugging and monitoring:

import logging
from typing import Any, Dict, List, Optional, Union

class LoggingLLM(BaseLLM):
    def __init__(self, api_key: str, endpoint: str):
        super().__init__()
        self.api_key = api_key
        self.endpoint = endpoint
        self.logger = logging.getLogger("crewai.llm.custom")
        
    def call(
        self,
        messages: Union[str, List[Dict[str, str]]],
        tools: Optional[List[dict]] = None,
        callbacks: Optional[List[Any]] = None,
        available_functions: Optional[Dict[str, Any]] = None,
    ) -> Union[str, Any]:
        self.logger.info(f"Calling LLM with {len(messages) if isinstance(messages, list) else 1} messages")
        try:
            # API call implementation
            response = self._make_api_call(messages, tools)
            self.logger.debug(f"LLM response received: {response[:100]}...")
            return response
        except Exception as e:
            self.logger.error(f"LLM call failed: {str(e)}")
            raise

Rate Limiting

Implementing rate limiting can help avoid overwhelming the LLM API:

import time
from typing import Any, Dict, List, Optional, Union

class RateLimitedLLM(BaseLLM):
    def __init__(
        self, 
        api_key: str, 
        endpoint: str, 
        requests_per_minute: int = 60
    ):
        super().__init__()
        self.api_key = api_key
        self.endpoint = endpoint
        self.requests_per_minute = requests_per_minute
        self.request_times: List[float] = []
        
    def call(
        self,
        messages: Union[str, List[Dict[str, str]]],
        tools: Optional[List[dict]] = None,
        callbacks: Optional[List[Any]] = None,
        available_functions: Optional[Dict[str, Any]] = None,
    ) -> Union[str, Any]:
        self._enforce_rate_limit()
        # Record this request time
        self.request_times.append(time.time())
        # Make the actual API call
        return self._make_api_call(messages, tools)
        
    def _enforce_rate_limit(self) -> None:
        """Enforce the rate limit by waiting if necessary."""
        now = time.time()
        # Remove request times older than 1 minute
        self.request_times = [t for t in self.request_times if now - t < 60]
        
        if len(self.request_times) >= self.requests_per_minute:
            # Calculate how long to wait
            oldest_request = min(self.request_times)
            wait_time = 60 - (now - oldest_request)
            if wait_time > 0:
                time.sleep(wait_time)

Metrics Collection

Collecting metrics can help you monitor your LLM usage:

import time
from typing import Any, Dict, List, Optional, Union

class MetricsCollectingLLM(BaseLLM):
    def __init__(self, api_key: str, endpoint: str):
        super().__init__()
        self.api_key = api_key
        self.endpoint = endpoint
        self.metrics: Dict[str, Any] = {
            "total_calls": 0,
            "total_tokens": 0,
            "errors": 0,
            "latency": []
        }
        
    def call(
        self,
        messages: Union[str, List[Dict[str, str]]],
        tools: Optional[List[dict]] = None,
        callbacks: Optional[List[Any]] = None,
        available_functions: Optional[Dict[str, Any]] = None,
    ) -> Union[str, Any]:
        start_time = time.time()
        self.metrics["total_calls"] += 1
        
        try:
            response = self._make_api_call(messages, tools)
            # Estimate tokens (simplified)
            if isinstance(messages, str):
                token_estimate = len(messages) // 4
            else:
                token_estimate = sum(len(m.get("content", "")) // 4 for m in messages)
            self.metrics["total_tokens"] += token_estimate
            return response
        except Exception as e:
            self.metrics["errors"] += 1
            raise
        finally:
            latency = time.time() - start_time
            self.metrics["latency"].append(latency)
            
    def get_metrics(self) -> Dict[str, Any]:
        """Return the collected metrics."""
        avg_latency = sum(self.metrics["latency"]) / len(self.metrics["latency"]) if self.metrics["latency"] else 0
        return {
            **self.metrics,
            "avg_latency": avg_latency
        }

Advanced Usage: Function Calling

If your LLM supports function calling, you can implement the function calling logic in your custom LLM:

import json
from typing import Any, Dict, List, Optional, Union

def call(
    self,
    messages: Union[str, List[Dict[str, str]]],
    tools: Optional[List[dict]] = None,
    callbacks: Optional[List[Any]] = None,
    available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
    import requests
    
    try:
        headers = {
            "Authorization": f"Bearer {self.jwt_token}",
            "Content-Type": "application/json"
        }
        
        # Convert string message to proper format if needed
        if isinstance(messages, str):
            messages = [{"role": "user", "content": messages}]
        
        data = {
            "messages": messages,
            "tools": tools
        }
        
        response = requests.post(
            self.endpoint,
            headers=headers,
            json=data,
            timeout=30
        )
        response.raise_for_status()
        response_data = response.json()
        
        # Check if the LLM wants to call a function
        if response_data["choices"][0]["message"].get("tool_calls"):
            tool_calls = response_data["choices"][0]["message"]["tool_calls"]
            
            # Process each tool call
            for tool_call in tool_calls:
                function_name = tool_call["function"]["name"]
                function_args = json.loads(tool_call["function"]["arguments"])
                
                if available_functions and function_name in available_functions:
                    function_to_call = available_functions[function_name]
                    function_response = function_to_call(**function_args)
                    
                    # Add the function response to the messages
                    messages.append({
                        "role": "tool",
                        "tool_call_id": tool_call["id"],
                        "name": function_name,
                        "content": str(function_response)
                    })
            
            # Call the LLM again with the updated messages
            return self.call(messages, tools, callbacks, available_functions)
        
        # Return the text response if no function call
        return response_data["choices"][0]["message"]["content"]
    except requests.Timeout:
        raise TimeoutError("LLM request timed out")
    except requests.RequestException as e:
        raise RuntimeError(f"LLM request failed: {str(e)}")
    except (KeyError, IndexError, ValueError) as e:
        raise ValueError(f"Invalid response format: {str(e)}")

Using Your Custom LLM with CrewAI

Once you've implemented your custom LLM, you can use it with CrewAI agents and crews:

from crewai import Agent, Task, Crew
from typing import Dict, Any

# Create your custom LLM instance
jwt_llm = JWTAuthLLM(
    jwt_token="your.jwt.token", 
    endpoint="https://your-llm-endpoint.com/v1/chat/completions"
)

# Use it with an agent
agent = Agent(
    role="Research Assistant",
    goal="Find information on a topic",
    backstory="You are a research assistant tasked with finding information.",
    llm=jwt_llm,
)

# Create a task for the agent
task = Task(
    description="Research the benefits of exercise",
    agent=agent,
    expected_output="A summary of the benefits of exercise",
)

# Execute the task
result = agent.execute_task(task)
print(result)

# Or use it with a crew
crew = Crew(
    agents=[agent],
    tasks=[task],
    manager_llm=jwt_llm,  # Use your custom LLM for the manager
)

# Run the crew
result = crew.kickoff()
print(result)

Implementing Your Own Authentication Mechanism

The LLM class allows you to implement any authentication mechanism you need, not just JWT or API keys. You can use:

  • OAuth tokens
  • Client certificates
  • Custom headers
  • Session-based authentication
  • Any other authentication method required by your LLM provider

Simply implement the appropriate authentication logic in your custom LLM class.

Migrating from BaseLLM to LLM

If you were previously using BaseLLM, you can simply replace it with LLM:

# Old code
from crewai import BaseLLM

class CustomLLM(BaseLLM):
    # ...

# New code
from crewai import LLM

class CustomLLM(LLM):
    # ...

The BaseLLM class is still available for backward compatibility but will be removed in a future release. It now inherits from LLM and emits a deprecation warning when instantiated.