Skip to content

Commit 4b1d309

Browse files
committed
Refactor team config validation and user checks
Updated `rai_validate_team_config` to accept additional parameters for improved validation. Refactored user authentication and error handling in API routes to ensure user and team existence before proceeding, and improved error reporting for missing or invalid user/team information.
1 parent 48b3f26 commit 4b1d309

File tree

2 files changed

+31
-26
lines changed

2 files changed

+31
-26
lines changed

src/backend/common/utils/utils_af.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ async def rai_success(description: str, team_config: TeamConfiguration, memory_
145145
pass
146146

147147

148-
async def rai_validate_team_config(team_config_json: dict) -> tuple[bool, str]:
148+
async def rai_validate_team_config(team_config_json: dict, team_config: TeamConfiguration, memory_store: DatabaseBase) -> tuple[bool, str]:
149149
"""
150150
Validate a team configuration for RAI compliance.
151151
@@ -187,7 +187,7 @@ async def rai_validate_team_config(team_config_json: dict) -> tuple[bool, str]:
187187
if not combined:
188188
return False, "Team configuration contains no readable text content."
189189

190-
if not await rai_success(combined):
190+
if not await rai_success(combined, team_config, memory_store):
191191
return (
192192
False,
193193
"Team configuration contains inappropriate content and cannot be uploaded.",

src/backend/v4/api/router.py

Lines changed: 29 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,13 @@ async def process_request(
229229
type: string
230230
description: Error message
231231
"""
232+
authenticated_user = get_authenticated_user_details(request_headers=request.headers)
233+
user_id = authenticated_user["user_principal_id"]
234+
if not user_id:
235+
track_event_if_configured(
236+
"UserIdNotFound", {"status_code": 400, "detail": "no user"}
237+
)
238+
raise HTTPException(status_code=400, detail="no user found")
232239
try:
233240
memory_store = await DatabaseFactory.get_database(user_id=user_id)
234241
user_current_team = await memory_store.get_current_team(user_id=user_id)
@@ -261,20 +268,6 @@ async def process_request(
261268
detail="Request contains content that doesn't meet our safety guidelines, try again.",
262269
)
263270

264-
authenticated_user = get_authenticated_user_details(request_headers=request.headers)
265-
user_id = authenticated_user["user_principal_id"]
266-
267-
if not user_id:
268-
track_event_if_configured(
269-
"UserIdNotFound", {"status_code": 400, "detail": "no user"}
270-
)
271-
raise HTTPException(status_code=400, detail="no user found")
272-
273-
# if not input_task.team_id:
274-
# track_event_if_configured(
275-
# "TeamIDNofound", {"status_code": 400, "detail": "no team id"}
276-
# )
277-
# raise HTTPException(status_code=400, detail="no team id")
278271

279272
if not input_task.session_id:
280273
input_task.session_id = str(uuid.uuid4())
@@ -315,12 +308,9 @@ async def process_request(
315308
"error": str(e),
316309
},
317310
)
318-
raise HTTPException(status_code=500, detail="Failed to create plan")
311+
raise HTTPException(status_code=500, detail="Failed to create plan") from e
319312

320313
try:
321-
# background_tasks.add_task(
322-
# lambda: current_context.run(lambda:OrchestrationManager().run_orchestration, user_id, input_task)
323-
# )
324314

325315
async def run_orchestration_task():
326316
await OrchestrationManager().run_orchestration(user_id, input_task)
@@ -714,10 +704,27 @@ async def upload_team_config(
714704
authenticated_user = get_authenticated_user_details(request_headers=request.headers)
715705
user_id = authenticated_user["user_principal_id"]
716706
if not user_id:
707+
track_event_if_configured(
708+
"UserIdNotFound", {"status_code": 400, "detail": "no user"}
709+
)
710+
raise HTTPException(status_code=400, detail="no user found")
711+
try:
712+
memory_store = await DatabaseFactory.get_database(user_id=user_id)
713+
user_current_team = await memory_store.get_current_team(user_id=user_id)
714+
team_id = None
715+
if user_current_team:
716+
team_id = user_current_team.team_id
717+
team = await memory_store.get_team_by_id(team_id=team_id)
718+
if not team:
719+
raise HTTPException(
720+
status_code=404,
721+
detail=f"Team configuration '{team_id}' not found or access denied",
722+
)
723+
except Exception as e:
717724
raise HTTPException(
718-
status_code=401, detail="Missing or invalid user information"
719-
)
720-
725+
status_code=400,
726+
detail=f"Error retrieving team configuration: {e}",
727+
) from e
721728
# Validate file is provided and is JSON
722729
if not file:
723730
raise HTTPException(status_code=400, detail="No file provided")
@@ -737,7 +744,7 @@ async def upload_team_config(
737744

738745
# Validate content with RAI before processing
739746
if not team_id:
740-
rai_valid, rai_error = await rai_validate_team_config(json_data)
747+
rai_valid, rai_error = await rai_validate_team_config(json_data, team, memory_store)
741748
if not rai_valid:
742749
track_event_if_configured(
743750
"Team configuration RAI validation failed",
@@ -754,8 +761,6 @@ async def upload_team_config(
754761
"Team configuration RAI validation passed",
755762
{"status": "passed", "user_id": user_id, "filename": file.filename},
756763
)
757-
# Initialize memory store and service
758-
memory_store = await DatabaseFactory.get_database(user_id=user_id)
759764
team_service = TeamService(memory_store)
760765

761766
# Validate model deployments

0 commit comments

Comments
 (0)