Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 109 additions & 8 deletions src/app/endpoints/rlsapi_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,31 @@
"""

import logging
from typing import Annotated, Any
from typing import Annotated, Any, cast

from fastapi import APIRouter, Depends
from fastapi import APIRouter, Depends, HTTPException
from llama_stack_client import APIConnectionError # type: ignore
from llama_stack_client.types import UserMessage # type: ignore
from llama_stack_client.types.alpha.agents.turn import Turn

import constants
from authentication import get_auth_dependency
from authentication.interface import AuthTuple
from authorization.middleware import authorize
from client import AsyncLlamaStackClientHolder
from configuration import configuration
from models.config import Action
from models.responses import (
ForbiddenResponse,
ServiceUnavailableResponse,
UnauthorizedResponse,
UnprocessableEntityResponse,
)
from models.rlsapi.requests import RlsapiV1InferRequest
from models.rlsapi.responses import RlsapiV1InferData, RlsapiV1InferResponse
from utils.endpoints import get_temp_agent
from utils.suid import get_suid
from utils.types import content_to_str

logger = logging.getLogger(__name__)
router = APIRouter(tags=["rlsapi-v1"])
Expand All @@ -33,9 +42,84 @@
),
403: ForbiddenResponse.openapi_response(examples=["endpoint"]),
422: UnprocessableEntityResponse.openapi_response(),
503: ServiceUnavailableResponse.openapi_response(),
}


def _get_default_model_id() -> str:
"""Get the default model ID from configuration.

Returns the model identifier in Llama Stack format (provider/model).

Returns:
The model identifier string.

Raises:
HTTPException: If no model can be determined from configuration.
"""
if configuration.inference is None:
msg = "No inference configuration available"
logger.error(msg)
raise HTTPException(
status_code=503,
detail={"response": "Service configuration error", "cause": msg},
)

model_id = configuration.inference.default_model
provider_id = configuration.inference.default_provider

if model_id and provider_id:
return f"{provider_id}/{model_id}"

msg = "No default model configured for rlsapi v1 inference"
logger.error(msg)
raise HTTPException(
status_code=503,
detail={"response": "Service configuration error", "cause": msg},
)


async def retrieve_simple_response(question: str) -> str:
"""Retrieve a simple response from the LLM for a stateless query.

Creates a temporary agent, sends a single turn with the user's question,
and returns the LLM response text. No conversation persistence or tools.

Args:
question: The combined user input (question + context).

Returns:
The LLM-generated response text.

Raises:
APIConnectionError: If the Llama Stack service is unreachable.
HTTPException: 503 if no model is configured.
"""
client = AsyncLlamaStackClientHolder().get_client()
model_id = _get_default_model_id()

logger.debug("Using model %s for rlsapi v1 inference", model_id)

agent, session_id, _ = await get_temp_agent(
client, model_id, constants.DEFAULT_SYSTEM_PROMPT
)

response = await agent.create_turn(
messages=[UserMessage(role="user", content=question).model_dump()],
session_id=session_id,
stream=False,
)
response = cast(Turn, response)

if getattr(response, "output_message", None) is None:
return ""

if getattr(response.output_message, "content", None) is None:
return ""

return content_to_str(response.output_message.content)


@router.post("/infer", responses=infer_responses)
@authorize(Action.RLSAPI_V1_INFER)
async def infer_endpoint(
Expand All @@ -55,6 +139,9 @@ async def infer_endpoint(

Returns:
RlsapiV1InferResponse containing the generated response text and request ID.

Raises:
HTTPException: 503 if the LLM service is unavailable.
"""
# Authentication enforced by get_auth_dependency(), authorization by @authorize decorator.
_ = auth
Expand All @@ -66,14 +153,28 @@ async def infer_endpoint(

# Combine all input sources (question, stdin, attachments, terminal)
input_source = infer_request.get_input_source()
logger.debug("Combined input source length: %d", len(input_source))

# NOTE(major): Placeholder until we wire up the LLM integration.
response_text = (
"Inference endpoint is functional. "
"LLM integration will be added in a subsequent update."
logger.debug(
"Request %s: Combined input source length: %d", request_id, len(input_source)
)

try:
response_text = await retrieve_simple_response(input_source)
except APIConnectionError as e:
logger.error(
"Unable to connect to Llama Stack for request %s: %s", request_id, e
)
response = ServiceUnavailableResponse(
backend_name="Llama Stack",
cause=str(e),
)
raise HTTPException(**response.model_dump()) from e

if not response_text:
logger.warning("Empty response from LLM for request %s", request_id)
response_text = constants.UNABLE_TO_PROCESS_RESPONSE

logger.info("Completed rlsapi v1 /infer request %s", request_id)

return RlsapiV1InferResponse(
data=RlsapiV1InferData(
text=response_text,
Expand Down
Loading
Loading