Skip to content

Commit 6f9d622

Browse files
authored
fix(api): update embeddings signature so inputs and outputs list align (#1161)
See Issue #922 The change is slightly backwards incompatible but no callsite (in our client codebases or stack-apps) every passes a depth-2 `List[List[InterleavedContentItem]]` (which is now disallowed.) ## Test Plan ```bash $ cd llama_stack/providers/tests/inference $ pytest -s -v -k fireworks test_embeddings.py \ --inference-model nomic-ai/nomic-embed-text-v1.5 --env EMBEDDING_DIMENSION=784 $ pytest -s -v -k together test_embeddings.py \ --inference-model togethercomputer/m2-bert-80M-8k-retrieval --env EMBEDDING_DIMENSION=784 $ pytest -s -v -k ollama test_embeddings.py \ --inference-model all-minilm:latest --env EMBEDDING_DIMENSION=784 ``` Also ran `tests/client-sdk/inference/test_embeddings.py`
1 parent cfa752f commit 6f9d622

File tree

17 files changed

+85
-41
lines changed

17 files changed

+85
-41
lines changed

docs/_static/llama-stack-spec.html

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4929,11 +4929,21 @@
49294929
"description": "The identifier of the model to use. The model must be an embedding model registered with Llama Stack and available via the /models endpoint."
49304930
},
49314931
"contents": {
4932-
"type": "array",
4933-
"items": {
4934-
"$ref": "#/components/schemas/InterleavedContent"
4935-
},
4936-
"description": "List of contents to generate embeddings for. Note that content can be multimodal. The behavior depends on the model and provider. Some models may only support text."
4932+
"oneOf": [
4933+
{
4934+
"type": "array",
4935+
"items": {
4936+
"type": "string"
4937+
}
4938+
},
4939+
{
4940+
"type": "array",
4941+
"items": {
4942+
"$ref": "#/components/schemas/InterleavedContentItem"
4943+
}
4944+
}
4945+
],
4946+
"description": "List of contents to generate embeddings for. Each content can be a string or an InterleavedContentItem (and hence can be multimodal). The behavior depends on the model and provider. Some models may only support text."
49374947
}
49384948
},
49394949
"additionalProperties": false,

docs/_static/llama-stack-spec.yaml

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3224,13 +3224,17 @@ components:
32243224
The identifier of the model to use. The model must be an embedding model
32253225
registered with Llama Stack and available via the /models endpoint.
32263226
contents:
3227-
type: array
3228-
items:
3229-
$ref: '#/components/schemas/InterleavedContent'
3227+
oneOf:
3228+
- type: array
3229+
items:
3230+
type: string
3231+
- type: array
3232+
items:
3233+
$ref: '#/components/schemas/InterleavedContentItem'
32303234
description: >-
3231-
List of contents to generate embeddings for. Note that content can be
3232-
multimodal. The behavior depends on the model and provider. Some models
3233-
may only support text.
3235+
List of contents to generate embeddings for. Each content can be a string
3236+
or an InterleavedContentItem (and hence can be multimodal). The behavior
3237+
depends on the model and provider. Some models may only support text.
32343238
additionalProperties: false
32353239
required:
32363240
- model_id

llama_stack/apis/inference/inference.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from pydantic import BaseModel, Field, field_validator
2121
from typing_extensions import Annotated
2222

