diff --git a/lightspeed-stack.yaml b/lightspeed-stack.yaml index 200bad48..96425376 100644 --- a/lightspeed-stack.yaml +++ b/lightspeed-stack.yaml @@ -1,4 +1,5 @@ name: foo bar baz llama_stack: + use_as_library_client: false url: http://localhost:8321 api_key: xyzzy diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index da4164db..fc1f5264 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -3,11 +3,13 @@ import logging from typing import Any +from llama_stack.distribution.library_client import LlamaStackAsLibraryClient from llama_stack_client import LlamaStackClient from fastapi import APIRouter, Request from configuration import configuration +from models.config import LLamaStackConfiguration from models.responses import QueryResponse logger = logging.getLogger(__name__) @@ -26,6 +28,7 @@ def info_endpoint_handler(request: Request, query: str) -> QueryResponse: llama_stack_config = configuration.llama_stack_configuration logger.info("LLama stack config: %s", llama_stack_config) + client = getLLamaStackClient(llama_stack_config) client = LlamaStackClient( base_url=llama_stack_config.url, api_key=llama_stack_config.api_key ) @@ -47,3 +50,16 @@ def info_endpoint_handler(request: Request, query: str) -> QueryResponse: ], ) return QueryResponse(query=query, response=str(response.completion_message.content)) + + +def getLLamaStackClient( + llama_stack_config: LLamaStackConfiguration, +) -> LlamaStackClient: + if llama_stack_config.use_as_library_client is True: + client = LlamaStackAsLibraryClient("ollama") + client.initialize() + return client + else: + return LlamaStackClient( + base_url=llama_stack_config.url, api_key=llama_stack_config.api_key + ) diff --git a/src/models/config.py b/src/models/config.py index 606094a9..6ffb8f51 100644 --- a/src/models/config.py +++ b/src/models/config.py @@ -1,13 +1,30 @@ -from pydantic import BaseModel +from pydantic import BaseModel, model_validator from typing import Optional +from typing_extensions import Self class LLamaStackConfiguration(BaseModel): """Llama stack configuration.""" - url: str + url: Optional[str] = None api_key: Optional[str] = None + use_as_library_client: Optional[bool] = None + + @model_validator(mode="after") + def check_llama_stack_model(self) -> Self: + if self.url is None: + if self.use_as_library_client is None: + raise ValueError( + "LLama stack URL is not specified and library client mode is not specified" + ) + if self.use_as_library_client is False: + raise ValueError( + "LLama stack URL is not specified and library client mode is not enabled" + ) + if self.use_as_library_client is None: + self.use_as_library_client = False + return self class Configuration(BaseModel):