Skip to content

Commit 8680db1

Browse files
committed
Refactor RAI agent creation to use team and memory store
Updated RAI agent creation and compliance check functions to require explicit team configuration and memory store parameters. Refactored API endpoints to retrieve team and memory store before RAI checks, improving context handling and error reporting.
1 parent 4632bca commit 8680db1

File tree

2 files changed

+47
-17
lines changed

2 files changed

+47
-17
lines changed

src/backend/common/utils/utils_af.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import logging
44

55
# Converted import path (agent_framework version of FoundryAgentTemplate)
6+
from common.database.database_base import DatabaseBase
7+
from common.models.messages_af import TeamConfiguration
68
from v4.common.services.team_service import TeamService
79
from v4.magentic_agents.foundry_agent import FoundryAgentTemplate # formerly v4.magentic_agents.foundry_agent
810
from v4.config.agent_registry import agent_registry
@@ -34,7 +36,7 @@ async def find_first_available_team(team_service: TeamService, user_id: str) ->
3436
print("No teams found in priority order")
3537
return None
3638

37-
async def create_RAI_agent() -> FoundryAgentTemplate:
39+
async def create_RAI_agent(team: TeamConfiguration, memory_store: DatabaseBase) -> FoundryAgentTemplate:
3840
"""Create and initialize a FoundryAgentTemplate for Responsible AI (RAI) checks."""
3941
agent_name = "RAIAgent"
4042
agent_description = "A comprehensive research assistant for integration testing"
@@ -53,6 +55,9 @@ async def create_RAI_agent() -> FoundryAgentTemplate:
5355
)
5456

5557
model_deployment_name = config.AZURE_OPENAI_DEPLOYMENT_NAME
58+
team.team_id = "rai_team" # Use a fixed team ID for RAI agent
59+
team.name = "RAI Team"
60+
team.description = "Team responsible for Responsible AI checks"
5661
agent = FoundryAgentTemplate(
5762
agent_name=agent_name,
5863
agent_description=agent_description,
@@ -62,6 +67,8 @@ async def create_RAI_agent() -> FoundryAgentTemplate:
6267
project_endpoint=config.AZURE_AI_PROJECT_ENDPOINT,
6368
mcp_config=None,
6469
search_config=None,
70+
team_config=team,
71+
memory_store=memory_store,
6572
)
6673

6774
await agent.open()
@@ -104,14 +111,14 @@ async def _get_agent_response(agent: FoundryAgentTemplate, query: str) -> str:
104111
return "TRUE" # Default to blocking on error
105112

106113

107-
async def rai_success(description: str) -> bool:
114+
async def rai_success(description: str, team_config: TeamConfiguration, memory_store: DatabaseBase) -> bool:
108115
"""
109116
Run a RAI compliance check on the provided description using the RAIAgent.
110117
Returns True if content is safe (should proceed), False if it should be blocked.
111118
"""
112119
agent: FoundryAgentTemplate | None = None
113120
try:
114-
agent = await create_RAI_agent()
121+
agent = await create_RAI_agent(team_config, memory_store)
115122
if not agent:
116123
logging.error("Failed to instantiate RAIAgent.")
117124
return False

src/backend/v4/api/router.py

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -229,8 +229,25 @@ async def process_request(
229229
type: string
230230
description: Error message
231231
"""
232-
233-
if not await rai_success(input_task.description):
232+
try:
233+
memory_store = await DatabaseFactory.get_database(user_id=user_id)
234+
user_current_team = await memory_store.get_current_team(user_id=user_id)
235+
team_id = None
236+
if user_current_team:
237+
team_id = user_current_team.team_id
238+
team = await memory_store.get_team_by_id(team_id=team_id)
239+
if not team:
240+
raise HTTPException(
241+
status_code=404,
242+
detail=f"Team configuration '{team_id}' not found or access denied",
243+
)
244+
except Exception as e:
245+
raise HTTPException(
246+
status_code=400,
247+
detail=f"Error retrieving team configuration: {e}",
248+
) from e
249+
250+
if not await rai_success(input_task.description, team, memory_store):
234251
track_event_if_configured(
235252
"RAI failed",
236253
{
@@ -264,17 +281,6 @@ async def process_request(
264281
try:
265282
plan_id = str(uuid.uuid4())
266283
# Initialize memory store and service
267-
memory_store = await DatabaseFactory.get_database(user_id=user_id)
268-
user_current_team = await memory_store.get_current_team(user_id=user_id)
269-
team_id = None
270-
if user_current_team:
271-
team_id = user_current_team.team_id
272-
team = await memory_store.get_team_by_id(team_id=team_id)
273-
if not team:
274-
raise HTTPException(
275-
status_code=404,
276-
detail=f"Team configuration '{team_id}' not found or access denied",
277-
)
278284
plan = Plan(
279285
id=plan_id,
280286
plan_id=plan_id,
@@ -507,11 +513,28 @@ async def user_clarification(
507513
raise HTTPException(
508514
status_code=401, detail="Missing or invalid user information"
509515
)
516+
try:
517+
memory_store = await DatabaseFactory.get_database(user_id=user_id)
518+
user_current_team = await memory_store.get_current_team(user_id=user_id)
519+
team_id = None
520+
if user_current_team:
521+
team_id = user_current_team.team_id
522+
team = await memory_store.get_team_by_id(team_id=team_id)
523+
if not team:
524+
raise HTTPException(
525+
status_code=404,
526+
detail=f"Team configuration '{team_id}' not found or access denied",
527+
)
528+
except Exception as e:
529+
raise HTTPException(
530+
status_code=400,
531+
detail=f"Error retrieving team configuration: {e}",
532+
) from e
510533
# Set the approval in the orchestration config
511534
if user_id and human_feedback.request_id:
512535
# validate rai
513536
if human_feedback.answer is not None or human_feedback.answer != "":
514-
if not await rai_success(human_feedback.answer):
537+
if not await rai_success(human_feedback.answer, team, memory_store):
515538
track_event_if_configured(
516539
"RAI failed",
517540
{

0 commit comments

Comments
 (0)