23-
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent
23+
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, InterleavedContentItem
2424
from llama_stack.apis.models import Model
2525
from llama_stack.apis.telemetry.telemetry import MetricResponseMixin
2626
from llama_stack.models.llama.datatypes import (
@@ -481,12 +481,12 @@ async def chat_completion(
481481
async def embeddings(
482482
self,
483483
model_id: str,
484-
contents: List[InterleavedContent],
484+
contents: List[str] | List[InterleavedContentItem],
485485
) -> EmbeddingsResponse:
486486
"""Generate embeddings for content pieces using the specified model.
487487
488488
:param model_id: The identifier of the model to use. The model must be an embedding model registered with Llama Stack and available via the /models endpoint.
489-
:param contents: List of contents to generate embeddings for. Note that content can be multimodal. The behavior depends on the model and provider. Some models may only support text.
489+
:param contents: List of contents to generate embeddings for. Each content can be a string or an InterleavedContentItem (and hence can be multimodal). The behavior depends on the model and provider. Some models may only support text.
490490
:returns: An array of embeddings, one for each content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id}
491491
"""
492492
...

llama_stack/distribution/routers/routers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from typing import Any, AsyncGenerator, Dict, List, Optional
88

9-
from llama_stack.apis.common.content_types import URL, InterleavedContent
9+
from llama_stack.apis.common.content_types import URL, InterleavedContent, InterleavedContentItem
1010
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
1111
from llama_stack.apis.eval import (
1212
BenchmarkConfig,
@@ -214,7 +214,7 @@ async def completion(
214214
async def embeddings(
215215
self,
216216
model_id: str,
217-
contents: List[InterleavedContent],
217+
contents: List[str] | List[InterleavedContentItem],
218218
) -> EmbeddingsResponse:
219219
model = await self.routing_table.get_model(model_id)
220220
if model is None:

llama_stack/providers/inline/inference/vllm/vllm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
CompletionResponseStreamChunk,
2424
EmbeddingsResponse,
2525
Inference,
26+
InterleavedContentItem,
2627
LogProbConfig,
2728
Message,
2829
ResponseFormat,
@@ -230,5 +231,5 @@ async def _generate_and_convert_to_openai_compat():
230231
async for chunk in process_chat_completion_stream_response(stream, request):
231232
yield chunk
232233

233-
async def embeddings(self, model_id: str, contents: List[InterleavedContent]) -> EmbeddingsResponse:
234+
async def embeddings(self, model_id: str, contents: List[str] | List[InterleavedContentItem]) -> EmbeddingsResponse:
234235
raise NotImplementedError()

llama_stack/providers/remote/inference/bedrock/bedrock.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from botocore.client import BaseClient
1111

12-
from llama_stack.apis.common.content_types import InterleavedContent
12+
from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
1313
from llama_stack.apis.inference import (
1414
ChatCompletionRequest,
1515
ChatCompletionResponse,
@@ -162,7 +162,7 @@ async def _get_params_for_chat_completion(self, request: ChatCompletionRequest)
162162
async def embeddings(
163163
self,
164164
model_id: str,
165-
contents: List[InterleavedContent],
165+
contents: List[str] | List[InterleavedContentItem],
166166
) -> EmbeddingsResponse:
167167
model = await self.model_store.get_model(model_id)
168168
embeddings = []

llama_stack/providers/remote/inference/cerebras/cerebras.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from cerebras.cloud.sdk import AsyncCerebras
1010

11-
from llama_stack.apis.common.content_types import InterleavedContent
11+
from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
1212
from llama_stack.apis.inference import (
1313
ChatCompletionRequest,
1414
CompletionRequest,
@@ -172,6 +172,6 @@ async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequ
172172
async def embeddings(
173173
self,
174174
model_id: str,
175-
contents: List[InterleavedContent],
175+
contents: List[str] | List[InterleavedContentItem],
176176
) -> EmbeddingsResponse:
177177
raise NotImplementedError()

llama_stack/providers/remote/inference/databricks/databricks.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from openai import OpenAI
1010

11-
from llama_stack.apis.common.content_types import InterleavedContent
11+
from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
1212
from llama_stack.apis.inference import (
1313
ChatCompletionRequest,
1414
ChatCompletionResponse,
@@ -130,7 +130,7 @@ def _get_params(self, request: ChatCompletionRequest) -> dict:
130130

131131
async def embeddings(
132132
self,
133-
model: str,
134-
contents: List[InterleavedContent],
133+
model_id: str,
134+
contents: List[str] | List[InterleavedContentItem],
135135
) -> EmbeddingsResponse:
136136
raise NotImplementedError()

llama_stack/providers/remote/inference/fireworks/fireworks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from fireworks.client import Fireworks
1010

11-
from llama_stack.apis.common.content_types import InterleavedContent
11+
from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
1212
from llama_stack.apis.inference import (
1313
ChatCompletionRequest,
1414
ChatCompletionResponse,
@@ -232,7 +232,7 @@ async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequ
232232
async def embeddings(
233233
self,
234234
model_id: str,
235-
contents: List[InterleavedContent],
235+
contents: List[str] | List[InterleavedContentItem],
236236
) -> EmbeddingsResponse:
237237
model = await self.model_store.get_model(model_id)
238238

llama_stack/providers/remote/inference/groq/groq.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
EmbeddingsResponse,
2020
Inference,
2121
InterleavedContent,
22+
InterleavedContentItem,
2223
LogProbConfig,
2324
Message,
2425
ResponseFormat,
@@ -140,7 +141,7 @@ async def chat_completion(
140141
async def embeddings(
141142
self,
142143
model_id: str,
143-
contents: List[InterleavedContent],
144+
contents: List[str] | List[InterleavedContentItem],
144145
) -> EmbeddingsResponse:
145146
raise NotImplementedError()
146147

0 commit comments

Comments
 (0)