Skip to content

Commit accf3a9

Browse files
authored
Merge pull request #325 from ldjebran/fixes-llama-stack-model-id
Fixes llama-stack model-id
2 parents 3cded1c + 7d49636 commit accf3a9

File tree

2 files changed

+21
-9
lines changed

2 files changed

+21
-9
lines changed

src/app/endpoints/query.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,10 +216,12 @@ def select_model_and_provider_id(
216216
},
217217
) from e
218218

219+
llama_stack_model_id = f"{provider_id}/{model_id}"
219220
# Validate that the model_id and provider_id are in the available models
220221
logger.debug("Searching for model: %s, provider: %s", model_id, provider_id)
221222
if not any(
222-
m.identifier == model_id and m.provider_id == provider_id for m in models
223+
m.identifier == llama_stack_model_id and m.provider_id == provider_id
224+
for m in models
223225
):
224226
message = f"Model {model_id} from provider {provider_id} not found in available models"
225227
logger.error(message)
@@ -231,7 +233,7 @@ def select_model_and_provider_id(
231233
},
232234
)
233235

234-
return model_id, provider_id
236+
return llama_stack_model_id, provider_id
235237

236238

237239
def _is_inout_shield(shield: Shield) -> bool:

tests/unit/app/endpoints/test_query.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -191,10 +191,16 @@ def test_select_model_and_provider_id_from_request(mocker):
191191
)
192192

193193
model_list = [
194-
mocker.Mock(identifier="model1", model_type="llm", provider_id="provider1"),
195-
mocker.Mock(identifier="model2", model_type="llm", provider_id="provider2"),
196194
mocker.Mock(
197-
identifier="default_model", model_type="llm", provider_id="default_provider"
195+
identifier="provider1/model1", model_type="llm", provider_id="provider1"
196+
),
197+
mocker.Mock(
198+
identifier="provider2/model2", model_type="llm", provider_id="provider2"
199+
),
200+
mocker.Mock(
201+
identifier="default_provider/default_model",
202+
model_type="llm",
203+
provider_id="default_provider",
198204
),
199205
]
200206

@@ -206,7 +212,7 @@ def test_select_model_and_provider_id_from_request(mocker):
206212
# Assert the model and provider from request take precedence from the configuration one
207213
model_id, provider_id = select_model_and_provider_id(model_list, query_request)
208214

209-
assert model_id == "model2"
215+
assert model_id == "provider2/model2"
210216
assert provider_id == "provider2"
211217

212218

@@ -222,9 +228,13 @@ def test_select_model_and_provider_id_from_configuration(mocker):
222228
)
223229

224230
model_list = [
225-
mocker.Mock(identifier="model1", model_type="llm", provider_id="provider1"),
226231
mocker.Mock(
227-
identifier="default_model", model_type="llm", provider_id="default_provider"
232+
identifier="provider1/model1", model_type="llm", provider_id="provider1"
233+
),
234+
mocker.Mock(
235+
identifier="default_provider/default_model",
236+
model_type="llm",
237+
provider_id="default_provider",
228238
),
229239
]
230240

@@ -236,7 +246,7 @@ def test_select_model_and_provider_id_from_configuration(mocker):
236246
model_id, provider_id = select_model_and_provider_id(model_list, query_request)
237247

238248
# Assert that the default model and provider from the configuration are returned
239-
assert model_id == "default_model"
249+
assert model_id == "default_provider/default_model"
240250
assert provider_id == "default_provider"
241251

242252

0 commit comments

Comments
 (0)