Skip to content

Commit 8a6466d

Browse files
committed
Implement User ID from Auth Handling
This was kind of halfway done before but this should get it closer.
1 parent a3b530d commit 8a6466d

File tree

9 files changed

+27
-54
lines changed

9 files changed

+27
-54
lines changed

src/app/endpoints/feedback.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
"""Handler for REST API call to provide info."""
22

33
import logging
4-
from typing import Any
4+
from typing import Annotated, Any
55
from pathlib import Path
66
import json
77
from datetime import datetime, UTC
8-
98
from fastapi import APIRouter, Request, HTTPException, Depends, status
109

1110
from auth import get_auth_dependency
11+
from auth.interface import AuthTuple
1212
from configuration import configuration
1313
from models.responses import (
1414
FeedbackResponse,
@@ -18,7 +18,6 @@
1818
)
1919
from models.requests import FeedbackRequest
2020
from utils.suid import get_suid
21-
from utils.common import retrieve_user_id
2221

2322
logger = logging.getLogger(__name__)
2423
router = APIRouter(prefix="/feedback", tags=["feedback"])
@@ -66,10 +65,9 @@ async def assert_feedback_enabled(_request: Request) -> None:
6665

6766
@router.post("", responses=feedback_response)
6867
def feedback_endpoint_handler(
69-
_request: Request,
7068
feedback_request: FeedbackRequest,
69+
auth: Annotated[AuthTuple, Depends(auth_dependency)],
7170
_ensure_feedback_enabled: Any = Depends(assert_feedback_enabled),
72-
auth: Any = Depends(auth_dependency),
7371
) -> FeedbackResponse:
7472
"""Handle feedback requests.
7573
@@ -85,7 +83,7 @@ def feedback_endpoint_handler(
8583
"""
8684
logger.debug("Feedback received %s", str(feedback_request))
8785

88-
user_id = retrieve_user_id(auth)
86+
user_id, _, _ = auth
8987
try:
9088
store_feedback(user_id, feedback_request.model_dump(exclude={"model_config"}))
9189
except Exception as e:

src/app/endpoints/query.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import logging
77
import os
88
from pathlib import Path
9-
from typing import Any
9+
from typing import Annotated, Any
1010

