Skip to content

Commit 2cc5e93

Browse files
authored
Merge pull request #187 from manstis/LCORE-323
LCORE-323: LlamaStackClient should be a singleton
2 parents d678ede + dacd970 commit 2cc5e93

File tree

14 files changed

+156
-127
lines changed

14 files changed

+156
-127
lines changed

src/app/endpoints/health.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,7 @@
1111
from llama_stack.providers.datatypes import HealthStatus
1212

1313
from fastapi import APIRouter, status, Response
14-
from client import get_llama_stack_client
15-
from configuration import configuration
14+
from client import LlamaStackClientHolder
1615
from models.responses import (
1716
LivenessResponse,
1817
ReadinessResponse,
@@ -30,9 +29,7 @@ def get_providers_health_statuses() -> list[ProviderHealthStatus]:
3029
List of provider health statuses.
3130
"""
3231
try:
33-
llama_stack_config = configuration.llama_stack_configuration
34-
35-
client = get_llama_stack_client(llama_stack_config)
32+
client = LlamaStackClientHolder().get_client()
3633

3734
providers = client.providers.list()
3835
logger.debug("Found %d providers", len(providers))

src/app/endpoints/models.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from llama_stack_client import APIConnectionError
77
from fastapi import APIRouter, HTTPException, Request, status
88

9-
from client import get_llama_stack_client
9+
from client import LlamaStackClientHolder
1010
from configuration import configuration
1111
from models.responses import ModelsResponse
1212
from utils.endpoints import check_configuration_loaded
@@ -52,11 +52,12 @@ def models_endpoint_handler(_request: Request) -> ModelsResponse:
5252

5353
try:
5454
# try to get Llama Stack client
55-
client = get_llama_stack_client(llama_stack_configuration)
55+
client = LlamaStackClientHolder().get_client()
5656
# retrieve models
5757
models = client.models.list()
5858
m = [dict(m) for m in models]
5959
return ModelsResponse(models=m)
60+
6061
# connection to Llama Stack server
6162
except APIConnectionError as e:
6263
logger.error("Unable to connect to Llama Stack: %s", e)

src/app/endpoints/query.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from fastapi import APIRouter, HTTPException, status, Depends
2323

24-
from client import get_llama_stack_client
24+
from client import LlamaStackClientHolder
2525
from configuration import configuration
2626
from models.responses import QueryResponse
2727
from models.requests import QueryRequest, Attachment
@@ -104,7 +104,7 @@ def query_endpoint_handler(
104104

105105
try:
106106
# try to get Llama Stack client
107-
client = get_llama_stack_client(llama_stack_config)
107+
client = LlamaStackClientHolder().get_client()
108108
model_id = select_model_id(client.models.list(), query_request)
109109
response, conversation_id = retrieve_response(
110110
client,
@@ -130,6 +130,7 @@ def query_endpoint_handler(
130130
)
131131

132132
return QueryResponse(conversation_id=conversation_id, response=response)
133+
133134
# connection to Llama Stack server
134135
except APIConnectionError as e:
135136
logger.error("Unable to connect to Llama Stack: %s", e)

src/app/endpoints/streaming_query.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from fastapi.responses import StreamingResponse
1818

1919
from auth import get_auth_dependency
20-
from client import get_async_llama_stack_client
20+
from client import AsyncLlamaStackClientHolder
2121
from configuration import configuration
2222
from models.requests import QueryRequest
2323
from utils.endpoints import check_configuration_loaded, get_system_prompt
@@ -197,7 +197,7 @@ async def streaming_query_endpoint_handler(
197197

198198
try:
199199
# try to get Llama Stack client
200-
client = await get_async_llama_stack_client(llama_stack_config)
200+
client = AsyncLlamaStackClientHolder().get_client()
201201
model_id = select_model_id(await client.models.list(), query_request)
202202
response, conversation_id = await retrieve_response(
203203
client,

src/client.py

Lines changed: 71 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -2,55 +2,86 @@
22

33
import logging
44

5+
from typing import Optional
6+
57
from llama_stack.distribution.library_client import (
68
AsyncLlamaStackAsLibraryClient, # type: ignore
79
LlamaStackAsLibraryClient, # type: ignore
810
)
911
from llama_stack_client import AsyncLlamaStackClient, LlamaStackClient # type: ignore
1012
from models.config import LLamaStackConfiguration
13+
from utils.types import Singleton
14+
1115

1216
logger = logging.getLogger(__name__)
1317

1418

15-
def get_llama_stack_client(
16-
llama_stack_config: LLamaStackConfiguration,
17-
) -> LlamaStackClient:
18-
"""Retrieve Llama stack client according to configuration."""
19-
if llama_stack_config.use_as_library_client is True:
20-
if llama_stack_config.library_client_config_path is not None:
21-
logger.info("Using Llama stack as library client")
22-
client = LlamaStackAsLibraryClient(
23-
llama_stack_config.library_client_config_path
19+
class LlamaStackClientHolder(metaclass=Singleton):
20+
"""Container for an initialised LlamaStackClient."""
21+
22+
_lsc: Optional[LlamaStackClient] = None
23+
24+
def load(self, llama_stack_config: LLamaStackConfiguration) -> None:
25+
"""Retrieve Llama stack client according to configuration."""
26+
if llama_stack_config.use_as_library_client is True:
27+
if llama_stack_config.library_client_config_path is not None:
28+
logger.info("Using Llama stack as library client")
29+
client = LlamaStackAsLibraryClient(
30+
llama_stack_config.library_client_config_path
31+
)
32+
client.initialize()
33+
self._lsc = client
34+
else:
35+
msg = "Configuration problem: library_client_config_path option is not set"
36+
logger.error(msg)
37+
# tisnik: use custom exception there - with cause etc.
38+
raise ValueError(msg)
39+
40+
else:
41+
logger.info("Using Llama stack running as a service")
42+
self._lsc = LlamaStackClient(
43+
base_url=llama_stack_config.url, api_key=llama_stack_config.api_key
44+
)
45+
46+
def get_client(self) -> LlamaStackClient:
47+
"""Return an initialised LlamaStackClient."""
48+
if not self._lsc:
49+
raise RuntimeError(
50+
"LlamaStackClient has not been initialised. Ensure 'load(..)' has been called."
2451
)
25-
client.initialize()
26-
return client
27-
msg = "Configuration problem: library_client_config_path option is not set"
28-
logger.error(msg)
29-
# tisnik: use custom exception there - with cause etc.
30-
raise Exception(msg) # pylint: disable=broad-exception-raised
31-
logger.info("Using Llama stack running as a service")
32-
return LlamaStackClient(
33-
base_url=llama_stack_config.url, api_key=llama_stack_config.api_key
34-
)
35-
36-
37-
async def get_async_llama_stack_client(
38-
llama_stack_config: LLamaStackConfiguration,
39-
) -> AsyncLlamaStackClient:
40-
"""Retrieve Async Llama stack client according to configuration."""
41-
if llama_stack_config.use_as_library_client is True:
42-
if llama_stack_config.library_client_config_path is not None:
43-
logger.info("Using Llama stack as library client")
44-
client = AsyncLlamaStackAsLibraryClient(
45-
llama_stack_config.library_client_config_path
52+
return self._lsc
53+
54+
55+
class AsyncLlamaStackClientHolder(metaclass=Singleton):
56+
"""Container for an initialised AsyncLlamaStackClient."""
57+
58+
_lsc: Optional[AsyncLlamaStackClient] = None
59+
60+
async def load(self, llama_stack_config: LLamaStackConfiguration) -> None:
61+
"""Retrieve Async Llama stack client according to configuration."""
62+
if llama_stack_config.use_as_library_client is True:
63+
if llama_stack_config.library_client_config_path is not None:
64+
logger.info("Using Llama stack as library client")
65+
client = AsyncLlamaStackAsLibraryClient(
66+
llama_stack_config.library_client_config_path
67+
)
68+
await client.initialize()
69+
self._lsc = client
70+
else:
71+
msg = "Configuration problem: library_client_config_path option is not set"
72+
logger.error(msg)
73+
# tisnik: use custom exception there - with cause etc.
74+
raise ValueError(msg)
75+
else:
76+
logger.info("Using Llama stack running as a service")
77+
self._lsc = AsyncLlamaStackClient(
78+
base_url=llama_stack_config.url, api_key=llama_stack_config.api_key
79+
)
80+
81+
def get_client(self) -> AsyncLlamaStackClient:
82+
"""Return an initialised AsyncLlamaStackClient."""
83+
if not self._lsc:
84+
raise RuntimeError(
85+
"AsyncLlamaStackClient has not been initialised. Ensure 'load(..)' has been called."
4686
)
47-
await client.initialize()
48-
return client
49-
msg = "Configuration problem: library_client_config_path option is not set"
50-
logger.error(msg)
51-
# tisnik: use custom exception there - with cause etc.
52-
raise Exception(msg) # pylint: disable=broad-exception-raised
53-
logger.info("Using Llama stack running as a service")
54-
return AsyncLlamaStackClient(
55-
base_url=llama_stack_config.url, api_key=llama_stack_config.api_key
56-
)
87+
return self._lsc

src/lightspeed_stack.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,13 @@
55
"""
66

77
from argparse import ArgumentParser
8+
import asyncio
89
import logging
9-
1010
from rich.logging import RichHandler
1111

1212
from runners.uvicorn import start_uvicorn
1313
from configuration import configuration
14-
14+
from client import LlamaStackClientHolder, AsyncLlamaStackClientHolder
1515

1616
FORMAT = "%(message)s"
1717
logging.basicConfig(
@@ -61,6 +61,12 @@ def main() -> None:
6161
logger.info(
6262
"Llama stack configuration: %s", configuration.llama_stack_configuration
6363
)
64+
logger.info("Creating LlamaStackClient")
65+
LlamaStackClientHolder().load(configuration.configuration.llama_stack)
66+
logger.info("Creating AsyncLlamaStackClient")
67+
asyncio.run(
68+
AsyncLlamaStackClientHolder().load(configuration.configuration.llama_stack)
69+
)
6470

6571
if args.dump_configuration:
6672
configuration.configuration.dump()

src/utils/common.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,13 @@
33
from typing import Any, List, cast
44
from logging import Logger
55

6-
from llama_stack_client import LlamaStackClient
6+
from llama_stack_client import LlamaStackClient, AsyncLlamaStackClient
77

88
from llama_stack.distribution.library_client import (
9-
LlamaStackAsLibraryClient,
109
AsyncLlamaStackAsLibraryClient,
1110
)
1211

13-
from client import get_llama_stack_client
12+
from client import LlamaStackClientHolder, AsyncLlamaStackClientHolder
1413
from models.config import Configuration, ModelContextProtocolServer
1514

1615

@@ -39,24 +38,19 @@ async def register_mcp_servers_async(
3938

4039
if configuration.llama_stack.use_as_library_client:
4140
# Library client - use async interface
42-
# config.py validation ensures library_client_config_path is not None
43-
# when use_as_library_client is True
44-
config_path = cast(str, configuration.llama_stack.library_client_config_path)
45-
client = LlamaStackAsLibraryClient(config_path)
46-
await client.async_client.initialize()
47-
48-
await _register_mcp_toolgroups_async(
49-
client.async_client, configuration.mcp_servers, logger
41+
client = cast(
42+
AsyncLlamaStackAsLibraryClient, AsyncLlamaStackClientHolder().get_client()
5043
)
44+
await client.initialize()
45+
await _register_mcp_toolgroups_async(client, configuration.mcp_servers, logger)
5146
else:
5247
# Service client - use sync interface
53-
client = get_llama_stack_client(configuration.llama_stack)
54-
48+
client = LlamaStackClientHolder().get_client()
5549
_register_mcp_toolgroups_sync(client, configuration.mcp_servers, logger)
5650

5751

5852
async def _register_mcp_toolgroups_async(
59-
client: AsyncLlamaStackAsLibraryClient,
53+
client: AsyncLlamaStackClient,
6054
mcp_servers: List[ModelContextProtocolServer],
6155
logger: Logger,
6256
) -> None:

src/utils/types.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
"""Common types for the project."""
2+
3+
4+
class Singleton(type):
5+
"""Metaclass for Singleton support."""
6+
7+
_instances = {} # type: ignore
8+
9+
def __call__(cls, *args, **kwargs): # type: ignore
10+
"""Ensure a single instance is created."""
11+
if cls not in cls._instances:
12+
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
13+
return cls._instances[cls]

tests/unit/app/endpoints/test_health.py

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -98,14 +98,11 @@ class TestGetProvidersHealthStatuses:
9898
def test_get_providers_health_statuses(self, mocker):
9999
"""Test get_providers_health_statuses with healthy providers."""
100100
# Mock the imports
101-
mock_get_llama_stack_client = mocker.patch(
102-
"app.endpoints.health.get_llama_stack_client"
103-
)
104-
mock_configuration = mocker.patch("app.endpoints.health.configuration")
101+
mock_lsc = mocker.patch("client.LlamaStackClientHolder.get_client")
105102

106103
# Mock the client and its methods
107104
mock_client = mocker.Mock()
108-
mock_get_llama_stack_client.return_value = mock_client
105+
mock_lsc.return_value = mock_client
109106

110107
# Mock providers.list() to return providers with health
111108
mock_provider_1 = mocker.Mock()
@@ -136,9 +133,6 @@ def test_get_providers_health_statuses(self, mocker):
136133
]
137134

138135
# Mock configuration
139-
mock_llama_stack_config = mocker.Mock()
140-
mock_configuration.llama_stack_configuration = mock_llama_stack_config
141-
142136
result = get_providers_health_statuses()
143137

144138
assert len(result) == 3
@@ -155,17 +149,10 @@ def test_get_providers_health_statuses(self, mocker):
155149
def test_get_providers_health_statuses_connection_error(self, mocker):
156150
"""Test get_providers_health_statuses when connection fails."""
157151
# Mock the imports
158-
mock_get_llama_stack_client = mocker.patch(
159-
"app.endpoints.health.get_llama_stack_client"
160-
)
161-
mock_configuration = mocker.patch("app.endpoints.health.configuration")
162-
163-
# Mock configuration
164-
mock_llama_stack_config = mocker.Mock()
165-
mock_configuration.llama_stack_configuration = mock_llama_stack_config
152+
mock_lsc = mocker.patch("client.LlamaStackClientHolder.get_client")
166153

167154
# Mock get_llama_stack_client to raise an exception
168-
mock_get_llama_stack_client.side_effect = Exception("Connection error")
155+
mock_lsc.side_effect = Exception("Connection error")
169156

170157
result = get_providers_health_statuses()
171158

tests/unit/app/endpoints/test_models.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -124,11 +124,10 @@ def test_models_endpoint_handler_unable_to_retrieve_models_list(mocker):
124124
# Mock the LlamaStack client
125125
mock_client = Mock()
126126
mock_client.models.list.return_value = []
127-
128-
# Mock the LlamaStack client (shouldn't be called directly)
129-
mocker.patch(
130-
"app.endpoints.models.get_llama_stack_client", return_value=mock_client
131-
)
127+
mock_lsc = mocker.patch("client.LlamaStackClientHolder.get_client")
128+
mock_lsc.return_value = mock_client
129+
mock_config = mocker.Mock()
130+
mocker.patch("app.endpoints.models.configuration", mock_config)
132131

133132
request = None
134133
response = models_endpoint_handler(request)

0 commit comments

Comments
 (0)