Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 2 additions & 27 deletions python/packages/azure-ai/agent_framework_azure_ai/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,7 @@
)
from azure.core.credentials_async import AsyncTokenCredential
from azure.core.exceptions import ResourceNotFoundError
from openai.types.responses.parsed_response import (
ParsedResponse,
)
from openai.types.responses.response import Response as OpenAIResponse
from pydantic import BaseModel, ValidationError
from pydantic import ValidationError

from ._shared import AzureAISettings

Expand Down Expand Up @@ -249,18 +245,6 @@ async def _get_agent_reference_or_create(

return {"name": agent_name, "version": self.agent_version, "type": "agent_reference"}

async def _get_conversation_id_or_create(self, run_options: dict[str, Any]) -> str:
# Since "conversation" property is used, remove "previous_response_id" from options
# Use global conversation_id as fallback
conversation_id = run_options.pop("previous_response_id", self.conversation_id)

if conversation_id:
return conversation_id

# Create a new conversation with messages
created_conversation = await self.client.conversations.create()
return created_conversation.id

async def _close_client_if_needed(self) -> None:
"""Close project_client session if we created it."""
if self._should_close_client:
Expand Down Expand Up @@ -288,16 +272,11 @@ def _prepare_input(self, messages: MutableSequence[ChatMessage]) -> tuple[list[C
async def prepare_options(
self, messages: MutableSequence[ChatMessage], chat_options: ChatOptions
) -> dict[str, Any]:
chat_options.store = bool(chat_options.store or chat_options.store is None)
prepared_messages, instructions = self._prepare_input(messages)
run_options = await super().prepare_options(prepared_messages, chat_options)
agent_reference = await self._get_agent_reference_or_create(run_options, instructions)

store = run_options.get("store", False)

if store:
conversation_id = await self._get_conversation_id_or_create(run_options)
run_options["conversation"] = conversation_id

run_options["extra_body"] = {"agent": agent_reference}

# Remove properties that are not supported on request level
Expand All @@ -313,10 +292,6 @@ async def initialize_client(self):
"""Initialize OpenAI client asynchronously."""
self.client = await self.project_client.get_openai_client() # type: ignore

def get_conversation_id(self, response: OpenAIResponse | ParsedResponse[BaseModel], store: bool) -> str | None:
"""Get the conversation ID from the response if store is True."""
return response.conversation.id if response.conversation and store else None

def _update_agent_name(self, agent_name: str | None) -> None:
"""Update the agent name in the chat client.

Expand Down
77 changes: 0 additions & 77 deletions python/packages/azure-ai/tests/test_azure_ai_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,34 +193,6 @@ async def test_azure_ai_client_get_agent_reference_missing_model(
await client._get_agent_reference_or_create({}, None) # type: ignore


async def test_azure_ai_client_get_conversation_id_or_create_existing(
mock_project_client: MagicMock,
) -> None:
"""Test _get_conversation_id_or_create when conversation_id is already provided."""
client = create_test_azure_ai_client(mock_project_client, conversation_id="existing-conversation")

conversation_id = await client._get_conversation_id_or_create({}) # type: ignore

assert conversation_id == "existing-conversation"


async def test_azure_ai_client_get_conversation_id_or_create_new(
mock_project_client: MagicMock,
) -> None:
"""Test _get_conversation_id_or_create when creating a new conversation."""
client = create_test_azure_ai_client(mock_project_client)

# Mock conversation creation response
mock_conversation = MagicMock()
mock_conversation.id = "new-conversation-123"
client.client.conversations.create = AsyncMock(return_value=mock_conversation)

conversation_id = await client._get_conversation_id_or_create({}) # type: ignore

assert conversation_id == "new-conversation-123"
client.client.conversations.create.assert_called_once()


async def test_azure_ai_client_prepare_input_with_system_messages(
mock_project_client: MagicMock,
) -> None:
Expand Down Expand Up @@ -279,34 +251,6 @@ async def test_azure_ai_client_prepare_options_basic(mock_project_client: MagicM
assert run_options["extra_body"]["agent"]["name"] == "test-agent"


async def test_azure_ai_client_prepare_options_with_store(mock_project_client: MagicMock) -> None:
"""Test prepare_options with store=True creates conversation."""
client = create_test_azure_ai_client(mock_project_client, agent_name="test-agent", agent_version="1.0")

# Mock conversation creation
mock_conversation = MagicMock()
mock_conversation.id = "new-conversation-456"
client.client.conversations.create = AsyncMock(return_value=mock_conversation)

messages = [ChatMessage(role=Role.USER, contents=[TextContent(text="Hello")])]
chat_options = ChatOptions(store=True)

with (
patch.object(
client.__class__.__bases__[0], "prepare_options", return_value={"model": "test-model", "store": True}
),
patch.object(
client,
"_get_agent_reference_or_create",
return_value={"name": "test-agent", "version": "1.0", "type": "agent_reference"},
),
):
run_options = await client.prepare_options(messages, chat_options)

assert "conversation" in run_options
assert run_options["conversation"] == "new-conversation-456"


async def test_azure_ai_client_initialize_client(mock_project_client: MagicMock) -> None:
"""Test initialize_client method."""
client = create_test_azure_ai_client(mock_project_client)
Expand All @@ -320,27 +264,6 @@ async def test_azure_ai_client_initialize_client(mock_project_client: MagicMock)
mock_project_client.get_openai_client.assert_called_once()


def test_azure_ai_client_get_conversation_id_from_response(mock_project_client: MagicMock) -> None:
"""Test get_conversation_id method."""
client = create_test_azure_ai_client(mock_project_client)

# Test with conversation and store=True
mock_response = MagicMock()
mock_response.conversation.id = "test-conversation-123"

conversation_id = client.get_conversation_id(mock_response, store=True)
assert conversation_id == "test-conversation-123"

# Test with store=False
conversation_id = client.get_conversation_id(mock_response, store=False)
assert conversation_id is None

# Test with no conversation
mock_response.conversation = None
conversation_id = client.get_conversation_id(mock_response, store=True)
assert conversation_id is None


def test_azure_ai_client_update_agent_name(mock_project_client: MagicMock) -> None:
"""Test _update_agent_name method."""
client = create_test_azure_ai_client(mock_project_client)
Expand Down
8 changes: 0 additions & 8 deletions python/packages/core/agent_framework/_clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,10 +564,6 @@ async def get_response(

# Validate that store is True when conversation_id is set
if chat_options.conversation_id is not None and chat_options.store is not True:
logger.warning(
"When conversation_id is set, store must be True for service-managed threads. "
"Automatically setting store=True."
)
chat_options.store = True

if chat_options.instructions:
Expand Down Expand Up @@ -663,10 +659,6 @@ async def get_streaming_response(

# Validate that store is True when conversation_id is set
if chat_options.conversation_id is not None and chat_options.store is not True:
logger.warning(
"When conversation_id is set, store must be True for service-managed threads. "
"Automatically setting store=True."
)
chat_options.store = True

if chat_options.instructions:
Expand Down
4 changes: 2 additions & 2 deletions python/packages/core/agent_framework/_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -1630,7 +1630,7 @@ async def function_invocation_wrapper(
# this runs in every but the first run
# we need to keep track of all function call messages
fcc_messages.extend(response.messages)
if getattr(kwargs.get("chat_options"), "store", False):
if response.conversation_id is not None:
prepped_messages.clear()
prepped_messages.append(result_message)
else:
Expand Down Expand Up @@ -1833,7 +1833,7 @@ async def streaming_function_invocation_wrapper(
# this runs in every but the first run
# we need to keep track of all function call messages
fcc_messages.extend(response.messages)
if getattr(kwargs.get("chat_options"), "store", False):
if response.conversation_id is not None:
prepped_messages.clear()
prepped_messages.append(result_message)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,8 @@ def get_mcp_tool(self, tool: HostedMCPTool) -> MutableMapping[str, Any]:
if never_require_approvals := tool.approval_mode.get("never_require_approval"):
mcp["require_approval"] = {"never": {"tool_names": list(never_require_approvals)}}

return mcp

async def prepare_options(
self, messages: MutableSequence[ChatMessage], chat_options: ChatOptions
) -> dict[str, Any]:
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -71,27 +71,27 @@ async def example_with_thread_persistence_in_memory() -> None:
# First conversation
query1 = "What's the weather like in Tokyo?"
print(f"User: {query1}")
result1 = await agent.run(query1, thread=thread)
result1 = await agent.run(query1, thread=thread, store=False)
print(f"Agent: {result1.text}")

# Second conversation using the same thread - maintains context
query2 = "How about London?"
print(f"\nUser: {query2}")
result2 = await agent.run(query2, thread=thread)
result2 = await agent.run(query2, thread=thread, store=False)
print(f"Agent: {result2.text}")

# Third conversation - agent should remember both previous cities
query3 = "Which of the cities I asked about has better weather?"
print(f"\nUser: {query3}")
result3 = await agent.run(query3, thread=thread)
result3 = await agent.run(query3, thread=thread, store=False)
print(f"Agent: {result3.text}")
print("Note: The agent remembers context from previous messages in the same thread.\n")


async def example_with_existing_thread_id() -> None:
"""
Example showing how to work with an existing thread ID from the service.
In this example, messages are stored on the server using Azure AI conversation state.
In this example, messages are stored on the server.
"""
print("=== Existing Thread ID Example ===")

Expand All @@ -111,8 +111,7 @@ async def example_with_existing_thread_id() -> None:

query1 = "What's the weather in Paris?"
print(f"User: {query1}")
# Enable Azure AI conversation state by setting `store` parameter to True
result1 = await agent.run(query1, thread=thread, store=True)
result1 = await agent.run(query1, thread=thread)
print(f"Agent: {result1.text}")

# The thread ID is set after the first response
Expand All @@ -134,7 +133,7 @@ async def example_with_existing_thread_id() -> None:

query2 = "What was the last city I asked about?"
print(f"User: {query2}")
result2 = await agent.run(query2, thread=thread, store=True)
result2 = await agent.run(query2, thread=thread)
print(f"Agent: {result2.text}")
print("Note: The agent continues the conversation from the previous thread by using thread ID.\n")

Expand Down