diff --git a/src/agents/agent.py b/src/agents/agent.py index 2723e678..28fcbaeb 100644 --- a/src/agents/agent.py +++ b/src/agents/agent.py @@ -93,7 +93,15 @@ class Agent(Generic[TContext]): modularity. """ - model: str | Model | None = None + model: ( + str + | Model + | Callable[ + [RunContextWrapper[TContext], Agent[TContext]], + MaybeAwaitable[str | Model], + ] + | None + ) = None """The model implementation to use when invoking the LLM. By default, if not set, the agent will use the default model configured in @@ -205,3 +213,17 @@ async def get_system_prompt(self, run_context: RunContextWrapper[TContext]) -> s logger.error(f"Instructions must be a string or a function, got {self.instructions}") return None + + async def get_model(self, run_context: RunContextWrapper[TContext]) -> str | Model | None: + """Get the model for the agent.""" + if isinstance(self.model, (str, Model)): + return self.model + elif callable(self.model): + if inspect.iscoroutinefunction(self.model): + return await cast(Awaitable[str | Model], self.model(run_context, self)) + else: + return cast(str | Model, self.model(run_context, self)) + elif self.model is not None: + logger.error(f"Model must be a string, Model object, or a function, got {self.model}") + + return None diff --git a/src/agents/run.py b/src/agents/run.py index 934400fe..53043eed 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -628,7 +628,7 @@ async def _run_single_turn_streamed( handoffs = cls._get_handoffs(agent) - model = cls._get_model(agent, run_config) + model = await cls._get_model(agent, run_config, context_wrapper) model_settings = agent.model_settings.resolve(run_config.model_settings) final_response: ModelResponse | None = None @@ -857,7 +857,7 @@ async def _get_new_response( context_wrapper: RunContextWrapper[TContext], run_config: RunConfig, ) -> ModelResponse: - model = cls._get_model(agent, run_config) + model = await cls._get_model(agent, run_config, context_wrapper) model_settings = agent.model_settings.resolve(run_config.model_settings) new_response = await model.get_response( system_instructions=system_prompt, @@ -893,12 +893,22 @@ def _get_handoffs(cls, agent: Agent[Any]) -> list[Handoff]: return handoffs @classmethod - def _get_model(cls, agent: Agent[Any], run_config: RunConfig) -> Model: + async def _get_model( + cls, + agent: Agent[Any], + run_config: RunConfig, + context_wrapper: RunContextWrapper[TContext], + ) -> Model: if isinstance(run_config.model, Model): return run_config.model elif isinstance(run_config.model, str): return run_config.model_provider.get_model(run_config.model) elif isinstance(agent.model, Model): return agent.model + elif callable(agent.model): + model = await agent.get_model(context_wrapper) + if isinstance(model, Model): + return model + return run_config.model_provider.get_model(model) return run_config.model_provider.get_model(agent.model) diff --git a/tests/test_agent_config.py b/tests/test_agent_config.py index 44339dad..8ed81c3c 100644 --- a/tests/test_agent_config.py +++ b/tests/test_agent_config.py @@ -27,6 +27,29 @@ async def async_instructions(agent: Agent[None], context: RunContextWrapper[None assert await agent.get_system_prompt(context) == "async_123" +@pytest.mark.asyncio +async def test_model(): + agent = Agent[None]( + name="test", + model="gpt-4", + ) + context = RunContextWrapper(None) + + assert await agent.get_model(context) == "gpt-4" + + def sync_model(context: RunContextWrapper[None], agent: Agent[None]): + return "sync-model" + + agent = agent.clone(model=sync_model) + assert await agent.get_model(context) == "sync-model" + + async def async_model(context: RunContextWrapper[None], agent: Agent[None]): + return "async-model" + + agent = agent.clone(model=async_model) + assert await agent.get_model(context) == "async-model" + + @pytest.mark.asyncio async def test_handoff_with_agents(): agent_1 = Agent(