Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify LLM implementation by consolidating LLM and BaseLLM classes #2371

Closed
wants to merge 8 commits into from
681 changes: 681 additions & 0 deletions docs/custom_llm.md

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion src/crewai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from crewai.crew import Crew
from crewai.flow.flow import Flow
from crewai.knowledge.knowledge import Knowledge
from crewai.llm import LLM
from crewai.llm import LLM, BaseLLM, DefaultLLM
from crewai.process import Process
from crewai.task import Task

Expand All @@ -21,6 +21,8 @@
"Process",
"Task",
"LLM",
"BaseLLM",
"DefaultLLM",
"Flow",
"Knowledge",
]
19 changes: 13 additions & 6 deletions src/crewai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from crewai.knowledge.knowledge import Knowledge
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context
from crewai.llm import LLM
from crewai.llm import LLM, BaseLLM
from crewai.memory.contextual.contextual_memory import ContextualMemory
from crewai.task import Task
from crewai.tools import BaseTool
Expand Down Expand Up @@ -70,10 +70,10 @@ class Agent(BaseAgent):
default=True,
description="Use system prompt for the agent.",
)
llm: Union[str, InstanceOf[LLM], Any] = Field(
llm: Union[str, InstanceOf[BaseLLM], Any] = Field(
description="Language model that will run the agent.", default=None
)
function_calling_llm: Optional[Union[str, InstanceOf[LLM], Any]] = Field(
function_calling_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
description="Language model that will run the agent.", default=None
)
system_template: Optional[str] = Field(
Expand Down Expand Up @@ -116,9 +116,16 @@ class Agent(BaseAgent):
def post_init_setup(self):
self.agent_ops_agent_name = self.role

self.llm = create_llm(self.llm)
if self.function_calling_llm and not isinstance(self.function_calling_llm, LLM):
self.function_calling_llm = create_llm(self.function_calling_llm)
try:
self.llm = create_llm(self.llm)
except Exception as e:
raise RuntimeError(f"Failed to initialize LLM for agent '{self.role}': {str(e)}")

if self.function_calling_llm and not isinstance(self.function_calling_llm, BaseLLM):
try:
self.function_calling_llm = create_llm(self.function_calling_llm)
except Exception as e:
raise RuntimeError(f"Failed to initialize function calling LLM for agent '{self.role}': {str(e)}")

if not self.agent_executor:
self._setup_agent_executor()
Expand Down
6 changes: 3 additions & 3 deletions src/crewai/cli/crew_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from crewai.cli.utils import read_toml
from crewai.cli.version import get_crewai_version
from crewai.crew import Crew
from crewai.llm import LLM
from crewai.llm import LLM, BaseLLM
from crewai.types.crew_chat import ChatInputField, ChatInputs
from crewai.utilities.llm_utils import create_llm

Expand Down Expand Up @@ -116,7 +116,7 @@ def show_loading(event: threading.Event):
print()


def initialize_chat_llm(crew: Crew) -> Optional[LLM]:
def initialize_chat_llm(crew: Crew) -> Optional[BaseLLM]:
"""Initializes the chat LLM and handles exceptions."""
try:
return create_llm(crew.chat_llm)
Expand Down Expand Up @@ -220,7 +220,7 @@ def get_user_input() -> str:

def handle_user_input(
user_input: str,
chat_llm: LLM,
chat_llm: BaseLLM,
messages: List[Dict[str, str]],
crew_tool_schema: Dict[str, Any],
available_functions: Dict[str, Any],
Expand Down
104 changes: 61 additions & 43 deletions src/crewai/crew.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
from concurrent.futures import Future
from copy import copy as shallow_copy
from hashlib import md5
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TypeVar, Union, cast

from langchain_core.tools import BaseTool as LangchainBaseTool
from pydantic import (
UUID4,
BaseModel,
Expand All @@ -26,7 +27,7 @@
from crewai.crews.crew_output import CrewOutput
from crewai.knowledge.knowledge import Knowledge
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.llm import LLM
from crewai.llm import LLM, BaseLLM
from crewai.memory.entity.entity_memory import EntityMemory
from crewai.memory.long_term.long_term_memory import LongTermMemory
from crewai.memory.short_term.short_term_memory import ShortTermMemory
Expand All @@ -36,7 +37,7 @@
from crewai.tasks.conditional_task import ConditionalTask
from crewai.tasks.task_output import TaskOutput
from crewai.tools.agent_tools.agent_tools import AgentTools
from crewai.tools.base_tool import Tool
from crewai.tools.base_tool import BaseTool, Tool
from crewai.types.usage_metrics import UsageMetrics
from crewai.utilities import I18N, FileHandler, Logger, RPMController
from crewai.utilities.constants import TRAINING_DATA_FILE
Expand Down Expand Up @@ -150,14 +151,14 @@ class Crew(BaseModel):
default=None,
description="Metrics for the LLM usage during all tasks execution.",
)
manager_llm: Optional[Any] = Field(
manager_llm: Optional[Union[str, InstanceOf[LLM], Any]] = Field(
description="Language model that will run the agent.", default=None
)
manager_agent: Optional[BaseAgent] = Field(
description="Custom agent that will be used as manager.", default=None
)
function_calling_llm: Optional[Union[str, InstanceOf[LLM], Any]] = Field(
description="Language model that will run the agent.", default=None
description="Language model that will be used for function calling.", default=None
)
config: Optional[Union[Json, Dict[str, Any]]] = Field(default=None)
id: UUID4 = Field(default_factory=uuid.uuid4, frozen=True)
Expand All @@ -184,7 +185,7 @@ class Crew(BaseModel):
default=None,
description="Maximum number of requests per minute for the crew execution to be respected.",
)
prompt_file: str = Field(
prompt_file: Optional[str] = Field(
default=None,
description="Path to the prompt json file to be used for the crew.",
)
Expand All @@ -196,7 +197,7 @@ class Crew(BaseModel):
default=False,
description="Plan the crew execution and add the plan to the crew.",
)
planning_llm: Optional[Any] = Field(
planning_llm: Optional[Union[str, InstanceOf[LLM], Any]] = Field(
default=None,
description="Language model that will run the AgentPlanner if planning is True.",
)
Expand All @@ -212,7 +213,7 @@ class Crew(BaseModel):
default=None,
description="Knowledge sources for the crew. Add knowledge sources to the knowledge object.",
)
chat_llm: Optional[Any] = Field(
chat_llm: Optional[Union[str, InstanceOf[LLM], Any]] = Field(
default=None,
description="LLM used to handle chatting with the crew.",
)
Expand Down Expand Up @@ -798,7 +799,8 @@ def _execute_tasks(

# Determine which tools to use - task tools take precedence over agent tools
tools_for_task = task.tools or agent_to_use.tools or []
tools_for_task = self._prepare_tools(agent_to_use, task, tools_for_task)
# Prepare tools and ensure they're compatible with task execution
tools_for_task = self._prepare_tools(agent_to_use, task, cast(Union[List[Tool], List[BaseTool]], tools_for_task))

self._log_task_start(task, agent_to_use.role)

Expand All @@ -817,7 +819,7 @@ def _execute_tasks(
future = task.execute_async(
agent=agent_to_use,
context=context,
tools=tools_for_task,
tools=cast(List[BaseTool], tools_for_task),
)
futures.append((task, future, task_index))
else:
Expand All @@ -829,7 +831,7 @@ def _execute_tasks(
task_output = task.execute_sync(
agent=agent_to_use,
context=context,
tools=tools_for_task,
tools=cast(List[BaseTool], tools_for_task),
)
task_outputs.append(task_output)
self._process_task_result(task, task_output)
Expand Down Expand Up @@ -867,10 +869,10 @@ def _handle_conditional_task(
return None

def _prepare_tools(
self, agent: BaseAgent, task: Task, tools: List[Tool]
) -> List[Tool]:
self, agent: BaseAgent, task: Task, tools: Union[List[Tool], List[BaseTool]]
) -> List[BaseTool]:
# Add delegation tools if agent allows delegation
if agent.allow_delegation:
if hasattr(agent, "allow_delegation") and getattr(agent, "allow_delegation", False):
if self.process == Process.hierarchical:
if self.manager_agent:
tools = self._update_manager_tools(task, tools)
Expand All @@ -879,29 +881,30 @@ def _prepare_tools(
"Manager agent is required for hierarchical process."
)

elif agent and agent.allow_delegation:
elif agent:
tools = self._add_delegation_tools(task, tools)

# Add code execution tools if agent allows code execution
if agent.allow_code_execution:
if hasattr(agent, "allow_code_execution") and getattr(agent, "allow_code_execution", False):
tools = self._add_code_execution_tools(agent, tools)

if agent and agent.multimodal:
if agent and hasattr(agent, "multimodal") and getattr(agent, "multimodal", False):
tools = self._add_multimodal_tools(agent, tools)

return tools
# Return a List[BaseTool] which is compatible with both Task.execute_sync and Task.execute_async
return cast(List[BaseTool], tools)

def _get_agent_to_use(self, task: Task) -> Optional[BaseAgent]:
if self.process == Process.hierarchical:
return self.manager_agent
return task.agent

def _merge_tools(
self, existing_tools: List[Tool], new_tools: List[Tool]
) -> List[Tool]:
self, existing_tools: Union[List[Tool], List[BaseTool]], new_tools: Union[List[Tool], List[BaseTool]]
) -> List[BaseTool]:
"""Merge new tools into existing tools list, avoiding duplicates by tool name."""
if not new_tools:
return existing_tools
return cast(List[BaseTool], existing_tools)

# Create mapping of tool names to new tools
new_tool_map = {tool.name: tool for tool in new_tools}
Expand All @@ -912,47 +915,56 @@ def _merge_tools(
# Add all new tools
tools.extend(new_tools)

return tools
return cast(List[BaseTool], tools)

def _inject_delegation_tools(
self, tools: List[Tool], task_agent: BaseAgent, agents: List[BaseAgent]
):
delegation_tools = task_agent.get_delegation_tools(agents)
return self._merge_tools(tools, delegation_tools)

def _add_multimodal_tools(self, agent: BaseAgent, tools: List[Tool]):
multimodal_tools = agent.get_multimodal_tools()
return self._merge_tools(tools, multimodal_tools)

def _add_code_execution_tools(self, agent: BaseAgent, tools: List[Tool]):
code_tools = agent.get_code_execution_tools()
return self._merge_tools(tools, code_tools)

def _add_delegation_tools(self, task: Task, tools: List[Tool]):
self, tools: Union[List[Tool], List[BaseTool]], task_agent: BaseAgent, agents: List[BaseAgent]
) -> List[BaseTool]:
if hasattr(task_agent, "get_delegation_tools"):
delegation_tools = task_agent.get_delegation_tools(agents)
# Cast delegation_tools to the expected type for _merge_tools
return self._merge_tools(tools, cast(List[BaseTool], delegation_tools))
return cast(List[BaseTool], tools)

def _add_multimodal_tools(self, agent: BaseAgent, tools: Union[List[Tool], List[BaseTool]]) -> List[BaseTool]:
if hasattr(agent, "get_multimodal_tools"):
multimodal_tools = agent.get_multimodal_tools()
# Cast multimodal_tools to the expected type for _merge_tools
return self._merge_tools(tools, cast(List[BaseTool], multimodal_tools))
return cast(List[BaseTool], tools)

def _add_code_execution_tools(self, agent: BaseAgent, tools: Union[List[Tool], List[BaseTool]]) -> List[BaseTool]:
if hasattr(agent, "get_code_execution_tools"):
code_tools = agent.get_code_execution_tools()
# Cast code_tools to the expected type for _merge_tools
return self._merge_tools(tools, cast(List[BaseTool], code_tools))
return cast(List[BaseTool], tools)

def _add_delegation_tools(self, task: Task, tools: Union[List[Tool], List[BaseTool]]) -> List[BaseTool]:
agents_for_delegation = [agent for agent in self.agents if agent != task.agent]
if len(self.agents) > 1 and len(agents_for_delegation) > 0 and task.agent:
if not tools:
tools = []
tools = self._inject_delegation_tools(
tools, task.agent, agents_for_delegation
)
return tools
return cast(List[BaseTool], tools)

def _log_task_start(self, task: Task, role: str = "None"):
if self.output_log_file:
self._file_handler.log(
task_name=task.name, task=task.description, agent=role, status="started"
)

def _update_manager_tools(self, task: Task, tools: List[Tool]):
def _update_manager_tools(self, task: Task, tools: Union[List[Tool], List[BaseTool]]) -> List[BaseTool]:
if self.manager_agent:
if task.agent:
tools = self._inject_delegation_tools(tools, task.agent, [task.agent])
else:
tools = self._inject_delegation_tools(
tools, self.manager_agent, self.agents
)
return tools
return cast(List[BaseTool], tools)

def _get_context(self, task: Task, task_outputs: List[TaskOutput]):
context = (
Expand Down Expand Up @@ -1198,21 +1210,27 @@ def test(
) -> None:
"""Test and evaluate the Crew with the given inputs for n iterations concurrently using concurrent.futures."""
try:
eval_llm = create_llm(eval_llm)
if not eval_llm:
# Create LLM instance and ensure it's of type LLM for CrewEvaluator
llm_instance = create_llm(eval_llm)
if not llm_instance:
raise ValueError("Failed to create LLM instance.")

# Ensure we have an LLM instance (not just BaseLLM) for CrewEvaluator
from crewai.llm import LLM
if not isinstance(llm_instance, LLM):
raise TypeError("CrewEvaluator requires an LLM instance, not a BaseLLM instance.")

crewai_event_bus.emit(
self,
CrewTestStartedEvent(
crew_name=self.name or "crew",
n_iterations=n_iterations,
eval_llm=eval_llm,
eval_llm=llm_instance,
inputs=inputs,
),
)
test_crew = self.copy()
evaluator = CrewEvaluator(test_crew, eval_llm) # type: ignore[arg-type]
evaluator = CrewEvaluator(test_crew, llm_instance)

for i in range(1, n_iterations + 1):
evaluator.set_iteration(i)
Expand Down
Loading