Skip to content

Commit bf07451

Browse files
authored
Merge pull request #343 from keitwb/user-id-from-auth
Implement User ID from Auth Handling
2 parents e7c5cfc + 8a6466d commit bf07451

File tree

9 files changed

+27
-54
lines changed

9 files changed

+27
-54
lines changed

src/app/endpoints/feedback.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
"""Handler for REST API call to provide info."""
22

33
import logging
4-
from typing import Any
4+
from typing import Annotated, Any
55
from pathlib import Path
66
import json
77
from datetime import datetime, UTC
8-
98
from fastapi import APIRouter, Request, HTTPException, Depends, status
109

1110
from auth import get_auth_dependency
11+
from auth.interface import AuthTuple
1212
from configuration import configuration
1313
from models.responses import (
1414
FeedbackResponse,
@@ -18,7 +18,6 @@
1818
)
1919
from models.requests import FeedbackRequest
2020
from utils.suid import get_suid
21-
from utils.common import retrieve_user_id
2221

2322
logger = logging.getLogger(__name__)
2423
router = APIRouter(prefix="/feedback", tags=["feedback"])
@@ -66,10 +65,9 @@ async def assert_feedback_enabled(_request: Request) -> None:
6665

6766
@router.post("", responses=feedback_response)
6867
def feedback_endpoint_handler(
69-
_request: Request,
7068
feedback_request: FeedbackRequest,
69+
auth: Annotated[AuthTuple, Depends(auth_dependency)],
7170
_ensure_feedback_enabled: Any = Depends(assert_feedback_enabled),
72-
auth: Any = Depends(auth_dependency),
7371
) -> FeedbackResponse:
7472
"""Handle feedback requests.
7573
@@ -85,7 +83,7 @@ def feedback_endpoint_handler(
8583
"""
8684
logger.debug("Feedback received %s", str(feedback_request))
8785

88-
user_id = retrieve_user_id(auth)
86+
user_id, _, _ = auth
8987
try:
9088
store_feedback(user_id, feedback_request.model_dump(exclude={"model_config"}))
9189
except Exception as e:

src/app/endpoints/query.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import logging
77
import os
88
from pathlib import Path
9-
from typing import Any
9+
from typing import Annotated, Any
1010