1111
from llama_stack_client.lib.agents.agent import Agent
1212
from llama_stack_client import APIConnectionError
@@ -27,7 +27,7 @@
2727
from models.requests import QueryRequest, Attachment
2828
import constants
2929
from auth import get_auth_dependency
30-
from utils.common import retrieve_user_id
30+
from auth.interface import AuthTuple
3131
from utils.endpoints import check_configuration_loaded, get_system_prompt
3232
from utils.mcp_headers import mcp_headers_dependency, handle_mcp_headers_with_toolgroups
3333
from utils.suid import get_suid
@@ -113,7 +113,7 @@ def get_agent( # pylint: disable=too-many-arguments,too-many-positional-argumen
113113
@router.post("/query", responses=query_response)
114114
def query_endpoint_handler(
115115
query_request: QueryRequest,
116-
auth: Any = Depends(auth_dependency),
116+
auth: Annotated[AuthTuple, Depends(auth_dependency)],
117117
mcp_headers: dict[str, dict[str, str]] = Depends(mcp_headers_dependency),
118118
) -> QueryResponse:
119119
"""Handle request to the /query endpoint."""
@@ -122,7 +122,7 @@ def query_endpoint_handler(
122122
llama_stack_config = configuration.llama_stack_configuration
123123
logger.info("LLama stack config: %s", llama_stack_config)
124124

125-
_user_id, _user_name, token = auth
125+
user_id, _, token = auth
126126

127127
try:
128128
# try to get Llama Stack client
@@ -144,7 +144,7 @@ def query_endpoint_handler(
144144
logger.debug("Transcript collection is disabled in the configuration")
145145
else:
146146
store_transcript(
147-
user_id=retrieve_user_id(auth),
147+
user_id=user_id,
148148
conversation_id=conversation_id,
149149
query_is_valid=True, # TODO(lucasagomes): implement as part of query validation
150150
query=query_request.query,

src/app/endpoints/streaming_query.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import json
66
import re
77
import logging
8-
from typing import Any, AsyncIterator, Iterator
8+
from typing import Annotated, Any, AsyncIterator, Iterator
99

1010
from llama_stack_client import APIConnectionError
1111
from llama_stack_client.lib.agents.agent import AsyncAgent # type: ignore
@@ -20,12 +20,12 @@
2020
from fastapi.responses import StreamingResponse
2121

2222
from auth import get_auth_dependency
23+
from auth.interface import AuthTuple
2324
from client import AsyncLlamaStackClientHolder
2425
from configuration import configuration
2526
import metrics
2627
from models.requests import QueryRequest
2728
from utils.endpoints import check_configuration_loaded, get_system_prompt
28-
from utils.common import retrieve_user_id
2929
from utils.mcp_headers import mcp_headers_dependency, handle_mcp_headers_with_toolgroups
3030
from utils.suid import get_suid
3131
from utils.types import GraniteToolParser
@@ -415,7 +415,7 @@ def _handle_heartbeat_event(chunk_id: int) -> Iterator[str]:
415415
async def streaming_query_endpoint_handler(
416416
_request: Request,
417417
query_request: QueryRequest,
418-
auth: Any = Depends(auth_dependency),
418+
auth: Annotated[AuthTuple, Depends(auth_dependency)],
419419
mcp_headers: dict[str, dict[str, str]] = Depends(mcp_headers_dependency),
420420
) -> StreamingResponse:
421421
"""Handle request to the /streaming_query endpoint."""
@@ -424,7 +424,7 @@ async def streaming_query_endpoint_handler(
424424
llama_stack_config = configuration.llama_stack_configuration
425425
logger.info("LLama stack config: %s", llama_stack_config)
426426

427-
_user_id, _user_name, token = auth
427+
user_id, _user_name, token = auth
428428

429429
try:
430430
# try to get Llama Stack client
@@ -463,7 +463,7 @@ async def response_generator(turn_response: Any) -> AsyncIterator[str]:
463463
logger.debug("Transcript collection is disabled in the configuration")
464464
else:
465465
store_transcript(
466-
user_id=retrieve_user_id(auth),
466+
user_id=user_id,
467467
conversation_id=conversation_id,
468468
query_is_valid=True, # TODO(lucasagomes): implement as part of query validation
469469
query=query_request.query,

src/auth/interface.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,16 @@
44

55
from fastapi import Request
66

7+
UserID = str
8+
UserName = str
9+
Token = str
10+
11+
AuthTuple = tuple[UserID, UserName, Token]
12+
713

814
class AuthInterface(ABC): # pylint: disable=too-few-public-methods
915
"""Base class for all authentication method implementations."""
1016

1117
@abstractmethod
12-
async def __call__(self, request: Request) -> tuple[str, str, str]:
18+
async def __call__(self, request: Request) -> AuthTuple:
1319
"""Validate FastAPI Requests for authentication and authorization."""

src/utils/common.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
import asyncio
44
from functools import wraps
5+
from typing import Any, Callable, List, cast
56
from logging import Logger
6-
from typing import Any, List, cast, Callable
77

88
from llama_stack_client import LlamaStackClient, AsyncLlamaStackClient
99
from llama_stack.distribution.library_client import (
@@ -14,20 +14,6 @@
1414
from models.config import Configuration, ModelContextProtocolServer
1515

1616

17-
# TODO(lucasagomes): implement this function to retrieve user ID from auth
18-
def retrieve_user_id(auth: Any) -> str: # pylint: disable=unused-argument
19-
"""Retrieve the user ID from the authentication handler.
20-
21-
Args:
22-
auth: The Authentication handler (FastAPI Depends) that will
23-
handle authentication Logic.
24-
25-
Returns:
26-
str: The user ID.
27-
"""
28-
return "user_id_placeholder"
29-
30-
3117
async def register_mcp_servers_async(
3218
logger: Logger, configuration: Configuration
3319
) -> None:

tests/unit/app/endpoints/test_feedback.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@ def test_feedback_endpoint_handler(mocker, feedback_request_data):
6767

6868
# Mock the dependencies
6969
mocker.patch("app.endpoints.feedback.assert_feedback_enabled", return_value=None)
70-
mocker.patch("utils.common.retrieve_user_id", return_value="test_user_id")
7170
mocker.patch("app.endpoints.feedback.store_feedback", return_value=None)
7271

7372
# Prepare the feedback request mock
@@ -76,8 +75,8 @@ def test_feedback_endpoint_handler(mocker, feedback_request_data):
7675

7776
# Call the endpoint handler
7877
result = feedback_endpoint_handler(
79-
_request=mocker.Mock(),
8078
feedback_request=feedback_request,
79+
auth=["test-user", "", ""],
8180
_ensure_feedback_enabled=assert_feedback_enabled,
8281
)
8382

@@ -89,7 +88,6 @@ def test_feedback_endpoint_handler_error(mocker):
8988
"""Test that feedback_endpoint_handler raises an HTTPException on error."""
9089
# Mock the dependencies
9190
mocker.patch("app.endpoints.feedback.assert_feedback_enabled", return_value=None)
92-
mocker.patch("utils.common.retrieve_user_id", return_value="test_user_id")
9391
mocker.patch(
9492
"app.endpoints.feedback.store_feedback",
9593
side_effect=Exception("Error storing feedback"),
@@ -101,8 +99,8 @@ def test_feedback_endpoint_handler_error(mocker):
10199
# Call the endpoint handler and assert it raises an exception
102100
with pytest.raises(HTTPException) as exc_info:
103101
feedback_endpoint_handler(
104-
_request=mocker.Mock(),
105102
feedback_request=feedback_request,
103+
auth=["test-user", "", ""],
106104
_ensure_feedback_enabled=assert_feedback_enabled,
107105
)
108106

tests/unit/app/endpoints/test_query.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def test_query_endpoint_handler_configuration_not_loaded(mocker):
7777

7878
request = None
7979
with pytest.raises(HTTPException) as e:
80-
query_endpoint_handler(request)
80+
query_endpoint_handler(request, auth=["test-user", "", "token"])
8181
assert e.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
8282
assert e.detail["response"] == "Configuration is not loaded"
8383

@@ -152,7 +152,7 @@ def _test_query_endpoint_handler(mocker, store_transcript_to_file=False):
152152
# Assert the store_transcript function is called if transcripts are enabled
153153
if store_transcript_to_file:
154154
mock_transcript.assert_called_once_with(
155-
user_id="user_id_placeholder",
155+
user_id="mock_user_id",
156156
conversation_id=conversation_id,
157157
query_is_valid=True,
158158
query=query,

tests/unit/app/endpoints/test_streaming_query.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -230,10 +230,6 @@ async def _test_streaming_query_endpoint_handler(mocker, store_transcript=False)
230230
"app.endpoints.streaming_query.is_transcripts_enabled",
231231
return_value=store_transcript,
232232
)
233-
mocker.patch(
234-
"app.endpoints.streaming_query.retrieve_user_id",
235-
return_value="user_id_placeholder",
236-
)
237233
mock_transcript = mocker.patch("app.endpoints.streaming_query.store_transcript")
238234

239235
query_request = QueryRequest(query=query)
@@ -272,7 +268,7 @@ async def _test_streaming_query_endpoint_handler(mocker, store_transcript=False)
272268
# Assert the store_transcript function is called if transcripts are enabled
273269
if store_transcript:
274270
mock_transcript.assert_called_once_with(
275-
user_id="user_id_placeholder",
271+
user_id="mock_user_id",
276272
conversation_id="test_conversation_id",
277273
query_is_valid=True,
278274
query=query,
@@ -1553,9 +1549,6 @@ async def test_auth_tuple_unpacking_in_streaming_query_endpoint_handler(mocker):
15531549
mocker.patch(
15541550
"app.endpoints.streaming_query.is_transcripts_enabled", return_value=False
15551551
)
1556-
mocker.patch(
1557-
"app.endpoints.streaming_query.retrieve_user_id", return_value="user123"
1558-
)
15591552

15601553
await streaming_query_endpoint_handler(
15611554
None,

tests/unit/utils/test_common.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import pytest
77

88
from utils.common import (
9-
retrieve_user_id,
109
register_mcp_servers_async,
1110
)
1211
from models.config import (
@@ -18,13 +17,6 @@
1817
)
1918

2019

21-
# TODO(lucasagomes): Implement this test when the retrieve_user_id function is implemented
22-
def test_retrieve_user_id():
23-
"""Test that retrieve_user_id returns a user ID."""
24-
user_id = retrieve_user_id(None)
25-
assert user_id == "user_id_placeholder"
26-
27-
2820
@pytest.mark.asyncio
2921
async def test_register_mcp_servers_empty_list(mocker):
3022
"""Test register_mcp_servers with empty MCP servers list."""

0 commit comments

Comments
 (0)