Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions src/App/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,16 +81,25 @@ async def startup():
app.wealth_advisor_agent = await AgentFactory.get_wealth_advisor_agent()
logging.info("Wealth Advisor Agent initialized during application startup")
app.search_agent = await AgentFactory.get_search_agent()
logging.info(
"Call Transcript Search Agent initialized during application startup"
)
logging.info("Call Transcript Search Agent initialized during application startup")
app.sql_agent = await AgentFactory.get_sql_agent()
logging.info("SQL Agent initialized during application startup")

@app.after_serving
async def shutdown():
await AgentFactory.delete_all_agent_instance()
app.wealth_advisor_agent = None
app.search_agent = None
logging.info("Agents cleaned up during application shutdown")
try:
logging.info("Application shutdown initiated...")
await AgentFactory.delete_all_agent_instance()
if hasattr(app, 'wealth_advisor_agent'):
app.wealth_advisor_agent = None
if hasattr(app, 'search_agent'):
app.search_agent = None
if hasattr(app, 'sql_agent'):
app.sql_agent = None
logging.info("Agents cleaned up successfully")
except Exception as e:
logging.error(f"Error during shutdown: {e}")
logging.exception("Detailed error during shutdown")

# app.secret_key = secrets.token_hex(16)
# app.session_interface = SecureCookieSessionInterface()
Expand Down
73 changes: 62 additions & 11 deletions src/App/backend/agents/agent_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"""

import asyncio
import logging
from typing import Optional

from azure.ai.projects import AIProjectClient
Expand All @@ -26,6 +27,7 @@ class AgentFactory:
_lock = asyncio.Lock()
_wealth_advisor_agent: Optional[AzureAIAgent] = None
_search_agent: Optional[dict] = None
_sql_agent: Optional[dict] = None

@classmethod
async def get_wealth_advisor_agent(cls):
Expand Down Expand Up @@ -94,18 +96,67 @@ async def delete_all_agent_instance(cls):
Delete the singleton AzureAIAgent instances if it exists.
"""
async with cls._lock:
if cls._wealth_advisor_agent is not None:
await cls._wealth_advisor_agent.client.agents.delete_agent(
cls._wealth_advisor_agent.id
)
cls._wealth_advisor_agent = None
logging.info("Starting agent deletion process...")

# Delete Wealth Advisor Agent
if cls._wealth_advisor_agent is not None:
try:
agent_id = cls._wealth_advisor_agent.id
logging.info(f"Deleting wealth advisor agent: {agent_id}")
if hasattr(cls._wealth_advisor_agent, 'client') and cls._wealth_advisor_agent.client:
await cls._wealth_advisor_agent.client.agents.delete_agent(agent_id)
logging.info("Wealth advisor agent deleted successfully")
else:
logging.warning("Wealth advisor agent client is None")
except Exception as e:
logging.error(f"Error deleting wealth advisor agent: {e}")
logging.exception("Detailed wealth advisor agent deletion error")
finally:
cls._wealth_advisor_agent = None

# Delete Search Agent
if cls._search_agent is not None:
cls._search_agent["client"].agents.delete_agent(
cls._search_agent["agent"].id
)
cls._search_agent["client"].close()
cls._search_agent = None
try:
agent_id = cls._search_agent['agent'].id
logging.info(f"Deleting search agent: {agent_id}")
if cls._search_agent.get("client") and hasattr(cls._search_agent["client"], "agents"):
# AIProjectClient.agents.delete_agent is synchronous, don't await it
cls._search_agent["client"].agents.delete_agent(agent_id)
logging.info("Search agent deleted successfully")

# Close the client if it has a close method
if hasattr(cls._search_agent["client"], "close"):
cls._search_agent["client"].close()
else:
logging.warning("Search agent client is None or invalid")
except Exception as e:
logging.error(f"Error deleting search agent: {e}")
logging.exception("Detailed search agent deletion error")
finally:
cls._search_agent = None

# Delete SQL Agent
if cls._sql_agent is not None:
try:
agent_id = cls._sql_agent['agent'].id
logging.info(f"Deleting SQL agent: {agent_id}")
if cls._sql_agent.get("client") and hasattr(cls._sql_agent["client"], "agents"):
# AIProjectClient.agents.delete_agent is synchronous, don't await it
cls._sql_agent["client"].agents.delete_agent(agent_id)
logging.info("SQL agent deleted successfully")

# Close the client if it has a close method
if hasattr(cls._sql_agent["client"], "close"):
cls._sql_agent["client"].close()
else:
logging.warning("SQL agent client is None or invalid")
except Exception as e:
logging.error(f"Error deleting SQL agent: {e}")
logging.exception("Detailed SQL agent deletion error")
finally:
cls._sql_agent = None

logging.info("Agent deletion process completed")

