diff --git a/src/App/app.py b/src/App/app.py index 901494b2b..036e59d7c 100644 --- a/src/App/app.py +++ b/src/App/app.py @@ -81,16 +81,25 @@ async def startup(): app.wealth_advisor_agent = await AgentFactory.get_wealth_advisor_agent() logging.info("Wealth Advisor Agent initialized during application startup") app.search_agent = await AgentFactory.get_search_agent() - logging.info( - "Call Transcript Search Agent initialized during application startup" - ) + logging.info("Call Transcript Search Agent initialized during application startup") + app.sql_agent = await AgentFactory.get_sql_agent() + logging.info("SQL Agent initialized during application startup") @app.after_serving async def shutdown(): - await AgentFactory.delete_all_agent_instance() - app.wealth_advisor_agent = None - app.search_agent = None - logging.info("Agents cleaned up during application shutdown") + try: + logging.info("Application shutdown initiated...") + await AgentFactory.delete_all_agent_instance() + if hasattr(app, 'wealth_advisor_agent'): + app.wealth_advisor_agent = None + if hasattr(app, 'search_agent'): + app.search_agent = None + if hasattr(app, 'sql_agent'): + app.sql_agent = None + logging.info("Agents cleaned up successfully") + except Exception as e: + logging.error(f"Error during shutdown: {e}") + logging.exception("Detailed error during shutdown") # app.secret_key = secrets.token_hex(16) # app.session_interface = SecureCookieSessionInterface() diff --git a/src/App/backend/agents/agent_factory.py b/src/App/backend/agents/agent_factory.py index df81a2caf..634af8e13 100644 --- a/src/App/backend/agents/agent_factory.py +++ b/src/App/backend/agents/agent_factory.py @@ -7,6 +7,7 @@ """ import asyncio +import logging from typing import Optional from azure.ai.projects import AIProjectClient @@ -26,6 +27,7 @@ class AgentFactory: _lock = asyncio.Lock() _wealth_advisor_agent: Optional[AzureAIAgent] = None _search_agent: Optional[dict] = None + _sql_agent: Optional[dict] = None @classmethod async def get_wealth_advisor_agent(cls): @@ -94,18 +96,67 @@ async def delete_all_agent_instance(cls): Delete the singleton AzureAIAgent instances if it exists. """ async with cls._lock: - if cls._wealth_advisor_agent is not None: - await cls._wealth_advisor_agent.client.agents.delete_agent( - cls._wealth_advisor_agent.id - ) - cls._wealth_advisor_agent = None + logging.info("Starting agent deletion process...") + # Delete Wealth Advisor Agent + if cls._wealth_advisor_agent is not None: + try: + agent_id = cls._wealth_advisor_agent.id + logging.info(f"Deleting wealth advisor agent: {agent_id}") + if hasattr(cls._wealth_advisor_agent, 'client') and cls._wealth_advisor_agent.client: + await cls._wealth_advisor_agent.client.agents.delete_agent(agent_id) + logging.info("Wealth advisor agent deleted successfully") + else: + logging.warning("Wealth advisor agent client is None") + except Exception as e: + logging.error(f"Error deleting wealth advisor agent: {e}") + logging.exception("Detailed wealth advisor agent deletion error") + finally: + cls._wealth_advisor_agent = None + + # Delete Search Agent if cls._search_agent is not None: - cls._search_agent["client"].agents.delete_agent( - cls._search_agent["agent"].id - ) - cls._search_agent["client"].close() - cls._search_agent = None + try: + agent_id = cls._search_agent['agent'].id + logging.info(f"Deleting search agent: {agent_id}") + if cls._search_agent.get("client") and hasattr(cls._search_agent["client"], "agents"): + # AIProjectClient.agents.delete_agent is synchronous, don't await it + cls._search_agent["client"].agents.delete_agent(agent_id) + logging.info("Search agent deleted successfully") + + # Close the client if it has a close method + if hasattr(cls._search_agent["client"], "close"): + cls._search_agent["client"].close() + else: + logging.warning("Search agent client is None or invalid") + except Exception as e: + logging.error(f"Error deleting search agent: {e}") + logging.exception("Detailed search agent deletion error") + finally: + cls._search_agent = None + + # Delete SQL Agent + if cls._sql_agent is not None: + try: + agent_id = cls._sql_agent['agent'].id + logging.info(f"Deleting SQL agent: {agent_id}") + if cls._sql_agent.get("client") and hasattr(cls._sql_agent["client"], "agents"): + # AIProjectClient.agents.delete_agent is synchronous, don't await it + cls._sql_agent["client"].agents.delete_agent(agent_id) + logging.info("SQL agent deleted successfully") + + # Close the client if it has a close method + if hasattr(cls._sql_agent["client"], "close"): + cls._sql_agent["client"].close() + else: + logging.warning("SQL agent client is None or invalid") + except Exception as e: + logging.error(f"Error deleting SQL agent: {e}") + logging.exception("Detailed SQL agent deletion error") + finally: + cls._sql_agent = None + + logging.info("Agent deletion process completed") @classmethod async def get_sql_agent(cls) -> dict: @@ -114,7 +165,7 @@ async def get_sql_agent(cls) -> dict: This agent is used to generate T-SQL queries from natural language input. """ async with cls._lock: - if not hasattr(cls, "_sql_agent") or cls._sql_agent is None: + if cls._sql_agent is None: agent_instructions = config.SQL_SYSTEM_PROMPT or """ You are an expert assistant in generating T-SQL queries based on user questions. diff --git a/src/App/backend/plugins/chat_with_data_plugin.py b/src/App/backend/plugins/chat_with_data_plugin.py index f421af7ef..b8317a919 100644 --- a/src/App/backend/plugins/chat_with_data_plugin.py +++ b/src/App/backend/plugins/chat_with_data_plugin.py @@ -11,6 +11,7 @@ from azure.ai.projects import AIProjectClient from azure.identity import DefaultAzureCredential, get_bearer_token_provider from semantic_kernel.functions.kernel_function_decorator import kernel_function +from quart import current_app from backend.common.config import config from backend.services.sqldb_service import get_connection @@ -41,9 +42,14 @@ async def get_SQL_Response( if not input or not input.strip(): return "Error: Query input is required" + thread = None try: + # TEMPORARY: Use AgentFactory directly to debug the issue + logging.info(f"Using AgentFactory directly for SQL agent for ClientId: {ClientId}") from backend.agents.agent_factory import AgentFactory agent_info = await AgentFactory.get_sql_agent() + + logging.info(f"SQL agent retrieved: {agent_info is not None}") agent = agent_info["agent"] project_client = agent_info["client"] @@ -72,23 +78,27 @@ async def get_SQL_Response( role=MessageRole.AGENT ) sql_query = message.text.value.strip() if message else None + logging.info(f"Generated SQL query: {sql_query}") if not sql_query: return "No SQL query was generated." # Clean up triple backticks (if any) sql_query = sql_query.replace("```sql", "").replace("```", "") + logging.info(f"Cleaned SQL query: {sql_query}") # Execute the query conn = get_connection() cursor = conn.cursor() cursor.execute(sql_query) rows = cursor.fetchall() + logging.info(f"Query returned {len(rows)} rows") if not rows: result = "No data found for that client." else: result = "\n".join(str(row) for row in rows) + logging.info(f"Result preview: {result[:200]}...") conn.close() @@ -96,6 +106,14 @@ async def get_SQL_Response( except Exception as e: logging.exception("Error in get_SQL_Response") return f"Error retrieving SQL data: {str(e)}" + finally: + if thread: + try: + logging.info(f"Attempting to delete thread {thread.id}") + await project_client.agents.threads.delete(thread.id) + logging.info(f"Thread {thread.id} deleted successfully") + except Exception as e: + logging.error(f"Error deleting thread {thread.id}: {str(e)}") @kernel_function( name="ChatWithCallTranscripts", @@ -114,12 +132,17 @@ async def get_answers_from_calltranscripts( if not question or not question.strip(): return "Error: Question input is required" + thread = None try: response_text = "" - from backend.agents.agent_factory import AgentFactory - - agent_info: dict = await AgentFactory.get_search_agent() + # Use the singleton search agent from app context + if not hasattr(current_app, 'search_agent') or current_app.search_agent is None: + logging.error("Search agent not found in app context, falling back to AgentFactory") + from backend.agents.agent_factory import AgentFactory + agent_info = await AgentFactory.get_search_agent() + else: + agent_info = current_app.search_agent agent: Agent = agent_info["agent"] project_client: AIProjectClient = agent_info["client"] @@ -190,7 +213,11 @@ async def get_answers_from_calltranscripts( finally: if thread: - project_client.agents.threads.delete(thread.id) + try: + await project_client.agents.threads.delete(thread.id) + logging.info(f"Thread {thread.id} deleted successfully") + except Exception as e: + logging.error(f"Error deleting thread {thread.id}: {str(e)}") if not response_text.strip(): return "No data found for that client." diff --git a/src/App/tests/backend/plugins/test_chat_with_data_plugin.py b/src/App/tests/backend/plugins/test_chat_with_data_plugin.py index 684c947ae..0494b22d6 100644 --- a/src/App/tests/backend/plugins/test_chat_with_data_plugin.py +++ b/src/App/tests/backend/plugins/test_chat_with_data_plugin.py @@ -180,12 +180,14 @@ async def test_get_sql_response_openai_error(self, mock_get_sql_agent, mock_conf assert "OpenAI API error" in result @pytest.mark.asyncio + @patch("backend.plugins.chat_with_data_plugin.hasattr", return_value=False) @patch("backend.agents.agent_factory.AgentFactory.get_search_agent") + @patch("backend.plugins.chat_with_data_plugin.config") async def test_get_answers_from_calltranscripts_success( - self, mock_get_search_agent + self, mock_config, mock_get_search_agent, mock_hasattr ): """Test successful retrieval of answers from call transcripts using AI Search Agent.""" - # Setup mocks for agent factory + # Setup mocks for agent factory (fallback case when current_app.search_agent is None) mock_agent = MagicMock() mock_agent.id = "test-agent-id" @@ -195,6 +197,10 @@ async def test_get_answers_from_calltranscripts_success( "client": mock_project_client, } + # Mock config values + mock_config.AZURE_SEARCH_INDEX = "test-index" + mock_config.AZURE_SEARCH_CONNECTION_NAME = "test-connection" + # Mock project index creation mock_index = MagicMock() mock_index.name = "project-index-test" @@ -229,7 +235,7 @@ async def test_get_answers_from_calltranscripts_success( assert "Based on call transcripts" in result assert "investment options" in result - # Verify agent factory was called + # Verify agent factory was called (fallback case) mock_get_search_agent.assert_called_once() # Verify project index was created/updated @@ -249,9 +255,11 @@ async def test_get_answers_from_calltranscripts_success( mock_project_client.agents.runs.create_and_process.assert_called_once() @pytest.mark.asyncio + @patch("backend.plugins.chat_with_data_plugin.hasattr", return_value=False) @patch("backend.agents.agent_factory.AgentFactory.get_search_agent") + @patch("backend.plugins.chat_with_data_plugin.config") async def test_get_answers_from_calltranscripts_no_results( - self, mock_get_search_agent + self, mock_config, mock_get_search_agent, mock_hasattr ): """Test call transcripts search with no results.""" # Setup mocks for agent factory @@ -264,6 +272,10 @@ async def test_get_answers_from_calltranscripts_no_results( "client": mock_project_client, } + # Mock config values + mock_config.AZURE_SEARCH_INDEX = "test-index" + mock_config.AZURE_SEARCH_CONNECTION_NAME = "test-connection" + # Mock project index creation mock_index = MagicMock() mock_index.name = "project-index-test" @@ -295,9 +307,11 @@ async def test_get_answers_from_calltranscripts_no_results( assert "No data found for that client." in result @pytest.mark.asyncio + @patch("backend.plugins.chat_with_data_plugin.hasattr", return_value=False) @patch("backend.agents.agent_factory.AgentFactory.get_search_agent") + @patch("backend.plugins.chat_with_data_plugin.config") async def test_get_answers_from_calltranscripts_openai_error( - self, mock_get_search_agent + self, mock_config, mock_get_search_agent, mock_hasattr ): """Test call transcripts with AI Search processing error.""" # Setup mocks for agent factory @@ -310,6 +324,10 @@ async def test_get_answers_from_calltranscripts_openai_error( "client": mock_project_client, } + # Mock config values + mock_config.AZURE_SEARCH_INDEX = "test-index" + mock_config.AZURE_SEARCH_CONNECTION_NAME = "test-connection" + # Mock project index creation mock_index = MagicMock() mock_index.name = "project-index-test" @@ -336,9 +354,11 @@ async def test_get_answers_from_calltranscripts_openai_error( assert "Error retrieving data from call transcripts" in result @pytest.mark.asyncio + @patch("backend.plugins.chat_with_data_plugin.hasattr", return_value=False) @patch("backend.agents.agent_factory.AgentFactory.get_search_agent") + @patch("backend.plugins.chat_with_data_plugin.config") async def test_get_answers_from_calltranscripts_failed_run( - self, mock_get_search_agent + self, mock_config, mock_get_search_agent, mock_hasattr ): """Test call transcripts with failed AI Search run.""" # Setup mocks for agent factory @@ -351,6 +371,10 @@ async def test_get_answers_from_calltranscripts_failed_run( "client": mock_project_client, } + # Mock config values + mock_config.AZURE_SEARCH_INDEX = "test-index" + mock_config.AZURE_SEARCH_CONNECTION_NAME = "test-connection" + # Mock project index creation mock_index = MagicMock() mock_index.name = "project-index-test" @@ -378,9 +402,11 @@ async def test_get_answers_from_calltranscripts_failed_run( assert "Error retrieving data from call transcripts" in result @pytest.mark.asyncio + @patch("backend.plugins.chat_with_data_plugin.hasattr", return_value=False) @patch("backend.agents.agent_factory.AgentFactory.get_search_agent") + @patch("backend.plugins.chat_with_data_plugin.config") async def test_get_answers_from_calltranscripts_empty_response( - self, mock_get_search_agent + self, mock_config, mock_get_search_agent, mock_hasattr ): """Test call transcripts with empty response text.""" # Setup mocks for agent factory @@ -393,6 +419,10 @@ async def test_get_answers_from_calltranscripts_empty_response( "client": mock_project_client, } + # Mock config values + mock_config.AZURE_SEARCH_INDEX = "test-index" + mock_config.AZURE_SEARCH_CONNECTION_NAME = "test-connection" + # Mock project index creation mock_index = MagicMock() mock_index.name = "project-index-test"