Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions llama_stack/apis/vector_io/vector_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,22 @@ def chunk_id(self) -> str:

return generate_chunk_id(str(uuid.uuid4()), str(self.content))

@property
def document_id(self) -> str | None:
"""Returns the document_id from either metadata or chunk_metadata, with metadata taking precedence."""
# Check metadata first (takes precedence)
doc_id = self.metadata.get("document_id")
if doc_id is not None:
if not isinstance(doc_id, str):
raise TypeError(f"metadata['document_id'] must be a string, got {type(doc_id).__name__}: {doc_id!r}")
return doc_id

# Fall back to chunk_metadata if available (Pydantic ensures type safety)
if self.chunk_metadata is not None:
return self.chunk_metadata.document_id

return None


@json_schema_type
class QueryChunksResponse(BaseModel):
Expand Down
4 changes: 3 additions & 1 deletion llama_stack/core/routers/vector_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,10 @@ async def insert_chunks(
chunks: list[Chunk],
ttl_seconds: int | None = None,
) -> None:
doc_ids = [chunk.document_id for chunk in chunks[:3]]
logger.debug(
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, "
f"ttl_seconds={ttl_seconds}, chunk_ids={doc_ids}{' and more...' if len(chunks) > 3 else ''}"
)
provider = await self.routing_table.get_provider_impl(vector_db_id)
return await provider.insert_chunks(vector_db_id, chunks, ttl_seconds)
Expand Down
2 changes: 1 addition & 1 deletion llama_stack/providers/inline/tool_runtime/rag/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ async def query(
return RAGQueryResult(
content=picked,
metadata={
"document_ids": [c.metadata["document_id"] for c in chunks[: len(picked)]],
"document_ids": [c.document_id for c in chunks[: len(picked)]],
"chunks": [c.content for c in chunks[: len(picked)]],
"scores": scores[: len(picked)],
"vector_db_ids": [c.metadata["vector_db_id"] for c in chunks[: len(picked)]],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,37 @@ async def test_insert_chunks_missing_db_raises(vector_io_adapter):
await vector_io_adapter.insert_chunks("db_not_exist", [])


async def test_insert_chunks_with_missing_document_id(vector_io_adapter):
"""Ensure no KeyError when document_id is missing or in different places."""
from llama_stack.apis.vector_io import Chunk, ChunkMetadata

fake_index = AsyncMock()
vector_io_adapter.cache["db1"] = fake_index

# Various document_id scenarios that shouldn't crash
chunks = [
Chunk(content="has doc_id in metadata", metadata={"document_id": "doc-1"}),
Chunk(content="no doc_id anywhere", metadata={"source": "test"}),
Chunk(content="doc_id in chunk_metadata", chunk_metadata=ChunkMetadata(document_id="doc-3")),
]

# Should work without KeyError
await vector_io_adapter.insert_chunks("db1", chunks)
fake_index.insert_chunks.assert_awaited_once()


async def test_document_id_with_invalid_type_raises_error():
"""Ensure TypeError is raised when document_id is not a string."""
from llama_stack.apis.vector_io import Chunk

# Integer document_id should raise TypeError
chunk = Chunk(content="test", metadata={"document_id": 12345})
with pytest.raises(TypeError) as exc_info:
_ = chunk.document_id
assert "metadata['document_id'] must be a string" in str(exc_info.value)
assert "got int" in str(exc_info.value)


async def test_query_chunks_calls_underlying_index_and_returns(vector_io_adapter):
expected = QueryChunksResponse(chunks=[Chunk(content="c1")], scores=[0.1])
fake_index = AsyncMock(query_chunks=AsyncMock(return_value=expected))
Expand Down
Loading