1111
from llama_stack_client.lib.agents.agent import Agent
1212
from llama_stack_client import APIConnectionError
@@ -27,7 +27,7 @@
2727
from models.requests import QueryRequest, Attachment
2828
import constants
2929
from auth import get_auth_dependency
30-
from utils.common import retrieve_user_id
30+
from auth.interface import AuthTuple
3131
from utils.endpoints import check_configuration_loaded, get_system_prompt
3232
from utils.mcp_headers import mcp_headers_dependency, handle_mcp_headers_with_toolgroups
3333
from utils.suid import get_suid
@@ -116,7 +116,7 @@ def get_agent( # pylint: disable=too-many-arguments,too-many-positional-argumen
116116
@router.post("/query", responses=query_response)
117117
def query_endpoint_handler(
118118
query_request: QueryRequest,
119-
auth: Any = Depends(auth_dependency),
119+
auth: Annotated[AuthTuple, Depends(auth_dependency)],
120120
mcp_headers: dict[str, dict[str, str]] = Depends(mcp_headers_dependency),
121121
) -> QueryResponse:
122122
"""Handle request to the /query endpoint."""
@@ -125,7 +125,7 @@ def query_endpoint_handler(
125125
llama_stack_config = configuration.llama_stack_configuration
126126
logger.info("LLama stack config: %s", llama_stack_config)
127127

128-
_user_id, _user_name, token = auth
128+
user_id, _, token = auth
129129

130130
try:
131131
# try to get Llama Stack client
@@ -147,7 +147,7 @@ def query_endpoint_handler(
147147
logger.debug("Transcript collection is disabled in the configuration")
148148
else:
149149
store_transcript(
150-
user_id=retrieve_user_id(auth),
150+
user_id=user_id,
151151
conversation_id=conversation_id,
152152
query_is_valid=True, # TODO(lucasagomes): implement as part of query validation
153153
query=query_request.query,

src/app/endpoints/streaming_query.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import json
66
import re
77
import logging
8-
from typing import Any, AsyncIterator, Iterator
8+
from typing import Annotated, Any, AsyncIterator, Iterator
99

1010
from llama_stack_client import APIConnectionError
1111
from llama_stack_client.lib.agents.agent import AsyncAgent # type: ignore
@@ -20,12 +20,12 @@
2020
from fastapi.responses import StreamingResponse
2121

2222
from auth import get_auth_dependency
23+
from auth.interface import AuthTuple
2324
from client import AsyncLlamaStackClientHolder
2425
from configuration import configuration
2526
import metrics
2627
from models.requests import QueryRequest
2728
from utils.endpoints import check_configuration_loaded, get_system_prompt
28-
from utils.common import retrieve_user_id
2929
from utils.mcp_headers import mcp_headers_dependency, handle_mcp_headers_with_toolgroups
3030
from utils.suid import get_suid
3131
from utils.types import GraniteToolParser
@@ -431,7 +431,7 @@ def _handle_heartbeat_event(chunk_id: int) -> Iterator[str]:
431431
async def streaming_query_endpoint_handler(
432432
_request: Request,
433433
query_request: QueryRequest,
434-
auth: Any = Depends(auth_dependency),
434+
auth: Annotated[AuthTuple, Depends(auth_dependency)],
435435
mcp_headers: dict[str, dict[str, str]] = Depends(mcp_headers_dependency),
436436
) -> StreamingResponse:
437437
"""Handle request to the /streaming_query endpoint."""
@@ -440,7 +440,7 @@ async def streaming_query_endpoint_handler(
440440
llama_stack_config = configuration.llama_stack_configuration
441441
logger.info("LLama stack config: %s", llama_stack_config)
442442

443-
_user_id, _user_name, token = auth
443+
user_id, _user_name, token = auth
444444

445445
try:
446446
# try to get Llama Stack client
@@ -483,7 +483,7 @@ async def response_generator(turn_response: Any) -> AsyncIterator[str]:
483483
logger.debug("Transcript collection is disabled in the configuration")
484484
else:
485485
store_transcript(
486-
user_id=retrieve_user_id(auth),
486+
user_id=user_id,
487487
conversation_id=conversation_id,
488488
query_is_valid=True, # TODO(lucasagomes): implement as part of query validation
489489
query=query_request.query,

src/auth/interface.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,16 @@
44

55
from fastapi import Request
66

7+
UserID = str
8+
UserName = str
9+
Token = str
10+
11+
AuthTuple = tuple[UserID, UserName, Token]
12+
713

814
class AuthInterface(ABC): # pylint: disable=too-few-public-methods
915
"""Base class for all authentication method implementations."""
1016

1117
@abstractmethod
12-
async def __call__(self, request: Request) -> tuple[str, str, str]:
18+
async def __call__(self, request: Request) -> AuthTuple:
1319
"""Validate FastAPI Requests for authentication and authorization."""

src/utils/common.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
import asyncio
44
from functools import wraps
5+
from typing import Any, Callable, List, cast
56
from logging import Logger
6-
from typing import Any, List, cast, Callable
77

88
from llama_stack_client import LlamaStackClient, AsyncLlamaStackClient
99
from llama_stack.distribution.library_client import (
@@ -14,20 +14,6 @@
1414
from models.config import Configuration, ModelContextProtocolServer
1515

1616

17-
# TODO(lucasagomes): implement this function to retrieve user ID from auth
18-
def retrieve_user_id(auth: Any) -> str: # pylint: disable=unused-argument
19-
"""Retrieve the user ID from the authentication handler.
20-
21-
Args:
22-
auth: The Authentication handler (FastAPI Depends) that will
23-
handle authentication Logic.
24-
25-
Returns:
26-
str: The user ID.
27-
"""
28-
return "user_id_placeholder"
29-
30-
3117
async def register_mcp_servers_async(
3218
logger: Logger, configuration: Configuration
3319
) -> None:

tests/unit/app/endpoints/test_feedback.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@ def test_feedback_endpoint_handler(mocker, feedback_request_data):
6767

6868
# Mock the dependencies
6969
mocker.patch("app.endpoints.feedback.assert_feedback_enabled", return_value=None)
70-
mocker.patch("utils.common.retrieve_user_id", return_value="test_user_id")
7170
mocker.patch("app.endpoints.feedback.store_feedback", return_value=None)
7271

7372
# Prepare the feedback request mock
@@ -76,8 +75,8 @@ def test_feedback_endpoint_handler(mocker, feedback_request_data):
7675

7776
# Call the endpoint handler
7877
result = feedback_endpoint_handler(
79-
_request=mocker.Mock(),
8078
feedback_request=feedback_request,
79+
auth=["test-user", "", ""],
8180
_ensure_feedback_enabled=assert_feedback_enabled,
8281
)
8382

@@ -89,7 +88,6 @@ def test_feedback_endpoint_handler_error(mocker):
8988
"""Test that feedback_endpoint_handler raises an HTTPException on error."""
9089
# Mock the dependencies
9190
mocker.patch("app.endpoints.feedback.assert_feedback_enabled", return_value=None)
92-
mocker.patch("utils.common.retrieve_user_id", return_value="test_user_id")
9391
mocker.patch(
9492
"app.endpoints.feedback.store_feedback",
9593
side_effect=Exception("Error storing feedback"),
@@ -101,8 +99,8 @@ def test_feedback_endpoint_handler_error(mocker):
10199
# Call the endpoint handler and assert it raises an exception
102100
with pytest.raises(HTTPException) as exc_info:
103101
feedback_endpoint_handler(
104-
_request=mocker.Mock(),
105102
feedback_request=feedback_request,
103+
auth=["test-user", "", ""],
106104
_ensure_feedback_enabled=assert_feedback_enabled,
107105
)
108106

tests/unit/app/endpoints/test_query.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def test_query_endpoint_handler_configuration_not_loaded(mocker):
7777

7878
request = None
7979
with pytest.raises(HTTPException) as e:
80-
query_endpoint_handler(request)
80+
query_endpoint_handler(request, auth=["test-user", "", "token"])
8181
assert e.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
8282
assert e.detail["response"] == "Configuration is not loaded"
8383

@@ -152,7 +152,7 @@ def _test_query_endpoint_handler(mocker, store_transcript_to_file=False):
152152
# Assert the store_transcript function is called if transcripts are enabled
153153
if store_transcript_to_file:
154154
mock_transcript.assert_called_once_with(
155-
user_id="user_id_placeholder",
155+
user_id="mock_user_id",
156156
conversation_id=conversation_id,
157157
query_is_valid=True,
158158
query=query,

tests/unit/app/endpoints/test_streaming_query.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -261,10 +261,6 @@ async def _test_streaming_query_endpoint_handler(mocker, store_transcript=False)
261261
"app.endpoints.streaming_query.is_transcripts_enabled",
262262
return_value=store_transcript,
263263
)
264-
mocker.patch(
265-
"app.endpoints.streaming_query.retrieve_user_id",
266-
return_value="user_id_placeholder",
267-
)
268264
mock_transcript = mocker.patch("app.endpoints.streaming_query.store_transcript")
269265

270266
query_request = QueryRequest(query=query)
@@ -303,7 +299,7 @@ async def _test_streaming_query_endpoint_handler(mocker, store_transcript=False)
303299
# Assert the store_transcript function is called if transcripts are enabled
304300
if store_transcript:
305301
mock_transcript.assert_called_once_with(
306-
user_id="user_id_placeholder",
302+
user_id="mock_user_id",
307303
conversation_id="test_conversation_id",
308304
query_is_valid=True,
309305
query=query,
@@ -1583,9 +1579,6 @@ async def test_auth_tuple_unpacking_in_streaming_query_endpoint_handler(mocker):
15831579
mocker.patch(
15841580
"app.endpoints.streaming_query.is_transcripts_enabled", return_value=False
15851581
)
1586-
mocker.patch(
1587-
"app.endpoints.streaming_query.retrieve_user_id", return_value="user123"
1588-
)
15891582

15901583
await streaming_query_endpoint_handler(
15911584
None,

tests/unit/utils/test_common.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import pytest
77

88
from utils.common import (
9-
retrieve_user_id,
109
register_mcp_servers_async,
1110
)
1211
from models.config import (
@@ -18,13 +17,6 @@
1817
)
1918

2019

21-
# TODO(lucasagomes): Implement this test when the retrieve_user_id function is implemented
22-
def test_retrieve_user_id():
23-
"""Test that retrieve_user_id returns a user ID."""
24-
user_id = retrieve_user_id(None)
25-
assert user_id == "user_id_placeholder"
26-
27-
2820
@pytest.mark.asyncio
2921
async def test_register_mcp_servers_empty_list(mocker):
3022
"""Test register_mcp_servers with empty MCP servers list."""

0 commit comments

Comments
 (0)