@classmethod
async def get_sql_agent(cls) -> dict:
Expand All @@ -114,7 +165,7 @@ async def get_sql_agent(cls) -> dict:
This agent is used to generate T-SQL queries from natural language input.
"""
async with cls._lock:
if not hasattr(cls, "_sql_agent") or cls._sql_agent is None:
if cls._sql_agent is None:

agent_instructions = config.SQL_SYSTEM_PROMPT or """
You are an expert assistant in generating T-SQL queries based on user questions.
Expand Down
35 changes: 31 additions & 4 deletions src/App/backend/plugins/chat_with_data_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from azure.ai.projects import AIProjectClient
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
from semantic_kernel.functions.kernel_function_decorator import kernel_function
from quart import current_app

from backend.common.config import config
from backend.services.sqldb_service import get_connection
Expand Down Expand Up @@ -41,9 +42,14 @@ async def get_SQL_Response(
if not input or not input.strip():
return "Error: Query input is required"

thread = None
try:
# TEMPORARY: Use AgentFactory directly to debug the issue
logging.info(f"Using AgentFactory directly for SQL agent for ClientId: {ClientId}")
from backend.agents.agent_factory import AgentFactory
agent_info = await AgentFactory.get_sql_agent()

logging.info(f"SQL agent retrieved: {agent_info is not None}")
agent = agent_info["agent"]
project_client = agent_info["client"]

Expand Down Expand Up @@ -72,30 +78,42 @@ async def get_SQL_Response(
role=MessageRole.AGENT
)
sql_query = message.text.value.strip() if message else None
logging.info(f"Generated SQL query: {sql_query}")

if not sql_query:
return "No SQL query was generated."

# Clean up triple backticks (if any)
sql_query = sql_query.replace("```sql", "").replace("```", "")
logging.info(f"Cleaned SQL query: {sql_query}")

# Execute the query
conn = get_connection()
cursor = conn.cursor()
cursor.execute(sql_query)
rows = cursor.fetchall()
logging.info(f"Query returned {len(rows)} rows")

if not rows:
result = "No data found for that client."
else:
result = "\n".join(str(row) for row in rows)
logging.info(f"Result preview: {result[:200]}...")

conn.close()

return result[:20000] if len(result) > 20000 else result
except Exception as e:
logging.exception("Error in get_SQL_Response")
return f"Error retrieving SQL data: {str(e)}"
finally:
if thread:
try:
logging.info(f"Attempting to delete thread {thread.id}")
await project_client.agents.threads.delete(thread.id)
logging.info(f"Thread {thread.id} deleted successfully")
except Exception as e:
logging.error(f"Error deleting thread {thread.id}: {str(e)}")

@kernel_function(
name="ChatWithCallTranscripts",
Expand All @@ -114,12 +132,17 @@ async def get_answers_from_calltranscripts(
if not question or not question.strip():
return "Error: Question input is required"

thread = None
try:
response_text = ""

from backend.agents.agent_factory import AgentFactory

agent_info: dict = await AgentFactory.get_search_agent()
# Use the singleton search agent from app context
if not hasattr(current_app, 'search_agent') or current_app.search_agent is None:
logging.error("Search agent not found in app context, falling back to AgentFactory")
from backend.agents.agent_factory import AgentFactory
agent_info = await AgentFactory.get_search_agent()
else:
agent_info = current_app.search_agent

agent: Agent = agent_info["agent"]
project_client: AIProjectClient = agent_info["client"]
Expand Down Expand Up @@ -190,7 +213,11 @@ async def get_answers_from_calltranscripts(

finally:
if thread:
project_client.agents.threads.delete(thread.id)
try:
await project_client.agents.threads.delete(thread.id)
logging.info(f"Thread {thread.id} deleted successfully")
except Exception as e:
logging.error(f"Error deleting thread {thread.id}: {str(e)}")

if not response_text.strip():
return "No data found for that client."
Expand Down
44 changes: 37 additions & 7 deletions src/App/tests/backend/plugins/test_chat_with_data_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,12 +180,14 @@ async def test_get_sql_response_openai_error(self, mock_get_sql_agent, mock_conf
assert "OpenAI API error" in result

@pytest.mark.asyncio
@patch("backend.plugins.chat_with_data_plugin.hasattr", return_value=False)
@patch("backend.agents.agent_factory.AgentFactory.get_search_agent")
@patch("backend.plugins.chat_with_data_plugin.config")
async def test_get_answers_from_calltranscripts_success(
self, mock_get_search_agent
self, mock_config, mock_get_search_agent, mock_hasattr
):
"""Test successful retrieval of answers from call transcripts using AI Search Agent."""
# Setup mocks for agent factory
# Setup mocks for agent factory (fallback case when current_app.search_agent is None)
mock_agent = MagicMock()
mock_agent.id = "test-agent-id"

Expand All @@ -195,6 +197,10 @@ async def test_get_answers_from_calltranscripts_success(
"client": mock_project_client,
}

# Mock config values
mock_config.AZURE_SEARCH_INDEX = "test-index"
mock_config.AZURE_SEARCH_CONNECTION_NAME = "test-connection"

# Mock project index creation
mock_index = MagicMock()
mock_index.name = "project-index-test"
Expand Down Expand Up @@ -229,7 +235,7 @@ async def test_get_answers_from_calltranscripts_success(
assert "Based on call transcripts" in result
assert "investment options" in result

# Verify agent factory was called
# Verify agent factory was called (fallback case)
mock_get_search_agent.assert_called_once()

# Verify project index was created/updated
Expand All @@ -249,9 +255,11 @@ async def test_get_answers_from_calltranscripts_success(
mock_project_client.agents.runs.create_and_process.assert_called_once()

@pytest.mark.asyncio
@patch("backend.plugins.chat_with_data_plugin.hasattr", return_value=False)
@patch("backend.agents.agent_factory.AgentFactory.get_search_agent")
@patch("backend.plugins.chat_with_data_plugin.config")
async def test_get_answers_from_calltranscripts_no_results(
self, mock_get_search_agent
self, mock_config, mock_get_search_agent, mock_hasattr
):
"""Test call transcripts search with no results."""
# Setup mocks for agent factory
Expand All @@ -264,6 +272,10 @@ async def test_get_answers_from_calltranscripts_no_results(
"client": mock_project_client,
}

# Mock config values
mock_config.AZURE_SEARCH_INDEX = "test-index"
mock_config.AZURE_SEARCH_CONNECTION_NAME = "test-connection"

# Mock project index creation
mock_index = MagicMock()
mock_index.name = "project-index-test"
Expand Down Expand Up @@ -295,9 +307,11 @@ async def test_get_answers_from_calltranscripts_no_results(
assert "No data found for that client." in result

@pytest.mark.asyncio
@patch("backend.plugins.chat_with_data_plugin.hasattr", return_value=False)
@patch("backend.agents.agent_factory.AgentFactory.get_search_agent")
@patch("backend.plugins.chat_with_data_plugin.config")
async def test_get_answers_from_calltranscripts_openai_error(
self, mock_get_search_agent
self, mock_config, mock_get_search_agent, mock_hasattr
):
"""Test call transcripts with AI Search processing error."""
# Setup mocks for agent factory
Expand All @@ -310,6 +324,10 @@ async def test_get_answers_from_calltranscripts_openai_error(
"client": mock_project_client,
}

# Mock config values
mock_config.AZURE_SEARCH_INDEX = "test-index"
mock_config.AZURE_SEARCH_CONNECTION_NAME = "test-connection"

# Mock project index creation
mock_index = MagicMock()
mock_index.name = "project-index-test"
Expand All @@ -336,9 +354,11 @@ async def test_get_answers_from_calltranscripts_openai_error(
assert "Error retrieving data from call transcripts" in result

@pytest.mark.asyncio
@patch("backend.plugins.chat_with_data_plugin.hasattr", return_value=False)
@patch("backend.agents.agent_factory.AgentFactory.get_search_agent")
@patch("backend.plugins.chat_with_data_plugin.config")
async def test_get_answers_from_calltranscripts_failed_run(
self, mock_get_search_agent
self, mock_config, mock_get_search_agent, mock_hasattr
):
"""Test call transcripts with failed AI Search run."""
# Setup mocks for agent factory
Expand All @@ -351,6 +371,10 @@ async def test_get_answers_from_calltranscripts_failed_run(
"client": mock_project_client,
}

# Mock config values
mock_config.AZURE_SEARCH_INDEX = "test-index"
mock_config.AZURE_SEARCH_CONNECTION_NAME = "test-connection"

# Mock project index creation
mock_index = MagicMock()
mock_index.name = "project-index-test"
Expand Down Expand Up @@ -378,9 +402,11 @@ async def test_get_answers_from_calltranscripts_failed_run(
assert "Error retrieving data from call transcripts" in result

@pytest.mark.asyncio
@patch("backend.plugins.chat_with_data_plugin.hasattr", return_value=False)
@patch("backend.agents.agent_factory.AgentFactory.get_search_agent")
@patch("backend.plugins.chat_with_data_plugin.config")
async def test_get_answers_from_calltranscripts_empty_response(
self, mock_get_search_agent
self, mock_config, mock_get_search_agent, mock_hasattr
):
"""Test call transcripts with empty response text."""
# Setup mocks for agent factory
Expand All @@ -393,6 +419,10 @@ async def test_get_answers_from_calltranscripts_empty_response(
"client": mock_project_client,
}

# Mock config values
mock_config.AZURE_SEARCH_INDEX = "test-index"
mock_config.AZURE_SEARCH_CONNECTION_NAME = "test-connection"

# Mock project index creation
mock_index = MagicMock()
mock_index.name = "project-index-test"
Expand Down