diff --git a/src/crewai/crew.py b/src/crewai/crew.py index 7c9696f6d4..3ca6ff9f72 100644 --- a/src/crewai/crew.py +++ b/src/crewai/crew.py @@ -720,8 +720,13 @@ async def run_crew(crew, input_data): def _handle_crew_planning(self): """Handles the Crew planning.""" self._logger.log("info", "Planning the crew execution") + + agent_llm = self.agents[0].llm if self.agents and hasattr(self.agents[0], 'llm') else None + result = CrewPlanner( - tasks=self.tasks, planning_agent_llm=self.planning_llm + tasks=self.tasks, + planning_agent_llm=self.planning_llm, + agent_llm=agent_llm )._handle_crew_planning() for task, step_plan in zip(self.tasks, result.list_of_plans_per_task): diff --git a/src/crewai/utilities/planning_handler.py b/src/crewai/utilities/planning_handler.py index 1bd14a0c88..3bcfc8c81a 100644 --- a/src/crewai/utilities/planning_handler.py +++ b/src/crewai/utilities/planning_handler.py @@ -28,11 +28,23 @@ class PlannerTaskPydanticOutput(BaseModel): class CrewPlanner: """Plans and coordinates the execution of crew tasks.""" - def __init__(self, tasks: List[Task], planning_agent_llm: Optional[Any] = None): + def __init__(self, tasks: List[Task], planning_agent_llm: Optional[Any] = None, agent_llm: Optional[Any] = None): self.tasks = tasks if planning_agent_llm is None: - self.planning_agent_llm = "gpt-4o-mini" + if agent_llm is not None and hasattr(agent_llm, "base_url") and agent_llm.base_url is not None: + from crewai.llm import LLM + self.planning_agent_llm = LLM( + model="gpt-4o-mini", + base_url=agent_llm.base_url, + api_key=getattr(agent_llm, "api_key", None), + organization=getattr(agent_llm, "organization", None), + api_version=getattr(agent_llm, "api_version", None), + extra_headers=getattr(agent_llm, "extra_headers", None) + ) + else: + from crewai.llm import LLM + self.planning_agent_llm = LLM(model="gpt-4o-mini") else: self.planning_agent_llm = planning_agent_llm diff --git a/tests/utilities/test_planning_auth/__init__.py b/tests/utilities/test_planning_auth/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/utilities/test_planning_auth/test_planning_auth_inheritance.py b/tests/utilities/test_planning_auth/test_planning_auth_inheritance.py new file mode 100644 index 0000000000..c667e619b9 --- /dev/null +++ b/tests/utilities/test_planning_auth/test_planning_auth_inheritance.py @@ -0,0 +1,28 @@ +import unittest +from unittest.mock import MagicMock + +from crewai import Task +from crewai.utilities.planning_handler import CrewPlanner + + +def test_planning_llm_inherits_auth_params(): + """Test that planning LLM inherits authentication parameters from agent LLM.""" + mock_llm = MagicMock() + mock_llm.base_url = "https://api.custom-provider.com/v1" + mock_llm.api_version = "2023-05-15" + + task = Task( + description="Test Task", + expected_output="Test Output" + ) + + planner = CrewPlanner( + tasks=[task], + planning_agent_llm=None, # This should trigger the inheritance logic + agent_llm=mock_llm + ) + + assert hasattr(planner, 'planning_agent_llm') + assert hasattr(planner.planning_agent_llm, 'base_url') + assert planner.planning_agent_llm.base_url == "https://api.custom-provider.com/v1" + assert planner.planning_agent_llm.api_version == "2023-05-15"