Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 62 additions & 1 deletion llama_stack/providers/remote/inference/gemini/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from openai import NOT_GIVEN

from llama_stack.apis.inference import (
OpenAIEmbeddingData,
OpenAIEmbeddingsRequestWithExtraBody,
OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin

from .config import GeminiConfig
Expand All @@ -14,8 +22,61 @@ class GeminiInferenceAdapter(OpenAIMixin):

provider_data_api_key_field: str = "gemini_api_key"
embedding_model_metadata: dict[str, dict[str, int]] = {
"text-embedding-004": {"embedding_dimension": 768, "context_length": 2048},
"models/text-embedding-004": {"embedding_dimension": 768, "context_length": 2048},
"models/gemini-embedding-001": {"embedding_dimension": 3072, "context_length": 2048},
}

def get_base_url(self):
return "https://generativelanguage.googleapis.com/v1beta/openai/"

async def openai_embeddings(
self,
params: OpenAIEmbeddingsRequestWithExtraBody,
) -> OpenAIEmbeddingsResponse:
"""
Override embeddings method to handle Gemini's missing usage statistics.
Gemini's embedding API doesn't return usage information, so we provide default values.
"""
# Prepare request parameters
request_params = {
"model": await self._get_provider_model_id(params.model),
"input": params.input,
"encoding_format": params.encoding_format if params.encoding_format is not None else NOT_GIVEN,
"dimensions": params.dimensions if params.dimensions is not None else NOT_GIVEN,
"user": params.user if params.user is not None else NOT_GIVEN,
}

# Add extra_body if present
extra_body = params.model_extra
if extra_body:
request_params["extra_body"] = extra_body

# Call OpenAI embeddings API with properly typed parameters
response = await self.client.embeddings.create(**request_params)

data = []
for i, embedding_data in enumerate(response.data):
data.append(
OpenAIEmbeddingData(
embedding=embedding_data.embedding,
index=i,
)
)

# Gemini doesn't return usage statistics - use default values
if hasattr(response, "usage") and response.usage:
usage = OpenAIEmbeddingUsage(
prompt_tokens=response.usage.prompt_tokens,
total_tokens=response.usage.total_tokens,
)
else:
usage = OpenAIEmbeddingUsage(
prompt_tokens=0,
total_tokens=0,
)

return OpenAIEmbeddingsResponse(
data=data,
model=params.model,
usage=usage,
)
Loading