Skip to content

Commit 67f23df

Browse files
seanzhougooglecopybara-github
authored andcommitted
feat: Allow user specify embedding model for file retrieval
And use Gemini embedding model as default model if no embedding model is specified. PiperOrigin-RevId: 801161505
1 parent 0c87907 commit 67f23df

File tree

3 files changed

+206
-11
lines changed

3 files changed

+206
-11
lines changed

pyproject.toml

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -127,15 +127,16 @@ docs = [
127127

128128
# Optional extensions
129129
extensions = [
130-
"anthropic>=0.43.0", # For anthropic model support
131-
"beautifulsoup4>=3.2.2", # For load_web_page tool.
132-
"crewai[tools];python_version>='3.10'", # For CrewaiTool
133-
"docker>=7.0.0", # For ContainerCodeExecutor
134-
"langgraph>=0.2.60", # For LangGraphAgent
135-
"litellm>=1.75.5", # For LiteLlm class. Currently has OpenAI limitations. TODO: once LiteLlm fix it
136-
"llama-index-readers-file>=0.4.0", # For retrieval using LlamaIndex.
137-
"lxml>=5.3.0", # For load_web_page tool.
138-
"toolbox-core>=0.1.0", # For tools.toolbox_toolset.ToolboxToolset
130+
"anthropic>=0.43.0", # For anthropic model support
131+
"beautifulsoup4>=3.2.2", # For load_web_page tool.
132+
"crewai[tools];python_version>='3.10'", # For CrewaiTool
133+
"docker>=7.0.0", # For ContainerCodeExecutor
134+
"langgraph>=0.2.60", # For LangGraphAgent
135+
"litellm>=1.75.5", # For LiteLlm class. Currently has OpenAI limitations. TODO: once LiteLlm fix it
136+
"llama-index-readers-file>=0.4.0", # For retrieval using LlamaIndex.
137+
"llama-index-embeddings-google-genai>=0.3.0",# For files retrieval using LlamaIndex.
138+
"lxml>=5.3.0", # For load_web_page tool.
139+
"toolbox-core>=0.1.0", # For tools.toolbox_toolset.ToolboxToolset
139140
]
140141

141142

src/google/adk/tools/retrieval/files_retrieval.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,23 +17,64 @@
1717
from __future__ import annotations
1818

1919
import logging
20+
from typing import Optional
2021

2122
from llama_index.core import SimpleDirectoryReader
2223
from llama_index.core import VectorStoreIndex
24+
from llama_index.core.base.embeddings.base import BaseEmbedding
2325

2426
from .llama_index_retrieval import LlamaIndexRetrieval
2527

2628
logger = logging.getLogger("google_adk." + __name__)
2729

2830

31+
def _get_default_embedding_model() -> BaseEmbedding:
32+
"""Get the default Google Gemini embedding model.
33+
34+
Returns:
35+
GoogleGenAIEmbedding instance configured with text-embedding-004 model.
36+
37+
Raises:
38+
ImportError: If llama-index-embeddings-google-genai package is not installed.
39+
"""
40+
try:
41+
from llama_index.embeddings.google_genai import GoogleGenAIEmbedding
42+
43+
return GoogleGenAIEmbedding(model_name="text-embedding-004")
44+
except ImportError as e:
45+
raise ImportError(
46+
"llama-index-embeddings-google-genai package not found. "
47+
"Please run: pip install llama-index-embeddings-google-genai"
48+
) from e
49+
50+
2951
class FilesRetrieval(LlamaIndexRetrieval):
3052

31-
def __init__(self, *, name: str, description: str, input_dir: str):
53+
def __init__(
54+
self,
55+
*,
56+
name: str,
57+
description: str,
58+
input_dir: str,
59+
embedding_model: Optional[BaseEmbedding] = None,
60+
):
61+
"""Initialize FilesRetrieval with optional embedding model.
3262
63+
Args:
64+
name: Name of the tool.
65+
description: Description of the tool.
66+
input_dir: Directory path containing files to index.
67+
embedding_model: Optional custom embedding model. If None, defaults to
68+
Google's text-embedding-004 model.
69+
"""
3370
self.input_dir = input_dir
3471

72+
if embedding_model is None:
73+
embedding_model = _get_default_embedding_model()
74+
3575
logger.info("Loading data from %s", input_dir)
3676
retriever = VectorStoreIndex.from_documents(
37-
SimpleDirectoryReader(input_dir).load_data()
77+
SimpleDirectoryReader(input_dir).load_data(),
78+
embed_model=embedding_model,
3879
).as_retriever()
3980
super().__init__(name=name, description=description, retriever=retriever)
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for FilesRetrieval tool."""
16+
17+
import sys
18+
import unittest.mock as mock
19+
20+
from google.adk.tools.retrieval.files_retrieval import _get_default_embedding_model
21+
from google.adk.tools.retrieval.files_retrieval import FilesRetrieval
22+
from llama_index.core.base.embeddings.base import BaseEmbedding
23+
import pytest
24+
25+
26+
class MockEmbedding(BaseEmbedding):
27+
"""Mock embedding model for testing."""
28+
29+
def _get_query_embedding(self, query):
30+
return [0.1] * 384
31+
32+
def _get_text_embedding(self, text):
33+
return [0.1] * 384
34+
35+
async def _aget_query_embedding(self, query):
36+
return [0.1] * 384
37+
38+
async def _aget_text_embedding(self, text):
39+
return [0.1] * 384
40+
41+
42+
class TestFilesRetrieval:
43+
44+
def test_files_retrieval_with_custom_embedding(self, tmp_path):
45+
"""Test FilesRetrieval with custom embedding model."""
46+
# Create test file
47+
test_file = tmp_path / "test.txt"
48+
test_file.write_text("This is a test document for retrieval testing.")
49+
50+
custom_embedding = MockEmbedding()
51+
retrieval = FilesRetrieval(
52+
name="test_retrieval",
53+
description="Test retrieval tool",
54+
input_dir=str(tmp_path),
55+
embedding_model=custom_embedding,
56+
)
57+
58+
assert retrieval.name == "test_retrieval"
59+
assert retrieval.input_dir == str(tmp_path)
60+
assert retrieval.retriever is not None
61+
62+
@mock.patch(
63+
"google.adk.tools.retrieval.files_retrieval._get_default_embedding_model"
64+
)
65+
def test_files_retrieval_uses_default_embedding(
66+
self, mock_get_default_embedding, tmp_path
67+
):
68+
"""Test FilesRetrieval uses default embedding when none provided."""
69+
# Create test file
70+
test_file = tmp_path / "test.txt"
71+
test_file.write_text("This is a test document for retrieval testing.")
72+
73+
mock_embedding = MockEmbedding()
74+
mock_get_default_embedding.return_value = mock_embedding
75+
76+
retrieval = FilesRetrieval(
77+
name="test_retrieval",
78+
description="Test retrieval tool",
79+
input_dir=str(tmp_path),
80+
)
81+
82+
mock_get_default_embedding.assert_called_once()
83+
assert retrieval.name == "test_retrieval"
84+
assert retrieval.input_dir == str(tmp_path)
85+
86+
def test_get_default_embedding_model_import_error(self):
87+
"""Test _get_default_embedding_model handles ImportError correctly."""
88+
# Simulate the package not being installed by making import fail
89+
import builtins
90+
91+
original_import = builtins.__import__
92+
93+
def mock_import(name, *args, **kwargs):
94+
if name == "llama_index.embeddings.google_genai":
95+
raise ImportError(
96+
"No module named 'llama_index.embeddings.google_genai'"
97+
)
98+
return original_import(name, *args, **kwargs)
99+
100+
with mock.patch("builtins.__import__", side_effect=mock_import):
101+
with pytest.raises(ImportError) as exc_info:
102+
_get_default_embedding_model()
103+
104+
# The exception should be re-raised as our custom ImportError with helpful message
105+
assert "llama-index-embeddings-google-genai package not found" in str(
106+
exc_info.value
107+
)
108+
assert "pip install llama-index-embeddings-google-genai" in str(
109+
exc_info.value
110+
)
111+
112+
def test_get_default_embedding_model_success(self):
113+
"""Test _get_default_embedding_model returns Google embedding when available."""
114+
# Skip this test in Python 3.9 where llama_index.embeddings.google_genai may not be available
115+
if sys.version_info < (3, 10):
116+
pytest.skip("llama_index.embeddings.google_genai requires Python 3.10+")
117+
118+
# Mock the module creation to avoid import issues
119+
mock_module = mock.MagicMock()
120+
mock_embedding_instance = MockEmbedding()
121+
mock_module.GoogleGenAIEmbedding.return_value = mock_embedding_instance
122+
123+
with mock.patch.dict(
124+
"sys.modules", {"llama_index.embeddings.google_genai": mock_module}
125+
):
126+
result = _get_default_embedding_model()
127+
128+
mock_module.GoogleGenAIEmbedding.assert_called_once_with(
129+
model_name="text-embedding-004"
130+
)
131+
assert result == mock_embedding_instance
132+
133+
def test_backward_compatibility(self, tmp_path):
134+
"""Test that existing code without embedding_model parameter still works."""
135+
# Create test file
136+
test_file = tmp_path / "test.txt"
137+
test_file.write_text("This is a test document for retrieval testing.")
138+
139+
with mock.patch(
140+
"google.adk.tools.retrieval.files_retrieval._get_default_embedding_model"
141+
) as mock_get_default:
142+
mock_get_default.return_value = MockEmbedding()
143+
144+
# This should work exactly like before - no embedding_model parameter
145+
retrieval = FilesRetrieval(
146+
name="test_retrieval",
147+
description="Test retrieval tool",
148+
input_dir=str(tmp_path),
149+
)
150+
151+
assert retrieval.name == "test_retrieval"
152+
assert retrieval.input_dir == str(tmp_path)
153+
mock_get_default.assert_called_once()

0 commit comments

Comments
 (0)